refactor: Extract image validation to provider base class

Consolidates duplicated image validation logic from individual providers
into a reusable base class method. This improves maintainability and
ensures consistent validation across all providers.

- Added validate_image() method to ModelProvider base class
- Supports both file paths and data URLs
- Validates image format, size, and MIME types
- Added DEFAULT_MAX_IMAGE_SIZE_MB class constant (20MB)
- Refactored Gemini and OpenAI providers to use base validation
- Added comprehensive test suite with 19 tests
- Used minimal mocking approach with concrete test provider class
This commit is contained in:
Nate Parsons
2025-07-10 22:35:07 -07:00
parent ad6b216265
commit 70d6cf8b54
4 changed files with 409 additions and 35 deletions

View File

@@ -196,6 +196,9 @@ class ModelProvider(ABC):
# All concrete providers must define their supported models
SUPPORTED_MODELS: dict[str, Any] = {}
# Default maximum image size in MB
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
def __init__(self, api_key: str, **kwargs):
"""Initialize the provider with API key and optional configuration."""
self.api_key = api_key
@@ -433,6 +436,83 @@ class ModelProvider(ABC):
return list(all_models)
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
"""Provider-independent image validation.
Args:
image_path: Path to image file or data URL
max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB)
Returns:
Tuple of (image_bytes, mime_type)
Raises:
ValueError: If image is invalid
Examples:
# Validate a file path
image_bytes, mime_type = provider.validate_image("/path/to/image.png")
# Validate a data URL
image_bytes, mime_type = provider.validate_image("data:image/png;base64,...")
# Validate with custom size limit
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
"""
import base64
import os
from utils.file_types import IMAGES, get_image_mime_type
# Use default if not specified
if max_size_mb is None:
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
if image_path.startswith("data:"):
# Parse data URL: data:image/png;base64,iVBORw0...
try:
header, data = image_path.split(",", 1)
mime_type = header.split(";")[0].split(":")[1]
except (ValueError, IndexError) as e:
raise ValueError(f"Invalid data URL format: {e}")
# Validate MIME type using IMAGES constant
valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES]
if mime_type not in valid_mime_types:
raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}")
# Decode base64 data
try:
image_bytes = base64.b64decode(data)
except Exception as e:
raise ValueError(f"Invalid base64 data: {e}")
else:
# Handle file path
if not os.path.exists(image_path):
raise ValueError(f"Image file not found: {image_path}")
# Validate extension
ext = os.path.splitext(image_path)[1].lower()
if ext not in IMAGES:
raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}")
# Get MIME type
mime_type = get_image_mime_type(ext)
# Read file
try:
with open(image_path, "rb") as f:
image_bytes = f.read()
except Exception as e:
raise ValueError(f"Failed to read image file: {e}")
# Validate size
size_mb = len(image_bytes) / (1024 * 1024)
if size_mb > max_size_mb:
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
return image_bytes, mime_type
def close(self):
"""Clean up any resources held by the provider.

View File

@@ -2,7 +2,6 @@
import base64
import logging
import os
import time
from typing import Optional
@@ -440,28 +439,22 @@ class GeminiModelProvider(ModelProvider):
def _process_image(self, image_path: str) -> Optional[dict]:
"""Process an image for Gemini API."""
try:
if image_path.startswith("data:image/"):
# Handle data URL: data:image/png;base64,iVBORw0...
header, data = image_path.split(",", 1)
mime_type = header.split(";")[0].split(":")[1]
# Use base class validation
image_bytes, mime_type = self.validate_image(image_path)
# For data URLs, extract the base64 data directly
if image_path.startswith("data:"):
# Extract base64 data from data URL
_, data = image_path.split(",", 1)
return {"inline_data": {"mime_type": mime_type, "data": data}}
else:
# Handle file path
from utils.file_types import get_image_mime_type
if not os.path.exists(image_path):
logger.warning(f"Image file not found: {image_path}")
return None
# Detect MIME type from file extension using centralized mappings
ext = os.path.splitext(image_path)[1].lower()
mime_type = get_image_mime_type(ext)
# Read and encode the image
with open(image_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
# For file paths, encode the bytes
image_data = base64.b64encode(image_bytes).decode()
return {"inline_data": {"mime_type": mime_type, "data": image_data}}
except ValueError as e:
logger.warning(str(e))
return None
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
return None

View File

@@ -1,6 +1,5 @@
"""Base class for OpenAI-compatible API providers."""
import base64
import ipaddress
import logging
import os
@@ -788,30 +787,29 @@ class OpenAICompatibleProvider(ModelProvider):
def _process_image(self, image_path: str) -> Optional[dict]:
"""Process an image for OpenAI-compatible API."""
try:
if image_path.startswith("data:image/"):
if image_path.startswith("data:"):
# Validate the data URL
self.validate_image(image_path)
# Handle data URL: data:image/png;base64,iVBORw0...
return {"type": "image_url", "image_url": {"url": image_path}}
else:
# Handle file path
if not os.path.exists(image_path):
logging.warning(f"Image file not found: {image_path}")
return None
# Detect MIME type from file extension using centralized mappings
from utils.file_types import get_image_mime_type
ext = os.path.splitext(image_path)[1].lower()
mime_type = get_image_mime_type(ext)
logging.debug(f"Processing image '{image_path}' with extension '{ext}' as MIME type '{mime_type}'")
# Use base class validation
image_bytes, mime_type = self.validate_image(image_path)
# Read and encode the image
with open(image_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
import base64
image_data = base64.b64encode(image_bytes).decode()
logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'")
# Create data URL for OpenAI API
data_url = f"data:{mime_type};base64,{image_data}"
return {"type": "image_url", "image_url": {"url": data_url}}
except ValueError as e:
logging.warning(str(e))
return None
except Exception as e:
logging.error(f"Error processing image {image_path}: {e}")
return None