refactor: moved image related code out of base provider into a separate utility

This commit is contained in:
Fahad
2025-10-02 11:23:15 +04:00
parent a254ff2220
commit 14a35afa1d
5 changed files with 122 additions and 152 deletions

View File

@@ -1,57 +1,18 @@
"""Tests for provider-independent image validation."""
"""Tests for image validation utility helpers."""
import base64
import os
import tempfile
from typing import Optional
from unittest.mock import Mock, patch
import pytest
from providers.base import ModelProvider
from providers.shared import ModelCapabilities, ModelResponse, ProviderType
class MinimalTestProvider(ModelProvider):
"""Minimal concrete provider for testing base class methods."""
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def count_tokens(self, text: str, model_name: str) -> int:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def get_provider_type(self) -> ProviderType:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def validate_model_name(self, model_name: str) -> bool:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
from utils.image_utils import DEFAULT_MAX_IMAGE_SIZE_MB, validate_image
class TestImageValidation:
"""Test suite for image validation functionality."""
def setup_method(self) -> None:
"""Set up test fixtures."""
# Create a minimal concrete provider instance for testing base class methods
self.provider = MinimalTestProvider(api_key="test-key")
def test_validate_data_url_valid(self) -> None:
"""Test validation of valid data URL."""
# Create a small test image (1x1 PNG)
@@ -60,7 +21,7 @@ class TestImageValidation:
)
data_url = f"data:image/png;base64,{base64.b64encode(test_image_data).decode()}"
image_bytes, mime_type = self.provider.validate_image(data_url)
image_bytes, mime_type = validate_image(data_url)
assert image_bytes == test_image_data
assert mime_type == "image/png"
@@ -76,14 +37,14 @@ class TestImageValidation:
def test_validate_data_url_invalid_format(self, invalid_url: str, expected_error: str) -> None:
"""Test validation of malformed data URL."""
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(invalid_url)
validate_image(invalid_url)
assert expected_error in str(excinfo.value)
def test_non_data_url_treated_as_file_path(self) -> None:
"""Test that non-data URLs are treated as file paths."""
# Test case that's not a data URL at all
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image("image/png;base64,abc123")
validate_image("image/png;base64,abc123")
assert "Image file not found" in str(excinfo.value) # Treated as file path
def test_validate_data_url_unsupported_type(self) -> None:
@@ -91,7 +52,7 @@ class TestImageValidation:
data_url = "data:image/bmp;base64,Qk0=" # BMP format
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url)
validate_image(data_url)
assert "Unsupported image type: image/bmp" in str(excinfo.value)
def test_validate_data_url_invalid_base64(self) -> None:
@@ -99,7 +60,7 @@ class TestImageValidation:
data_url = "data:image/png;base64,@@@invalid@@@"
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url)
validate_image(data_url)
assert "Invalid base64 data" in str(excinfo.value)
def test_validate_large_data_url(self) -> None:
@@ -115,11 +76,11 @@ class TestImageValidation:
# Should fail with default 20MB limit
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url)
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
validate_image(data_url)
assert f"Image too large: 21.0MB (max: {DEFAULT_MAX_IMAGE_SIZE_MB:.1f}MB)" in str(excinfo.value)
# Should succeed with higher limit
image_bytes, mime_type = self.provider.validate_image(data_url, max_size_mb=25.0)
image_bytes, mime_type = validate_image(data_url, max_size_mb=25.0)
assert len(image_bytes) == len(large_data)
assert mime_type == "image/png"
@@ -135,7 +96,7 @@ class TestImageValidation:
tmp_file_path = tmp_file.name
try:
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
image_bytes, mime_type = validate_image(tmp_file_path)
assert image_bytes == test_image_data
assert mime_type == "image/png"
@@ -145,7 +106,7 @@ class TestImageValidation:
def test_validate_file_path_not_found(self) -> None:
"""Test validation of non-existent file."""
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image("/path/to/nonexistent/image.png")
validate_image("/path/to/nonexistent/image.png")
assert "Image file not found" in str(excinfo.value)
def test_validate_file_path_unsupported_extension(self) -> None:
@@ -156,7 +117,7 @@ class TestImageValidation:
try:
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path)
validate_image(tmp_file_path)
assert "Unsupported image format: .bmp" in str(excinfo.value)
finally:
os.unlink(tmp_file_path)
@@ -170,7 +131,7 @@ class TestImageValidation:
os.unlink(tmp_file_path)
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path)
validate_image(tmp_file_path)
assert "Image file not found" in str(excinfo.value)
def test_validate_image_size_limit(self) -> None:
@@ -184,7 +145,7 @@ class TestImageValidation:
try:
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path, max_size_mb=20.0)
validate_image(tmp_file_path, max_size_mb=20.0)
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
finally:
os.unlink(tmp_file_path)
@@ -201,11 +162,11 @@ class TestImageValidation:
try:
# Should fail with 1MB limit
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path, max_size_mb=1.0)
validate_image(tmp_file_path, max_size_mb=1.0)
assert "Image too large: 2.0MB (max: 1.0MB)" in str(excinfo.value)
# Should succeed with 3MB limit
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=3.0)
image_bytes, mime_type = validate_image(tmp_file_path, max_size_mb=3.0)
assert len(image_bytes) == len(data)
assert mime_type == "image/png"
finally:
@@ -222,12 +183,12 @@ class TestImageValidation:
try:
# Should succeed with default limit (20MB)
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
image_bytes, mime_type = validate_image(tmp_file_path)
assert len(image_bytes) == len(data)
assert mime_type == "image/jpeg"
# Should also succeed when explicitly passing None
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=None)
image_bytes, mime_type = validate_image(tmp_file_path, max_size_mb=None)
assert len(image_bytes) == len(data)
assert mime_type == "image/jpeg"
finally:
@@ -249,7 +210,7 @@ class TestImageValidation:
tmp_file_path = tmp_file.name
try:
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
image_bytes, mime_type = validate_image(tmp_file_path)
assert mime_type == expected_mime
assert image_bytes == b"dummy image data"
finally: