diff --git a/providers/base.py b/providers/base.py index 8cd2091..e8e54f9 100644 --- a/providers/base.py +++ b/providers/base.py @@ -1,17 +1,12 @@ """Base interfaces and common behaviour for model providers.""" -import base64 -import binascii import logging -import os from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from tools.models import ToolModelCategory -from utils.file_types import IMAGES, get_image_mime_type - from .shared import ModelCapabilities, ModelResponse, ProviderType logger = logging.getLogger(__name__) @@ -43,9 +38,6 @@ class ModelProvider(ABC): # All concrete providers must define their supported models MODEL_CAPABILITIES: dict[str, Any] = {} - # Default maximum image size in MB - DEFAULT_MAX_IMAGE_SIZE_MB = 20.0 - def __init__(self, api_key: str, **kwargs): """Initialize the provider with API key and optional configuration.""" self.api_key = api_key @@ -167,17 +159,7 @@ class ModelProvider(ABC): lowercase: bool = False, unique: bool = False, ) -> list[str]: - """Return formatted model names supported by this provider. - - Args: - respect_restrictions: Apply provider restriction policy. - include_aliases: Include aliases alongside canonical model names. - lowercase: Normalize returned names to lowercase. - unique: Deduplicate names after formatting. - - Returns: - List of model names formatted according to the provided options. - """ + """Return formatted model names supported by this provider.""" model_configs = self.get_model_configurations() if not model_configs: @@ -206,77 +188,6 @@ class ModelProvider(ABC): unique=unique, ) - def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]: - """Provider-independent image validation. - - Args: - image_path: Path to image file or data URL - max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB) - - Returns: - Tuple of (image_bytes, mime_type) - - Raises: - ValueError: If image is invalid - - Examples: - # Validate a file path - image_bytes, mime_type = provider.validate_image("/path/to/image.png") - - # Validate a data URL - image_bytes, mime_type = provider.validate_image("data:image/png;base64,...") - - # Validate with custom size limit - image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0) - """ - # Use default if not specified - if max_size_mb is None: - max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB - - if image_path.startswith("data:"): - # Parse data URL: ... - try: - header, data = image_path.split(",", 1) - mime_type = header.split(";")[0].split(":")[1] - except (ValueError, IndexError) as e: - raise ValueError(f"Invalid data URL format: {e}") - - # Validate MIME type using IMAGES constant - valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES] - if mime_type not in valid_mime_types: - raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}") - - # Decode base64 data - try: - image_bytes = base64.b64decode(data) - except binascii.Error as e: - raise ValueError(f"Invalid base64 data: {e}") - else: - # Handle file path - # Read file first to check if it exists - try: - with open(image_path, "rb") as f: - image_bytes = f.read() - except FileNotFoundError: - raise ValueError(f"Image file not found: {image_path}") - except Exception as e: - raise ValueError(f"Failed to read image file: {e}") - - # Validate extension - ext = os.path.splitext(image_path)[1].lower() - if ext not in IMAGES: - raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}") - - # Get MIME type - mime_type = get_image_mime_type(ext) - - # Validate size - size_mb = len(image_bytes) / (1024 * 1024) - if size_mb > max_size_mb: - raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)") - - return image_bytes, mime_type - def close(self): """Clean up any resources held by the provider. diff --git a/providers/gemini.py b/providers/gemini.py index d136f55..2bdc4da 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -11,6 +11,8 @@ if TYPE_CHECKING: from google import genai from google.genai import types +from utils.image_utils import validate_image + from .base import ModelProvider from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint @@ -529,7 +531,7 @@ class GeminiModelProvider(ModelProvider): """Process an image for Gemini API.""" try: # Use base class validation - image_bytes, mime_type = self.validate_image(image_path) + image_bytes, mime_type = validate_image(image_path) # For data URLs, extract the base64 data directly if image_path.startswith("data:"): diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py index 6a0454f..fd04e7d 100644 --- a/providers/openai_compatible.py +++ b/providers/openai_compatible.py @@ -11,6 +11,8 @@ from urllib.parse import urlparse from openai import OpenAI +from utils.image_utils import validate_image + from .base import ModelProvider from .shared import ( ModelCapabilities, @@ -830,12 +832,12 @@ class OpenAICompatibleProvider(ModelProvider): try: if image_path.startswith("data:"): # Validate the data URL - self.validate_image(image_path) + validate_image(image_path) # Handle data URL: ... return {"type": "image_url", "image_url": {"url": image_path}} else: # Use base class validation - image_bytes, mime_type = self.validate_image(image_path) + image_bytes, mime_type = validate_image(image_path) # Read and encode the image import base64 diff --git a/tests/test_image_validation.py b/tests/test_image_validation.py index 9734f6c..c4454c8 100644 --- a/tests/test_image_validation.py +++ b/tests/test_image_validation.py @@ -1,57 +1,18 @@ -"""Tests for provider-independent image validation.""" +"""Tests for image validation utility helpers.""" import base64 import os import tempfile -from typing import Optional from unittest.mock import Mock, patch import pytest -from providers.base import ModelProvider -from providers.shared import ModelCapabilities, ModelResponse, ProviderType - - -class MinimalTestProvider(ModelProvider): - """Minimal concrete provider for testing base class methods.""" - - def get_capabilities(self, model_name: str) -> ModelCapabilities: - """Not needed for image validation tests.""" - raise NotImplementedError("Not needed for image validation tests") - - def generate_content( - self, - prompt: str, - model_name: str, - system_prompt: Optional[str] = None, - temperature: float = 0.7, - max_output_tokens: Optional[int] = None, - **kwargs, - ) -> ModelResponse: - """Not needed for image validation tests.""" - raise NotImplementedError("Not needed for image validation tests") - - def count_tokens(self, text: str, model_name: str) -> int: - """Not needed for image validation tests.""" - raise NotImplementedError("Not needed for image validation tests") - - def get_provider_type(self) -> ProviderType: - """Not needed for image validation tests.""" - raise NotImplementedError("Not needed for image validation tests") - - def validate_model_name(self, model_name: str) -> bool: - """Not needed for image validation tests.""" - raise NotImplementedError("Not needed for image validation tests") +from utils.image_utils import DEFAULT_MAX_IMAGE_SIZE_MB, validate_image class TestImageValidation: """Test suite for image validation functionality.""" - def setup_method(self) -> None: - """Set up test fixtures.""" - # Create a minimal concrete provider instance for testing base class methods - self.provider = MinimalTestProvider(api_key="test-key") - def test_validate_data_url_valid(self) -> None: """Test validation of valid data URL.""" # Create a small test image (1x1 PNG) @@ -60,7 +21,7 @@ class TestImageValidation: ) data_url = f"data:image/png;base64,{base64.b64encode(test_image_data).decode()}" - image_bytes, mime_type = self.provider.validate_image(data_url) + image_bytes, mime_type = validate_image(data_url) assert image_bytes == test_image_data assert mime_type == "image/png" @@ -76,14 +37,14 @@ class TestImageValidation: def test_validate_data_url_invalid_format(self, invalid_url: str, expected_error: str) -> None: """Test validation of malformed data URL.""" with pytest.raises(ValueError) as excinfo: - self.provider.validate_image(invalid_url) + validate_image(invalid_url) assert expected_error in str(excinfo.value) def test_non_data_url_treated_as_file_path(self) -> None: """Test that non-data URLs are treated as file paths.""" # Test case that's not a data URL at all with pytest.raises(ValueError) as excinfo: - self.provider.validate_image("image/png;base64,abc123") + validate_image("image/png;base64,abc123") assert "Image file not found" in str(excinfo.value) # Treated as file path def test_validate_data_url_unsupported_type(self) -> None: @@ -91,7 +52,7 @@ class TestImageValidation: data_url = "" # BMP format with pytest.raises(ValueError) as excinfo: - self.provider.validate_image(data_url) + validate_image(data_url) assert "Unsupported image type: image/bmp" in str(excinfo.value) def test_validate_data_url_invalid_base64(self) -> None: @@ -99,7 +60,7 @@ class TestImageValidation: data_url = "data:image/png;base64,@@@invalid@@@" with pytest.raises(ValueError) as excinfo: - self.provider.validate_image(data_url) + validate_image(data_url) assert "Invalid base64 data" in str(excinfo.value) def test_validate_large_data_url(self) -> None: @@ -115,11 +76,11 @@ class TestImageValidation: # Should fail with default 20MB limit with pytest.raises(ValueError) as excinfo: - self.provider.validate_image(data_url) - assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value) + validate_image(data_url) + assert f"Image too large: 21.0MB (max: {DEFAULT_MAX_IMAGE_SIZE_MB:.1f}MB)" in str(excinfo.value) # Should succeed with higher limit - image_bytes, mime_type = self.provider.validate_image(data_url, max_size_mb=25.0) + image_bytes, mime_type = validate_image(data_url, max_size_mb=25.0) assert len(image_bytes) == len(large_data) assert mime_type == "image/png" @@ -135,7 +96,7 @@ class TestImageValidation: tmp_file_path = tmp_file.name try: - image_bytes, mime_type = self.provider.validate_image(tmp_file_path) + image_bytes, mime_type = validate_image(tmp_file_path) assert image_bytes == test_image_data assert mime_type == "image/png" @@ -145,7 +106,7 @@ class TestImageValidation: def test_validate_file_path_not_found(self) -> None: """Test validation of non-existent file.""" with pytest.raises(ValueError) as excinfo: - self.provider.validate_image("/path/to/nonexistent/image.png") + validate_image("/path/to/nonexistent/image.png") assert "Image file not found" in str(excinfo.value) def test_validate_file_path_unsupported_extension(self) -> None: @@ -156,7 +117,7 @@ class TestImageValidation: try: with pytest.raises(ValueError) as excinfo: - self.provider.validate_image(tmp_file_path) + validate_image(tmp_file_path) assert "Unsupported image format: .bmp" in str(excinfo.value) finally: os.unlink(tmp_file_path) @@ -170,7 +131,7 @@ class TestImageValidation: os.unlink(tmp_file_path) with pytest.raises(ValueError) as excinfo: - self.provider.validate_image(tmp_file_path) + validate_image(tmp_file_path) assert "Image file not found" in str(excinfo.value) def test_validate_image_size_limit(self) -> None: @@ -184,7 +145,7 @@ class TestImageValidation: try: with pytest.raises(ValueError) as excinfo: - self.provider.validate_image(tmp_file_path, max_size_mb=20.0) + validate_image(tmp_file_path, max_size_mb=20.0) assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value) finally: os.unlink(tmp_file_path) @@ -201,11 +162,11 @@ class TestImageValidation: try: # Should fail with 1MB limit with pytest.raises(ValueError) as excinfo: - self.provider.validate_image(tmp_file_path, max_size_mb=1.0) + validate_image(tmp_file_path, max_size_mb=1.0) assert "Image too large: 2.0MB (max: 1.0MB)" in str(excinfo.value) # Should succeed with 3MB limit - image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=3.0) + image_bytes, mime_type = validate_image(tmp_file_path, max_size_mb=3.0) assert len(image_bytes) == len(data) assert mime_type == "image/png" finally: @@ -222,12 +183,12 @@ class TestImageValidation: try: # Should succeed with default limit (20MB) - image_bytes, mime_type = self.provider.validate_image(tmp_file_path) + image_bytes, mime_type = validate_image(tmp_file_path) assert len(image_bytes) == len(data) assert mime_type == "image/jpeg" # Should also succeed when explicitly passing None - image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=None) + image_bytes, mime_type = validate_image(tmp_file_path, max_size_mb=None) assert len(image_bytes) == len(data) assert mime_type == "image/jpeg" finally: @@ -249,7 +210,7 @@ class TestImageValidation: tmp_file_path = tmp_file.name try: - image_bytes, mime_type = self.provider.validate_image(tmp_file_path) + image_bytes, mime_type = validate_image(tmp_file_path) assert mime_type == expected_mime assert image_bytes == b"dummy image data" finally: diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100644 index 0000000..621ea9a --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,94 @@ +"""Utility helpers for validating image inputs.""" + +import base64 +import binascii +import os +from collections.abc import Iterable + +from utils.file_types import IMAGES, get_image_mime_type + +DEFAULT_MAX_IMAGE_SIZE_MB = 20.0 + +__all__ = ["DEFAULT_MAX_IMAGE_SIZE_MB", "validate_image"] + + +def _valid_mime_types() -> Iterable[str]: + """Return the MIME types permitted by the IMAGES whitelist.""" + return (get_image_mime_type(ext) for ext in IMAGES) + + +def validate_image(image_path: str, max_size_mb: float = None) -> tuple[bytes, str]: + """Validate a user-supplied image path or data URL. + + Args: + image_path: Either a filesystem path or a data URL. + max_size_mb: Optional size limit (defaults to ``DEFAULT_MAX_IMAGE_SIZE_MB``). + + Returns: + A tuple ``(image_bytes, mime_type)`` ready for upstream providers. + + Raises: + ValueError: When the image is missing, malformed, or exceeds limits. + """ + if max_size_mb is None: + max_size_mb = DEFAULT_MAX_IMAGE_SIZE_MB + + if image_path.startswith("data:"): + return _validate_data_url(image_path, max_size_mb) + + return _validate_file_path(image_path, max_size_mb) + + +def _validate_data_url(image_data_url: str, max_size_mb: float) -> tuple[bytes, str]: + """Validate a data URL and return image bytes plus MIME type.""" + try: + header, data = image_data_url.split(",", 1) + mime_type = header.split(";")[0].split(":")[1] + except (ValueError, IndexError) as exc: + raise ValueError(f"Invalid data URL format: {exc}") + + valid_mime_types = list(_valid_mime_types()) + if mime_type not in valid_mime_types: + raise ValueError( + "Unsupported image type: {mime}. Supported types: {supported}".format( + mime=mime_type, supported=", ".join(valid_mime_types) + ) + ) + + try: + image_bytes = base64.b64decode(data) + except binascii.Error as exc: + raise ValueError(f"Invalid base64 data: {exc}") + + _validate_size(image_bytes, max_size_mb) + return image_bytes, mime_type + + +def _validate_file_path(file_path: str, max_size_mb: float) -> tuple[bytes, str]: + """Validate an image loaded from the filesystem.""" + try: + with open(file_path, "rb") as handle: + image_bytes = handle.read() + except FileNotFoundError: + raise ValueError(f"Image file not found: {file_path}") + except OSError as exc: + raise ValueError(f"Failed to read image file: {exc}") + + ext = os.path.splitext(file_path)[1].lower() + if ext not in IMAGES: + raise ValueError( + "Unsupported image format: {ext}. Supported formats: {supported}".format( + ext=ext, supported=", ".join(sorted(IMAGES)) + ) + ) + + mime_type = get_image_mime_type(ext) + _validate_size(image_bytes, max_size_mb) + return image_bytes, mime_type + + +def _validate_size(image_bytes: bytes, max_size_mb: float) -> None: + """Ensure the image does not exceed the configured size limit.""" + size_mb = len(image_bytes) / (1024 * 1024) + if size_mb > max_size_mb: + raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")