refactor: moved image related code out of base provider into a separate utility
This commit is contained in:
@@ -1,17 +1,12 @@
|
||||
"""Base interfaces and common behaviour for model providers."""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from utils.file_types import IMAGES, get_image_mime_type
|
||||
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -43,9 +38,6 @@ class ModelProvider(ABC):
|
||||
# All concrete providers must define their supported models
|
||||
MODEL_CAPABILITIES: dict[str, Any] = {}
|
||||
|
||||
# Default maximum image size in MB
|
||||
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize the provider with API key and optional configuration."""
|
||||
self.api_key = api_key
|
||||
@@ -167,17 +159,7 @@ class ModelProvider(ABC):
|
||||
lowercase: bool = False,
|
||||
unique: bool = False,
|
||||
) -> list[str]:
|
||||
"""Return formatted model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Apply provider restriction policy.
|
||||
include_aliases: Include aliases alongside canonical model names.
|
||||
lowercase: Normalize returned names to lowercase.
|
||||
unique: Deduplicate names after formatting.
|
||||
|
||||
Returns:
|
||||
List of model names formatted according to the provided options.
|
||||
"""
|
||||
"""Return formatted model names supported by this provider."""
|
||||
|
||||
model_configs = self.get_model_configurations()
|
||||
if not model_configs:
|
||||
@@ -206,77 +188,6 @@ class ModelProvider(ABC):
|
||||
unique=unique,
|
||||
)
|
||||
|
||||
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
|
||||
"""Provider-independent image validation.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file or data URL
|
||||
max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB)
|
||||
|
||||
Returns:
|
||||
Tuple of (image_bytes, mime_type)
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
|
||||
Examples:
|
||||
# Validate a file path
|
||||
image_bytes, mime_type = provider.validate_image("/path/to/image.png")
|
||||
|
||||
# Validate a data URL
|
||||
image_bytes, mime_type = provider.validate_image("data:image/png;base64,...")
|
||||
|
||||
# Validate with custom size limit
|
||||
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
|
||||
"""
|
||||
# Use default if not specified
|
||||
if max_size_mb is None:
|
||||
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
|
||||
|
||||
if image_path.startswith("data:"):
|
||||
# Parse data URL: data:image/png;base64,iVBORw0...
|
||||
try:
|
||||
header, data = image_path.split(",", 1)
|
||||
mime_type = header.split(";")[0].split(":")[1]
|
||||
except (ValueError, IndexError) as e:
|
||||
raise ValueError(f"Invalid data URL format: {e}")
|
||||
|
||||
# Validate MIME type using IMAGES constant
|
||||
valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES]
|
||||
if mime_type not in valid_mime_types:
|
||||
raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}")
|
||||
|
||||
# Decode base64 data
|
||||
try:
|
||||
image_bytes = base64.b64decode(data)
|
||||
except binascii.Error as e:
|
||||
raise ValueError(f"Invalid base64 data: {e}")
|
||||
else:
|
||||
# Handle file path
|
||||
# Read file first to check if it exists
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"Image file not found: {image_path}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to read image file: {e}")
|
||||
|
||||
# Validate extension
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
if ext not in IMAGES:
|
||||
raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}")
|
||||
|
||||
# Get MIME type
|
||||
mime_type = get_image_mime_type(ext)
|
||||
|
||||
# Validate size
|
||||
size_mb = len(image_bytes) / (1024 * 1024)
|
||||
if size_mb > max_size_mb:
|
||||
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
|
||||
|
||||
return image_bytes, mime_type
|
||||
|
||||
def close(self):
|
||||
"""Clean up any resources held by the provider.
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ if TYPE_CHECKING:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from utils.image_utils import validate_image
|
||||
|
||||
from .base import ModelProvider
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||
|
||||
@@ -529,7 +531,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
"""Process an image for Gemini API."""
|
||||
try:
|
||||
# Use base class validation
|
||||
image_bytes, mime_type = self.validate_image(image_path)
|
||||
image_bytes, mime_type = validate_image(image_path)
|
||||
|
||||
# For data URLs, extract the base64 data directly
|
||||
if image_path.startswith("data:"):
|
||||
|
||||
@@ -11,6 +11,8 @@ from urllib.parse import urlparse
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from utils.image_utils import validate_image
|
||||
|
||||
from .base import ModelProvider
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
@@ -830,12 +832,12 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
try:
|
||||
if image_path.startswith("data:"):
|
||||
# Validate the data URL
|
||||
self.validate_image(image_path)
|
||||
validate_image(image_path)
|
||||
# Handle data URL: data:image/png;base64,iVBORw0...
|
||||
return {"type": "image_url", "image_url": {"url": image_path}}
|
||||
else:
|
||||
# Use base class validation
|
||||
image_bytes, mime_type = self.validate_image(image_path)
|
||||
image_bytes, mime_type = validate_image(image_path)
|
||||
|
||||
# Read and encode the image
|
||||
import base64
|
||||
|
||||
Reference in New Issue
Block a user