fix: listmodels to always honor restricted models

fix: restrictions should resolve canonical names for openrouter
fix: tools now correctly return restricted list by presenting model names in schema
fix: tests updated to ensure these manage their expected env vars properly
perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
Fahad
2025-10-04 13:46:22 +04:00
parent 054e34e31c
commit 4015e917ed
17 changed files with 885 additions and 253 deletions

View File

@@ -22,6 +22,7 @@ Example:
import logging
import os
from collections import defaultdict
from typing import Optional
from providers.shared import ProviderType
@@ -58,6 +59,7 @@ class ModelRestrictionService:
def __init__(self):
"""Initialize the restriction service by loading from environment."""
self.restrictions: dict[ProviderType, set[str]] = {}
self._alias_resolution_cache: dict[ProviderType, dict[str, str]] = defaultdict(dict)
self._load_from_env()
def _load_from_env(self) -> None:
@@ -79,6 +81,7 @@ class ModelRestrictionService:
if models:
self.restrictions[provider_type] = models
self._alias_resolution_cache[provider_type] = {}
logger.info(f"{provider_type.value} allowed models: {sorted(models)}")
else:
# All entries were empty after cleaning - treat as no restrictions
@@ -150,7 +153,41 @@ class ModelRestrictionService:
names_to_check.add(original_name.lower())
# If any of the names is in the allowed set, it's allowed
return any(name in allowed_set for name in names_to_check)
if any(name in allowed_set for name in names_to_check):
return True
# Attempt to resolve canonical names for allowed aliases using provider metadata.
try:
from providers.registry import ModelProviderRegistry
provider = ModelProviderRegistry.get_provider(provider_type)
except Exception: # pragma: no cover - registry lookup failure shouldn't break validation
provider = None
if provider:
cache = self._alias_resolution_cache.setdefault(provider_type, {})
for allowed_entry in list(allowed_set):
normalized_resolved = cache.get(allowed_entry)
if not normalized_resolved:
try:
resolved = provider._resolve_model_name(allowed_entry)
except Exception: # pragma: no cover - resolution failures are treated as non-matches
continue
if not resolved:
continue
normalized_resolved = resolved.lower()
cache[allowed_entry] = normalized_resolved
if normalized_resolved in names_to_check:
allowed_set.add(normalized_resolved)
cache[normalized_resolved] = normalized_resolved
return True
return False
def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]:
"""