fix: listmodels to always honor restricted models
fix: restrictions should resolve canonical names for openrouter fix: tools now correctly return restricted list by presenting model names in schema fix: tests updated to ensure these manage their expected env vars properly perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
@@ -73,6 +73,8 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
|
||||
logging.info(f"Initializing Custom provider with endpoint: {base_url}")
|
||||
|
||||
self._alias_cache: dict[str, str] = {}
|
||||
|
||||
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||
|
||||
# Initialize model registry (shared with OpenRouter for consistent aliases)
|
||||
@@ -120,11 +122,18 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve registry aliases and strip version tags for local models."""
|
||||
|
||||
cache_key = model_name.lower()
|
||||
if cache_key in self._alias_cache:
|
||||
return self._alias_cache[cache_key]
|
||||
|
||||
config = self._registry.resolve(model_name)
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
|
||||
resolved = config.model_name
|
||||
self._alias_cache[cache_key] = resolved
|
||||
self._alias_cache.setdefault(resolved.lower(), resolved)
|
||||
return resolved
|
||||
|
||||
if ":" in model_name:
|
||||
base_model = model_name.split(":")[0]
|
||||
@@ -132,11 +141,16 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
|
||||
base_config = self._registry.resolve(base_model)
|
||||
if base_config:
|
||||
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
||||
return base_config.model_name
|
||||
logging.debug("Resolved base model '%s' to '%s'", base_model, base_config.model_name)
|
||||
resolved = base_config.model_name
|
||||
self._alias_cache[cache_key] = resolved
|
||||
self._alias_cache.setdefault(resolved.lower(), resolved)
|
||||
return resolved
|
||||
self._alias_cache[cache_key] = base_model
|
||||
return base_model
|
||||
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
self._alias_cache[cache_key] = model_name
|
||||
return model_name
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
|
||||
@@ -39,6 +39,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
base_url: Base URL for the API endpoint
|
||||
**kwargs: Additional configuration options including timeout
|
||||
"""
|
||||
self._allowed_alias_cache: dict[str, str] = {}
|
||||
super().__init__(api_key, **kwargs)
|
||||
self._client = None
|
||||
self.base_url = base_url
|
||||
@@ -74,9 +75,33 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
canonical = canonical_name.lower()
|
||||
|
||||
if requested not in self.allowed_models and canonical not in self.allowed_models:
|
||||
raise ValueError(
|
||||
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
||||
)
|
||||
allowed = False
|
||||
for allowed_entry in list(self.allowed_models):
|
||||
normalized_resolved = self._allowed_alias_cache.get(allowed_entry)
|
||||
if normalized_resolved is None:
|
||||
try:
|
||||
resolved_name = self._resolve_model_name(allowed_entry)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not resolved_name:
|
||||
continue
|
||||
|
||||
normalized_resolved = resolved_name.lower()
|
||||
self._allowed_alias_cache[allowed_entry] = normalized_resolved
|
||||
|
||||
if normalized_resolved == canonical:
|
||||
# Canonical match discovered via alias resolution – mark as allowed and
|
||||
# memoise the canonical entry for future lookups.
|
||||
allowed = True
|
||||
self._allowed_alias_cache[canonical] = canonical
|
||||
self.allowed_models.add(canonical)
|
||||
break
|
||||
|
||||
if not allowed:
|
||||
raise ValueError(
|
||||
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
||||
)
|
||||
|
||||
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||||
"""Parse allowed models from environment variable.
|
||||
@@ -94,6 +119,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
||||
if models:
|
||||
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
||||
self._allowed_alias_cache = {}
|
||||
return models
|
||||
|
||||
# Log info if no allow-list configured for proxy providers
|
||||
|
||||
@@ -50,6 +50,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
self._alias_cache: dict[str, str] = {}
|
||||
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||
|
||||
# Initialize model registry
|
||||
@@ -178,13 +179,21 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve aliases defined in the OpenRouter registry."""
|
||||
|
||||
cache_key = model_name.lower()
|
||||
if cache_key in self._alias_cache:
|
||||
return self._alias_cache[cache_key]
|
||||
|
||||
config = self._registry.resolve(model_name)
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
|
||||
resolved = config.model_name
|
||||
self._alias_cache[cache_key] = resolved
|
||||
self._alias_cache.setdefault(resolved.lower(), resolved)
|
||||
return resolved
|
||||
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
self._alias_cache[cache_key] = model_name
|
||||
return model_name
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
|
||||
@@ -205,6 +205,18 @@ class ModelProviderRegistry:
|
||||
logging.warning("Provider %s does not implement list_models", provider_type)
|
||||
continue
|
||||
|
||||
if restriction_service and restriction_service.has_restrictions(provider_type):
|
||||
restricted_display = cls._collect_restricted_display_names(
|
||||
provider,
|
||||
provider_type,
|
||||
available,
|
||||
restriction_service,
|
||||
)
|
||||
if restricted_display:
|
||||
for model_name in restricted_display:
|
||||
models[model_name] = provider_type
|
||||
continue
|
||||
|
||||
for model_name in available:
|
||||
# =====================================================================================
|
||||
# CRITICAL: Prevent double restriction filtering (Fixed Issue #98)
|
||||
@@ -227,6 +239,50 @@ class ModelProviderRegistry:
|
||||
|
||||
return models
|
||||
|
||||
@classmethod
|
||||
def _collect_restricted_display_names(
|
||||
cls,
|
||||
provider: ModelProvider,
|
||||
provider_type: ProviderType,
|
||||
available: list[str],
|
||||
restriction_service,
|
||||
) -> list[str] | None:
|
||||
"""Derive the human-facing model list when restrictions are active."""
|
||||
|
||||
allowed_models = restriction_service.get_allowed_models(provider_type)
|
||||
if not allowed_models:
|
||||
return None
|
||||
|
||||
allowed_details: list[tuple[str, int]] = []
|
||||
|
||||
for model_name in sorted(allowed_models):
|
||||
try:
|
||||
capabilities = provider.get_capabilities(model_name)
|
||||
except (AttributeError, ValueError):
|
||||
continue
|
||||
|
||||
try:
|
||||
rank = capabilities.get_effective_capability_rank()
|
||||
rank_value = float(rank)
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
rank_value = 0.0
|
||||
|
||||
allowed_details.append((model_name, rank_value))
|
||||
|
||||
if allowed_details:
|
||||
allowed_details.sort(key=lambda item: (-item[1], item[0]))
|
||||
return [name for name, _ in allowed_details]
|
||||
|
||||
# Fallback: intersect the allowlist with the provider-advertised names.
|
||||
available_lookup = {name.lower(): name for name in available}
|
||||
display_names: list[str] = []
|
||||
for model_name in sorted(allowed_models):
|
||||
lowered = model_name.lower()
|
||||
if lowered in available_lookup:
|
||||
display_names.append(available_lookup[lowered])
|
||||
|
||||
return display_names
|
||||
|
||||
@classmethod
|
||||
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
|
||||
"""Get list of available model names, optionally filtered by provider.
|
||||
|
||||
Reference in New Issue
Block a user