refactor: removed hard coded checks, use model capabilities instead
This commit is contained in:
@@ -402,8 +402,9 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
if not self.validate_model_name(model_name):
|
if not self.validate_model_name(model_name):
|
||||||
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
|
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
|
||||||
|
|
||||||
# Validate parameters
|
# Validate parameters and fetch capabilities
|
||||||
self.validate_parameters(model_name, temperature)
|
self.validate_parameters(model_name, temperature)
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
# Prepare messages
|
# Prepare messages
|
||||||
messages = []
|
messages = []
|
||||||
@@ -414,7 +415,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
if prompt:
|
if prompt:
|
||||||
user_message_content.append({"type": "text", "text": prompt})
|
user_message_content.append({"type": "text", "text": prompt})
|
||||||
|
|
||||||
if images and self._supports_vision(model_name):
|
if images and capabilities.supports_images:
|
||||||
for img_path in images:
|
for img_path in images:
|
||||||
processed_image = self._process_image(img_path)
|
processed_image = self._process_image(img_path)
|
||||||
if processed_image:
|
if processed_image:
|
||||||
@@ -437,13 +438,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
"messages": messages,
|
"messages": messages,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Check model capabilities
|
# Determine temperature support from capabilities
|
||||||
try:
|
supports_temperature = capabilities.supports_temperature
|
||||||
capabilities = self.get_capabilities(model_name)
|
|
||||||
supports_temperature = capabilities.supports_temperature
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
|
||||||
supports_temperature = True
|
|
||||||
|
|
||||||
# Add temperature parameter if supported
|
# Add temperature parameter if supported
|
||||||
if supports_temperature:
|
if supports_temperature:
|
||||||
@@ -513,23 +509,6 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _supports_vision(self, model_name: str) -> bool:
|
|
||||||
"""Check if the model supports vision (image processing).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Model name to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if model supports vision, False otherwise
|
|
||||||
"""
|
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
|
||||||
|
|
||||||
if resolved_name in self.SUPPORTED_MODELS:
|
|
||||||
return self.SUPPORTED_MODELS[resolved_name].supports_images
|
|
||||||
|
|
||||||
# Fall back to parent implementation for unknown models
|
|
||||||
return super()._supports_vision(model_name)
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Clean up HTTP clients when provider is closed."""
|
"""Clean up HTTP clients when provider is closed."""
|
||||||
logger.info("Closing DIAL provider HTTP clients...")
|
logger.info("Closing DIAL provider HTTP clients...")
|
||||||
|
|||||||
@@ -181,9 +181,10 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Generate content using Gemini model."""
|
"""Generate content using Gemini model."""
|
||||||
# Validate parameters
|
# Validate parameters and fetch capabilities
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
self.validate_parameters(model_name, temperature)
|
self.validate_parameters(model_name, temperature)
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
# Prepare content parts (text and potentially images)
|
# Prepare content parts (text and potentially images)
|
||||||
parts = []
|
parts = []
|
||||||
@@ -197,7 +198,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
parts.append({"text": full_prompt})
|
parts.append({"text": full_prompt})
|
||||||
|
|
||||||
# Add images if provided and model supports vision
|
# Add images if provided and model supports vision
|
||||||
if images and self._supports_vision(resolved_name):
|
if images and capabilities.supports_images:
|
||||||
for image_path in images:
|
for image_path in images:
|
||||||
try:
|
try:
|
||||||
image_part = self._process_image(image_path)
|
image_part = self._process_image(image_path)
|
||||||
@@ -207,7 +208,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
logger.warning(f"Failed to process image {image_path}: {e}")
|
logger.warning(f"Failed to process image {image_path}: {e}")
|
||||||
# Continue with other images and text
|
# Continue with other images and text
|
||||||
continue
|
continue
|
||||||
elif images and not self._supports_vision(resolved_name):
|
elif images and not capabilities.supports_images:
|
||||||
logger.warning(f"Model {resolved_name} does not support images, ignoring {len(images)} image(s)")
|
logger.warning(f"Model {resolved_name} does not support images, ignoring {len(images)} image(s)")
|
||||||
|
|
||||||
# Create contents structure
|
# Create contents structure
|
||||||
@@ -224,7 +225,6 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
generation_config.max_output_tokens = max_output_tokens
|
generation_config.max_output_tokens = max_output_tokens
|
||||||
|
|
||||||
# Add thinking configuration for models that support it
|
# Add thinking configuration for models that support it
|
||||||
capabilities = self.get_capabilities(model_name)
|
|
||||||
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||||
# Get model's max thinking tokens and calculate actual budget
|
# Get model's max thinking tokens and calculate actual budget
|
||||||
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
||||||
@@ -457,18 +457,6 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|
||||||
def _supports_vision(self, model_name: str) -> bool:
|
|
||||||
"""Check if the model supports vision (image processing)."""
|
|
||||||
# Gemini 2.5 models support vision
|
|
||||||
vision_models = {
|
|
||||||
"gemini-2.5-flash",
|
|
||||||
"gemini-2.5-pro",
|
|
||||||
"gemini-2.0-flash",
|
|
||||||
"gemini-1.5-pro",
|
|
||||||
"gemini-1.5-flash",
|
|
||||||
}
|
|
||||||
return model_name in vision_models
|
|
||||||
|
|
||||||
def _is_error_retryable(self, error: Exception) -> bool:
|
def _is_error_retryable(self, error: Exception) -> bool:
|
||||||
"""Determine if an error should be retried based on structured error codes.
|
"""Determine if an error should be retried based on structured error codes.
|
||||||
|
|
||||||
|
|||||||
@@ -482,12 +482,19 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
if system_prompt:
|
if system_prompt:
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
# Resolve capabilities once for vision/temperature checks
|
||||||
|
try:
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
except Exception as exc:
|
||||||
|
logging.debug(f"Falling back to generic capabilities for {model_name}: {exc}")
|
||||||
|
capabilities = None
|
||||||
|
|
||||||
# Prepare user message with text and potentially images
|
# Prepare user message with text and potentially images
|
||||||
user_content = []
|
user_content = []
|
||||||
user_content.append({"type": "text", "text": prompt})
|
user_content.append({"type": "text", "text": prompt})
|
||||||
|
|
||||||
# Add images if provided and model supports vision
|
# Add images if provided and model supports vision
|
||||||
if images and self._supports_vision(model_name):
|
if images and capabilities and capabilities.supports_images:
|
||||||
for image_path in images:
|
for image_path in images:
|
||||||
try:
|
try:
|
||||||
image_content = self._process_image(image_path)
|
image_content = self._process_image(image_path)
|
||||||
@@ -497,7 +504,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
logging.warning(f"Failed to process image {image_path}: {e}")
|
logging.warning(f"Failed to process image {image_path}: {e}")
|
||||||
# Continue with other images and text
|
# Continue with other images and text
|
||||||
continue
|
continue
|
||||||
elif images and not self._supports_vision(model_name):
|
elif images and (not capabilities or not capabilities.supports_images):
|
||||||
logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
|
logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
|
||||||
|
|
||||||
# Add user message
|
# Add user message
|
||||||
@@ -727,31 +734,6 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _supports_vision(self, model_name: str) -> bool:
|
|
||||||
"""Check if the model supports vision (image processing).
|
|
||||||
|
|
||||||
Default implementation for OpenAI-compatible providers.
|
|
||||||
Subclasses should override with specific model support.
|
|
||||||
"""
|
|
||||||
# Common vision-capable models - only include models that actually support images
|
|
||||||
vision_models = {
|
|
||||||
"gpt-5",
|
|
||||||
"gpt-5-mini",
|
|
||||||
"gpt-4o",
|
|
||||||
"gpt-4o-mini",
|
|
||||||
"gpt-4-turbo",
|
|
||||||
"gpt-4-vision-preview",
|
|
||||||
"gpt-4.1-2025-04-14",
|
|
||||||
"o3",
|
|
||||||
"o3-mini",
|
|
||||||
"o3-pro",
|
|
||||||
"o4-mini",
|
|
||||||
# Note: Claude models would be handled by a separate provider
|
|
||||||
}
|
|
||||||
supports = model_name.lower() in vision_models
|
|
||||||
logging.debug(f"Model '{model_name}' vision support: {supports}")
|
|
||||||
return supports
|
|
||||||
|
|
||||||
def _is_error_retryable(self, error: Exception) -> bool:
|
def _is_error_retryable(self, error: Exception) -> bool:
|
||||||
"""Determine if an error should be retried based on structured error codes.
|
"""Determine if an error should be retried based on structured error codes.
|
||||||
|
|
||||||
|
|||||||
@@ -140,17 +140,16 @@ class TestDIALProvider:
|
|||||||
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
|
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
|
||||||
@patch("utils.model_restrictions._restriction_service", None)
|
@patch("utils.model_restrictions._restriction_service", None)
|
||||||
def test_supports_vision(self):
|
def test_supports_vision(self):
|
||||||
"""Test vision support detection."""
|
"""Test vision support detection through model capabilities."""
|
||||||
provider = DIALModelProvider("test-key")
|
provider = DIALModelProvider("test-key")
|
||||||
|
|
||||||
# Test models with vision support
|
assert provider.get_capabilities("o3-2025-04-16").supports_images is True
|
||||||
assert provider._supports_vision("o3-2025-04-16") is True
|
assert provider.get_capabilities("o3").supports_images is True # Via resolution
|
||||||
assert provider._supports_vision("o3") is True # Via resolution
|
assert provider.get_capabilities("anthropic.claude-opus-4.1-20250805-v1:0").supports_images is True
|
||||||
assert provider._supports_vision("anthropic.claude-opus-4.1-20250805-v1:0") is True
|
assert provider.get_capabilities("gemini-2.5-pro-preview-05-06").supports_images is True
|
||||||
assert provider._supports_vision("gemini-2.5-pro-preview-05-06") is True
|
|
||||||
|
|
||||||
# Test unknown model (falls back to parent implementation)
|
with pytest.raises(ValueError):
|
||||||
assert provider._supports_vision("unknown-model") is False
|
provider.get_capabilities("unknown-model")
|
||||||
|
|
||||||
@patch("openai.OpenAI") # Mock the OpenAI class directly from openai module
|
@patch("openai.OpenAI") # Mock the OpenAI class directly from openai module
|
||||||
def test_generate_content_with_alias(self, mock_openai_class):
|
def test_generate_content_with_alias(self, mock_openai_class):
|
||||||
|
|||||||
Reference in New Issue
Block a user