refactor: moved image related code out of base provider into a separate utility
This commit is contained in:
@@ -1,57 +1,18 @@
|
||||
"""Tests for provider-independent image validation."""
|
||||
"""Tests for image validation utility helpers."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ModelProvider
|
||||
from providers.shared import ModelCapabilities, ModelResponse, ProviderType
|
||||
|
||||
|
||||
class MinimalTestProvider(ModelProvider):
|
||||
"""Minimal concrete provider for testing base class methods."""
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Not needed for image validation tests."""
|
||||
raise NotImplementedError("Not needed for image validation tests")
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Not needed for image validation tests."""
|
||||
raise NotImplementedError("Not needed for image validation tests")
|
||||
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
"""Not needed for image validation tests."""
|
||||
raise NotImplementedError("Not needed for image validation tests")
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Not needed for image validation tests."""
|
||||
raise NotImplementedError("Not needed for image validation tests")
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Not needed for image validation tests."""
|
||||
raise NotImplementedError("Not needed for image validation tests")
|
||||
from utils.image_utils import DEFAULT_MAX_IMAGE_SIZE_MB, validate_image
|
||||
|
||||
|
||||
class TestImageValidation:
|
||||
"""Test suite for image validation functionality."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
# Create a minimal concrete provider instance for testing base class methods
|
||||
self.provider = MinimalTestProvider(api_key="test-key")
|
||||
|
||||
def test_validate_data_url_valid(self) -> None:
|
||||
"""Test validation of valid data URL."""
|
||||
# Create a small test image (1x1 PNG)
|
||||
@@ -60,7 +21,7 @@ class TestImageValidation:
|
||||
)
|
||||
data_url = f"data:image/png;base64,{base64.b64encode(test_image_data).decode()}"
|
||||
|
||||
image_bytes, mime_type = self.provider.validate_image(data_url)
|
||||
image_bytes, mime_type = validate_image(data_url)
|
||||
|
||||
assert image_bytes == test_image_data
|
||||
assert mime_type == "image/png"
|
||||
@@ -76,14 +37,14 @@ class TestImageValidation:
|
||||
def test_validate_data_url_invalid_format(self, invalid_url: str, expected_error: str) -> None:
|
||||
"""Test validation of malformed data URL."""
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(invalid_url)
|
||||
validate_image(invalid_url)
|
||||
assert expected_error in str(excinfo.value)
|
||||
|
||||
def test_non_data_url_treated_as_file_path(self) -> None:
|
||||
"""Test that non-data URLs are treated as file paths."""
|
||||
# Test case that's not a data URL at all
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image("image/png;base64,abc123")
|
||||
validate_image("image/png;base64,abc123")
|
||||
assert "Image file not found" in str(excinfo.value) # Treated as file path
|
||||
|
||||
def test_validate_data_url_unsupported_type(self) -> None:
|
||||
@@ -91,7 +52,7 @@ class TestImageValidation:
|
||||
data_url = "data:image/bmp;base64,Qk0=" # BMP format
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(data_url)
|
||||
validate_image(data_url)
|
||||
assert "Unsupported image type: image/bmp" in str(excinfo.value)
|
||||
|
||||
def test_validate_data_url_invalid_base64(self) -> None:
|
||||
@@ -99,7 +60,7 @@ class TestImageValidation:
|
||||
data_url = "data:image/png;base64,@@@invalid@@@"
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(data_url)
|
||||
validate_image(data_url)
|
||||
assert "Invalid base64 data" in str(excinfo.value)
|
||||
|
||||
def test_validate_large_data_url(self) -> None:
|
||||
@@ -115,11 +76,11 @@ class TestImageValidation:
|
||||
|
||||
# Should fail with default 20MB limit
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(data_url)
|
||||
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
|
||||
validate_image(data_url)
|
||||
assert f"Image too large: 21.0MB (max: {DEFAULT_MAX_IMAGE_SIZE_MB:.1f}MB)" in str(excinfo.value)
|
||||
|
||||
# Should succeed with higher limit
|
||||
image_bytes, mime_type = self.provider.validate_image(data_url, max_size_mb=25.0)
|
||||
image_bytes, mime_type = validate_image(data_url, max_size_mb=25.0)
|
||||
assert len(image_bytes) == len(large_data)
|
||||
assert mime_type == "image/png"
|
||||
|
||||
@@ -135,7 +96,7 @@ class TestImageValidation:
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
try:
|
||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
|
||||
image_bytes, mime_type = validate_image(tmp_file_path)
|
||||
|
||||
assert image_bytes == test_image_data
|
||||
assert mime_type == "image/png"
|
||||
@@ -145,7 +106,7 @@ class TestImageValidation:
|
||||
def test_validate_file_path_not_found(self) -> None:
|
||||
"""Test validation of non-existent file."""
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image("/path/to/nonexistent/image.png")
|
||||
validate_image("/path/to/nonexistent/image.png")
|
||||
assert "Image file not found" in str(excinfo.value)
|
||||
|
||||
def test_validate_file_path_unsupported_extension(self) -> None:
|
||||
@@ -156,7 +117,7 @@ class TestImageValidation:
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(tmp_file_path)
|
||||
validate_image(tmp_file_path)
|
||||
assert "Unsupported image format: .bmp" in str(excinfo.value)
|
||||
finally:
|
||||
os.unlink(tmp_file_path)
|
||||
@@ -170,7 +131,7 @@ class TestImageValidation:
|
||||
os.unlink(tmp_file_path)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(tmp_file_path)
|
||||
validate_image(tmp_file_path)
|
||||
assert "Image file not found" in str(excinfo.value)
|
||||
|
||||
def test_validate_image_size_limit(self) -> None:
|
||||
@@ -184,7 +145,7 @@ class TestImageValidation:
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(tmp_file_path, max_size_mb=20.0)
|
||||
validate_image(tmp_file_path, max_size_mb=20.0)
|
||||
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
|
||||
finally:
|
||||
os.unlink(tmp_file_path)
|
||||
@@ -201,11 +162,11 @@ class TestImageValidation:
|
||||
try:
|
||||
# Should fail with 1MB limit
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.provider.validate_image(tmp_file_path, max_size_mb=1.0)
|
||||
validate_image(tmp_file_path, max_size_mb=1.0)
|
||||
assert "Image too large: 2.0MB (max: 1.0MB)" in str(excinfo.value)
|
||||
|
||||
# Should succeed with 3MB limit
|
||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=3.0)
|
||||
image_bytes, mime_type = validate_image(tmp_file_path, max_size_mb=3.0)
|
||||
assert len(image_bytes) == len(data)
|
||||
assert mime_type == "image/png"
|
||||
finally:
|
||||
@@ -222,12 +183,12 @@ class TestImageValidation:
|
||||
|
||||
try:
|
||||
# Should succeed with default limit (20MB)
|
||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
|
||||
image_bytes, mime_type = validate_image(tmp_file_path)
|
||||
assert len(image_bytes) == len(data)
|
||||
assert mime_type == "image/jpeg"
|
||||
|
||||
# Should also succeed when explicitly passing None
|
||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=None)
|
||||
image_bytes, mime_type = validate_image(tmp_file_path, max_size_mb=None)
|
||||
assert len(image_bytes) == len(data)
|
||||
assert mime_type == "image/jpeg"
|
||||
finally:
|
||||
@@ -249,7 +210,7 @@ class TestImageValidation:
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
try:
|
||||
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
|
||||
image_bytes, mime_type = validate_image(tmp_file_path)
|
||||
assert mime_type == expected_mime
|
||||
assert image_bytes == b"dummy image data"
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user