refactor: removed hard coded checks, use model capabilities instead

This commit is contained in:
Fahad
2025-10-02 08:32:51 +04:00
parent bb138e2fb5
commit 250545e34f
4 changed files with 25 additions and 77 deletions

View File

@@ -402,8 +402,9 @@ class DIALModelProvider(OpenAICompatibleProvider):
if not self.validate_model_name(model_name): if not self.validate_model_name(model_name):
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}") raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
# Validate parameters # Validate parameters and fetch capabilities
self.validate_parameters(model_name, temperature) self.validate_parameters(model_name, temperature)
capabilities = self.get_capabilities(model_name)
# Prepare messages # Prepare messages
messages = [] messages = []
@@ -414,7 +415,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
if prompt: if prompt:
user_message_content.append({"type": "text", "text": prompt}) user_message_content.append({"type": "text", "text": prompt})
if images and self._supports_vision(model_name): if images and capabilities.supports_images:
for img_path in images: for img_path in images:
processed_image = self._process_image(img_path) processed_image = self._process_image(img_path)
if processed_image: if processed_image:
@@ -437,13 +438,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
"messages": messages, "messages": messages,
} }
# Check model capabilities # Determine temperature support from capabilities
try: supports_temperature = capabilities.supports_temperature
capabilities = self.get_capabilities(model_name)
supports_temperature = capabilities.supports_temperature
except Exception as e:
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
supports_temperature = True
# Add temperature parameter if supported # Add temperature parameter if supported
if supports_temperature: if supports_temperature:
@@ -513,23 +509,6 @@ class DIALModelProvider(OpenAICompatibleProvider):
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}" f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
) )
def _supports_vision(self, model_name: str) -> bool:
"""Check if the model supports vision (image processing).
Args:
model_name: Model name to check
Returns:
True if model supports vision, False otherwise
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name in self.SUPPORTED_MODELS:
return self.SUPPORTED_MODELS[resolved_name].supports_images
# Fall back to parent implementation for unknown models
return super()._supports_vision(model_name)
def close(self): def close(self):
"""Clean up HTTP clients when provider is closed.""" """Clean up HTTP clients when provider is closed."""
logger.info("Closing DIAL provider HTTP clients...") logger.info("Closing DIAL provider HTTP clients...")

View File

@@ -181,9 +181,10 @@ class GeminiModelProvider(ModelProvider):
**kwargs, **kwargs,
) -> ModelResponse: ) -> ModelResponse:
"""Generate content using Gemini model.""" """Generate content using Gemini model."""
# Validate parameters # Validate parameters and fetch capabilities
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
self.validate_parameters(model_name, temperature) self.validate_parameters(model_name, temperature)
capabilities = self.get_capabilities(model_name)
# Prepare content parts (text and potentially images) # Prepare content parts (text and potentially images)
parts = [] parts = []
@@ -197,7 +198,7 @@ class GeminiModelProvider(ModelProvider):
parts.append({"text": full_prompt}) parts.append({"text": full_prompt})
# Add images if provided and model supports vision # Add images if provided and model supports vision
if images and self._supports_vision(resolved_name): if images and capabilities.supports_images:
for image_path in images: for image_path in images:
try: try:
image_part = self._process_image(image_path) image_part = self._process_image(image_path)
@@ -207,7 +208,7 @@ class GeminiModelProvider(ModelProvider):
logger.warning(f"Failed to process image {image_path}: {e}") logger.warning(f"Failed to process image {image_path}: {e}")
# Continue with other images and text # Continue with other images and text
continue continue
elif images and not self._supports_vision(resolved_name): elif images and not capabilities.supports_images:
logger.warning(f"Model {resolved_name} does not support images, ignoring {len(images)} image(s)") logger.warning(f"Model {resolved_name} does not support images, ignoring {len(images)} image(s)")
# Create contents structure # Create contents structure
@@ -224,7 +225,6 @@ class GeminiModelProvider(ModelProvider):
generation_config.max_output_tokens = max_output_tokens generation_config.max_output_tokens = max_output_tokens
# Add thinking configuration for models that support it # Add thinking configuration for models that support it
capabilities = self.get_capabilities(model_name)
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS: if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
# Get model's max thinking tokens and calculate actual budget # Get model's max thinking tokens and calculate actual budget
model_config = self.SUPPORTED_MODELS.get(resolved_name) model_config = self.SUPPORTED_MODELS.get(resolved_name)
@@ -457,18 +457,6 @@ class GeminiModelProvider(ModelProvider):
return usage return usage
def _supports_vision(self, model_name: str) -> bool:
"""Check if the model supports vision (image processing)."""
# Gemini 2.5 models support vision
vision_models = {
"gemini-2.5-flash",
"gemini-2.5-pro",
"gemini-2.0-flash",
"gemini-1.5-pro",
"gemini-1.5-flash",
}
return model_name in vision_models
def _is_error_retryable(self, error: Exception) -> bool: def _is_error_retryable(self, error: Exception) -> bool:
"""Determine if an error should be retried based on structured error codes. """Determine if an error should be retried based on structured error codes.

View File

@@ -482,12 +482,19 @@ class OpenAICompatibleProvider(ModelProvider):
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
# Resolve capabilities once for vision/temperature checks
try:
capabilities = self.get_capabilities(model_name)
except Exception as exc:
logging.debug(f"Falling back to generic capabilities for {model_name}: {exc}")
capabilities = None
# Prepare user message with text and potentially images # Prepare user message with text and potentially images
user_content = [] user_content = []
user_content.append({"type": "text", "text": prompt}) user_content.append({"type": "text", "text": prompt})
# Add images if provided and model supports vision # Add images if provided and model supports vision
if images and self._supports_vision(model_name): if images and capabilities and capabilities.supports_images:
for image_path in images: for image_path in images:
try: try:
image_content = self._process_image(image_path) image_content = self._process_image(image_path)
@@ -497,7 +504,7 @@ class OpenAICompatibleProvider(ModelProvider):
logging.warning(f"Failed to process image {image_path}: {e}") logging.warning(f"Failed to process image {image_path}: {e}")
# Continue with other images and text # Continue with other images and text
continue continue
elif images and not self._supports_vision(model_name): elif images and (not capabilities or not capabilities.supports_images):
logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)") logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
# Add user message # Add user message
@@ -727,31 +734,6 @@ class OpenAICompatibleProvider(ModelProvider):
""" """
return False return False
def _supports_vision(self, model_name: str) -> bool:
"""Check if the model supports vision (image processing).
Default implementation for OpenAI-compatible providers.
Subclasses should override with specific model support.
"""
# Common vision-capable models - only include models that actually support images
vision_models = {
"gpt-5",
"gpt-5-mini",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4-vision-preview",
"gpt-4.1-2025-04-14",
"o3",
"o3-mini",
"o3-pro",
"o4-mini",
# Note: Claude models would be handled by a separate provider
}
supports = model_name.lower() in vision_models
logging.debug(f"Model '{model_name}' vision support: {supports}")
return supports
def _is_error_retryable(self, error: Exception) -> bool: def _is_error_retryable(self, error: Exception) -> bool:
"""Determine if an error should be retried based on structured error codes. """Determine if an error should be retried based on structured error codes.

View File

@@ -140,17 +140,16 @@ class TestDIALProvider:
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False) @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
@patch("utils.model_restrictions._restriction_service", None) @patch("utils.model_restrictions._restriction_service", None)
def test_supports_vision(self): def test_supports_vision(self):
"""Test vision support detection.""" """Test vision support detection through model capabilities."""
provider = DIALModelProvider("test-key") provider = DIALModelProvider("test-key")
# Test models with vision support assert provider.get_capabilities("o3-2025-04-16").supports_images is True
assert provider._supports_vision("o3-2025-04-16") is True assert provider.get_capabilities("o3").supports_images is True # Via resolution
assert provider._supports_vision("o3") is True # Via resolution assert provider.get_capabilities("anthropic.claude-opus-4.1-20250805-v1:0").supports_images is True
assert provider._supports_vision("anthropic.claude-opus-4.1-20250805-v1:0") is True assert provider.get_capabilities("gemini-2.5-pro-preview-05-06").supports_images is True
assert provider._supports_vision("gemini-2.5-pro-preview-05-06") is True
# Test unknown model (falls back to parent implementation) with pytest.raises(ValueError):
assert provider._supports_vision("unknown-model") is False provider.get_capabilities("unknown-model")
@patch("openai.OpenAI") # Mock the OpenAI class directly from openai module @patch("openai.OpenAI") # Mock the OpenAI class directly from openai module
def test_generate_content_with_alias(self, mock_openai_class): def test_generate_content_with_alias(self, mock_openai_class):