diff --git a/AGENTS.md b/AGENTS.md index 60cd3d1..5225ded 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,5 +1,9 @@ # Repository Guidelines +See `requirements.txt` and `requirements-dev.txt` + +Also read CLAUDE.md and CLAUDE.local.md if available. + ## Project Structure & Module Organization Zen MCP Server centers on `server.py`, which exposes MCP entrypoints and coordinates multi-model workflows. Feature-specific tools live in `tools/`, provider integrations in `providers/`, and shared helpers in `utils/`. @@ -14,6 +18,13 @@ Authoritative documentation and samples live in `docs/`, and runtime diagnostics - `python communication_simulator_test.py --quick` – smoke-test orchestration across tools and providers. - `./run_integration_tests.sh [--with-simulator]` – exercise provider-dependent flows against remote or Ollama models. +For example, this is how we run an individual / all tests: + +``` +.zen_venv/bin/activate && pytest tests/test_auto_mode_model_listing.py -q +.zen_venv/bin/activate && pytest -q +``` + ## Coding Style & Naming Conventions Target Python 3.9+ with Black and isort using a 120-character line limit; Ruff enforces pycodestyle, pyflakes, bugbear, comprehension, and pyupgrade rules. Prefer explicit type hints, snake_case modules, and imperative commit-time docstrings. Extend workflows by defining hook or abstract methods instead of checking `hasattr()`/`getattr()`—inheritance-backed contracts keep behavior discoverable and testable. diff --git a/providers/custom.py b/providers/custom.py index 8255e39..104fc6e 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -73,6 +73,8 @@ class CustomProvider(OpenAICompatibleProvider): logging.info(f"Initializing Custom provider with endpoint: {base_url}") + self._alias_cache: dict[str, str] = {} + super().__init__(api_key, base_url=base_url, **kwargs) # Initialize model registry (shared with OpenRouter for consistent aliases) @@ -120,11 +122,18 @@ class CustomProvider(OpenAICompatibleProvider): def _resolve_model_name(self, model_name: str) -> str: """Resolve registry aliases and strip version tags for local models.""" + cache_key = model_name.lower() + if cache_key in self._alias_cache: + return self._alias_cache[cache_key] + config = self._registry.resolve(model_name) if config: if config.model_name != model_name: - logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") - return config.model_name + logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name) + resolved = config.model_name + self._alias_cache[cache_key] = resolved + self._alias_cache.setdefault(resolved.lower(), resolved) + return resolved if ":" in model_name: base_model = model_name.split(":")[0] @@ -132,11 +141,16 @@ class CustomProvider(OpenAICompatibleProvider): base_config = self._registry.resolve(base_model) if base_config: - logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'") - return base_config.model_name + logging.debug("Resolved base model '%s' to '%s'", base_model, base_config.model_name) + resolved = base_config.model_name + self._alias_cache[cache_key] = resolved + self._alias_cache.setdefault(resolved.lower(), resolved) + return resolved + self._alias_cache[cache_key] = base_model return base_model logging.debug(f"Model '{model_name}' not found in registry, using as-is") + self._alias_cache[cache_key] = model_name return model_name def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py index 7f49837..6b4a08a 100644 --- a/providers/openai_compatible.py +++ b/providers/openai_compatible.py @@ -39,6 +39,7 @@ class OpenAICompatibleProvider(ModelProvider): base_url: Base URL for the API endpoint **kwargs: Additional configuration options including timeout """ + self._allowed_alias_cache: dict[str, str] = {} super().__init__(api_key, **kwargs) self._client = None self.base_url = base_url @@ -74,9 +75,33 @@ class OpenAICompatibleProvider(ModelProvider): canonical = canonical_name.lower() if requested not in self.allowed_models and canonical not in self.allowed_models: - raise ValueError( - f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}" - ) + allowed = False + for allowed_entry in list(self.allowed_models): + normalized_resolved = self._allowed_alias_cache.get(allowed_entry) + if normalized_resolved is None: + try: + resolved_name = self._resolve_model_name(allowed_entry) + except Exception: + continue + + if not resolved_name: + continue + + normalized_resolved = resolved_name.lower() + self._allowed_alias_cache[allowed_entry] = normalized_resolved + + if normalized_resolved == canonical: + # Canonical match discovered via alias resolution – mark as allowed and + # memoise the canonical entry for future lookups. + allowed = True + self._allowed_alias_cache[canonical] = canonical + self.allowed_models.add(canonical) + break + + if not allowed: + raise ValueError( + f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}" + ) def _parse_allowed_models(self) -> Optional[set[str]]: """Parse allowed models from environment variable. @@ -94,6 +119,7 @@ class OpenAICompatibleProvider(ModelProvider): models = {m.strip().lower() for m in models_str.split(",") if m.strip()} if models: logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}") + self._allowed_alias_cache = {} return models # Log info if no allow-list configured for proxy providers diff --git a/providers/openrouter.py b/providers/openrouter.py index 12b10c7..c4b54b5 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -50,6 +50,7 @@ class OpenRouterProvider(OpenAICompatibleProvider): **kwargs: Additional configuration """ base_url = "https://openrouter.ai/api/v1" + self._alias_cache: dict[str, str] = {} super().__init__(api_key, base_url=base_url, **kwargs) # Initialize model registry @@ -178,13 +179,21 @@ class OpenRouterProvider(OpenAICompatibleProvider): def _resolve_model_name(self, model_name: str) -> str: """Resolve aliases defined in the OpenRouter registry.""" + cache_key = model_name.lower() + if cache_key in self._alias_cache: + return self._alias_cache[cache_key] + config = self._registry.resolve(model_name) if config: if config.model_name != model_name: - logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") - return config.model_name + logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name) + resolved = config.model_name + self._alias_cache[cache_key] = resolved + self._alias_cache.setdefault(resolved.lower(), resolved) + return resolved logging.debug(f"Model '{model_name}' not found in registry, using as-is") + self._alias_cache[cache_key] = model_name return model_name def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: diff --git a/providers/registry.py b/providers/registry.py index 6f412ff..f5865d5 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -205,6 +205,18 @@ class ModelProviderRegistry: logging.warning("Provider %s does not implement list_models", provider_type) continue + if restriction_service and restriction_service.has_restrictions(provider_type): + restricted_display = cls._collect_restricted_display_names( + provider, + provider_type, + available, + restriction_service, + ) + if restricted_display: + for model_name in restricted_display: + models[model_name] = provider_type + continue + for model_name in available: # ===================================================================================== # CRITICAL: Prevent double restriction filtering (Fixed Issue #98) @@ -227,6 +239,50 @@ class ModelProviderRegistry: return models + @classmethod + def _collect_restricted_display_names( + cls, + provider: ModelProvider, + provider_type: ProviderType, + available: list[str], + restriction_service, + ) -> list[str] | None: + """Derive the human-facing model list when restrictions are active.""" + + allowed_models = restriction_service.get_allowed_models(provider_type) + if not allowed_models: + return None + + allowed_details: list[tuple[str, int]] = [] + + for model_name in sorted(allowed_models): + try: + capabilities = provider.get_capabilities(model_name) + except (AttributeError, ValueError): + continue + + try: + rank = capabilities.get_effective_capability_rank() + rank_value = float(rank) + except (AttributeError, TypeError, ValueError): + rank_value = 0.0 + + allowed_details.append((model_name, rank_value)) + + if allowed_details: + allowed_details.sort(key=lambda item: (-item[1], item[0])) + return [name for name, _ in allowed_details] + + # Fallback: intersect the allowlist with the provider-advertised names. + available_lookup = {name.lower(): name for name in available} + display_names: list[str] = [] + for model_name in sorted(allowed_models): + lowered = model_name.lower() + if lowered in available_lookup: + display_names.append(available_lookup[lowered]) + + return display_names + @classmethod def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]: """Get list of available model names, optionally filtered by provider. diff --git a/server.py b/server.py index a8bf47e..077e47f 100644 --- a/server.py +++ b/server.py @@ -492,15 +492,25 @@ def configure_providers(): # Register providers in priority order: # 1. Native APIs first (most direct and efficient) + registered_providers = [] + if has_native_apis: if gemini_key and gemini_key != "your_gemini_api_key_here": ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider) + registered_providers.append(ProviderType.GOOGLE.value) + logger.debug(f"Registered provider: {ProviderType.GOOGLE.value}") if openai_key and openai_key != "your_openai_api_key_here": ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider) + registered_providers.append(ProviderType.OPENAI.value) + logger.debug(f"Registered provider: {ProviderType.OPENAI.value}") if xai_key and xai_key != "your_xai_api_key_here": ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider) + registered_providers.append(ProviderType.XAI.value) + logger.debug(f"Registered provider: {ProviderType.XAI.value}") if dial_key and dial_key != "your_dial_api_key_here": ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider) + registered_providers.append(ProviderType.DIAL.value) + logger.debug(f"Registered provider: {ProviderType.DIAL.value}") # 2. Custom provider second (for local/private models) if has_custom: @@ -511,10 +521,18 @@ def configure_providers(): return CustomProvider(api_key=api_key or "", base_url=base_url) # Use provided API key or empty string ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory) + registered_providers.append(ProviderType.CUSTOM.value) + logger.debug(f"Registered provider: {ProviderType.CUSTOM.value}") # 3. OpenRouter last (catch-all for everything else) if has_openrouter: ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) + registered_providers.append(ProviderType.OPENROUTER.value) + logger.debug(f"Registered provider: {ProviderType.OPENROUTER.value}") + + # Log all registered providers + if registered_providers: + logger.info(f"Registered providers: {', '.join(registered_providers)}") # Require at least one valid provider if not valid_providers: diff --git a/tests/test_alias_target_restrictions.py b/tests/test_alias_target_restrictions.py index e779acd..77f2799 100644 --- a/tests/test_alias_target_restrictions.py +++ b/tests/test_alias_target_restrictions.py @@ -63,27 +63,30 @@ class TestAliasTargetRestrictions: assert provider.validate_model_name("o4mini") @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only - def test_restriction_policy_allows_only_alias_when_alias_specified(self): - """Test that restriction policy allows only the alias when just alias is specified. - - If you restrict to 'mini' (which is an alias for gpt-5-mini), - only the alias should work, not other models. - This is the correct restrictive behavior. - """ - # Clear cached restriction service + def test_restriction_policy_alias_allows_canonical(self): + """Alias-only allowlists should permit both the alias and its canonical target.""" import utils.model_restrictions utils.model_restrictions._restriction_service = None provider = OpenAIModelProvider(api_key="test-key") - # Only the alias should be allowed assert provider.validate_model_name("mini") - # Direct target for this alias should NOT be allowed (mini -> gpt-5-mini) - assert not provider.validate_model_name("gpt-5-mini") - # Other models should NOT be allowed + assert provider.validate_model_name("gpt-5-mini") assert not provider.validate_model_name("o4-mini") + @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"}) + def test_restriction_policy_alias_allows_short_name(self): + """Common aliases like 'gpt5' should allow their canonical forms.""" + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + + provider = OpenAIModelProvider(api_key="test-key") + + assert provider.validate_model_name("gpt5") + assert provider.validate_model_name("gpt-5") + @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target def test_gemini_restriction_policy_allows_alias_when_target_allowed(self): """Test Gemini restriction policy allows alias when target is allowed.""" @@ -99,19 +102,16 @@ class TestAliasTargetRestrictions: assert provider.validate_model_name("flash") @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}) # Allow alias only - def test_gemini_restriction_policy_allows_only_alias_when_alias_specified(self): - """Test Gemini restriction policy allows only alias when just alias is specified.""" - # Clear cached restriction service + def test_gemini_restriction_policy_alias_allows_canonical(self): + """Gemini alias allowlists should permit canonical forms.""" import utils.model_restrictions utils.model_restrictions._restriction_service = None provider = GeminiModelProvider(api_key="test-key") - # Only the alias should be allowed assert provider.validate_model_name("flash") - # Direct target should NOT be allowed - assert not provider.validate_model_name("gemini-2.5-flash") + assert provider.validate_model_name("gemini-2.5-flash") def test_restriction_service_validation_includes_all_targets(self): """Test that restriction service validation knows about all aliases and targets.""" @@ -153,6 +153,30 @@ class TestAliasTargetRestrictions: assert provider.validate_model_name("o4-mini") # target assert provider.validate_model_name("o4mini") # alias for o4-mini + @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"}, clear=True) + def test_service_alias_allows_canonical_openai(self): + """ModelRestrictionService should permit canonical names resolved from aliases.""" + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + provider = OpenAIModelProvider(api_key="test-key") + service = ModelRestrictionService() + + assert service.is_allowed(ProviderType.OPENAI, "gpt-5") + assert provider.validate_model_name("gpt-5") + + @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}, clear=True) + def test_service_alias_allows_canonical_gemini(self): + """Gemini alias allowlists should permit canonical forms.""" + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + provider = GeminiModelProvider(api_key="test-key") + service = ModelRestrictionService() + + assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash") + assert provider.validate_model_name("gemini-2.5-flash") + def test_alias_target_policy_regression_prevention(self): """Regression test to ensure aliases and targets are both validated properly. diff --git a/tests/test_auto_mode.py b/tests/test_auto_mode.py index 602aed9..da104df 100644 --- a/tests/test_auto_mode.py +++ b/tests/test_auto_mode.py @@ -106,19 +106,35 @@ class TestAutoMode: def test_tool_schema_in_normal_mode(self): """Test that tool schemas don't require model in normal mode""" - # This test uses the default from conftest.py which sets non-auto mode - # The conftest.py mock_provider_availability fixture ensures the model is available - tool = ChatTool() - schema = tool.get_input_schema() + # Save original + original = os.environ.get("DEFAULT_MODEL", "") - # Model should not be required when default model is configured - assert "model" not in schema["required"] + try: + # Set to a specific model (not auto mode) + os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash" + import config - # Model field should have simpler description - model_schema = schema["properties"]["model"] - assert "enum" not in model_schema - assert "listmodels" in model_schema["description"] - assert "default model" in model_schema["description"].lower() + importlib.reload(config) + + tool = ChatTool() + schema = tool.get_input_schema() + + # Model should not be required when default model is configured + assert "model" not in schema["required"] + + # Model field should have simpler description + model_schema = schema["properties"]["model"] + assert "enum" not in model_schema + assert "listmodels" in model_schema["description"] + assert "default model" in model_schema["description"].lower() + + finally: + # Restore + if original: + os.environ["DEFAULT_MODEL"] = original + else: + os.environ.pop("DEFAULT_MODEL", None) + importlib.reload(config) @pytest.mark.asyncio async def test_auto_mode_requires_model_parameter(self): diff --git a/tests/test_auto_mode_model_listing.py b/tests/test_auto_mode_model_listing.py new file mode 100644 index 0000000..e2f0008 --- /dev/null +++ b/tests/test_auto_mode_model_listing.py @@ -0,0 +1,203 @@ +"""Tests covering model restriction-aware error messaging in auto mode.""" + +import asyncio +import importlib +import json + +import pytest + +import utils.model_restrictions as model_restrictions +from providers.gemini import GeminiModelProvider +from providers.openai_provider import OpenAIModelProvider +from providers.openrouter import OpenRouterProvider +from providers.registry import ModelProviderRegistry +from providers.shared import ProviderType +from providers.xai import XAIModelProvider + + +def _extract_available_models(message: str) -> list[str]: + """Parse the available model list from the error message.""" + + marker = "Available models: " + if marker not in message: + raise AssertionError(f"Expected '{marker}' in message: {message}") + + start = message.index(marker) + len(marker) + end = message.find(". Suggested", start) + if end == -1: + end = len(message) + + available_segment = message[start:end].strip() + if not available_segment: + return [] + + return [item.strip() for item in available_segment.split(",")] + + +@pytest.fixture +def reset_registry(): + """Ensure registry and restriction service state is isolated.""" + + ModelProviderRegistry.reset_for_testing() + model_restrictions._restriction_service = None + yield + ModelProviderRegistry.reset_for_testing() + model_restrictions._restriction_service = None + + +def _register_core_providers(*, include_xai: bool = False): + ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider) + ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider) + ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) + if include_xai: + ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider) + + +@pytest.mark.no_mock_provider +def test_error_listing_respects_env_restrictions(monkeypatch, reset_registry): + """Error payload should surface only the allowed models for each provider.""" + + monkeypatch.setenv("DEFAULT_MODEL", "auto") + monkeypatch.setenv("GEMINI_API_KEY", "test-gemini") + monkeypatch.setenv("OPENAI_API_KEY", "test-openai") + monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter") + monkeypatch.delenv("XAI_API_KEY", raising=False) + monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false") + try: + import dotenv + + monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"}) + except ModuleNotFoundError: + pass + + monkeypatch.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro") + monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "gpt-5") + monkeypatch.setenv("OPENROUTER_ALLOWED_MODELS", "gpt5nano") + monkeypatch.setenv("XAI_ALLOWED_MODELS", "") + + import config + + importlib.reload(config) + + _register_core_providers() + + import server + + importlib.reload(server) + + # Reload may have re-applied .env overrides; enforce our test configuration + for key, value in ( + ("DEFAULT_MODEL", "auto"), + ("GEMINI_API_KEY", "test-gemini"), + ("OPENAI_API_KEY", "test-openai"), + ("OPENROUTER_API_KEY", "test-openrouter"), + ("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro"), + ("OPENAI_ALLOWED_MODELS", "gpt-5"), + ("OPENROUTER_ALLOWED_MODELS", "gpt5nano"), + ("XAI_ALLOWED_MODELS", ""), + ): + monkeypatch.setenv(key, value) + + for var in ("XAI_API_KEY", "CUSTOM_API_URL", "CUSTOM_API_KEY", "DIAL_API_KEY"): + monkeypatch.delenv(var, raising=False) + + ModelProviderRegistry.reset_for_testing() + model_restrictions._restriction_service = None + server.configure_providers() + + result = asyncio.run( + server.handle_call_tool( + "chat", + { + "model": "gpt5mini", + "prompt": "Tell me about your strengths", + }, + ) + ) + + assert len(result) == 1 + payload = json.loads(result[0].text) + assert payload["status"] == "error" + + available_models = _extract_available_models(payload["content"]) + assert set(available_models) == {"gemini-2.5-pro", "gpt-5", "gpt5nano", "openai/gpt-5-nano"} + + +@pytest.mark.no_mock_provider +def test_error_listing_without_restrictions_shows_full_catalog(monkeypatch, reset_registry): + """When no restrictions are set, the full high-capability catalogue should appear.""" + + monkeypatch.setenv("DEFAULT_MODEL", "auto") + monkeypatch.setenv("GEMINI_API_KEY", "test-gemini") + monkeypatch.setenv("OPENAI_API_KEY", "test-openai") + monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter") + monkeypatch.setenv("XAI_API_KEY", "test-xai") + monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false") + try: + import dotenv + + monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"}) + except ModuleNotFoundError: + pass + + for var in ( + "GOOGLE_ALLOWED_MODELS", + "OPENAI_ALLOWED_MODELS", + "OPENROUTER_ALLOWED_MODELS", + "XAI_ALLOWED_MODELS", + "DIAL_ALLOWED_MODELS", + ): + monkeypatch.delenv(var, raising=False) + + import config + + importlib.reload(config) + + _register_core_providers(include_xai=True) + + import server + + importlib.reload(server) + + for key, value in ( + ("DEFAULT_MODEL", "auto"), + ("GEMINI_API_KEY", "test-gemini"), + ("OPENAI_API_KEY", "test-openai"), + ("OPENROUTER_API_KEY", "test-openrouter"), + ): + monkeypatch.setenv(key, value) + + for var in ( + "GOOGLE_ALLOWED_MODELS", + "OPENAI_ALLOWED_MODELS", + "OPENROUTER_ALLOWED_MODELS", + "XAI_ALLOWED_MODELS", + "DIAL_ALLOWED_MODELS", + "CUSTOM_API_URL", + "CUSTOM_API_KEY", + ): + monkeypatch.delenv(var, raising=False) + + ModelProviderRegistry.reset_for_testing() + model_restrictions._restriction_service = None + server.configure_providers() + + result = asyncio.run( + server.handle_call_tool( + "chat", + { + "model": "dummymodel", + "prompt": "Hi there", + }, + ) + ) + + assert len(result) == 1 + payload = json.loads(result[0].text) + assert payload["status"] == "error" + + available_models = _extract_available_models(payload["content"]) + assert "gemini-2.5-pro" in available_models + assert "gpt-5" in available_models + assert "grok-4" in available_models + assert len(available_models) >= 5 diff --git a/tests/test_collaboration.py b/tests/test_collaboration.py index 399fe41..41a3534 100644 --- a/tests/test_collaboration.py +++ b/tests/test_collaboration.py @@ -3,6 +3,7 @@ Tests for dynamic context request and collaboration features """ import json +import os from unittest.mock import Mock, patch import pytest @@ -157,95 +158,120 @@ class TestDynamicContextRequests: @patch("tools.shared.base_tool.BaseTool.get_model_provider") async def test_clarification_with_suggested_action(self, mock_get_provider, analyze_tool): """Test clarification request with suggested next action""" - clarification_json = json.dumps( - { - "status": "files_required_to_continue", - "mandatory_instructions": "I need to see the database configuration to analyze the connection error", - "files_needed": ["config/database.yml", "src/db.py"], - "suggested_next_action": { - "tool": "analyze", - "args": { - "prompt": "Analyze database connection timeout issue", - "relevant_files": [ - "/config/database.yml", - "/src/db.py", - "/logs/error.log", - ], + import importlib + + from providers.registry import ModelProviderRegistry + + # Ensure deterministic model configuration for this test regardless of previous suites + ModelProviderRegistry.reset_for_testing() + + original_default = os.environ.get("DEFAULT_MODEL") + + try: + os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash" + import config + + importlib.reload(config) + + clarification_json = json.dumps( + { + "status": "files_required_to_continue", + "mandatory_instructions": "I need to see the database configuration to analyze the connection error", + "files_needed": ["config/database.yml", "src/db.py"], + "suggested_next_action": { + "tool": "analyze", + "args": { + "prompt": "Analyze database connection timeout issue", + "relevant_files": [ + "/config/database.yml", + "/src/db.py", + "/logs/error.log", + ], + }, }, }, - }, - ensure_ascii=False, - ) + ensure_ascii=False, + ) - mock_provider = create_mock_provider() - mock_provider.get_provider_type.return_value = Mock(value="google") - mock_provider.generate_content.return_value = Mock( - content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={} - ) - mock_get_provider.return_value = mock_provider + mock_provider = create_mock_provider() + mock_provider.get_provider_type.return_value = Mock(value="google") + mock_provider.generate_content.return_value = Mock( + content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={} + ) + mock_get_provider.return_value = mock_provider - result = await analyze_tool.execute( - { - "step": "Analyze database connection timeout issue", - "step_number": 1, - "total_steps": 1, - "next_step_required": False, - "findings": "Initial database timeout analysis", - "relevant_files": ["/absolute/logs/error.log"], - } - ) + result = await analyze_tool.execute( + { + "step": "Analyze database connection timeout issue", + "step_number": 1, + "total_steps": 1, + "next_step_required": False, + "findings": "Initial database timeout analysis", + "relevant_files": ["/absolute/logs/error.log"], + } + ) - assert len(result) == 1 + assert len(result) == 1 - response_data = json.loads(result[0].text) + response_data = json.loads(result[0].text) - # Workflow tools should either promote clarification status or handle it in expert analysis - if response_data["status"] == "files_required_to_continue": - # Clarification was properly promoted to main status - # Check if mandatory_instructions is at top level or in content - if "mandatory_instructions" in response_data: - assert "database configuration" in response_data["mandatory_instructions"] - assert "files_needed" in response_data - assert "config/database.yml" in response_data["files_needed"] - assert "src/db.py" in response_data["files_needed"] - elif "content" in response_data: - # Parse content JSON for workflow tools - try: - content_json = json.loads(response_data["content"]) - assert "mandatory_instructions" in content_json + # Workflow tools should either promote clarification status or handle it in expert analysis + if response_data["status"] == "files_required_to_continue": + # Clarification was properly promoted to main status + # Check if mandatory_instructions is at top level or in content + if "mandatory_instructions" in response_data: + assert "database configuration" in response_data["mandatory_instructions"] + assert "files_needed" in response_data + assert "config/database.yml" in response_data["files_needed"] + assert "src/db.py" in response_data["files_needed"] + elif "content" in response_data: + # Parse content JSON for workflow tools + try: + content_json = json.loads(response_data["content"]) + assert "mandatory_instructions" in content_json + assert ( + "database configuration" in content_json["mandatory_instructions"] + or "database" in content_json["mandatory_instructions"] + ) + assert "files_needed" in content_json + files_needed_str = str(content_json["files_needed"]) + assert ( + "config/database.yml" in files_needed_str + or "config" in files_needed_str + or "database" in files_needed_str + ) + except json.JSONDecodeError: + # Content is not JSON, check if it contains required text + content = response_data["content"] + assert "database configuration" in content or "config" in content + elif response_data["status"] == "calling_expert_analysis": + # Clarification may be handled in expert analysis section + if "expert_analysis" in response_data: + expert_analysis = response_data["expert_analysis"] + expert_content = str(expert_analysis) assert ( - "database configuration" in content_json["mandatory_instructions"] - or "database" in content_json["mandatory_instructions"] + "database configuration" in expert_content + or "config/database.yml" in expert_content + or "files_required_to_continue" in expert_content ) - assert "files_needed" in content_json - files_needed_str = str(content_json["files_needed"]) - assert ( - "config/database.yml" in files_needed_str - or "config" in files_needed_str - or "database" in files_needed_str - ) - except json.JSONDecodeError: - # Content is not JSON, check if it contains required text - content = response_data["content"] - assert "database configuration" in content or "config" in content - elif response_data["status"] == "calling_expert_analysis": - # Clarification may be handled in expert analysis section - if "expert_analysis" in response_data: - expert_analysis = response_data["expert_analysis"] - expert_content = str(expert_analysis) - assert ( - "database configuration" in expert_content - or "config/database.yml" in expert_content - or "files_required_to_continue" in expert_content - ) - else: - # Some other status - ensure it's a valid workflow response - assert "step_number" in response_data + else: + # Some other status - ensure it's a valid workflow response + assert "step_number" in response_data - # Check for suggested next action - if "suggested_next_action" in response_data: - action = response_data["suggested_next_action"] - assert action["tool"] == "analyze" + # Check for suggested next action + if "suggested_next_action" in response_data: + action = response_data["suggested_next_action"] + assert action["tool"] == "analyze" + finally: + if original_default is not None: + os.environ["DEFAULT_MODEL"] = original_default + else: + os.environ.pop("DEFAULT_MODEL", None) + + import config + + importlib.reload(config) + ModelProviderRegistry.reset_for_testing() def test_tool_output_model_serialization(self): """Test ToolOutput model serialization""" diff --git a/tests/test_listmodels_restrictions.py b/tests/test_listmodels_restrictions.py index 06a78f8..9cf3cf4 100644 --- a/tests/test_listmodels_restrictions.py +++ b/tests/test_listmodels_restrictions.py @@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch from providers.base import ModelProvider from providers.registry import ModelProviderRegistry -from providers.shared import ProviderType +from providers.shared import ModelCapabilities, ProviderType from tools.listmodels import ListModelsTool @@ -23,10 +23,63 @@ class TestListModelsRestrictions(unittest.TestCase): self.mock_openrouter = MagicMock(spec=ModelProvider) self.mock_openrouter.provider_type = ProviderType.OPENROUTER + def make_capabilities( + canonical: str, friendly: str, *, aliases=None, context: int = 200_000 + ) -> ModelCapabilities: + return ModelCapabilities( + provider=ProviderType.OPENROUTER, + model_name=canonical, + friendly_name=friendly, + intelligence_score=20, + description=friendly, + aliases=aliases or [], + context_window=context, + max_output_tokens=context, + supports_extended_thinking=True, + ) + + opus_caps = make_capabilities( + "anthropic/claude-opus-4-20240229", + "Claude Opus", + aliases=["opus"], + ) + sonnet_caps = make_capabilities( + "anthropic/claude-sonnet-4-20240229", + "Claude Sonnet", + aliases=["sonnet"], + ) + deepseek_caps = make_capabilities( + "deepseek/deepseek-r1-0528:free", + "DeepSeek R1", + aliases=[], + ) + qwen_caps = make_capabilities( + "qwen/qwen3-235b-a22b-04-28:free", + "Qwen3", + aliases=[], + ) + + self._openrouter_caps_map = { + "anthropic/claude-opus-4": opus_caps, + "opus": opus_caps, + "anthropic/claude-opus-4-20240229": opus_caps, + "anthropic/claude-sonnet-4": sonnet_caps, + "sonnet": sonnet_caps, + "anthropic/claude-sonnet-4-20240229": sonnet_caps, + "deepseek/deepseek-r1-0528:free": deepseek_caps, + "qwen/qwen3-235b-a22b-04-28:free": qwen_caps, + } + + self.mock_openrouter.get_capabilities.side_effect = self._openrouter_caps_map.__getitem__ + self.mock_openrouter.get_capabilities_by_rank.return_value = [] + self.mock_openrouter.list_models.return_value = [] + # Create mock Gemini provider for comparison self.mock_gemini = MagicMock(spec=ModelProvider) self.mock_gemini.provider_type = ProviderType.GOOGLE self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"] + self.mock_gemini.get_capabilities_by_rank.return_value = [] + self.mock_gemini.get_capabilities_by_rank.return_value = [] def tearDown(self): """Clean up after tests.""" @@ -159,7 +212,7 @@ class TestListModelsRestrictions(unittest.TestCase): for line in lines: if "OpenRouter" in line and "✅" in line: openrouter_section_found = True - elif "Available Models" in line and openrouter_section_found: + elif ("Models (policy restricted)" in line or "Available Models" in line) and openrouter_section_found: in_openrouter_section = True elif in_openrouter_section: # Check for lines with model names in backticks @@ -179,11 +232,11 @@ class TestListModelsRestrictions(unittest.TestCase): len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}" ) - # Verify list_models was called with respect_restrictions=True - self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True) + # Verify we did not fall back to unrestricted listing + self.mock_openrouter.list_models.assert_not_called() # Check for restriction note - self.assertIn("Restricted to models matching:", result) + self.assertIn("OpenRouter models restricted by", result) @patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}, clear=True) @patch("providers.openrouter_registry.OpenRouterModelRegistry") diff --git a/tests/test_model_metadata_continuation.py b/tests/test_model_metadata_continuation.py index 5065804..e190441 100644 --- a/tests/test_model_metadata_continuation.py +++ b/tests/test_model_metadata_continuation.py @@ -121,38 +121,59 @@ class TestModelMetadataContinuation: @pytest.mark.asyncio async def test_no_previous_assistant_turn_defaults(self): """Test behavior when there's no previous assistant turn.""" - thread_id = create_thread("chat", {"prompt": "test"}) + # Save and set DEFAULT_MODEL for test + import importlib + import os - # Only add user turns - add_turn(thread_id, "user", "First question") - add_turn(thread_id, "user", "Second question") + original_default = os.environ.get("DEFAULT_MODEL", "") + os.environ["DEFAULT_MODEL"] = "auto" + import config + import utils.model_context - arguments = {"continuation_id": thread_id} + importlib.reload(config) + importlib.reload(utils.model_context) - # Mock dependencies - with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc: - mock_calc.return_value = MagicMock( - total_tokens=200000, - content_tokens=160000, - response_tokens=40000, - file_tokens=64000, - history_tokens=64000, - ) + try: + thread_id = create_thread("chat", {"prompt": "test"}) - with patch("utils.conversation_memory.build_conversation_history") as mock_build: - mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000) + # Only add user turns + add_turn(thread_id, "user", "First question") + add_turn(thread_id, "user", "Second question") - # Call the actual function - enhanced_args = await reconstruct_thread_context(arguments) + arguments = {"continuation_id": thread_id} - # Should not have set a model - assert enhanced_args.get("model") is None + # Mock dependencies + with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc: + mock_calc.return_value = MagicMock( + total_tokens=200000, + content_tokens=160000, + response_tokens=40000, + file_tokens=64000, + history_tokens=64000, + ) - # ModelContext should use DEFAULT_MODEL - model_context = ModelContext.from_arguments(enhanced_args) - from config import DEFAULT_MODEL + with patch("utils.conversation_memory.build_conversation_history") as mock_build: + mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000) - assert model_context.model_name == DEFAULT_MODEL + # Call the actual function + enhanced_args = await reconstruct_thread_context(arguments) + + # Should not have set a model + assert enhanced_args.get("model") is None + + # ModelContext should use DEFAULT_MODEL + model_context = ModelContext.from_arguments(enhanced_args) + from config import DEFAULT_MODEL + + assert model_context.model_name == DEFAULT_MODEL + finally: + # Restore original value + if original_default: + os.environ["DEFAULT_MODEL"] = original_default + else: + os.environ.pop("DEFAULT_MODEL", None) + importlib.reload(config) + importlib.reload(utils.model_context) @pytest.mark.asyncio async def test_explicit_model_overrides_previous_turn(self): diff --git a/tests/test_model_restrictions.py b/tests/test_model_restrictions.py index 4277463..6096764 100644 --- a/tests/test_model_restrictions.py +++ b/tests/test_model_restrictions.py @@ -49,17 +49,32 @@ class TestModelRestrictionService: def test_load_multiple_models_restriction(self): """Test loading multiple allowed models.""" with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}): - service = ModelRestrictionService() + # Instantiate providers so alias resolution for allow-lists is available + openai_provider = OpenAIModelProvider(api_key="test-key") + gemini_provider = GeminiModelProvider(api_key="test-key") - # Check OpenAI models - assert service.is_allowed(ProviderType.OPENAI, "o3-mini") - assert service.is_allowed(ProviderType.OPENAI, "o4-mini") - assert not service.is_allowed(ProviderType.OPENAI, "o3") + from providers.registry import ModelProviderRegistry - # Check Google models - assert service.is_allowed(ProviderType.GOOGLE, "flash") - assert service.is_allowed(ProviderType.GOOGLE, "pro") - assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro") + def fake_get_provider(provider_type, force_new=False): + mapping = { + ProviderType.OPENAI: openai_provider, + ProviderType.GOOGLE: gemini_provider, + } + return mapping.get(provider_type) + + with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider): + + service = ModelRestrictionService() + + # Check OpenAI models + assert service.is_allowed(ProviderType.OPENAI, "o3-mini") + assert service.is_allowed(ProviderType.OPENAI, "o4-mini") + assert not service.is_allowed(ProviderType.OPENAI, "o3") + + # Check Google models + assert service.is_allowed(ProviderType.GOOGLE, "flash") + assert service.is_allowed(ProviderType.GOOGLE, "pro") + assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro") def test_case_insensitive_and_whitespace_handling(self): """Test that model names are case-insensitive and whitespace is trimmed.""" @@ -111,13 +126,17 @@ class TestModelRestrictionService: def test_shorthand_names_in_restrictions(self): """Test that shorthand names work in restrictions.""" - with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,o3-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}): + with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}): + # Instantiate providers so the registry can resolve aliases + OpenAIModelProvider(api_key="test-key") + GeminiModelProvider(api_key="test-key") + service = ModelRestrictionService() # When providers check models, they pass both resolved and original names - # OpenAI: 'mini' shorthand allows o4-mini - assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "mini") # How providers actually call it - assert not service.is_allowed(ProviderType.OPENAI, "o4-mini") # Direct check without original (for testing) + # OpenAI: 'o4mini' shorthand allows o4-mini + assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini") # How providers actually call it + assert service.is_allowed(ProviderType.OPENAI, "o4-mini") # Canonical should also be allowed # OpenAI: o3-mini allowed directly assert service.is_allowed(ProviderType.OPENAI, "o3-mini") @@ -280,19 +299,25 @@ class TestProviderIntegration: provider = GeminiModelProvider(api_key="test-key") - # Test case: Only alias "flash" is allowed, not the full name - # If parameters are in wrong order, this test will catch it + from providers.registry import ModelProviderRegistry - # Should allow "flash" alias - assert provider.validate_model_name("flash") + with patch.object(ModelProviderRegistry, "get_provider", return_value=provider): - # Should allow getting capabilities for "flash" - capabilities = provider.get_capabilities("flash") - assert capabilities.model_name == "gemini-2.5-flash" + # Test case: Only alias "flash" is allowed, not the full name + # If parameters are in wrong order, this test will catch it - # Test the edge case: Try to use full model name when only alias is allowed - # This should NOT be allowed - only the alias "flash" is in the restriction list - assert not provider.validate_model_name("gemini-2.5-flash") + # Should allow "flash" alias + assert provider.validate_model_name("flash") + + # Should allow getting capabilities for "flash" + capabilities = provider.get_capabilities("flash") + assert capabilities.model_name == "gemini-2.5-flash" + + # Canonical form should also be allowed now that alias is on the allowlist + assert provider.validate_model_name("gemini-2.5-flash") + # Unrelated models remain blocked + assert not provider.validate_model_name("pro") + assert not provider.validate_model_name("gemini-2.5-pro") @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) def test_gemini_parameter_order_edge_case_full_name_only(self): @@ -570,17 +595,27 @@ class TestShorthandRestrictions: # Test OpenAI provider openai_provider = OpenAIModelProvider(api_key="test-key") - assert openai_provider.validate_model_name("mini") # Should work with shorthand - # When restricting to "mini", you can't use "o4-mini" directly - this is correct behavior - assert not openai_provider.validate_model_name("o4-mini") # Not allowed - only shorthand is allowed - assert not openai_provider.validate_model_name("o3-mini") # Not allowed - - # Test Gemini provider gemini_provider = GeminiModelProvider(api_key="test-key") - assert gemini_provider.validate_model_name("flash") # Should work with shorthand - # Same for Gemini - if you restrict to "flash", you can't use the full name - assert not gemini_provider.validate_model_name("gemini-2.5-flash") # Not allowed - assert not gemini_provider.validate_model_name("pro") # Not allowed + + from providers.registry import ModelProviderRegistry + + def registry_side_effect(provider_type, force_new=False): + mapping = { + ProviderType.OPENAI: openai_provider, + ProviderType.GOOGLE: gemini_provider, + } + return mapping.get(provider_type) + + with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect): + assert openai_provider.validate_model_name("mini") # Should work with shorthand + assert openai_provider.validate_model_name("gpt-5-mini") # Canonical resolved from shorthand + assert not openai_provider.validate_model_name("o4-mini") # Unrelated model still blocked + assert not openai_provider.validate_model_name("o3-mini") + + # Test Gemini provider + assert gemini_provider.validate_model_name("flash") # Should work with shorthand + assert gemini_provider.validate_model_name("gemini-2.5-flash") # Canonical allowed + assert not gemini_provider.validate_model_name("pro") # Not allowed @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"}) def test_multiple_shorthands_for_same_model(self): @@ -596,9 +631,9 @@ class TestShorthandRestrictions: assert openai_provider.validate_model_name("mini") # mini -> o4-mini assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini - # Resolved names work only if explicitly allowed + # Resolved names should be allowed when their shorthands are present assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed - assert not openai_provider.validate_model_name("o3-mini") # Not explicitly allowed, only shorthand + assert openai_provider.validate_model_name("o3-mini") # Allowed via shorthand # Other models should not work assert not openai_provider.validate_model_name("o3") diff --git a/tests/test_openrouter_provider.py b/tests/test_openrouter_provider.py index 7717418..f38d3e8 100644 --- a/tests/test_openrouter_provider.py +++ b/tests/test_openrouter_provider.py @@ -260,9 +260,10 @@ class TestOpenRouterAutoMode: os.environ["DEFAULT_MODEL"] = "auto" mock_provider_class = Mock() - mock_provider_instance = Mock(spec=["get_provider_type", "list_models"]) + mock_provider_instance = Mock(spec=["get_provider_type", "list_models", "get_all_model_capabilities"]) mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER mock_provider_instance.list_models.return_value = [] + mock_provider_instance.get_all_model_capabilities.return_value = {} mock_provider_class.return_value = mock_provider_instance ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class) diff --git a/tests/test_provider_routing_bugs.py b/tests/test_provider_routing_bugs.py index dec2f83..d2b7133 100644 --- a/tests/test_provider_routing_bugs.py +++ b/tests/test_provider_routing_bugs.py @@ -293,13 +293,7 @@ class TestOpenRouterAliasRestrictions: # o3 -> openai/o3 # gpt4.1 -> should not exist (expected to be filtered out) - expected_models = { - "openai/o3-mini", - "google/gemini-2.5-pro", - "google/gemini-2.5-flash", - "openai/o4-mini", - "openai/o3", - } + expected_models = {"o3-mini", "pro", "flash", "o4-mini", "o3"} available_model_names = set(available_models.keys()) @@ -355,9 +349,11 @@ class TestOpenRouterAliasRestrictions: available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True) expected_models = { - "openai/o3-mini", # from alias + "o3-mini", # alias + "openai/o3-mini", # canonical "anthropic/claude-opus-4.1", # full name - "google/gemini-2.5-flash", # from alias + "flash", # alias + "google/gemini-2.5-flash", # canonical } available_model_names = set(available_models.keys()) diff --git a/tools/listmodels.py b/tools/listmodels.py index e5ef05b..bc45bf3 100644 --- a/tools/listmodels.py +++ b/tools/listmodels.py @@ -83,9 +83,18 @@ class ListModelsTool(BaseTool): from providers.openrouter_registry import OpenRouterModelRegistry from providers.registry import ModelProviderRegistry from providers.shared import ProviderType + from utils.model_restrictions import get_restriction_service output_lines = ["# Available AI Models\n"] + restriction_service = get_restriction_service() + restricted_models_by_provider: dict[ProviderType, list[str]] = {} + + if restriction_service: + restricted_map = ModelProviderRegistry.get_available_models(respect_restrictions=True) + for model_name, provider_type in restricted_map.items(): + restricted_models_by_provider.setdefault(provider_type, []).append(model_name) + # Map provider types to friendly names and their models provider_info = { ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"}, @@ -94,6 +103,43 @@ class ListModelsTool(BaseTool): ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"}, } + def format_model_entry(provider, display_name: str) -> list[str]: + try: + capabilities = provider.get_capabilities(display_name) + except ValueError: + return [f"- `{display_name}` *(not recognized by provider)*"] + + canonical = capabilities.model_name + if canonical.lower() == display_name.lower(): + header = f"- `{canonical}`" + else: + header = f"- `{display_name}` → `{canonical}`" + + try: + context_value = capabilities.context_window or 0 + except AttributeError: + context_value = 0 + try: + context_value = int(context_value) + except (TypeError, ValueError): + context_value = 0 + + if context_value >= 1_000_000: + context_str = f"{context_value // 1_000_000}M context" + elif context_value >= 1_000: + context_str = f"{context_value // 1_000}K context" + elif context_value > 0: + context_str = f"{context_value} context" + else: + context_str = "unknown context" + + try: + description = capabilities.description or "No description available" + except AttributeError: + description = "No description available" + lines = [header, f" - {context_str}", f" - {description}"] + return lines + # Check each native provider type for provider_type, info in provider_info.items(): # Check if provider is enabled @@ -104,30 +150,49 @@ class ListModelsTool(BaseTool): if is_configured: output_lines.append("**Status**: Configured and available") - output_lines.append("\n**Models**:") + has_restrictions = bool(restriction_service and restriction_service.has_restrictions(provider_type)) - aliases = [] - for model_name, capabilities in provider.get_capabilities_by_rank(): - description = capabilities.description or "No description available" - context_window = capabilities.context_window + if has_restrictions: + restricted_names = sorted(set(restricted_models_by_provider.get(provider_type, []))) - if context_window >= 1_000_000: - context_str = f"{context_window // 1_000_000}M context" - elif context_window >= 1_000: - context_str = f"{context_window // 1_000}K context" + if restricted_names: + output_lines.append("\n**Models (policy restricted)**:") + for model_name in restricted_names: + output_lines.extend(format_model_entry(provider, model_name)) else: - context_str = f"{context_window} context" if context_window > 0 else "unknown context" + output_lines.append("\n*No models are currently allowed by restriction policy.*") + else: + output_lines.append("\n**Models**:") - output_lines.append(f"- `{model_name}` - {context_str}") - output_lines.append(f" - {description}") + aliases = [] + for model_name, capabilities in provider.get_capabilities_by_rank(): + try: + description = capabilities.description or "No description available" + except AttributeError: + description = "No description available" - for alias in capabilities.aliases or []: - if alias != model_name: - aliases.append(f"- `{alias}` → `{model_name}`") + try: + context_window = capabilities.context_window or 0 + except AttributeError: + context_window = 0 - if aliases: - output_lines.append("\n**Aliases**:") - output_lines.extend(sorted(aliases)) + if context_window >= 1_000_000: + context_str = f"{context_window // 1_000_000}M context" + elif context_window >= 1_000: + context_str = f"{context_window // 1_000}K context" + else: + context_str = f"{context_window} context" if context_window > 0 else "unknown context" + + output_lines.append(f"- `{model_name}` - {context_str}") + output_lines.append(f" - {description}") + + for alias in capabilities.aliases or []: + if alias != model_name: + aliases.append(f"- `{alias}` → `{model_name}`") + + if aliases: + output_lines.append("\n**Aliases**:") + output_lines.extend(sorted(aliases)) else: output_lines.append(f"**Status**: Not configured (set {info['env_key']})") @@ -144,19 +209,10 @@ class ListModelsTool(BaseTool): output_lines.append("**Description**: Access to multiple cloud AI providers via unified API") try: - # Get OpenRouter provider from registry to properly apply restrictions - from providers.registry import ModelProviderRegistry - from providers.shared import ProviderType - provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER) if provider: - # Get models with restrictions applied - available_models = provider.list_models(respect_restrictions=True) registry = OpenRouterModelRegistry() - # Group by provider and retain ranking information for consistent ordering - providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {} - def _format_context(tokens: int) -> str: if not tokens: return "?" @@ -166,53 +222,83 @@ class ListModelsTool(BaseTool): return f"{tokens // 1_000}K" return str(tokens) - for model_name in available_models: - config = registry.resolve(model_name) - provider_name = "other" - if config and "/" in config.model_name: - provider_name = config.model_name.split("/")[0] - elif "/" in model_name: - provider_name = model_name.split("/")[0] + has_restrictions = bool( + restriction_service and restriction_service.has_restrictions(ProviderType.OPENROUTER) + ) - providers_models.setdefault(provider_name, []) + if has_restrictions: + restricted_names = sorted(set(restricted_models_by_provider.get(ProviderType.OPENROUTER, []))) - rank = config.get_effective_capability_rank() if config else 0 - providers_models[provider_name].append((rank, model_name, config)) + output_lines.append("\n**Models (policy restricted)**:") + if restricted_names: + for model_name in restricted_names: + try: + caps = provider.get_capabilities(model_name) + except ValueError: + output_lines.append(f"- `{model_name}` *(not recognized by provider)*") + continue - output_lines.append("\n**Available Models**:") - for provider_name, models in sorted(providers_models.items()): - output_lines.append(f"\n*{provider_name.title()}:*") - for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])): - if config: - context_str = _format_context(config.context_window) + context_value = int(caps.context_window or 0) + context_str = _format_context(context_value) suffix_parts = [f"{context_str} context"] - if getattr(config, "supports_extended_thinking", False): + if caps.supports_extended_thinking: suffix_parts.append("thinking") suffix = ", ".join(suffix_parts) - output_lines.append(f"- `{alias}` → `{config.model_name}` (score {rank}, {suffix})") - else: - output_lines.append(f"- `{alias}` (score {rank})") - total_models = len(available_models) - # Show all models - no truncation message needed + arrow = "" + if caps.model_name.lower() != model_name.lower(): + arrow = f" → `{caps.model_name}`" - # Check if restrictions are applied - restriction_service = None - try: - from utils.model_restrictions import get_restriction_service + score = caps.get_effective_capability_rank() + output_lines.append(f"- `{model_name}`{arrow} (score {score}, {suffix})") - restriction_service = get_restriction_service() - if restriction_service.has_restrictions(ProviderType.OPENROUTER): - allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER) - output_lines.append( - f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}" - ) - except Exception as e: - logger.warning(f"Error checking OpenRouter restrictions: {e}") + allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER) or set() + if allowed_set: + output_lines.append( + f"\n*OpenRouter models restricted by OPENROUTER_ALLOWED_MODELS: {', '.join(sorted(allowed_set))}*" + ) + else: + output_lines.append("- *No models allowed by current restriction policy.*") + else: + available_models = provider.list_models(respect_restrictions=True) + providers_models: dict[str, list[tuple[int, str, Optional[Any]]]] = {} + + for model_name in available_models: + config = registry.resolve(model_name) + provider_name = "other" + if config and "/" in config.model_name: + provider_name = config.model_name.split("/")[0] + elif "/" in model_name: + provider_name = model_name.split("/")[0] + + providers_models.setdefault(provider_name, []) + + rank = config.get_effective_capability_rank() if config else 0 + providers_models[provider_name].append((rank, model_name, config)) + + output_lines.append("\n**Available Models**:") + for provider_name, models in sorted(providers_models.items()): + output_lines.append(f"\n*{provider_name.title()}:*") + for rank, alias, config in sorted(models, key=lambda item: (-item[0], item[1])): + if config: + context_str = _format_context(getattr(config, "context_window", 0)) + suffix_parts = [f"{context_str} context"] + if getattr(config, "supports_extended_thinking", False): + suffix_parts.append("thinking") + suffix = ", ".join(suffix_parts) + + arrow = "" + if config.model_name.lower() != alias.lower(): + arrow = f" → `{config.model_name}`" + + output_lines.append(f"- `{alias}`{arrow} (score {rank}, {suffix})") + else: + output_lines.append(f"- `{alias}` (score {rank})") else: output_lines.append("**Error**: Could not load OpenRouter provider") except Exception as e: + logger.exception("Error listing OpenRouter models: %s", e) output_lines.append(f"**Error loading models**: {str(e)}") else: output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)") diff --git a/utils/model_restrictions.py b/utils/model_restrictions.py index 8b0984e..6fbcc45 100644 --- a/utils/model_restrictions.py +++ b/utils/model_restrictions.py @@ -22,6 +22,7 @@ Example: import logging import os +from collections import defaultdict from typing import Optional from providers.shared import ProviderType @@ -58,6 +59,7 @@ class ModelRestrictionService: def __init__(self): """Initialize the restriction service by loading from environment.""" self.restrictions: dict[ProviderType, set[str]] = {} + self._alias_resolution_cache: dict[ProviderType, dict[str, str]] = defaultdict(dict) self._load_from_env() def _load_from_env(self) -> None: @@ -79,6 +81,7 @@ class ModelRestrictionService: if models: self.restrictions[provider_type] = models + self._alias_resolution_cache[provider_type] = {} logger.info(f"{provider_type.value} allowed models: {sorted(models)}") else: # All entries were empty after cleaning - treat as no restrictions @@ -150,7 +153,41 @@ class ModelRestrictionService: names_to_check.add(original_name.lower()) # If any of the names is in the allowed set, it's allowed - return any(name in allowed_set for name in names_to_check) + if any(name in allowed_set for name in names_to_check): + return True + + # Attempt to resolve canonical names for allowed aliases using provider metadata. + try: + from providers.registry import ModelProviderRegistry + + provider = ModelProviderRegistry.get_provider(provider_type) + except Exception: # pragma: no cover - registry lookup failure shouldn't break validation + provider = None + + if provider: + cache = self._alias_resolution_cache.setdefault(provider_type, {}) + + for allowed_entry in list(allowed_set): + normalized_resolved = cache.get(allowed_entry) + + if not normalized_resolved: + try: + resolved = provider._resolve_model_name(allowed_entry) + except Exception: # pragma: no cover - resolution failures are treated as non-matches + continue + + if not resolved: + continue + + normalized_resolved = resolved.lower() + cache[allowed_entry] = normalized_resolved + + if normalized_resolved in names_to_check: + allowed_set.add(normalized_resolved) + cache[normalized_resolved] = normalized_resolved + return True + + return False def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]: """