fix: listmodels to always honor restricted models
fix: restrictions should resolve canonical names for openrouter fix: tools now correctly return restricted list by presenting model names in schema fix: tests updated to ensure these manage their expected env vars properly perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
@@ -83,9 +83,18 @@ class ListModelsTool(BaseTool):
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
output_lines = ["# Available AI Models\n"]
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
restricted_models_by_provider: dict[ProviderType, list[str]] = {}
|
||||
|
||||
if restriction_service:
|
||||
restricted_map = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
for model_name, provider_type in restricted_map.items():
|
||||
restricted_models_by_provider.setdefault(provider_type, []).append(model_name)
|
||||
|
||||
# Map provider types to friendly names and their models
|
||||
provider_info = {
|
||||
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
|
||||
@@ -94,6 +103,43 @@ class ListModelsTool(BaseTool):
|
||||
ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
|
||||
}
|
||||
|
||||
def format_model_entry(provider, display_name: str) -> list[str]:
|
||||
try:
|
||||
capabilities = provider.get_capabilities(display_name)
|
||||
except ValueError:
|
||||
return [f"- `{display_name}` *(not recognized by provider)*"]
|
||||
|
||||
canonical = capabilities.model_name
|
||||
if canonical.lower() == display_name.lower():
|
||||
header = f"- `{canonical}`"
|
||||
else:
|
||||
header = f"- `{display_name}` → `{canonical}`"
|
||||
|
||||
try:
|
||||
context_value = capabilities.context_window or 0
|
||||
except AttributeError:
|
||||
context_value = 0
|
||||
try:
|
||||
context_value = int(context_value)
|
||||
except (TypeError, ValueError):
|
||||
context_value = 0
|
||||
|
||||
if context_value >= 1_000_000:
|
||||
context_str = f"{context_value // 1_000_000}M context"
|
||||
elif context_value >= 1_000:
|
||||
context_str = f"{context_value // 1_000}K context"
|
||||
elif context_value > 0:
|
||||
context_str = f"{context_value} context"
|
||||
else:
|
||||
context_str = "unknown context"
|
||||
|
||||
try:
|
||||
description = capabilities.description or "No description available"
|
||||
except AttributeError:
|
||||
description = "No description available"
|
||||
lines = [header, f" - {context_str}", f" - {description}"]
|
||||
return lines
|
||||
|
||||
# Check each native provider type
|
||||
for provider_type, info in provider_info.items():
|
||||
# Check if provider is enabled
|
||||
@@ -104,30 +150,49 @@ class ListModelsTool(BaseTool):
|
||||
|
||||
if is_configured:
|
||||
output_lines.append("**Status**: Configured and available")
|
||||
output_lines.append("\n**Models**:")
|
||||
has_restrictions = bool(restriction_service and restriction_service.has_restrictions(provider_type))
|
||||
|
||||
aliases = []
|
||||
for model_name, capabilities in provider.get_capabilities_by_rank():
|
||||
description = capabilities.description or "No description available"
|
||||
context_window = capabilities.context_window
|
||||
if has_restrictions:
|
||||
restricted_names = sorted(set(restricted_models_by_provider.get(provider_type, [])))
|
||||
|
||||
if context_window >= 1_000_000:
|
||||
context_str = f"{context_window // 1_000_000}M context"
|
||||
elif context_window >= 1_000:
|
||||
context_str = f"{context_window // 1_000}K context"
|
||||
if restricted_names:
|
||||
output_lines.append("\n**Models (policy restricted)**:")
|
||||
for model_name in restricted_names:
|
||||
output_lines.extend(format_model_entry(provider, model_name))
|
||||
else:
|
||||
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
|
||||
output_lines.append("\n*No models are currently allowed by restriction policy.*")
|
||||
else:
|
||||
output_lines.append("\n**Models**:")
|
||||
|
||||
output_lines.append(f"- `{model_name}` - {context_str}")
|
||||
output_lines.append(f" - {description}")
|
||||
aliases = []
|
||||
for model_name, capabilities in provider.get_capabilities_by_rank():
|
||||
try:
|
||||
description = capabilities.description or "No description available"
|
||||
except AttributeError:
|
||||
description = "No description available"
|
||||
|
||||
for alias in capabilities.aliases or []:
|
||||
if alias != model_name:
|
||||
aliases.append(f"- `{alias}` → `{model_name}`")
|
||||
try:
|
||||
context_window = capabilities.context_window or 0
|
||||
except AttributeError:
|
||||
context_window = 0
|
||||
|
||||
if aliases:
|
||||
output_lines.append("\n**Aliases**:")
|
||||
output_lines.extend(sorted(aliases))
|
||||
if context_window >= 1_000_000:
|
||||
context_str = f"{context_window // 1_000_000}M context"
|
||||
elif context_window >= 1_000:
|
||||
context_str = f"{context_window // 1_000}K context"
|
||||
else:
|
||||
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
|
||||
|
||||
output_lines.append(f"- `{model_name}` - {context_str}")
|
||||
output_lines.append(f" - {description}")
|
||||
|
||||
for alias in capabilities.aliases or []:
|
||||
if alias != model_name:
|
||||
aliases.append(f"- `{alias}` → `{model_name}`")
|
||||
|
||||
if aliases:
|
||||
output_lines.append("\n**Aliases**:")
|
||||
output_lines.extend(sorted(aliases))
|
||||
else:
|
||||
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
||||
|
||||
@@ -144,19 +209,10 @@ class ListModelsTool(BaseTool):
|
||||
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
|
||||
|
||||
try:
|
||||
# Get OpenRouter provider from registry to properly apply restrictions
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||
if provider:
|
||||
# Get models with restrictions applied
|
||||
available_models = provider.list_models(respect_restrictions=True)
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
# Group by provider and retain ranking information for consistent ordering
|
||||
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
|
||||
|
||||
def _format_context(tokens: int) -> str:
|
||||
if not tokens:
|
||||
return "?"
|
||||
@@ -166,53 +222,83 @@ class ListModelsTool(BaseTool):
|
||||
return f"{tokens // 1_000}K"
|
||||
return str(tokens)
|
||||
|
||||
for model_name in available_models:
|
||||
config = registry.resolve(model_name)
|
||||
provider_name = "other"
|
||||
if config and "/" in config.model_name:
|
||||
provider_name = config.model_name.split("/")[0]
|
||||
elif "/" in model_name:
|
||||
provider_name = model_name.split("/")[0]
|
||||
has_restrictions = bool(
|
||||
restriction_service and restriction_service.has_restrictions(ProviderType.OPENROUTER)
|
||||
)
|
||||
|
||||
providers_models.setdefault(provider_name, [])
|
||||
if has_restrictions:
|
||||
restricted_names = sorted(set(restricted_models_by_provider.get(ProviderType.OPENROUTER, [])))
|
||||
|
||||
rank = config.get_effective_capability_rank() if config else 0
|
||||
providers_models[provider_name].append((rank, model_name, config))
|
||||
output_lines.append("\n**Models (policy restricted)**:")
|
||||
if restricted_names:
|
||||
for model_name in restricted_names:
|
||||
try:
|
||||
caps = provider.get_capabilities(model_name)
|
||||
except ValueError:
|
||||
output_lines.append(f"- `{model_name}` *(not recognized by provider)*")
|
||||
continue
|
||||
|
||||
output_lines.append("\n**Available Models**:")
|
||||
for provider_name, models in sorted(providers_models.items()):
|
||||
output_lines.append(f"\n*{provider_name.title()}:*")
|
||||
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
|
||||
if config:
|
||||
context_str = _format_context(config.context_window)
|
||||
context_value = int(caps.context_window or 0)
|
||||
context_str = _format_context(context_value)
|
||||
suffix_parts = [f"{context_str} context"]
|
||||
if getattr(config, "supports_extended_thinking", False):
|
||||
if caps.supports_extended_thinking:
|
||||
suffix_parts.append("thinking")
|
||||
suffix = ", ".join(suffix_parts)
|
||||
output_lines.append(f"- `{alias}` → `{config.model_name}` (score {rank}, {suffix})")
|
||||
else:
|
||||
output_lines.append(f"- `{alias}` (score {rank})")
|
||||
|
||||
total_models = len(available_models)
|
||||
# Show all models - no truncation message needed
|
||||
arrow = ""
|
||||
if caps.model_name.lower() != model_name.lower():
|
||||
arrow = f" → `{caps.model_name}`"
|
||||
|
||||
# Check if restrictions are applied
|
||||
restriction_service = None
|
||||
try:
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
score = caps.get_effective_capability_rank()
|
||||
output_lines.append(f"- `{model_name}`{arrow} (score {score}, {suffix})")
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if restriction_service.has_restrictions(ProviderType.OPENROUTER):
|
||||
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER)
|
||||
output_lines.append(
|
||||
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking OpenRouter restrictions: {e}")
|
||||
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER) or set()
|
||||
if allowed_set:
|
||||
output_lines.append(
|
||||
f"\n*OpenRouter models restricted by OPENROUTER_ALLOWED_MODELS: {', '.join(sorted(allowed_set))}*"
|
||||
)
|
||||
else:
|
||||
output_lines.append("- *No models allowed by current restriction policy.*")
|
||||
else:
|
||||
available_models = provider.list_models(respect_restrictions=True)
|
||||
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
|
||||
|
||||
for model_name in available_models:
|
||||
config = registry.resolve(model_name)
|
||||
provider_name = "other"
|
||||
if config and "/" in config.model_name:
|
||||
provider_name = config.model_name.split("/")[0]
|
||||
elif "/" in model_name:
|
||||
provider_name = model_name.split("/")[0]
|
||||
|
||||
providers_models.setdefault(provider_name, [])
|
||||
|
||||
rank = config.get_effective_capability_rank() if config else 0
|
||||
providers_models[provider_name].append((rank, model_name, config))
|
||||
|
||||
output_lines.append("\n**Available Models**:")
|
||||
for provider_name, models in sorted(providers_models.items()):
|
||||
output_lines.append(f"\n*{provider_name.title()}:*")
|
||||
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
|
||||
if config:
|
||||
context_str = _format_context(getattr(config, "context_window", 0))
|
||||
suffix_parts = [f"{context_str} context"]
|
||||
if getattr(config, "supports_extended_thinking", False):
|
||||
suffix_parts.append("thinking")
|
||||
suffix = ", ".join(suffix_parts)
|
||||
|
||||
arrow = ""
|
||||
if config.model_name.lower() != alias.lower():
|
||||
arrow = f" → `{config.model_name}`"
|
||||
|
||||
output_lines.append(f"- `{alias}`{arrow} (score {rank}, {suffix})")
|
||||
else:
|
||||
output_lines.append(f"- `{alias}` (score {rank})")
|
||||
else:
|
||||
output_lines.append("**Error**: Could not load OpenRouter provider")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error listing OpenRouter models: %s", e)
|
||||
output_lines.append(f"**Error loading models**: {str(e)}")
|
||||
else:
|
||||
output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)")
|
||||
|
||||
Reference in New Issue
Block a user