Use ModelCapabilities consistently instead of dictionaries
Moved aliases as part of SUPPORTED_MODELS instead of shorthand, more in line with how custom_models are declared Further refactoring to cleanup some code
This commit is contained in:
@@ -268,65 +268,55 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode.
|
||||
|
||||
Most custom/local models don't support extended thinking.
|
||||
|
||||
Args:
|
||||
model_name: Model to check
|
||||
|
||||
Returns:
|
||||
False (custom models generally don't support thinking mode)
|
||||
True if model supports thinking mode, False otherwise
|
||||
"""
|
||||
# Check if model is in registry
|
||||
config = self._registry.resolve(model_name) if self._registry else None
|
||||
if config and config.is_custom:
|
||||
# Trust the config from custom_models.json
|
||||
return config.supports_extended_thinking
|
||||
|
||||
# Default to False for unknown models
|
||||
return False
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations from the registry.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
For CustomProvider, we convert registry configurations to ModelCapabilities objects.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
from .base import ProviderType
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
configs = {}
|
||||
|
||||
if self._registry:
|
||||
# Get all models from the registry
|
||||
all_models = self._registry.list_models()
|
||||
aliases = self._registry.list_aliases()
|
||||
|
||||
# Add models that are validated by the custom provider
|
||||
for model_name in all_models + aliases:
|
||||
# Use the provider's validation logic to determine if this model
|
||||
# is appropriate for the custom endpoint
|
||||
# Get all models from registry
|
||||
for model_name in self._registry.list_models():
|
||||
# Only include custom models that this provider validates
|
||||
if self.validate_model_name(model_name):
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and config.is_custom:
|
||||
# Convert OpenRouterModelConfig to ModelCapabilities
|
||||
capabilities = config.to_capabilities()
|
||||
# Override provider type to CUSTOM for local models
|
||||
capabilities.provider = ProviderType.CUSTOM
|
||||
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
|
||||
configs[model_name] = capabilities
|
||||
|
||||
models.append(model_name)
|
||||
return configs
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||
"""Get all model aliases from the registry.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
Dictionary mapping model names to their list of aliases
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
if self._registry:
|
||||
# Get all models and aliases from the registry
|
||||
all_models.update(model.lower() for model in self._registry.list_models())
|
||||
all_models.update(alias.lower() for alias in self._registry.list_aliases())
|
||||
|
||||
# For each alias, also add its target
|
||||
for alias in self._registry.list_aliases():
|
||||
config = self._registry.resolve(alias)
|
||||
if config:
|
||||
all_models.add(config.model_name.lower())
|
||||
|
||||
return list(all_models)
|
||||
# Since aliases are now included in the configurations,
|
||||
# we can use the base class implementation
|
||||
return super().get_all_model_aliases()
|
||||
|
||||
Reference in New Issue
Block a user