Merge pull request #192 from nsp/refactor-image-validation

refactor: Extract image validation to provider base class
This commit is contained in:
Beehive Innovations
2025-08-07 23:19:41 -07:00
committed by GitHub
4 changed files with 408 additions and 35 deletions

View File

@@ -1,6 +1,9 @@
"""Base model provider interface and data classes.""" """Base model provider interface and data classes."""
import base64
import binascii
import logging import logging
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
@@ -9,6 +12,8 @@ from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from tools.models import ToolModelCategory from tools.models import ToolModelCategory
from utils.file_types import IMAGES, get_image_mime_type
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -186,6 +191,9 @@ class ModelProvider(ABC):
# All concrete providers must define their supported models # All concrete providers must define their supported models
SUPPORTED_MODELS: dict[str, Any] = {} SUPPORTED_MODELS: dict[str, Any] = {}
# Default maximum image size in MB
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
def __init__(self, api_key: str, **kwargs): def __init__(self, api_key: str, **kwargs):
"""Initialize the provider with API key and optional configuration.""" """Initialize the provider with API key and optional configuration."""
self.api_key = api_key self.api_key = api_key
@@ -420,6 +428,77 @@ class ModelProvider(ABC):
return list(all_models) return list(all_models)
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
"""Provider-independent image validation.
Args:
image_path: Path to image file or data URL
max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB)
Returns:
Tuple of (image_bytes, mime_type)
Raises:
ValueError: If image is invalid
Examples:
# Validate a file path
image_bytes, mime_type = provider.validate_image("/path/to/image.png")
# Validate a data URL
image_bytes, mime_type = provider.validate_image("data:image/png;base64,...")
# Validate with custom size limit
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
"""
# Use default if not specified
if max_size_mb is None:
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
if image_path.startswith("data:"):
# Parse data URL: data:image/png;base64,iVBORw0...
try:
header, data = image_path.split(",", 1)
mime_type = header.split(";")[0].split(":")[1]
except (ValueError, IndexError) as e:
raise ValueError(f"Invalid data URL format: {e}")
# Validate MIME type using IMAGES constant
valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES]
if mime_type not in valid_mime_types:
raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}")
# Decode base64 data
try:
image_bytes = base64.b64decode(data)
except binascii.Error as e:
raise ValueError(f"Invalid base64 data: {e}")
else:
# Handle file path
# Read file first to check if it exists
try:
with open(image_path, "rb") as f:
image_bytes = f.read()
except FileNotFoundError:
raise ValueError(f"Image file not found: {image_path}")
except Exception as e:
raise ValueError(f"Failed to read image file: {e}")
# Validate extension
ext = os.path.splitext(image_path)[1].lower()
if ext not in IMAGES:
raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}")
# Get MIME type
mime_type = get_image_mime_type(ext)
# Validate size
size_mb = len(image_bytes) / (1024 * 1024)
if size_mb > max_size_mb:
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
return image_bytes, mime_type
def close(self): def close(self):
"""Clean up any resources held by the provider. """Clean up any resources held by the provider.

View File

