fix: listmodels to always honor restricted models

fix: restrictions should resolve canonical names for openrouter
fix: tools now correctly return restricted list by presenting model names in schema
fix: tests updated to ensure these manage their expected env vars properly
perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
Fahad
2025-10-04 13:46:22 +04:00
parent 054e34e31c
commit 4015e917ed
17 changed files with 885 additions and 253 deletions

View File

@@ -1,5 +1,9 @@
# Repository Guidelines
See `requirements.txt` and `requirements-dev.txt`
Also read CLAUDE.md and CLAUDE.local.md if available.
## Project Structure & Module Organization
Zen MCP Server centers on `server.py`, which exposes MCP entrypoints and coordinates multi-model workflows.
Feature-specific tools live in `tools/`, provider integrations in `providers/`, and shared helpers in `utils/`.
@@ -14,6 +18,13 @@ Authoritative documentation and samples live in `docs/`, and runtime diagnostics
- `python communication_simulator_test.py --quick` smoke-test orchestration across tools and providers.
- `./run_integration_tests.sh [--with-simulator]` exercise provider-dependent flows against remote or Ollama models.
For example, this is how we run an individual / all tests:
```
.zen_venv/bin/activate && pytest tests/test_auto_mode_model_listing.py -q
.zen_venv/bin/activate && pytest -q
```
## Coding Style & Naming Conventions
Target Python 3.9+ with Black and isort using a 120-character line limit; Ruff enforces pycodestyle, pyflakes, bugbear, comprehension, and pyupgrade rules. Prefer explicit type hints, snake_case modules, and imperative commit-time docstrings. Extend workflows by defining hook or abstract methods instead of checking `hasattr()`/`getattr()`—inheritance-backed contracts keep behavior discoverable and testable.

View File

@@ -73,6 +73,8 @@ class CustomProvider(OpenAICompatibleProvider):
logging.info(f"Initializing Custom provider with endpoint: {base_url}")
self._alias_cache: dict[str, str] = {}
super().__init__(api_key, base_url=base_url, **kwargs)
# Initialize model registry (shared with OpenRouter for consistent aliases)
@@ -120,11 +122,18 @@ class CustomProvider(OpenAICompatibleProvider):
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve registry aliases and strip version tags for local models."""
cache_key = model_name.lower()
if cache_key in self._alias_cache:
return self._alias_cache[cache_key]
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
resolved = config.model_name
self._alias_cache[cache_key] = resolved
self._alias_cache.setdefault(resolved.lower(), resolved)
return resolved
if ":" in model_name:
base_model = model_name.split(":")[0]
@@ -132,11 +141,16 @@ class CustomProvider(OpenAICompatibleProvider):
base_config = self._registry.resolve(base_model)
if base_config:
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
return base_config.model_name
logging.debug("Resolved base model '%s' to '%s'", base_model, base_config.model_name)
resolved = base_config.model_name
self._alias_cache[cache_key] = resolved
self._alias_cache.setdefault(resolved.lower(), resolved)
return resolved
self._alias_cache[cache_key] = base_model
return base_model
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
self._alias_cache[cache_key] = model_name
return model_name
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:

View File

@@ -39,6 +39,7 @@ class OpenAICompatibleProvider(ModelProvider):
base_url: Base URL for the API endpoint
**kwargs: Additional configuration options including timeout
"""
self._allowed_alias_cache: dict[str, str] = {}
super().__init__(api_key, **kwargs)
self._client = None
self.base_url = base_url
@@ -74,9 +75,33 @@ class OpenAICompatibleProvider(ModelProvider):
canonical = canonical_name.lower()
if requested not in self.allowed_models and canonical not in self.allowed_models:
raise ValueError(
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
)
allowed = False
for allowed_entry in list(self.allowed_models):
normalized_resolved = self._allowed_alias_cache.get(allowed_entry)
if normalized_resolved is None:
try:
resolved_name = self._resolve_model_name(allowed_entry)
except Exception:
continue
if not resolved_name:
continue
normalized_resolved = resolved_name.lower()
self._allowed_alias_cache[allowed_entry] = normalized_resolved
if normalized_resolved == canonical:
# Canonical match discovered via alias resolution mark as allowed and
# memoise the canonical entry for future lookups.
allowed = True
self._allowed_alias_cache[canonical] = canonical
self.allowed_models.add(canonical)
break
if not allowed:
raise ValueError(
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
)
def _parse_allowed_models(self) -> Optional[set[str]]:
"""Parse allowed models from environment variable.
@@ -94,6 +119,7 @@ class OpenAICompatibleProvider(ModelProvider):
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
if models:
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
self._allowed_alias_cache = {}
return models
# Log info if no allow-list configured for proxy providers

View File

@@ -50,6 +50,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
**kwargs: Additional configuration
"""
base_url = "https://openrouter.ai/api/v1"
self._alias_cache: dict[str, str] = {}
super().__init__(api_key, base_url=base_url, **kwargs)
# Initialize model registry
@@ -178,13 +179,21 @@ class OpenRouterProvider(OpenAICompatibleProvider):
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve aliases defined in the OpenRouter registry."""
cache_key = model_name.lower()
if cache_key in self._alias_cache:
return self._alias_cache[cache_key]
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
resolved = config.model_name
self._alias_cache[cache_key] = resolved
self._alias_cache.setdefault(resolved.lower(), resolved)
return resolved
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
self._alias_cache[cache_key] = model_name
return model_name
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:

View File

@@ -205,6 +205,18 @@ class ModelProviderRegistry:
logging.warning("Provider %s does not implement list_models", provider_type)
continue
if restriction_service and restriction_service.has_restrictions(provider_type):
restricted_display = cls._collect_restricted_display_names(
provider,
provider_type,
available,
restriction_service,
)
if restricted_display:
for model_name in restricted_display:
models[model_name] = provider_type
continue
for model_name in available:
# =====================================================================================
# CRITICAL: Prevent double restriction filtering (Fixed Issue #98)
@@ -227,6 +239,50 @@ class ModelProviderRegistry:
return models
@classmethod
def _collect_restricted_display_names(
cls,
provider: ModelProvider,
provider_type: ProviderType,
available: list[str],
restriction_service,
) -> list[str] | None:
"""Derive the human-facing model list when restrictions are active."""
allowed_models = restriction_service.get_allowed_models(provider_type)
if not allowed_models:
return None
allowed_details: list[tuple[str, int]] = []
for model_name in sorted(allowed_models):
try:
capabilities = provider.get_capabilities(model_name)
except (AttributeError, ValueError):
continue
try:
rank = capabilities.get_effective_capability_rank()
rank_value = float(rank)
except (AttributeError, TypeError, ValueError):
rank_value = 0.0
allowed_details.append((model_name, rank_value))
if allowed_details:
allowed_details.sort(key=lambda item: (-item[1], item[0]))
return [name for name, _ in allowed_details]
# Fallback: intersect the allowlist with the provider-advertised names.
available_lookup = {name.lower(): name for name in available}
display_names: list[str] = []
for model_name in sorted(allowed_models):
lowered = model_name.lower()
if lowered in available_lookup:
display_names.append(available_lookup[lowered])
return display_names
@classmethod
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
"""Get list of available model names, optionally filtered by provider.

