fix: listmodels to always honor restricted models
fix: restrictions should resolve canonical names for openrouter fix: tools now correctly return restricted list by presenting model names in schema fix: tests updated to ensure these manage their expected env vars properly perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
@@ -22,6 +22,7 @@ Example:
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from providers.shared import ProviderType
|
||||
@@ -58,6 +59,7 @@ class ModelRestrictionService:
|
||||
def __init__(self):
|
||||
"""Initialize the restriction service by loading from environment."""
|
||||
self.restrictions: dict[ProviderType, set[str]] = {}
|
||||
self._alias_resolution_cache: dict[ProviderType, dict[str, str]] = defaultdict(dict)
|
||||
self._load_from_env()
|
||||
|
||||
def _load_from_env(self) -> None:
|
||||
@@ -79,6 +81,7 @@ class ModelRestrictionService:
|
||||
|
||||
if models:
|
||||
self.restrictions[provider_type] = models
|
||||
self._alias_resolution_cache[provider_type] = {}
|
||||
logger.info(f"{provider_type.value} allowed models: {sorted(models)}")
|
||||
else:
|
||||
# All entries were empty after cleaning - treat as no restrictions
|
||||
@@ -150,7 +153,41 @@ class ModelRestrictionService:
|
||||
names_to_check.add(original_name.lower())
|
||||
|
||||
# If any of the names is in the allowed set, it's allowed
|
||||
return any(name in allowed_set for name in names_to_check)
|
||||
if any(name in allowed_set for name in names_to_check):
|
||||
return True
|
||||
|
||||
# Attempt to resolve canonical names for allowed aliases using provider metadata.
|
||||
try:
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
provider = ModelProviderRegistry.get_provider(provider_type)
|
||||
except Exception: # pragma: no cover - registry lookup failure shouldn't break validation
|
||||
provider = None
|
||||
|
||||
if provider:
|
||||
cache = self._alias_resolution_cache.setdefault(provider_type, {})
|
||||
|
||||
for allowed_entry in list(allowed_set):
|
||||
normalized_resolved = cache.get(allowed_entry)
|
||||
|
||||
if not normalized_resolved:
|
||||
try:
|
||||
resolved = provider._resolve_model_name(allowed_entry)
|
||||
except Exception: # pragma: no cover - resolution failures are treated as non-matches
|
||||
continue
|
||||
|
||||
if not resolved:
|
||||
continue
|
||||
|
||||
normalized_resolved = resolved.lower()
|
||||
cache[allowed_entry] = normalized_resolved
|
||||
|
||||
if normalized_resolved in names_to_check:
|
||||
allowed_set.add(normalized_resolved)
|
||||
cache[normalized_resolved] = normalized_resolved
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user