diff --git a/providers/base.py b/providers/base.py index ff290aa..4efe9d9 100644 --- a/providers/base.py +++ b/providers/base.py @@ -152,22 +152,6 @@ class ModelProvider(ABC): return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)} return {} - def get_all_model_aliases(self) -> dict[str, list[str]]: - """Get all model aliases for this provider. - - This is a hook method that subclasses can override to provide - aliases from different sources. - - Returns: - Dictionary mapping model names to their list of aliases - """ - # Default implementation extracts from ModelCapabilities objects - aliases = {} - for model_name, capabilities in self.get_model_configurations().items(): - if capabilities.aliases: - aliases[model_name] = capabilities.aliases - return aliases - def _resolve_model_name(self, model_name: str) -> str: """Resolve model shorthand to full name. @@ -195,9 +179,9 @@ class ModelProvider(ABC): if base_model.lower() == model_name_lower: return base_model - # Check aliases from the hook method - all_aliases = self.get_all_model_aliases() - for base_model, aliases in all_aliases.items(): + # Check aliases from the model configurations + alias_map = ModelCapabilities.collect_aliases(model_configs) + for base_model, aliases in alias_map.items(): if any(alias.lower() == model_name_lower for alias in aliases): return base_model @@ -232,9 +216,9 @@ class ModelProvider(ABC): # Add the base model models.append(model_name) - # Get aliases from the hook method - all_aliases = self.get_all_model_aliases() - for model_name, aliases in all_aliases.items(): + # Add aliases derived from the model configurations + alias_map = ModelCapabilities.collect_aliases(model_configs) + for model_name, aliases in alias_map.items(): # Only add aliases for models that passed restriction check if model_name in models: models.extend(aliases) @@ -259,9 +243,8 @@ class ModelProvider(ABC): for model_name in model_configs: all_models.add(model_name.lower()) - # Get aliases from the hook method and add them - all_aliases = self.get_all_model_aliases() - for _model_name, aliases in all_aliases.items(): + # Add aliases derived from the model configurations + for aliases in ModelCapabilities.collect_aliases(model_configs).values(): for alias in aliases: all_models.add(alias.lower()) diff --git a/providers/custom.py b/providers/custom.py index d7c4f37..3f6f813 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -367,13 +367,3 @@ class CustomProvider(OpenAICompatibleProvider): configs[model_name] = config return configs - - def get_all_model_aliases(self) -> dict[str, list[str]]: - """Get all model aliases from the registry. - - Returns: - Dictionary mapping model names to their list of aliases - """ - # Since aliases are now included in the configurations, - # we can use the base class implementation - return super().get_all_model_aliases() diff --git a/providers/openrouter.py b/providers/openrouter.py index 5360b87..fdbbc62 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -299,13 +299,3 @@ class OpenRouterProvider(OpenAICompatibleProvider): configs[model_name] = config return configs - - def get_all_model_aliases(self) -> dict[str, list[str]]: - """Get all model aliases from the registry. - - Returns: - Dictionary mapping model names to their list of aliases - """ - # Since aliases are now included in the configurations, - # we can use the base class implementation - return super().get_all_model_aliases() diff --git a/providers/shared/model_capabilities.py b/providers/shared/model_capabilities.py index f68d304..105a8fc 100644 --- a/providers/shared/model_capabilities.py +++ b/providers/shared/model_capabilities.py @@ -32,3 +32,13 @@ class ModelCapabilities: temperature_constraint: TemperatureConstraint = field( default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3) ) + + @staticmethod + def collect_aliases(model_configs: dict[str, "ModelCapabilities"]) -> dict[str, list[str]]: + """Build a mapping of model name to aliases from capability configs.""" + + return { + base_model: capabilities.aliases + for base_model, capabilities in model_configs.items() + if capabilities.aliases + }