View File

@@ -492,15 +492,25 @@ def configure_providers():
# Register providers in priority order:
# 1. Native APIs first (most direct and efficient)
registered_providers = []
if has_native_apis:
if gemini_key and gemini_key != "your_gemini_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
registered_providers.append(ProviderType.GOOGLE.value)
logger.debug(f"Registered provider: {ProviderType.GOOGLE.value}")
if openai_key and openai_key != "your_openai_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
registered_providers.append(ProviderType.OPENAI.value)
logger.debug(f"Registered provider: {ProviderType.OPENAI.value}")
if xai_key and xai_key != "your_xai_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
registered_providers.append(ProviderType.XAI.value)
logger.debug(f"Registered provider: {ProviderType.XAI.value}")
if dial_key and dial_key != "your_dial_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider)
registered_providers.append(ProviderType.DIAL.value)
logger.debug(f"Registered provider: {ProviderType.DIAL.value}")
# 2. Custom provider second (for local/private models)
if has_custom:
@@ -511,10 +521,18 @@ def configure_providers():
return CustomProvider(api_key=api_key or "", base_url=base_url) # Use provided API key or empty string
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
registered_providers.append(ProviderType.CUSTOM.value)
logger.debug(f"Registered provider: {ProviderType.CUSTOM.value}")
# 3. OpenRouter last (catch-all for everything else)
if has_openrouter:
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
registered_providers.append(ProviderType.OPENROUTER.value)
logger.debug(f"Registered provider: {ProviderType.OPENROUTER.value}")
# Log all registered providers
if registered_providers:
logger.info(f"Registered providers: {', '.join(registered_providers)}")
# Require at least one valid provider
if not valid_providers:

View File

@@ -63,27 +63,30 @@ class TestAliasTargetRestrictions:
assert provider.validate_model_name("o4mini")
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only
def test_restriction_policy_allows_only_alias_when_alias_specified(self):
"""Test that restriction policy allows only the alias when just alias is specified.
If you restrict to 'mini' (which is an alias for gpt-5-mini),
only the alias should work, not other models.
This is the correct restrictive behavior.
"""
# Clear cached restriction service
def test_restriction_policy_alias_allows_canonical(self):
"""Alias-only allowlists should permit both the alias and its canonical target."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = OpenAIModelProvider(api_key="test-key")
# Only the alias should be allowed
assert provider.validate_model_name("mini")
# Direct target for this alias should NOT be allowed (mini -> gpt-5-mini)
assert not provider.validate_model_name("gpt-5-mini")
# Other models should NOT be allowed
assert provider.validate_model_name("gpt-5-mini")
assert not provider.validate_model_name("o4-mini")
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"})
def test_restriction_policy_alias_allows_short_name(self):
"""Common aliases like 'gpt5' should allow their canonical forms."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = OpenAIModelProvider(api_key="test-key")
assert provider.validate_model_name("gpt5")
assert provider.validate_model_name("gpt-5")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target
def test_gemini_restriction_policy_allows_alias_when_target_allowed(self):
"""Test Gemini restriction policy allows alias when target is allowed."""
@@ -99,19 +102,16 @@ class TestAliasTargetRestrictions:
assert provider.validate_model_name("flash")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}) # Allow alias only
def test_gemini_restriction_policy_allows_only_alias_when_alias_specified(self):
"""Test Gemini restriction policy allows only alias when just alias is specified."""
# Clear cached restriction service
def test_gemini_restriction_policy_alias_allows_canonical(self):
"""Gemini alias allowlists should permit canonical forms."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key")
# Only the alias should be allowed
assert provider.validate_model_name("flash")
# Direct target should NOT be allowed
assert not provider.validate_model_name("gemini-2.5-flash")
assert provider.validate_model_name("gemini-2.5-flash")
def test_restriction_service_validation_includes_all_targets(self):
"""Test that restriction service validation knows about all aliases and targets."""
@@ -153,6 +153,30 @@ class TestAliasTargetRestrictions:
assert provider.validate_model_name("o4-mini") # target
assert provider.validate_model_name("o4mini") # alias for o4-mini
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"}, clear=True)
def test_service_alias_allows_canonical_openai(self):
"""ModelRestrictionService should permit canonical names resolved from aliases."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = OpenAIModelProvider(api_key="test-key")
service = ModelRestrictionService()
assert service.is_allowed(ProviderType.OPENAI, "gpt-5")
assert provider.validate_model_name("gpt-5")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}, clear=True)
def test_service_alias_allows_canonical_gemini(self):
"""Gemini alias allowlists should permit canonical forms."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key")
service = ModelRestrictionService()
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
assert provider.validate_model_name("gemini-2.5-flash")
def test_alias_target_policy_regression_prevention(self):
"""Regression test to ensure aliases and targets are both validated properly.

