Fixed restriction checks for OpenRouter
This commit is contained in:
Fahad
2025-06-23 15:23:55 +04:00
parent b4852c825f
commit e94c028a3f
9 changed files with 246 additions and 60 deletions

View File

@@ -50,14 +50,6 @@ class OpenRouterProvider(OpenAICompatibleProvider):
aliases = self._registry.list_aliases()
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
def _parse_allowed_models(self) -> None:
"""Override to disable environment-based allow-list.
OpenRouter model access is controlled via the OpenRouter dashboard,
not through environment variables.
"""
return None
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model aliases to OpenRouter model names.
@@ -130,16 +122,34 @@ class OpenRouterProvider(OpenAICompatibleProvider):
As the catch-all provider, OpenRouter accepts any model name that wasn't
handled by higher-priority providers. OpenRouter will validate based on
the API key's permissions.
the API key's permissions and local restrictions.
Args:
model_name: Model name to validate
Returns:
Always True - OpenRouter is the catch-all provider
True if model is allowed, False if restricted
"""
# Accept any model name - OpenRouter is the fallback provider
# Higher priority providers (native APIs, custom endpoints) get first chance
# Check model restrictions if configured
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if restriction_service:
# Check if model name itself is allowed
if restriction_service.is_allowed(self.get_provider_type(), model_name):
return True
# Also check aliases - model_name might be an alias
model_config = self._registry.resolve(model_name)
if model_config and model_config.aliases:
for alias in model_config.aliases:
if restriction_service.is_allowed(self.get_provider_type(), alias):
return True
# If restrictions are configured and model/alias not in allowed list, reject
return False
# No restrictions configured - accept any model name as the fallback provider
return True
def generate_content(

View File

@@ -129,7 +129,6 @@ class ModelProviderRegistry:
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
for provider_type in PROVIDER_PRIORITY_ORDER:
logging.debug(f"Checking provider_type: {provider_type}")
if provider_type in instance._providers:
logging.debug(f"Found {provider_type} in registry")
# Get or create provider instance