Files
my-pal-mcp-server/tests/test_image_validation.py
Nate Parsons 70d6cf8b54 refactor: Extract image validation to provider base class
Consolidates duplicated image validation logic from individual providers
into a reusable base class method. This improves maintainability and
ensures consistent validation across all providers.

- Added validate_image() method to ModelProvider base class
- Supports both file paths and data URLs
- Validates image format, size, and MIME types
- Added DEFAULT_MAX_IMAGE_SIZE_MB class constant (20MB)
- Refactored Gemini and OpenAI providers to use base validation
- Added comprehensive test suite with 19 tests
- Used minimal mocking approach with concrete test provider class
2025-07-12 21:51:24 -07:00

304 lines
12 KiB
Python

"""Tests for provider-independent image validation."""
import base64
import os
import tempfile
from typing import Optional
from unittest.mock import Mock, patch
import pytest
from providers.base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType
class MinimalTestProvider(ModelProvider):
"""Minimal concrete provider for testing base class methods."""
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def count_tokens(self, text: str, model_name: str) -> int:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def get_provider_type(self) -> ProviderType:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def validate_model_name(self, model_name: str) -> bool:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
def supports_thinking_mode(self, model_name: str) -> bool:
"""Not needed for image validation tests."""
raise NotImplementedError("Not needed for image validation tests")
class TestImageValidation:
"""Test suite for image validation functionality."""
def setup_method(self) -> None:
"""Set up test fixtures."""
# Create a minimal concrete provider instance for testing base class methods
self.provider = MinimalTestProvider(api_key="test-key")
def test_validate_data_url_valid(self) -> None:
"""Test validation of valid data URL."""
# Create a small test image (1x1 PNG)
test_image_data = base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
)
data_url = f"data:image/png;base64,{base64.b64encode(test_image_data).decode()}"
image_bytes, mime_type = self.provider.validate_image(data_url)
assert image_bytes == test_image_data
assert mime_type == "image/png"
@pytest.mark.parametrize(
"invalid_url,expected_error",
[
("data:image/png", "Invalid data URL format"), # Missing base64 part
("data:image/png;base64", "Invalid data URL format"), # Missing data
("data:text/plain;base64,dGVzdA==", "Unsupported image type"), # Not an image
],
)
def test_validate_data_url_invalid_format(self, invalid_url: str, expected_error: str) -> None:
"""Test validation of malformed data URL."""
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(invalid_url)
assert expected_error in str(excinfo.value)
def test_non_data_url_treated_as_file_path(self) -> None:
"""Test that non-data URLs are treated as file paths."""
# Test case that's not a data URL at all
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image("image/png;base64,abc123")
assert "Image file not found" in str(excinfo.value) # Treated as file path
def test_validate_data_url_unsupported_type(self) -> None:
"""Test validation of unsupported image type in data URL."""
data_url = "" # BMP format
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url)
assert "Unsupported image type: image/bmp" in str(excinfo.value)
def test_validate_data_url_invalid_base64(self) -> None:
"""Test validation of data URL with invalid base64."""
data_url = "data:image/png;base64,@@@invalid@@@"
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url)
assert "Invalid base64 data" in str(excinfo.value)
def test_validate_large_data_url(self) -> None:
"""Test validation of large data URL to ensure size limits work."""
# Create a large image (21MB)
large_data = b"x" * (21 * 1024 * 1024) # 21MB
# Encode as base64 and create data URL
import base64
encoded_data = base64.b64encode(large_data).decode()
data_url = f"data:image/png;base64,{encoded_data}"
# Should fail with default 20MB limit
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(data_url)
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
# Should succeed with higher limit
image_bytes, mime_type = self.provider.validate_image(data_url, max_size_mb=25.0)
assert len(image_bytes) == len(large_data)
assert mime_type == "image/png"
def test_validate_file_path_valid(self) -> None:
"""Test validation of valid image file."""
# Create a temporary image file
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
# Write a small test PNG
test_image_data = base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
)
tmp_file.write(test_image_data)
tmp_file_path = tmp_file.name
try:
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
assert image_bytes == test_image_data
assert mime_type == "image/png"
finally:
os.unlink(tmp_file_path)
def test_validate_file_path_not_found(self) -> None:
"""Test validation of non-existent file."""
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image("/path/to/nonexistent/image.png")
assert "Image file not found" in str(excinfo.value)
def test_validate_file_path_unsupported_extension(self) -> None:
"""Test validation of file with unsupported extension."""
with tempfile.NamedTemporaryFile(suffix=".bmp", delete=False) as tmp_file:
tmp_file.write(b"dummy data")
tmp_file_path = tmp_file.name
try:
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path)
assert "Unsupported image format: .bmp" in str(excinfo.value)
finally:
os.unlink(tmp_file_path)
def test_validate_file_path_read_error(self) -> None:
"""Test validation when file cannot be read."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
tmp_file_path = tmp_file.name
# Remove the file but keep the path
os.unlink(tmp_file_path)
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path)
assert "Image file not found" in str(excinfo.value)
def test_validate_image_size_limit(self) -> None:
"""Test validation of image size limits."""
# Create a large "image" (just random data)
large_data = b"x" * (21 * 1024 * 1024) # 21MB
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
tmp_file.write(large_data)
tmp_file_path = tmp_file.name
try:
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path, max_size_mb=20.0)
assert "Image too large: 21.0MB (max: 20.0MB)" in str(excinfo.value)
finally:
os.unlink(tmp_file_path)
def test_validate_image_custom_size_limit(self) -> None:
"""Test validation with custom size limit."""
# Create a 2MB "image"
data = b"x" * (2 * 1024 * 1024)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
tmp_file.write(data)
tmp_file_path = tmp_file.name
try:
# Should fail with 1MB limit
with pytest.raises(ValueError) as excinfo:
self.provider.validate_image(tmp_file_path, max_size_mb=1.0)
assert "Image too large: 2.0MB (max: 1.0MB)" in str(excinfo.value)
# Should succeed with 3MB limit
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=3.0)
assert len(image_bytes) == len(data)
assert mime_type == "image/png"
finally:
os.unlink(tmp_file_path)
def test_validate_image_default_size_limit(self) -> None:
"""Test validation with default size limit (None)."""
# Create a small image that's under the default limit
data = b"x" * (1024 * 1024) # 1MB
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
tmp_file.write(data)
tmp_file_path = tmp_file.name
try:
# Should succeed with default limit (20MB)
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
assert len(image_bytes) == len(data)
assert mime_type == "image/jpeg"
# Should also succeed when explicitly passing None
image_bytes, mime_type = self.provider.validate_image(tmp_file_path, max_size_mb=None)
assert len(image_bytes) == len(data)
assert mime_type == "image/jpeg"
finally:
os.unlink(tmp_file_path)
def test_validate_all_supported_formats(self) -> None:
"""Test validation of all supported image formats."""
supported_formats = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
}
for ext, expected_mime in supported_formats.items():
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file:
tmp_file.write(b"dummy image data")
tmp_file_path = tmp_file.name
try:
image_bytes, mime_type = self.provider.validate_image(tmp_file_path)
assert mime_type == expected_mime
assert image_bytes == b"dummy image data"
finally:
os.unlink(tmp_file_path)
class TestProviderIntegration:
"""Test image validation integration with different providers."""
@patch("providers.gemini.logger")
def test_gemini_provider_uses_validation(self, mock_logger: Mock) -> None:
"""Test that Gemini provider uses the base validation."""
from providers.gemini import GeminiModelProvider
# Create a provider instance
provider = GeminiModelProvider(api_key="test-key")
# Test with non-existent file
result = provider._process_image("/nonexistent/image.png")
assert result is None
mock_logger.warning.assert_called_with("Image file not found: /nonexistent/image.png")
@patch("providers.openai_compatible.logging")
def test_openai_compatible_provider_uses_validation(self, mock_logging: Mock) -> None:
"""Test that OpenAI-compatible providers use the base validation."""
from providers.xai import XAIModelProvider
# Create a provider instance (XAI inherits from OpenAICompatibleProvider)
provider = XAIModelProvider(api_key="test-key")
# Test with non-existent file
result = provider._process_image("/nonexistent/image.png")
assert result is None
mock_logging.warning.assert_called_with("Image file not found: /nonexistent/image.png")
def test_data_url_preservation(self) -> None:
"""Test that data URLs are properly preserved through validation."""
from providers.xai import XAIModelProvider
provider = XAIModelProvider(api_key="test-key")
# Valid data URL
data_url = ""
result = provider._process_image(data_url)
assert result is not None
assert result["type"] == "image_url"
assert result["image_url"]["url"] == data_url