refactor: removed hard coded checks, use model capabilities instead
This commit is contained in:
@@ -402,8 +402,9 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
if not self.validate_model_name(model_name):
|
||||
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
|
||||
|
||||
# Validate parameters
|
||||
# Validate parameters and fetch capabilities
|
||||
self.validate_parameters(model_name, temperature)
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
# Prepare messages
|
||||
messages = []
|
||||
@@ -414,7 +415,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
if prompt:
|
||||
user_message_content.append({"type": "text", "text": prompt})
|
||||
|
||||
if images and self._supports_vision(model_name):
|
||||
if images and capabilities.supports_images:
|
||||
for img_path in images:
|
||||
processed_image = self._process_image(img_path)
|
||||
if processed_image:
|
||||
@@ -437,13 +438,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
# Check model capabilities
|
||||
try:
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
supports_temperature = capabilities.supports_temperature
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
||||
supports_temperature = True
|
||||
# Determine temperature support from capabilities
|
||||
supports_temperature = capabilities.supports_temperature
|
||||
|
||||
# Add temperature parameter if supported
|
||||
if supports_temperature:
|
||||
@@ -513,23 +509,6 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
||||
)
|
||||
|
||||
def _supports_vision(self, model_name: str) -> bool:
|
||||
"""Check if the model supports vision (image processing).
|
||||
|
||||
Args:
|
||||
model_name: Model name to check
|
||||
|
||||
Returns:
|
||||
True if model supports vision, False otherwise
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name in self.SUPPORTED_MODELS:
|
||||
return self.SUPPORTED_MODELS[resolved_name].supports_images
|
||||
|
||||
# Fall back to parent implementation for unknown models
|
||||
return super()._supports_vision(model_name)
|
||||
|
||||
def close(self):
|
||||
"""Clean up HTTP clients when provider is closed."""
|
||||
logger.info("Closing DIAL provider HTTP clients...")
|
||||
|
||||
Reference in New Issue
Block a user