Proper fix for model discovery per provider
This commit is contained in:
@@ -160,66 +160,29 @@ class ModelProviderRegistry:
|
||||
Returns:
|
||||
Dict mapping model names to provider types
|
||||
"""
|
||||
models = {}
|
||||
instance = cls()
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models: dict[str, ProviderType] = {}
|
||||
instance = cls()
|
||||
|
||||
for provider_type in instance._providers:
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider:
|
||||
# Get supported models based on provider type
|
||||
if hasattr(provider, "SUPPORTED_MODELS"):
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
# Handle both base models (dict configs) and aliases (string values)
|
||||
if isinstance(config, str):
|
||||
# This is an alias - check if the target model would be allowed
|
||||
target_model = config
|
||||
if restriction_service and not restriction_service.is_allowed(provider_type, target_model):
|
||||
logging.debug(f"Alias {model_name} -> {target_model} filtered by restrictions")
|
||||
continue
|
||||
# Allow the alias
|
||||
models[model_name] = provider_type
|
||||
else:
|
||||
# This is a base model with config dict
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
||||
continue
|
||||
models[model_name] = provider_type
|
||||
elif provider_type == ProviderType.OPENROUTER:
|
||||
# OpenRouter uses a registry system instead of SUPPORTED_MODELS
|
||||
if hasattr(provider, "_registry") and provider._registry:
|
||||
for model_name in provider._registry.list_models():
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
||||
continue
|
||||
if not provider:
|
||||
continue
|
||||
|
||||
models[model_name] = provider_type
|
||||
elif provider_type == ProviderType.CUSTOM:
|
||||
# Custom provider also uses a registry system (shared with OpenRouter)
|
||||
if hasattr(provider, "_registry") and provider._registry:
|
||||
# Get all models from the registry
|
||||
all_models = provider._registry.list_models()
|
||||
aliases = provider._registry.list_aliases()
|
||||
try:
|
||||
available = provider.list_models(respect_restrictions=respect_restrictions)
|
||||
except NotImplementedError:
|
||||
logging.warning("Provider %s does not implement list_models", provider_type)
|
||||
continue
|
||||
|
||||
# Add models that are validated by the custom provider
|
||||
for model_name in all_models + aliases:
|
||||
# Use the provider's validation logic to determine if this model
|
||||
# is appropriate for the custom endpoint
|
||||
if provider.validate_model_name(model_name):
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(
|
||||
provider_type, model_name
|
||||
):
|
||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
||||
continue
|
||||
|
||||
models[model_name] = provider_type
|
||||
for model_name in available:
|
||||
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
||||
logging.debug("Model %s filtered by restrictions", model_name)
|
||||
continue
|
||||
models[model_name] = provider_type
|
||||
|
||||
return models
|
||||
|
||||
|
||||
Reference in New Issue
Block a user