View File

@@ -106,19 +106,35 @@ class TestAutoMode:
def test_tool_schema_in_normal_mode(self):
"""Test that tool schemas don't require model in normal mode"""
# This test uses the default from conftest.py which sets non-auto mode
# The conftest.py mock_provider_availability fixture ensures the model is available
tool = ChatTool()
schema = tool.get_input_schema()
# Save original
original = os.environ.get("DEFAULT_MODEL", "")
# Model should not be required when default model is configured
assert "model" not in schema["required"]
try:
# Set to a specific model (not auto mode)
os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
import config
# Model field should have simpler description
model_schema = schema["properties"]["model"]
assert "enum" not in model_schema
assert "listmodels" in model_schema["description"]
assert "default model" in model_schema["description"].lower()
importlib.reload(config)
tool = ChatTool()
schema = tool.get_input_schema()
# Model should not be required when default model is configured
assert "model" not in schema["required"]
# Model field should have simpler description
model_schema = schema["properties"]["model"]
assert "enum" not in model_schema
assert "listmodels" in model_schema["description"]
assert "default model" in model_schema["description"].lower()
finally:
# Restore
if original:
os.environ["DEFAULT_MODEL"] = original
else:
os.environ.pop("DEFAULT_MODEL", None)
importlib.reload(config)
@pytest.mark.asyncio
async def test_auto_mode_requires_model_parameter(self):

View File

@@ -0,0 +1,203 @@
"""Tests covering model restriction-aware error messaging in auto mode."""
import asyncio
import importlib
import json
import pytest
import utils.model_restrictions as model_restrictions
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.openrouter import OpenRouterProvider
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from providers.xai import XAIModelProvider
def _extract_available_models(message: str) -> list[str]:
"""Parse the available model list from the error message."""
marker = "Available models: "
if marker not in message:
raise AssertionError(f"Expected '{marker}' in message: {message}")
start = message.index(marker) + len(marker)
end = message.find(". Suggested", start)
if end == -1:
end = len(message)
available_segment = message[start:end].strip()
if not available_segment:
return []
return [item.strip() for item in available_segment.split(",")]
@pytest.fixture
def reset_registry():
"""Ensure registry and restriction service state is isolated."""
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
yield
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
def _register_core_providers(*, include_xai: bool = False):
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
if include_xai:
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
@pytest.mark.no_mock_provider
def test_error_listing_respects_env_restrictions(monkeypatch, reset_registry):
"""Error payload should surface only the allowed models for each provider."""
monkeypatch.setenv("DEFAULT_MODEL", "auto")
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
monkeypatch.delenv("XAI_API_KEY", raising=False)
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
try:
import dotenv
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
except ModuleNotFoundError:
pass
monkeypatch.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro")
monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "gpt-5")
monkeypatch.setenv("OPENROUTER_ALLOWED_MODELS", "gpt5nano")
monkeypatch.setenv("XAI_ALLOWED_MODELS", "")
import config
importlib.reload(config)
_register_core_providers()
import server
importlib.reload(server)
# Reload may have re-applied .env overrides; enforce our test configuration
for key, value in (
("DEFAULT_MODEL", "auto"),
("GEMINI_API_KEY", "test-gemini"),
("OPENAI_API_KEY", "test-openai"),
("OPENROUTER_API_KEY", "test-openrouter"),
("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro"),
("OPENAI_ALLOWED_MODELS", "gpt-5"),
("OPENROUTER_ALLOWED_MODELS", "gpt5nano"),
("XAI_ALLOWED_MODELS", ""),
):
monkeypatch.setenv(key, value)
for var in ("XAI_API_KEY", "CUSTOM_API_URL", "CUSTOM_API_KEY", "DIAL_API_KEY"):
monkeypatch.delenv(var, raising=False)
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
server.configure_providers()
result = asyncio.run(
server.handle_call_tool(
"chat",
{
"model": "gpt5mini",
"prompt": "Tell me about your strengths",
},
)
)
assert len(result) == 1
payload = json.loads(result[0].text)
assert payload["status"] == "error"
available_models = _extract_available_models(payload["content"])
assert set(available_models) == {"gemini-2.5-pro", "gpt-5", "gpt5nano", "openai/gpt-5-nano"}
@pytest.mark.no_mock_provider
def test_error_listing_without_restrictions_shows_full_catalog(monkeypatch, reset_registry):
"""When no restrictions are set, the full high-capability catalogue should appear."""
monkeypatch.setenv("DEFAULT_MODEL", "auto")
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
monkeypatch.setenv("XAI_API_KEY", "test-xai")
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
try:
import dotenv
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
except ModuleNotFoundError:
pass
for var in (
"GOOGLE_ALLOWED_MODELS",
"OPENAI_ALLOWED_MODELS",
"OPENROUTER_ALLOWED_MODELS",
"XAI_ALLOWED_MODELS",
"DIAL_ALLOWED_MODELS",
):
monkeypatch.delenv(var, raising=False)
import config
importlib.reload(config)
_register_core_providers(include_xai=True)
import server
importlib.reload(server)
for key, value in (
("DEFAULT_MODEL", "auto"),
("GEMINI_API_KEY", "test-gemini"),
("OPENAI_API_KEY", "test-openai"),
("OPENROUTER_API_KEY", "test-openrouter"),
):
monkeypatch.setenv(key, value)
for var in (
"GOOGLE_ALLOWED_MODELS",
"OPENAI_ALLOWED_MODELS",
"OPENROUTER_ALLOWED_MODELS",
"XAI_ALLOWED_MODELS",
"DIAL_ALLOWED_MODELS",
"CUSTOM_API_URL",
"CUSTOM_API_KEY",
):
monkeypatch.delenv(var, raising=False)
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
server.configure_providers()
result = asyncio.run(
server.handle_call_tool(
"chat",
{
"model": "dummymodel",
"prompt": "Hi there",
},
)
)
assert len(result) == 1
payload = json.loads(result[0].text)
assert payload["status"] == "error"
available_models = _extract_available_models(payload["content"])
assert "gemini-2.5-pro" in available_models
assert "gpt-5" in available_models
assert "grok-4" in available_models
assert len(available_models) >= 5

