refactor: moved image related code out of base provider into a separate utility

This commit is contained in:
Fahad
2025-10-02 11:23:15 +04:00
parent a254ff2220
commit 14a35afa1d
5 changed files with 122 additions and 152 deletions

View File

@@ -1,17 +1,12 @@
"""Base interfaces and common behaviour for model providers.""" """Base interfaces and common behaviour for model providers."""
import base64
import binascii
import logging import logging
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from tools.models import ToolModelCategory from tools.models import ToolModelCategory
from utils.file_types import IMAGES, get_image_mime_type
from .shared import ModelCapabilities, ModelResponse, ProviderType from .shared import ModelCapabilities, ModelResponse, ProviderType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -43,9 +38,6 @@ class ModelProvider(ABC):
# All concrete providers must define their supported models # All concrete providers must define their supported models
MODEL_CAPABILITIES: dict[str, Any] = {} MODEL_CAPABILITIES: dict[str, Any] = {}
# Default maximum image size in MB
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
def __init__(self, api_key: str, **kwargs): def __init__(self, api_key: str, **kwargs):
"""Initialize the provider with API key and optional configuration.""" """Initialize the provider with API key and optional configuration."""
self.api_key = api_key self.api_key = api_key
@@ -167,17 +159,7 @@ class ModelProvider(ABC):
lowercase: bool = False, lowercase: bool = False,
unique: bool = False, unique: bool = False,
) -> list[str]: ) -> list[str]:
"""Return formatted model names supported by this provider. """Return formatted model names supported by this provider."""
Args:
respect_restrictions: Apply provider restriction policy.
include_aliases: Include aliases alongside canonical model names.
lowercase: Normalize returned names to lowercase.
unique: Deduplicate names after formatting.
Returns:
List of model names formatted according to the provided options.
"""
model_configs = self.get_model_configurations() model_configs = self.get_model_configurations()
if not model_configs: if not model_configs:
@@ -206,77 +188,6 @@ class ModelProvider(ABC):
unique=unique, unique=unique,
) )
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
"""Provider-independent image validation.
Args:
image_path: Path to image file or data URL
max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB)
Returns:
Tuple of (image_bytes, mime_type)
Raises:
ValueError: If image is invalid
Examples:
# Validate a file path
image_bytes, mime_type = provider.validate_image("/path/to/image.png")
# Validate a data URL
image_bytes, mime_type = provider.validate_image("data:image/png;base64,...")
# Validate with custom size limit
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
"""
# Use default if not specified
if max_size_mb is None:
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
if image_path.startswith("data:"):
# Parse data URL: data:image/png;base64,iVBORw0...
try:
header, data = image_path.split(",", 1)
mime_type = header.split(";")[0].split(":")[1]
except (ValueError, IndexError) as e:
raise ValueError(f"Invalid data URL format: {e}")
# Validate MIME type using IMAGES constant
valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES]
if mime_type not in valid_mime_types:
raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}")
# Decode base64 data
try:
image_bytes = base64.b64decode(data)
except binascii.Error as e:
raise ValueError(f"Invalid base64 data: {e}")
else:
# Handle file path
# Read file first to check if it exists
try:
with open(image_path, "rb") as f:
image_bytes = f.read()
except FileNotFoundError:
raise ValueError(f"Image file not found: {image_path}")
except Exception as e:
raise ValueError(f"Failed to read image file: {e}")
# Validate extension
ext = os.path.splitext(image_path)[1].lower()
if ext not in IMAGES:
raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}")
# Get MIME type
mime_type = get_image_mime_type(ext)
# Validate size
size_mb = len(image_bytes) / (1024 * 1024)
if size_mb > max_size_mb:
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
return image_bytes, mime_type
def close(self): def close(self):
"""Clean up any resources held by the provider. """Clean up any resources held by the provider.

View File

@@ -11,6 +11,8 @@ if TYPE_CHECKING:
from google import genai from google import genai
from google.genai import types from google.genai import types
from utils.image_utils import validate_image
from .base import ModelProvider from .base import ModelProvider
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
@@ -529,7 +531,7 @@ class GeminiModelProvider(ModelProvider):
"""Process an image for Gemini API.""" """Process an image for Gemini API."""
try: try:
# Use base class validation # Use base class validation
image_bytes, mime_type = self.validate_image(image_path) image_bytes, mime_type = validate_image(image_path)
# For data URLs, extract the base64 data directly # For data URLs, extract the base64 data directly
if image_path.startswith("data:"): if image_path.startswith("data:"):

