refactor: Extract image validation to provider base class
Consolidates duplicated image validation logic from individual providers into a reusable base class method. This improves maintainability and ensures consistent validation across all providers. - Added validate_image() method to ModelProvider base class - Supports both file paths and data URLs - Validates image format, size, and MIME types - Added DEFAULT_MAX_IMAGE_SIZE_MB class constant (20MB) - Refactored Gemini and OpenAI providers to use base validation - Added comprehensive test suite with 19 tests - Used minimal mocking approach with concrete test provider class
This commit is contained in:
@@ -196,6 +196,9 @@ class ModelProvider(ABC):
|
||||
# All concrete providers must define their supported models
|
||||
SUPPORTED_MODELS: dict[str, Any] = {}
|
||||
|
||||
# Default maximum image size in MB
|
||||
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize the provider with API key and optional configuration."""
|
||||
self.api_key = api_key
|
||||
@@ -433,6 +436,83 @@ class ModelProvider(ABC):
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
|
||||
"""Provider-independent image validation.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file or data URL
|
||||
max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB)
|
||||
|
||||
Returns:
|
||||
Tuple of (image_bytes, mime_type)
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
|
||||
Examples:
|
||||
# Validate a file path
|
||||
image_bytes, mime_type = provider.validate_image("/path/to/image.png")
|
||||
|
||||
# Validate a data URL
|
||||
image_bytes, mime_type = provider.validate_image("data:image/png;base64,...")
|
||||
|
||||
# Validate with custom size limit
|
||||
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
|
||||
from utils.file_types import IMAGES, get_image_mime_type
|
||||
|
||||
# Use default if not specified
|
||||
if max_size_mb is None:
|
||||
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
|
||||
|
||||
if image_path.startswith("data:"):
|
||||
# Parse data URL: ...
|
||||
try:
|
||||
header, data = image_path.split(",", 1)
|
||||
mime_type = header.split(";")[0].split(":")[1]
|
||||
except (ValueError, IndexError) as e:
|
||||
raise ValueError(f"Invalid data URL format: {e}")
|
||||
|
||||
# Validate MIME type using IMAGES constant
|
||||
valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES]
|
||||
if mime_type not in valid_mime_types:
|
||||
raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}")
|
||||
|
||||
# Decode base64 data
|
||||
try:
|
||||
image_bytes = base64.b64decode(data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid base64 data: {e}")
|
||||
else:
|
||||
# Handle file path
|
||||
if not os.path.exists(image_path):
|
||||
raise ValueError(f"Image file not found: {image_path}")
|
||||
|
||||
# Validate extension
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
if ext not in IMAGES:
|
||||
raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}")
|
||||
|
||||
# Get MIME type
|
||||
mime_type = get_image_mime_type(ext)
|
||||
|
||||
# Read file
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to read image file: {e}")
|
||||
|
||||
# Validate size
|
||||
size_mb = len(image_bytes) / (1024 * 1024)
|
||||
if size_mb > max_size_mb:
|
||||
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
|
||||
|
||||
return image_bytes, mime_type
|
||||
|
||||
def close(self):
|
||||
"""Clean up any resources held by the provider.
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
@@ -440,28 +439,22 @@ class GeminiModelProvider(ModelProvider):
|
||||
def _process_image(self, image_path: str) -> Optional[dict]:
|
||||
"""Process an image for Gemini API."""
|
||||
try:
|
||||
if image_path.startswith("...
|
||||
header, data = image_path.split(",", 1)
|
||||
mime_type = header.split(";")[0].split(":")[1]
|
||||
# Use base class validation
|
||||
image_bytes, mime_type = self.validate_image(image_path)
|
||||
|
||||
# For data URLs, extract the base64 data directly
|
||||
if image_path.startswith("data:"):
|
||||
# Extract base64 data from data URL
|
||||
_, data = image_path.split(",", 1)
|
||||
return {"inline_data": {"mime_type": mime_type, "data": data}}
|
||||
else:
|
||||
# Handle file path
|
||||
from utils.file_types import get_image_mime_type
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
logger.warning(f"Image file not found: {image_path}")
|
||||
return None
|
||||
|
||||
# Detect MIME type from file extension using centralized mappings
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
mime_type = get_image_mime_type(ext)
|
||||
|
||||
# Read and encode the image
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = base64.b64encode(f.read()).decode()
|
||||
|
||||
# For file paths, encode the bytes
|
||||
image_data = base64.b64encode(image_bytes).decode()
|
||||
return {"inline_data": {"mime_type": mime_type, "data": image_data}}
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(str(e))
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {image_path}: {e}")
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Base class for OpenAI-compatible API providers."""
|
||||
|
||||
import base64
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
@@ -788,30 +787,29 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
def _process_image(self, image_path: str) -> Optional[dict]:
|
||||
"""Process an image for OpenAI-compatible API."""
|
||||
try:
|
||||
if image_path.startswith("...
|
||||
return {"type": "image_url", "image_url": {"url": image_path}}
|
||||
else:
|
||||
# Handle file path
|
||||
if not os.path.exists(image_path):
|
||||
logging.warning(f"Image file not found: {image_path}")
|
||||
return None
|
||||
|
||||
# Detect MIME type from file extension using centralized mappings
|
||||
from utils.file_types import get_image_mime_type
|
||||
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
mime_type = get_image_mime_type(ext)
|
||||
logging.debug(f"Processing image '{image_path}' with extension '{ext}' as MIME type '{mime_type}'")
|
||||
# Use base class validation
|
||||
image_bytes, mime_type = self.validate_image(image_path)
|
||||
|
||||
# Read and encode the image
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = base64.b64encode(f.read()).decode()
|
||||
import base64
|
||||
|
||||
image_data = base64.b64encode(image_bytes).decode()
|
||||
logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'")
|
||||
|
||||
# Create data URL for OpenAI API
|
||||
data_url = f"data:{mime_type};base64,{image_data}"
|
||||
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
except ValueError as e:
|
||||
logging.warning(str(e))
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing image {image_path}: {e}")
|
||||
return None
|
||||
|
||||
303
tests/test_image_validation.py
Normal file
303
tests/test_image_validation.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Tests for provider-independent image validation."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType
|
||||
|
||||
|
||||
class MinimalTestProvider(ModelProvider):
|
||||
"""Minimal concrete provider for testing base class methods."""
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Not needed for image validation tests."""
|
||||
raise NotImplementedError("Not needed for image validation tests")
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Not needed for image validation tests."""
|
||||
raise NotImplementedError("Not needed for image validation tests")
|
||||
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
"""Not needed for image validation tests."""
|
||||
raise NotImplementedError("Not needed for image validation tests")
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Not needed for image validation tests."""
|
||||
raise NotImplementedError("Not needed for image validation tests")
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Not needed for image validation tests."""
|
||||
raise NotImplementedError("Not needed for image validation tests")
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Not needed for image validation tests."""
|
||||
raise NotImplementedError("Not needed for image validation tests")
|
||||
|
||||
|
||||
class TestImageValidation:
|
||||
"""Test suite for image validation functionality."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
# Create a minimal concrete provider instance for testing base class methods
|
||||
self.provider = MinimalTestProvider(api_key="test-key")
|
||||
|
||||
def test_validate_data_url_valid(self) -> None:
|
||||
"""Test validation of valid data URL."""
|
||||
# Create a small test image (1x1 PNG)
|
||||
test_image_data = base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
||||
)
|
||||
data_url = f"data:image/png;base64,{base64.b64encode(test_image_data).decode()}"
|
||||
|
||||
image_bytes, mime_type = self.provider.validate_image(data_url)
|
||||
|
||||
assert image_bytes == test_image_data
|
||||
assert mime_type == "image/png"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_url,expected_error",
|
||||
[
|
||||
("data:image/png", "Invalid data URL format"), # Missing base64 part
|
||||
("data:image/png;base64", "Invalid data URL format"), # Missing data
|
||||
("data:text/plain;base64,dGVzdA==", "Unsupported image type"), # Not an image
|
||||
],
|
||||
)
|
||||
def test_validate_data_url_invalid_format(self, invalid_url: str, expected_error: str) -> None:
|
||||
"""Test validation of malformed data URL."""
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(invalid_url)
|
||||
assert expected_error in str(excinfo.value)
|
||||
|
||||
def test_non_data_url_treated_as_file_path(self) -> None:
|
||||
"""Test that non-data URLs are treated as file paths."""
|
||||
# Test case that's not a data URL at all
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image("image/png;base64,abc123")
|
||||
assert "Image file not found" in str(excinfo.value) # Treated as file path
|
||||
|
||||
def test_validate_data_url_unsupported_type(self) -> None:
|
||||
"""Test validation of unsupported image type in data URL."""
|
||||
data_url = "" # BMP format
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(data_url)
|
||||
assert "Unsupported image type: image/bmp" in str(excinfo.value)
|
||||
|
||||
def test_validate_data_url_invalid_base64(self) -> None:
|
||||
"""Test validation of data URL with invalid base64."""
|
||||
data_url = "data:image/png;base64,@@@invalid@@@"
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(data_url)
|
||||
assert "Invalid base64 data" in str(excinfo.value)
|
||||
|
||||
def test_validate_large_data_url(self) -> None:
|
||||
"""Test validation of large data URL to ensure size limits work."""
|
||||
# Create a large image (21MB)
|
||||
large_data = b"x" * (21 * 1024 * 1024) # 21MB
|
||||
|
||||
# Encode as base64 and create data URL
|
||||
import base64
|
||||
|
||||
encoded_data = base64.b64encode(large_data).decode()
|
||||
data_url = f"data:image/png;base64,{encoded_data}"
|
||||
|
||||
# Should fail with default 20MB limit
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(data_url)
|
||||
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
|
||||
|
||||
# Should succeed with higher limit
|
||||
image_bytes, mime_type = self.provider.validate_image(data_url, max_size_mb=25.0)
|
||||
assert len(image_bytes) == len(large_data)
|
||||
assert mime_type == "image/png"
|
||||
|
||||
def test_validate_file_path_valid(self) -> None:
|
||||
"""Test validation of valid image file."""
|
||||
# Create a temporary image file
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
|
||||
# Write a small test PNG
|
||||
test_image_data = base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
||||
)
|
||||
tmp_file.write(test_image_data)
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
try:
|
||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
|
||||
|
||||
assert image_bytes == test_image_data
|
||||
assert mime_type == "image/png"
|
||||
finally:
|
||||
os.unlink(tmp_file_path)
|
||||
|
||||
def test_validate_file_path_not_found(self) -> None:
|
||||
"""Test validation of non-existent file."""
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image("/path/to/nonexistent/image.png")
|
||||
assert "Image file not found" in str(excinfo.value)
|
||||
|
||||
def test_validate_file_path_unsupported_extension(self) -> None:
|
||||
"""Test validation of file with unsupported extension."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".bmp", delete=False) as tmp_file:
|
||||
tmp_file.write(b"dummy data")
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(tmp_file_path)
|
||||
assert "Unsupported image format: .bmp" in str(excinfo.value)
|
||||
finally:
|
||||
os.unlink(tmp_file_path)
|
||||
|
||||
def test_validate_file_path_read_error(self) -> None:
|
||||
"""Test validation when file cannot be read."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
# Remove the file but keep the path
|
||||
os.unlink(tmp_file_path)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(tmp_file_path)
|
||||
assert "Image file not found" in str(excinfo.value)
|
||||
|
||||
def test_validate_image_size_limit(self) -> None:
|
||||
"""Test validation of image size limits."""
|
||||
# Create a large "image" (just random data)
|
||||
large_data = b"x" * (21 * 1024 * 1024) # 21MB
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
|
||||
tmp_file.write(large_data)
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(tmp_file_path, max_size_mb=20.0)
|
||||
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
|
||||
finally:
|
||||
os.unlink(tmp_file_path)
|
||||
|
||||
def test_validate_image_custom_size_limit(self) -> None:
|
||||
"""Test validation with custom size limit."""
|
||||
# Create a 2MB "image"
|
||||
data = b"x" * (2 * 1024 * 1024)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
|
||||
tmp_file.write(data)
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
try:
|
||||
# Should fail with 1MB limit
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(tmp_file_path, max_size_mb=1.0)
|
||||
assert "Image too large: 2.0MB (max: 1.0MB)" in str(excinfo.value)
|
||||
|
||||
# Should succeed with 3MB limit
|
||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=3.0)
|
||||
assert len(image_bytes) == len(data)
|
||||
assert mime_type == "image/png"
|
||||
finally:
|
||||
os.unlink(tmp_file_path)
|
||||
|
||||
def test_validate_image_default_size_limit(self) -> None:
|
||||
"""Test validation with default size limit (None)."""
|
||||
# Create a small image that's under the default limit
|
||||
data = b"x" * (1024 * 1024) # 1MB
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
|
||||
tmp_file.write(data)
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
try:
|
||||
# Should succeed with default limit (20MB)
|
||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
|
||||
assert len(image_bytes) == len(data)
|
||||
assert mime_type == "image/jpeg"
|
||||
|
||||
# Should also succeed when explicitly passing None
|
||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=None)
|
||||
assert len(image_bytes) == len(data)
|
||||
assert mime_type == "image/jpeg"
|
||||
finally:
|
||||
os.unlink(tmp_file_path)
|
||||
|
||||
def test_validate_all_supported_formats(self) -> None:
|
||||
"""Test validation of all supported image formats."""
|
||||
supported_formats = {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
}
|
||||
|
||||
for ext, expected_mime in supported_formats.items():
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file:
|
||||
tmp_file.write(b"dummy image data")
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
try:
|
||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
|
||||
assert mime_type == expected_mime
|
||||
assert image_bytes == b"dummy image data"
|
||||
finally:
|
||||
os.unlink(tmp_file_path)
|
||||
|
||||
|
||||
class TestProviderIntegration:
|
||||
"""Test image validation integration with different providers."""
|
||||
|
||||
@patch("providers.gemini.logger")
|
||||
def test_gemini_provider_uses_validation(self, mock_logger: Mock) -> None:
|
||||
"""Test that Gemini provider uses the base validation."""
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
# Create a provider instance
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
# Test with non-existent file
|
||||
result = provider._process_image("/nonexistent/image.png")
|
||||
assert result is None
|
||||
mock_logger.warning.assert_called_with("Image file not found: /nonexistent/image.png")
|
||||
|
||||
@patch("providers.openai_compatible.logging")
|
||||
def test_openai_compatible_provider_uses_validation(self, mock_logging: Mock) -> None:
|
||||
"""Test that OpenAI-compatible providers use the base validation."""
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
# Create a provider instance (XAI inherits from OpenAICompatibleProvider)
|
||||
provider = XAIModelProvider(api_key="test-key")
|
||||
|
||||
# Test with non-existent file
|
||||
result = provider._process_image("/nonexistent/image.png")
|
||||
assert result is None
|
||||
mock_logging.warning.assert_called_with("Image file not found: /nonexistent/image.png")
|
||||
|
||||
def test_data_url_preservation(self) -> None:
|
||||
"""Test that data URLs are properly preserved through validation."""
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
provider = XAIModelProvider(api_key="test-key")
|
||||
|
||||
# Valid data URL
|
||||
data_url = ""
|
||||
|
||||
result = provider._process_image(data_url)
|
||||
assert result is not None
|
||||
assert result["type"] == "image_url"
|
||||
assert result["image_url"]["url"] == data_url
|
||||
Reference in New Issue
Block a user