View File

@@ -3,6 +3,7 @@ Tests for dynamic context request and collaboration features
"""
import json
import os
from unittest.mock import Mock, patch
import pytest
@@ -157,95 +158,120 @@ class TestDynamicContextRequests:
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
async def test_clarification_with_suggested_action(self, mock_get_provider, analyze_tool):
"""Test clarification request with suggested next action"""
clarification_json = json.dumps(
{
"status": "files_required_to_continue",
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
"files_needed": ["config/database.yml", "src/db.py"],
"suggested_next_action": {
"tool": "analyze",
"args": {
"prompt": "Analyze database connection timeout issue",
"relevant_files": [
"/config/database.yml",
"/src/db.py",
"/logs/error.log",
],
import importlib
from providers.registry import ModelProviderRegistry
# Ensure deterministic model configuration for this test regardless of previous suites
ModelProviderRegistry.reset_for_testing()
original_default = os.environ.get("DEFAULT_MODEL")
try:
os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
import config
importlib.reload(config)
clarification_json = json.dumps(
{
"status": "files_required_to_continue",
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
"files_needed": ["config/database.yml", "src/db.py"],
"suggested_next_action": {
"tool": "analyze",
"args": {
"prompt": "Analyze database connection timeout issue",
"relevant_files": [
"/config/database.yml",
"/src/db.py",
"/logs/error.log",
],
},
},
},
},
ensure_ascii=False,
)
ensure_ascii=False,
)
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
result = await analyze_tool.execute(
{
"step": "Analyze database connection timeout issue",
"step_number": 1,
"total_steps": 1,
"next_step_required": False,
"findings": "Initial database timeout analysis",
"relevant_files": ["/absolute/logs/error.log"],
}
)
result = await analyze_tool.execute(
{
"step": "Analyze database connection timeout issue",
"step_number": 1,
"total_steps": 1,
"next_step_required": False,
"findings": "Initial database timeout analysis",
"relevant_files": ["/absolute/logs/error.log"],
}
)
assert len(result) == 1
assert len(result) == 1
response_data = json.loads(result[0].text)
response_data = json.loads(result[0].text)
# Workflow tools should either promote clarification status or handle it in expert analysis
if response_data["status"] == "files_required_to_continue":
# Clarification was properly promoted to main status
# Check if mandatory_instructions is at top level or in content
if "mandatory_instructions" in response_data:
assert "database configuration" in response_data["mandatory_instructions"]
assert "files_needed" in response_data
assert "config/database.yml" in response_data["files_needed"]
assert "src/db.py" in response_data["files_needed"]
elif "content" in response_data:
# Parse content JSON for workflow tools
try:
content_json = json.loads(response_data["content"])
assert "mandatory_instructions" in content_json
# Workflow tools should either promote clarification status or handle it in expert analysis
if response_data["status"] == "files_required_to_continue":
# Clarification was properly promoted to main status
# Check if mandatory_instructions is at top level or in content
if "mandatory_instructions" in response_data:
assert "database configuration" in response_data["mandatory_instructions"]
assert "files_needed" in response_data
assert "config/database.yml" in response_data["files_needed"]
assert "src/db.py" in response_data["files_needed"]
elif "content" in response_data:
# Parse content JSON for workflow tools
try:
content_json = json.loads(response_data["content"])
assert "mandatory_instructions" in content_json
assert (
"database configuration" in content_json["mandatory_instructions"]
or "database" in content_json["mandatory_instructions"]
)
assert "files_needed" in content_json
files_needed_str = str(content_json["files_needed"])
assert (
"config/database.yml" in files_needed_str
or "config" in files_needed_str
or "database" in files_needed_str
)
except json.JSONDecodeError:
# Content is not JSON, check if it contains required text
content = response_data["content"]
assert "database configuration" in content or "config" in content
elif response_data["status"] == "calling_expert_analysis":
# Clarification may be handled in expert analysis section
if "expert_analysis" in response_data:
expert_analysis = response_data["expert_analysis"]
expert_content = str(expert_analysis)
assert (
"database configuration" in content_json["mandatory_instructions"]
or "database" in content_json["mandatory_instructions"]
"database configuration" in expert_content
or "config/database.yml" in expert_content
or "files_required_to_continue" in expert_content
)
assert "files_needed" in content_json
files_needed_str = str(content_json["files_needed"])
assert (
"config/database.yml" in files_needed_str
or "config" in files_needed_str
or "database" in files_needed_str
)
except json.JSONDecodeError:
# Content is not JSON, check if it contains required text
content = response_data["content"]
assert "database configuration" in content or "config" in content
elif response_data["status"] == "calling_expert_analysis":
# Clarification may be handled in expert analysis section
if "expert_analysis" in response_data:
expert_analysis = response_data["expert_analysis"]
expert_content = str(expert_analysis)
assert (
"database configuration" in expert_content
or "config/database.yml" in expert_content
or "files_required_to_continue" in expert_content
)
else:
# Some other status - ensure it's a valid workflow response
assert "step_number" in response_data
else:
# Some other status - ensure it's a valid workflow response
assert "step_number" in response_data
# Check for suggested next action
if "suggested_next_action" in response_data:
action = response_data["suggested_next_action"]
assert action["tool"] == "analyze"
# Check for suggested next action
if "suggested_next_action" in response_data:
action = response_data["suggested_next_action"]
assert action["tool"] == "analyze"
finally:
if original_default is not None:
os.environ["DEFAULT_MODEL"] = original_default
else:
os.environ.pop("DEFAULT_MODEL", None)
import config
importlib.reload(config)
ModelProviderRegistry.reset_for_testing()
def test_tool_output_model_serialization(self):
"""Test ToolOutput model serialization"""

View File

@@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch
from providers.base import ModelProvider
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from providers.shared import ModelCapabilities, ProviderType
from tools.listmodels import ListModelsTool
@@ -23,10 +23,63 @@ class TestListModelsRestrictions(unittest.TestCase):
self.mock_openrouter = MagicMock(spec=ModelProvider)
self.mock_openrouter.provider_type = ProviderType.OPENROUTER
def make_capabilities(
canonical: str, friendly: str, *, aliases=None, context: int = 200_000
) -> ModelCapabilities:
return ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=canonical,
friendly_name=friendly,
intelligence_score=20,
description=friendly,
aliases=aliases or [],
context_window=context,
max_output_tokens=context,
supports_extended_thinking=True,
)
opus_caps = make_capabilities(
"anthropic/claude-opus-4-20240229",
"Claude Opus",
aliases=["opus"],
)
sonnet_caps = make_capabilities(
"anthropic/claude-sonnet-4-20240229",
"Claude Sonnet",
aliases=["sonnet"],
)
deepseek_caps = make_capabilities(
"deepseek/deepseek-r1-0528:free",
"DeepSeek R1",
aliases=[],
)
qwen_caps = make_capabilities(
"qwen/qwen3-235b-a22b-04-28:free",
"Qwen3",
aliases=[],
)
self._openrouter_caps_map = {
"anthropic/claude-opus-4": opus_caps,
"opus": opus_caps,
"anthropic/claude-opus-4-20240229": opus_caps,
"anthropic/claude-sonnet-4": sonnet_caps,
"sonnet": sonnet_caps,
"anthropic/claude-sonnet-4-20240229": sonnet_caps,
"deepseek/deepseek-r1-0528:free": deepseek_caps,
"qwen/qwen3-235b-a22b-04-28:free": qwen_caps,
}
self.mock_openrouter.get_capabilities.side_effect = self._openrouter_caps_map.__getitem__
self.mock_openrouter.get_capabilities_by_rank.return_value = []
self.mock_openrouter.list_models.return_value = []
# Create mock Gemini provider for comparison
self.mock_gemini = MagicMock(spec=ModelProvider)
self.mock_gemini.provider_type = ProviderType.GOOGLE
self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"]
self.mock_gemini.get_capabilities_by_rank.return_value = []
self.mock_gemini.get_capabilities_by_rank.return_value = []
def tearDown(self):
"""Clean up after tests."""
@@ -159,7 +212,7 @@ class TestListModelsRestrictions(unittest.TestCase):
for line in lines:
if "OpenRouter" in line and "" in line:
openrouter_section_found = True
elif "Available Models" in line and openrouter_section_found:
elif ("Models (policy restricted)" in line or "Available Models" in line) and openrouter_section_found:
in_openrouter_section = True
elif in_openrouter_section:
# Check for lines with model names in backticks
@@ -179,11 +232,11 @@ class TestListModelsRestrictions(unittest.TestCase):
len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}"
)
# Verify list_models was called with respect_restrictions=True
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True)
# Verify we did not fall back to unrestricted listing
self.mock_openrouter.list_models.assert_not_called()
# Check for restriction note
self.assertIn("Restricted to models matching:", result)
self.assertIn("OpenRouter models restricted by", result)
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}, clear=True)
@patch("providers.openrouter_registry.OpenRouterModelRegistry")

View File

@@ -121,38 +121,59 @@ class TestModelMetadataContinuation:
@pytest.mark.asyncio
async def test_no_previous_assistant_turn_defaults(self):
"""Test behavior when there's no previous assistant turn."""
thread_id = create_thread("chat", {"prompt": "test"})
# Save and set DEFAULT_MODEL for test
import importlib
import os
# Only add user turns
add_turn(thread_id, "user", "First question")
add_turn(thread_id, "user", "Second question")
original_default = os.environ.get("DEFAULT_MODEL", "")
os.environ["DEFAULT_MODEL"] = "auto"
import config
import utils.model_context
arguments = {"continuation_id": thread_id}
importlib.reload(config)
importlib.reload(utils.model_context)
# Mock dependencies
with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc:
mock_calc.return_value = MagicMock(
total_tokens=200000,
content_tokens=160000,
response_tokens=40000,
file_tokens=64000,
history_tokens=64000,
)
try:
thread_id = create_thread("chat", {"prompt": "test"})
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000)
# Only add user turns
add_turn(thread_id, "user", "First question")
add_turn(thread_id, "user", "Second question")
# Call the actual function
enhanced_args = await reconstruct_thread_context(arguments)
arguments = {"continuation_id": thread_id}
# Should not have set a model
assert enhanced_args.get("model") is None
# Mock dependencies
with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc:
mock_calc.return_value = MagicMock(
total_tokens=200000,
content_tokens=160000,
response_tokens=40000,
file_tokens=64000,
history_tokens=64000,
)
# ModelContext should use DEFAULT_MODEL
model_context = ModelContext.from_arguments(enhanced_args)
from config import DEFAULT_MODEL
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000)
assert model_context.model_name == DEFAULT_MODEL
# Call the actual function
enhanced_args = await reconstruct_thread_context(arguments)
# Should not have set a model
assert enhanced_args.get("model") is None
# ModelContext should use DEFAULT_MODEL
model_context = ModelContext.from_arguments(enhanced_args)
from config import DEFAULT_MODEL
assert model_context.model_name == DEFAULT_MODEL
finally:
# Restore original value
if original_default:
os.environ["DEFAULT_MODEL"] = original_default
else:
os.environ.pop("DEFAULT_MODEL", None)
importlib.reload(config)
importlib.reload(utils.model_context)
@pytest.mark.asyncio
async def test_explicit_model_overrides_previous_turn(self):

