refactor: moved image related code out of base provider into a separate utility

This commit is contained in:
Fahad
2025-10-02 11:23:15 +04:00
parent a254ff2220
commit 14a35afa1d
5 changed files with 122 additions and 152 deletions

View File

@@ -1,17 +1,12 @@
"""Base interfaces and common behaviour for model providers."""
import base64
import binascii
import logging
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from utils.file_types import IMAGES, get_image_mime_type
from .shared import ModelCapabilities, ModelResponse, ProviderType
logger = logging.getLogger(__name__)
@@ -43,9 +38,6 @@ class ModelProvider(ABC):
# All concrete providers must define their supported models
MODEL_CAPABILITIES: dict[str, Any] = {}
# Default maximum image size in MB
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
def __init__(self, api_key: str, **kwargs):
"""Initialize the provider with API key and optional configuration."""
self.api_key = api_key
@@ -167,17 +159,7 @@ class ModelProvider(ABC):
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Return formatted model names supported by this provider.
Args:
respect_restrictions: Apply provider restriction policy.
include_aliases: Include aliases alongside canonical model names.
lowercase: Normalize returned names to lowercase.
unique: Deduplicate names after formatting.
Returns:
List of model names formatted according to the provided options.
"""
"""Return formatted model names supported by this provider."""
model_configs = self.get_model_configurations()
if not model_configs:
@@ -206,77 +188,6 @@ class ModelProvider(ABC):
unique=unique,
)
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
"""Provider-independent image validation.
Args:
image_path: Path to image file or data URL
max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB)
Returns:
Tuple of (image_bytes, mime_type)
Raises:
ValueError: If image is invalid
Examples:
# Validate a file path
image_bytes, mime_type = provider.validate_image("/path/to/image.png")
# Validate a data URL
image_bytes, mime_type = provider.validate_image("data:image/png;base64,...")
# Validate with custom size limit
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
"""
# Use default if not specified
if max_size_mb is None:
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
if image_path.startswith("data:"):
# Parse data URL: ...
try:
header, data = image_path.split(",", 1)
mime_type = header.split(";")[0].split(":")[1]
except (ValueError, IndexError) as e:
raise ValueError(f"Invalid data URL format: {e}")
# Validate MIME type using IMAGES constant
valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES]
if mime_type not in valid_mime_types:
raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}")
# Decode base64 data
try:
image_bytes = base64.b64decode(data)
except binascii.Error as e:
raise ValueError(f"Invalid base64 data: {e}")
else:
# Handle file path
# Read file first to check if it exists
try:
with open(image_path, "rb") as f:
image_bytes = f.read()
except FileNotFoundError:
raise ValueError(f"Image file not found: {image_path}")
except Exception as e:
raise ValueError(f"Failed to read image file: {e}")
# Validate extension
ext = os.path.splitext(image_path)[1].lower()
if ext not in IMAGES:
raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}")
# Get MIME type
mime_type = get_image_mime_type(ext)
# Validate size
size_mb = len(image_bytes) / (1024 * 1024)
if size_mb > max_size_mb:
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
return image_bytes, mime_type
def close(self):
"""Clean up any resources held by the provider.

View File

@@ -11,6 +11,8 @@ if TYPE_CHECKING:
from google import genai
from google.genai import types
from utils.image_utils import validate_image
from .base import ModelProvider
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
@@ -529,7 +531,7 @@ class GeminiModelProvider(ModelProvider):
"""Process an image for Gemini API."""
try:
# Use base class validation
image_bytes, mime_type = self.validate_image(image_path)
image_bytes, mime_type = validate_image(image_path)
# For data URLs, extract the base64 data directly
if image_path.startswith("data:"):

View File

@@ -11,6 +11,8 @@ from urllib.parse import urlparse
from openai import OpenAI
from utils.image_utils import validate_image
from .base import ModelProvider
from .shared import (
ModelCapabilities,
@@ -830,12 +832,12 @@ class OpenAICompatibleProvider(ModelProvider):
try:
if image_path.startswith("data:"):
# Validate the data URL
self.validate_image(image_path)
validate_image(image_path)
# Handle data URL: ...
return {"type": "image_url", "image_url": {"url": image_path}}
else:
# Use base class validation
image_bytes, mime_type = self.validate_image(image_path)
image_bytes, mime_type = validate_image(image_path)
# Read and encode the image
import base64