From 250545e34f8d4f8026bfebb3171f3c2bc40f4692 Mon Sep 17 00:00:00 2001 From: Fahad Date: Thu, 2 Oct 2025 08:32:51 +0400 Subject: [PATCH] refactor: removed hard coded checks, use model capabilities instead --- providers/dial.py | 31 +++++------------------------ providers/gemini.py | 20 ++++--------------- providers/openai_compatible.py | 36 +++++++++------------------------- tests/test_dial_provider.py | 15 +++++++------- 4 files changed, 25 insertions(+), 77 deletions(-) diff --git a/providers/dial.py b/providers/dial.py index 8ca5b9c..59910cc 100644 --- a/providers/dial.py +++ b/providers/dial.py @@ -402,8 +402,9 @@ class DIALModelProvider(OpenAICompatibleProvider): if not self.validate_model_name(model_name): raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}") - # Validate parameters + # Validate parameters and fetch capabilities self.validate_parameters(model_name, temperature) + capabilities = self.get_capabilities(model_name) # Prepare messages messages = [] @@ -414,7 +415,7 @@ class DIALModelProvider(OpenAICompatibleProvider): if prompt: user_message_content.append({"type": "text", "text": prompt}) - if images and self._supports_vision(model_name): + if images and capabilities.supports_images: for img_path in images: processed_image = self._process_image(img_path) if processed_image: @@ -437,13 +438,8 @@ class DIALModelProvider(OpenAICompatibleProvider): "messages": messages, } - # Check model capabilities - try: - capabilities = self.get_capabilities(model_name) - supports_temperature = capabilities.supports_temperature - except Exception as e: - logger.debug(f"Failed to check temperature support for {model_name}: {e}") - supports_temperature = True + # Determine temperature support from capabilities + supports_temperature = capabilities.supports_temperature # Add temperature parameter if supported if supports_temperature: @@ -513,23 +509,6 @@ class DIALModelProvider(OpenAICompatibleProvider): f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}" ) - def _supports_vision(self, model_name: str) -> bool: - """Check if the model supports vision (image processing). - - Args: - model_name: Model name to check - - Returns: - True if model supports vision, False otherwise - """ - resolved_name = self._resolve_model_name(model_name) - - if resolved_name in self.SUPPORTED_MODELS: - return self.SUPPORTED_MODELS[resolved_name].supports_images - - # Fall back to parent implementation for unknown models - return super()._supports_vision(model_name) - def close(self): """Clean up HTTP clients when provider is closed.""" logger.info("Closing DIAL provider HTTP clients...") diff --git a/providers/gemini.py b/providers/gemini.py index 9f2bc26..44f947d 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -181,9 +181,10 @@ class GeminiModelProvider(ModelProvider): **kwargs, ) -> ModelResponse: """Generate content using Gemini model.""" - # Validate parameters + # Validate parameters and fetch capabilities resolved_name = self._resolve_model_name(model_name) self.validate_parameters(model_name, temperature) + capabilities = self.get_capabilities(model_name) # Prepare content parts (text and potentially images) parts = [] @@ -197,7 +198,7 @@ class GeminiModelProvider(ModelProvider): parts.append({"text": full_prompt}) # Add images if provided and model supports vision - if images and self._supports_vision(resolved_name): + if images and capabilities.supports_images: for image_path in images: try: image_part = self._process_image(image_path) @@ -207,7 +208,7 @@ class GeminiModelProvider(ModelProvider): logger.warning(f"Failed to process image {image_path}: {e}") # Continue with other images and text continue - elif images and not self._supports_vision(resolved_name): + elif images and not capabilities.supports_images: logger.warning(f"Model {resolved_name} does not support images, ignoring {len(images)} image(s)") # Create contents structure @@ -224,7 +225,6 @@ class GeminiModelProvider(ModelProvider): generation_config.max_output_tokens = max_output_tokens # Add thinking configuration for models that support it - capabilities = self.get_capabilities(model_name) if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS: # Get model's max thinking tokens and calculate actual budget model_config = self.SUPPORTED_MODELS.get(resolved_name) @@ -457,18 +457,6 @@ class GeminiModelProvider(ModelProvider): return usage - def _supports_vision(self, model_name: str) -> bool: - """Check if the model supports vision (image processing).""" - # Gemini 2.5 models support vision - vision_models = { - "gemini-2.5-flash", - "gemini-2.5-pro", - "gemini-2.0-flash", - "gemini-1.5-pro", - "gemini-1.5-flash", - } - return model_name in vision_models - def _is_error_retryable(self, error: Exception) -> bool: """Determine if an error should be retried based on structured error codes. diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py index 701c84f..d26ea13 100644 --- a/providers/openai_compatible.py +++ b/providers/openai_compatible.py @@ -482,12 +482,19 @@ class OpenAICompatibleProvider(ModelProvider): if system_prompt: messages.append({"role": "system", "content": system_prompt}) + # Resolve capabilities once for vision/temperature checks + try: + capabilities = self.get_capabilities(model_name) + except Exception as exc: + logging.debug(f"Falling back to generic capabilities for {model_name}: {exc}") + capabilities = None + # Prepare user message with text and potentially images user_content = [] user_content.append({"type": "text", "text": prompt}) # Add images if provided and model supports vision - if images and self._supports_vision(model_name): + if images and capabilities and capabilities.supports_images: for image_path in images: try: image_content = self._process_image(image_path) @@ -497,7 +504,7 @@ class OpenAICompatibleProvider(ModelProvider): logging.warning(f"Failed to process image {image_path}: {e}") # Continue with other images and text continue - elif images and not self._supports_vision(model_name): + elif images and (not capabilities or not capabilities.supports_images): logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)") # Add user message @@ -727,31 +734,6 @@ class OpenAICompatibleProvider(ModelProvider): """ return False - def _supports_vision(self, model_name: str) -> bool: - """Check if the model supports vision (image processing). - - Default implementation for OpenAI-compatible providers. - Subclasses should override with specific model support. - """ - # Common vision-capable models - only include models that actually support images - vision_models = { - "gpt-5", - "gpt-5-mini", - "gpt-4o", - "gpt-4o-mini", - "gpt-4-turbo", - "gpt-4-vision-preview", - "gpt-4.1-2025-04-14", - "o3", - "o3-mini", - "o3-pro", - "o4-mini", - # Note: Claude models would be handled by a separate provider - } - supports = model_name.lower() in vision_models - logging.debug(f"Model '{model_name}' vision support: {supports}") - return supports - def _is_error_retryable(self, error: Exception) -> bool: """Determine if an error should be retried based on structured error codes. diff --git a/tests/test_dial_provider.py b/tests/test_dial_provider.py index 3423c7c..9062a18 100644 --- a/tests/test_dial_provider.py +++ b/tests/test_dial_provider.py @@ -140,17 +140,16 @@ class TestDIALProvider: @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False) @patch("utils.model_restrictions._restriction_service", None) def test_supports_vision(self): - """Test vision support detection.""" + """Test vision support detection through model capabilities.""" provider = DIALModelProvider("test-key") - # Test models with vision support - assert provider._supports_vision("o3-2025-04-16") is True - assert provider._supports_vision("o3") is True # Via resolution - assert provider._supports_vision("anthropic.claude-opus-4.1-20250805-v1:0") is True - assert provider._supports_vision("gemini-2.5-pro-preview-05-06") is True + assert provider.get_capabilities("o3-2025-04-16").supports_images is True + assert provider.get_capabilities("o3").supports_images is True # Via resolution + assert provider.get_capabilities("anthropic.claude-opus-4.1-20250805-v1:0").supports_images is True + assert provider.get_capabilities("gemini-2.5-pro-preview-05-06").supports_images is True - # Test unknown model (falls back to parent implementation) - assert provider._supports_vision("unknown-model") is False + with pytest.raises(ValueError): + provider.get_capabilities("unknown-model") @patch("openai.OpenAI") # Mock the OpenAI class directly from openai module def test_generate_content_with_alias(self, mock_openai_class):