fix: listmodels to always honor restricted models
fix: restrictions should resolve canonical names for openrouter fix: tools now correctly return restricted list by presenting model names in schema fix: tests updated to ensure these manage their expected env vars properly perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
11
AGENTS.md
11
AGENTS.md
@@ -1,5 +1,9 @@
|
||||
# Repository Guidelines
|
||||
|
||||
See `requirements.txt` and `requirements-dev.txt`
|
||||
|
||||
Also read CLAUDE.md and CLAUDE.local.md if available.
|
||||
|
||||
## Project Structure & Module Organization
|
||||
Zen MCP Server centers on `server.py`, which exposes MCP entrypoints and coordinates multi-model workflows.
|
||||
Feature-specific tools live in `tools/`, provider integrations in `providers/`, and shared helpers in `utils/`.
|
||||
@@ -14,6 +18,13 @@ Authoritative documentation and samples live in `docs/`, and runtime diagnostics
|
||||
- `python communication_simulator_test.py --quick` – smoke-test orchestration across tools and providers.
|
||||
- `./run_integration_tests.sh [--with-simulator]` – exercise provider-dependent flows against remote or Ollama models.
|
||||
|
||||
For example, this is how we run an individual / all tests:
|
||||
|
||||
```
|
||||
.zen_venv/bin/activate && pytest tests/test_auto_mode_model_listing.py -q
|
||||
.zen_venv/bin/activate && pytest -q
|
||||
```
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
Target Python 3.9+ with Black and isort using a 120-character line limit; Ruff enforces pycodestyle, pyflakes, bugbear, comprehension, and pyupgrade rules. Prefer explicit type hints, snake_case modules, and imperative commit-time docstrings. Extend workflows by defining hook or abstract methods instead of checking `hasattr()`/`getattr()`—inheritance-backed contracts keep behavior discoverable and testable.
|
||||
|
||||
|
||||
@@ -73,6 +73,8 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
|
||||
logging.info(f"Initializing Custom provider with endpoint: {base_url}")
|
||||
|
||||
self._alias_cache: dict[str, str] = {}
|
||||
|
||||
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||
|
||||
# Initialize model registry (shared with OpenRouter for consistent aliases)
|
||||
@@ -120,11 +122,18 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve registry aliases and strip version tags for local models."""
|
||||
|
||||
cache_key = model_name.lower()
|
||||
if cache_key in self._alias_cache:
|
||||
return self._alias_cache[cache_key]
|
||||
|
||||
config = self._registry.resolve(model_name)
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
|
||||
resolved = config.model_name
|
||||
self._alias_cache[cache_key] = resolved
|
||||
self._alias_cache.setdefault(resolved.lower(), resolved)
|
||||
return resolved
|
||||
|
||||
if ":" in model_name:
|
||||
base_model = model_name.split(":")[0]
|
||||
@@ -132,11 +141,16 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
|
||||
base_config = self._registry.resolve(base_model)
|
||||
if base_config:
|
||||
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
||||
return base_config.model_name
|
||||
logging.debug("Resolved base model '%s' to '%s'", base_model, base_config.model_name)
|
||||
resolved = base_config.model_name
|
||||
self._alias_cache[cache_key] = resolved
|
||||
self._alias_cache.setdefault(resolved.lower(), resolved)
|
||||
return resolved
|
||||
self._alias_cache[cache_key] = base_model
|
||||
return base_model
|
||||
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
self._alias_cache[cache_key] = model_name
|
||||
return model_name
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
|
||||
@@ -39,6 +39,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
base_url: Base URL for the API endpoint
|
||||
**kwargs: Additional configuration options including timeout
|
||||
"""
|
||||
self._allowed_alias_cache: dict[str, str] = {}
|
||||
super().__init__(api_key, **kwargs)
|
||||
self._client = None
|
||||
self.base_url = base_url
|
||||
@@ -74,9 +75,33 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
canonical = canonical_name.lower()
|
||||
|
||||
if requested not in self.allowed_models and canonical not in self.allowed_models:
|
||||
raise ValueError(
|
||||
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
||||
)
|
||||
allowed = False
|
||||
for allowed_entry in list(self.allowed_models):
|
||||
normalized_resolved = self._allowed_alias_cache.get(allowed_entry)
|
||||
if normalized_resolved is None:
|
||||
try:
|
||||
resolved_name = self._resolve_model_name(allowed_entry)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not resolved_name:
|
||||
continue
|
||||
|
||||
normalized_resolved = resolved_name.lower()
|
||||
self._allowed_alias_cache[allowed_entry] = normalized_resolved
|
||||
|
||||
if normalized_resolved == canonical:
|
||||
# Canonical match discovered via alias resolution – mark as allowed and
|
||||
# memoise the canonical entry for future lookups.
|
||||
allowed = True
|
||||
self._allowed_alias_cache[canonical] = canonical
|
||||
self.allowed_models.add(canonical)
|
||||
break
|
||||
|
||||
if not allowed:
|
||||
raise ValueError(
|
||||
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
||||
)
|
||||
|
||||
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||||
"""Parse allowed models from environment variable.
|
||||
@@ -94,6 +119,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
||||
if models:
|
||||
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
||||
self._allowed_alias_cache = {}
|
||||
return models
|
||||
|
||||
# Log info if no allow-list configured for proxy providers
|
||||
|
||||
@@ -50,6 +50,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
self._alias_cache: dict[str, str] = {}
|
||||
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||
|
||||
# Initialize model registry
|
||||
@@ -178,13 +179,21 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve aliases defined in the OpenRouter registry."""
|
||||
|
||||
cache_key = model_name.lower()
|
||||
if cache_key in self._alias_cache:
|
||||
return self._alias_cache[cache_key]
|
||||
|
||||
config = self._registry.resolve(model_name)
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
return config.model_name
|
||||
logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
|
||||
resolved = config.model_name
|
||||
self._alias_cache[cache_key] = resolved
|
||||
self._alias_cache.setdefault(resolved.lower(), resolved)
|
||||
return resolved
|
||||
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
self._alias_cache[cache_key] = model_name
|
||||
return model_name
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
|
||||
@@ -205,6 +205,18 @@ class ModelProviderRegistry:
|
||||
logging.warning("Provider %s does not implement list_models", provider_type)
|
||||
continue
|
||||
|
||||
if restriction_service and restriction_service.has_restrictions(provider_type):
|
||||
restricted_display = cls._collect_restricted_display_names(
|
||||
provider,
|
||||
provider_type,
|
||||
available,
|
||||
restriction_service,
|
||||
)
|
||||
if restricted_display:
|
||||
for model_name in restricted_display:
|
||||
models[model_name] = provider_type
|
||||
continue
|
||||
|
||||
for model_name in available:
|
||||
# =====================================================================================
|
||||
# CRITICAL: Prevent double restriction filtering (Fixed Issue #98)
|
||||
@@ -227,6 +239,50 @@ class ModelProviderRegistry:
|
||||
|
||||
return models
|
||||
|
||||
@classmethod
|
||||
def _collect_restricted_display_names(
|
||||
cls,
|
||||
provider: ModelProvider,
|
||||
provider_type: ProviderType,
|
||||
available: list[str],
|
||||
restriction_service,
|
||||
) -> list[str] | None:
|
||||
"""Derive the human-facing model list when restrictions are active."""
|
||||
|
||||
allowed_models = restriction_service.get_allowed_models(provider_type)
|
||||
if not allowed_models:
|
||||
return None
|
||||
|
||||
allowed_details: list[tuple[str, int]] = []
|
||||
|
||||
for model_name in sorted(allowed_models):
|
||||
try:
|
||||
capabilities = provider.get_capabilities(model_name)
|
||||
except (AttributeError, ValueError):
|
||||
continue
|
||||
|
||||
try:
|
||||
rank = capabilities.get_effective_capability_rank()
|
||||
rank_value = float(rank)
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
rank_value = 0.0
|
||||
|
||||
allowed_details.append((model_name, rank_value))
|
||||
|
||||
if allowed_details:
|
||||
allowed_details.sort(key=lambda item: (-item[1], item[0]))
|
||||
return [name for name, _ in allowed_details]
|
||||
|
||||
# Fallback: intersect the allowlist with the provider-advertised names.
|
||||
available_lookup = {name.lower(): name for name in available}
|
||||
display_names: list[str] = []
|
||||
for model_name in sorted(allowed_models):
|
||||
lowered = model_name.lower()
|
||||
if lowered in available_lookup:
|
||||
display_names.append(available_lookup[lowered])
|
||||
|
||||
return display_names
|
||||
|
||||
@classmethod
|
||||
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
|
||||
"""Get list of available model names, optionally filtered by provider.
|
||||
|
||||
18
server.py
18
server.py
@@ -492,15 +492,25 @@ def configure_providers():
|
||||
|
||||
# Register providers in priority order:
|
||||
# 1. Native APIs first (most direct and efficient)
|
||||
registered_providers = []
|
||||
|
||||
if has_native_apis:
|
||||
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
registered_providers.append(ProviderType.GOOGLE.value)
|
||||
logger.debug(f"Registered provider: {ProviderType.GOOGLE.value}")
|
||||
if openai_key and openai_key != "your_openai_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
registered_providers.append(ProviderType.OPENAI.value)
|
||||
logger.debug(f"Registered provider: {ProviderType.OPENAI.value}")
|
||||
if xai_key and xai_key != "your_xai_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
registered_providers.append(ProviderType.XAI.value)
|
||||
logger.debug(f"Registered provider: {ProviderType.XAI.value}")
|
||||
if dial_key and dial_key != "your_dial_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider)
|
||||
registered_providers.append(ProviderType.DIAL.value)
|
||||
logger.debug(f"Registered provider: {ProviderType.DIAL.value}")
|
||||
|
||||
# 2. Custom provider second (for local/private models)
|
||||
if has_custom:
|
||||
@@ -511,10 +521,18 @@ def configure_providers():
|
||||
return CustomProvider(api_key=api_key or "", base_url=base_url) # Use provided API key or empty string
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
|
||||
registered_providers.append(ProviderType.CUSTOM.value)
|
||||
logger.debug(f"Registered provider: {ProviderType.CUSTOM.value}")
|
||||
|
||||
# 3. OpenRouter last (catch-all for everything else)
|
||||
if has_openrouter:
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
registered_providers.append(ProviderType.OPENROUTER.value)
|
||||
logger.debug(f"Registered provider: {ProviderType.OPENROUTER.value}")
|
||||
|
||||
# Log all registered providers
|
||||
if registered_providers:
|
||||
logger.info(f"Registered providers: {', '.join(registered_providers)}")
|
||||
|
||||
# Require at least one valid provider
|
||||
if not valid_providers:
|
||||
|
||||
@@ -63,27 +63,30 @@ class TestAliasTargetRestrictions:
|
||||
assert provider.validate_model_name("o4mini")
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only
|
||||
def test_restriction_policy_allows_only_alias_when_alias_specified(self):
|
||||
"""Test that restriction policy allows only the alias when just alias is specified.
|
||||
|
||||
If you restrict to 'mini' (which is an alias for gpt-5-mini),
|
||||
only the alias should work, not other models.
|
||||
This is the correct restrictive behavior.
|
||||
"""
|
||||
# Clear cached restriction service
|
||||
def test_restriction_policy_alias_allows_canonical(self):
|
||||
"""Alias-only allowlists should permit both the alias and its canonical target."""
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
# Only the alias should be allowed
|
||||
assert provider.validate_model_name("mini")
|
||||
# Direct target for this alias should NOT be allowed (mini -> gpt-5-mini)
|
||||
assert not provider.validate_model_name("gpt-5-mini")
|
||||
# Other models should NOT be allowed
|
||||
assert provider.validate_model_name("gpt-5-mini")
|
||||
assert not provider.validate_model_name("o4-mini")
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"})
|
||||
def test_restriction_policy_alias_allows_short_name(self):
|
||||
"""Common aliases like 'gpt5' should allow their canonical forms."""
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
assert provider.validate_model_name("gpt5")
|
||||
assert provider.validate_model_name("gpt-5")
|
||||
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target
|
||||
def test_gemini_restriction_policy_allows_alias_when_target_allowed(self):
|
||||
"""Test Gemini restriction policy allows alias when target is allowed."""
|
||||
@@ -99,19 +102,16 @@ class TestAliasTargetRestrictions:
|
||||
assert provider.validate_model_name("flash")
|
||||
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}) # Allow alias only
|
||||
def test_gemini_restriction_policy_allows_only_alias_when_alias_specified(self):
|
||||
"""Test Gemini restriction policy allows only alias when just alias is specified."""
|
||||
# Clear cached restriction service
|
||||
def test_gemini_restriction_policy_alias_allows_canonical(self):
|
||||
"""Gemini alias allowlists should permit canonical forms."""
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
# Only the alias should be allowed
|
||||
assert provider.validate_model_name("flash")
|
||||
# Direct target should NOT be allowed
|
||||
assert not provider.validate_model_name("gemini-2.5-flash")
|
||||
assert provider.validate_model_name("gemini-2.5-flash")
|
||||
|
||||
def test_restriction_service_validation_includes_all_targets(self):
|
||||
"""Test that restriction service validation knows about all aliases and targets."""
|
||||
@@ -153,6 +153,30 @@ class TestAliasTargetRestrictions:
|
||||
assert provider.validate_model_name("o4-mini") # target
|
||||
assert provider.validate_model_name("o4mini") # alias for o4-mini
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"}, clear=True)
|
||||
def test_service_alias_allows_canonical_openai(self):
|
||||
"""ModelRestrictionService should permit canonical names resolved from aliases."""
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
service = ModelRestrictionService()
|
||||
|
||||
assert service.is_allowed(ProviderType.OPENAI, "gpt-5")
|
||||
assert provider.validate_model_name("gpt-5")
|
||||
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}, clear=True)
|
||||
def test_service_alias_allows_canonical_gemini(self):
|
||||
"""Gemini alias allowlists should permit canonical forms."""
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
service = ModelRestrictionService()
|
||||
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
|
||||
assert provider.validate_model_name("gemini-2.5-flash")
|
||||
|
||||
def test_alias_target_policy_regression_prevention(self):
|
||||
"""Regression test to ensure aliases and targets are both validated properly.
|
||||
|
||||
|
||||
@@ -106,19 +106,35 @@ class TestAutoMode:
|
||||
|
||||
def test_tool_schema_in_normal_mode(self):
|
||||
"""Test that tool schemas don't require model in normal mode"""
|
||||
# This test uses the default from conftest.py which sets non-auto mode
|
||||
# The conftest.py mock_provider_availability fixture ensures the model is available
|
||||
tool = ChatTool()
|
||||
schema = tool.get_input_schema()
|
||||
# Save original
|
||||
original = os.environ.get("DEFAULT_MODEL", "")
|
||||
|
||||
# Model should not be required when default model is configured
|
||||
assert "model" not in schema["required"]
|
||||
try:
|
||||
# Set to a specific model (not auto mode)
|
||||
os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
|
||||
import config
|
||||
|
||||
# Model field should have simpler description
|
||||
model_schema = schema["properties"]["model"]
|
||||
assert "enum" not in model_schema
|
||||
assert "listmodels" in model_schema["description"]
|
||||
assert "default model" in model_schema["description"].lower()
|
||||
importlib.reload(config)
|
||||
|
||||
tool = ChatTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Model should not be required when default model is configured
|
||||
assert "model" not in schema["required"]
|
||||
|
||||
# Model field should have simpler description
|
||||
model_schema = schema["properties"]["model"]
|
||||
assert "enum" not in model_schema
|
||||
assert "listmodels" in model_schema["description"]
|
||||
assert "default model" in model_schema["description"].lower()
|
||||
|
||||
finally:
|
||||
# Restore
|
||||
if original:
|
||||
os.environ["DEFAULT_MODEL"] = original
|
||||
else:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
importlib.reload(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_requires_model_parameter(self):
|
||||
|
||||
203
tests/test_auto_mode_model_listing.py
Normal file
203
tests/test_auto_mode_model_listing.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Tests covering model restriction-aware error messaging in auto mode."""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
import utils.model_restrictions as model_restrictions
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
|
||||
def _extract_available_models(message: str) -> list[str]:
|
||||
"""Parse the available model list from the error message."""
|
||||
|
||||
marker = "Available models: "
|
||||
if marker not in message:
|
||||
raise AssertionError(f"Expected '{marker}' in message: {message}")
|
||||
|
||||
start = message.index(marker) + len(marker)
|
||||
end = message.find(". Suggested", start)
|
||||
if end == -1:
|
||||
end = len(message)
|
||||
|
||||
available_segment = message[start:end].strip()
|
||||
if not available_segment:
|
||||
return []
|
||||
|
||||
return [item.strip() for item in available_segment.split(",")]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_registry():
|
||||
"""Ensure registry and restriction service state is isolated."""
|
||||
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
model_restrictions._restriction_service = None
|
||||
yield
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
model_restrictions._restriction_service = None
|
||||
|
||||
|
||||
def _register_core_providers(*, include_xai: bool = False):
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
if include_xai:
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
def test_error_listing_respects_env_restrictions(monkeypatch, reset_registry):
|
||||
"""Error payload should surface only the allowed models for each provider."""
|
||||
|
||||
monkeypatch.setenv("DEFAULT_MODEL", "auto")
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
|
||||
monkeypatch.delenv("XAI_API_KEY", raising=False)
|
||||
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
|
||||
try:
|
||||
import dotenv
|
||||
|
||||
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
monkeypatch.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro")
|
||||
monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "gpt-5")
|
||||
monkeypatch.setenv("OPENROUTER_ALLOWED_MODELS", "gpt5nano")
|
||||
monkeypatch.setenv("XAI_ALLOWED_MODELS", "")
|
||||
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
_register_core_providers()
|
||||
|
||||
import server
|
||||
|
||||
importlib.reload(server)
|
||||
|
||||
# Reload may have re-applied .env overrides; enforce our test configuration
|
||||
for key, value in (
|
||||
("DEFAULT_MODEL", "auto"),
|
||||
("GEMINI_API_KEY", "test-gemini"),
|
||||
("OPENAI_API_KEY", "test-openai"),
|
||||
("OPENROUTER_API_KEY", "test-openrouter"),
|
||||
("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro"),
|
||||
("OPENAI_ALLOWED_MODELS", "gpt-5"),
|
||||
("OPENROUTER_ALLOWED_MODELS", "gpt5nano"),
|
||||
("XAI_ALLOWED_MODELS", ""),
|
||||
):
|
||||
monkeypatch.setenv(key, value)
|
||||
|
||||
for var in ("XAI_API_KEY", "CUSTOM_API_URL", "CUSTOM_API_KEY", "DIAL_API_KEY"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
model_restrictions._restriction_service = None
|
||||
server.configure_providers()
|
||||
|
||||
result = asyncio.run(
|
||||
server.handle_call_tool(
|
||||
"chat",
|
||||
{
|
||||
"model": "gpt5mini",
|
||||
"prompt": "Tell me about your strengths",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
payload = json.loads(result[0].text)
|
||||
assert payload["status"] == "error"
|
||||
|
||||
available_models = _extract_available_models(payload["content"])
|
||||
assert set(available_models) == {"gemini-2.5-pro", "gpt-5", "gpt5nano", "openai/gpt-5-nano"}
|
||||
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
def test_error_listing_without_restrictions_shows_full_catalog(monkeypatch, reset_registry):
|
||||
"""When no restrictions are set, the full high-capability catalogue should appear."""
|
||||
|
||||
monkeypatch.setenv("DEFAULT_MODEL", "auto")
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
|
||||
monkeypatch.setenv("XAI_API_KEY", "test-xai")
|
||||
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
|
||||
try:
|
||||
import dotenv
|
||||
|
||||
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
for var in (
|
||||
"GOOGLE_ALLOWED_MODELS",
|
||||
"OPENAI_ALLOWED_MODELS",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
"XAI_ALLOWED_MODELS",
|
||||
"DIAL_ALLOWED_MODELS",
|
||||
):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
_register_core_providers(include_xai=True)
|
||||
|
||||
import server
|
||||
|
||||
importlib.reload(server)
|
||||
|
||||
for key, value in (
|
||||
("DEFAULT_MODEL", "auto"),
|
||||
("GEMINI_API_KEY", "test-gemini"),
|
||||
("OPENAI_API_KEY", "test-openai"),
|
||||
("OPENROUTER_API_KEY", "test-openrouter"),
|
||||
):
|
||||
monkeypatch.setenv(key, value)
|
||||
|
||||
for var in (
|
||||
"GOOGLE_ALLOWED_MODELS",
|
||||
"OPENAI_ALLOWED_MODELS",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
"XAI_ALLOWED_MODELS",
|
||||
"DIAL_ALLOWED_MODELS",
|
||||
"CUSTOM_API_URL",
|
||||
"CUSTOM_API_KEY",
|
||||
):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
model_restrictions._restriction_service = None
|
||||
server.configure_providers()
|
||||
|
||||
result = asyncio.run(
|
||||
server.handle_call_tool(
|
||||
"chat",
|
||||
{
|
||||
"model": "dummymodel",
|
||||
"prompt": "Hi there",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
payload = json.loads(result[0].text)
|
||||
assert payload["status"] == "error"
|
||||
|
||||
available_models = _extract_available_models(payload["content"])
|
||||
assert "gemini-2.5-pro" in available_models
|
||||
assert "gpt-5" in available_models
|
||||
assert "grok-4" in available_models
|
||||
assert len(available_models) >= 5
|
||||
@@ -3,6 +3,7 @@ Tests for dynamic context request and collaboration features
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -157,95 +158,120 @@ class TestDynamicContextRequests:
|
||||
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
|
||||
async def test_clarification_with_suggested_action(self, mock_get_provider, analyze_tool):
|
||||
"""Test clarification request with suggested next action"""
|
||||
clarification_json = json.dumps(
|
||||
{
|
||||
"status": "files_required_to_continue",
|
||||
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
|
||||
"files_needed": ["config/database.yml", "src/db.py"],
|
||||
"suggested_next_action": {
|
||||
"tool": "analyze",
|
||||
"args": {
|
||||
"prompt": "Analyze database connection timeout issue",
|
||||
"relevant_files": [
|
||||
"/config/database.yml",
|
||||
"/src/db.py",
|
||||
"/logs/error.log",
|
||||
],
|
||||
import importlib
|
||||
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Ensure deterministic model configuration for this test regardless of previous suites
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
|
||||
original_default = os.environ.get("DEFAULT_MODEL")
|
||||
|
||||
try:
|
||||
os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
clarification_json = json.dumps(
|
||||
{
|
||||
"status": "files_required_to_continue",
|
||||
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
|
||||
"files_needed": ["config/database.yml", "src/db.py"],
|
||||
"suggested_next_action": {
|
||||
"tool": "analyze",
|
||||
"args": {
|
||||
"prompt": "Analyze database connection timeout issue",
|
||||
"relevant_files": [
|
||||
"/config/database.yml",
|
||||
"/src/db.py",
|
||||
"/logs/error.log",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await analyze_tool.execute(
|
||||
{
|
||||
"step": "Analyze database connection timeout issue",
|
||||
"step_number": 1,
|
||||
"total_steps": 1,
|
||||
"next_step_required": False,
|
||||
"findings": "Initial database timeout analysis",
|
||||
"relevant_files": ["/absolute/logs/error.log"],
|
||||
}
|
||||
)
|
||||
result = await analyze_tool.execute(
|
||||
{
|
||||
"step": "Analyze database connection timeout issue",
|
||||
"step_number": 1,
|
||||
"total_steps": 1,
|
||||
"next_step_required": False,
|
||||
"findings": "Initial database timeout analysis",
|
||||
"relevant_files": ["/absolute/logs/error.log"],
|
||||
}
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result) == 1
|
||||
|
||||
response_data = json.loads(result[0].text)
|
||||
response_data = json.loads(result[0].text)
|
||||
|
||||
# Workflow tools should either promote clarification status or handle it in expert analysis
|
||||
if response_data["status"] == "files_required_to_continue":
|
||||
# Clarification was properly promoted to main status
|
||||
# Check if mandatory_instructions is at top level or in content
|
||||
if "mandatory_instructions" in response_data:
|
||||
assert "database configuration" in response_data["mandatory_instructions"]
|
||||
assert "files_needed" in response_data
|
||||
assert "config/database.yml" in response_data["files_needed"]
|
||||
assert "src/db.py" in response_data["files_needed"]
|
||||
elif "content" in response_data:
|
||||
# Parse content JSON for workflow tools
|
||||
try:
|
||||
content_json = json.loads(response_data["content"])
|
||||
assert "mandatory_instructions" in content_json
|
||||
# Workflow tools should either promote clarification status or handle it in expert analysis
|
||||
if response_data["status"] == "files_required_to_continue":
|
||||
# Clarification was properly promoted to main status
|
||||
# Check if mandatory_instructions is at top level or in content
|
||||
if "mandatory_instructions" in response_data:
|
||||
assert "database configuration" in response_data["mandatory_instructions"]
|
||||
assert "files_needed" in response_data
|
||||
assert "config/database.yml" in response_data["files_needed"]
|
||||
assert "src/db.py" in response_data["files_needed"]
|
||||
elif "content" in response_data:
|
||||
# Parse content JSON for workflow tools
|
||||
try:
|
||||
content_json = json.loads(response_data["content"])
|
||||
assert "mandatory_instructions" in content_json
|
||||
assert (
|
||||
"database configuration" in content_json["mandatory_instructions"]
|
||||
or "database" in content_json["mandatory_instructions"]
|
||||
)
|
||||
assert "files_needed" in content_json
|
||||
files_needed_str = str(content_json["files_needed"])
|
||||
assert (
|
||||
"config/database.yml" in files_needed_str
|
||||
or "config" in files_needed_str
|
||||
or "database" in files_needed_str
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Content is not JSON, check if it contains required text
|
||||
content = response_data["content"]
|
||||
assert "database configuration" in content or "config" in content
|
||||
elif response_data["status"] == "calling_expert_analysis":
|
||||
# Clarification may be handled in expert analysis section
|
||||
if "expert_analysis" in response_data:
|
||||
expert_analysis = response_data["expert_analysis"]
|
||||
expert_content = str(expert_analysis)
|
||||
assert (
|
||||
"database configuration" in content_json["mandatory_instructions"]
|
||||
or "database" in content_json["mandatory_instructions"]
|
||||
"database configuration" in expert_content
|
||||
or "config/database.yml" in expert_content
|
||||
or "files_required_to_continue" in expert_content
|
||||
)
|
||||
assert "files_needed" in content_json
|
||||
files_needed_str = str(content_json["files_needed"])
|
||||
assert (
|
||||
"config/database.yml" in files_needed_str
|
||||
or "config" in files_needed_str
|
||||
or "database" in files_needed_str
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Content is not JSON, check if it contains required text
|
||||
content = response_data["content"]
|
||||
assert "database configuration" in content or "config" in content
|
||||
elif response_data["status"] == "calling_expert_analysis":
|
||||
# Clarification may be handled in expert analysis section
|
||||
if "expert_analysis" in response_data:
|
||||
expert_analysis = response_data["expert_analysis"]
|
||||
expert_content = str(expert_analysis)
|
||||
assert (
|
||||
"database configuration" in expert_content
|
||||
or "config/database.yml" in expert_content
|
||||
or "files_required_to_continue" in expert_content
|
||||
)
|
||||
else:
|
||||
# Some other status - ensure it's a valid workflow response
|
||||
assert "step_number" in response_data
|
||||
else:
|
||||
# Some other status - ensure it's a valid workflow response
|
||||
assert "step_number" in response_data
|
||||
|
||||
# Check for suggested next action
|
||||
if "suggested_next_action" in response_data:
|
||||
action = response_data["suggested_next_action"]
|
||||
assert action["tool"] == "analyze"
|
||||
# Check for suggested next action
|
||||
if "suggested_next_action" in response_data:
|
||||
action = response_data["suggested_next_action"]
|
||||
assert action["tool"] == "analyze"
|
||||
finally:
|
||||
if original_default is not None:
|
||||
os.environ["DEFAULT_MODEL"] = original_default
|
||||
else:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
|
||||
def test_tool_output_model_serialization(self):
|
||||
"""Test ToolOutput model serialization"""
|
||||
|
||||
@@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from providers.base import ModelProvider
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from providers.shared import ModelCapabilities, ProviderType
|
||||
from tools.listmodels import ListModelsTool
|
||||
|
||||
|
||||
@@ -23,10 +23,63 @@ class TestListModelsRestrictions(unittest.TestCase):
|
||||
self.mock_openrouter = MagicMock(spec=ModelProvider)
|
||||
self.mock_openrouter.provider_type = ProviderType.OPENROUTER
|
||||
|
||||
def make_capabilities(
|
||||
canonical: str, friendly: str, *, aliases=None, context: int = 200_000
|
||||
) -> ModelCapabilities:
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
model_name=canonical,
|
||||
friendly_name=friendly,
|
||||
intelligence_score=20,
|
||||
description=friendly,
|
||||
aliases=aliases or [],
|
||||
context_window=context,
|
||||
max_output_tokens=context,
|
||||
supports_extended_thinking=True,
|
||||
)
|
||||
|
||||
opus_caps = make_capabilities(
|
||||
"anthropic/claude-opus-4-20240229",
|
||||
"Claude Opus",
|
||||
aliases=["opus"],
|
||||
)
|
||||
sonnet_caps = make_capabilities(
|
||||
"anthropic/claude-sonnet-4-20240229",
|
||||
"Claude Sonnet",
|
||||
aliases=["sonnet"],
|
||||
)
|
||||
deepseek_caps = make_capabilities(
|
||||
"deepseek/deepseek-r1-0528:free",
|
||||
"DeepSeek R1",
|
||||
aliases=[],
|
||||
)
|
||||
qwen_caps = make_capabilities(
|
||||
"qwen/qwen3-235b-a22b-04-28:free",
|
||||
"Qwen3",
|
||||
aliases=[],
|
||||
)
|
||||
|
||||
self._openrouter_caps_map = {
|
||||
"anthropic/claude-opus-4": opus_caps,
|
||||
"opus": opus_caps,
|
||||
"anthropic/claude-opus-4-20240229": opus_caps,
|
||||
"anthropic/claude-sonnet-4": sonnet_caps,
|
||||
"sonnet": sonnet_caps,
|
||||
"anthropic/claude-sonnet-4-20240229": sonnet_caps,
|
||||
"deepseek/deepseek-r1-0528:free": deepseek_caps,
|
||||
"qwen/qwen3-235b-a22b-04-28:free": qwen_caps,
|
||||
}
|
||||
|
||||
self.mock_openrouter.get_capabilities.side_effect = self._openrouter_caps_map.__getitem__
|
||||
self.mock_openrouter.get_capabilities_by_rank.return_value = []
|
||||
self.mock_openrouter.list_models.return_value = []
|
||||
|
||||
# Create mock Gemini provider for comparison
|
||||
self.mock_gemini = MagicMock(spec=ModelProvider)
|
||||
self.mock_gemini.provider_type = ProviderType.GOOGLE
|
||||
self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"]
|
||||
self.mock_gemini.get_capabilities_by_rank.return_value = []
|
||||
self.mock_gemini.get_capabilities_by_rank.return_value = []
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after tests."""
|
||||
@@ -159,7 +212,7 @@ class TestListModelsRestrictions(unittest.TestCase):
|
||||
for line in lines:
|
||||
if "OpenRouter" in line and "✅" in line:
|
||||
openrouter_section_found = True
|
||||
elif "Available Models" in line and openrouter_section_found:
|
||||
elif ("Models (policy restricted)" in line or "Available Models" in line) and openrouter_section_found:
|
||||
in_openrouter_section = True
|
||||
elif in_openrouter_section:
|
||||
# Check for lines with model names in backticks
|
||||
@@ -179,11 +232,11 @@ class TestListModelsRestrictions(unittest.TestCase):
|
||||
len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}"
|
||||
)
|
||||
|
||||
# Verify list_models was called with respect_restrictions=True
|
||||
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True)
|
||||
# Verify we did not fall back to unrestricted listing
|
||||
self.mock_openrouter.list_models.assert_not_called()
|
||||
|
||||
# Check for restriction note
|
||||
self.assertIn("Restricted to models matching:", result)
|
||||
self.assertIn("OpenRouter models restricted by", result)
|
||||
|
||||
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}, clear=True)
|
||||
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
|
||||
|
||||
@@ -121,38 +121,59 @@ class TestModelMetadataContinuation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_previous_assistant_turn_defaults(self):
|
||||
"""Test behavior when there's no previous assistant turn."""
|
||||
thread_id = create_thread("chat", {"prompt": "test"})
|
||||
# Save and set DEFAULT_MODEL for test
|
||||
import importlib
|
||||
import os
|
||||
|
||||
# Only add user turns
|
||||
add_turn(thread_id, "user", "First question")
|
||||
add_turn(thread_id, "user", "Second question")
|
||||
original_default = os.environ.get("DEFAULT_MODEL", "")
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
import config
|
||||
import utils.model_context
|
||||
|
||||
arguments = {"continuation_id": thread_id}
|
||||
importlib.reload(config)
|
||||
importlib.reload(utils.model_context)
|
||||
|
||||
# Mock dependencies
|
||||
with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc:
|
||||
mock_calc.return_value = MagicMock(
|
||||
total_tokens=200000,
|
||||
content_tokens=160000,
|
||||
response_tokens=40000,
|
||||
file_tokens=64000,
|
||||
history_tokens=64000,
|
||||
)
|
||||
try:
|
||||
thread_id = create_thread("chat", {"prompt": "test"})
|
||||
|
||||
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
|
||||
mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000)
|
||||
# Only add user turns
|
||||
add_turn(thread_id, "user", "First question")
|
||||
add_turn(thread_id, "user", "Second question")
|
||||
|
||||
# Call the actual function
|
||||
enhanced_args = await reconstruct_thread_context(arguments)
|
||||
arguments = {"continuation_id": thread_id}
|
||||
|
||||
# Should not have set a model
|
||||
assert enhanced_args.get("model") is None
|
||||
# Mock dependencies
|
||||
with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc:
|
||||
mock_calc.return_value = MagicMock(
|
||||
total_tokens=200000,
|
||||
content_tokens=160000,
|
||||
response_tokens=40000,
|
||||
file_tokens=64000,
|
||||
history_tokens=64000,
|
||||
)
|
||||
|
||||
# ModelContext should use DEFAULT_MODEL
|
||||
model_context = ModelContext.from_arguments(enhanced_args)
|
||||
from config import DEFAULT_MODEL
|
||||
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
|
||||
mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000)
|
||||
|
||||
assert model_context.model_name == DEFAULT_MODEL
|
||||
# Call the actual function
|
||||
enhanced_args = await reconstruct_thread_context(arguments)
|
||||
|
||||
# Should not have set a model
|
||||
assert enhanced_args.get("model") is None
|
||||
|
||||
# ModelContext should use DEFAULT_MODEL
|
||||
model_context = ModelContext.from_arguments(enhanced_args)
|
||||
from config import DEFAULT_MODEL
|
||||
|
||||
assert model_context.model_name == DEFAULT_MODEL
|
||||
finally:
|
||||
# Restore original value
|
||||
if original_default:
|
||||
os.environ["DEFAULT_MODEL"] = original_default
|
||||
else:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
importlib.reload(config)
|
||||
importlib.reload(utils.model_context)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_model_overrides_previous_turn(self):
|
||||
|
||||
@@ -49,17 +49,32 @@ class TestModelRestrictionService:
|
||||
def test_load_multiple_models_restriction(self):
|
||||
"""Test loading multiple allowed models."""
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
||||
service = ModelRestrictionService()
|
||||
# Instantiate providers so alias resolution for allow-lists is available
|
||||
openai_provider = OpenAIModelProvider(api_key="test-key")
|
||||
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
# Check OpenAI models
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o3")
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Check Google models
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "flash")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "pro")
|
||||
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||
def fake_get_provider(provider_type, force_new=False):
|
||||
mapping = {
|
||||
ProviderType.OPENAI: openai_provider,
|
||||
ProviderType.GOOGLE: gemini_provider,
|
||||
}
|
||||
return mapping.get(provider_type)
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider):
|
||||
|
||||
service = ModelRestrictionService()
|
||||
|
||||
# Check OpenAI models
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o3")
|
||||
|
||||
# Check Google models
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "flash")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "pro")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||
|
||||
def test_case_insensitive_and_whitespace_handling(self):
|
||||
"""Test that model names are case-insensitive and whitespace is trimmed."""
|
||||
@@ -111,13 +126,17 @@ class TestModelRestrictionService:
|
||||
|
||||
def test_shorthand_names_in_restrictions(self):
|
||||
"""Test that shorthand names work in restrictions."""
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,o3-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
||||
# Instantiate providers so the registry can resolve aliases
|
||||
OpenAIModelProvider(api_key="test-key")
|
||||
GeminiModelProvider(api_key="test-key")
|
||||
|
||||
service = ModelRestrictionService()
|
||||
|
||||
# When providers check models, they pass both resolved and original names
|
||||
# OpenAI: 'mini' shorthand allows o4-mini
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "mini") # How providers actually call it
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o4-mini") # Direct check without original (for testing)
|
||||
# OpenAI: 'o4mini' shorthand allows o4-mini
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini") # How providers actually call it
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini") # Canonical should also be allowed
|
||||
|
||||
# OpenAI: o3-mini allowed directly
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
@@ -280,19 +299,25 @@ class TestProviderIntegration:
|
||||
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
# Test case: Only alias "flash" is allowed, not the full name
|
||||
# If parameters are in wrong order, this test will catch it
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Should allow "flash" alias
|
||||
assert provider.validate_model_name("flash")
|
||||
with patch.object(ModelProviderRegistry, "get_provider", return_value=provider):
|
||||
|
||||
# Should allow getting capabilities for "flash"
|
||||
capabilities = provider.get_capabilities("flash")
|
||||
assert capabilities.model_name == "gemini-2.5-flash"
|
||||
# Test case: Only alias "flash" is allowed, not the full name
|
||||
# If parameters are in wrong order, this test will catch it
|
||||
|
||||
# Test the edge case: Try to use full model name when only alias is allowed
|
||||
# This should NOT be allowed - only the alias "flash" is in the restriction list
|
||||
assert not provider.validate_model_name("gemini-2.5-flash")
|
||||
# Should allow "flash" alias
|
||||
assert provider.validate_model_name("flash")
|
||||
|
||||
# Should allow getting capabilities for "flash"
|
||||
capabilities = provider.get_capabilities("flash")
|
||||
assert capabilities.model_name == "gemini-2.5-flash"
|
||||
|
||||
# Canonical form should also be allowed now that alias is on the allowlist
|
||||
assert provider.validate_model_name("gemini-2.5-flash")
|
||||
# Unrelated models remain blocked
|
||||
assert not provider.validate_model_name("pro")
|
||||
assert not provider.validate_model_name("gemini-2.5-pro")
|
||||
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
|
||||
def test_gemini_parameter_order_edge_case_full_name_only(self):
|
||||
@@ -570,17 +595,27 @@ class TestShorthandRestrictions:
|
||||
|
||||
# Test OpenAI provider
|
||||
openai_provider = OpenAIModelProvider(api_key="test-key")
|
||||
assert openai_provider.validate_model_name("mini") # Should work with shorthand
|
||||
# When restricting to "mini", you can't use "o4-mini" directly - this is correct behavior
|
||||
assert not openai_provider.validate_model_name("o4-mini") # Not allowed - only shorthand is allowed
|
||||
assert not openai_provider.validate_model_name("o3-mini") # Not allowed
|
||||
|
||||
# Test Gemini provider
|
||||
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
|
||||
# Same for Gemini - if you restrict to "flash", you can't use the full name
|
||||
assert not gemini_provider.validate_model_name("gemini-2.5-flash") # Not allowed
|
||||
assert not gemini_provider.validate_model_name("pro") # Not allowed
|
||||
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
def registry_side_effect(provider_type, force_new=False):
|
||||
mapping = {
|
||||
ProviderType.OPENAI: openai_provider,
|
||||
ProviderType.GOOGLE: gemini_provider,
|
||||
}
|
||||
return mapping.get(provider_type)
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect):
|
||||
assert openai_provider.validate_model_name("mini") # Should work with shorthand
|
||||
assert openai_provider.validate_model_name("gpt-5-mini") # Canonical resolved from shorthand
|
||||
assert not openai_provider.validate_model_name("o4-mini") # Unrelated model still blocked
|
||||
assert not openai_provider.validate_model_name("o3-mini")
|
||||
|
||||
# Test Gemini provider
|
||||
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
|
||||
assert gemini_provider.validate_model_name("gemini-2.5-flash") # Canonical allowed
|
||||
assert not gemini_provider.validate_model_name("pro") # Not allowed
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
|
||||
def test_multiple_shorthands_for_same_model(self):
|
||||
@@ -596,9 +631,9 @@ class TestShorthandRestrictions:
|
||||
assert openai_provider.validate_model_name("mini") # mini -> o4-mini
|
||||
assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini
|
||||
|
||||
# Resolved names work only if explicitly allowed
|
||||
# Resolved names should be allowed when their shorthands are present
|
||||
assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed
|
||||
assert not openai_provider.validate_model_name("o3-mini") # Not explicitly allowed, only shorthand
|
||||
assert openai_provider.validate_model_name("o3-mini") # Allowed via shorthand
|
||||
|
||||
# Other models should not work
|
||||
assert not openai_provider.validate_model_name("o3")
|
||||
|
||||
@@ -260,9 +260,10 @@ class TestOpenRouterAutoMode:
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
|
||||
mock_provider_class = Mock()
|
||||
mock_provider_instance = Mock(spec=["get_provider_type", "list_models"])
|
||||
mock_provider_instance = Mock(spec=["get_provider_type", "list_models", "get_all_model_capabilities"])
|
||||
mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER
|
||||
mock_provider_instance.list_models.return_value = []
|
||||
mock_provider_instance.get_all_model_capabilities.return_value = {}
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class)
|
||||
|
||||
@@ -293,13 +293,7 @@ class TestOpenRouterAliasRestrictions:
|
||||
# o3 -> openai/o3
|
||||
# gpt4.1 -> should not exist (expected to be filtered out)
|
||||
|
||||
expected_models = {
|
||||
"openai/o3-mini",
|
||||
"google/gemini-2.5-pro",
|
||||
"google/gemini-2.5-flash",
|
||||
"openai/o4-mini",
|
||||
"openai/o3",
|
||||
}
|
||||
expected_models = {"o3-mini", "pro", "flash", "o4-mini", "o3"}
|
||||
|
||||
available_model_names = set(available_models.keys())
|
||||
|
||||
@@ -355,9 +349,11 @@ class TestOpenRouterAliasRestrictions:
|
||||
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
|
||||
expected_models = {
|
||||
"openai/o3-mini", # from alias
|
||||
"o3-mini", # alias
|
||||
"openai/o3-mini", # canonical
|
||||
"anthropic/claude-opus-4.1", # full name
|
||||
"google/gemini-2.5-flash", # from alias
|
||||
"flash", # alias
|
||||
"google/gemini-2.5-flash", # canonical
|
||||
}
|
||||
|
||||
available_model_names = set(available_models.keys())
|
||||
|
||||
@@ -83,9 +83,18 @@ class ListModelsTool(BaseTool):
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
output_lines = ["# Available AI Models\n"]
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
restricted_models_by_provider: dict[ProviderType, list[str]] = {}
|
||||
|
||||
if restriction_service:
|
||||
restricted_map = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
for model_name, provider_type in restricted_map.items():
|
||||
restricted_models_by_provider.setdefault(provider_type, []).append(model_name)
|
||||
|
||||
# Map provider types to friendly names and their models
|
||||
provider_info = {
|
||||
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
|
||||
@@ -94,6 +103,43 @@ class ListModelsTool(BaseTool):
|
||||
ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
|
||||
}
|
||||
|
||||
def format_model_entry(provider, display_name: str) -> list[str]:
|
||||
try:
|
||||
capabilities = provider.get_capabilities(display_name)
|
||||
except ValueError:
|
||||
return [f"- `{display_name}` *(not recognized by provider)*"]
|
||||
|
||||
canonical = capabilities.model_name
|
||||
if canonical.lower() == display_name.lower():
|
||||
header = f"- `{canonical}`"
|
||||
else:
|
||||
header = f"- `{display_name}` → `{canonical}`"
|
||||
|
||||
try:
|
||||
context_value = capabilities.context_window or 0
|
||||
except AttributeError:
|
||||
context_value = 0
|
||||
try:
|
||||
context_value = int(context_value)
|
||||
except (TypeError, ValueError):
|
||||
context_value = 0
|
||||
|
||||
if context_value >= 1_000_000:
|
||||
context_str = f"{context_value // 1_000_000}M context"
|
||||
elif context_value >= 1_000:
|
||||
context_str = f"{context_value // 1_000}K context"
|
||||
elif context_value > 0:
|
||||
context_str = f"{context_value} context"
|
||||
else:
|
||||
context_str = "unknown context"
|
||||
|
||||
try:
|
||||
description = capabilities.description or "No description available"
|
||||
except AttributeError:
|
||||
description = "No description available"
|
||||
lines = [header, f" - {context_str}", f" - {description}"]
|
||||
return lines
|
||||
|
||||
# Check each native provider type
|
||||
for provider_type, info in provider_info.items():
|
||||
# Check if provider is enabled
|
||||
@@ -104,30 +150,49 @@ class ListModelsTool(BaseTool):
|
||||
|
||||
if is_configured:
|
||||
output_lines.append("**Status**: Configured and available")
|
||||
output_lines.append("\n**Models**:")
|
||||
has_restrictions = bool(restriction_service and restriction_service.has_restrictions(provider_type))
|
||||
|
||||
aliases = []
|
||||
for model_name, capabilities in provider.get_capabilities_by_rank():
|
||||
description = capabilities.description or "No description available"
|
||||
context_window = capabilities.context_window
|
||||
if has_restrictions:
|
||||
restricted_names = sorted(set(restricted_models_by_provider.get(provider_type, [])))
|
||||
|
||||
if context_window >= 1_000_000:
|
||||
context_str = f"{context_window // 1_000_000}M context"
|
||||
elif context_window >= 1_000:
|
||||
context_str = f"{context_window // 1_000}K context"
|
||||
if restricted_names:
|
||||
output_lines.append("\n**Models (policy restricted)**:")
|
||||
for model_name in restricted_names:
|
||||
output_lines.extend(format_model_entry(provider, model_name))
|
||||
else:
|
||||
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
|
||||
output_lines.append("\n*No models are currently allowed by restriction policy.*")
|
||||
else:
|
||||
output_lines.append("\n**Models**:")
|
||||
|
||||
output_lines.append(f"- `{model_name}` - {context_str}")
|
||||
output_lines.append(f" - {description}")
|
||||
aliases = []
|
||||
for model_name, capabilities in provider.get_capabilities_by_rank():
|
||||
try:
|
||||
description = capabilities.description or "No description available"
|
||||
except AttributeError:
|
||||
description = "No description available"
|
||||
|
||||
for alias in capabilities.aliases or []:
|
||||
if alias != model_name:
|
||||
aliases.append(f"- `{alias}` → `{model_name}`")
|
||||
try:
|
||||
context_window = capabilities.context_window or 0
|
||||
except AttributeError:
|
||||
context_window = 0
|
||||
|
||||
if aliases:
|
||||
output_lines.append("\n**Aliases**:")
|
||||
output_lines.extend(sorted(aliases))
|
||||
if context_window >= 1_000_000:
|
||||
context_str = f"{context_window // 1_000_000}M context"
|
||||
elif context_window >= 1_000:
|
||||
context_str = f"{context_window // 1_000}K context"
|
||||
else:
|
||||
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
|
||||
|
||||
output_lines.append(f"- `{model_name}` - {context_str}")
|
||||
output_lines.append(f" - {description}")
|
||||
|
||||
for alias in capabilities.aliases or []:
|
||||
if alias != model_name:
|
||||
aliases.append(f"- `{alias}` → `{model_name}`")
|
||||
|
||||
if aliases:
|
||||
output_lines.append("\n**Aliases**:")
|
||||
output_lines.extend(sorted(aliases))
|
||||
else:
|
||||
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
||||
|
||||
@@ -144,19 +209,10 @@ class ListModelsTool(BaseTool):
|
||||
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
|
||||
|
||||
try:
|
||||
# Get OpenRouter provider from registry to properly apply restrictions
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||
if provider:
|
||||
# Get models with restrictions applied
|
||||
available_models = provider.list_models(respect_restrictions=True)
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
# Group by provider and retain ranking information for consistent ordering
|
||||
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
|
||||
|
||||
def _format_context(tokens: int) -> str:
|
||||
if not tokens:
|
||||
return "?"
|
||||
@@ -166,53 +222,83 @@ class ListModelsTool(BaseTool):
|
||||
return f"{tokens // 1_000}K"
|
||||
return str(tokens)
|
||||
|
||||
for model_name in available_models:
|
||||
config = registry.resolve(model_name)
|
||||
provider_name = "other"
|
||||
if config and "/" in config.model_name:
|
||||
provider_name = config.model_name.split("/")[0]
|
||||
elif "/" in model_name:
|
||||
provider_name = model_name.split("/")[0]
|
||||
has_restrictions = bool(
|
||||
restriction_service and restriction_service.has_restrictions(ProviderType.OPENROUTER)
|
||||
)
|
||||
|
||||
providers_models.setdefault(provider_name, [])
|
||||
if has_restrictions:
|
||||
restricted_names = sorted(set(restricted_models_by_provider.get(ProviderType.OPENROUTER, [])))
|
||||
|
||||
rank = config.get_effective_capability_rank() if config else 0
|
||||
providers_models[provider_name].append((rank, model_name, config))
|
||||
output_lines.append("\n**Models (policy restricted)**:")
|
||||
if restricted_names:
|
||||
for model_name in restricted_names:
|
||||
try:
|
||||
caps = provider.get_capabilities(model_name)
|
||||
except ValueError:
|
||||
output_lines.append(f"- `{model_name}` *(not recognized by provider)*")
|
||||
continue
|
||||
|
||||
output_lines.append("\n**Available Models**:")
|
||||
for provider_name, models in sorted(providers_models.items()):
|
||||
output_lines.append(f"\n*{provider_name.title()}:*")
|
||||
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
|
||||
if config:
|
||||
context_str = _format_context(config.context_window)
|
||||
context_value = int(caps.context_window or 0)
|
||||
context_str = _format_context(context_value)
|
||||
suffix_parts = [f"{context_str} context"]
|
||||
if getattr(config, "supports_extended_thinking", False):
|
||||
if caps.supports_extended_thinking:
|
||||
suffix_parts.append("thinking")
|
||||
suffix = ", ".join(suffix_parts)
|
||||
output_lines.append(f"- `{alias}` → `{config.model_name}` (score {rank}, {suffix})")
|
||||
else:
|
||||
output_lines.append(f"- `{alias}` (score {rank})")
|
||||
|
||||
total_models = len(available_models)
|
||||
# Show all models - no truncation message needed
|
||||
arrow = ""
|
||||
if caps.model_name.lower() != model_name.lower():
|
||||
arrow = f" → `{caps.model_name}`"
|
||||
|
||||
# Check if restrictions are applied
|
||||
restriction_service = None
|
||||
try:
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
score = caps.get_effective_capability_rank()
|
||||
output_lines.append(f"- `{model_name}`{arrow} (score {score}, {suffix})")
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if restriction_service.has_restrictions(ProviderType.OPENROUTER):
|
||||
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER)
|
||||
output_lines.append(
|
||||
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking OpenRouter restrictions: {e}")
|
||||
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER) or set()
|
||||
if allowed_set:
|
||||
output_lines.append(
|
||||
f"\n*OpenRouter models restricted by OPENROUTER_ALLOWED_MODELS: {', '.join(sorted(allowed_set))}*"
|
||||
)
|
||||
else:
|
||||
output_lines.append("- *No models allowed by current restriction policy.*")
|
||||
else:
|
||||
available_models = provider.list_models(respect_restrictions=True)
|
||||
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
|
||||
|
||||
for model_name in available_models:
|
||||
config = registry.resolve(model_name)
|
||||
provider_name = "other"
|
||||
if config and "/" in config.model_name:
|
||||
provider_name = config.model_name.split("/")[0]
|
||||
elif "/" in model_name:
|
||||
provider_name = model_name.split("/")[0]
|
||||
|
||||
providers_models.setdefault(provider_name, [])
|
||||
|
||||
rank = config.get_effective_capability_rank() if config else 0
|
||||
providers_models[provider_name].append((rank, model_name, config))
|
||||
|
||||
output_lines.append("\n**Available Models**:")
|
||||
for provider_name, models in sorted(providers_models.items()):
|
||||
output_lines.append(f"\n*{provider_name.title()}:*")
|
||||
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
|
||||
if config:
|
||||
context_str = _format_context(getattr(config, "context_window", 0))
|
||||
suffix_parts = [f"{context_str} context"]
|
||||
if getattr(config, "supports_extended_thinking", False):
|
||||
suffix_parts.append("thinking")
|
||||
suffix = ", ".join(suffix_parts)
|
||||
|
||||
arrow = ""
|
||||
if config.model_name.lower() != alias.lower():
|
||||
arrow = f" → `{config.model_name}`"
|
||||
|
||||
output_lines.append(f"- `{alias}`{arrow} (score {rank}, {suffix})")
|
||||
else:
|
||||
output_lines.append(f"- `{alias}` (score {rank})")
|
||||
else:
|
||||
output_lines.append("**Error**: Could not load OpenRouter provider")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error listing OpenRouter models: %s", e)
|
||||
output_lines.append(f"**Error loading models**: {str(e)}")
|
||||
else:
|
||||
output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)")
|
||||
|
||||
@@ -22,6 +22,7 @@ Example:
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from providers.shared import ProviderType
|
||||
@@ -58,6 +59,7 @@ class ModelRestrictionService:
|
||||
def __init__(self):
|
||||
"""Initialize the restriction service by loading from environment."""
|
||||
self.restrictions: dict[ProviderType, set[str]] = {}
|
||||
self._alias_resolution_cache: dict[ProviderType, dict[str, str]] = defaultdict(dict)
|
||||
self._load_from_env()
|
||||
|
||||
def _load_from_env(self) -> None:
|
||||
@@ -79,6 +81,7 @@ class ModelRestrictionService:
|
||||
|
||||
if models:
|
||||
self.restrictions[provider_type] = models
|
||||
self._alias_resolution_cache[provider_type] = {}
|
||||
logger.info(f"{provider_type.value} allowed models: {sorted(models)}")
|
||||
else:
|
||||
# All entries were empty after cleaning - treat as no restrictions
|
||||
@@ -150,7 +153,41 @@ class ModelRestrictionService:
|
||||
names_to_check.add(original_name.lower())
|
||||
|
||||
# If any of the names is in the allowed set, it's allowed
|
||||
return any(name in allowed_set for name in names_to_check)
|
||||
if any(name in allowed_set for name in names_to_check):
|
||||
return True
|
||||
|
||||
# Attempt to resolve canonical names for allowed aliases using provider metadata.
|
||||
try:
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
provider = ModelProviderRegistry.get_provider(provider_type)
|
||||
except Exception: # pragma: no cover - registry lookup failure shouldn't break validation
|
||||
provider = None
|
||||
|
||||
if provider:
|
||||
cache = self._alias_resolution_cache.setdefault(provider_type, {})
|
||||
|
||||
for allowed_entry in list(allowed_set):
|
||||
normalized_resolved = cache.get(allowed_entry)
|
||||
|
||||
if not normalized_resolved:
|
||||
try:
|
||||
resolved = provider._resolve_model_name(allowed_entry)
|
||||
except Exception: # pragma: no cover - resolution failures are treated as non-matches
|
||||
continue
|
||||
|
||||
if not resolved:
|
||||
continue
|
||||
|
||||
normalized_resolved = resolved.lower()
|
||||
cache[allowed_entry] = normalized_resolved
|
||||
|
||||
if normalized_resolved in names_to_check:
|
||||
allowed_set.add(normalized_resolved)
|
||||
cache[normalized_resolved] = normalized_resolved
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user