View File

@@ -49,17 +49,32 @@ class TestModelRestrictionService:
def test_load_multiple_models_restriction(self):
"""Test loading multiple allowed models."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
service = ModelRestrictionService()
# Instantiate providers so alias resolution for allow-lists is available
openai_provider = OpenAIModelProvider(api_key="test-key")
gemini_provider = GeminiModelProvider(api_key="test-key")
# Check OpenAI models
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
from providers.registry import ModelProviderRegistry
# Check Google models
assert service.is_allowed(ProviderType.GOOGLE, "flash")
assert service.is_allowed(ProviderType.GOOGLE, "pro")
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
def fake_get_provider(provider_type, force_new=False):
mapping = {
ProviderType.OPENAI: openai_provider,
ProviderType.GOOGLE: gemini_provider,
}
return mapping.get(provider_type)
with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider):
service = ModelRestrictionService()
# Check OpenAI models
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
# Check Google models
assert service.is_allowed(ProviderType.GOOGLE, "flash")
assert service.is_allowed(ProviderType.GOOGLE, "pro")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
def test_case_insensitive_and_whitespace_handling(self):
"""Test that model names are case-insensitive and whitespace is trimmed."""
@@ -111,13 +126,17 @@ class TestModelRestrictionService:
def test_shorthand_names_in_restrictions(self):
"""Test that shorthand names work in restrictions."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,o3-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
# Instantiate providers so the registry can resolve aliases
OpenAIModelProvider(api_key="test-key")
GeminiModelProvider(api_key="test-key")
service = ModelRestrictionService()
# When providers check models, they pass both resolved and original names
# OpenAI: 'mini' shorthand allows o4-mini
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "mini") # How providers actually call it
assert not service.is_allowed(ProviderType.OPENAI, "o4-mini") # Direct check without original (for testing)
# OpenAI: 'o4mini' shorthand allows o4-mini
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini") # How providers actually call it
assert service.is_allowed(ProviderType.OPENAI, "o4-mini") # Canonical should also be allowed
# OpenAI: o3-mini allowed directly
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
@@ -280,19 +299,25 @@ class TestProviderIntegration:
provider = GeminiModelProvider(api_key="test-key")
# Test case: Only alias "flash" is allowed, not the full name
# If parameters are in wrong order, this test will catch it
from providers.registry import ModelProviderRegistry
# Should allow "flash" alias
assert provider.validate_model_name("flash")
with patch.object(ModelProviderRegistry, "get_provider", return_value=provider):
# Should allow getting capabilities for "flash"
capabilities = provider.get_capabilities("flash")
assert capabilities.model_name == "gemini-2.5-flash"
# Test case: Only alias "flash" is allowed, not the full name
# If parameters are in wrong order, this test will catch it
# Test the edge case: Try to use full model name when only alias is allowed
# This should NOT be allowed - only the alias "flash" is in the restriction list
assert not provider.validate_model_name("gemini-2.5-flash")
# Should allow "flash" alias
assert provider.validate_model_name("flash")
# Should allow getting capabilities for "flash"
capabilities = provider.get_capabilities("flash")
assert capabilities.model_name == "gemini-2.5-flash"
# Canonical form should also be allowed now that alias is on the allowlist
assert provider.validate_model_name("gemini-2.5-flash")
# Unrelated models remain blocked
assert not provider.validate_model_name("pro")
assert not provider.validate_model_name("gemini-2.5-pro")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
def test_gemini_parameter_order_edge_case_full_name_only(self):
@@ -570,17 +595,27 @@ class TestShorthandRestrictions:
# Test OpenAI provider
openai_provider = OpenAIModelProvider(api_key="test-key")
assert openai_provider.validate_model_name("mini") # Should work with shorthand
# When restricting to "mini", you can't use "o4-mini" directly - this is correct behavior
assert not openai_provider.validate_model_name("o4-mini") # Not allowed - only shorthand is allowed
assert not openai_provider.validate_model_name("o3-mini") # Not allowed
# Test Gemini provider
gemini_provider = GeminiModelProvider(api_key="test-key")
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
# Same for Gemini - if you restrict to "flash", you can't use the full name
assert not gemini_provider.validate_model_name("gemini-2.5-flash") # Not allowed
assert not gemini_provider.validate_model_name("pro") # Not allowed
from providers.registry import ModelProviderRegistry
def registry_side_effect(provider_type, force_new=False):
mapping = {
ProviderType.OPENAI: openai_provider,
ProviderType.GOOGLE: gemini_provider,
}
return mapping.get(provider_type)
with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect):
assert openai_provider.validate_model_name("mini") # Should work with shorthand
assert openai_provider.validate_model_name("gpt-5-mini") # Canonical resolved from shorthand
assert not openai_provider.validate_model_name("o4-mini") # Unrelated model still blocked
assert not openai_provider.validate_model_name("o3-mini")
# Test Gemini provider
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
assert gemini_provider.validate_model_name("gemini-2.5-flash") # Canonical allowed
assert not gemini_provider.validate_model_name("pro") # Not allowed
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
def test_multiple_shorthands_for_same_model(self):
@@ -596,9 +631,9 @@ class TestShorthandRestrictions:
assert openai_provider.validate_model_name("mini") # mini -> o4-mini
assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini
# Resolved names work only if explicitly allowed
# Resolved names should be allowed when their shorthands are present
assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed
assert not openai_provider.validate_model_name("o3-mini") # Not explicitly allowed, only shorthand
assert openai_provider.validate_model_name("o3-mini") # Allowed via shorthand
# Other models should not work
assert not openai_provider.validate_model_name("o3")

