fix: listmodels to always honor restricted models
fix: restrictions should resolve canonical names for openrouter fix: tools now correctly return restricted list by presenting model names in schema fix: tests updated to ensure these manage their expected env vars properly perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
11
AGENTS.md
11
AGENTS.md
@@ -1,5 +1,9 @@
|
|||||||
# Repository Guidelines
|
# Repository Guidelines
|
||||||
|
|
||||||
|
See `requirements.txt` and `requirements-dev.txt`
|
||||||
|
|
||||||
|
Also read CLAUDE.md and CLAUDE.local.md if available.
|
||||||
|
|
||||||
## Project Structure & Module Organization
|
## Project Structure & Module Organization
|
||||||
Zen MCP Server centers on `server.py`, which exposes MCP entrypoints and coordinates multi-model workflows.
|
Zen MCP Server centers on `server.py`, which exposes MCP entrypoints and coordinates multi-model workflows.
|
||||||
Feature-specific tools live in `tools/`, provider integrations in `providers/`, and shared helpers in `utils/`.
|
Feature-specific tools live in `tools/`, provider integrations in `providers/`, and shared helpers in `utils/`.
|
||||||
@@ -14,6 +18,13 @@ Authoritative documentation and samples live in `docs/`, and runtime diagnostics
|
|||||||
- `python communication_simulator_test.py --quick` – smoke-test orchestration across tools and providers.
|
- `python communication_simulator_test.py --quick` – smoke-test orchestration across tools and providers.
|
||||||
- `./run_integration_tests.sh [--with-simulator]` – exercise provider-dependent flows against remote or Ollama models.
|
- `./run_integration_tests.sh [--with-simulator]` – exercise provider-dependent flows against remote or Ollama models.
|
||||||
|
|
||||||
|
For example, this is how we run an individual / all tests:
|
||||||
|
|
||||||
|
```
|
||||||
|
.zen_venv/bin/activate && pytest tests/test_auto_mode_model_listing.py -q
|
||||||
|
.zen_venv/bin/activate && pytest -q
|
||||||
|
```
|
||||||
|
|
||||||
## Coding Style & Naming Conventions
|
## Coding Style & Naming Conventions
|
||||||
Target Python 3.9+ with Black and isort using a 120-character line limit; Ruff enforces pycodestyle, pyflakes, bugbear, comprehension, and pyupgrade rules. Prefer explicit type hints, snake_case modules, and imperative commit-time docstrings. Extend workflows by defining hook or abstract methods instead of checking `hasattr()`/`getattr()`—inheritance-backed contracts keep behavior discoverable and testable.
|
Target Python 3.9+ with Black and isort using a 120-character line limit; Ruff enforces pycodestyle, pyflakes, bugbear, comprehension, and pyupgrade rules. Prefer explicit type hints, snake_case modules, and imperative commit-time docstrings. Extend workflows by defining hook or abstract methods instead of checking `hasattr()`/`getattr()`—inheritance-backed contracts keep behavior discoverable and testable.
|
||||||
|
|
||||||
|
|||||||
@@ -73,6 +73,8 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
logging.info(f"Initializing Custom provider with endpoint: {base_url}")
|
logging.info(f"Initializing Custom provider with endpoint: {base_url}")
|
||||||
|
|
||||||
|
self._alias_cache: dict[str, str] = {}
|
||||||
|
|
||||||
super().__init__(api_key, base_url=base_url, **kwargs)
|
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
# Initialize model registry (shared with OpenRouter for consistent aliases)
|
# Initialize model registry (shared with OpenRouter for consistent aliases)
|
||||||
@@ -120,11 +122,18 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
def _resolve_model_name(self, model_name: str) -> str:
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
"""Resolve registry aliases and strip version tags for local models."""
|
"""Resolve registry aliases and strip version tags for local models."""
|
||||||
|
|
||||||
|
cache_key = model_name.lower()
|
||||||
|
if cache_key in self._alias_cache:
|
||||||
|
return self._alias_cache[cache_key]
|
||||||
|
|
||||||
config = self._registry.resolve(model_name)
|
config = self._registry.resolve(model_name)
|
||||||
if config:
|
if config:
|
||||||
if config.model_name != model_name:
|
if config.model_name != model_name:
|
||||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
|
||||||
return config.model_name
|
resolved = config.model_name
|
||||||
|
self._alias_cache[cache_key] = resolved
|
||||||
|
self._alias_cache.setdefault(resolved.lower(), resolved)
|
||||||
|
return resolved
|
||||||
|
|
||||||
if ":" in model_name:
|
if ":" in model_name:
|
||||||
base_model = model_name.split(":")[0]
|
base_model = model_name.split(":")[0]
|
||||||
@@ -132,11 +141,16 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
base_config = self._registry.resolve(base_model)
|
base_config = self._registry.resolve(base_model)
|
||||||
if base_config:
|
if base_config:
|
||||||
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
logging.debug("Resolved base model '%s' to '%s'", base_model, base_config.model_name)
|
||||||
return base_config.model_name
|
resolved = base_config.model_name
|
||||||
|
self._alias_cache[cache_key] = resolved
|
||||||
|
self._alias_cache.setdefault(resolved.lower(), resolved)
|
||||||
|
return resolved
|
||||||
|
self._alias_cache[cache_key] = base_model
|
||||||
return base_model
|
return base_model
|
||||||
|
|
||||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||||
|
self._alias_cache[cache_key] = model_name
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
base_url: Base URL for the API endpoint
|
base_url: Base URL for the API endpoint
|
||||||
**kwargs: Additional configuration options including timeout
|
**kwargs: Additional configuration options including timeout
|
||||||
"""
|
"""
|
||||||
|
self._allowed_alias_cache: dict[str, str] = {}
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
self._client = None
|
self._client = None
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
@@ -74,9 +75,33 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
canonical = canonical_name.lower()
|
canonical = canonical_name.lower()
|
||||||
|
|
||||||
if requested not in self.allowed_models and canonical not in self.allowed_models:
|
if requested not in self.allowed_models and canonical not in self.allowed_models:
|
||||||
raise ValueError(
|
allowed = False
|
||||||
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
for allowed_entry in list(self.allowed_models):
|
||||||
)
|
normalized_resolved = self._allowed_alias_cache.get(allowed_entry)
|
||||||
|
if normalized_resolved is None:
|
||||||
|
try:
|
||||||
|
resolved_name = self._resolve_model_name(allowed_entry)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not resolved_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_resolved = resolved_name.lower()
|
||||||
|
self._allowed_alias_cache[allowed_entry] = normalized_resolved
|
||||||
|
|
||||||
|
if normalized_resolved == canonical:
|
||||||
|
# Canonical match discovered via alias resolution – mark as allowed and
|
||||||
|
# memoise the canonical entry for future lookups.
|
||||||
|
allowed = True
|
||||||
|
self._allowed_alias_cache[canonical] = canonical
|
||||||
|
self.allowed_models.add(canonical)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_allowed_models(self) -> Optional[set[str]]:
|
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||||||
"""Parse allowed models from environment variable.
|
"""Parse allowed models from environment variable.
|
||||||
@@ -94,6 +119,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
||||||
if models:
|
if models:
|
||||||
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
||||||
|
self._allowed_alias_cache = {}
|
||||||
return models
|
return models
|
||||||
|
|
||||||
# Log info if no allow-list configured for proxy providers
|
# Log info if no allow-list configured for proxy providers
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
**kwargs: Additional configuration
|
**kwargs: Additional configuration
|
||||||
"""
|
"""
|
||||||
base_url = "https://openrouter.ai/api/v1"
|
base_url = "https://openrouter.ai/api/v1"
|
||||||
|
self._alias_cache: dict[str, str] = {}
|
||||||
super().__init__(api_key, base_url=base_url, **kwargs)
|
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
# Initialize model registry
|
# Initialize model registry
|
||||||
@@ -178,13 +179,21 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
def _resolve_model_name(self, model_name: str) -> str:
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
"""Resolve aliases defined in the OpenRouter registry."""
|
"""Resolve aliases defined in the OpenRouter registry."""
|
||||||
|
|
||||||
|
cache_key = model_name.lower()
|
||||||
|
if cache_key in self._alias_cache:
|
||||||
|
return self._alias_cache[cache_key]
|
||||||
|
|
||||||
config = self._registry.resolve(model_name)
|
config = self._registry.resolve(model_name)
|
||||||
if config:
|
if config:
|
||||||
if config.model_name != model_name:
|
if config.model_name != model_name:
|
||||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
|
||||||
return config.model_name
|
resolved = config.model_name
|
||||||
|
self._alias_cache[cache_key] = resolved
|
||||||
|
self._alias_cache.setdefault(resolved.lower(), resolved)
|
||||||
|
return resolved
|
||||||
|
|
||||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||||
|
self._alias_cache[cache_key] = model_name
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
|||||||
@@ -205,6 +205,18 @@ class ModelProviderRegistry:
|
|||||||
logging.warning("Provider %s does not implement list_models", provider_type)
|
logging.warning("Provider %s does not implement list_models", provider_type)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if restriction_service and restriction_service.has_restrictions(provider_type):
|
||||||
|
restricted_display = cls._collect_restricted_display_names(
|
||||||
|
provider,
|
||||||
|
provider_type,
|
||||||
|
available,
|
||||||
|
restriction_service,
|
||||||
|
)
|
||||||
|
if restricted_display:
|
||||||
|
for model_name in restricted_display:
|
||||||
|
models[model_name] = provider_type
|
||||||
|
continue
|
||||||
|
|
||||||
for model_name in available:
|
for model_name in available:
|
||||||
# =====================================================================================
|
# =====================================================================================
|
||||||
# CRITICAL: Prevent double restriction filtering (Fixed Issue #98)
|
# CRITICAL: Prevent double restriction filtering (Fixed Issue #98)
|
||||||
@@ -227,6 +239,50 @@ class ModelProviderRegistry:
|
|||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _collect_restricted_display_names(
|
||||||
|
cls,
|
||||||
|
provider: ModelProvider,
|
||||||
|
provider_type: ProviderType,
|
||||||
|
available: list[str],
|
||||||
|
restriction_service,
|
||||||
|
) -> list[str] | None:
|
||||||
|
"""Derive the human-facing model list when restrictions are active."""
|
||||||
|
|
||||||
|
allowed_models = restriction_service.get_allowed_models(provider_type)
|
||||||
|
if not allowed_models:
|
||||||
|
return None
|
||||||
|
|
||||||
|
allowed_details: list[tuple[str, int]] = []
|
||||||
|
|
||||||
|
for model_name in sorted(allowed_models):
|
||||||
|
try:
|
||||||
|
capabilities = provider.get_capabilities(model_name)
|
||||||
|
except (AttributeError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
rank = capabilities.get_effective_capability_rank()
|
||||||
|
rank_value = float(rank)
|
||||||
|
except (AttributeError, TypeError, ValueError):
|
||||||
|
rank_value = 0.0
|
||||||
|
|
||||||
|
allowed_details.append((model_name, rank_value))
|
||||||
|
|
||||||
|
if allowed_details:
|
||||||
|
allowed_details.sort(key=lambda item: (-item[1], item[0]))
|
||||||
|
return [name for name, _ in allowed_details]
|
||||||
|
|
||||||
|
# Fallback: intersect the allowlist with the provider-advertised names.
|
||||||
|
available_lookup = {name.lower(): name for name in available}
|
||||||
|
display_names: list[str] = []
|
||||||
|
for model_name in sorted(allowed_models):
|
||||||
|
lowered = model_name.lower()
|
||||||
|
if lowered in available_lookup:
|
||||||
|
display_names.append(available_lookup[lowered])
|
||||||
|
|
||||||
|
return display_names
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
|
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
|
||||||
"""Get list of available model names, optionally filtered by provider.
|
"""Get list of available model names, optionally filtered by provider.
|
||||||
|
|||||||
18
server.py
18
server.py
@@ -492,15 +492,25 @@ def configure_providers():
|
|||||||
|
|
||||||
# Register providers in priority order:
|
# Register providers in priority order:
|
||||||
# 1. Native APIs first (most direct and efficient)
|
# 1. Native APIs first (most direct and efficient)
|
||||||
|
registered_providers = []
|
||||||
|
|
||||||
if has_native_apis:
|
if has_native_apis:
|
||||||
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
registered_providers.append(ProviderType.GOOGLE.value)
|
||||||
|
logger.debug(f"Registered provider: {ProviderType.GOOGLE.value}")
|
||||||
if openai_key and openai_key != "your_openai_api_key_here":
|
if openai_key and openai_key != "your_openai_api_key_here":
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
registered_providers.append(ProviderType.OPENAI.value)
|
||||||
|
logger.debug(f"Registered provider: {ProviderType.OPENAI.value}")
|
||||||
if xai_key and xai_key != "your_xai_api_key_here":
|
if xai_key and xai_key != "your_xai_api_key_here":
|
||||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||||
|
registered_providers.append(ProviderType.XAI.value)
|
||||||
|
logger.debug(f"Registered provider: {ProviderType.XAI.value}")
|
||||||
if dial_key and dial_key != "your_dial_api_key_here":
|
if dial_key and dial_key != "your_dial_api_key_here":
|
||||||
ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider)
|
||||||
|
registered_providers.append(ProviderType.DIAL.value)
|
||||||
|
logger.debug(f"Registered provider: {ProviderType.DIAL.value}")
|
||||||
|
|
||||||
# 2. Custom provider second (for local/private models)
|
# 2. Custom provider second (for local/private models)
|
||||||
if has_custom:
|
if has_custom:
|
||||||
@@ -511,10 +521,18 @@ def configure_providers():
|
|||||||
return CustomProvider(api_key=api_key or "", base_url=base_url) # Use provided API key or empty string
|
return CustomProvider(api_key=api_key or "", base_url=base_url) # Use provided API key or empty string
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
|
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
|
||||||
|
registered_providers.append(ProviderType.CUSTOM.value)
|
||||||
|
logger.debug(f"Registered provider: {ProviderType.CUSTOM.value}")
|
||||||
|
|
||||||
# 3. OpenRouter last (catch-all for everything else)
|
# 3. OpenRouter last (catch-all for everything else)
|
||||||
if has_openrouter:
|
if has_openrouter:
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
registered_providers.append(ProviderType.OPENROUTER.value)
|
||||||
|
logger.debug(f"Registered provider: {ProviderType.OPENROUTER.value}")
|
||||||
|
|
||||||
|
# Log all registered providers
|
||||||
|
if registered_providers:
|
||||||
|
logger.info(f"Registered providers: {', '.join(registered_providers)}")
|
||||||
|
|
||||||
# Require at least one valid provider
|
# Require at least one valid provider
|
||||||
if not valid_providers:
|
if not valid_providers:
|
||||||
|
|||||||
@@ -63,27 +63,30 @@ class TestAliasTargetRestrictions:
|
|||||||
assert provider.validate_model_name("o4mini")
|
assert provider.validate_model_name("o4mini")
|
||||||
|
|
||||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only
|
||||||
def test_restriction_policy_allows_only_alias_when_alias_specified(self):
|
def test_restriction_policy_alias_allows_canonical(self):
|
||||||
"""Test that restriction policy allows only the alias when just alias is specified.
|
"""Alias-only allowlists should permit both the alias and its canonical target."""
|
||||||
|
|
||||||
If you restrict to 'mini' (which is an alias for gpt-5-mini),
|
|
||||||
only the alias should work, not other models.
|
|
||||||
This is the correct restrictive behavior.
|
|
||||||
"""
|
|
||||||
# Clear cached restriction service
|
|
||||||
import utils.model_restrictions
|
import utils.model_restrictions
|
||||||
|
|
||||||
utils.model_restrictions._restriction_service = None
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# Only the alias should be allowed
|
|
||||||
assert provider.validate_model_name("mini")
|
assert provider.validate_model_name("mini")
|
||||||
# Direct target for this alias should NOT be allowed (mini -> gpt-5-mini)
|
assert provider.validate_model_name("gpt-5-mini")
|
||||||
assert not provider.validate_model_name("gpt-5-mini")
|
|
||||||
# Other models should NOT be allowed
|
|
||||||
assert not provider.validate_model_name("o4-mini")
|
assert not provider.validate_model_name("o4-mini")
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"})
|
||||||
|
def test_restriction_policy_alias_allows_short_name(self):
|
||||||
|
"""Common aliases like 'gpt5' should allow their canonical forms."""
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
assert provider.validate_model_name("gpt5")
|
||||||
|
assert provider.validate_model_name("gpt-5")
|
||||||
|
|
||||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target
|
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target
|
||||||
def test_gemini_restriction_policy_allows_alias_when_target_allowed(self):
|
def test_gemini_restriction_policy_allows_alias_when_target_allowed(self):
|
||||||
"""Test Gemini restriction policy allows alias when target is allowed."""
|
"""Test Gemini restriction policy allows alias when target is allowed."""
|
||||||
@@ -99,19 +102,16 @@ class TestAliasTargetRestrictions:
|
|||||||
assert provider.validate_model_name("flash")
|
assert provider.validate_model_name("flash")
|
||||||
|
|
||||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}) # Allow alias only
|
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}) # Allow alias only
|
||||||
def test_gemini_restriction_policy_allows_only_alias_when_alias_specified(self):
|
def test_gemini_restriction_policy_alias_allows_canonical(self):
|
||||||
"""Test Gemini restriction policy allows only alias when just alias is specified."""
|
"""Gemini alias allowlists should permit canonical forms."""
|
||||||
# Clear cached restriction service
|
|
||||||
import utils.model_restrictions
|
import utils.model_restrictions
|
||||||
|
|
||||||
utils.model_restrictions._restriction_service = None
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
provider = GeminiModelProvider(api_key="test-key")
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# Only the alias should be allowed
|
|
||||||
assert provider.validate_model_name("flash")
|
assert provider.validate_model_name("flash")
|
||||||
# Direct target should NOT be allowed
|
assert provider.validate_model_name("gemini-2.5-flash")
|
||||||
assert not provider.validate_model_name("gemini-2.5-flash")
|
|
||||||
|
|
||||||
def test_restriction_service_validation_includes_all_targets(self):
|
def test_restriction_service_validation_includes_all_targets(self):
|
||||||
"""Test that restriction service validation knows about all aliases and targets."""
|
"""Test that restriction service validation knows about all aliases and targets."""
|
||||||
@@ -153,6 +153,30 @@ class TestAliasTargetRestrictions:
|
|||||||
assert provider.validate_model_name("o4-mini") # target
|
assert provider.validate_model_name("o4-mini") # target
|
||||||
assert provider.validate_model_name("o4mini") # alias for o4-mini
|
assert provider.validate_model_name("o4mini") # alias for o4-mini
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"}, clear=True)
|
||||||
|
def test_service_alias_allows_canonical_openai(self):
|
||||||
|
"""ModelRestrictionService should permit canonical names resolved from aliases."""
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
|
||||||
|
assert service.is_allowed(ProviderType.OPENAI, "gpt-5")
|
||||||
|
assert provider.validate_model_name("gpt-5")
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}, clear=True)
|
||||||
|
def test_service_alias_allows_canonical_gemini(self):
|
||||||
|
"""Gemini alias allowlists should permit canonical forms."""
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
|
||||||
|
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
|
||||||
|
assert provider.validate_model_name("gemini-2.5-flash")
|
||||||
|
|
||||||
def test_alias_target_policy_regression_prevention(self):
|
def test_alias_target_policy_regression_prevention(self):
|
||||||
"""Regression test to ensure aliases and targets are both validated properly.
|
"""Regression test to ensure aliases and targets are both validated properly.
|
||||||
|
|
||||||
|
|||||||
@@ -106,19 +106,35 @@ class TestAutoMode:
|
|||||||
|
|
||||||
def test_tool_schema_in_normal_mode(self):
|
def test_tool_schema_in_normal_mode(self):
|
||||||
"""Test that tool schemas don't require model in normal mode"""
|
"""Test that tool schemas don't require model in normal mode"""
|
||||||
# This test uses the default from conftest.py which sets non-auto mode
|
# Save original
|
||||||
# The conftest.py mock_provider_availability fixture ensures the model is available
|
original = os.environ.get("DEFAULT_MODEL", "")
|
||||||
tool = ChatTool()
|
|
||||||
schema = tool.get_input_schema()
|
|
||||||
|
|
||||||
# Model should not be required when default model is configured
|
try:
|
||||||
assert "model" not in schema["required"]
|
# Set to a specific model (not auto mode)
|
||||||
|
os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
|
||||||
|
import config
|
||||||
|
|
||||||
# Model field should have simpler description
|
importlib.reload(config)
|
||||||
model_schema = schema["properties"]["model"]
|
|
||||||
assert "enum" not in model_schema
|
tool = ChatTool()
|
||||||
assert "listmodels" in model_schema["description"]
|
schema = tool.get_input_schema()
|
||||||
assert "default model" in model_schema["description"].lower()
|
|
||||||
|
# Model should not be required when default model is configured
|
||||||
|
assert "model" not in schema["required"]
|
||||||
|
|
||||||
|
# Model field should have simpler description
|
||||||
|
model_schema = schema["properties"]["model"]
|
||||||
|
assert "enum" not in model_schema
|
||||||
|
assert "listmodels" in model_schema["description"]
|
||||||
|
assert "default model" in model_schema["description"].lower()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore
|
||||||
|
if original:
|
||||||
|
os.environ["DEFAULT_MODEL"] = original
|
||||||
|
else:
|
||||||
|
os.environ.pop("DEFAULT_MODEL", None)
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_auto_mode_requires_model_parameter(self):
|
async def test_auto_mode_requires_model_parameter(self):
|
||||||
|
|||||||
203
tests/test_auto_mode_model_listing.py
Normal file
203
tests/test_auto_mode_model_listing.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""Tests covering model restriction-aware error messaging in auto mode."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import utils.model_restrictions as model_restrictions
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_available_models(message: str) -> list[str]:
|
||||||
|
"""Parse the available model list from the error message."""
|
||||||
|
|
||||||
|
marker = "Available models: "
|
||||||
|
if marker not in message:
|
||||||
|
raise AssertionError(f"Expected '{marker}' in message: {message}")
|
||||||
|
|
||||||
|
start = message.index(marker) + len(marker)
|
||||||
|
end = message.find(". Suggested", start)
|
||||||
|
if end == -1:
|
||||||
|
end = len(message)
|
||||||
|
|
||||||
|
available_segment = message[start:end].strip()
|
||||||
|
if not available_segment:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [item.strip() for item in available_segment.split(",")]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def reset_registry():
|
||||||
|
"""Ensure registry and restriction service state is isolated."""
|
||||||
|
|
||||||
|
ModelProviderRegistry.reset_for_testing()
|
||||||
|
model_restrictions._restriction_service = None
|
||||||
|
yield
|
||||||
|
ModelProviderRegistry.reset_for_testing()
|
||||||
|
model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
|
||||||
|
def _register_core_providers(*, include_xai: bool = False):
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
if include_xai:
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_mock_provider
|
||||||
|
def test_error_listing_respects_env_restrictions(monkeypatch, reset_registry):
|
||||||
|
"""Error payload should surface only the allowed models for each provider."""
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEFAULT_MODEL", "auto")
|
||||||
|
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
|
||||||
|
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
|
||||||
|
monkeypatch.delenv("XAI_API_KEY", raising=False)
|
||||||
|
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
|
||||||
|
try:
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro")
|
||||||
|
monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "gpt-5")
|
||||||
|
monkeypatch.setenv("OPENROUTER_ALLOWED_MODELS", "gpt5nano")
|
||||||
|
monkeypatch.setenv("XAI_ALLOWED_MODELS", "")
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
_register_core_providers()
|
||||||
|
|
||||||
|
import server
|
||||||
|
|
||||||
|
importlib.reload(server)
|
||||||
|
|
||||||
|
# Reload may have re-applied .env overrides; enforce our test configuration
|
||||||
|
for key, value in (
|
||||||
|
("DEFAULT_MODEL", "auto"),
|
||||||
|
("GEMINI_API_KEY", "test-gemini"),
|
||||||
|
("OPENAI_API_KEY", "test-openai"),
|
||||||
|
("OPENROUTER_API_KEY", "test-openrouter"),
|
||||||
|
("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro"),
|
||||||
|
("OPENAI_ALLOWED_MODELS", "gpt-5"),
|
||||||
|
("OPENROUTER_ALLOWED_MODELS", "gpt5nano"),
|
||||||
|
("XAI_ALLOWED_MODELS", ""),
|
||||||
|
):
|
||||||
|
monkeypatch.setenv(key, value)
|
||||||
|
|
||||||
|
for var in ("XAI_API_KEY", "CUSTOM_API_URL", "CUSTOM_API_KEY", "DIAL_API_KEY"):
|
||||||
|
monkeypatch.delenv(var, raising=False)
|
||||||
|
|
||||||
|
ModelProviderRegistry.reset_for_testing()
|
||||||
|
model_restrictions._restriction_service = None
|
||||||
|
server.configure_providers()
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
server.handle_call_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"model": "gpt5mini",
|
||||||
|
"prompt": "Tell me about your strengths",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
payload = json.loads(result[0].text)
|
||||||
|
assert payload["status"] == "error"
|
||||||
|
|
||||||
|
available_models = _extract_available_models(payload["content"])
|
||||||
|
assert set(available_models) == {"gemini-2.5-pro", "gpt-5", "gpt5nano", "openai/gpt-5-nano"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_mock_provider
|
||||||
|
def test_error_listing_without_restrictions_shows_full_catalog(monkeypatch, reset_registry):
|
||||||
|
"""When no restrictions are set, the full high-capability catalogue should appear."""
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEFAULT_MODEL", "auto")
|
||||||
|
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
|
||||||
|
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
|
||||||
|
monkeypatch.setenv("XAI_API_KEY", "test-xai")
|
||||||
|
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
|
||||||
|
try:
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for var in (
|
||||||
|
"GOOGLE_ALLOWED_MODELS",
|
||||||
|
"OPENAI_ALLOWED_MODELS",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS",
|
||||||
|
"XAI_ALLOWED_MODELS",
|
||||||
|
"DIAL_ALLOWED_MODELS",
|
||||||
|
):
|
||||||
|
monkeypatch.delenv(var, raising=False)
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
_register_core_providers(include_xai=True)
|
||||||
|
|
||||||
|
import server
|
||||||
|
|
||||||
|
importlib.reload(server)
|
||||||
|
|
||||||
|
for key, value in (
|
||||||
|
("DEFAULT_MODEL", "auto"),
|
||||||
|
("GEMINI_API_KEY", "test-gemini"),
|
||||||
|
("OPENAI_API_KEY", "test-openai"),
|
||||||
|
("OPENROUTER_API_KEY", "test-openrouter"),
|
||||||
|
):
|
||||||
|
monkeypatch.setenv(key, value)
|
||||||
|
|
||||||
|
for var in (
|
||||||
|
"GOOGLE_ALLOWED_MODELS",
|
||||||
|
"OPENAI_ALLOWED_MODELS",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS",
|
||||||
|
"XAI_ALLOWED_MODELS",
|
||||||
|
"DIAL_ALLOWED_MODELS",
|
||||||
|
"CUSTOM_API_URL",
|
||||||
|
"CUSTOM_API_KEY",
|
||||||
|
):
|
||||||
|
monkeypatch.delenv(var, raising=False)
|
||||||
|
|
||||||
|
ModelProviderRegistry.reset_for_testing()
|
||||||
|
model_restrictions._restriction_service = None
|
||||||
|
server.configure_providers()
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
server.handle_call_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"model": "dummymodel",
|
||||||
|
"prompt": "Hi there",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
payload = json.loads(result[0].text)
|
||||||
|
assert payload["status"] == "error"
|
||||||
|
|
||||||
|
available_models = _extract_available_models(payload["content"])
|
||||||
|
assert "gemini-2.5-pro" in available_models
|
||||||
|
assert "gpt-5" in available_models
|
||||||
|
assert "grok-4" in available_models
|
||||||
|
assert len(available_models) >= 5
|
||||||
@@ -3,6 +3,7 @@ Tests for dynamic context request and collaboration features
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -157,95 +158,120 @@ class TestDynamicContextRequests:
|
|||||||
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
|
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
|
||||||
async def test_clarification_with_suggested_action(self, mock_get_provider, analyze_tool):
|
async def test_clarification_with_suggested_action(self, mock_get_provider, analyze_tool):
|
||||||
"""Test clarification request with suggested next action"""
|
"""Test clarification request with suggested next action"""
|
||||||
clarification_json = json.dumps(
|
import importlib
|
||||||
{
|
|
||||||
"status": "files_required_to_continue",
|
from providers.registry import ModelProviderRegistry
|
||||||
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
|
|
||||||
"files_needed": ["config/database.yml", "src/db.py"],
|
# Ensure deterministic model configuration for this test regardless of previous suites
|
||||||
"suggested_next_action": {
|
ModelProviderRegistry.reset_for_testing()
|
||||||
"tool": "analyze",
|
|
||||||
"args": {
|
original_default = os.environ.get("DEFAULT_MODEL")
|
||||||
"prompt": "Analyze database connection timeout issue",
|
|
||||||
"relevant_files": [
|
try:
|
||||||
"/config/database.yml",
|
os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
|
||||||
"/src/db.py",
|
import config
|
||||||
"/logs/error.log",
|
|
||||||
],
|
importlib.reload(config)
|
||||||
|
|
||||||
|
clarification_json = json.dumps(
|
||||||
|
{
|
||||||
|
"status": "files_required_to_continue",
|
||||||
|
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
|
||||||
|
"files_needed": ["config/database.yml", "src/db.py"],
|
||||||
|
"suggested_next_action": {
|
||||||
|
"tool": "analyze",
|
||||||
|
"args": {
|
||||||
|
"prompt": "Analyze database connection timeout issue",
|
||||||
|
"relevant_files": [
|
||||||
|
"/config/database.yml",
|
||||||
|
"/src/db.py",
|
||||||
|
"/logs/error.log",
|
||||||
|
],
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
ensure_ascii=False,
|
||||||
ensure_ascii=False,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
mock_provider = create_mock_provider()
|
mock_provider = create_mock_provider()
|
||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={}
|
content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={}
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await analyze_tool.execute(
|
result = await analyze_tool.execute(
|
||||||
{
|
{
|
||||||
"step": "Analyze database connection timeout issue",
|
"step": "Analyze database connection timeout issue",
|
||||||
"step_number": 1,
|
"step_number": 1,
|
||||||
"total_steps": 1,
|
"total_steps": 1,
|
||||||
"next_step_required": False,
|
"next_step_required": False,
|
||||||
"findings": "Initial database timeout analysis",
|
"findings": "Initial database timeout analysis",
|
||||||
"relevant_files": ["/absolute/logs/error.log"],
|
"relevant_files": ["/absolute/logs/error.log"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
|
|
||||||
response_data = json.loads(result[0].text)
|
response_data = json.loads(result[0].text)
|
||||||
|
|
||||||
# Workflow tools should either promote clarification status or handle it in expert analysis
|
# Workflow tools should either promote clarification status or handle it in expert analysis
|
||||||
if response_data["status"] == "files_required_to_continue":
|
if response_data["status"] == "files_required_to_continue":
|
||||||
# Clarification was properly promoted to main status
|
# Clarification was properly promoted to main status
|
||||||
# Check if mandatory_instructions is at top level or in content
|
# Check if mandatory_instructions is at top level or in content
|
||||||
if "mandatory_instructions" in response_data:
|
if "mandatory_instructions" in response_data:
|
||||||
assert "database configuration" in response_data["mandatory_instructions"]
|
assert "database configuration" in response_data["mandatory_instructions"]
|
||||||
assert "files_needed" in response_data
|
assert "files_needed" in response_data
|
||||||
assert "config/database.yml" in response_data["files_needed"]
|
assert "config/database.yml" in response_data["files_needed"]
|
||||||
assert "src/db.py" in response_data["files_needed"]
|
assert "src/db.py" in response_data["files_needed"]
|
||||||
elif "content" in response_data:
|
elif "content" in response_data:
|
||||||
# Parse content JSON for workflow tools
|
# Parse content JSON for workflow tools
|
||||||
try:
|
try:
|
||||||
content_json = json.loads(response_data["content"])
|
content_json = json.loads(response_data["content"])
|
||||||
assert "mandatory_instructions" in content_json
|
assert "mandatory_instructions" in content_json
|
||||||
|
assert (
|
||||||
|
"database configuration" in content_json["mandatory_instructions"]
|
||||||
|
or "database" in content_json["mandatory_instructions"]
|
||||||
|
)
|
||||||
|
assert "files_needed" in content_json
|
||||||
|
files_needed_str = str(content_json["files_needed"])
|
||||||
|
assert (
|
||||||
|
"config/database.yml" in files_needed_str
|
||||||
|
or "config" in files_needed_str
|
||||||
|
or "database" in files_needed_str
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Content is not JSON, check if it contains required text
|
||||||
|
content = response_data["content"]
|
||||||
|
assert "database configuration" in content or "config" in content
|
||||||
|
elif response_data["status"] == "calling_expert_analysis":
|
||||||
|
# Clarification may be handled in expert analysis section
|
||||||
|
if "expert_analysis" in response_data:
|
||||||
|
expert_analysis = response_data["expert_analysis"]
|
||||||
|
expert_content = str(expert_analysis)
|
||||||
assert (
|
assert (
|
||||||
"database configuration" in content_json["mandatory_instructions"]
|
"database configuration" in expert_content
|
||||||
or "database" in content_json["mandatory_instructions"]
|
or "config/database.yml" in expert_content
|
||||||
|
or "files_required_to_continue" in expert_content
|
||||||
)
|
)
|
||||||
assert "files_needed" in content_json
|
else:
|
||||||
files_needed_str = str(content_json["files_needed"])
|
# Some other status - ensure it's a valid workflow response
|
||||||
assert (
|
assert "step_number" in response_data
|
||||||
"config/database.yml" in files_needed_str
|
|
||||||
or "config" in files_needed_str
|
|
||||||
or "database" in files_needed_str
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# Content is not JSON, check if it contains required text
|
|
||||||
content = response_data["content"]
|
|
||||||
assert "database configuration" in content or "config" in content
|
|
||||||
elif response_data["status"] == "calling_expert_analysis":
|
|
||||||
# Clarification may be handled in expert analysis section
|
|
||||||
if "expert_analysis" in response_data:
|
|
||||||
expert_analysis = response_data["expert_analysis"]
|
|
||||||
expert_content = str(expert_analysis)
|
|
||||||
assert (
|
|
||||||
"database configuration" in expert_content
|
|
||||||
or "config/database.yml" in expert_content
|
|
||||||
or "files_required_to_continue" in expert_content
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Some other status - ensure it's a valid workflow response
|
|
||||||
assert "step_number" in response_data
|
|
||||||
|
|
||||||
# Check for suggested next action
|
# Check for suggested next action
|
||||||
if "suggested_next_action" in response_data:
|
if "suggested_next_action" in response_data:
|
||||||
action = response_data["suggested_next_action"]
|
action = response_data["suggested_next_action"]
|
||||||
assert action["tool"] == "analyze"
|
assert action["tool"] == "analyze"
|
||||||
|
finally:
|
||||||
|
if original_default is not None:
|
||||||
|
os.environ["DEFAULT_MODEL"] = original_default
|
||||||
|
else:
|
||||||
|
os.environ.pop("DEFAULT_MODEL", None)
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
ModelProviderRegistry.reset_for_testing()
|
||||||
|
|
||||||
def test_tool_output_model_serialization(self):
|
def test_tool_output_model_serialization(self):
|
||||||
"""Test ToolOutput model serialization"""
|
"""Test ToolOutput model serialization"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
from providers.base import ModelProvider
|
from providers.base import ModelProvider
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ModelCapabilities, ProviderType
|
||||||
from tools.listmodels import ListModelsTool
|
from tools.listmodels import ListModelsTool
|
||||||
|
|
||||||
|
|
||||||
@@ -23,10 +23,63 @@ class TestListModelsRestrictions(unittest.TestCase):
|
|||||||
self.mock_openrouter = MagicMock(spec=ModelProvider)
|
self.mock_openrouter = MagicMock(spec=ModelProvider)
|
||||||
self.mock_openrouter.provider_type = ProviderType.OPENROUTER
|
self.mock_openrouter.provider_type = ProviderType.OPENROUTER
|
||||||
|
|
||||||
|
def make_capabilities(
|
||||||
|
canonical: str, friendly: str, *, aliases=None, context: int = 200_000
|
||||||
|
) -> ModelCapabilities:
|
||||||
|
return ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENROUTER,
|
||||||
|
model_name=canonical,
|
||||||
|
friendly_name=friendly,
|
||||||
|
intelligence_score=20,
|
||||||
|
description=friendly,
|
||||||
|
aliases=aliases or [],
|
||||||
|
context_window=context,
|
||||||
|
max_output_tokens=context,
|
||||||
|
supports_extended_thinking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
opus_caps = make_capabilities(
|
||||||
|
"anthropic/claude-opus-4-20240229",
|
||||||
|
"Claude Opus",
|
||||||
|
aliases=["opus"],
|
||||||
|
)
|
||||||
|
sonnet_caps = make_capabilities(
|
||||||
|
"anthropic/claude-sonnet-4-20240229",
|
||||||
|
"Claude Sonnet",
|
||||||
|
aliases=["sonnet"],
|
||||||
|
)
|
||||||
|
deepseek_caps = make_capabilities(
|
||||||
|
"deepseek/deepseek-r1-0528:free",
|
||||||
|
"DeepSeek R1",
|
||||||
|
aliases=[],
|
||||||
|
)
|
||||||
|
qwen_caps = make_capabilities(
|
||||||
|
"qwen/qwen3-235b-a22b-04-28:free",
|
||||||
|
"Qwen3",
|
||||||
|
aliases=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._openrouter_caps_map = {
|
||||||
|
"anthropic/claude-opus-4": opus_caps,
|
||||||
|
"opus": opus_caps,
|
||||||
|
"anthropic/claude-opus-4-20240229": opus_caps,
|
||||||
|
"anthropic/claude-sonnet-4": sonnet_caps,
|
||||||
|
"sonnet": sonnet_caps,
|
||||||
|
"anthropic/claude-sonnet-4-20240229": sonnet_caps,
|
||||||
|
"deepseek/deepseek-r1-0528:free": deepseek_caps,
|
||||||
|
"qwen/qwen3-235b-a22b-04-28:free": qwen_caps,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.mock_openrouter.get_capabilities.side_effect = self._openrouter_caps_map.__getitem__
|
||||||
|
self.mock_openrouter.get_capabilities_by_rank.return_value = []
|
||||||
|
self.mock_openrouter.list_models.return_value = []
|
||||||
|
|
||||||
# Create mock Gemini provider for comparison
|
# Create mock Gemini provider for comparison
|
||||||
self.mock_gemini = MagicMock(spec=ModelProvider)
|
self.mock_gemini = MagicMock(spec=ModelProvider)
|
||||||
self.mock_gemini.provider_type = ProviderType.GOOGLE
|
self.mock_gemini.provider_type = ProviderType.GOOGLE
|
||||||
self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"]
|
self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"]
|
||||||
|
self.mock_gemini.get_capabilities_by_rank.return_value = []
|
||||||
|
self.mock_gemini.get_capabilities_by_rank.return_value = []
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""Clean up after tests."""
|
"""Clean up after tests."""
|
||||||
@@ -159,7 +212,7 @@ class TestListModelsRestrictions(unittest.TestCase):
|
|||||||
for line in lines:
|
for line in lines:
|
||||||
if "OpenRouter" in line and "✅" in line:
|
if "OpenRouter" in line and "✅" in line:
|
||||||
openrouter_section_found = True
|
openrouter_section_found = True
|
||||||
elif "Available Models" in line and openrouter_section_found:
|
elif ("Models (policy restricted)" in line or "Available Models" in line) and openrouter_section_found:
|
||||||
in_openrouter_section = True
|
in_openrouter_section = True
|
||||||
elif in_openrouter_section:
|
elif in_openrouter_section:
|
||||||
# Check for lines with model names in backticks
|
# Check for lines with model names in backticks
|
||||||
@@ -179,11 +232,11 @@ class TestListModelsRestrictions(unittest.TestCase):
|
|||||||
len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}"
|
len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify list_models was called with respect_restrictions=True
|
# Verify we did not fall back to unrestricted listing
|
||||||
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True)
|
self.mock_openrouter.list_models.assert_not_called()
|
||||||
|
|
||||||
# Check for restriction note
|
# Check for restriction note
|
||||||
self.assertIn("Restricted to models matching:", result)
|
self.assertIn("OpenRouter models restricted by", result)
|
||||||
|
|
||||||
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}, clear=True)
|
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}, clear=True)
|
||||||
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
|
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
|
||||||
|
|||||||
@@ -121,38 +121,59 @@ class TestModelMetadataContinuation:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_no_previous_assistant_turn_defaults(self):
|
async def test_no_previous_assistant_turn_defaults(self):
|
||||||
"""Test behavior when there's no previous assistant turn."""
|
"""Test behavior when there's no previous assistant turn."""
|
||||||
thread_id = create_thread("chat", {"prompt": "test"})
|
# Save and set DEFAULT_MODEL for test
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
# Only add user turns
|
original_default = os.environ.get("DEFAULT_MODEL", "")
|
||||||
add_turn(thread_id, "user", "First question")
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
add_turn(thread_id, "user", "Second question")
|
import config
|
||||||
|
import utils.model_context
|
||||||
|
|
||||||
arguments = {"continuation_id": thread_id}
|
importlib.reload(config)
|
||||||
|
importlib.reload(utils.model_context)
|
||||||
|
|
||||||
# Mock dependencies
|
try:
|
||||||
with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc:
|
thread_id = create_thread("chat", {"prompt": "test"})
|
||||||
mock_calc.return_value = MagicMock(
|
|
||||||
total_tokens=200000,
|
|
||||||
content_tokens=160000,
|
|
||||||
response_tokens=40000,
|
|
||||||
file_tokens=64000,
|
|
||||||
history_tokens=64000,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
|
# Only add user turns
|
||||||
mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000)
|
add_turn(thread_id, "user", "First question")
|
||||||
|
add_turn(thread_id, "user", "Second question")
|
||||||
|
|
||||||
# Call the actual function
|
arguments = {"continuation_id": thread_id}
|
||||||
enhanced_args = await reconstruct_thread_context(arguments)
|
|
||||||
|
|
||||||
# Should not have set a model
|
# Mock dependencies
|
||||||
assert enhanced_args.get("model") is None
|
with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc:
|
||||||
|
mock_calc.return_value = MagicMock(
|
||||||
|
total_tokens=200000,
|
||||||
|
content_tokens=160000,
|
||||||
|
response_tokens=40000,
|
||||||
|
file_tokens=64000,
|
||||||
|
history_tokens=64000,
|
||||||
|
)
|
||||||
|
|
||||||
# ModelContext should use DEFAULT_MODEL
|
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
|
||||||
model_context = ModelContext.from_arguments(enhanced_args)
|
mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000)
|
||||||
from config import DEFAULT_MODEL
|
|
||||||
|
|
||||||
assert model_context.model_name == DEFAULT_MODEL
|
# Call the actual function
|
||||||
|
enhanced_args = await reconstruct_thread_context(arguments)
|
||||||
|
|
||||||
|
# Should not have set a model
|
||||||
|
assert enhanced_args.get("model") is None
|
||||||
|
|
||||||
|
# ModelContext should use DEFAULT_MODEL
|
||||||
|
model_context = ModelContext.from_arguments(enhanced_args)
|
||||||
|
from config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
assert model_context.model_name == DEFAULT_MODEL
|
||||||
|
finally:
|
||||||
|
# Restore original value
|
||||||
|
if original_default:
|
||||||
|
os.environ["DEFAULT_MODEL"] = original_default
|
||||||
|
else:
|
||||||
|
os.environ.pop("DEFAULT_MODEL", None)
|
||||||
|
importlib.reload(config)
|
||||||
|
importlib.reload(utils.model_context)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_explicit_model_overrides_previous_turn(self):
|
async def test_explicit_model_overrides_previous_turn(self):
|
||||||
|
|||||||
@@ -49,17 +49,32 @@ class TestModelRestrictionService:
|
|||||||
def test_load_multiple_models_restriction(self):
|
def test_load_multiple_models_restriction(self):
|
||||||
"""Test loading multiple allowed models."""
|
"""Test loading multiple allowed models."""
|
||||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
||||||
service = ModelRestrictionService()
|
# Instantiate providers so alias resolution for allow-lists is available
|
||||||
|
openai_provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# Check OpenAI models
|
from providers.registry import ModelProviderRegistry
|
||||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
|
||||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
|
|
||||||
assert not service.is_allowed(ProviderType.OPENAI, "o3")
|
|
||||||
|
|
||||||
# Check Google models
|
def fake_get_provider(provider_type, force_new=False):
|
||||||
assert service.is_allowed(ProviderType.GOOGLE, "flash")
|
mapping = {
|
||||||
assert service.is_allowed(ProviderType.GOOGLE, "pro")
|
ProviderType.OPENAI: openai_provider,
|
||||||
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
ProviderType.GOOGLE: gemini_provider,
|
||||||
|
}
|
||||||
|
return mapping.get(provider_type)
|
||||||
|
|
||||||
|
with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider):
|
||||||
|
|
||||||
|
service = ModelRestrictionService()
|
||||||
|
|
||||||
|
# Check OpenAI models
|
||||||
|
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||||
|
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
|
||||||
|
assert not service.is_allowed(ProviderType.OPENAI, "o3")
|
||||||
|
|
||||||
|
# Check Google models
|
||||||
|
assert service.is_allowed(ProviderType.GOOGLE, "flash")
|
||||||
|
assert service.is_allowed(ProviderType.GOOGLE, "pro")
|
||||||
|
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||||
|
|
||||||
def test_case_insensitive_and_whitespace_handling(self):
|
def test_case_insensitive_and_whitespace_handling(self):
|
||||||
"""Test that model names are case-insensitive and whitespace is trimmed."""
|
"""Test that model names are case-insensitive and whitespace is trimmed."""
|
||||||
@@ -111,13 +126,17 @@ class TestModelRestrictionService:
|
|||||||
|
|
||||||
def test_shorthand_names_in_restrictions(self):
|
def test_shorthand_names_in_restrictions(self):
|
||||||
"""Test that shorthand names work in restrictions."""
|
"""Test that shorthand names work in restrictions."""
|
||||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,o3-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
||||||
|
# Instantiate providers so the registry can resolve aliases
|
||||||
|
OpenAIModelProvider(api_key="test-key")
|
||||||
|
GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
service = ModelRestrictionService()
|
service = ModelRestrictionService()
|
||||||
|
|
||||||
# When providers check models, they pass both resolved and original names
|
# When providers check models, they pass both resolved and original names
|
||||||
# OpenAI: 'mini' shorthand allows o4-mini
|
# OpenAI: 'o4mini' shorthand allows o4-mini
|
||||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "mini") # How providers actually call it
|
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini") # How providers actually call it
|
||||||
assert not service.is_allowed(ProviderType.OPENAI, "o4-mini") # Direct check without original (for testing)
|
assert service.is_allowed(ProviderType.OPENAI, "o4-mini") # Canonical should also be allowed
|
||||||
|
|
||||||
# OpenAI: o3-mini allowed directly
|
# OpenAI: o3-mini allowed directly
|
||||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||||
@@ -280,19 +299,25 @@ class TestProviderIntegration:
|
|||||||
|
|
||||||
provider = GeminiModelProvider(api_key="test-key")
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# Test case: Only alias "flash" is allowed, not the full name
|
from providers.registry import ModelProviderRegistry
|
||||||
# If parameters are in wrong order, this test will catch it
|
|
||||||
|
|
||||||
# Should allow "flash" alias
|
with patch.object(ModelProviderRegistry, "get_provider", return_value=provider):
|
||||||
assert provider.validate_model_name("flash")
|
|
||||||
|
|
||||||
# Should allow getting capabilities for "flash"
|
# Test case: Only alias "flash" is allowed, not the full name
|
||||||
capabilities = provider.get_capabilities("flash")
|
# If parameters are in wrong order, this test will catch it
|
||||||
assert capabilities.model_name == "gemini-2.5-flash"
|
|
||||||
|
|
||||||
# Test the edge case: Try to use full model name when only alias is allowed
|
# Should allow "flash" alias
|
||||||
# This should NOT be allowed - only the alias "flash" is in the restriction list
|
assert provider.validate_model_name("flash")
|
||||||
assert not provider.validate_model_name("gemini-2.5-flash")
|
|
||||||
|
# Should allow getting capabilities for "flash"
|
||||||
|
capabilities = provider.get_capabilities("flash")
|
||||||
|
assert capabilities.model_name == "gemini-2.5-flash"
|
||||||
|
|
||||||
|
# Canonical form should also be allowed now that alias is on the allowlist
|
||||||
|
assert provider.validate_model_name("gemini-2.5-flash")
|
||||||
|
# Unrelated models remain blocked
|
||||||
|
assert not provider.validate_model_name("pro")
|
||||||
|
assert not provider.validate_model_name("gemini-2.5-pro")
|
||||||
|
|
||||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
|
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
|
||||||
def test_gemini_parameter_order_edge_case_full_name_only(self):
|
def test_gemini_parameter_order_edge_case_full_name_only(self):
|
||||||
@@ -570,17 +595,27 @@ class TestShorthandRestrictions:
|
|||||||
|
|
||||||
# Test OpenAI provider
|
# Test OpenAI provider
|
||||||
openai_provider = OpenAIModelProvider(api_key="test-key")
|
openai_provider = OpenAIModelProvider(api_key="test-key")
|
||||||
assert openai_provider.validate_model_name("mini") # Should work with shorthand
|
|
||||||
# When restricting to "mini", you can't use "o4-mini" directly - this is correct behavior
|
|
||||||
assert not openai_provider.validate_model_name("o4-mini") # Not allowed - only shorthand is allowed
|
|
||||||
assert not openai_provider.validate_model_name("o3-mini") # Not allowed
|
|
||||||
|
|
||||||
# Test Gemini provider
|
|
||||||
gemini_provider = GeminiModelProvider(api_key="test-key")
|
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||||
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
|
|
||||||
# Same for Gemini - if you restrict to "flash", you can't use the full name
|
from providers.registry import ModelProviderRegistry
|
||||||
assert not gemini_provider.validate_model_name("gemini-2.5-flash") # Not allowed
|
|
||||||
assert not gemini_provider.validate_model_name("pro") # Not allowed
|
def registry_side_effect(provider_type, force_new=False):
|
||||||
|
mapping = {
|
||||||
|
ProviderType.OPENAI: openai_provider,
|
||||||
|
ProviderType.GOOGLE: gemini_provider,
|
||||||
|
}
|
||||||
|
return mapping.get(provider_type)
|
||||||
|
|
||||||
|
with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect):
|
||||||
|
assert openai_provider.validate_model_name("mini") # Should work with shorthand
|
||||||
|
assert openai_provider.validate_model_name("gpt-5-mini") # Canonical resolved from shorthand
|
||||||
|
assert not openai_provider.validate_model_name("o4-mini") # Unrelated model still blocked
|
||||||
|
assert not openai_provider.validate_model_name("o3-mini")
|
||||||
|
|
||||||
|
# Test Gemini provider
|
||||||
|
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
|
||||||
|
assert gemini_provider.validate_model_name("gemini-2.5-flash") # Canonical allowed
|
||||||
|
assert not gemini_provider.validate_model_name("pro") # Not allowed
|
||||||
|
|
||||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
|
||||||
def test_multiple_shorthands_for_same_model(self):
|
def test_multiple_shorthands_for_same_model(self):
|
||||||
@@ -596,9 +631,9 @@ class TestShorthandRestrictions:
|
|||||||
assert openai_provider.validate_model_name("mini") # mini -> o4-mini
|
assert openai_provider.validate_model_name("mini") # mini -> o4-mini
|
||||||
assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini
|
assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini
|
||||||
|
|
||||||
# Resolved names work only if explicitly allowed
|
# Resolved names should be allowed when their shorthands are present
|
||||||
assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed
|
assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed
|
||||||
assert not openai_provider.validate_model_name("o3-mini") # Not explicitly allowed, only shorthand
|
assert openai_provider.validate_model_name("o3-mini") # Allowed via shorthand
|
||||||
|
|
||||||
# Other models should not work
|
# Other models should not work
|
||||||
assert not openai_provider.validate_model_name("o3")
|
assert not openai_provider.validate_model_name("o3")
|
||||||
|
|||||||
@@ -260,9 +260,10 @@ class TestOpenRouterAutoMode:
|
|||||||
os.environ["DEFAULT_MODEL"] = "auto"
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
|
||||||
mock_provider_class = Mock()
|
mock_provider_class = Mock()
|
||||||
mock_provider_instance = Mock(spec=["get_provider_type", "list_models"])
|
mock_provider_instance = Mock(spec=["get_provider_type", "list_models", "get_all_model_capabilities"])
|
||||||
mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER
|
mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER
|
||||||
mock_provider_instance.list_models.return_value = []
|
mock_provider_instance.list_models.return_value = []
|
||||||
|
mock_provider_instance.get_all_model_capabilities.return_value = {}
|
||||||
mock_provider_class.return_value = mock_provider_instance
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class)
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class)
|
||||||
|
|||||||
@@ -293,13 +293,7 @@ class TestOpenRouterAliasRestrictions:
|
|||||||
# o3 -> openai/o3
|
# o3 -> openai/o3
|
||||||
# gpt4.1 -> should not exist (expected to be filtered out)
|
# gpt4.1 -> should not exist (expected to be filtered out)
|
||||||
|
|
||||||
expected_models = {
|
expected_models = {"o3-mini", "pro", "flash", "o4-mini", "o3"}
|
||||||
"openai/o3-mini",
|
|
||||||
"google/gemini-2.5-pro",
|
|
||||||
"google/gemini-2.5-flash",
|
|
||||||
"openai/o4-mini",
|
|
||||||
"openai/o3",
|
|
||||||
}
|
|
||||||
|
|
||||||
available_model_names = set(available_models.keys())
|
available_model_names = set(available_models.keys())
|
||||||
|
|
||||||
@@ -355,9 +349,11 @@ class TestOpenRouterAliasRestrictions:
|
|||||||
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
|
||||||
expected_models = {
|
expected_models = {
|
||||||
"openai/o3-mini", # from alias
|
"o3-mini", # alias
|
||||||
|
"openai/o3-mini", # canonical
|
||||||
"anthropic/claude-opus-4.1", # full name
|
"anthropic/claude-opus-4.1", # full name
|
||||||
"google/gemini-2.5-flash", # from alias
|
"flash", # alias
|
||||||
|
"google/gemini-2.5-flash", # canonical
|
||||||
}
|
}
|
||||||
|
|
||||||
available_model_names = set(available_models.keys())
|
available_model_names = set(available_models.keys())
|
||||||
|
|||||||
@@ -83,9 +83,18 @@ class ListModelsTool(BaseTool):
|
|||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
output_lines = ["# Available AI Models\n"]
|
output_lines = ["# Available AI Models\n"]
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service()
|
||||||
|
restricted_models_by_provider: dict[ProviderType, list[str]] = {}
|
||||||
|
|
||||||
|
if restriction_service:
|
||||||
|
restricted_map = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
for model_name, provider_type in restricted_map.items():
|
||||||
|
restricted_models_by_provider.setdefault(provider_type, []).append(model_name)
|
||||||
|
|
||||||
# Map provider types to friendly names and their models
|
# Map provider types to friendly names and their models
|
||||||
provider_info = {
|
provider_info = {
|
||||||
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
|
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
|
||||||
@@ -94,6 +103,43 @@ class ListModelsTool(BaseTool):
|
|||||||
ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
|
ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def format_model_entry(provider, display_name: str) -> list[str]:
|
||||||
|
try:
|
||||||
|
capabilities = provider.get_capabilities(display_name)
|
||||||
|
except ValueError:
|
||||||
|
return [f"- `{display_name}` *(not recognized by provider)*"]
|
||||||
|
|
||||||
|
canonical = capabilities.model_name
|
||||||
|
if canonical.lower() == display_name.lower():
|
||||||
|
header = f"- `{canonical}`"
|
||||||
|
else:
|
||||||
|
header = f"- `{display_name}` → `{canonical}`"
|
||||||
|
|
||||||
|
try:
|
||||||
|
context_value = capabilities.context_window or 0
|
||||||
|
except AttributeError:
|
||||||
|
context_value = 0
|
||||||
|
try:
|
||||||
|
context_value = int(context_value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
context_value = 0
|
||||||
|
|
||||||
|
if context_value >= 1_000_000:
|
||||||
|
context_str = f"{context_value // 1_000_000}M context"
|
||||||
|
elif context_value >= 1_000:
|
||||||
|
context_str = f"{context_value // 1_000}K context"
|
||||||
|
elif context_value > 0:
|
||||||
|
context_str = f"{context_value} context"
|
||||||
|
else:
|
||||||
|
context_str = "unknown context"
|
||||||
|
|
||||||
|
try:
|
||||||
|
description = capabilities.description or "No description available"
|
||||||
|
except AttributeError:
|
||||||
|
description = "No description available"
|
||||||
|
lines = [header, f" - {context_str}", f" - {description}"]
|
||||||
|
return lines
|
||||||
|
|
||||||
# Check each native provider type
|
# Check each native provider type
|
||||||
for provider_type, info in provider_info.items():
|
for provider_type, info in provider_info.items():
|
||||||
# Check if provider is enabled
|
# Check if provider is enabled
|
||||||
@@ -104,30 +150,49 @@ class ListModelsTool(BaseTool):
|
|||||||
|
|
||||||
if is_configured:
|
if is_configured:
|
||||||
output_lines.append("**Status**: Configured and available")
|
output_lines.append("**Status**: Configured and available")
|
||||||
output_lines.append("\n**Models**:")
|
has_restrictions = bool(restriction_service and restriction_service.has_restrictions(provider_type))
|
||||||
|
|
||||||
aliases = []
|
if has_restrictions:
|
||||||
for model_name, capabilities in provider.get_capabilities_by_rank():
|
restricted_names = sorted(set(restricted_models_by_provider.get(provider_type, [])))
|
||||||
description = capabilities.description or "No description available"
|
|
||||||
context_window = capabilities.context_window
|
|
||||||
|
|
||||||
if context_window >= 1_000_000:
|
if restricted_names:
|
||||||
context_str = f"{context_window // 1_000_000}M context"
|
output_lines.append("\n**Models (policy restricted)**:")
|
||||||
elif context_window >= 1_000:
|
for model_name in restricted_names:
|
||||||
context_str = f"{context_window // 1_000}K context"
|
output_lines.extend(format_model_entry(provider, model_name))
|
||||||
else:
|
else:
|
||||||
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
|
output_lines.append("\n*No models are currently allowed by restriction policy.*")
|
||||||
|
else:
|
||||||
|
output_lines.append("\n**Models**:")
|
||||||
|
|
||||||
output_lines.append(f"- `{model_name}` - {context_str}")
|
aliases = []
|
||||||
output_lines.append(f" - {description}")
|
for model_name, capabilities in provider.get_capabilities_by_rank():
|
||||||
|
try:
|
||||||
|
description = capabilities.description or "No description available"
|
||||||
|
except AttributeError:
|
||||||
|
description = "No description available"
|
||||||
|
|
||||||
for alias in capabilities.aliases or []:
|
try:
|
||||||
if alias != model_name:
|
context_window = capabilities.context_window or 0
|
||||||
aliases.append(f"- `{alias}` → `{model_name}`")
|
except AttributeError:
|
||||||
|
context_window = 0
|
||||||
|
|
||||||
if aliases:
|
if context_window >= 1_000_000:
|
||||||
output_lines.append("\n**Aliases**:")
|
context_str = f"{context_window // 1_000_000}M context"
|
||||||
output_lines.extend(sorted(aliases))
|
elif context_window >= 1_000:
|
||||||
|
context_str = f"{context_window // 1_000}K context"
|
||||||
|
else:
|
||||||
|
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
|
||||||
|
|
||||||
|
output_lines.append(f"- `{model_name}` - {context_str}")
|
||||||
|
output_lines.append(f" - {description}")
|
||||||
|
|
||||||
|
for alias in capabilities.aliases or []:
|
||||||
|
if alias != model_name:
|
||||||
|
aliases.append(f"- `{alias}` → `{model_name}`")
|
||||||
|
|
||||||
|
if aliases:
|
||||||
|
output_lines.append("\n**Aliases**:")
|
||||||
|
output_lines.extend(sorted(aliases))
|
||||||
else:
|
else:
|
||||||
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
||||||
|
|
||||||
@@ -144,19 +209,10 @@ class ListModelsTool(BaseTool):
|
|||||||
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
|
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get OpenRouter provider from registry to properly apply restrictions
|
|
||||||
from providers.registry import ModelProviderRegistry
|
|
||||||
from providers.shared import ProviderType
|
|
||||||
|
|
||||||
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||||
if provider:
|
if provider:
|
||||||
# Get models with restrictions applied
|
|
||||||
available_models = provider.list_models(respect_restrictions=True)
|
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
# Group by provider and retain ranking information for consistent ordering
|
|
||||||
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
|
|
||||||
|
|
||||||
def _format_context(tokens: int) -> str:
|
def _format_context(tokens: int) -> str:
|
||||||
if not tokens:
|
if not tokens:
|
||||||
return "?"
|
return "?"
|
||||||
@@ -166,53 +222,83 @@ class ListModelsTool(BaseTool):
|
|||||||
return f"{tokens // 1_000}K"
|
return f"{tokens // 1_000}K"
|
||||||
return str(tokens)
|
return str(tokens)
|
||||||
|
|
||||||
for model_name in available_models:
|
has_restrictions = bool(
|
||||||
config = registry.resolve(model_name)
|
restriction_service and restriction_service.has_restrictions(ProviderType.OPENROUTER)
|
||||||
provider_name = "other"
|
)
|
||||||
if config and "/" in config.model_name:
|
|
||||||
provider_name = config.model_name.split("/")[0]
|
|
||||||
elif "/" in model_name:
|
|
||||||
provider_name = model_name.split("/")[0]
|
|
||||||
|
|
||||||
providers_models.setdefault(provider_name, [])
|
if has_restrictions:
|
||||||
|
restricted_names = sorted(set(restricted_models_by_provider.get(ProviderType.OPENROUTER, [])))
|
||||||
|
|
||||||
rank = config.get_effective_capability_rank() if config else 0
|
output_lines.append("\n**Models (policy restricted)**:")
|
||||||
providers_models[provider_name].append((rank, model_name, config))
|
if restricted_names:
|
||||||
|
for model_name in restricted_names:
|
||||||
|
try:
|
||||||
|
caps = provider.get_capabilities(model_name)
|
||||||
|
except ValueError:
|
||||||
|
output_lines.append(f"- `{model_name}` *(not recognized by provider)*")
|
||||||
|
continue
|
||||||
|
|
||||||
output_lines.append("\n**Available Models**:")
|
context_value = int(caps.context_window or 0)
|
||||||
for provider_name, models in sorted(providers_models.items()):
|
context_str = _format_context(context_value)
|
||||||
output_lines.append(f"\n*{provider_name.title()}:*")
|
|
||||||
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
|
|
||||||
if config:
|
|
||||||
context_str = _format_context(config.context_window)
|
|
||||||
suffix_parts = [f"{context_str} context"]
|
suffix_parts = [f"{context_str} context"]
|
||||||
if getattr(config, "supports_extended_thinking", False):
|
if caps.supports_extended_thinking:
|
||||||
suffix_parts.append("thinking")
|
suffix_parts.append("thinking")
|
||||||
suffix = ", ".join(suffix_parts)
|
suffix = ", ".join(suffix_parts)
|
||||||
output_lines.append(f"- `{alias}` → `{config.model_name}` (score {rank}, {suffix})")
|
|
||||||
else:
|
|
||||||
output_lines.append(f"- `{alias}` (score {rank})")
|
|
||||||
|
|
||||||
total_models = len(available_models)
|
arrow = ""
|
||||||
# Show all models - no truncation message needed
|
if caps.model_name.lower() != model_name.lower():
|
||||||
|
arrow = f" → `{caps.model_name}`"
|
||||||
|
|
||||||
# Check if restrictions are applied
|
score = caps.get_effective_capability_rank()
|
||||||
restriction_service = None
|
output_lines.append(f"- `{model_name}`{arrow} (score {score}, {suffix})")
|
||||||
try:
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER) or set()
|
||||||
if restriction_service.has_restrictions(ProviderType.OPENROUTER):
|
if allowed_set:
|
||||||
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER)
|
output_lines.append(
|
||||||
output_lines.append(
|
f"\n*OpenRouter models restricted by OPENROUTER_ALLOWED_MODELS: {', '.join(sorted(allowed_set))}*"
|
||||||
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}"
|
)
|
||||||
)
|
else:
|
||||||
except Exception as e:
|
output_lines.append("- *No models allowed by current restriction policy.*")
|
||||||
logger.warning(f"Error checking OpenRouter restrictions: {e}")
|
else:
|
||||||
|
available_models = provider.list_models(respect_restrictions=True)
|
||||||
|
providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {}
|
||||||
|
|
||||||
|
for model_name in available_models:
|
||||||
|
config = registry.resolve(model_name)
|
||||||
|
provider_name = "other"
|
||||||
|
if config and "/" in config.model_name:
|
||||||
|
provider_name = config.model_name.split("/")[0]
|
||||||
|
elif "/" in model_name:
|
||||||
|
provider_name = model_name.split("/")[0]
|
||||||
|
|
||||||
|
providers_models.setdefault(provider_name, [])
|
||||||
|
|
||||||
|
rank = config.get_effective_capability_rank() if config else 0
|
||||||
|
providers_models[provider_name].append((rank, model_name, config))
|
||||||
|
|
||||||
|
output_lines.append("\n**Available Models**:")
|
||||||
|
for provider_name, models in sorted(providers_models.items()):
|
||||||
|
output_lines.append(f"\n*{provider_name.title()}:*")
|
||||||
|
for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])):
|
||||||
|
if config:
|
||||||
|
context_str = _format_context(getattr(config, "context_window", 0))
|
||||||
|
suffix_parts = [f"{context_str} context"]
|
||||||
|
if getattr(config, "supports_extended_thinking", False):
|
||||||
|
suffix_parts.append("thinking")
|
||||||
|
suffix = ", ".join(suffix_parts)
|
||||||
|
|
||||||
|
arrow = ""
|
||||||
|
if config.model_name.lower() != alias.lower():
|
||||||
|
arrow = f" → `{config.model_name}`"
|
||||||
|
|
||||||
|
output_lines.append(f"- `{alias}`{arrow} (score {rank}, {suffix})")
|
||||||
|
else:
|
||||||
|
output_lines.append(f"- `{alias}` (score {rank})")
|
||||||
else:
|
else:
|
||||||
output_lines.append("**Error**: Could not load OpenRouter provider")
|
output_lines.append("**Error**: Could not load OpenRouter provider")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.exception("Error listing OpenRouter models: %s", e)
|
||||||
output_lines.append(f"**Error loading models**: {str(e)}")
|
output_lines.append(f"**Error loading models**: {str(e)}")
|
||||||
else:
|
else:
|
||||||
output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)")
|
output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)")
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ Example:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from providers.shared import ProviderType
|
from providers.shared import ProviderType
|
||||||
@@ -58,6 +59,7 @@ class ModelRestrictionService:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize the restriction service by loading from environment."""
|
"""Initialize the restriction service by loading from environment."""
|
||||||
self.restrictions: dict[ProviderType, set[str]] = {}
|
self.restrictions: dict[ProviderType, set[str]] = {}
|
||||||
|
self._alias_resolution_cache: dict[ProviderType, dict[str, str]] = defaultdict(dict)
|
||||||
self._load_from_env()
|
self._load_from_env()
|
||||||
|
|
||||||
def _load_from_env(self) -> None:
|
def _load_from_env(self) -> None:
|
||||||
@@ -79,6 +81,7 @@ class ModelRestrictionService:
|
|||||||
|
|
||||||
if models:
|
if models:
|
||||||
self.restrictions[provider_type] = models
|
self.restrictions[provider_type] = models
|
||||||
|
self._alias_resolution_cache[provider_type] = {}
|
||||||
logger.info(f"{provider_type.value} allowed models: {sorted(models)}")
|
logger.info(f"{provider_type.value} allowed models: {sorted(models)}")
|
||||||
else:
|
else:
|
||||||
# All entries were empty after cleaning - treat as no restrictions
|
# All entries were empty after cleaning - treat as no restrictions
|
||||||
@@ -150,7 +153,41 @@ class ModelRestrictionService:
|
|||||||
names_to_check.add(original_name.lower())
|
names_to_check.add(original_name.lower())
|
||||||
|
|
||||||
# If any of the names is in the allowed set, it's allowed
|
# If any of the names is in the allowed set, it's allowed
|
||||||
return any(name in allowed_set for name in names_to_check)
|
if any(name in allowed_set for name in names_to_check):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Attempt to resolve canonical names for allowed aliases using provider metadata.
|
||||||
|
try:
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
provider = ModelProviderRegistry.get_provider(provider_type)
|
||||||
|
except Exception: # pragma: no cover - registry lookup failure shouldn't break validation
|
||||||
|
provider = None
|
||||||
|
|
||||||
|
if provider:
|
||||||
|
cache = self._alias_resolution_cache.setdefault(provider_type, {})
|
||||||
|
|
||||||
|
for allowed_entry in list(allowed_set):
|
||||||
|
normalized_resolved = cache.get(allowed_entry)
|
||||||
|
|
||||||
|
if not normalized_resolved:
|
||||||
|
try:
|
||||||
|
resolved = provider._resolve_model_name(allowed_entry)
|
||||||
|
except Exception: # pragma: no cover - resolution failures are treated as non-matches
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not resolved:
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_resolved = resolved.lower()
|
||||||
|
cache[allowed_entry] = normalized_resolved
|
||||||
|
|
||||||
|
if normalized_resolved in names_to_check:
|
||||||
|
allowed_set.add(normalized_resolved)
|
||||||
|
cache[normalized_resolved] = normalized_resolved
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]:
|
def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user