refactor: Extract image validation to provider base class

Consolidates duplicated image validation logic from individual providers
into a reusable base class method. This improves maintainability and
ensures consistent validation across all providers.

- Added validate_image() method to ModelProvider base class
- Supports both file paths and data URLs
- Validates image format, size, and MIME types
- Added DEFAULT_MAX_IMAGE_SIZE_MB class constant (20MB)
- Refactored Gemini and OpenAI providers to use base validation
- Added comprehensive test suite with 19 tests
- Used minimal mocking approach with concrete test provider class
This commit is contained in:
Nate Parsons
2025-07-10 22:35:07 -07:00
parent ad6b216265
commit 70d6cf8b54
4 changed files with 409 additions and 35 deletions

View File

@@ -196,6 +196,9 @@ class ModelProvider(ABC):
# All concrete providers must define their supported models
SUPPORTED_MODELS: dict[str, Any] = {}
# Default maximum image size in MB
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
def __init__(self, api_key: str, **kwargs):
"""Initialize the provider with API key and optional configuration."""
self.api_key = api_key
@@ -433,6 +436,83 @@ class ModelProvider(ABC):
return list(all_models)
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
"""Provider-independent image validation.
Args:
image_path: Path to image file or data URL
max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB)
Returns:
Tuple of (image_bytes, mime_type)
Raises:
ValueError: If image is invalid
Examples:
# Validate a file path
image_bytes, mime_type = provider.validate_image("/path/to/image.png")
# Validate a data URL
image_bytes, mime_type = provider.validate_image("data:image/png;base64,...")
# Validate with custom size limit
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
"""
import base64
import os
from utils.file_types import IMAGES, get_image_mime_type
# Use default if not specified
if max_size_mb is None:
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
if image_path.startswith("data:"):
# Parse data URL: data:image/png;base64,iVBORw0...
try:
header, data = image_path.split(",", 1)
mime_type = header.split(";")[0].split(":")[1]
except (ValueError, IndexError) as e:
raise ValueError(f"Invalid data URL format: {e}")
# Validate MIME type using IMAGES constant
valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES]
if mime_type not in valid_mime_types:
raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}")
# Decode base64 data
try:
image_bytes = base64.b64decode(data)
except Exception as e:
raise ValueError(f"Invalid base64 data: {e}")
else:
# Handle file path
if not os.path.exists(image_path):
raise ValueError(f"Image file not found: {image_path}")
# Validate extension
ext = os.path.splitext(image_path)[1].lower()
if ext not in IMAGES:
raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}")
# Get MIME type
mime_type = get_image_mime_type(ext)
# Read file
try:
with open(image_path, "rb") as f:
image_bytes = f.read()
except Exception as e:
raise ValueError(f"Failed to read image file: {e}")
# Validate size
size_mb = len(image_bytes) / (1024 * 1024)
if size_mb > max_size_mb:
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
return image_bytes, mime_type
def close(self):
"""Clean up any resources held by the provider.