Proper fix for model discovery per provider
This commit is contained in:
@@ -221,3 +221,27 @@ class ModelProvider(ABC):
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
This is used for validation purposes to ensure restriction policies
|
||||
can validate against both aliases and their target model names.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -276,3 +276,56 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
False (custom models generally don't support thinking mode)
|
||||
"""
|
||||
return False
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
if self._registry:
|
||||
# Get all models from the registry
|
||||
all_models = self._registry.list_models()
|
||||
aliases = self._registry.list_aliases()
|
||||
|
||||
# Add models that are validated by the custom provider
|
||||
for model_name in all_models + aliases:
|
||||
# Use the provider's validation logic to determine if this model
|
||||
# is appropriate for the custom endpoint
|
||||
if self.validate_model_name(model_name):
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
|
||||
models.append(model_name)
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
if self._registry:
|
||||
# Get all models and aliases from the registry
|
||||
all_models.update(model.lower() for model in self._registry.list_models())
|
||||
all_models.update(alias.lower() for alias in self._registry.list_aliases())
|
||||
|
||||
# For each alias, also add its target
|
||||
for alias in self._registry.list_aliases():
|
||||
config = self._registry.resolve(alias)
|
||||
if config:
|
||||
all_models.add(config.model_name.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
@@ -287,6 +287,56 @@ class GeminiModelProvider(ModelProvider):
|
||||
|
||||
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Handle both base models (dict configs) and aliases (string values)
|
||||
if isinstance(config, str):
|
||||
# This is an alias - check if the target model would be allowed
|
||||
target_model = config
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||
continue
|
||||
# Allow the alias
|
||||
models.append(model_name)
|
||||
else:
|
||||
# This is a base model with config dict
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
models.append(model_name)
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Add the model name itself
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
# If it's an alias (string value), add the target model too
|
||||
if isinstance(config, str):
|
||||
all_models.add(config.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
# Check if it's a shorthand
|
||||
|
||||
@@ -163,6 +163,56 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
# This may change with future O3 models
|
||||
return False
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Handle both base models (dict configs) and aliases (string values)
|
||||
if isinstance(config, str):
|
||||
# This is an alias - check if the target model would be allowed
|
||||
target_model = config
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||
continue
|
||||
# Allow the alias
|
||||
models.append(model_name)
|
||||
else:
|
||||
# This is a base model with config dict
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
models.append(model_name)
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Add the model name itself
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
# If it's an alias (string value), add the target model too
|
||||
if isinstance(config, str):
|
||||
all_models.add(config.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
# Check if it's a shorthand
|
||||
|
||||
@@ -190,3 +190,48 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
False (no OpenRouter models currently support thinking mode)
|
||||
"""
|
||||
return False
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
if self._registry:
|
||||
for model_name in self._registry.list_models():
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
|
||||
models.append(model_name)
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
if self._registry:
|
||||
# Get all models and aliases from the registry
|
||||
all_models.update(model.lower() for model in self._registry.list_models())
|
||||
all_models.update(alias.lower() for alias in self._registry.list_aliases())
|
||||
|
||||
# For each alias, also add its target
|
||||
for alias in self._registry.list_aliases():
|
||||
config = self._registry.resolve(alias)
|
||||
if config:
|
||||
all_models.add(config.model_name.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
@@ -160,66 +160,29 @@ class ModelProviderRegistry:
|
||||
Returns:
|
||||
Dict mapping model names to provider types
|
||||
"""
|
||||
models = {}
|
||||
instance = cls()
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models: dict[str, ProviderType] = {}
|
||||
instance = cls()
|
||||
|
||||
for provider_type in instance._providers:
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider:
|
||||
# Get supported models based on provider type
|
||||
if hasattr(provider, "SUPPORTED_MODELS"):
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
# Handle both base models (dict configs) and aliases (string values)
|
||||
if isinstance(config, str):
|
||||
# This is an alias - check if the target model would be allowed
|
||||
target_model = config
|
||||
if restriction_service and not restriction_service.is_allowed(provider_type, target_model):
|
||||
logging.debug(f"Alias {model_name} -> {target_model} filtered by restrictions")
|
||||
continue
|
||||
# Allow the alias
|
||||
models[model_name] = provider_type
|
||||
else:
|
||||
# This is a base model with config dict
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
||||
continue
|
||||
models[model_name] = provider_type
|
||||
elif provider_type == ProviderType.OPENROUTER:
|
||||
# OpenRouter uses a registry system instead of SUPPORTED_MODELS
|
||||
if hasattr(provider, "_registry") and provider._registry:
|
||||
for model_name in provider._registry.list_models():
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
||||
continue
|
||||
if not provider:
|
||||
continue
|
||||
|
||||
models[model_name] = provider_type
|
||||
elif provider_type == ProviderType.CUSTOM:
|
||||
# Custom provider also uses a registry system (shared with OpenRouter)
|
||||
if hasattr(provider, "_registry") and provider._registry:
|
||||
# Get all models from the registry
|
||||
all_models = provider._registry.list_models()
|
||||
aliases = provider._registry.list_aliases()
|
||||
try:
|
||||
available = provider.list_models(respect_restrictions=respect_restrictions)
|
||||
except NotImplementedError:
|
||||
logging.warning("Provider %s does not implement list_models", provider_type)
|
||||
continue
|
||||
|
||||
# Add models that are validated by the custom provider
|
||||
for model_name in all_models + aliases:
|
||||
# Use the provider's validation logic to determine if this model
|
||||
# is appropriate for the custom endpoint
|
||||
if provider.validate_model_name(model_name):
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(
|
||||
provider_type, model_name
|
||||
):
|
||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
||||
continue
|
||||
|
||||
models[model_name] = provider_type
|
||||
for model_name in available:
|
||||
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
||||
logging.debug("Model %s filtered by restrictions", model_name)
|
||||
continue
|
||||
models[model_name] = provider_type
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@@ -126,6 +126,56 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
# This may change with future GROK model releases
|
||||
return False
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Handle both base models (dict configs) and aliases (string values)
|
||||
if isinstance(config, str):
|
||||
# This is an alias - check if the target model would be allowed
|
||||
target_model = config
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||
continue
|
||||
# Allow the alias
|
||||
models.append(model_name)
|
||||
else:
|
||||
# This is a base model with config dict
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
models.append(model_name)
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Add the model name itself
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
# If it's an alias (string value), add the target model too
|
||||
if isinstance(config, str):
|
||||
all_models.add(config.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
# Check if it's a shorthand
|
||||
|
||||
Reference in New Issue
Block a user