fix: respect OPENROUTER_ALLOWED_MODELS in listmodels tool (#89)
* fix: respect OPENROUTER_ALLOWED_MODELS in listmodels tool - Modified listmodels tool to use provider's list_models() method with respect_restrictions=True - This ensures only models allowed by OPENROUTER_ALLOWED_MODELS are shown - Added note indicating when model restrictions are active - Fixed total model count to also respect restrictions Previously, the tool was directly accessing the OpenRouter registry and showing all ~200 models regardless of the OPENROUTER_ALLOWED_MODELS setting. * test: add tests for listmodels OpenRouter restrictions - Test that listmodels respects OPENROUTER_ALLOWED_MODELS setting - Test shows only allowed models when restrictions are set - Test shows all models when no restrictions are set - Verify proper use of respect_restrictions parameter * correcting test * test: fix test expectations for listmodels - Update tests to parse JSON response format - Fix model counting logic to handle provider grouping - Adjust expectations based on actual tool behavior (max 5 models per provider) - Tests now properly validate both restricted and unrestricted scenarios * style: fix code formatting issues - Applied ruff, black, and isort formatting - Fixed import order and removed trailing whitespace - All code quality checks now pass * fix: improve exception handling based on code review feedback - Added proper logging for exceptions instead of silent pass - Import logging module and create logger instance - Log warnings when error checking OpenRouter restrictions - Log warnings when error getting total available models - Maintains backward compatibility while improving debuggability --------- Co-authored-by: Patryk Ciechanski <patryk.ciechanski@inetum.com>
This commit is contained in:
240
tests/test_listmodels_restrictions.py
Normal file
240
tests/test_listmodels_restrictions.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""Test listmodels tool respects model restrictions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from providers.base import ModelProvider, ProviderType
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from tools.listmodels import ListModelsTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestListModelsRestrictions(unittest.TestCase):
|
||||||
|
"""Test that listmodels respects OPENROUTER_ALLOWED_MODELS."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test environment."""
|
||||||
|
# Clear any existing registry state
|
||||||
|
ModelProviderRegistry.clear_cache()
|
||||||
|
|
||||||
|
# Create mock OpenRouter provider
|
||||||
|
self.mock_openrouter = MagicMock(spec=ModelProvider)
|
||||||
|
self.mock_openrouter.provider_type = ProviderType.OPENROUTER
|
||||||
|
|
||||||
|
# Create mock Gemini provider for comparison
|
||||||
|
self.mock_gemini = MagicMock(spec=ModelProvider)
|
||||||
|
self.mock_gemini.provider_type = ProviderType.GOOGLE
|
||||||
|
self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up after tests."""
|
||||||
|
ModelProviderRegistry.clear_cache()
|
||||||
|
# Clean up environment variables
|
||||||
|
for key in ["OPENROUTER_ALLOWED_MODELS", "OPENROUTER_API_KEY", "GEMINI_API_KEY"]:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"OPENROUTER_API_KEY": "test-key",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS": "opus,sonnet,deepseek/deepseek-r1-0528:free,qwen/qwen3-235b-a22b-04-28:free",
|
||||||
|
"GEMINI_API_KEY": "gemini-test-key",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@patch("utils.model_restrictions.get_restriction_service")
|
||||||
|
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
|
||||||
|
@patch.object(ModelProviderRegistry, "get_available_models")
|
||||||
|
@patch.object(ModelProviderRegistry, "get_provider")
|
||||||
|
def test_listmodels_respects_openrouter_restrictions(
|
||||||
|
self, mock_get_provider, mock_get_models, mock_registry_class, mock_get_restriction
|
||||||
|
):
|
||||||
|
"""Test that listmodels only shows allowed OpenRouter models."""
|
||||||
|
# Set up mock to return only allowed models when restrictions are respected
|
||||||
|
# Include both aliased models and full model names without aliases
|
||||||
|
self.mock_openrouter.list_models.return_value = [
|
||||||
|
"anthropic/claude-3-opus-20240229", # Has alias "opus"
|
||||||
|
"anthropic/claude-3-sonnet-20240229", # Has alias "sonnet"
|
||||||
|
"deepseek/deepseek-r1-0528:free", # No alias, full name
|
||||||
|
"qwen/qwen3-235b-a22b-04-28:free", # No alias, full name
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock registry instance
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry_class.return_value = mock_registry
|
||||||
|
|
||||||
|
# Mock resolve method - return config for aliased models, None for others
|
||||||
|
def resolve_side_effect(model_name):
|
||||||
|
if "opus" in model_name.lower():
|
||||||
|
config = MagicMock()
|
||||||
|
config.model_name = "anthropic/claude-3-opus-20240229"
|
||||||
|
config.context_window = 200000
|
||||||
|
return config
|
||||||
|
elif "sonnet" in model_name.lower():
|
||||||
|
config = MagicMock()
|
||||||
|
config.model_name = "anthropic/claude-3-sonnet-20240229"
|
||||||
|
config.context_window = 200000
|
||||||
|
return config
|
||||||
|
return None # No config for models without aliases
|
||||||
|
|
||||||
|
mock_registry.resolve.side_effect = resolve_side_effect
|
||||||
|
|
||||||
|
# Mock provider registry
|
||||||
|
def get_provider_side_effect(provider_type, force_new=False):
|
||||||
|
if provider_type == ProviderType.OPENROUTER:
|
||||||
|
return self.mock_openrouter
|
||||||
|
elif provider_type == ProviderType.GOOGLE:
|
||||||
|
return self.mock_gemini
|
||||||
|
return None
|
||||||
|
|
||||||
|
mock_get_provider.side_effect = get_provider_side_effect
|
||||||
|
|
||||||
|
# Mock available models
|
||||||
|
mock_get_models.return_value = {
|
||||||
|
"gemini-2.5-flash": ProviderType.GOOGLE,
|
||||||
|
"gemini-2.5-pro": ProviderType.GOOGLE,
|
||||||
|
"anthropic/claude-3-opus-20240229": ProviderType.OPENROUTER,
|
||||||
|
"anthropic/claude-3-sonnet-20240229": ProviderType.OPENROUTER,
|
||||||
|
"deepseek/deepseek-r1-0528:free": ProviderType.OPENROUTER,
|
||||||
|
"qwen/qwen3-235b-a22b-04-28:free": ProviderType.OPENROUTER,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock restriction service
|
||||||
|
mock_restriction_service = MagicMock()
|
||||||
|
mock_restriction_service.has_restrictions.return_value = True
|
||||||
|
mock_restriction_service.get_allowed_models.return_value = {
|
||||||
|
"opus",
|
||||||
|
"sonnet",
|
||||||
|
"deepseek/deepseek-r1-0528:free",
|
||||||
|
"qwen/qwen3-235b-a22b-04-28:free",
|
||||||
|
}
|
||||||
|
mock_get_restriction.return_value = mock_restriction_service
|
||||||
|
|
||||||
|
# Create tool and execute
|
||||||
|
tool = ListModelsTool()
|
||||||
|
# Execute asynchronously
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
result_contents = loop.run_until_complete(tool.execute({}))
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Extract text content from result
|
||||||
|
result_text = result_contents[0].text
|
||||||
|
|
||||||
|
# Parse JSON response
|
||||||
|
import json
|
||||||
|
|
||||||
|
result_json = json.loads(result_text)
|
||||||
|
result = result_json["content"]
|
||||||
|
|
||||||
|
# Parse the output
|
||||||
|
lines = result.split("\n")
|
||||||
|
|
||||||
|
# Check that OpenRouter section exists
|
||||||
|
openrouter_section_found = False
|
||||||
|
openrouter_models = []
|
||||||
|
in_openrouter_section = False
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if "OpenRouter" in line and "✅" in line:
|
||||||
|
openrouter_section_found = True
|
||||||
|
elif "Available Models" in line and openrouter_section_found:
|
||||||
|
in_openrouter_section = True
|
||||||
|
elif in_openrouter_section and line.strip().startswith("- "):
|
||||||
|
# Extract model name from various line formats:
|
||||||
|
# - `model-name` → `full-name` (context)
|
||||||
|
# - `model-name`
|
||||||
|
line_content = line.strip()[2:] # Remove "- "
|
||||||
|
if "`" in line_content:
|
||||||
|
# Extract content between first pair of backticks
|
||||||
|
model_name = line_content.split("`")[1]
|
||||||
|
openrouter_models.append(model_name)
|
||||||
|
|
||||||
|
self.assertTrue(openrouter_section_found, "OpenRouter section not found")
|
||||||
|
self.assertEqual(
|
||||||
|
len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify list_models was called with respect_restrictions=True
|
||||||
|
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True)
|
||||||
|
|
||||||
|
# Check for restriction note
|
||||||
|
self.assertIn("Restricted to models matching:", result)
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"})
|
||||||
|
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
|
||||||
|
@patch.object(ModelProviderRegistry, "get_provider")
|
||||||
|
def test_listmodels_shows_all_models_without_restrictions(self, mock_get_provider, mock_registry_class):
|
||||||
|
"""Test that listmodels shows all models when no restrictions are set."""
|
||||||
|
# Set up mock to return many models when no restrictions
|
||||||
|
all_models = [f"provider{i//10}/model-{i}" for i in range(50)] # Simulate 50 models from different providers
|
||||||
|
self.mock_openrouter.list_models.return_value = all_models
|
||||||
|
|
||||||
|
# Mock registry instance
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry_class.return_value = mock_registry
|
||||||
|
mock_registry.resolve.return_value = None # No configs for simplicity
|
||||||
|
|
||||||
|
# Mock provider registry
|
||||||
|
def get_provider_side_effect(provider_type, force_new=False):
|
||||||
|
if provider_type == ProviderType.OPENROUTER:
|
||||||
|
return self.mock_openrouter
|
||||||
|
elif provider_type == ProviderType.GOOGLE:
|
||||||
|
return self.mock_gemini
|
||||||
|
return None
|
||||||
|
|
||||||
|
mock_get_provider.side_effect = get_provider_side_effect
|
||||||
|
|
||||||
|
# Create tool and execute
|
||||||
|
tool = ListModelsTool()
|
||||||
|
# Execute asynchronously
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
result_contents = loop.run_until_complete(tool.execute({}))
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Extract text content from result
|
||||||
|
result_text = result_contents[0].text
|
||||||
|
|
||||||
|
# Parse JSON response
|
||||||
|
import json
|
||||||
|
|
||||||
|
result_json = json.loads(result_text)
|
||||||
|
result = result_json["content"]
|
||||||
|
|
||||||
|
# Count OpenRouter models specifically
|
||||||
|
lines = result.split("\n")
|
||||||
|
openrouter_section_found = False
|
||||||
|
openrouter_model_count = 0
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if "OpenRouter" in line and "✅" in line:
|
||||||
|
openrouter_section_found = True
|
||||||
|
elif "Custom/Local API" in line:
|
||||||
|
# End of OpenRouter section
|
||||||
|
break
|
||||||
|
elif openrouter_section_found and line.strip().startswith("- ") and "`" in line:
|
||||||
|
openrouter_model_count += 1
|
||||||
|
|
||||||
|
# The tool shows models grouped by provider, max 5 per provider, total max 20
|
||||||
|
# With 50 models from 5 providers, we expect around 5*5=25, but capped at 20
|
||||||
|
self.assertGreaterEqual(
|
||||||
|
openrouter_model_count, 5, f"Expected at least 5 OpenRouter models shown, found {openrouter_model_count}"
|
||||||
|
)
|
||||||
|
self.assertLessEqual(
|
||||||
|
openrouter_model_count, 20, f"Expected at most 20 OpenRouter models shown, found {openrouter_model_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should show "and X more models available" message
|
||||||
|
self.assertIn("more models available", result)
|
||||||
|
|
||||||
|
# Verify list_models was called with respect_restrictions=True
|
||||||
|
# (even without restrictions, we always pass True)
|
||||||
|
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True)
|
||||||
|
|
||||||
|
# Should NOT have restriction note when no restrictions are set
|
||||||
|
self.assertNotIn("Restricted to models matching:", result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -6,6 +6,7 @@ organized by their provider (Gemini, OpenAI, X.AI, OpenRouter, Custom).
|
|||||||
It shows which providers are configured and what models can be used.
|
It shows which providers are configured and what models can be used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@@ -14,6 +15,8 @@ from mcp.types import TextContent
|
|||||||
from tools.base import BaseTool, ToolRequest
|
from tools.base import BaseTool, ToolRequest
|
||||||
from tools.models import ToolModelCategory, ToolOutput
|
from tools.models import ToolModelCategory, ToolOutput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ListModelsTool(BaseTool):
|
class ListModelsTool(BaseTool):
|
||||||
"""
|
"""
|
||||||
@@ -156,29 +159,63 @@ class ListModelsTool(BaseTool):
|
|||||||
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
|
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
registry = OpenRouterModelRegistry()
|
# Get OpenRouter provider from registry to properly apply restrictions
|
||||||
aliases = registry.list_aliases()
|
from providers.base import ProviderType
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
# Group by provider for better organization
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||||
providers_models = {}
|
if provider:
|
||||||
for alias in aliases[:20]: # Limit to first 20 to avoid overwhelming output
|
# Get models with restrictions applied
|
||||||
config = registry.resolve(alias)
|
available_models = provider.list_models(respect_restrictions=True)
|
||||||
if config and not (hasattr(config, "is_custom") and config.is_custom):
|
registry = OpenRouterModelRegistry()
|
||||||
# Extract provider from model_name
|
|
||||||
provider = config.model_name.split("/")[0] if "/" in config.model_name else "other"
|
|
||||||
if provider not in providers_models:
|
|
||||||
providers_models[provider] = []
|
|
||||||
providers_models[provider].append((alias, config))
|
|
||||||
|
|
||||||
output_lines.append("\n**Available Models** (showing top 20):")
|
# Group by provider for better organization
|
||||||
for provider, models in sorted(providers_models.items()):
|
providers_models = {}
|
||||||
output_lines.append(f"\n*{provider.title()}:*")
|
for model_name in available_models[:20]: # Limit to first 20 to avoid overwhelming output
|
||||||
for alias, config in models[:5]: # Limit each provider to 5 models
|
# Try to resolve to get config details
|
||||||
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
|
config = registry.resolve(model_name)
|
||||||
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
|
if config:
|
||||||
|
# Extract provider from model_name
|
||||||
|
provider_name = config.model_name.split("/")[0] if "/" in config.model_name else "other"
|
||||||
|
if provider_name not in providers_models:
|
||||||
|
providers_models[provider_name] = []
|
||||||
|
providers_models[provider_name].append((model_name, config))
|
||||||
|
else:
|
||||||
|
# Model without config - add with basic info
|
||||||
|
provider_name = model_name.split("/")[0] if "/" in model_name else "other"
|
||||||
|
if provider_name not in providers_models:
|
||||||
|
providers_models[provider_name] = []
|
||||||
|
providers_models[provider_name].append((model_name, None))
|
||||||
|
|
||||||
total_models = len(aliases)
|
output_lines.append("\n**Available Models** (showing top 20):")
|
||||||
output_lines.append(f"\n...and {total_models - 20} more models available")
|
for provider_name, models in sorted(providers_models.items()):
|
||||||
|
output_lines.append(f"\n*{provider_name.title()}:*")
|
||||||
|
for alias, config in models[:5]: # Limit each provider to 5 models
|
||||||
|
if config:
|
||||||
|
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
|
||||||
|
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
|
||||||
|
else:
|
||||||
|
output_lines.append(f"- `{alias}`")
|
||||||
|
|
||||||
|
total_models = len(available_models)
|
||||||
|
if total_models > 20:
|
||||||
|
output_lines.append(f"\n...and {total_models - 20} more models available")
|
||||||
|
|
||||||
|
# Check if restrictions are applied
|
||||||
|
restriction_service = None
|
||||||
|
try:
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service()
|
||||||
|
if restriction_service.has_restrictions(ProviderType.OPENROUTER):
|
||||||
|
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER)
|
||||||
|
output_lines.append(
|
||||||
|
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error checking OpenRouter restrictions: {e}")
|
||||||
|
else:
|
||||||
|
output_lines.append("**Error**: Could not load OpenRouter provider")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
output_lines.append(f"**Error loading models**: {str(e)}")
|
output_lines.append(f"**Error loading models**: {str(e)}")
|
||||||
@@ -244,13 +281,14 @@ class ListModelsTool(BaseTool):
|
|||||||
|
|
||||||
# Get total available models
|
# Get total available models
|
||||||
try:
|
try:
|
||||||
from tools.analyze import AnalyzeTool
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
tool = AnalyzeTool()
|
# Get all available models respecting restrictions
|
||||||
total_models = len(tool._get_available_models())
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
total_models = len(available_models)
|
||||||
output_lines.append(f"**Total Available Models**: {total_models}")
|
output_lines.append(f"**Total Available Models**: {total_models}")
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.warning(f"Error getting total available models: {e}")
|
||||||
|
|
||||||
# Add usage tips
|
# Add usage tips
|
||||||
output_lines.append("\n**Usage Tips**:")
|
output_lines.append("\n**Usage Tips**:")
|
||||||
|
|||||||
Reference in New Issue
Block a user