Use ModelCapabilities consistently instead of dictionaries
Moved aliases as part of SUPPORTED_MODELS instead of shorthand, more in line with how custom_models are declared Further refactoring to cleanup some code
This commit is contained in:
@@ -140,6 +140,19 @@ class ModelCapabilities:
|
||||
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
||||
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
||||
|
||||
# Additional fields for comprehensive model information
|
||||
description: str = "" # Human-readable description of the model
|
||||
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
|
||||
|
||||
# JSON mode support (for providers that support structured output)
|
||||
supports_json_mode: bool = False
|
||||
|
||||
# Thinking mode support (for models with thinking capabilities)
|
||||
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
|
||||
|
||||
# Custom model flag (for models that only work with custom endpoints)
|
||||
is_custom: bool = False # Whether this model requires custom API endpoints
|
||||
|
||||
# Temperature constraint object - preferred way to define temperature limits
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
@@ -251,7 +264,7 @@ class ModelProvider(ABC):
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
# Check if model supports temperature at all
|
||||
if hasattr(capabilities, "supports_temperature") and not capabilities.supports_temperature:
|
||||
if not capabilities.supports_temperature:
|
||||
return None
|
||||
|
||||
# Get temperature range
|
||||
@@ -290,19 +303,109 @@ class ModelProvider(ABC):
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations for this provider.
|
||||
|
||||
This is a hook method that subclasses can override to provide
|
||||
their model configurations from different sources.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
# Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects)
|
||||
if hasattr(self, "SUPPORTED_MODELS"):
|
||||
return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)}
|
||||
return {}
|
||||
|
||||
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||
"""Get all model aliases for this provider.
|
||||
|
||||
This is a hook method that subclasses can override to provide
|
||||
aliases from different sources.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their list of aliases
|
||||
"""
|
||||
# Default implementation extracts from ModelCapabilities objects
|
||||
aliases = {}
|
||||
for model_name, capabilities in self.get_model_configurations().items():
|
||||
if capabilities.aliases:
|
||||
aliases[model_name] = capabilities.aliases
|
||||
return aliases
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
|
||||
This implementation uses the hook methods to support different
|
||||
model configuration sources.
|
||||
|
||||
Args:
|
||||
model_name: Model name that may be an alias
|
||||
|
||||
Returns:
|
||||
Resolved model name
|
||||
"""
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
|
||||
# First check if it's already a base model name (case-sensitive exact match)
|
||||
if model_name in model_configs:
|
||||
return model_name
|
||||
|
||||
# Check case-insensitively for both base models and aliases
|
||||
model_name_lower = model_name.lower()
|
||||
|
||||
# Check base model names case-insensitively
|
||||
for base_model in model_configs:
|
||||
if base_model.lower() == model_name_lower:
|
||||
return base_model
|
||||
|
||||
# Check aliases from the hook method
|
||||
all_aliases = self.get_all_model_aliases()
|
||||
for base_model, aliases in all_aliases.items():
|
||||
if any(alias.lower() == model_name_lower for alias in aliases):
|
||||
return base_model
|
||||
|
||||
# If not found, return as-is
|
||||
return model_name
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
This implementation uses the get_model_configurations() hook
|
||||
to support different model configuration sources.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
pass
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
|
||||
for model_name in model_configs:
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
|
||||
# Add the base model
|
||||
models.append(model_name)
|
||||
|
||||
# Get aliases from the hook method
|
||||
all_aliases = self.get_all_model_aliases()
|
||||
for model_name, aliases in all_aliases.items():
|
||||
# Only add aliases for models that passed restriction check
|
||||
if model_name in models:
|
||||
models.extend(aliases)
|
||||
|
||||
return models
|
||||
|
||||
@abstractmethod
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
@@ -312,21 +415,22 @@ class ModelProvider(ABC):
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
pass
|
||||
all_models = set()
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
|
||||
Base implementation returns the model name unchanged.
|
||||
Subclasses should override to provide alias resolution.
|
||||
# Add all base model names
|
||||
for model_name in model_configs:
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
Args:
|
||||
model_name: Model name that may be an alias
|
||||
# Get aliases from the hook method and add them
|
||||
all_aliases = self.get_all_model_aliases()
|
||||
for _model_name, aliases in all_aliases.items():
|
||||
for alias in aliases:
|
||||
all_models.add(alias.lower())
|
||||
|
||||
Returns:
|
||||
Resolved model name
|
||||
"""
|
||||
return model_name
|
||||
return list(all_models)
|
||||
|
||||
def close(self):
|
||||
"""Clean up any resources held by the provider.
|
||||
|
||||
Reference in New Issue
Block a user