fix: listmodels to always honor restricted models

fix: restrictions should resolve canonical names for openrouter
fix: tools now correctly return restricted list by presenting model names in schema
fix: tests updated to ensure these manage their expected env vars properly
perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
Fahad
2025-10-04 13:46:22 +04:00
parent 054e34e31c
commit 4015e917ed
17 changed files with 885 additions and 253 deletions

View File

@@ -1,5 +1,9 @@
# Repository Guidelines # Repository Guidelines
See `requirements.txt` and `requirements-dev.txt`
Also read CLAUDE.md and CLAUDE.local.md if available.
## Project Structure & Module Organization ## Project Structure & Module Organization
Zen MCP Server centers on `server.py`, which exposes MCP entrypoints and coordinates multi-model workflows. Zen MCP Server centers on `server.py`, which exposes MCP entrypoints and coordinates multi-model workflows.
Feature-specific tools live in `tools/`, provider integrations in `providers/`, and shared helpers in `utils/`. Feature-specific tools live in `tools/`, provider integrations in `providers/`, and shared helpers in `utils/`.
@@ -14,6 +18,13 @@ Authoritative documentation and samples live in `docs/`, and runtime diagnostics
- `python communication_simulator_test.py --quick` smoke-test orchestration across tools and providers. - `python communication_simulator_test.py --quick` smoke-test orchestration across tools and providers.
- `./run_integration_tests.sh [--with-simulator]` exercise provider-dependent flows against remote or Ollama models. - `./run_integration_tests.sh [--with-simulator]` exercise provider-dependent flows against remote or Ollama models.
For example, this is how we run an individual / all tests:
```
.zen_venv/bin/activate && pytest tests/test_auto_mode_model_listing.py -q
.zen_venv/bin/activate && pytest -q
```
## Coding Style & Naming Conventions ## Coding Style & Naming Conventions
Target Python 3.9+ with Black and isort using a 120-character line limit; Ruff enforces pycodestyle, pyflakes, bugbear, comprehension, and pyupgrade rules. Prefer explicit type hints, snake_case modules, and imperative commit-time docstrings. Extend workflows by defining hook or abstract methods instead of checking `hasattr()`/`getattr()`—inheritance-backed contracts keep behavior discoverable and testable. Target Python 3.9+ with Black and isort using a 120-character line limit; Ruff enforces pycodestyle, pyflakes, bugbear, comprehension, and pyupgrade rules. Prefer explicit type hints, snake_case modules, and imperative commit-time docstrings. Extend workflows by defining hook or abstract methods instead of checking `hasattr()`/`getattr()`—inheritance-backed contracts keep behavior discoverable and testable.

View File

@@ -73,6 +73,8 @@ class CustomProvider(OpenAICompatibleProvider):
logging.info(f"Initializing Custom provider with endpoint: {base_url}") logging.info(f"Initializing Custom provider with endpoint: {base_url}")
self._alias_cache: dict[str, str] = {}
super().__init__(api_key, base_url=base_url, **kwargs) super().__init__(api_key, base_url=base_url, **kwargs)
# Initialize model registry (shared with OpenRouter for consistent aliases) # Initialize model registry (shared with OpenRouter for consistent aliases)
@@ -120,11 +122,18 @@ class CustomProvider(OpenAICompatibleProvider):
def _resolve_model_name(self, model_name: str) -> str: def _resolve_model_name(self, model_name: str) -> str:
"""Resolve registry aliases and strip version tags for local models.""" """Resolve registry aliases and strip version tags for local models."""
cache_key = model_name.lower()
if cache_key in self._alias_cache:
return self._alias_cache[cache_key]
config = self._registry.resolve(model_name) config = self._registry.resolve(model_name)
if config: if config:
if config.model_name != model_name: if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
return config.model_name resolved = config.model_name
self._alias_cache[cache_key] = resolved
self._alias_cache.setdefault(resolved.lower(), resolved)
return resolved
if ":" in model_name: if ":" in model_name:
base_model = model_name.split(":")[0] base_model = model_name.split(":")[0]
@@ -132,11 +141,16 @@ class CustomProvider(OpenAICompatibleProvider):
base_config = self._registry.resolve(base_model) base_config = self._registry.resolve(base_model)
if base_config: if base_config:
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'") logging.debug("Resolved base model '%s' to '%s'", base_model, base_config.model_name)
return base_config.model_name resolved = base_config.model_name
self._alias_cache[cache_key] = resolved
self._alias_cache.setdefault(resolved.lower(), resolved)
return resolved
self._alias_cache[cache_key] = base_model
return base_model return base_model
logging.debug(f"Model '{model_name}' not found in registry, using as-is") logging.debug(f"Model '{model_name}' not found in registry, using as-is")
self._alias_cache[cache_key] = model_name
return model_name return model_name
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:

View File

@@ -39,6 +39,7 @@ class OpenAICompatibleProvider(ModelProvider):
base_url: Base URL for the API endpoint base_url: Base URL for the API endpoint
**kwargs: Additional configuration options including timeout **kwargs: Additional configuration options including timeout
""" """
self._allowed_alias_cache: dict[str, str] = {}
super().__init__(api_key, **kwargs) super().__init__(api_key, **kwargs)
self._client = None self._client = None
self.base_url = base_url self.base_url = base_url
@@ -74,9 +75,33 @@ class OpenAICompatibleProvider(ModelProvider):
canonical = canonical_name.lower() canonical = canonical_name.lower()
if requested not in self.allowed_models and canonical not in self.allowed_models: if requested not in self.allowed_models and canonical not in self.allowed_models:
raise ValueError( allowed = False
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}" for allowed_entry in list(self.allowed_models):
) normalized_resolved = self._allowed_alias_cache.get(allowed_entry)
if normalized_resolved is None:
try:
resolved_name = self._resolve_model_name(allowed_entry)
except Exception:
continue
if not resolved_name:
continue
normalized_resolved = resolved_name.lower()
self._allowed_alias_cache[allowed_entry] = normalized_resolved
if normalized_resolved == canonical:
# Canonical match discovered via alias resolution mark as allowed and
# memoise the canonical entry for future lookups.
allowed = True
self._allowed_alias_cache[canonical] = canonical
self.allowed_models.add(canonical)
break
if not allowed:
raise ValueError(
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
)
def _parse_allowed_models(self) -> Optional[set[str]]: def _parse_allowed_models(self) -> Optional[set[str]]:
"""Parse allowed models from environment variable. """Parse allowed models from environment variable.
@@ -94,6 +119,7 @@ class OpenAICompatibleProvider(ModelProvider):
models = {m.strip().lower() for m in models_str.split(",") if m.strip()} models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
if models: if models:
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}") logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
self._allowed_alias_cache = {}
return models return models
# Log info if no allow-list configured for proxy providers # Log info if no allow-list configured for proxy providers

View File

