diff --git a/providers/base.py b/providers/base.py index aff8705..991642c 100644 --- a/providers/base.py +++ b/providers/base.py @@ -196,6 +196,9 @@ class ModelProvider(ABC): # All concrete providers must define their supported models SUPPORTED_MODELS: dict[str, Any] = {} + # Default maximum image size in MB + DEFAULT_MAX_IMAGE_SIZE_MB = 20.0 + def __init__(self, api_key: str, **kwargs): """Initialize the provider with API key and optional configuration.""" self.api_key = api_key @@ -433,6 +436,83 @@ class ModelProvider(ABC): return list(all_models) + def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]: + """Provider-independent image validation. + + Args: + image_path: Path to image file or data URL + max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB) + + Returns: + Tuple of (image_bytes, mime_type) + + Raises: + ValueError: If image is invalid + + Examples: + # Validate a file path + image_bytes, mime_type = provider.validate_image("/path/to/image.png") + + # Validate a data URL + image_bytes, mime_type = provider.validate_image("data:image/png;base64,...") + + # Validate with custom size limit + image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0) + """ + import base64 + import os + + from utils.file_types import IMAGES, get_image_mime_type + + # Use default if not specified + if max_size_mb is None: + max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB + + if image_path.startswith("data:"): + # Parse data URL: data:image/png;base64,iVBORw0... + try: + header, data = image_path.split(",", 1) + mime_type = header.split(";")[0].split(":")[1] + except (ValueError, IndexError) as e: + raise ValueError(f"Invalid data URL format: {e}") + + # Validate MIME type using IMAGES constant + valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES] + if mime_type not in valid_mime_types: + raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}") + + # Decode base64 data + try: + image_bytes = base64.b64decode(data) + except Exception as e: + raise ValueError(f"Invalid base64 data: {e}") + else: + # Handle file path + if not os.path.exists(image_path): + raise ValueError(f"Image file not found: {image_path}") + + # Validate extension + ext = os.path.splitext(image_path)[1].lower() + if ext not in IMAGES: + raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}") + + # Get MIME type + mime_type = get_image_mime_type(ext) + + # Read file + try: + with open(image_path, "rb") as f: + image_bytes = f.read() + except Exception as e: + raise ValueError(f"Failed to read image file: {e}") + + # Validate size + size_mb = len(image_bytes) / (1024 * 1024) + if size_mb > max_size_mb: + raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)") + + return image_bytes, mime_type + def close(self): """Clean up any resources held by the provider. diff --git a/providers/gemini.py b/providers/gemini.py index 51916b0..783e3dc 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -2,7 +2,6 @@ import base64 import logging -import os import time from typing import Optional @@ -440,28 +439,22 @@ class GeminiModelProvider(ModelProvider): def _process_image(self, image_path: str) -> Optional[dict]: """Process an image for Gemini API.""" try: - if image_path.startswith("data:image/"): - # Handle data URL: data:image/png;base64,iVBORw0... - header, data = image_path.split(",", 1) - mime_type = header.split(";")[0].split(":")[1] + # Use base class validation + image_bytes, mime_type = self.validate_image(image_path) + + # For data URLs, extract the base64 data directly + if image_path.startswith("data:"): + # Extract base64 data from data URL + _, data = image_path.split(",", 1) return {"inline_data": {"mime_type": mime_type, "data": data}} else: - # Handle file path - from utils.file_types import get_image_mime_type - - if not os.path.exists(image_path): - logger.warning(f"Image file not found: {image_path}") - return None - - # Detect MIME type from file extension using centralized mappings - ext = os.path.splitext(image_path)[1].lower() - mime_type = get_image_mime_type(ext) - - # Read and encode the image - with open(image_path, "rb") as f: - image_data = base64.b64encode(f.read()).decode() - + # For file paths, encode the bytes + image_data = base64.b64encode(image_bytes).decode() return {"inline_data": {"mime_type": mime_type, "data": image_data}} + + except ValueError as e: + logger.warning(str(e)) + return None except Exception as e: logger.error(f"Error processing image {image_path}: {e}") return None diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py index 88cbb26..6e8617c 100644 --- a/providers/openai_compatible.py +++ b/providers/openai_compatible.py @@ -1,6 +1,5 @@ """Base class for OpenAI-compatible API providers.""" -import base64 import ipaddress import logging import os @@ -788,30 +787,29 @@ class OpenAICompatibleProvider(ModelProvider): def _process_image(self, image_path: str) -> Optional[dict]: """Process an image for OpenAI-compatible API.""" try: - if image_path.startswith("data:image/"): + if image_path.startswith("data:"): + # Validate the data URL + self.validate_image(image_path) # Handle data URL: data:image/png;base64,iVBORw0... return {"type": "image_url", "image_url": {"url": image_path}} else: - # Handle file path - if not os.path.exists(image_path): - logging.warning(f"Image file not found: {image_path}") - return None - - # Detect MIME type from file extension using centralized mappings - from utils.file_types import get_image_mime_type - - ext = os.path.splitext(image_path)[1].lower() - mime_type = get_image_mime_type(ext) - logging.debug(f"Processing image '{image_path}' with extension '{ext}' as MIME type '{mime_type}'") + # Use base class validation + image_bytes, mime_type = self.validate_image(image_path) # Read and encode the image - with open(image_path, "rb") as f: - image_data = base64.b64encode(f.read()).decode() + import base64 + + image_data = base64.b64encode(image_bytes).decode() + logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'") # Create data URL for OpenAI API data_url = f"data:{mime_type};base64,{image_data}" return {"type": "image_url", "image_url": {"url": data_url}} + + except ValueError as e: + logging.warning(str(e)) + return None except Exception as e: logging.error(f"Error processing image {image_path}: {e}") return None diff --git a/tests/test_image_validation.py b/tests/test_image_validation.py new file mode 100644 index 0000000..e1fb36d --- /dev/null +++ b/tests/test_image_validation.py @@ -0,0 +1,303 @@ +"""Tests for provider-independent image validation.""" + +import base64 +import os +import tempfile +from typing import Optional +from unittest.mock import Mock, patch + +import pytest + +from providers.base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType + + +class MinimalTestProvider(ModelProvider): + """Minimal concrete provider for testing base class methods.""" + + def get_capabilities(self, model_name: str) -> ModelCapabilities: + """Not needed for image validation tests.""" + raise NotImplementedError("Not needed for image validation tests") + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.7, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> ModelResponse: + """Not needed for image validation tests.""" + raise NotImplementedError("Not needed for image validation tests") + + def count_tokens(self, text: str, model_name: str) -> int: + """Not needed for image validation tests.""" + raise NotImplementedError("Not needed for image validation tests") + + def get_provider_type(self) -> ProviderType: + """Not needed for image validation tests.""" + raise NotImplementedError("Not needed for image validation tests") + + def validate_model_name(self, model_name: str) -> bool: + """Not needed for image validation tests.""" + raise NotImplementedError("Not needed for image validation tests") + + def supports_thinking_mode(self, model_name: str) -> bool: + """Not needed for image validation tests.""" + raise NotImplementedError("Not needed for image validation tests") + + +class TestImageValidation: + """Test suite for image validation functionality.""" + + def setup_method(self) -> None: + """Set up test fixtures.""" + # Create a minimal concrete provider instance for testing base class methods + self.provider = MinimalTestProvider(api_key="test-key") + + def test_validate_data_url_valid(self) -> None: + """Test validation of valid data URL.""" + # Create a small test image (1x1 PNG) + test_image_data = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" + ) + data_url = f"data:image/png;base64,{base64.b64encode(test_image_data).decode()}" + + image_bytes, mime_type = self.provider.validate_image(data_url) + + assert image_bytes == test_image_data + assert mime_type == "image/png" + + @pytest.mark.parametrize( + "invalid_url,expected_error", + [ + ("data:image/png", "Invalid data URL format"), # Missing base64 part + ("data:image/png;base64", "Invalid data URL format"), # Missing data + ("data:text/plain;base64,dGVzdA==", "Unsupported image type"), # Not an image + ], + ) + def test_validate_data_url_invalid_format(self, invalid_url: str, expected_error: str) -> None: + """Test validation of malformed data URL.""" + with pytest.raises(ValueError) as excinfo: + self.provider.validate_image(invalid_url) + assert expected_error in str(excinfo.value) + + def test_non_data_url_treated_as_file_path(self) -> None: + """Test that non-data URLs are treated as file paths.""" + # Test case that's not a data URL at all + with pytest.raises(ValueError) as excinfo: + self.provider.validate_image("image/png;base64,abc123") + assert "Image file not found" in str(excinfo.value) # Treated as file path + + def test_validate_data_url_unsupported_type(self) -> None: + """Test validation of unsupported image type in data URL.""" + data_url = "data:image/bmp;base64,Qk0=" # BMP format + + with pytest.raises(ValueError) as excinfo: + self.provider.validate_image(data_url) + assert "Unsupported image type: image/bmp" in str(excinfo.value) + + def test_validate_data_url_invalid_base64(self) -> None: + """Test validation of data URL with invalid base64.""" + data_url = "data:image/png;base64,@@@invalid@@@" + + with pytest.raises(ValueError) as excinfo: + self.provider.validate_image(data_url) + assert "Invalid base64 data" in str(excinfo.value) + + def test_validate_large_data_url(self) -> None: + """Test validation of large data URL to ensure size limits work.""" + # Create a large image (21MB) + large_data = b"x" * (21 * 1024 * 1024) # 21MB + + # Encode as base64 and create data URL + import base64 + + encoded_data = base64.b64encode(large_data).decode() + data_url = f"data:image/png;base64,{encoded_data}" + + # Should fail with default 20MB limit + with pytest.raises(ValueError) as excinfo: + self.provider.validate_image(data_url) + assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value) + + # Should succeed with higher limit + image_bytes, mime_type = self.provider.validate_image(data_url, max_size_mb=25.0) + assert len(image_bytes) == len(large_data) + assert mime_type == "image/png" + + def test_validate_file_path_valid(self) -> None: + """Test validation of valid image file.""" + # Create a temporary image file + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file: + # Write a small test PNG + test_image_data = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" + ) + tmp_file.write(test_image_data) + tmp_file_path = tmp_file.name + + try: + image_bytes, mime_type = self.provider.validate_image(tmp_file_path) + + assert image_bytes == test_image_data + assert mime_type == "image/png" + finally: + os.unlink(tmp_file_path) + + def test_validate_file_path_not_found(self) -> None: + """Test validation of non-existent file.""" + with pytest.raises(ValueError) as excinfo: + self.provider.validate_image("/path/to/nonexistent/image.png") + assert "Image file not found" in str(excinfo.value) + + def test_validate_file_path_unsupported_extension(self) -> None: + """Test validation of file with unsupported extension.""" + with tempfile.NamedTemporaryFile(suffix=".bmp", delete=False) as tmp_file: + tmp_file.write(b"dummy data") + tmp_file_path = tmp_file.name + + try: + with pytest.raises(ValueError) as excinfo: + self.provider.validate_image(tmp_file_path) + assert "Unsupported image format: .bmp" in str(excinfo.value) + finally: + os.unlink(tmp_file_path) + + def test_validate_file_path_read_error(self) -> None: + """Test validation when file cannot be read.""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file: + tmp_file_path = tmp_file.name + + # Remove the file but keep the path + os.unlink(tmp_file_path) + + with pytest.raises(ValueError) as excinfo: + self.provider.validate_image(tmp_file_path) + assert "Image file not found" in str(excinfo.value) + + def test_validate_image_size_limit(self) -> None: + """Test validation of image size limits.""" + # Create a large "image" (just random data) + large_data = b"x" * (21 * 1024 * 1024) # 21MB + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file: + tmp_file.write(large_data) + tmp_file_path = tmp_file.name + + try: + with pytest.raises(ValueError) as excinfo: + self.provider.validate_image(tmp_file_path, max_size_mb=20.0) + assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value) + finally: + os.unlink(tmp_file_path) + + def test_validate_image_custom_size_limit(self) -> None: + """Test validation with custom size limit.""" + # Create a 2MB "image" + data = b"x" * (2 * 1024 * 1024) + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file: + tmp_file.write(data) + tmp_file_path = tmp_file.name + + try: + # Should fail with 1MB limit + with pytest.raises(ValueError) as excinfo: + self.provider.validate_image(tmp_file_path, max_size_mb=1.0) + assert "Image too large: 2.0MB (max: 1.0MB)" in str(excinfo.value) + + # Should succeed with 3MB limit + image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=3.0) + assert len(image_bytes) == len(data) + assert mime_type == "image/png" + finally: + os.unlink(tmp_file_path) + + def test_validate_image_default_size_limit(self) -> None: + """Test validation with default size limit (None).""" + # Create a small image that's under the default limit + data = b"x" * (1024 * 1024) # 1MB + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: + tmp_file.write(data) + tmp_file_path = tmp_file.name + + try: + # Should succeed with default limit (20MB) + image_bytes, mime_type = self.provider.validate_image(tmp_file_path) + assert len(image_bytes) == len(data) + assert mime_type == "image/jpeg" + + # Should also succeed when explicitly passing None + image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=None) + assert len(image_bytes) == len(data) + assert mime_type == "image/jpeg" + finally: + os.unlink(tmp_file_path) + + def test_validate_all_supported_formats(self) -> None: + """Test validation of all supported image formats.""" + supported_formats = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".webp": "image/webp", + } + + for ext, expected_mime in supported_formats.items(): + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file: + tmp_file.write(b"dummy image data") + tmp_file_path = tmp_file.name + + try: + image_bytes, mime_type = self.provider.validate_image(tmp_file_path) + assert mime_type == expected_mime + assert image_bytes == b"dummy image data" + finally: + os.unlink(tmp_file_path) + + +class TestProviderIntegration: + """Test image validation integration with different providers.""" + + @patch("providers.gemini.logger") + def test_gemini_provider_uses_validation(self, mock_logger: Mock) -> None: + """Test that Gemini provider uses the base validation.""" + from providers.gemini import GeminiModelProvider + + # Create a provider instance + provider = GeminiModelProvider(api_key="test-key") + + # Test with non-existent file + result = provider._process_image("/nonexistent/image.png") + assert result is None + mock_logger.warning.assert_called_with("Image file not found: /nonexistent/image.png") + + @patch("providers.openai_compatible.logging") + def test_openai_compatible_provider_uses_validation(self, mock_logging: Mock) -> None: + """Test that OpenAI-compatible providers use the base validation.""" + from providers.xai import XAIModelProvider + + # Create a provider instance (XAI inherits from OpenAICompatibleProvider) + provider = XAIModelProvider(api_key="test-key") + + # Test with non-existent file + result = provider._process_image("/nonexistent/image.png") + assert result is None + mock_logging.warning.assert_called_with("Image file not found: /nonexistent/image.png") + + def test_data_url_preservation(self) -> None: + """Test that data URLs are properly preserved through validation.""" + from providers.xai import XAIModelProvider + + provider = XAIModelProvider(api_key="test-key") + + # Valid data URL + data_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" + + result = provider._process_image(data_url) + assert result is not None + assert result["type"] == "image_url" + assert result["image_url"]["url"] == data_url