Proper fix for model discovery per provider

This commit is contained in:
Fahad
2025-06-18 07:16:10 +04:00
parent 5199dd6ead
commit dad1e2d74e
15 changed files with 1250 additions and 65 deletions

View File

@@ -221,3 +221,27 @@ class ModelProvider(ABC):
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
pass
@abstractmethod
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
pass
@abstractmethod
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
This is used for validation purposes to ensure restriction policies
can validate against both aliases and their target model names.
Returns:
List of all model names and alias targets known by this provider
"""
pass

View File

@@ -276,3 +276,56 @@ class CustomProvider(OpenAICompatibleProvider):
False (custom models generally don't support thinking mode)
"""
return False
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
if self._registry:
# Get all models from the registry
all_models = self._registry.list_models()
aliases = self._registry.list_aliases()
# Add models that are validated by the custom provider
for model_name in all_models + aliases:
# Use the provider's validation logic to determine if this model
# is appropriate for the custom endpoint
if self.validate_model_name(model_name):
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
models.append(model_name)
return models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
Returns:
List of all model names and alias targets known by this provider
"""
all_models = set()
if self._registry:
# Get all models and aliases from the registry
all_models.update(model.lower() for model in self._registry.list_models())
all_models.update(alias.lower() for alias in self._registry.list_aliases())
# For each alias, also add its target
for alias in self._registry.list_aliases():
config = self._registry.resolve(alias)
if config:
all_models.add(config.model_name.lower())
return list(all_models)

View File

@@ -287,6 +287,56 @@ class GeminiModelProvider(ModelProvider):
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in self.SUPPORTED_MODELS.items():
# Handle both base models (dict configs) and aliases (string values)
if isinstance(config, str):
# This is an alias - check if the target model would be allowed
target_model = config
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
continue
# Allow the alias
models.append(model_name)
else:
# This is a base model with config dict
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
models.append(model_name)
return models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
Returns:
List of all model names and alias targets known by this provider
"""
all_models = set()
for model_name, config in self.SUPPORTED_MODELS.items():
# Add the model name itself
all_models.add(model_name.lower())
# If it's an alias (string value), add the target model too
if isinstance(config, str):
all_models.add(config.lower())
return list(all_models)
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
# Check if it's a shorthand

View File

@@ -163,6 +163,56 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
# This may change with future O3 models
return False
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in self.SUPPORTED_MODELS.items():
# Handle both base models (dict configs) and aliases (string values)
if isinstance(config, str):
# This is an alias - check if the target model would be allowed
target_model = config
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
continue
# Allow the alias
models.append(model_name)
else:
# This is a base model with config dict
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
models.append(model_name)
return models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
Returns:
List of all model names and alias targets known by this provider
"""
all_models = set()
for model_name, config in self.SUPPORTED_MODELS.items():
# Add the model name itself
all_models.add(model_name.lower())
# If it's an alias (string value), add the target model too
if isinstance(config, str):
all_models.add(config.lower())
return list(all_models)
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
# Check if it's a shorthand

View File

@@ -190,3 +190,48 @@ class OpenRouterProvider(OpenAICompatibleProvider):
False (no OpenRouter models currently support thinking mode)
"""
return False
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
if self._registry:
for model_name in self._registry.list_models():
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
models.append(model_name)
return models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
Returns:
List of all model names and alias targets known by this provider
"""
all_models = set()
if self._registry:
# Get all models and aliases from the registry
all_models.update(model.lower() for model in self._registry.list_models())
all_models.update(alias.lower() for alias in self._registry.list_aliases())
# For each alias, also add its target
for alias in self._registry.list_aliases():
config = self._registry.resolve(alias)
if config:
all_models.add(config.model_name.lower())
return list(all_models)

View File

@@ -160,66 +160,29 @@ class ModelProviderRegistry:
Returns:
Dict mapping model names to provider types
"""
models = {}
instance = cls()
# Import here to avoid circular imports
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models: dict[str, ProviderType] = {}
instance = cls()
for provider_type in instance._providers:
provider = cls.get_provider(provider_type)
if provider:
# Get supported models based on provider type
if hasattr(provider, "SUPPORTED_MODELS"):
for model_name, config in provider.SUPPORTED_MODELS.items():
# Handle both base models (dict configs) and aliases (string values)
if isinstance(config, str):
# This is an alias - check if the target model would be allowed
target_model = config
if restriction_service and not restriction_service.is_allowed(provider_type, target_model):
logging.debug(f"Alias {model_name} -> {target_model} filtered by restrictions")
continue
# Allow the alias
models[model_name] = provider_type
else:
# This is a base model with config dict
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
logging.debug(f"Model {model_name} filtered by restrictions")
continue
models[model_name] = provider_type
elif provider_type == ProviderType.OPENROUTER:
# OpenRouter uses a registry system instead of SUPPORTED_MODELS
if hasattr(provider, "_registry") and provider._registry:
for model_name in provider._registry.list_models():
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
logging.debug(f"Model {model_name} filtered by restrictions")
continue
if not provider:
continue
models[model_name] = provider_type
elif provider_type == ProviderType.CUSTOM:
# Custom provider also uses a registry system (shared with OpenRouter)
if hasattr(provider, "_registry") and provider._registry:
# Get all models from the registry
all_models = provider._registry.list_models()
aliases = provider._registry.list_aliases()
try:
available = provider.list_models(respect_restrictions=respect_restrictions)
except NotImplementedError:
logging.warning("Provider %s does not implement list_models", provider_type)
continue
# Add models that are validated by the custom provider
for model_name in all_models + aliases:
# Use the provider's validation logic to determine if this model
# is appropriate for the custom endpoint
if provider.validate_model_name(model_name):
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(
provider_type, model_name
):
logging.debug(f"Model {model_name} filtered by restrictions")
continue
models[model_name] = provider_type
for model_name in available:
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
logging.debug("Model %s filtered by restrictions", model_name)
continue
models[model_name] = provider_type
return models

View File

@@ -126,6 +126,56 @@ class XAIModelProvider(OpenAICompatibleProvider):
# This may change with future GROK model releases
return False
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in self.SUPPORTED_MODELS.items():
# Handle both base models (dict configs) and aliases (string values)
if isinstance(config, str):
# This is an alias - check if the target model would be allowed
target_model = config
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
continue
# Allow the alias
models.append(model_name)
else:
# This is a base model with config dict
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
models.append(model_name)
return models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
Returns:
List of all model names and alias targets known by this provider
"""
all_models = set()
for model_name, config in self.SUPPORTED_MODELS.items():
# Add the model name itself
all_models.add(model_name.lower())
# If it's an alias (string value), add the target model too
if isinstance(config, str):
all_models.add(config.lower())
return list(all_models)
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
# Check if it's a shorthand