refactor: Extract image validation to provider base class
Consolidates duplicated image validation logic from individual providers into a reusable base class method. This improves maintainability and ensures consistent validation across all providers. - Added validate_image() method to ModelProvider base class - Supports both file paths and data URLs - Validates image format, size, and MIME types - Added DEFAULT_MAX_IMAGE_SIZE_MB class constant (20MB) - Refactored Gemini and OpenAI providers to use base validation - Added comprehensive test suite with 19 tests - Used minimal mocking approach with concrete test provider class
This commit is contained in:
@@ -196,6 +196,9 @@ class ModelProvider(ABC):
|
||||
# All concrete providers must define their supported models
|
||||
SUPPORTED_MODELS: dict[str, Any] = {}
|
||||
|
||||
# Default maximum image size in MB
|
||||
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize the provider with API key and optional configuration."""
|
||||
self.api_key = api_key
|
||||
@@ -433,6 +436,83 @@ class ModelProvider(ABC):
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
|
||||
"""Provider-independent image validation.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file or data URL
|
||||
max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB)
|
||||
|
||||
Returns:
|
||||
Tuple of (image_bytes, mime_type)
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
|
||||
Examples:
|
||||
# Validate a file path
|
||||
image_bytes, mime_type = provider.validate_image("/path/to/image.png")
|
||||
|
||||
# Validate a data URL
|
||||
image_bytes, mime_type = provider.validate_image("data:image/png;base64,...")
|
||||
|
||||
# Validate with custom size limit
|
||||
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
|
||||
from utils.file_types import IMAGES, get_image_mime_type
|
||||
|
||||
# Use default if not specified
|
||||
if max_size_mb is None:
|
||||
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
|
||||
|
||||
if image_path.startswith("data:"):
|
||||
# Parse data URL: data:image/png;base64,iVBORw0...
|
||||
try:
|
||||
header, data = image_path.split(",", 1)
|
||||
mime_type = header.split(";")[0].split(":")[1]
|
||||
except (ValueError, IndexError) as e:
|
||||
raise ValueError(f"Invalid data URL format: {e}")
|
||||
|
||||
# Validate MIME type using IMAGES constant
|
||||
valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES]
|
||||
if mime_type not in valid_mime_types:
|
||||
raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}")
|
||||
|
||||
# Decode base64 data
|
||||
try:
|
||||
image_bytes = base64.b64decode(data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid base64 data: {e}")
|
||||
else:
|
||||
# Handle file path
|
||||
if not os.path.exists(image_path):
|
||||
raise ValueError(f"Image file not found: {image_path}")
|
||||
|
||||
# Validate extension
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
if ext not in IMAGES:
|
||||
raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}")
|
||||
|
||||
# Get MIME type
|
||||
mime_type = get_image_mime_type(ext)
|
||||
|
||||
# Read file
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to read image file: {e}")
|
||||
|
||||
# Validate size
|
||||
size_mb = len(image_bytes) / (1024 * 1024)
|
||||
if size_mb > max_size_mb:
|
||||
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
|
||||
|
||||
return image_bytes, mime_type
|
||||
|
||||
def close(self):
|
||||
"""Clean up any resources held by the provider.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user