View File

@@ -260,9 +260,10 @@ class TestOpenRouterAutoMode:
os.environ["DEFAULT_MODEL"] = "auto"
mock_provider_class = Mock()
mock_provider_instance = Mock(spec=["get_provider_type", "list_models"])
mock_provider_instance = Mock(spec=["get_provider_type", "list_models", "get_all_model_capabilities"])
mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER
mock_provider_instance.list_models.return_value = []
mock_provider_instance.get_all_model_capabilities.return_value = {}
mock_provider_class.return_value = mock_provider_instance
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class)

View File

@@ -293,13 +293,7 @@ class TestOpenRouterAliasRestrictions:
# o3 -> openai/o3
# gpt4.1 -> should not exist (expected to be filtered out)
expected_models = {
"openai/o3-mini",
"google/gemini-2.5-pro",
"google/gemini-2.5-flash",
"openai/o4-mini",
"openai/o3",
}
expected_models = {"o3-mini", "pro", "flash", "o4-mini", "o3"}
available_model_names = set(available_models.keys())
@@ -355,9 +349,11 @@ class TestOpenRouterAliasRestrictions:
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
expected_models = {
"openai/o3-mini", # from alias
"o3-mini", # alias
"openai/o3-mini", # canonical
"anthropic/claude-opus-4.1", # full name
"google/gemini-2.5-flash", # from alias
"flash", # alias
"google/gemini-2.5-flash", # canonical
}
available_model_names = set(available_models.keys())