View File

@@ -11,6 +11,8 @@ from urllib.parse import urlparse
from openai import OpenAI from openai import OpenAI
from utils.image_utils import validate_image
from .base import ModelProvider from .base import ModelProvider
from .shared import ( from .shared import (
ModelCapabilities, ModelCapabilities,
@@ -830,12 +832,12 @@ class OpenAICompatibleProvider(ModelProvider):
try: try:
if image_path.startswith("data:"): if image_path.startswith("data:"):
# Validate the data URL # Validate the data URL
self.validate_image(image_path) validate_image(image_path)
# Handle data URL: data:image/png;base64,iVBORw0... # Handle data URL: data:image/png;base64,iVBORw0...
return {"type": "image_url", "image_url": {"url": image_path}} return {"type": "image_url", "image_url": {"url": image_path}}
else: else:
# Use base class validation # Use base class validation
image_bytes, mime_type = self.validate_image(image_path) image_bytes, mime_type = validate_image(image_path)
# Read and encode the image # Read and encode the image
import base64 import base64

View File

@@ -1,57 +1,18 @@
"""Tests for provider-independent image validation.""" """Tests for image validation utility helpers."""
import base64 import base64
import os import os
import tempfile import tempfile
from typing import Optional
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
from providers.base import ModelProvider from utils.image_utils import DEFAULT_MAX_IMAGE_SIZE_MB, validate_image
from providers.shared import ModelCapabilities, ModelResponse, ProviderType
class MinimalTestProvider(ModelProvider):
"""Minimal concrete provider for testing base class methods."""
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def count_tokens(self, text: str, model_name: str) -> int:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def get_provider_type(self) -> ProviderType:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def validate_model_name(self, model_name: str) -> bool:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
class TestImageValidation: class TestImageValidation:
"""Test suite for image validation functionality.""" """Test suite for image validation functionality."""
def setup_method(self) -> None:
"""Set up test fixtures."""
# Create a minimal concrete provider instance for testing base class methods
self.provider = MinimalTestProvider(api_key="test-key")
def test_validate_data_url_valid(self) -> None: def test_validate_data_url_valid(self) -> None:
"""Test validation of valid data URL.""" """Test validation of valid data URL."""
# Create a small test image (1x1 PNG) # Create a small test image (1x1 PNG)
@@ -60,7 +21,7 @@ class TestImageValidation:
) )
data_url = f"data:image/png;base64,{base64.b64encode(test_image_data).decode()}" data_url = f"data:image/png;base64,{base64.b64encode(test_image_data).decode()}"
image_bytes, mime_type = self.provider.validate_image(data_url) image_bytes, mime_type = validate_image(data_url)
assert image_bytes == test_image_data assert image_bytes == test_image_data
assert mime_type == "image/png" assert mime_type == "image/png"
@@ -76,14 +37,14 @@ class TestImageValidation:
def test_validate_data_url_invalid_format(self, invalid_url: str, expected_error: str) -> None: def test_validate_data_url_invalid_format(self, invalid_url: str, expected_error: str) -> None:
"""Test validation of malformed data URL.""" """Test validation of malformed data URL."""
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(invalid_url) validate_image(invalid_url)
assert expected_error in str(excinfo.value) assert expected_error in str(excinfo.value)
def test_non_data_url_treated_as_file_path(self) -> None: def test_non_data_url_treated_as_file_path(self) -> None:
"""Test that non-data URLs are treated as file paths.""" """Test that non-data URLs are treated as file paths."""
# Test case that's not a data URL at all # Test case that's not a data URL at all
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
self.provider.validate_image("image/png;base64,abc123") validate_image("image/png;base64,abc123")
assert "Image file not found" in str(excinfo.value) # Treated as file path assert "Image file not found" in str(excinfo.value) # Treated as file path
def test_validate_data_url_unsupported_type(self) -> None: def test_validate_data_url_unsupported_type(self) -> None:
@@ -91,7 +52,7 @@ class TestImageValidation:
data_url = "data:image/bmp;base64,Qk0=" # BMP format data_url = "data:image/bmp;base64,Qk0=" # BMP format
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url) validate_image(data_url)
assert "Unsupported image type: image/bmp" in str(excinfo.value) assert "Unsupported image type: image/bmp" in str(excinfo.value)
def test_validate_data_url_invalid_base64(self) -> None: def test_validate_data_url_invalid_base64(self) -> None:
@@ -99,7 +60,7 @@ class TestImageValidation:
data_url = "data:image/png;base64,@@@invalid@@@" data_url = "data:image/png;base64,@@@invalid@@@"
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url) validate_image(data_url)
assert "Invalid base64 data" in str(excinfo.value) assert "Invalid base64 data" in str(excinfo.value)
def test_validate_large_data_url(self) -> None: def test_validate_large_data_url(self) -> None:
@@ -115,11 +76,11 @@ class TestImageValidation:
# Should fail with default 20MB limit # Should fail with default 20MB limit
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url) validate_image(data_url)
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value) assert f"Image too large: 21.0MB (max: {DEFAULT_MAX_IMAGE_SIZE_MB:.1f}MB)" in str(excinfo.value)
# Should succeed with higher limit # Should succeed with higher limit
image_bytes, mime_type = self.provider.validate_image(data_url, max_size_mb=25.0) image_bytes, mime_type = validate_image(data_url, max_size_mb=25.0)
assert len(image_bytes) == len(large_data) assert len(image_bytes) == len(large_data)
assert mime_type == "image/png" assert mime_type == "image/png"
@@ -135,7 +96,7 @@ class TestImageValidation:
tmp_file_path = tmp_file.name tmp_file_path = tmp_file.name
try: try:
image_bytes, mime_type = self.provider.validate_image(tmp_file_path) image_bytes, mime_type = validate_image(tmp_file_path)
assert image_bytes == test_image_data assert image_bytes == test_image_data
assert mime_type == "image/png" assert mime_type == "image/png"
@@ -145,7 +106,7 @@ class TestImageValidation:
def test_validate_file_path_not_found(self) -> None: def test_validate_file_path_not_found(self) -> None:
"""Test validation of non-existent file.""" """Test validation of non-existent file."""
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
self.provider.validate_image("/path/to/nonexistent/image.png") validate_image("/path/to/nonexistent/image.png")
assert "Image file not found" in str(excinfo.value) assert "Image file not found" in str(excinfo.value)
def test_validate_file_path_unsupported_extension(self) -> None: def test_validate_file_path_unsupported_extension(self) -> None:
@@ -156,7 +117,7 @@ class TestImageValidation:
try: try:
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path) validate_image(tmp_file_path)
assert "Unsupported image format: .bmp" in str(excinfo.value) assert "Unsupported image format: .bmp" in str(excinfo.value)
finally: finally:
os.unlink(tmp_file_path) os.unlink(tmp_file_path)
@@ -170,7 +131,7 @@ class TestImageValidation:
os.unlink(tmp_file_path) os.unlink(tmp_file_path)
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path) validate_image(tmp_file_path)
assert "Image file not found" in str(excinfo.value) assert "Image file not found" in str(excinfo.value)
def test_validate_image_size_limit(self) -> None: def test_validate_image_size_limit(self) -> None:
@@ -184,7 +145,7 @@ class TestImageValidation:
try: try:
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path, max_size_mb=20.0) validate_image(tmp_file_path, max_size_mb=20.0)
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value) assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
finally: finally:
os.unlink(tmp_file_path) os.unlink(tmp_file_path)
@@ -201,11 +162,11 @@ class TestImageValidation:
try: try:
# Should fail with 1MB limit # Should fail with 1MB limit
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path, max_size_mb=1.0) validate_image(tmp_file_path, max_size_mb=1.0)
assert "Image too large: 2.0MB (max: 1.0MB)" in str(excinfo.value) assert "Image too large: 2.0MB (max: 1.0MB)" in str(excinfo.value)
# Should succeed with 3MB limit # Should succeed with 3MB limit
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=3.0) image_bytes, mime_type = validate_image(tmp_file_path, max_size_mb=3.0)
assert len(image_bytes) == len(data) assert len(image_bytes) == len(data)
assert mime_type == "image/png" assert mime_type == "image/png"
finally: finally:
@@ -222,12 +183,12 @@ class TestImageValidation:
try: try:
# Should succeed with default limit (20MB) # Should succeed with default limit (20MB)
image_bytes, mime_type = self.provider.validate_image(tmp_file_path) image_bytes, mime_type = validate_image(tmp_file_path)
assert len(image_bytes) == len(data) assert len(image_bytes) == len(data)
assert mime_type == "image/jpeg" assert mime_type == "image/jpeg"
# Should also succeed when explicitly passing None # Should also succeed when explicitly passing None
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=None) image_bytes, mime_type = validate_image(tmp_file_path, max_size_mb=None)
assert len(image_bytes) == len(data) assert len(image_bytes) == len(data)
assert mime_type == "image/jpeg" assert mime_type == "image/jpeg"
finally: finally:
@@ -249,7 +210,7 @@ class TestImageValidation:
tmp_file_path = tmp_file.name tmp_file_path = tmp_file.name
try: try:
image_bytes, mime_type = self.provider.validate_image(tmp_file_path) image_bytes, mime_type = validate_image(tmp_file_path)
assert mime_type == expected_mime assert mime_type == expected_mime
assert image_bytes == b"dummy image data" assert image_bytes == b"dummy image data"
finally: finally:

