refactor: moved image related code out of base provider into a separate utility
This commit is contained in:
@@ -1,17 +1,12 @@
|
|||||||
"""Base interfaces and common behaviour for model providers."""
|
"""Base interfaces and common behaviour for model providers."""
|
||||||
|
|
||||||
import base64
|
|
||||||
import binascii
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tools.models import ToolModelCategory
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
from utils.file_types import IMAGES, get_image_mime_type
|
|
||||||
|
|
||||||
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -43,9 +38,6 @@ class ModelProvider(ABC):
|
|||||||
# All concrete providers must define their supported models
|
# All concrete providers must define their supported models
|
||||||
MODEL_CAPABILITIES: dict[str, Any] = {}
|
MODEL_CAPABILITIES: dict[str, Any] = {}
|
||||||
|
|
||||||
# Default maximum image size in MB
|
|
||||||
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
|
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize the provider with API key and optional configuration."""
|
"""Initialize the provider with API key and optional configuration."""
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
@@ -167,17 +159,7 @@ class ModelProvider(ABC):
|
|||||||
lowercase: bool = False,
|
lowercase: bool = False,
|
||||||
unique: bool = False,
|
unique: bool = False,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Return formatted model names supported by this provider.
|
"""Return formatted model names supported by this provider."""
|
||||||
|
|
||||||
Args:
|
|
||||||
respect_restrictions: Apply provider restriction policy.
|
|
||||||
include_aliases: Include aliases alongside canonical model names.
|
|
||||||
lowercase: Normalize returned names to lowercase.
|
|
||||||
unique: Deduplicate names after formatting.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model names formatted according to the provided options.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_configs = self.get_model_configurations()
|
model_configs = self.get_model_configurations()
|
||||||
if not model_configs:
|
if not model_configs:
|
||||||
@@ -206,77 +188,6 @@ class ModelProvider(ABC):
|
|||||||
unique=unique,
|
unique=unique,
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
|
|
||||||
"""Provider-independent image validation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_path: Path to image file or data URL
|
|
||||||
max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (image_bytes, mime_type)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If image is invalid
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
# Validate a file path
|
|
||||||
image_bytes, mime_type = provider.validate_image("/path/to/image.png")
|
|
||||||
|
|
||||||
# Validate a data URL
|
|
||||||
image_bytes, mime_type = provider.validate_image("data:image/png;base64,...")
|
|
||||||
|
|
||||||
# Validate with custom size limit
|
|
||||||
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
|
|
||||||
"""
|
|
||||||
# Use default if not specified
|
|
||||||
if max_size_mb is None:
|
|
||||||
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
|
|
||||||
|
|
||||||
if image_path.startswith("data:"):
|
|
||||||
# Parse data URL: data:image/png;base64,iVBORw0...
|
|
||||||
try:
|
|
||||||
header, data = image_path.split(",", 1)
|
|
||||||
mime_type = header.split(";")[0].split(":")[1]
|
|
||||||
except (ValueError, IndexError) as e:
|
|
||||||
raise ValueError(f"Invalid data URL format: {e}")
|
|
||||||
|
|
||||||
# Validate MIME type using IMAGES constant
|
|
||||||
valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES]
|
|
||||||
if mime_type not in valid_mime_types:
|
|
||||||
raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}")
|
|
||||||
|
|
||||||
# Decode base64 data
|
|
||||||
try:
|
|
||||||
image_bytes = base64.b64decode(data)
|
|
||||||
except binascii.Error as e:
|
|
||||||
raise ValueError(f"Invalid base64 data: {e}")
|
|
||||||
else:
|
|
||||||
# Handle file path
|
|
||||||
# Read file first to check if it exists
|
|
||||||
try:
|
|
||||||
with open(image_path, "rb") as f:
|
|
||||||
image_bytes = f.read()
|
|
||||||
except FileNotFoundError:
|
|
||||||
raise ValueError(f"Image file not found: {image_path}")
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to read image file: {e}")
|
|
||||||
|
|
||||||
# Validate extension
|
|
||||||
ext = os.path.splitext(image_path)[1].lower()
|
|
||||||
if ext not in IMAGES:
|
|
||||||
raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}")
|
|
||||||
|
|
||||||
# Get MIME type
|
|
||||||
mime_type = get_image_mime_type(ext)
|
|
||||||
|
|
||||||
# Validate size
|
|
||||||
size_mb = len(image_bytes) / (1024 * 1024)
|
|
||||||
if size_mb > max_size_mb:
|
|
||||||
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
|
|
||||||
|
|
||||||
return image_bytes, mime_type
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Clean up any resources held by the provider.
|
"""Clean up any resources held by the provider.
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ if TYPE_CHECKING:
|
|||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
|
from utils.image_utils import validate_image
|
||||||
|
|
||||||
from .base import ModelProvider
|
from .base import ModelProvider
|
||||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||||
|
|
||||||
@@ -529,7 +531,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
"""Process an image for Gemini API."""
|
"""Process an image for Gemini API."""
|
||||||
try:
|
try:
|
||||||
# Use base class validation
|
# Use base class validation
|
||||||
image_bytes, mime_type = self.validate_image(image_path)
|
image_bytes, mime_type = validate_image(image_path)
|
||||||
|
|
||||||
# For data URLs, extract the base64 data directly
|
# For data URLs, extract the base64 data directly
|
||||||
if image_path.startswith("data:"):
|
if image_path.startswith("data:"):
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from utils.image_utils import validate_image
|
||||||
|
|
||||||
from .base import ModelProvider
|
from .base import ModelProvider
|
||||||
from .shared import (
|
from .shared import (
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
@@ -830,12 +832,12 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
try:
|
try:
|
||||||
if image_path.startswith("data:"):
|
if image_path.startswith("data:"):
|
||||||
# Validate the data URL
|
# Validate the data URL
|
||||||
self.validate_image(image_path)
|
validate_image(image_path)
|
||||||
# Handle data URL: data:image/png;base64,iVBORw0...
|
# Handle data URL: data:image/png;base64,iVBORw0...
|
||||||
return {"type": "image_url", "image_url": {"url": image_path}}
|
return {"type": "image_url", "image_url": {"url": image_path}}
|
||||||
else:
|
else:
|
||||||
# Use base class validation
|
# Use base class validation
|
||||||
image_bytes, mime_type = self.validate_image(image_path)
|
image_bytes, mime_type = validate_image(image_path)
|
||||||
|
|
||||||
# Read and encode the image
|
# Read and encode the image
|
||||||
import base64
|
import base64
|
||||||
|
|||||||
@@ -1,57 +1,18 @@
|
|||||||
"""Tests for provider-independent image validation."""
|
"""Tests for image validation utility helpers."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Optional
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ModelProvider
|
from utils.image_utils import DEFAULT_MAX_IMAGE_SIZE_MB, validate_image
|
||||||
from providers.shared import ModelCapabilities, ModelResponse, ProviderType
|
|
||||||
|
|
||||||
|
|
||||||
class MinimalTestProvider(ModelProvider):
|
|
||||||
"""Minimal concrete provider for testing base class methods."""
|
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
||||||
"""Not needed for image validation tests."""
|
|
||||||
raise NotImplementedError("Not needed for image validation tests")
|
|
||||||
|
|
||||||
def generate_content(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
model_name: str,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
temperature: float = 0.7,
|
|
||||||
max_output_tokens: Optional[int] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> ModelResponse:
|
|
||||||
"""Not needed for image validation tests."""
|
|
||||||
raise NotImplementedError("Not needed for image validation tests")
|
|
||||||
|
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
|
||||||
"""Not needed for image validation tests."""
|
|
||||||
raise NotImplementedError("Not needed for image validation tests")
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
|
||||||
"""Not needed for image validation tests."""
|
|
||||||
raise NotImplementedError("Not needed for image validation tests")
|
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
|
||||||
"""Not needed for image validation tests."""
|
|
||||||
raise NotImplementedError("Not needed for image validation tests")
|
|
||||||
|
|
||||||
|
|
||||||
class TestImageValidation:
|
class TestImageValidation:
|
||||||
"""Test suite for image validation functionality."""
|
"""Test suite for image validation functionality."""
|
||||||
|
|
||||||
def setup_method(self) -> None:
|
|
||||||
"""Set up test fixtures."""
|
|
||||||
# Create a minimal concrete provider instance for testing base class methods
|
|
||||||
self.provider = MinimalTestProvider(api_key="test-key")
|
|
||||||
|
|
||||||
def test_validate_data_url_valid(self) -> None:
|
def test_validate_data_url_valid(self) -> None:
|
||||||
"""Test validation of valid data URL."""
|
"""Test validation of valid data URL."""
|
||||||
# Create a small test image (1x1 PNG)
|
# Create a small test image (1x1 PNG)
|
||||||
@@ -60,7 +21,7 @@ class TestImageValidation:
|
|||||||
)
|
)
|
||||||
data_url = f"data:image/png;base64,{base64.b64encode(test_image_data).decode()}"
|
data_url = f"data:image/png;base64,{base64.b64encode(test_image_data).decode()}"
|
||||||
|
|
||||||
image_bytes, mime_type = self.provider.validate_image(data_url)
|
image_bytes, mime_type = validate_image(data_url)
|
||||||
|
|
||||||
assert image_bytes == test_image_data
|
assert image_bytes == test_image_data
|
||||||
assert mime_type == "image/png"
|
assert mime_type == "image/png"
|
||||||
@@ -76,14 +37,14 @@ class TestImageValidation:
|
|||||||
def test_validate_data_url_invalid_format(self, invalid_url: str, expected_error: str) -> None:
|
def test_validate_data_url_invalid_format(self, invalid_url: str, expected_error: str) -> None:
|
||||||
"""Test validation of malformed data URL."""
|
"""Test validation of malformed data URL."""
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
self.provider.validate_image(invalid_url)
|
validate_image(invalid_url)
|
||||||
assert expected_error in str(excinfo.value)
|
assert expected_error in str(excinfo.value)
|
||||||
|
|
||||||
def test_non_data_url_treated_as_file_path(self) -> None:
|
def test_non_data_url_treated_as_file_path(self) -> None:
|
||||||
"""Test that non-data URLs are treated as file paths."""
|
"""Test that non-data URLs are treated as file paths."""
|
||||||
# Test case that's not a data URL at all
|
# Test case that's not a data URL at all
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
self.provider.validate_image("image/png;base64,abc123")
|
validate_image("image/png;base64,abc123")
|
||||||
assert "Image file not found" in str(excinfo.value) # Treated as file path
|
assert "Image file not found" in str(excinfo.value) # Treated as file path
|
||||||
|
|
||||||
def test_validate_data_url_unsupported_type(self) -> None:
|
def test_validate_data_url_unsupported_type(self) -> None:
|
||||||
@@ -91,7 +52,7 @@ class TestImageValidation:
|
|||||||
data_url = "data:image/bmp;base64,Qk0=" # BMP format
|
data_url = "data:image/bmp;base64,Qk0=" # BMP format
|
||||||
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
self.provider.validate_image(data_url)
|
validate_image(data_url)
|
||||||
assert "Unsupported image type: image/bmp" in str(excinfo.value)
|
assert "Unsupported image type: image/bmp" in str(excinfo.value)
|
||||||
|
|
||||||
def test_validate_data_url_invalid_base64(self) -> None:
|
def test_validate_data_url_invalid_base64(self) -> None:
|
||||||
@@ -99,7 +60,7 @@ class TestImageValidation:
|
|||||||
data_url = "data:image/png;base64,@@@invalid@@@"
|
data_url = "data:image/png;base64,@@@invalid@@@"
|
||||||
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
self.provider.validate_image(data_url)
|
validate_image(data_url)
|
||||||
assert "Invalid base64 data" in str(excinfo.value)
|
assert "Invalid base64 data" in str(excinfo.value)
|
||||||
|
|
||||||
def test_validate_large_data_url(self) -> None:
|
def test_validate_large_data_url(self) -> None:
|
||||||
@@ -115,11 +76,11 @@ class TestImageValidation:
|
|||||||
|
|
||||||
# Should fail with default 20MB limit
|
# Should fail with default 20MB limit
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
self.provider.validate_image(data_url)
|
validate_image(data_url)
|
||||||
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
|
assert f"Image too large: 21.0MB (max: {DEFAULT_MAX_IMAGE_SIZE_MB:.1f}MB)" in str(excinfo.value)
|
||||||
|
|
||||||
# Should succeed with higher limit
|
# Should succeed with higher limit
|
||||||
image_bytes, mime_type = self.provider.validate_image(data_url, max_size_mb=25.0)
|
image_bytes, mime_type = validate_image(data_url, max_size_mb=25.0)
|
||||||
assert len(image_bytes) == len(large_data)
|
assert len(image_bytes) == len(large_data)
|
||||||
assert mime_type == "image/png"
|
assert mime_type == "image/png"
|
||||||
|
|
||||||
@@ -135,7 +96,7 @@ class TestImageValidation:
|
|||||||
tmp_file_path = tmp_file.name
|
tmp_file_path = tmp_file.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
|
image_bytes, mime_type = validate_image(tmp_file_path)
|
||||||
|
|
||||||
assert image_bytes == test_image_data
|
assert image_bytes == test_image_data
|
||||||
assert mime_type == "image/png"
|
assert mime_type == "image/png"
|
||||||
@@ -145,7 +106,7 @@ class TestImageValidation:
|
|||||||
def test_validate_file_path_not_found(self) -> None:
|
def test_validate_file_path_not_found(self) -> None:
|
||||||
"""Test validation of non-existent file."""
|
"""Test validation of non-existent file."""
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
self.provider.validate_image("/path/to/nonexistent/image.png")
|
validate_image("/path/to/nonexistent/image.png")
|
||||||
assert "Image file not found" in str(excinfo.value)
|
assert "Image file not found" in str(excinfo.value)
|
||||||
|
|
||||||
def test_validate_file_path_unsupported_extension(self) -> None:
|
def test_validate_file_path_unsupported_extension(self) -> None:
|
||||||
@@ -156,7 +117,7 @@ class TestImageValidation:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
self.provider.validate_image(tmp_file_path)
|
validate_image(tmp_file_path)
|
||||||
assert "Unsupported image format: .bmp" in str(excinfo.value)
|
assert "Unsupported image format: .bmp" in str(excinfo.value)
|
||||||
finally:
|
finally:
|
||||||
os.unlink(tmp_file_path)
|
os.unlink(tmp_file_path)
|
||||||
@@ -170,7 +131,7 @@ class TestImageValidation:
|
|||||||
os.unlink(tmp_file_path)
|
os.unlink(tmp_file_path)
|
||||||
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
self.provider.validate_image(tmp_file_path)
|
validate_image(tmp_file_path)
|
||||||
assert "Image file not found" in str(excinfo.value)
|
assert "Image file not found" in str(excinfo.value)
|
||||||
|
|
||||||
def test_validate_image_size_limit(self) -> None:
|
def test_validate_image_size_limit(self) -> None:
|
||||||
@@ -184,7 +145,7 @@ class TestImageValidation:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
self.provider.validate_image(tmp_file_path, max_size_mb=20.0)
|
validate_image(tmp_file_path, max_size_mb=20.0)
|
||||||
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
|
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
|
||||||
finally:
|
finally:
|
||||||
os.unlink(tmp_file_path)
|
os.unlink(tmp_file_path)
|
||||||
@@ -201,11 +162,11 @@ class TestImageValidation:
|
|||||||
try:
|
try:
|
||||||
# Should fail with 1MB limit
|
# Should fail with 1MB limit
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
self.provider.validate_image(tmp_file_path, max_size_mb=1.0)
|
validate_image(tmp_file_path, max_size_mb=1.0)
|
||||||
assert "Image too large: 2.0MB (max: 1.0MB)" in str(excinfo.value)
|
assert "Image too large: 2.0MB (max: 1.0MB)" in str(excinfo.value)
|
||||||
|
|
||||||
# Should succeed with 3MB limit
|
# Should succeed with 3MB limit
|
||||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=3.0)
|
image_bytes, mime_type = validate_image(tmp_file_path, max_size_mb=3.0)
|
||||||
assert len(image_bytes) == len(data)
|
assert len(image_bytes) == len(data)
|
||||||
assert mime_type == "image/png"
|
assert mime_type == "image/png"
|
||||||
finally:
|
finally:
|
||||||
@@ -222,12 +183,12 @@ class TestImageValidation:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Should succeed with default limit (20MB)
|
# Should succeed with default limit (20MB)
|
||||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
|
image_bytes, mime_type = validate_image(tmp_file_path)
|
||||||
assert len(image_bytes) == len(data)
|
assert len(image_bytes) == len(data)
|
||||||
assert mime_type == "image/jpeg"
|
assert mime_type == "image/jpeg"
|
||||||
|
|
||||||
# Should also succeed when explicitly passing None
|
# Should also succeed when explicitly passing None
|
||||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=None)
|
image_bytes, mime_type = validate_image(tmp_file_path, max_size_mb=None)
|
||||||
assert len(image_bytes) == len(data)
|
assert len(image_bytes) == len(data)
|
||||||
assert mime_type == "image/jpeg"
|
assert mime_type == "image/jpeg"
|
||||||
finally:
|
finally:
|
||||||
@@ -249,7 +210,7 @@ class TestImageValidation:
|
|||||||
tmp_file_path = tmp_file.name
|
tmp_file_path = tmp_file.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
|
image_bytes, mime_type = validate_image(tmp_file_path)
|
||||||
assert mime_type == expected_mime
|
assert mime_type == expected_mime
|
||||||
assert image_bytes == b"dummy image data"
|
assert image_bytes == b"dummy image data"
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
94
utils/image_utils.py
Normal file
94
utils/image_utils.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
"""Utility helpers for validating image inputs."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import binascii
|
||||||
|
import os
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
from utils.file_types import IMAGES, get_image_mime_type
|
||||||
|
|
||||||
|
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
|
||||||
|
|
||||||
|
__all__ = ["DEFAULT_MAX_IMAGE_SIZE_MB", "validate_image"]
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_mime_types() -> Iterable[str]:
|
||||||
|
"""Return the MIME types permitted by the IMAGES whitelist."""
|
||||||
|
return (get_image_mime_type(ext) for ext in IMAGES)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_image(image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
|
||||||
|
"""Validate a user-supplied image path or data URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Either a filesystem path or a data URL.
|
||||||
|
max_size_mb: Optional size limit (defaults to ``DEFAULT_MAX_IMAGE_SIZE_MB``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple ``(image_bytes, mime_type)`` ready for upstream providers.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: When the image is missing, malformed, or exceeds limits.
|
||||||
|
"""
|
||||||
|
if max_size_mb is None:
|
||||||
|
max_size_mb = DEFAULT_MAX_IMAGE_SIZE_MB
|
||||||
|
|
||||||
|
if image_path.startswith("data:"):
|
||||||
|
return _validate_data_url(image_path, max_size_mb)
|
||||||
|
|
||||||
|
return _validate_file_path(image_path, max_size_mb)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_data_url(image_data_url: str, max_size_mb: float) -> tuple[bytes, str]:
|
||||||
|
"""Validate a data URL and return image bytes plus MIME type."""
|
||||||
|
try:
|
||||||
|
header, data = image_data_url.split(",", 1)
|
||||||
|
mime_type = header.split(";")[0].split(":")[1]
|
||||||
|
except (ValueError, IndexError) as exc:
|
||||||
|
raise ValueError(f"Invalid data URL format: {exc}")
|
||||||
|
|
||||||
|
valid_mime_types = list(_valid_mime_types())
|
||||||
|
if mime_type not in valid_mime_types:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported image type: {mime}. Supported types: {supported}".format(
|
||||||
|
mime=mime_type, supported=", ".join(valid_mime_types)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_bytes = base64.b64decode(data)
|
||||||
|
except binascii.Error as exc:
|
||||||
|
raise ValueError(f"Invalid base64 data: {exc}")
|
||||||
|
|
||||||
|
_validate_size(image_bytes, max_size_mb)
|
||||||
|
return image_bytes, mime_type
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_file_path(file_path: str, max_size_mb: float) -> tuple[bytes, str]:
|
||||||
|
"""Validate an image loaded from the filesystem."""
|
||||||
|
try:
|
||||||
|
with open(file_path, "rb") as handle:
|
||||||
|
image_bytes = handle.read()
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise ValueError(f"Image file not found: {file_path}")
|
||||||
|
except OSError as exc:
|
||||||
|
raise ValueError(f"Failed to read image file: {exc}")
|
||||||
|
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
if ext not in IMAGES:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported image format: {ext}. Supported formats: {supported}".format(
|
||||||
|
ext=ext, supported=", ".join(sorted(IMAGES))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mime_type = get_image_mime_type(ext)
|
||||||
|
_validate_size(image_bytes, max_size_mb)
|
||||||
|
return image_bytes, mime_type
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_size(image_bytes: bytes, max_size_mb: float) -> None:
|
||||||
|
"""Ensure the image does not exceed the configured size limit."""
|
||||||
|
size_mb = len(image_bytes) / (1024 * 1024)
|
||||||
|
if size_mb > max_size_mb:
|
||||||
|
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
|
||||||
Reference in New Issue
Block a user