@@ -50,6 +50,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
**kwargs: Additional configuration **kwargs: Additional configuration
""" """
base_url = "https://openrouter.ai/api/v1" base_url = "https://openrouter.ai/api/v1"
self._alias_cache: dict[str, str] = {}
super().__init__(api_key, base_url=base_url, **kwargs) super().__init__(api_key, base_url=base_url, **kwargs)
# Initialize model registry # Initialize model registry
@@ -178,13 +179,21 @@ class OpenRouterProvider(OpenAICompatibleProvider):
def _resolve_model_name(self, model_name: str) -> str: def _resolve_model_name(self, model_name: str) -> str:
"""Resolve aliases defined in the OpenRouter registry.""" """Resolve aliases defined in the OpenRouter registry."""
cache_key = model_name.lower()
if cache_key in self._alias_cache:
return self._alias_cache[cache_key]
config = self._registry.resolve(model_name) config = self._registry.resolve(model_name)
if config: if config:
if config.model_name != model_name: if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
return config.model_name resolved = config.model_name
self._alias_cache[cache_key] = resolved
self._alias_cache.setdefault(resolved.lower(), resolved)
return resolved
logging.debug(f"Model '{model_name}' not found in registry, using as-is") logging.debug(f"Model '{model_name}' not found in registry, using as-is")
self._alias_cache[cache_key] = model_name
return model_name return model_name
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:

View File

@@ -205,6 +205,18 @@ class ModelProviderRegistry:
logging.warning("Provider %s does not implement list_models", provider_type) logging.warning("Provider %s does not implement list_models", provider_type)
continue continue
if restriction_service and restriction_service.has_restrictions(provider_type):
restricted_display = cls._collect_restricted_display_names(
provider,
provider_type,
available,
restriction_service,
)
if restricted_display:
for model_name in restricted_display:
models[model_name] = provider_type
continue
for model_name in available: for model_name in available:
# ===================================================================================== # =====================================================================================
# CRITICAL: Prevent double restriction filtering (Fixed Issue #98) # CRITICAL: Prevent double restriction filtering (Fixed Issue #98)
@@ -227,6 +239,50 @@ class ModelProviderRegistry:
return models return models
@classmethod
def _collect_restricted_display_names(
cls,
provider: ModelProvider,
provider_type: ProviderType,
available: list[str],
restriction_service,
) -> list[str] | None:
"""Derive the human-facing model list when restrictions are active."""
allowed_models = restriction_service.get_allowed_models(provider_type)
if not allowed_models:
return None
allowed_details: list[tuple[str, int]] = []
for model_name in sorted(allowed_models):
try:
capabilities = provider.get_capabilities(model_name)
except (AttributeError, ValueError):
continue
try:
rank = capabilities.get_effective_capability_rank()
rank_value = float(rank)
except (AttributeError, TypeError, ValueError):
rank_value = 0.0
allowed_details.append((model_name, rank_value))
if allowed_details:
allowed_details.sort(key=lambda item: (-item[1], item[0]))
return [name for name, _ in allowed_details]
# Fallback: intersect the allowlist with the provider-advertised names.
available_lookup = {name.lower(): name for name in available}
display_names: list[str] = []
for model_name in sorted(allowed_models):
lowered = model_name.lower()
if lowered in available_lookup:
display_names.append(available_lookup[lowered])
return display_names
@classmethod @classmethod
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]: def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
"""Get list of available model names, optionally filtered by provider. """Get list of available model names, optionally filtered by provider.

View File

@@ -492,15 +492,25 @@ def configure_providers():
# Register providers in priority order: # Register providers in priority order:
# 1. Native APIs first (most direct and efficient) # 1. Native APIs first (most direct and efficient)
registered_providers = []
if has_native_apis: if has_native_apis:
if gemini_key and gemini_key != "your_gemini_api_key_here": if gemini_key and gemini_key != "your_gemini_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider) ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
registered_providers.append(ProviderType.GOOGLE.value)
logger.debug(f"Registered provider: {ProviderType.GOOGLE.value}")
if openai_key and openai_key != "your_openai_api_key_here": if openai_key and openai_key != "your_openai_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider) ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
registered_providers.append(ProviderType.OPENAI.value)
logger.debug(f"Registered provider: {ProviderType.OPENAI.value}")
if xai_key and xai_key != "your_xai_api_key_here": if xai_key and xai_key != "your_xai_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider) ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
registered_providers.append(ProviderType.XAI.value)
logger.debug(f"Registered provider: {ProviderType.XAI.value}")
if dial_key and dial_key != "your_dial_api_key_here": if dial_key and dial_key != "your_dial_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider) ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider)
registered_providers.append(ProviderType.DIAL.value)
logger.debug(f"Registered provider: {ProviderType.DIAL.value}")
# 2. Custom provider second (for local/private models) # 2. Custom provider second (for local/private models)
if has_custom: if has_custom:
@@ -511,10 +521,18 @@ def configure_providers():
return CustomProvider(api_key=api_key or "", base_url=base_url) # Use provided API key or empty string return CustomProvider(api_key=api_key or "", base_url=base_url) # Use provided API key or empty string
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory) ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
registered_providers.append(ProviderType.CUSTOM.value)
logger.debug(f"Registered provider: {ProviderType.CUSTOM.value}")
# 3. OpenRouter last (catch-all for everything else) # 3. OpenRouter last (catch-all for everything else)
if has_openrouter: if has_openrouter:
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
registered_providers.append(ProviderType.OPENROUTER.value)
logger.debug(f"Registered provider: {ProviderType.OPENROUTER.value}")
# Log all registered providers
if registered_providers:
logger.info(f"Registered providers: {', '.join(registered_providers)}")
# Require at least one valid provider # Require at least one valid provider
if not valid_providers: if not valid_providers:

View File

