fix: listmodels to always honor restricted models

fix: restrictions should resolve canonical names for openrouter
fix: tools now correctly return restricted list by presenting model names in schema
fix: tests updated to ensure these manage their expected env vars properly
perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
Fahad
2025-10-04 13:46:22 +04:00
parent 054e34e31c
commit 4015e917ed
17 changed files with 885 additions and 253 deletions

View File

@@ -83,9 +83,18 @@ class ListModelsTool(BaseTool):
from providers.openrouter_registry import OpenRouterModelRegistry
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from utils.model_restrictions import get_restriction_service
output_lines = ["# Available AI Models\n"]
restriction_service = get_restriction_service()
restricted_models_by_provider: dict[ProviderType, list[str]] = {}
if restriction_service:
restricted_map = ModelProviderRegistry.get_available_models(respect_restrictions=True)
for model_name, provider_type in restricted_map.items():
restricted_models_by_provider.setdefault(provider_type, []).append(model_name)
# Map provider types to friendly names and their models
provider_info = {
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
@@ -94,6 +103,43 @@ class ListModelsTool(BaseTool):
ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
}
def format_model_entry(provider, display_name: str) -> list[str]:
try:
capabilities = provider.get_capabilities(display_name)
except ValueError:
return [f"- `{display_name}` *(not recognized by provider)*"]
canonical = capabilities.model_name
if canonical.lower() == display_name.lower():
header = f"- `{canonical}`"
else:
header = f"- `{display_name}` → `{canonical}`"
try:
context_value = capabilities.context_window or 0
except AttributeError:
context_value = 0
try:
context_value = int(context_value)
except (TypeError, ValueError):
context_value = 0
if context_value >= 1_000_000:
context_str = f"{context_value // 1_000_000}M context"
elif context_value >= 1_000:
context_str = f"{context_value // 1_000}K context"
elif context_value > 0:
context_str = f"{context_value} context"
else:
context_str = "unknown context"
try:
description = capabilities.description or "No description available"
except AttributeError:
description = "No description available"
lines = [header, f" - {context_str}", f" - {description}"]
return lines
# Check each native provider type
for provider_type, info in provider_info.items():
# Check if provider is enabled
@@ -104,30 +150,49 @@ class ListModelsTool(BaseTool):
if is_configured:
output_lines.append("**Status**: Configured and available")
output_lines.append("\n**Models**:")
has_restrictions = bool(restriction_service and restriction_service.has_restrictions(provider_type))
aliases = []
for model_name, capabilities in provider.get_capabilities_by_rank():
description = capabilities.description or "No description available"
context_window = capabilities.context_window
if has_restrictions:
restricted_names = sorted(set(restricted_models_by_provider.get(provider_type, [])))
if context_window >= 1_000_000:
context_str = f"{context_window // 1_000_000}M context"
elif context_window >= 1_000:
context_str = f"{context_window // 1_000}K context"
if restricted_names:
output_lines.append("\n**Models (policy restricted)**:")
for model_name in restricted_names:
output_lines.extend(format_model_entry(provider, model_name))
else:
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
output_lines.append("\n*No models are currently allowed by restriction policy.*")
else:
output_lines.append("\n**Models**:")
output_lines.append(f"- `{model_name}` - {context_str}")
output_lines.append(f" - {description}")
aliases = []
for model_name, capabilities in provider.get_capabilities_by_rank():
try:
description = capabilities.description or "No description available"
except AttributeError:
description = "No description available"
for alias in capabilities.aliases or []:
if alias != model_name:
aliases.append(f"- `{alias}` → `{model_name}`")
try:
context_window = capabilities.context_window or 0
except AttributeError:
context_window = 0
if aliases:
output_lines.append("\n**Aliases**:")
output_lines.extend(sorted(aliases))
if context_window >= 1_000_000:
context_str = f"{context_window // 1_000_000}M context"
elif context_window >= 1_000:
context_str = f"{context_window // 1_000}K context"
else:
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
output_lines.append(f"- `{model_name}` - {context_str}")
output_lines.append(f" - {description}")
for alias in capabilities.aliases or []:
if alias != model_name:
aliases.append(f"- `{alias}` → `{model_name}`")
if aliases:
output_lines.append("\n**Aliases**:")
output_lines.extend(sorted(aliases))
else:
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
@@ -144,19 +209,10 @@ class ListModelsTool(BaseTool):
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
try:
# Get OpenRouter provider from registry to properly apply restrictions
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
if provider:
# Get models with restrictions applied
available_models = provider.list_models(respect_restrictions=True)
registry = OpenRouterModelRegistry()
# Group by provider and retain ranking information for consistent ordering
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
def _format_context(tokens: int) -> str:
if not tokens:
return "?"
@@ -166,53 +222,83 @@ class ListModelsTool(BaseTool):
return f"{tokens // 1_000}K"
return str(tokens)
for model_name in available_models:
config = registry.resolve(model_name)
provider_name = "other"
if config and "/" in config.model_name:
provider_name = config.model_name.split("/")[0]
elif "/" in model_name:
provider_name = model_name.split("/")[0]
has_restrictions = bool(
restriction_service and restriction_service.has_restrictions(ProviderType.OPENROUTER)
)
providers_models.setdefault(provider_name, [])
if has_restrictions:
restricted_names = sorted(set(restricted_models_by_provider.get(ProviderType.OPENROUTER, [])))
rank = config.get_effective_capability_rank() if config else 0
providers_models[provider_name].append((rank, model_name, config))
output_lines.append("\n**Models (policy restricted)**:")
if restricted_names:
for model_name in restricted_names:
try:
caps = provider.get_capabilities(model_name)
except ValueError:
output_lines.append(f"- `{model_name}` *(not recognized by provider)*")
continue
output_lines.append("\n**Available Models**:")
for provider_name, models in sorted(providers_models.items()):
output_lines.append(f"\n*{provider_name.title()}:*")
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
if config:
context_str = _format_context(config.context_window)
context_value = int(caps.context_window or 0)
context_str = _format_context(context_value)
suffix_parts = [f"{context_str} context"]
if getattr(config, "supports_extended_thinking", False):
if caps.supports_extended_thinking:
suffix_parts.append("thinking")
suffix = ", ".join(suffix_parts)
output_lines.append(f"- `{alias}` → `{config.model_name}` (score {rank}, {suffix})")
else:
output_lines.append(f"- `{alias}` (score {rank})")
total_models = len(available_models)
# Show all models - no truncation message needed
arrow = ""
if caps.model_name.lower() != model_name.lower():
arrow = f" → `{caps.model_name}`"
# Check if restrictions are applied
restriction_service = None
try:
from utils.model_restrictions import get_restriction_service
score = caps.get_effective_capability_rank()
output_lines.append(f"- `{model_name}`{arrow} (score {score}, {suffix})")
restriction_service = get_restriction_service()
if restriction_service.has_restrictions(ProviderType.OPENROUTER):
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER)
output_lines.append(
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}"
)
except Exception as e:
logger.warning(f"Error checking OpenRouter restrictions: {e}")
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER) or set()
if allowed_set:
output_lines.append(
f"\n*OpenRouter models restricted by OPENROUTER_ALLOWED_MODELS: {', '.join(sorted(allowed_set))}*"
)
else:
output_lines.append("- *No models allowed by current restriction policy.*")
else:
available_models = provider.list_models(respect_restrictions=True)
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
for model_name in available_models:
config = registry.resolve(model_name)
provider_name = "other"
if config and "/" in config.model_name:
provider_name = config.model_name.split("/")[0]
elif "/" in model_name:
provider_name = model_name.split("/")[0]
providers_models.setdefault(provider_name, [])
rank = config.get_effective_capability_rank() if config else 0
providers_models[provider_name].append((rank, model_name, config))
output_lines.append("\n**Available Models**:")
for provider_name, models in sorted(providers_models.items()):
output_lines.append(f"\n*{provider_name.title()}:*")
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
if config:
context_str = _format_context(getattr(config, "context_window", 0))
suffix_parts = [f"{context_str} context"]
if getattr(config, "supports_extended_thinking", False):
suffix_parts.append("thinking")
suffix = ", ".join(suffix_parts)
arrow = ""
if config.model_name.lower() != alias.lower():
arrow = f" → `{config.model_name}`"
output_lines.append(f"- `{alias}`{arrow} (score {rank}, {suffix})")
else:
output_lines.append(f"- `{alias}` (score {rank})")
else:
output_lines.append("**Error**: Could not load OpenRouter provider")
except Exception as e:
logger.exception("Error listing OpenRouter models: %s", e)
output_lines.append(f"**Error loading models**: {str(e)}")
else:
output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)")