View File

@@ -83,9 +83,18 @@ class ListModelsTool(BaseTool):
from providers.openrouter_registry import OpenRouterModelRegistry
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from utils.model_restrictions import get_restriction_service
output_lines = ["# Available AI Models\n"]
restriction_service = get_restriction_service()
restricted_models_by_provider: dict[ProviderType, list[str]] = {}
if restriction_service:
restricted_map = ModelProviderRegistry.get_available_models(respect_restrictions=True)
for model_name, provider_type in restricted_map.items():
restricted_models_by_provider.setdefault(provider_type, []).append(model_name)
# Map provider types to friendly names and their models
provider_info = {
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
@@ -94,6 +103,43 @@ class ListModelsTool(BaseTool):
ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
}
def format_model_entry(provider, display_name: str) -> list[str]:
try:
capabilities = provider.get_capabilities(display_name)
except ValueError:
return [f"- `{display_name}` *(not recognized by provider)*"]
canonical = capabilities.model_name
if canonical.lower() == display_name.lower():
header = f"- `{canonical}`"
else:
header = f"- `{display_name}` → `{canonical}`"
try:
context_value = capabilities.context_window or 0
except AttributeError:
context_value = 0
try:
context_value = int(context_value)
except (TypeError, ValueError):
context_value = 0
if context_value >= 1_000_000:
context_str = f"{context_value // 1_000_000}M context"
elif context_value >= 1_000:
context_str = f"{context_value // 1_000}K context"
elif context_value > 0:
context_str = f"{context_value} context"
else:
context_str = "unknown context"
try:
description = capabilities.description or "No description available"
except AttributeError:
description = "No description available"
lines = [header, f" - {context_str}", f" - {description}"]
return lines
# Check each native provider type
for provider_type, info in provider_info.items():
# Check if provider is enabled
@@ -104,30 +150,49 @@ class ListModelsTool(BaseTool):
if is_configured:
output_lines.append("**Status**: Configured and available")
output_lines.append("\n**Models**:")
has_restrictions = bool(restriction_service and restriction_service.has_restrictions(provider_type))
aliases = []
for model_name, capabilities in provider.get_capabilities_by_rank():
description = capabilities.description or "No description available"
context_window = capabilities.context_window
if has_restrictions:
restricted_names = sorted(set(restricted_models_by_provider.get(provider_type, [])))
if context_window >= 1_000_000:
context_str = f"{context_window // 1_000_000}M context"
elif context_window >= 1_000:
context_str = f"{context_window // 1_000}K context"
if restricted_names:
output_lines.append("\n**Models (policy restricted)**:")
for model_name in restricted_names:
output_lines.extend(format_model_entry(provider, model_name))
else:
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
output_lines.append("\n*No models are currently allowed by restriction policy.*")
else:
output_lines.append("\n**Models**:")
output_lines.append(f"- `{model_name}` - {context_str}")
output_lines.append(f" - {description}")
aliases = []
for model_name, capabilities in provider.get_capabilities_by_rank():
try:
description = capabilities.description or "No description available"
except AttributeError:
description = "No description available"
for alias in capabilities.aliases or []:
if alias != model_name:
aliases.append(f"- `{alias}` → `{model_name}`")
try:
context_window = capabilities.context_window or 0
except AttributeError:
context_window = 0
if aliases:
output_lines.append("\n**Aliases**:")
output_lines.extend(sorted(aliases))
if context_window >= 1_000_000:
context_str = f"{context_window // 1_000_000}M context"
elif context_window >= 1_000:
context_str = f"{context_window // 1_000}K context"
else:
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
output_lines.append(f"- `{model_name}` - {context_str}")
output_lines.append(f" - {description}")
for alias in capabilities.aliases or []:
if alias != model_name:
aliases.append(f"- `{alias}` → `{model_name}`")
if aliases:
output_lines.append("\n**Aliases**:")
output_lines.extend(sorted(aliases))
else:
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
@@ -144,19 +209,10 @@ class ListModelsTool(BaseTool):
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
try:
# Get OpenRouter provider from registry to properly apply restrictions
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
if provider:
# Get models with restrictions applied
available_models = provider.list_models(respect_restrictions=True)
registry = OpenRouterModelRegistry()
# Group by provider and retain ranking information for consistent ordering
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
def _format_context(tokens: int) -> str:
if not tokens:
return "?"
@@ -166,53 +222,83 @@ class ListModelsTool(BaseTool):
return f"{tokens // 1_000}K"
return str(tokens)
for model_name in available_models:
config = registry.resolve(model_name)
provider_name = "other"
if config and "/" in config.model_name:
provider_name = config.model_name.split("/")[0]
elif "/" in model_name:
provider_name = model_name.split("/")[0]
has_restrictions = bool(
restriction_service and restriction_service.has_restrictions(ProviderType.OPENROUTER)
)
providers_models.setdefault(provider_name, [])
if has_restrictions:
restricted_names = sorted(set(restricted_models_by_provider.get(ProviderType.OPENROUTER, [])))
rank = config.get_effective_capability_rank() if config else 0
providers_models[provider_name].append((rank, model_name, config))
output_lines.append("\n**Models (policy restricted)**:")
if restricted_names:
for model_name in restricted_names:
try:
caps = provider.get_capabilities(model_name)
except ValueError:
output_lines.append(f"- `{model_name}` *(not recognized by provider)*")
continue
output_lines.append("\n**Available Models**:")
for provider_name, models in sorted(providers_models.items()):
output_lines.append(f"\n*{provider_name.title()}:*")
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
if config:
context_str = _format_context(config.context_window)
context_value = int(caps.context_window or 0)
context_str = _format_context(context_value)
suffix_parts = [f"{context_str} context"]
if getattr(config, "supports_extended_thinking", False):
if caps.supports_extended_thinking:
suffix_parts.append("thinking")
suffix = ", ".join(suffix_parts)
output_lines.append(f"- `{alias}` → `{config.model_name}` (score {rank}, {suffix})")
else:
output_lines.append(f"- `{alias}` (score {rank})")
total_models = len(available_models)
# Show all models - no truncation message needed
arrow = ""
if caps.model_name.lower() != model_name.lower():
arrow = f" → `{caps.model_name}`"
# Check if restrictions are applied
restriction_service = None
try:
from utils.model_restrictions import get_restriction_service
score = caps.get_effective_capability_rank()
output_lines.append(f"- `{model_name}`{arrow} (score {score}, {suffix})")
restriction_service = get_restriction_service()
if restriction_service.has_restrictions(ProviderType.OPENROUTER):
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER)
output_lines.append(
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}"
)
except Exception as e:
logger.warning(f"Error checking OpenRouter restrictions: {e}")
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER) or set()
if allowed_set:
output_lines.append(
f"\n*OpenRouter models restricted by OPENROUTER_ALLOWED_MODELS: {', '.join(sorted(allowed_set))}*"
)
else:
output_lines.append("- *No models allowed by current restriction policy.*")
else:
available_models = provider.list_models(respect_restrictions=True)
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
for model_name in available_models:
config = registry.resolve(model_name)
provider_name = "other"
if config and "/" in config.model_name:
provider_name = config.model_name.split("/")[0]
elif "/" in model_name:
provider_name = model_name.split("/")[0]
providers_models.setdefault(provider_name, [])
rank = config.get_effective_capability_rank() if config else 0
providers_models[provider_name].append((rank, model_name, config))
output_lines.append("\n**Available Models**:")
for provider_name, models in sorted(providers_models.items()):
output_lines.append(f"\n*{provider_name.title()}:*")
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
if config:
context_str = _format_context(getattr(config, "context_window", 0))
suffix_parts = [f"{context_str} context"]
if getattr(config, "supports_extended_thinking", False):
suffix_parts.append("thinking")
suffix = ", ".join(suffix_parts)
arrow = ""
if config.model_name.lower() != alias.lower():
arrow = f" → `{config.model_name}`"
output_lines.append(f"- `{alias}`{arrow} (score {rank}, {suffix})")
else:
output_lines.append(f"- `{alias}` (score {rank})")
else:
output_lines.append("**Error**: Could not load OpenRouter provider")
except Exception as e:
logger.exception("Error listing OpenRouter models: %s", e)
output_lines.append(f"**Error loading models**: {str(e)}")
else:
output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)")

