Proper fix for model discovery per provider

This commit is contained in:
Fahad
2025-06-18 07:16:10 +04:00
parent 5199dd6ead
commit dad1e2d74e
15 changed files with 1250 additions and 65 deletions

View File

@@ -160,66 +160,29 @@ class ModelProviderRegistry:
Returns:
Dict mapping model names to provider types
"""
models = {}
instance = cls()
# Import here to avoid circular imports
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models: dict[str, ProviderType] = {}
instance = cls()
for provider_type in instance._providers:
provider = cls.get_provider(provider_type)
if provider:
# Get supported models based on provider type
if hasattr(provider, "SUPPORTED_MODELS"):
for model_name, config in provider.SUPPORTED_MODELS.items():
# Handle both base models (dict configs) and aliases (string values)
if isinstance(config, str):
# This is an alias - check if the target model would be allowed
target_model = config
if restriction_service and not restriction_service.is_allowed(provider_type, target_model):
logging.debug(f"Alias {model_name} -> {target_model} filtered by restrictions")
continue
# Allow the alias
models[model_name] = provider_type
else:
# This is a base model with config dict
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
logging.debug(f"Model {model_name} filtered by restrictions")
continue
models[model_name] = provider_type
elif provider_type == ProviderType.OPENROUTER:
# OpenRouter uses a registry system instead of SUPPORTED_MODELS
if hasattr(provider, "_registry") and provider._registry:
for model_name in provider._registry.list_models():
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
logging.debug(f"Model {model_name} filtered by restrictions")
continue
if not provider:
continue
models[model_name] = provider_type
elif provider_type == ProviderType.CUSTOM:
# Custom provider also uses a registry system (shared with OpenRouter)
if hasattr(provider, "_registry") and provider._registry:
# Get all models from the registry
all_models = provider._registry.list_models()
aliases = provider._registry.list_aliases()
try:
available = provider.list_models(respect_restrictions=respect_restrictions)
except NotImplementedError:
logging.warning("Provider %s does not implement list_models", provider_type)
continue
# Add models that are validated by the custom provider
for model_name in all_models + aliases:
# Use the provider's validation logic to determine if this model
# is appropriate for the custom endpoint
if provider.validate_model_name(model_name):
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(
provider_type, model_name
):
logging.debug(f"Model {model_name} filtered by restrictions")
continue
models[model_name] = provider_type
for model_name in available:
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
logging.debug("Model %s filtered by restrictions", model_name)
continue
models[model_name] = provider_type
return models