diff --git a/tests/test_listmodels_restrictions.py b/tests/test_listmodels_restrictions.py new file mode 100644 index 0000000..f92dfbd --- /dev/null +++ b/tests/test_listmodels_restrictions.py @@ -0,0 +1,240 @@ +"""Test listmodels tool respects model restrictions.""" + +import asyncio +import os +import unittest +from unittest.mock import MagicMock, patch + +from providers.base import ModelProvider, ProviderType +from providers.registry import ModelProviderRegistry +from tools.listmodels import ListModelsTool + + +class TestListModelsRestrictions(unittest.TestCase): + """Test that listmodels respects OPENROUTER_ALLOWED_MODELS.""" + + def setUp(self): + """Set up test environment.""" + # Clear any existing registry state + ModelProviderRegistry.clear_cache() + + # Create mock OpenRouter provider + self.mock_openrouter = MagicMock(spec=ModelProvider) + self.mock_openrouter.provider_type = ProviderType.OPENROUTER + + # Create mock Gemini provider for comparison + self.mock_gemini = MagicMock(spec=ModelProvider) + self.mock_gemini.provider_type = ProviderType.GOOGLE + self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"] + + def tearDown(self): + """Clean up after tests.""" + ModelProviderRegistry.clear_cache() + # Clean up environment variables + for key in ["OPENROUTER_ALLOWED_MODELS", "OPENROUTER_API_KEY", "GEMINI_API_KEY"]: + os.environ.pop(key, None) + + @patch.dict( + os.environ, + { + "OPENROUTER_API_KEY": "test-key", + "OPENROUTER_ALLOWED_MODELS": "opus,sonnet,deepseek/deepseek-r1-0528:free,qwen/qwen3-235b-a22b-04-28:free", + "GEMINI_API_KEY": "gemini-test-key", + }, + ) + @patch("utils.model_restrictions.get_restriction_service") + @patch("providers.openrouter_registry.OpenRouterModelRegistry") + @patch.object(ModelProviderRegistry, "get_available_models") + @patch.object(ModelProviderRegistry, "get_provider") + def test_listmodels_respects_openrouter_restrictions( + self, mock_get_provider, mock_get_models, mock_registry_class, mock_get_restriction + ): + """Test that listmodels only shows allowed OpenRouter models.""" + # Set up mock to return only allowed models when restrictions are respected + # Include both aliased models and full model names without aliases + self.mock_openrouter.list_models.return_value = [ + "anthropic/claude-3-opus-20240229", # Has alias "opus" + "anthropic/claude-3-sonnet-20240229", # Has alias "sonnet" + "deepseek/deepseek-r1-0528:free", # No alias, full name + "qwen/qwen3-235b-a22b-04-28:free", # No alias, full name + ] + + # Mock registry instance + mock_registry = MagicMock() + mock_registry_class.return_value = mock_registry + + # Mock resolve method - return config for aliased models, None for others + def resolve_side_effect(model_name): + if "opus" in model_name.lower(): + config = MagicMock() + config.model_name = "anthropic/claude-3-opus-20240229" + config.context_window = 200000 + return config + elif "sonnet" in model_name.lower(): + config = MagicMock() + config.model_name = "anthropic/claude-3-sonnet-20240229" + config.context_window = 200000 + return config + return None # No config for models without aliases + + mock_registry.resolve.side_effect = resolve_side_effect + + # Mock provider registry + def get_provider_side_effect(provider_type, force_new=False): + if provider_type == ProviderType.OPENROUTER: + return self.mock_openrouter + elif provider_type == ProviderType.GOOGLE: + return self.mock_gemini + return None + + mock_get_provider.side_effect = get_provider_side_effect + + # Mock available models + mock_get_models.return_value = { + "gemini-2.5-flash": ProviderType.GOOGLE, + "gemini-2.5-pro": ProviderType.GOOGLE, + "anthropic/claude-3-opus-20240229": ProviderType.OPENROUTER, + "anthropic/claude-3-sonnet-20240229": ProviderType.OPENROUTER, + "deepseek/deepseek-r1-0528:free": ProviderType.OPENROUTER, + "qwen/qwen3-235b-a22b-04-28:free": ProviderType.OPENROUTER, + } + + # Mock restriction service + mock_restriction_service = MagicMock() + mock_restriction_service.has_restrictions.return_value = True + mock_restriction_service.get_allowed_models.return_value = { + "opus", + "sonnet", + "deepseek/deepseek-r1-0528:free", + "qwen/qwen3-235b-a22b-04-28:free", + } + mock_get_restriction.return_value = mock_restriction_service + + # Create tool and execute + tool = ListModelsTool() + # Execute asynchronously + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + result_contents = loop.run_until_complete(tool.execute({})) + loop.close() + + # Extract text content from result + result_text = result_contents[0].text + + # Parse JSON response + import json + + result_json = json.loads(result_text) + result = result_json["content"] + + # Parse the output + lines = result.split("\n") + + # Check that OpenRouter section exists + openrouter_section_found = False + openrouter_models = [] + in_openrouter_section = False + + for line in lines: + if "OpenRouter" in line and "✅" in line: + openrouter_section_found = True + elif "Available Models" in line and openrouter_section_found: + in_openrouter_section = True + elif in_openrouter_section and line.strip().startswith("- "): + # Extract model name from various line formats: + # - `model-name` → `full-name` (context) + # - `model-name` + line_content = line.strip()[2:] # Remove "- " + if "`" in line_content: + # Extract content between first pair of backticks + model_name = line_content.split("`")[1] + openrouter_models.append(model_name) + + self.assertTrue(openrouter_section_found, "OpenRouter section not found") + self.assertEqual( + len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}" + ) + + # Verify list_models was called with respect_restrictions=True + self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True) + + # Check for restriction note + self.assertIn("Restricted to models matching:", result) + + @patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}) + @patch("providers.openrouter_registry.OpenRouterModelRegistry") + @patch.object(ModelProviderRegistry, "get_provider") + def test_listmodels_shows_all_models_without_restrictions(self, mock_get_provider, mock_registry_class): + """Test that listmodels shows all models when no restrictions are set.""" + # Set up mock to return many models when no restrictions + all_models = [f"provider{i//10}/model-{i}" for i in range(50)] # Simulate 50 models from different providers + self.mock_openrouter.list_models.return_value = all_models + + # Mock registry instance + mock_registry = MagicMock() + mock_registry_class.return_value = mock_registry + mock_registry.resolve.return_value = None # No configs for simplicity + + # Mock provider registry + def get_provider_side_effect(provider_type, force_new=False): + if provider_type == ProviderType.OPENROUTER: + return self.mock_openrouter + elif provider_type == ProviderType.GOOGLE: + return self.mock_gemini + return None + + mock_get_provider.side_effect = get_provider_side_effect + + # Create tool and execute + tool = ListModelsTool() + # Execute asynchronously + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + result_contents = loop.run_until_complete(tool.execute({})) + loop.close() + + # Extract text content from result + result_text = result_contents[0].text + + # Parse JSON response + import json + + result_json = json.loads(result_text) + result = result_json["content"] + + # Count OpenRouter models specifically + lines = result.split("\n") + openrouter_section_found = False + openrouter_model_count = 0 + + for line in lines: + if "OpenRouter" in line and "✅" in line: + openrouter_section_found = True + elif "Custom/Local API" in line: + # End of OpenRouter section + break + elif openrouter_section_found and line.strip().startswith("- ") and "`" in line: + openrouter_model_count += 1 + + # The tool shows models grouped by provider, max 5 per provider, total max 20 + # With 50 models from 5 providers, we expect around 5*5=25, but capped at 20 + self.assertGreaterEqual( + openrouter_model_count, 5, f"Expected at least 5 OpenRouter models shown, found {openrouter_model_count}" + ) + self.assertLessEqual( + openrouter_model_count, 20, f"Expected at most 20 OpenRouter models shown, found {openrouter_model_count}" + ) + + # Should show "and X more models available" message + self.assertIn("more models available", result) + + # Verify list_models was called with respect_restrictions=True + # (even without restrictions, we always pass True) + self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True) + + # Should NOT have restriction note when no restrictions are set + self.assertNotIn("Restricted to models matching:", result) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/listmodels.py b/tools/listmodels.py index c1ee10c..06c2e37 100644 --- a/tools/listmodels.py +++ b/tools/listmodels.py @@ -6,6 +6,7 @@ organized by their provider (Gemini, OpenAI, X.AI, OpenRouter, Custom). It shows which providers are configured and what models can be used. """ +import logging import os from typing import Any, Optional @@ -14,6 +15,8 @@ from mcp.types import TextContent from tools.base import BaseTool, ToolRequest from tools.models import ToolModelCategory, ToolOutput +logger = logging.getLogger(__name__) + class ListModelsTool(BaseTool): """ @@ -156,29 +159,63 @@ class ListModelsTool(BaseTool): output_lines.append("**Description**: Access to multiple cloud AI providers via unified API") try: - registry = OpenRouterModelRegistry() - aliases = registry.list_aliases() + # Get OpenRouter provider from registry to properly apply restrictions + from providers.base import ProviderType + from providers.registry import ModelProviderRegistry - # Group by provider for better organization - providers_models = {} - for alias in aliases[:20]: # Limit to first 20 to avoid overwhelming output - config = registry.resolve(alias) - if config and not (hasattr(config, "is_custom") and config.is_custom): - # Extract provider from model_name - provider = config.model_name.split("/")[0] if "/" in config.model_name else "other" - if provider not in providers_models: - providers_models[provider] = [] - providers_models[provider].append((alias, config)) + provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER) + if provider: + # Get models with restrictions applied + available_models = provider.list_models(respect_restrictions=True) + registry = OpenRouterModelRegistry() - output_lines.append("\n**Available Models** (showing top 20):") - for provider, models in sorted(providers_models.items()): - output_lines.append(f"\n*{provider.title()}:*") - for alias, config in models[:5]: # Limit each provider to 5 models - context_str = f"{config.context_window // 1000}K" if config.context_window else "?" - output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)") + # Group by provider for better organization + providers_models = {} + for model_name in available_models[:20]: # Limit to first 20 to avoid overwhelming output + # Try to resolve to get config details + config = registry.resolve(model_name) + if config: + # Extract provider from model_name + provider_name = config.model_name.split("/")[0] if "/" in config.model_name else "other" + if provider_name not in providers_models: + providers_models[provider_name] = [] + providers_models[provider_name].append((model_name, config)) + else: + # Model without config - add with basic info + provider_name = model_name.split("/")[0] if "/" in model_name else "other" + if provider_name not in providers_models: + providers_models[provider_name] = [] + providers_models[provider_name].append((model_name, None)) - total_models = len(aliases) - output_lines.append(f"\n...and {total_models - 20} more models available") + output_lines.append("\n**Available Models** (showing top 20):") + for provider_name, models in sorted(providers_models.items()): + output_lines.append(f"\n*{provider_name.title()}:*") + for alias, config in models[:5]: # Limit each provider to 5 models + if config: + context_str = f"{config.context_window // 1000}K" if config.context_window else "?" + output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)") + else: + output_lines.append(f"- `{alias}`") + + total_models = len(available_models) + if total_models > 20: + output_lines.append(f"\n...and {total_models - 20} more models available") + + # Check if restrictions are applied + restriction_service = None + try: + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + if restriction_service.has_restrictions(ProviderType.OPENROUTER): + allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER) + output_lines.append( + f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}" + ) + except Exception as e: + logger.warning(f"Error checking OpenRouter restrictions: {e}") + else: + output_lines.append("**Error**: Could not load OpenRouter provider") except Exception as e: output_lines.append(f"**Error loading models**: {str(e)}") @@ -244,13 +281,14 @@ class ListModelsTool(BaseTool): # Get total available models try: - from tools.analyze import AnalyzeTool + from providers.registry import ModelProviderRegistry - tool = AnalyzeTool() - total_models = len(tool._get_available_models()) + # Get all available models respecting restrictions + available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True) + total_models = len(available_models) output_lines.append(f"**Total Available Models**: {total_models}") - except Exception: - pass + except Exception as e: + logger.warning(f"Error getting total available models: {e}") # Add usage tips output_lines.append("\n**Usage Tips**:")