View File

@@ -22,6 +22,7 @@ Example:
import logging
import os
from collections import defaultdict
from typing import Optional
from providers.shared import ProviderType
@@ -58,6 +59,7 @@ class ModelRestrictionService:
def __init__(self):
"""Initialize the restriction service by loading from environment."""
self.restrictions: dict[ProviderType, set[str]] = {}
self._alias_resolution_cache: dict[ProviderType, dict[str, str]] = defaultdict(dict)
self._load_from_env()
def _load_from_env(self) -> None:
@@ -79,6 +81,7 @@ class ModelRestrictionService:
if models:
self.restrictions[provider_type] = models
self._alias_resolution_cache[provider_type] = {}
logger.info(f"{provider_type.value} allowed models: {sorted(models)}")
else:
# All entries were empty after cleaning - treat as no restrictions
@@ -150,7 +153,41 @@ class ModelRestrictionService:
names_to_check.add(original_name.lower())
# If any of the names is in the allowed set, it's allowed
return any(name in allowed_set for name in names_to_check)
if any(name in allowed_set for name in names_to_check):
return True
# Attempt to resolve canonical names for allowed aliases using provider metadata.
try:
from providers.registry import ModelProviderRegistry
provider = ModelProviderRegistry.get_provider(provider_type)
except Exception: # pragma: no cover - registry lookup failure shouldn't break validation
provider = None
if provider:
cache = self._alias_resolution_cache.setdefault(provider_type, {})
for allowed_entry in list(allowed_set):
normalized_resolved = cache.get(allowed_entry)
if not normalized_resolved:
try:
resolved = provider._resolve_model_name(allowed_entry)
except Exception: # pragma: no cover - resolution failures are treated as non-matches
continue
if not resolved:
continue
normalized_resolved = resolved.lower()
cache[allowed_entry] = normalized_resolved
if normalized_resolved in names_to_check:
allowed_set.add(normalized_resolved)
cache[normalized_resolved] = normalized_resolved
return True
return False
def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]:
"""