@@ -2,7 +2,6 @@
import base64 import base64
import logging import logging
import os
import time import time
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
@@ -443,28 +442,22 @@ class GeminiModelProvider(ModelProvider):
def _process_image(self, image_path: str) -> Optional[dict]: def _process_image(self, image_path: str) -> Optional[dict]:
"""Process an image for Gemini API.""" """Process an image for Gemini API."""
try: try:
if image_path.startswith("data:image/"): # Use base class validation
# Handle data URL: data:image/png;base64,iVBORw0... image_bytes, mime_type = self.validate_image(image_path)
header, data = image_path.split(",", 1)
mime_type = header.split(";")[0].split(":")[1] # For data URLs, extract the base64 data directly
if image_path.startswith("data:"):
# Extract base64 data from data URL
_, data = image_path.split(",", 1)
return {"inline_data": {"mime_type": mime_type, "data": data}} return {"inline_data": {"mime_type": mime_type, "data": data}}
else: else:
# Handle file path # For file paths, encode the bytes
from utils.file_types import get_image_mime_type image_data = base64.b64encode(image_bytes).decode()
if not os.path.exists(image_path):
logger.warning(f"Image file not found: {image_path}")
return None
# Detect MIME type from file extension using centralized mappings
ext = os.path.splitext(image_path)[1].lower()
mime_type = get_image_mime_type(ext)
# Read and encode the image
with open(image_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
return {"inline_data": {"mime_type": mime_type, "data": image_data}} return {"inline_data": {"mime_type": mime_type, "data": image_data}}
except ValueError as e:
logger.warning(str(e))
return None
except Exception as e: except Exception as e:
logger.error(f"Error processing image {image_path}: {e}") logger.error(f"Error processing image {image_path}: {e}")
return None return None

View File

@@ -1,6 +1,5 @@
"""Base class for OpenAI-compatible API providers.""" """Base class for OpenAI-compatible API providers."""
import base64
import copy import copy
import ipaddress import ipaddress
import logging import logging
@@ -847,30 +846,29 @@ class OpenAICompatibleProvider(ModelProvider):
def _process_image(self, image_path: str) -> Optional[dict]: def _process_image(self, image_path: str) -> Optional[dict]:
"""Process an image for OpenAI-compatible API.""" """Process an image for OpenAI-compatible API."""
try: try:
if image_path.startswith("data:image/"): if image_path.startswith("data:"):
# Validate the data URL
self.validate_image(image_path)
# Handle data URL: data:image/png;base64,iVBORw0... # Handle data URL: data:image/png;base64,iVBORw0...
return {"type": "image_url", "image_url": {"url": image_path}} return {"type": "image_url", "image_url": {"url": image_path}}
else: else:
# Handle file path # Use base class validation
if not os.path.exists(image_path): image_bytes, mime_type = self.validate_image(image_path)
logging.warning(f"Image file not found: {image_path}")
return None
# Detect MIME type from file extension using centralized mappings
from utils.file_types import get_image_mime_type
ext = os.path.splitext(image_path)[1].lower()
mime_type = get_image_mime_type(ext)
logging.debug(f"Processing image '{image_path}' with extension '{ext}' as MIME type '{mime_type}'")
# Read and encode the image # Read and encode the image
with open(image_path, "rb") as f: import base64
image_data = base64.b64encode(f.read()).decode()
image_data = base64.b64encode(image_bytes).decode()
logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'")
# Create data URL for OpenAI API # Create data URL for OpenAI API
data_url = f"data:{mime_type};base64,{image_data}" data_url = f"data:{mime_type};base64,{image_data}"
return {"type": "image_url", "image_url": {"url": data_url}} return {"type": "image_url", "image_url": {"url": data_url}}
except ValueError as e:
logging.warning(str(e))
return None
except Exception as e: except Exception as e:
logging.error(f"Error processing image {image_path}: {e}") logging.error(f"Error processing image {image_path}: {e}")
return None return None

View File

@@ -0,0 +1,303 @@
"""Tests for provider-independent image validation."""
import base64
import os
import tempfile
from typing import Optional
from unittest.mock import Mock, patch
import pytest
from providers.base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType
class MinimalTestProvider(ModelProvider):
"""Minimal concrete provider for testing base class methods."""
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def count_tokens(self, text: str, model_name: str) -> int:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def get_provider_type(self) -> ProviderType:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def validate_model_name(self, model_name: str) -> bool:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def supports_thinking_mode(self, model_name: str) -> bool:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
class TestImageValidation:
"""Test suite for image validation functionality."""
def setup_method(self) -> None:
"""Set up test fixtures."""
# Create a minimal concrete provider instance for testing base class methods
self.provider = MinimalTestProvider(api_key="test-key")
def test_validate_data_url_valid(self) -> None:
"""Test validation of valid data URL."""
# Create a small test image (1x1 PNG)
test_image_data = base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
)
data_url = f"data:image/png;base64,{base64.b64encode(test_image_data).decode()}"
image_bytes, mime_type = self.provider.validate_image(data_url)
assert image_bytes == test_image_data
assert mime_type == "image/png"
@pytest.mark.parametrize(
"invalid_url,expected_error",
[
("data:image/png", "Invalid data URL format"), # Missing base64 part
("data:image/png;base64", "Invalid data URL format"), # Missing data
("data:text/plain;base64,dGVzdA==", "Unsupported image type"), # Not an image
],
)
def test_validate_data_url_invalid_format(self, invalid_url: str, expected_error: str) -> None:
"""Test validation of malformed data URL."""
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(invalid_url)
assert expected_error in str(excinfo.value)
def test_non_data_url_treated_as_file_path(self) -> None:
"""Test that non-data URLs are treated as file paths."""
# Test case that's not a data URL at all
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image("image/png;base64,abc123")
assert "Image file not found" in str(excinfo.value) # Treated as file path
def test_validate_data_url_unsupported_type(self) -> None:
"""Test validation of unsupported image type in data URL."""
data_url = "data:image/bmp;base64,Qk0=" # BMP format
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url)
assert "Unsupported image type: image/bmp" in str(excinfo.value)
def test_validate_data_url_invalid_base64(self) -> None:
"""Test validation of data URL with invalid base64."""
data_url = "data:image/png;base64,@@@invalid@@@"
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url)
assert "Invalid base64 data" in str(excinfo.value)
def test_validate_large_data_url(self) -> None:
"""Test validation of large data URL to ensure size limits work."""
# Create a large image (21MB)
large_data = b"x" * (21 * 1024 * 1024) # 21MB
# Encode as base64 and create data URL
import base64
encoded_data = base64.b64encode(large_data).decode()
data_url = f"data:image/png;base64,{encoded_data}"
# Should fail with default 20MB limit
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url)
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
# Should succeed with higher limit
image_bytes, mime_type = self.provider.validate_image(data_url, max_size_mb=25.0)
assert len(image_bytes) == len(large_data)
assert mime_type == "image/png"
def test_validate_file_path_valid(self) -> None:
"""Test validation of valid image file."""
# Create a temporary image file
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
# Write a small test PNG
test_image_data = base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
)
tmp_file.write(test_image_data)
tmp_file_path = tmp_file.name
try:
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
assert image_bytes == test_image_data
assert mime_type == "image/png"
finally:
os.unlink(tmp_file_path)
def test_validate_file_path_not_found(self) -> None:
"""Test validation of non-existent file."""
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image("/path/to/nonexistent/image.png")
assert "Image file not found" in str(excinfo.value)
def test_validate_file_path_unsupported_extension(self) -> None:
"""Test validation of file with unsupported extension."""
with tempfile.NamedTemporaryFile(suffix=".bmp", delete=False) as tmp_file:
tmp_file.write(b"dummy data")
tmp_file_path = tmp_file.name
try:
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path)
assert "Unsupported image format: .bmp" in str(excinfo.value)
finally:
os.unlink(tmp_file_path)
def test_validate_file_path_read_error(self) -> None:
"""Test validation when file cannot be read."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
tmp_file_path = tmp_file.name
# Remove the file but keep the path
os.unlink(tmp_file_path)
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path)
assert "Image file not found" in str(excinfo.value)
def test_validate_image_size_limit(self) -> None:
"""Test validation of image size limits."""
# Create a large "image" (just random data)
large_data = b"x" * (21 * 1024 * 1024) # 21MB
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
tmp_file.write(large_data)
tmp_file_path = tmp_file.name
try:
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path, max_size_mb=20.0)
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
finally:
os.unlink(tmp_file_path)
def test_validate_image_custom_size_limit(self) -> None:
"""Test validation with custom size limit."""
# Create a 2MB "image"
data = b"x" * (2 * 1024 * 1024)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
tmp_file.write(data)
tmp_file_path = tmp_file.name
try:
# Should fail with 1MB limit
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path, max_size_mb=1.0)
assert "Image too large: 2.0MB (max: 1.0MB)" in str(excinfo.value)
# Should succeed with 3MB limit
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=3.0)
assert len(image_bytes) == len(data)
assert mime_type == "image/png"
finally:
os.unlink(tmp_file_path)
def test_validate_image_default_size_limit(self) -> None:
"""Test validation with default size limit (None)."""
# Create a small image that's under the default limit
data = b"x" * (1024 * 1024) # 1MB
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
tmp_file.write(data)
tmp_file_path = tmp_file.name
try:
# Should succeed with default limit (20MB)
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
assert len(image_bytes) == len(data)
assert mime_type == "image/jpeg"
# Should also succeed when explicitly passing None
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=None)
assert len(image_bytes) == len(data)
assert mime_type == "image/jpeg"
finally:
os.unlink(tmp_file_path)
def test_validate_all_supported_formats(self) -> None:
"""Test validation of all supported image formats."""
supported_formats = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
}
for ext, expected_mime in supported_formats.items():
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file:
tmp_file.write(b"dummy image data")
tmp_file_path = tmp_file.name
try:
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
assert mime_type == expected_mime
assert image_bytes == b"dummy image data"
finally:
os.unlink(tmp_file_path)
class TestProviderIntegration:
"""Test image validation integration with different providers."""
@patch("providers.gemini.logger")
def test_gemini_provider_uses_validation(self, mock_logger: Mock) -> None:
"""Test that Gemini provider uses the base validation."""
from providers.gemini import GeminiModelProvider
# Create a provider instance
provider = GeminiModelProvider(api_key="test-key")
# Test with non-existent file
result = provider._process_image("/nonexistent/image.png")
assert result is None
mock_logger.warning.assert_called_with("Image file not found: /nonexistent/image.png")
@patch("providers.openai_compatible.logging")
def test_openai_compatible_provider_uses_validation(self, mock_logging: Mock) -> None:
"""Test that OpenAI-compatible providers use the base validation."""
from providers.xai import XAIModelProvider
# Create a provider instance (XAI inherits from OpenAICompatibleProvider)
provider = XAIModelProvider(api_key="test-key")
# Test with non-existent file
result = provider._process_image("/nonexistent/image.png")
assert result is None
mock_logging.warning.assert_called_with("Image file not found: /nonexistent/image.png")
def test_data_url_preservation(self) -> None:
"""Test that data URLs are properly preserved through validation."""
from providers.xai import XAIModelProvider
provider = XAIModelProvider(api_key="test-key")
# Valid data URL
data_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
result = provider._process_image(data_url)
assert result is not None
assert result["type"] == "image_url"
assert result["image_url"]["url"] == data_url