refactor: removed hook from base class, turned into helper static method instead
This commit is contained in:
@@ -152,22 +152,6 @@ class ModelProvider(ABC):
|
|||||||
return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)}
|
return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)}
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
|
||||||
"""Get all model aliases for this provider.
|
|
||||||
|
|
||||||
This is a hook method that subclasses can override to provide
|
|
||||||
aliases from different sources.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping model names to their list of aliases
|
|
||||||
"""
|
|
||||||
# Default implementation extracts from ModelCapabilities objects
|
|
||||||
aliases = {}
|
|
||||||
for model_name, capabilities in self.get_model_configurations().items():
|
|
||||||
if capabilities.aliases:
|
|
||||||
aliases[model_name] = capabilities.aliases
|
|
||||||
return aliases
|
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
"""Resolve model shorthand to full name.
|
"""Resolve model shorthand to full name.
|
||||||
|
|
||||||
@@ -195,9 +179,9 @@ class ModelProvider(ABC):
|
|||||||
if base_model.lower() == model_name_lower:
|
if base_model.lower() == model_name_lower:
|
||||||
return base_model
|
return base_model
|
||||||
|
|
||||||
# Check aliases from the hook method
|
# Check aliases from the model configurations
|
||||||
all_aliases = self.get_all_model_aliases()
|
alias_map = ModelCapabilities.collect_aliases(model_configs)
|
||||||
for base_model, aliases in all_aliases.items():
|
for base_model, aliases in alias_map.items():
|
||||||
if any(alias.lower() == model_name_lower for alias in aliases):
|
if any(alias.lower() == model_name_lower for alias in aliases):
|
||||||
return base_model
|
return base_model
|
||||||
|
|
||||||
@@ -232,9 +216,9 @@ class ModelProvider(ABC):
|
|||||||
# Add the base model
|
# Add the base model
|
||||||
models.append(model_name)
|
models.append(model_name)
|
||||||
|
|
||||||
# Get aliases from the hook method
|
# Add aliases derived from the model configurations
|
||||||
all_aliases = self.get_all_model_aliases()
|
alias_map = ModelCapabilities.collect_aliases(model_configs)
|
||||||
for model_name, aliases in all_aliases.items():
|
for model_name, aliases in alias_map.items():
|
||||||
# Only add aliases for models that passed restriction check
|
# Only add aliases for models that passed restriction check
|
||||||
if model_name in models:
|
if model_name in models:
|
||||||
models.extend(aliases)
|
models.extend(aliases)
|
||||||
@@ -259,9 +243,8 @@ class ModelProvider(ABC):
|
|||||||
for model_name in model_configs:
|
for model_name in model_configs:
|
||||||
all_models.add(model_name.lower())
|
all_models.add(model_name.lower())
|
||||||
|
|
||||||
# Get aliases from the hook method and add them
|
# Add aliases derived from the model configurations
|
||||||
all_aliases = self.get_all_model_aliases()
|
for aliases in ModelCapabilities.collect_aliases(model_configs).values():
|
||||||
for _model_name, aliases in all_aliases.items():
|
|
||||||
for alias in aliases:
|
for alias in aliases:
|
||||||
all_models.add(alias.lower())
|
all_models.add(alias.lower())
|
||||||
|
|
||||||
|
|||||||
@@ -367,13 +367,3 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
configs[model_name] = config
|
configs[model_name] = config
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
|
||||||
"""Get all model aliases from the registry.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping model names to their list of aliases
|
|
||||||
"""
|
|
||||||
# Since aliases are now included in the configurations,
|
|
||||||
# we can use the base class implementation
|
|
||||||
return super().get_all_model_aliases()
|
|
||||||
|
|||||||
@@ -299,13 +299,3 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
configs[model_name] = config
|
configs[model_name] = config
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
|
||||||
"""Get all model aliases from the registry.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping model names to their list of aliases
|
|
||||||
"""
|
|
||||||
# Since aliases are now included in the configurations,
|
|
||||||
# we can use the base class implementation
|
|
||||||
return super().get_all_model_aliases()
|
|
||||||
|
|||||||
@@ -32,3 +32,13 @@ class ModelCapabilities:
|
|||||||
temperature_constraint: TemperatureConstraint = field(
|
temperature_constraint: TemperatureConstraint = field(
|
||||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collect_aliases(model_configs: dict[str, "ModelCapabilities"]) -> dict[str, list[str]]:
|
||||||
|
"""Build a mapping of model name to aliases from capability configs."""
|
||||||
|
|
||||||
|
return {
|
||||||
|
base_model: capabilities.aliases
|
||||||
|
for base_model, capabilities in model_configs.items()
|
||||||
|
if capabilities.aliases
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user