@@ -63,27 +63,30 @@ class TestAliasTargetRestrictions:
assert provider.validate_model_name("o4mini") assert provider.validate_model_name("o4mini")
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only
def test_restriction_policy_allows_only_alias_when_alias_specified(self): def test_restriction_policy_alias_allows_canonical(self):
"""Test that restriction policy allows only the alias when just alias is specified. """Alias-only allowlists should permit both the alias and its canonical target."""
If you restrict to 'mini' (which is an alias for gpt-5-mini),
only the alias should work, not other models.
This is the correct restrictive behavior.
"""
# Clear cached restriction service
import utils.model_restrictions import utils.model_restrictions
utils.model_restrictions._restriction_service = None utils.model_restrictions._restriction_service = None
provider = OpenAIModelProvider(api_key="test-key") provider = OpenAIModelProvider(api_key="test-key")
# Only the alias should be allowed
assert provider.validate_model_name("mini") assert provider.validate_model_name("mini")
# Direct target for this alias should NOT be allowed (mini -> gpt-5-mini) assert provider.validate_model_name("gpt-5-mini")
assert not provider.validate_model_name("gpt-5-mini")
# Other models should NOT be allowed
assert not provider.validate_model_name("o4-mini") assert not provider.validate_model_name("o4-mini")
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"})
def test_restriction_policy_alias_allows_short_name(self):
"""Common aliases like 'gpt5' should allow their canonical forms."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = OpenAIModelProvider(api_key="test-key")
assert provider.validate_model_name("gpt5")
assert provider.validate_model_name("gpt-5")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target
def test_gemini_restriction_policy_allows_alias_when_target_allowed(self): def test_gemini_restriction_policy_allows_alias_when_target_allowed(self):
"""Test Gemini restriction policy allows alias when target is allowed.""" """Test Gemini restriction policy allows alias when target is allowed."""
@@ -99,19 +102,16 @@ class TestAliasTargetRestrictions:
assert provider.validate_model_name("flash") assert provider.validate_model_name("flash")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}) # Allow alias only @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}) # Allow alias only
def test_gemini_restriction_policy_allows_only_alias_when_alias_specified(self): def test_gemini_restriction_policy_alias_allows_canonical(self):
"""Test Gemini restriction policy allows only alias when just alias is specified.""" """Gemini alias allowlists should permit canonical forms."""
# Clear cached restriction service
import utils.model_restrictions import utils.model_restrictions
utils.model_restrictions._restriction_service = None utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key") provider = GeminiModelProvider(api_key="test-key")
# Only the alias should be allowed
assert provider.validate_model_name("flash") assert provider.validate_model_name("flash")
# Direct target should NOT be allowed assert provider.validate_model_name("gemini-2.5-flash")
assert not provider.validate_model_name("gemini-2.5-flash")
def test_restriction_service_validation_includes_all_targets(self): def test_restriction_service_validation_includes_all_targets(self):
"""Test that restriction service validation knows about all aliases and targets.""" """Test that restriction service validation knows about all aliases and targets."""
@@ -153,6 +153,30 @@ class TestAliasTargetRestrictions:
assert provider.validate_model_name("o4-mini") # target assert provider.validate_model_name("o4-mini") # target
assert provider.validate_model_name("o4mini") # alias for o4-mini assert provider.validate_model_name("o4mini") # alias for o4-mini
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"}, clear=True)
def test_service_alias_allows_canonical_openai(self):
"""ModelRestrictionService should permit canonical names resolved from aliases."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = OpenAIModelProvider(api_key="test-key")
service = ModelRestrictionService()
assert service.is_allowed(ProviderType.OPENAI, "gpt-5")
assert provider.validate_model_name("gpt-5")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}, clear=True)
def test_service_alias_allows_canonical_gemini(self):
"""Gemini alias allowlists should permit canonical forms."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key")
service = ModelRestrictionService()
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
assert provider.validate_model_name("gemini-2.5-flash")
def test_alias_target_policy_regression_prevention(self): def test_alias_target_policy_regression_prevention(self):
"""Regression test to ensure aliases and targets are both validated properly. """Regression test to ensure aliases and targets are both validated properly.

View File

@@ -106,19 +106,35 @@ class TestAutoMode:
def test_tool_schema_in_normal_mode(self): def test_tool_schema_in_normal_mode(self):
"""Test that tool schemas don't require model in normal mode""" """Test that tool schemas don't require model in normal mode"""
# This test uses the default from conftest.py which sets non-auto mode # Save original
# The conftest.py mock_provider_availability fixture ensures the model is available original = os.environ.get("DEFAULT_MODEL", "")
tool = ChatTool()
schema = tool.get_input_schema()
# Model should not be required when default model is configured try:
assert "model" not in schema["required"] # Set to a specific model (not auto mode)
os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
import config
# Model field should have simpler description importlib.reload(config)
model_schema = schema["properties"]["model"]
assert "enum" not in model_schema tool = ChatTool()
assert "listmodels" in model_schema["description"] schema = tool.get_input_schema()
assert "default model" in model_schema["description"].lower()
# Model should not be required when default model is configured
assert "model" not in schema["required"]
# Model field should have simpler description
model_schema = schema["properties"]["model"]
assert "enum" not in model_schema
assert "listmodels" in model_schema["description"]
assert "default model" in model_schema["description"].lower()
finally:
# Restore
if original:
os.environ["DEFAULT_MODEL"] = original
else:
os.environ.pop("DEFAULT_MODEL", None)
importlib.reload(config)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_auto_mode_requires_model_parameter(self): async def test_auto_mode_requires_model_parameter(self):

View File

@@ -0,0 +1,203 @@
"""Tests covering model restriction-aware error messaging in auto mode."""
import asyncio
import importlib
import json
import pytest
import utils.model_restrictions as model_restrictions
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.openrouter import OpenRouterProvider
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from providers.xai import XAIModelProvider
def _extract_available_models(message: str) -> list[str]:
"""Parse the available model list from the error message."""
marker = "Available models: "
if marker not in message:
raise AssertionError(f"Expected '{marker}' in message: {message}")
start = message.index(marker) + len(marker)
end = message.find(". Suggested", start)
if end == -1:
end = len(message)
available_segment = message[start:end].strip()
if not available_segment:
return []
return [item.strip() for item in available_segment.split(",")]
@pytest.fixture
def reset_registry():
"""Ensure registry and restriction service state is isolated."""
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
yield
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
def _register_core_providers(*, include_xai: bool = False):
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
if include_xai:
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
@pytest.mark.no_mock_provider
def test_error_listing_respects_env_restrictions(monkeypatch, reset_registry):
"""Error payload should surface only the allowed models for each provider."""
monkeypatch.setenv("DEFAULT_MODEL", "auto")
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
monkeypatch.delenv("XAI_API_KEY", raising=False)
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
try:
import dotenv
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
except ModuleNotFoundError:
pass
monkeypatch.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro")
monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "gpt-5")
monkeypatch.setenv("OPENROUTER_ALLOWED_MODELS", "gpt5nano")
monkeypatch.setenv("XAI_ALLOWED_MODELS", "")
import config
importlib.reload(config)
_register_core_providers()
import server
importlib.reload(server)
# Reload may have re-applied .env overrides; enforce our test configuration
for key, value in (
("DEFAULT_MODEL", "auto"),
("GEMINI_API_KEY", "test-gemini"),
("OPENAI_API_KEY", "test-openai"),
("OPENROUTER_API_KEY", "test-openrouter"),
("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro"),
("OPENAI_ALLOWED_MODELS", "gpt-5"),
("OPENROUTER_ALLOWED_MODELS", "gpt5nano"),
("XAI_ALLOWED_MODELS", ""),
):
monkeypatch.setenv(key, value)
for var in ("XAI_API_KEY", "CUSTOM_API_URL", "CUSTOM_API_KEY", "DIAL_API_KEY"):
monkeypatch.delenv(var, raising=False)
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
server.configure_providers()
result = asyncio.run(
server.handle_call_tool(
"chat",
{
"model": "gpt5mini",
"prompt": "Tell me about your strengths",
},
)
)
assert len(result) == 1
payload = json.loads(result[0].text)
assert payload["status"] == "error"
available_models = _extract_available_models(payload["content"])
assert set(available_models) == {"gemini-2.5-pro", "gpt-5", "gpt5nano", "openai/gpt-5-nano"}
@pytest.mark.no_mock_provider
def test_error_listing_without_restrictions_shows_full_catalog(monkeypatch, reset_registry):
"""When no restrictions are set, the full high-capability catalogue should appear."""
monkeypatch.setenv("DEFAULT_MODEL", "auto")
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
monkeypatch.setenv("XAI_API_KEY", "test-xai")
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
try:
import dotenv
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
except ModuleNotFoundError:
pass
for var in (
"GOOGLE_ALLOWED_MODELS",
"OPENAI_ALLOWED_MODELS",
"OPENROUTER_ALLOWED_MODELS",
"XAI_ALLOWED_MODELS",
"DIAL_ALLOWED_MODELS",
):
monkeypatch.delenv(var, raising=False)
import config
importlib.reload(config)
_register_core_providers(include_xai=True)
import server
importlib.reload(server)
for key, value in (
("DEFAULT_MODEL", "auto"),
("GEMINI_API_KEY", "test-gemini"),
("OPENAI_API_KEY", "test-openai"),
("OPENROUTER_API_KEY", "test-openrouter"),
):
monkeypatch.setenv(key, value)
for var in (
"GOOGLE_ALLOWED_MODELS",
"OPENAI_ALLOWED_MODELS",
"OPENROUTER_ALLOWED_MODELS",
"XAI_ALLOWED_MODELS",
"DIAL_ALLOWED_MODELS",
"CUSTOM_API_URL",
"CUSTOM_API_KEY",
):
monkeypatch.delenv(var, raising=False)
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
server.configure_providers()
result = asyncio.run(
server.handle_call_tool(
"chat",
{
"model": "dummymodel",
"prompt": "Hi there",
},
)
)
assert len(result) == 1
payload = json.loads(result[0].text)
assert payload["status"] == "error"
available_models = _extract_available_models(payload["content"])
assert "gemini-2.5-pro" in available_models
assert "gpt-5" in available_models
assert "grok-4" in available_models
assert len(available_models) >= 5

View File

@@ -3,6 +3,7 @@ Tests for dynamic context request and collaboration features
""" """
import json import json
import os
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
@@ -157,95 +158,120 @@ class TestDynamicContextRequests:
@patch("tools.shared.base_tool.BaseTool.get_model_provider") @patch("tools.shared.base_tool.BaseTool.get_model_provider")
async def test_clarification_with_suggested_action(self, mock_get_provider, analyze_tool): async def test_clarification_with_suggested_action(self, mock_get_provider, analyze_tool):
"""Test clarification request with suggested next action""" """Test clarification request with suggested next action"""
clarification_json = json.dumps( import importlib
{
"status": "files_required_to_continue", from providers.registry import ModelProviderRegistry
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
"files_needed": ["config/database.yml", "src/db.py"], # Ensure deterministic model configuration for this test regardless of previous suites
"suggested_next_action": { ModelProviderRegistry.reset_for_testing()
"tool": "analyze",
"args": { original_default = os.environ.get("DEFAULT_MODEL")
"prompt": "Analyze database connection timeout issue",
"relevant_files": [ try:
"/config/database.yml", os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
"/src/db.py", import config
"/logs/error.log",
], importlib.reload(config)
clarification_json = json.dumps(
{
"status": "files_required_to_continue",
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
"files_needed": ["config/database.yml", "src/db.py"],
"suggested_next_action": {
"tool": "analyze",
"args": {
"prompt": "Analyze database connection timeout issue",
"relevant_files": [
"/config/database.yml",
"/src/db.py",
"/logs/error.log",
],
},
}, },
}, },
}, ensure_ascii=False,
ensure_ascii=False, )
)
mock_provider = create_mock_provider() mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.generate_content.return_value = Mock( mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={} content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={}
) )
mock_get_provider.return_value = mock_provider mock_get_provider.return_value = mock_provider
result = await analyze_tool.execute( result = await analyze_tool.execute(
{ {
"step": "Analyze database connection timeout issue", "step": "Analyze database connection timeout issue",
"step_number": 1, "step_number": 1,
"total_steps": 1, "total_steps": 1,
"next_step_required": False, "next_step_required": False,
"findings": "Initial database timeout analysis", "findings": "Initial database timeout analysis",
"relevant_files": ["/absolute/logs/error.log"], "relevant_files": ["/absolute/logs/error.log"],
} }
) )
assert len(result) == 1 assert len(result) == 1
response_data = json.loads(result[0].text) response_data = json.loads(result[0].text)
# Workflow tools should either promote clarification status or handle it in expert analysis # Workflow tools should either promote clarification status or handle it in expert analysis
if response_data["status"] == "files_required_to_continue": if response_data["status"] == "files_required_to_continue":
# Clarification was properly promoted to main status # Clarification was properly promoted to main status
# Check if mandatory_instructions is at top level or in content # Check if mandatory_instructions is at top level or in content
if "mandatory_instructions" in response_data: if "mandatory_instructions" in response_data:
assert "database configuration" in response_data["mandatory_instructions"] assert "database configuration" in response_data["mandatory_instructions"]
assert "files_needed" in response_data assert "files_needed" in response_data
assert "config/database.yml" in response_data["files_needed"] assert "config/database.yml" in response_data["files_needed"]
assert "src/db.py" in response_data["files_needed"] assert "src/db.py" in response_data["files_needed"]
elif "content" in response_data: elif "content" in response_data:
# Parse content JSON for workflow tools # Parse content JSON for workflow tools
try: try:
content_json = json.loads(response_data["content"]) content_json = json.loads(response_data["content"])
assert "mandatory_instructions" in content_json assert "mandatory_instructions" in content_json
assert (
"database configuration" in content_json["mandatory_instructions"]
or "database" in content_json["mandatory_instructions"]
)
assert "files_needed" in content_json
files_needed_str = str(content_json["files_needed"])
assert (
"config/database.yml" in files_needed_str
or "config" in files_needed_str
or "database" in files_needed_str
)
except json.JSONDecodeError:
# Content is not JSON, check if it contains required text
content = response_data["content"]
assert "database configuration" in content or "config" in content
elif response_data["status"] == "calling_expert_analysis":
# Clarification may be handled in expert analysis section
if "expert_analysis" in response_data:
expert_analysis = response_data["expert_analysis"]
expert_content = str(expert_analysis)
assert ( assert (
"database configuration" in content_json["mandatory_instructions"] "database configuration" in expert_content
or "database" in content_json["mandatory_instructions"] or "config/database.yml" in expert_content
or "files_required_to_continue" in expert_content
) )
assert "files_needed" in content_json else:
files_needed_str = str(content_json["files_needed"]) # Some other status - ensure it's a valid workflow response
assert ( assert "step_number" in response_data
"config/database.yml" in files_needed_str
or "config" in files_needed_str
or "database" in files_needed_str
)
except json.JSONDecodeError:
# Content is not JSON, check if it contains required text
content = response_data["content"]
assert "database configuration" in content or "config" in content
elif response_data["status"] == "calling_expert_analysis":
# Clarification may be handled in expert analysis section
if "expert_analysis" in response_data:
expert_analysis = response_data["expert_analysis"]
expert_content = str(expert_analysis)
assert (
"database configuration" in expert_content
or "config/database.yml" in expert_content
or "files_required_to_continue" in expert_content
)
else:
# Some other status - ensure it's a valid workflow response
assert "step_number" in response_data
# Check for suggested next action # Check for suggested next action
if "suggested_next_action" in response_data: if "suggested_next_action" in response_data:
action = response_data["suggested_next_action"] action = response_data["suggested_next_action"]
assert action["tool"] == "analyze" assert action["tool"] == "analyze"
finally:
if original_default is not None:
os.environ["DEFAULT_MODEL"] = original_default
else:
os.environ.pop("DEFAULT_MODEL", None)
import config
importlib.reload(config)
ModelProviderRegistry.reset_for_testing()
def test_tool_output_model_serialization(self): def test_tool_output_model_serialization(self):
"""Test ToolOutput model serialization""" """Test ToolOutput model serialization"""

View File

@@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch
from providers.base import ModelProvider from providers.base import ModelProvider
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType from providers.shared import ModelCapabilities, ProviderType
from tools.listmodels import ListModelsTool from tools.listmodels import ListModelsTool
@@ -23,10 +23,63 @@ class TestListModelsRestrictions(unittest.TestCase):
self.mock_openrouter = MagicMock(spec=ModelProvider) self.mock_openrouter = MagicMock(spec=ModelProvider)
self.mock_openrouter.provider_type = ProviderType.OPENROUTER self.mock_openrouter.provider_type = ProviderType.OPENROUTER
def make_capabilities(
canonical: str, friendly: str, *, aliases=None, context: int = 200_000
) -> ModelCapabilities:
return ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=canonical,
friendly_name=friendly,
intelligence_score=20,
description=friendly,
aliases=aliases or [],
context_window=context,
max_output_tokens=context,
supports_extended_thinking=True,
)
opus_caps = make_capabilities(
"anthropic/claude-opus-4-20240229",
"Claude Opus",
aliases=["opus"],
)
sonnet_caps = make_capabilities(
"anthropic/claude-sonnet-4-20240229",
"Claude Sonnet",
aliases=["sonnet"],
)
deepseek_caps = make_capabilities(
"deepseek/deepseek-r1-0528:free",
"DeepSeek R1",
aliases=[],
)
qwen_caps = make_capabilities(
"qwen/qwen3-235b-a22b-04-28:free",
"Qwen3",
aliases=[],
)
self._openrouter_caps_map = {
"anthropic/claude-opus-4": opus_caps,
"opus": opus_caps,
"anthropic/claude-opus-4-20240229": opus_caps,
"anthropic/claude-sonnet-4": sonnet_caps,
"sonnet": sonnet_caps,
"anthropic/claude-sonnet-4-20240229": sonnet_caps,
"deepseek/deepseek-r1-0528:free": deepseek_caps,
"qwen/qwen3-235b-a22b-04-28:free": qwen_caps,
}
self.mock_openrouter.get_capabilities.side_effect = self._openrouter_caps_map.__getitem__
self.mock_openrouter.get_capabilities_by_rank.return_value = []
self.mock_openrouter.list_models.return_value = []
# Create mock Gemini provider for comparison # Create mock Gemini provider for comparison
self.mock_gemini = MagicMock(spec=ModelProvider) self.mock_gemini = MagicMock(spec=ModelProvider)
self.mock_gemini.provider_type = ProviderType.GOOGLE self.mock_gemini.provider_type = ProviderType.GOOGLE
self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"] self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"]
self.mock_gemini.get_capabilities_by_rank.return_value = []
self.mock_gemini.get_capabilities_by_rank.return_value = []
def tearDown(self): def tearDown(self):
"""Clean up after tests.""" """Clean up after tests."""
@@ -159,7 +212,7 @@ class TestListModelsRestrictions(unittest.TestCase):
for line in lines: for line in lines:
if "OpenRouter" in line and "" in line: if "OpenRouter" in line and "" in line:
openrouter_section_found = True openrouter_section_found = True
elif "Available Models" in line and openrouter_section_found: elif ("Models (policy restricted)" in line or "Available Models" in line) and openrouter_section_found:
in_openrouter_section = True in_openrouter_section = True
elif in_openrouter_section: elif in_openrouter_section:
# Check for lines with model names in backticks # Check for lines with model names in backticks
@@ -179,11 +232,11 @@ class TestListModelsRestrictions(unittest.TestCase):
len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}" len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}"
) )
# Verify list_models was called with respect_restrictions=True # Verify we did not fall back to unrestricted listing
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True) self.mock_openrouter.list_models.assert_not_called()
# Check for restriction note # Check for restriction note
self.assertIn("Restricted to models matching:", result) self.assertIn("OpenRouter models restricted by", result)
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}, clear=True) @patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}, clear=True)
@patch("providers.openrouter_registry.OpenRouterModelRegistry") @patch("providers.openrouter_registry.OpenRouterModelRegistry")

View File

@@ -121,38 +121,59 @@ class TestModelMetadataContinuation:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_previous_assistant_turn_defaults(self): async def test_no_previous_assistant_turn_defaults(self):
"""Test behavior when there's no previous assistant turn.""" """Test behavior when there's no previous assistant turn."""
thread_id = create_thread("chat", {"prompt": "test"}) # Save and set DEFAULT_MODEL for test
import importlib
import os
# Only add user turns original_default = os.environ.get("DEFAULT_MODEL", "")
add_turn(thread_id, "user", "First question") os.environ["DEFAULT_MODEL"] = "auto"
add_turn(thread_id, "user", "Second question") import config
import utils.model_context
arguments = {"continuation_id": thread_id} importlib.reload(config)
importlib.reload(utils.model_context)
# Mock dependencies try:
with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc: thread_id = create_thread("chat", {"prompt": "test"})
mock_calc.return_value = MagicMock(
total_tokens=200000,
content_tokens=160000,
response_tokens=40000,
file_tokens=64000,
history_tokens=64000,
)
with patch("utils.conversation_memory.build_conversation_history") as mock_build: # Only add user turns
mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000) add_turn(thread_id, "user", "First question")
add_turn(thread_id, "user", "Second question")
# Call the actual function arguments = {"continuation_id": thread_id}
enhanced_args = await reconstruct_thread_context(arguments)
# Should not have set a model # Mock dependencies
assert enhanced_args.get("model") is None with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc:
mock_calc.return_value = MagicMock(
total_tokens=200000,
content_tokens=160000,
response_tokens=40000,
file_tokens=64000,
history_tokens=64000,
)
# ModelContext should use DEFAULT_MODEL with patch("utils.conversation_memory.build_conversation_history") as mock_build:
model_context = ModelContext.from_arguments(enhanced_args) mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000)
from config import DEFAULT_MODEL
assert model_context.model_name == DEFAULT_MODEL # Call the actual function
enhanced_args = await reconstruct_thread_context(arguments)
# Should not have set a model
assert enhanced_args.get("model") is None
# ModelContext should use DEFAULT_MODEL
model_context = ModelContext.from_arguments(enhanced_args)
from config import DEFAULT_MODEL
assert model_context.model_name == DEFAULT_MODEL
finally:
# Restore original value
if original_default:
os.environ["DEFAULT_MODEL"] = original_default
else:
os.environ.pop("DEFAULT_MODEL", None)
importlib.reload(config)
importlib.reload(utils.model_context)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_explicit_model_overrides_previous_turn(self): async def test_explicit_model_overrides_previous_turn(self):

View File

@@ -49,17 +49,32 @@ class TestModelRestrictionService:
def test_load_multiple_models_restriction(self): def test_load_multiple_models_restriction(self):
"""Test loading multiple allowed models.""" """Test loading multiple allowed models."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}): with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
service = ModelRestrictionService() # Instantiate providers so alias resolution for allow-lists is available
openai_provider = OpenAIModelProvider(api_key="test-key")
gemini_provider = GeminiModelProvider(api_key="test-key")
# Check OpenAI models from providers.registry import ModelProviderRegistry
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
# Check Google models def fake_get_provider(provider_type, force_new=False):
assert service.is_allowed(ProviderType.GOOGLE, "flash") mapping = {
assert service.is_allowed(ProviderType.GOOGLE, "pro") ProviderType.OPENAI: openai_provider,
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro") ProviderType.GOOGLE: gemini_provider,
}
return mapping.get(provider_type)
with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider):
service = ModelRestrictionService()
# Check OpenAI models
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
# Check Google models
assert service.is_allowed(ProviderType.GOOGLE, "flash")
assert service.is_allowed(ProviderType.GOOGLE, "pro")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
def test_case_insensitive_and_whitespace_handling(self): def test_case_insensitive_and_whitespace_handling(self):
"""Test that model names are case-insensitive and whitespace is trimmed.""" """Test that model names are case-insensitive and whitespace is trimmed."""
@@ -111,13 +126,17 @@ class TestModelRestrictionService:
def test_shorthand_names_in_restrictions(self): def test_shorthand_names_in_restrictions(self):
"""Test that shorthand names work in restrictions.""" """Test that shorthand names work in restrictions."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,o3-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}): with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
# Instantiate providers so the registry can resolve aliases
OpenAIModelProvider(api_key="test-key")
GeminiModelProvider(api_key="test-key")
service = ModelRestrictionService() service = ModelRestrictionService()
# When providers check models, they pass both resolved and original names # When providers check models, they pass both resolved and original names
# OpenAI: 'mini' shorthand allows o4-mini # OpenAI: 'o4mini' shorthand allows o4-mini
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "mini") # How providers actually call it assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini") # How providers actually call it
assert not service.is_allowed(ProviderType.OPENAI, "o4-mini") # Direct check without original (for testing) assert service.is_allowed(ProviderType.OPENAI, "o4-mini") # Canonical should also be allowed
# OpenAI: o3-mini allowed directly # OpenAI: o3-mini allowed directly
assert service.is_allowed(ProviderType.OPENAI, "o3-mini") assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
@@ -280,19 +299,25 @@ class TestProviderIntegration:
provider = GeminiModelProvider(api_key="test-key") provider = GeminiModelProvider(api_key="test-key")
# Test case: Only alias "flash" is allowed, not the full name from providers.registry import ModelProviderRegistry
# If parameters are in wrong order, this test will catch it
# Should allow "flash" alias with patch.object(ModelProviderRegistry, "get_provider", return_value=provider):
assert provider.validate_model_name("flash")
# Should allow getting capabilities for "flash" # Test case: Only alias "flash" is allowed, not the full name
capabilities = provider.get_capabilities("flash") # If parameters are in wrong order, this test will catch it
assert capabilities.model_name == "gemini-2.5-flash"
# Test the edge case: Try to use full model name when only alias is allowed # Should allow "flash" alias
# This should NOT be allowed - only the alias "flash" is in the restriction list assert provider.validate_model_name("flash")
assert not provider.validate_model_name("gemini-2.5-flash")
# Should allow getting capabilities for "flash"
capabilities = provider.get_capabilities("flash")
assert capabilities.model_name == "gemini-2.5-flash"
# Canonical form should also be allowed now that alias is on the allowlist
assert provider.validate_model_name("gemini-2.5-flash")
# Unrelated models remain blocked
assert not provider.validate_model_name("pro")
assert not provider.validate_model_name("gemini-2.5-pro")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
def test_gemini_parameter_order_edge_case_full_name_only(self): def test_gemini_parameter_order_edge_case_full_name_only(self):
@@ -570,17 +595,27 @@ class TestShorthandRestrictions:
# Test OpenAI provider # Test OpenAI provider
openai_provider = OpenAIModelProvider(api_key="test-key") openai_provider = OpenAIModelProvider(api_key="test-key")
assert openai_provider.validate_model_name("mini") # Should work with shorthand
# When restricting to "mini", you can't use "o4-mini" directly - this is correct behavior
assert not openai_provider.validate_model_name("o4-mini") # Not allowed - only shorthand is allowed
assert not openai_provider.validate_model_name("o3-mini") # Not allowed
# Test Gemini provider
gemini_provider = GeminiModelProvider(api_key="test-key") gemini_provider = GeminiModelProvider(api_key="test-key")
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
# Same for Gemini - if you restrict to "flash", you can't use the full name from providers.registry import ModelProviderRegistry
assert not gemini_provider.validate_model_name("gemini-2.5-flash") # Not allowed
assert not gemini_provider.validate_model_name("pro") # Not allowed def registry_side_effect(provider_type, force_new=False):
mapping = {
ProviderType.OPENAI: openai_provider,
ProviderType.GOOGLE: gemini_provider,
}
return mapping.get(provider_type)
with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect):
assert openai_provider.validate_model_name("mini") # Should work with shorthand
assert openai_provider.validate_model_name("gpt-5-mini") # Canonical resolved from shorthand
assert not openai_provider.validate_model_name("o4-mini") # Unrelated model still blocked
assert not openai_provider.validate_model_name("o3-mini")
# Test Gemini provider
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
assert gemini_provider.validate_model_name("gemini-2.5-flash") # Canonical allowed
assert not gemini_provider.validate_model_name("pro") # Not allowed
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"}) @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
def test_multiple_shorthands_for_same_model(self): def test_multiple_shorthands_for_same_model(self):
@@ -596,9 +631,9 @@ class TestShorthandRestrictions:
assert openai_provider.validate_model_name("mini") # mini -> o4-mini assert openai_provider.validate_model_name("mini") # mini -> o4-mini
assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini
# Resolved names work only if explicitly allowed # Resolved names should be allowed when their shorthands are present
assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed
assert not openai_provider.validate_model_name("o3-mini") # Not explicitly allowed, only shorthand assert openai_provider.validate_model_name("o3-mini") # Allowed via shorthand
# Other models should not work # Other models should not work
assert not openai_provider.validate_model_name("o3") assert not openai_provider.validate_model_name("o3")

View File

@@ -260,9 +260,10 @@ class TestOpenRouterAutoMode:
os.environ["DEFAULT_MODEL"] = "auto" os.environ["DEFAULT_MODEL"] = "auto"
mock_provider_class = Mock() mock_provider_class = Mock()
mock_provider_instance = Mock(spec=["get_provider_type", "list_models"]) mock_provider_instance = Mock(spec=["get_provider_type", "list_models", "get_all_model_capabilities"])
mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER
mock_provider_instance.list_models.return_value = [] mock_provider_instance.list_models.return_value = []
mock_provider_instance.get_all_model_capabilities.return_value = {}
mock_provider_class.return_value = mock_provider_instance mock_provider_class.return_value = mock_provider_instance
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class) ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class)

View File

@@ -293,13 +293,7 @@ class TestOpenRouterAliasRestrictions:
# o3 -> openai/o3 # o3 -> openai/o3
# gpt4.1 -> should not exist (expected to be filtered out) # gpt4.1 -> should not exist (expected to be filtered out)
expected_models = { expected_models = {"o3-mini", "pro", "flash", "o4-mini", "o3"}
"openai/o3-mini",
"google/gemini-2.5-pro",
"google/gemini-2.5-flash",
"openai/o4-mini",
"openai/o3",
}
available_model_names = set(available_models.keys()) available_model_names = set(available_models.keys())
@@ -355,9 +349,11 @@ class TestOpenRouterAliasRestrictions:
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True) available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
expected_models = { expected_models = {
"openai/o3-mini", # from alias "o3-mini", # alias
"openai/o3-mini", # canonical
"anthropic/claude-opus-4.1", # full name "anthropic/claude-opus-4.1", # full name
"google/gemini-2.5-flash", # from alias "flash", # alias
"google/gemini-2.5-flash", # canonical
} }
available_model_names = set(available_models.keys()) available_model_names = set(available_models.keys())

View File

@@ -83,9 +83,18 @@ class ListModelsTool(BaseTool):
from providers.openrouter_registry import OpenRouterModelRegistry from providers.openrouter_registry import OpenRouterModelRegistry
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType from providers.shared import ProviderType
from utils.model_restrictions import get_restriction_service
output_lines = ["# Available AI Models\n"] output_lines = ["# Available AI Models\n"]
restriction_service = get_restriction_service()
restricted_models_by_provider: dict[ProviderType, list[str]] = {}
if restriction_service:
restricted_map = ModelProviderRegistry.get_available_models(respect_restrictions=True)
for model_name, provider_type in restricted_map.items():
restricted_models_by_provider.setdefault(provider_type, []).append(model_name)
# Map provider types to friendly names and their models # Map provider types to friendly names and their models
provider_info = { provider_info = {
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"}, ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
@@ -94,6 +103,43 @@ class ListModelsTool(BaseTool):
ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"}, ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
} }
def format_model_entry(provider, display_name: str) -> list[str]:
try:
capabilities = provider.get_capabilities(display_name)
except ValueError:
return [f"- `{display_name}` *(not recognized by provider)*"]
canonical = capabilities.model_name
if canonical.lower() == display_name.lower():
header = f"- `{canonical}`"
else:
header = f"- `{display_name}` → `{canonical}`"
try:
context_value = capabilities.context_window or 0
except AttributeError:
context_value = 0
try:
context_value = int(context_value)
except (TypeError, ValueError):
context_value = 0
if context_value >= 1_000_000:
context_str = f"{context_value // 1_000_000}M context"
elif context_value >= 1_000:
context_str = f"{context_value // 1_000}K context"
elif context_value > 0:
context_str = f"{context_value} context"
else:
context_str = "unknown context"
try:
description = capabilities.description or "No description available"
except AttributeError:
description = "No description available"
lines = [header, f" - {context_str}", f" - {description}"]
return lines
# Check each native provider type # Check each native provider type
for provider_type, info in provider_info.items(): for provider_type, info in provider_info.items():
# Check if provider is enabled # Check if provider is enabled
@@ -104,30 +150,49 @@ class ListModelsTool(BaseTool):
if is_configured: if is_configured:
output_lines.append("**Status**: Configured and available") output_lines.append("**Status**: Configured and available")
output_lines.append("\n**Models**:") has_restrictions = bool(restriction_service and restriction_service.has_restrictions(provider_type))
aliases = [] if has_restrictions:
for model_name, capabilities in provider.get_capabilities_by_rank(): restricted_names = sorted(set(restricted_models_by_provider.get(provider_type, [])))
description = capabilities.description or "No description available"
context_window = capabilities.context_window
if context_window >= 1_000_000: if restricted_names:
context_str = f"{context_window // 1_000_000}M context" output_lines.append("\n**Models (policy restricted)**:")
elif context_window >= 1_000: for model_name in restricted_names:
context_str = f"{context_window // 1_000}K context" output_lines.extend(format_model_entry(provider, model_name))
else: else:
context_str = f"{context_window} context" if context_window > 0 else "unknown context" output_lines.append("\n*No models are currently allowed by restriction policy.*")
else:
output_lines.append("\n**Models**:")
output_lines.append(f"- `{model_name}` - {context_str}") aliases = []
output_lines.append(f" - {description}") for model_name, capabilities in provider.get_capabilities_by_rank():
try:
description = capabilities.description or "No description available"
except AttributeError:
description = "No description available"
for alias in capabilities.aliases or []: try:
if alias != model_name: context_window = capabilities.context_window or 0
aliases.append(f"- `{alias}` → `{model_name}`") except AttributeError:
context_window = 0
if aliases: if context_window >= 1_000_000:
output_lines.append("\n**Aliases**:") context_str = f"{context_window // 1_000_000}M context"
output_lines.extend(sorted(aliases)) elif context_window >= 1_000:
context_str = f"{context_window // 1_000}K context"
else:
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
output_lines.append(f"- `{model_name}` - {context_str}")
output_lines.append(f" - {description}")
for alias in capabilities.aliases or []:
if alias != model_name:
aliases.append(f"- `{alias}` → `{model_name}`")
if aliases:
output_lines.append("\n**Aliases**:")
output_lines.extend(sorted(aliases))
else: else:
output_lines.append(f"**Status**: Not configured (set {info['env_key']})") output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
@@ -144,19 +209,10 @@ class ListModelsTool(BaseTool):
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API") output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
try: try:
# Get OpenRouter provider from registry to properly apply restrictions
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER) provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
if provider: if provider:
# Get models with restrictions applied
available_models = provider.list_models(respect_restrictions=True)
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
# Group by provider and retain ranking information for consistent ordering
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
def _format_context(tokens: int) -> str: def _format_context(tokens: int) -> str:
if not tokens: if not tokens:
return "?" return "?"
@@ -166,53 +222,83 @@ class ListModelsTool(BaseTool):
return f"{tokens // 1_000}K" return f"{tokens // 1_000}K"
return str(tokens) return str(tokens)
for model_name in available_models: has_restrictions = bool(
config = registry.resolve(model_name) restriction_service and restriction_service.has_restrictions(ProviderType.OPENROUTER)
provider_name = "other" )
if config and "/" in config.model_name:
provider_name = config.model_name.split("/")[0]
elif "/" in model_name:
provider_name = model_name.split("/")[0]
providers_models.setdefault(provider_name, []) if has_restrictions:
restricted_names = sorted(set(restricted_models_by_provider.get(ProviderType.OPENROUTER, [])))
rank = config.get_effective_capability_rank() if config else 0 output_lines.append("\n**Models (policy restricted)**:")
providers_models[provider_name].append((rank, model_name, config)) if restricted_names:
for model_name in restricted_names:
try:
caps = provider.get_capabilities(model_name)
except ValueError:
output_lines.append(f"- `{model_name}` *(not recognized by provider)*")
continue
output_lines.append("\n**Available Models**:") context_value = int(caps.context_window or 0)
for provider_name, models in sorted(providers_models.items()): context_str = _format_context(context_value)
output_lines.append(f"\n*{provider_name.title()}:*")
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
if config:
context_str = _format_context(config.context_window)
suffix_parts = [f"{context_str} context"] suffix_parts = [f"{context_str} context"]
if getattr(config, "supports_extended_thinking", False): if caps.supports_extended_thinking:
suffix_parts.append("thinking") suffix_parts.append("thinking")
suffix = ", ".join(suffix_parts) suffix = ", ".join(suffix_parts)
output_lines.append(f"- `{alias}` → `{config.model_name}` (score {rank}, {suffix})")
else:
output_lines.append(f"- `{alias}` (score {rank})")
total_models = len(available_models) arrow = ""
# Show all models - no truncation message needed if caps.model_name.lower() != model_name.lower():
arrow = f" → `{caps.model_name}`"
# Check if restrictions are applied score = caps.get_effective_capability_rank()
restriction_service = None output_lines.append(f"- `{model_name}`{arrow} (score {score}, {suffix})")
try:
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER) or set()
if restriction_service.has_restrictions(ProviderType.OPENROUTER): if allowed_set:
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER) output_lines.append(
output_lines.append( f"\n*OpenRouter models restricted by OPENROUTER_ALLOWED_MODELS: {', '.join(sorted(allowed_set))}*"
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}" )
) else:
except Exception as e: output_lines.append("- *No models allowed by current restriction policy.*")
logger.warning(f"Error checking OpenRouter restrictions: {e}") else:
available_models = provider.list_models(respect_restrictions=True)
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
for model_name in available_models:
config = registry.resolve(model_name)
provider_name = "other"
if config and "/" in config.model_name:
provider_name = config.model_name.split("/")[0]
elif "/" in model_name:
provider_name = model_name.split("/")[0]
providers_models.setdefault(provider_name, [])
rank = config.get_effective_capability_rank() if config else 0
providers_models[provider_name].append((rank, model_name, config))
output_lines.append("\n**Available Models**:")
for provider_name, models in sorted(providers_models.items()):
output_lines.append(f"\n*{provider_name.title()}:*")
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
if config:
context_str = _format_context(getattr(config, "context_window", 0))
suffix_parts = [f"{context_str} context"]
if getattr(config, "supports_extended_thinking", False):
suffix_parts.append("thinking")
suffix = ", ".join(suffix_parts)
arrow = ""
if config.model_name.lower() != alias.lower():
arrow = f" → `{config.model_name}`"
output_lines.append(f"- `{alias}`{arrow} (score {rank}, {suffix})")
else:
output_lines.append(f"- `{alias}` (score {rank})")
else: else:
output_lines.append("**Error**: Could not load OpenRouter provider") output_lines.append("**Error**: Could not load OpenRouter provider")
except Exception as e: except Exception as e:
logger.exception("Error listing OpenRouter models: %s", e)
output_lines.append(f"**Error loading models**: {str(e)}") output_lines.append(f"**Error loading models**: {str(e)}")
else: else:
output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)") output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)")

View File

@@ -22,6 +22,7 @@ Example:
import logging import logging
import os import os
from collections import defaultdict
from typing import Optional from typing import Optional
from providers.shared import ProviderType from providers.shared import ProviderType
@@ -58,6 +59,7 @@ class ModelRestrictionService:
def __init__(self): def __init__(self):
"""Initialize the restriction service by loading from environment.""" """Initialize the restriction service by loading from environment."""
self.restrictions: dict[ProviderType, set[str]] = {} self.restrictions: dict[ProviderType, set[str]] = {}
self._alias_resolution_cache: dict[ProviderType, dict[str, str]] = defaultdict(dict)
self._load_from_env() self._load_from_env()
def _load_from_env(self) -> None: def _load_from_env(self) -> None:
@@ -79,6 +81,7 @@ class ModelRestrictionService:
if models: if models:
self.restrictions[provider_type] = models self.restrictions[provider_type] = models
self._alias_resolution_cache[provider_type] = {}
logger.info(f"{provider_type.value} allowed models: {sorted(models)}") logger.info(f"{provider_type.value} allowed models: {sorted(models)}")
else: else:
# All entries were empty after cleaning - treat as no restrictions # All entries were empty after cleaning - treat as no restrictions
@@ -150,7 +153,41 @@ class ModelRestrictionService:
names_to_check.add(original_name.lower()) names_to_check.add(original_name.lower())
# If any of the names is in the allowed set, it's allowed # If any of the names is in the allowed set, it's allowed
return any(name in allowed_set for name in names_to_check) if any(name in allowed_set for name in names_to_check):
return True
# Attempt to resolve canonical names for allowed aliases using provider metadata.
try:
from providers.registry import ModelProviderRegistry
provider = ModelProviderRegistry.get_provider(provider_type)
except Exception: # pragma: no cover - registry lookup failure shouldn't break validation
provider = None
if provider:
cache = self._alias_resolution_cache.setdefault(provider_type, {})
for allowed_entry in list(allowed_set):
normalized_resolved = cache.get(allowed_entry)
if not normalized_resolved:
try:
resolved = provider._resolve_model_name(allowed_entry)
except Exception: # pragma: no cover - resolution failures are treated as non-matches
continue
if not resolved:
continue
normalized_resolved = resolved.lower()
cache[allowed_entry] = normalized_resolved
if normalized_resolved in names_to_check:
allowed_set.add(normalized_resolved)
cache[normalized_resolved] = normalized_resolved
return True
return False
def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]: def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]:
""" """