diff --git a/docs/advanced-usage.md b/docs/advanced-usage.md index 173c75b..79a24ed 100644 --- a/docs/advanced-usage.md +++ b/docs/advanced-usage.md @@ -84,6 +84,9 @@ OPENAI_ALLOWED_MODELS=o4-mini,o3-mini # Only allow specific Gemini models GOOGLE_ALLOWED_MODELS=flash +# Only allow specific OpenRouter models +OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral + # Use shorthand names or full model names OPENAI_ALLOWED_MODELS=mini,o3-mini # mini = o4-mini ``` @@ -99,17 +102,21 @@ OPENAI_ALLOWED_MODELS=mini,o3-mini # mini = o4-mini # Cost control - only cheap models OPENAI_ALLOWED_MODELS=o4-mini GOOGLE_ALLOWED_MODELS=flash +OPENROUTER_ALLOWED_MODELS=haiku,sonnet # Single model per provider OPENAI_ALLOWED_MODELS=o4-mini GOOGLE_ALLOWED_MODELS=pro +OPENROUTER_ALLOWED_MODELS=opus ``` **Notes:** - Applies to all usage including auto mode - Case-insensitive, whitespace tolerant - Server warns about typos at startup -- Only affects native providers (not OpenRouter/Custom) +- `OPENAI_ALLOWED_MODELS` and `GOOGLE_ALLOWED_MODELS` only affect native providers +- `OPENROUTER_ALLOWED_MODELS` affects OpenRouter models accessed via custom provider (where `is_custom: false` in custom_models.json) +- Custom local models (`is_custom: true`) are not affected by any restrictions ## Thinking Modes diff --git a/docs/testing.md b/docs/testing.md index 4601bfe..b001e00 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -84,7 +84,7 @@ isort . ## What Each Test Suite Covers -### Unit Tests (256 tests) +### Unit Tests Test isolated components and functions: - **Provider functionality**: Model initialization, API interactions, capability checks - **Tool operations**: All MCP tools (chat, analyze, debug, etc.) @@ -92,7 +92,7 @@ Test isolated components and functions: - **File handling**: Path validation, token limits, deduplication - **Auto mode**: Model selection logic and fallback behavior -### Simulator Tests (14 tests) +### Simulator Tests Validate real-world usage scenarios by simulating actual Claude prompts: - **Basic conversations**: Multi-turn chat functionality with real prompts - **Cross-tool continuation**: Context preservation across different tools diff --git a/providers/custom.py b/providers/custom.py index 7d2feab..c551311 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -128,8 +128,21 @@ class CustomProvider(OpenAICompatibleProvider): capabilities = self._registry.get_capabilities(model_name) if capabilities: - # Update provider type to CUSTOM - capabilities.provider = ProviderType.CUSTOM + # Check if this is an OpenRouter model and apply restrictions + config = self._registry.resolve(model_name) + if config and not config.is_custom: + # This is an OpenRouter model, check restrictions + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + if not restriction_service.is_allowed(ProviderType.OPENROUTER, config.model_name, model_name): + raise ValueError(f"OpenRouter model '{model_name}' is not allowed by restriction policy.") + + # Update provider type to OPENROUTER for OpenRouter models + capabilities.provider = ProviderType.OPENROUTER + else: + # Update provider type to CUSTOM for local custom models + capabilities.provider = ProviderType.CUSTOM return capabilities else: # Resolve any potential aliases and create generic capabilities @@ -188,12 +201,23 @@ class CustomProvider(OpenAICompatibleProvider): logging.debug(f"Model '{model_name}' -> '{model_id}' validated via registry (custom model)") return True else: - # This is a cloud/OpenRouter model - if OpenRouter is available, defer to it + # This is a cloud/OpenRouter model - check restrictions if available if openrouter_available: - logging.debug(f"Model '{model_name}' -> '{model_id}' deferred to OpenRouter (cloud model)") + # Check if OpenRouter model is allowed by restrictions + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + if not restriction_service.is_allowed(ProviderType.OPENROUTER, model_id, model_name): + logging.debug(f"Model '{model_name}' -> '{model_id}' blocked by OpenRouter restrictions") + return False + + logging.debug( + f"Model '{model_name}' -> '{model_id}' validated via OpenRouter (passes restrictions)" + ) + return True else: logging.debug(f"Model '{model_name}' -> '{model_id}' rejected (cloud model, no OpenRouter)") - return False + return False # Handle version tags for unknown models (e.g., "my-model:latest") clean_model_name = model_name diff --git a/tests/test_custom_provider.py b/tests/test_custom_provider.py index 7e6e660..5a0275f 100644 --- a/tests/test_custom_provider.py +++ b/tests/test_custom_provider.py @@ -46,10 +46,15 @@ class TestCustomProvider: """Test get_capabilities returns registry capabilities when available.""" provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1") - # Test with a model that should be in the registry + # Test with a model that should be in the registry (OpenRouter model) capabilities = provider.get_capabilities("llama") - assert capabilities.provider == ProviderType.CUSTOM + assert capabilities.provider == ProviderType.OPENROUTER # llama is an OpenRouter model (is_custom=false) + assert capabilities.context_window > 0 + + # Test with a custom model (is_custom=true) + capabilities = provider.get_capabilities("local-llama") + assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true assert capabilities.context_window > 0 def test_get_capabilities_generic_fallback(self): diff --git a/tests/test_model_restrictions.py b/tests/test_model_restrictions.py index 176852d..4867dd4 100644 --- a/tests/test_model_restrictions.py +++ b/tests/test_model_restrictions.py @@ -24,10 +24,13 @@ class TestModelRestrictionService: assert service.is_allowed(ProviderType.OPENAI, "o3-mini") assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05") assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash-preview-05-20") + assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-opus") + assert service.is_allowed(ProviderType.OPENROUTER, "openai/o3") # Should have no restrictions assert not service.has_restrictions(ProviderType.OPENAI) assert not service.has_restrictions(ProviderType.GOOGLE) + assert not service.has_restrictions(ProviderType.OPENROUTER) def test_load_single_model_restriction(self): """Test loading a single allowed model.""" @@ -39,8 +42,9 @@ class TestModelRestrictionService: assert not service.is_allowed(ProviderType.OPENAI, "o3") assert not service.is_allowed(ProviderType.OPENAI, "o4-mini") - # Google should have no restrictions + # Google and OpenRouter should have no restrictions assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05") + assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-opus") def test_load_multiple_models_restriction(self): """Test loading multiple allowed models.""" @@ -146,6 +150,68 @@ class TestModelRestrictionService: assert "o4-mimi" in caplog.text assert "not a recognized" in caplog.text + def test_openrouter_model_restrictions(self): + """Test OpenRouter model restrictions functionality.""" + with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet"}): + service = ModelRestrictionService() + + # Should only allow specified OpenRouter models + assert service.is_allowed(ProviderType.OPENROUTER, "opus") + assert service.is_allowed(ProviderType.OPENROUTER, "sonnet") + assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-opus", "opus") # With original name + assert not service.is_allowed(ProviderType.OPENROUTER, "haiku") + assert not service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-haiku") + assert not service.is_allowed(ProviderType.OPENROUTER, "mistral-large") + + # Other providers should have no restrictions + assert service.is_allowed(ProviderType.OPENAI, "o3") + assert service.is_allowed(ProviderType.GOOGLE, "pro") + + # Should have restrictions for OpenRouter + assert service.has_restrictions(ProviderType.OPENROUTER) + assert not service.has_restrictions(ProviderType.OPENAI) + assert not service.has_restrictions(ProviderType.GOOGLE) + + def test_openrouter_filter_models(self): + """Test filtering OpenRouter models based on restrictions.""" + with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,mistral"}): + service = ModelRestrictionService() + + models = ["opus", "sonnet", "haiku", "mistral", "llama"] + filtered = service.filter_models(ProviderType.OPENROUTER, models) + + assert filtered == ["opus", "mistral"] + + def test_combined_provider_restrictions(self): + """Test that restrictions work correctly when set for multiple providers.""" + with patch.dict( + os.environ, + { + "OPENAI_ALLOWED_MODELS": "o3-mini", + "GOOGLE_ALLOWED_MODELS": "flash", + "OPENROUTER_ALLOWED_MODELS": "opus,sonnet", + }, + ): + service = ModelRestrictionService() + + # OpenAI restrictions + assert service.is_allowed(ProviderType.OPENAI, "o3-mini") + assert not service.is_allowed(ProviderType.OPENAI, "o3") + + # Google restrictions + assert service.is_allowed(ProviderType.GOOGLE, "flash") + assert not service.is_allowed(ProviderType.GOOGLE, "pro") + + # OpenRouter restrictions + assert service.is_allowed(ProviderType.OPENROUTER, "opus") + assert service.is_allowed(ProviderType.OPENROUTER, "sonnet") + assert not service.is_allowed(ProviderType.OPENROUTER, "haiku") + + # All providers should have restrictions + assert service.has_restrictions(ProviderType.OPENAI) + assert service.has_restrictions(ProviderType.GOOGLE) + assert service.has_restrictions(ProviderType.OPENROUTER) + class TestProviderIntegration: """Test integration with actual providers.""" @@ -195,6 +261,96 @@ class TestProviderIntegration: assert "not allowed by restriction policy" in str(exc_info.value) +class TestCustomProviderOpenRouterRestrictions: + """Test custom provider integration with OpenRouter restrictions.""" + + @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet", "OPENROUTER_API_KEY": "test-key"}) + def test_custom_provider_respects_openrouter_restrictions(self): + """Test that custom provider respects OpenRouter restrictions for cloud models.""" + # Clear any cached restriction service + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + + from providers.custom import CustomProvider + + provider = CustomProvider(base_url="http://test.com/v1") + + # Should validate allowed OpenRouter models (is_custom=false) + assert provider.validate_model_name("opus") + assert provider.validate_model_name("sonnet") + + # Should not validate disallowed OpenRouter models + assert not provider.validate_model_name("haiku") + + # Should still validate custom models (is_custom=true) regardless of restrictions + assert provider.validate_model_name("local-llama") # This has is_custom=true + + @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus", "OPENROUTER_API_KEY": "test-key"}) + def test_custom_provider_openrouter_capabilities_restrictions(self): + """Test that custom provider's get_capabilities respects OpenRouter restrictions.""" + # Clear any cached restriction service + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + + from providers.custom import CustomProvider + + provider = CustomProvider(base_url="http://test.com/v1") + + # Should work for allowed OpenRouter model + capabilities = provider.get_capabilities("opus") + assert capabilities.provider == ProviderType.OPENROUTER + + # Should raise for disallowed OpenRouter model + with pytest.raises(ValueError) as exc_info: + provider.get_capabilities("haiku") + assert "not allowed by restriction policy" in str(exc_info.value) + + # Should still work for custom models (is_custom=true) + capabilities = provider.get_capabilities("local-llama") + assert capabilities.provider == ProviderType.CUSTOM + + @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus"}, clear=False) + def test_custom_provider_no_openrouter_key_ignores_restrictions(self): + """Test that when OpenRouter key is not set, cloud models are rejected regardless of restrictions.""" + # Make sure OPENROUTER_API_KEY is not set + if "OPENROUTER_API_KEY" in os.environ: + del os.environ["OPENROUTER_API_KEY"] + # Clear any cached restriction service + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + + from providers.custom import CustomProvider + + provider = CustomProvider(base_url="http://test.com/v1") + + # Should not validate OpenRouter models when key is not available + assert not provider.validate_model_name("opus") # Even though it's in allowed list + assert not provider.validate_model_name("haiku") + + # Should still validate custom models + assert provider.validate_model_name("local-llama") + + @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "", "OPENROUTER_API_KEY": "test-key"}) + def test_custom_provider_empty_restrictions_allows_all_openrouter(self): + """Test that empty OPENROUTER_ALLOWED_MODELS allows all OpenRouter models.""" + # Clear any cached restriction service + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + + from providers.custom import CustomProvider + + provider = CustomProvider(base_url="http://test.com/v1") + + # Should validate all OpenRouter models when restrictions are empty + assert provider.validate_model_name("opus") + assert provider.validate_model_name("sonnet") + assert provider.validate_model_name("haiku") + + class TestRegistryIntegration: """Test integration with ModelProviderRegistry.""" diff --git a/tools/analyze.py b/tools/analyze.py index 90e771a..f0c13d6 100644 --- a/tools/analyze.py +++ b/tools/analyze.py @@ -41,7 +41,8 @@ class AnalyzeTool(BaseTool): "Supports both individual files and entire directories. " "Use this when you need to analyze files, examine code, or understand specific aspects of a codebase. " "Perfect for: codebase exploration, dependency analysis, pattern detection. " - "Always uses file paths for clean terminal output." + "Always uses file paths for clean terminal output. " + "Note: If you're not currently using a top-tier model such as Opus 4 or above, these tools can provide enhanced capabilities." ) def get_input_schema(self) -> dict[str, Any]: diff --git a/tools/chat.py b/tools/chat.py index e2fa501..0328ba4 100644 --- a/tools/chat.py +++ b/tools/chat.py @@ -43,7 +43,8 @@ class ChatTool(BaseTool): "collaborative brainstorming, validating your checklists and approaches, exploring alternatives. " "Also great for: explanations, comparisons, general development questions. " "Use this when you want to ask questions, brainstorm ideas, get opinions, discuss topics, " - "share your thinking, or need explanations about concepts and approaches." + "share your thinking, or need explanations about concepts and approaches. " + "Note: If you're not currently using a top-tier model such as Opus 4 or above, these tools can provide enhanced capabilities." ) def get_input_schema(self) -> dict[str, Any]: diff --git a/tools/codereview.py b/tools/codereview.py index 031bb74..1dfd480 100644 --- a/tools/codereview.py +++ b/tools/codereview.py @@ -78,7 +78,8 @@ class CodeReviewTool(BaseTool): "Supports focused reviews: security, performance, or quick checks. " "Choose thinking_mode based on review scope: 'low' for small code snippets, " "'medium' for standard files/modules (default), 'high' for complex systems/architectures, " - "'max' for critical security audits or large codebases requiring deepest analysis." + "'max' for critical security audits or large codebases requiring deepest analysis. " + "Note: If you're not currently using a top-tier model such as Opus 4 or above, these tools can provide enhanced capabilities." ) def get_input_schema(self) -> dict[str, Any]: diff --git a/tools/debug.py b/tools/debug.py index a3962b0..3d244c0 100644 --- a/tools/debug.py +++ b/tools/debug.py @@ -49,7 +49,8 @@ class DebugIssueTool(BaseTool): "code files as absolute paths. The more context, the better the debugging analysis. " "Choose thinking_mode based on issue complexity: 'low' for simple errors, " "'medium' for standard debugging (default), 'high' for complex system issues, " - "'max' for extremely challenging bugs requiring deepest analysis." + "'max' for extremely challenging bugs requiring deepest analysis. " + "Note: If you're not currently using a top-tier model such as Opus 4 or above, these tools can provide enhanced capabilities." ) def get_input_schema(self) -> dict[str, Any]: diff --git a/tools/precommit.py b/tools/precommit.py index a82eefb..b00e317 100644 --- a/tools/precommit.py +++ b/tools/precommit.py @@ -99,7 +99,8 @@ class Precommit(BaseTool): "whenever the user mentions committing or when changes are complete. " "Choose thinking_mode based on changeset size: 'low' for small focused changes, " "'medium' for standard commits (default), 'high' for large feature branches or complex refactoring, " - "'max' for critical releases or when reviewing extensive changes across multiple systems." + "'max' for critical releases or when reviewing extensive changes across multiple systems. " + "Note: If you're not currently using a top-tier model such as Opus 4 or above, these tools can provide enhanced capabilities." ) def get_input_schema(self) -> dict[str, Any]: diff --git a/tools/thinkdeep.py b/tools/thinkdeep.py index c5636f5..2cf01f9 100644 --- a/tools/thinkdeep.py +++ b/tools/thinkdeep.py @@ -47,7 +47,8 @@ class ThinkDeepTool(BaseTool): "IMPORTANT: Choose the appropriate thinking_mode based on task complexity - " "'low' for quick analysis, 'medium' for standard problems, 'high' for complex issues (default), " "'max' for extremely complex challenges requiring deepest analysis. " - "When in doubt, err on the side of a higher mode for truly deep thought and evaluation." + "When in doubt, err on the side of a higher mode for truly deep thought and evaluation. " + "Note: If you're not currently using a top-tier model such as Opus 4 or above, these tools can provide enhanced capabilities." ) def get_input_schema(self) -> dict[str, Any]: diff --git a/utils/model_restrictions.py b/utils/model_restrictions.py index c06ebcc..22e7d70 100644 --- a/utils/model_restrictions.py +++ b/utils/model_restrictions.py @@ -9,10 +9,12 @@ standardization purposes. Environment Variables: - OPENAI_ALLOWED_MODELS: Comma-separated list of allowed OpenAI models - GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models +- OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models Example: OPENAI_ALLOWED_MODELS=o3-mini,o4-mini GOOGLE_ALLOWED_MODELS=flash + OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral """ import logging @@ -38,6 +40,7 @@ class ModelRestrictionService: ENV_VARS = { ProviderType.OPENAI: "OPENAI_ALLOWED_MODELS", ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS", + ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS", } def __init__(self):