94
utils/image_utils.py Normal file
View File

@@ -0,0 +1,94 @@
"""Utility helpers for validating image inputs."""
import base64
import binascii
import os
from collections.abc import Iterable
from utils.file_types import IMAGES, get_image_mime_type
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
__all__ = ["DEFAULT_MAX_IMAGE_SIZE_MB", "validate_image"]
def _valid_mime_types() -> Iterable[str]:
"""Return the MIME types permitted by the IMAGES whitelist."""
return (get_image_mime_type(ext) for ext in IMAGES)
def validate_image(image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
"""Validate a user-supplied image path or data URL.
Args:
image_path: Either a filesystem path or a data URL.
max_size_mb: Optional size limit (defaults to ``DEFAULT_MAX_IMAGE_SIZE_MB``).
Returns:
A tuple ``(image_bytes, mime_type)`` ready for upstream providers.
Raises:
ValueError: When the image is missing, malformed, or exceeds limits.
"""
if max_size_mb is None:
max_size_mb = DEFAULT_MAX_IMAGE_SIZE_MB
if image_path.startswith("data:"):
return _validate_data_url(image_path, max_size_mb)
return _validate_file_path(image_path, max_size_mb)
def _validate_data_url(image_data_url: str, max_size_mb: float) -> tuple[bytes, str]:
"""Validate a data URL and return image bytes plus MIME type."""
try:
header, data = image_data_url.split(",", 1)
mime_type = header.split(";")[0].split(":")[1]
except (ValueError, IndexError) as exc:
raise ValueError(f"Invalid data URL format: {exc}")
valid_mime_types = list(_valid_mime_types())
if mime_type not in valid_mime_types:
raise ValueError(
"Unsupported image type: {mime}. Supported types: {supported}".format(
mime=mime_type, supported=", ".join(valid_mime_types)
)
)
try:
image_bytes = base64.b64decode(data)
except binascii.Error as exc:
raise ValueError(f"Invalid base64 data: {exc}")
_validate_size(image_bytes, max_size_mb)
return image_bytes, mime_type
def _validate_file_path(file_path: str, max_size_mb: float) -> tuple[bytes, str]:
"""Validate an image loaded from the filesystem."""
try:
with open(file_path, "rb") as handle:
image_bytes = handle.read()
except FileNotFoundError:
raise ValueError(f"Image file not found: {file_path}")
except OSError as exc:
raise ValueError(f"Failed to read image file: {exc}")
ext = os.path.splitext(file_path)[1].lower()
if ext not in IMAGES:
raise ValueError(
"Unsupported image format: {ext}. Supported formats: {supported}".format(
ext=ext, supported=", ".join(sorted(IMAGES))
)
)
mime_type = get_image_mime_type(ext)
_validate_size(image_bytes, max_size_mb)
return image_bytes, mime_type
def _validate_size(image_bytes: bytes, max_size_mb: float) -> None:
"""Ensure the image does not exceed the configured size limit."""
size_mb = len(image_bytes) / (1024 * 1024)
if size_mb > max_size_mb:
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")