Use ModelCapabilities consistently instead of dictionaries

Moved aliases as part of SUPPORTED_MODELS instead of shorthand, more in line with how custom_models are declared
Further refactoring to cleanup some code
This commit is contained in:
Fahad
2025-06-23 16:58:59 +04:00
parent e94c028a3f
commit 498ea88293
16 changed files with 850 additions and 605 deletions

View File

@@ -268,65 +268,55 @@ class CustomProvider(OpenAICompatibleProvider):
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode.
Most custom/local models don't support extended thinking.
Args:
model_name: Model to check
Returns:
False (custom models generally don't support thinking mode)
True if model supports thinking mode, False otherwise
"""
# Check if model is in registry
config = self._registry.resolve(model_name) if self._registry else None
if config and config.is_custom:
# Trust the config from custom_models.json
return config.supports_extended_thinking
# Default to False for unknown models
return False
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
"""Get model configurations from the registry.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
For CustomProvider, we convert registry configurations to ModelCapabilities objects.
Returns:
List of model names available from this provider
Dictionary mapping model names to their ModelCapabilities objects
"""
from utils.model_restrictions import get_restriction_service
from .base import ProviderType
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
configs = {}
if self._registry:
# Get all models from the registry
all_models = self._registry.list_models()
aliases = self._registry.list_aliases()
# Add models that are validated by the custom provider
for model_name in all_models + aliases:
# Use the provider's validation logic to determine if this model
# is appropriate for the custom endpoint
# Get all models from registry
for model_name in self._registry.list_models():
# Only include custom models that this provider validates
if self.validate_model_name(model_name):
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
config = self._registry.resolve(model_name)
if config and config.is_custom:
# Convert OpenRouterModelConfig to ModelCapabilities
capabilities = config.to_capabilities()
# Override provider type to CUSTOM for local models
capabilities.provider = ProviderType.CUSTOM
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
configs[model_name] = capabilities
models.append(model_name)
return configs
return models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
def get_all_model_aliases(self) -> dict[str, list[str]]:
"""Get all model aliases from the registry.
Returns:
List of all model names and alias targets known by this provider
Dictionary mapping model names to their list of aliases
"""
all_models = set()
if self._registry:
# Get all models and aliases from the registry
all_models.update(model.lower() for model in self._registry.list_models())
all_models.update(alias.lower() for alias in self._registry.list_aliases())
# For each alias, also add its target
for alias in self._registry.list_aliases():
config = self._registry.resolve(alias)
if config:
all_models.add(config.model_name.lower())
return list(all_models)
# Since aliases are now included in the configurations,
# we can use the base class implementation
return super().get_all_model_aliases()