fix: respect OPENROUTER_ALLOWED_MODELS in listmodels tool (#89)
* fix: respect OPENROUTER_ALLOWED_MODELS in listmodels tool - Modified listmodels tool to use provider's list_models() method with respect_restrictions=True - This ensures only models allowed by OPENROUTER_ALLOWED_MODELS are shown - Added note indicating when model restrictions are active - Fixed total model count to also respect restrictions Previously, the tool was directly accessing the OpenRouter registry and showing all ~200 models regardless of the OPENROUTER_ALLOWED_MODELS setting. * test: add tests for listmodels OpenRouter restrictions - Test that listmodels respects OPENROUTER_ALLOWED_MODELS setting - Test shows only allowed models when restrictions are set - Test shows all models when no restrictions are set - Verify proper use of respect_restrictions parameter * correcting test * test: fix test expectations for listmodels - Update tests to parse JSON response format - Fix model counting logic to handle provider grouping - Adjust expectations based on actual tool behavior (max 5 models per provider) - Tests now properly validate both restricted and unrestricted scenarios * style: fix code formatting issues - Applied ruff, black, and isort formatting - Fixed import order and removed trailing whitespace - All code quality checks now pass * fix: improve exception handling based on code review feedback - Added proper logging for exceptions instead of silent pass - Import logging module and create logger instance - Log warnings when error checking OpenRouter restrictions - Log warnings when error getting total available models - Maintains backward compatibility while improving debuggability --------- Co-authored-by: Patryk Ciechanski <patryk.ciechanski@inetum.com>
This commit is contained in:
240
tests/test_listmodels_restrictions.py
Normal file
240
tests/test_listmodels_restrictions.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Test listmodels tool respects model restrictions."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from providers.base import ModelProvider, ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from tools.listmodels import ListModelsTool
|
||||
|
||||
|
||||
class TestListModelsRestrictions(unittest.TestCase):
|
||||
"""Test that listmodels respects OPENROUTER_ALLOWED_MODELS."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
# Clear any existing registry state
|
||||
ModelProviderRegistry.clear_cache()
|
||||
|
||||
# Create mock OpenRouter provider
|
||||
self.mock_openrouter = MagicMock(spec=ModelProvider)
|
||||
self.mock_openrouter.provider_type = ProviderType.OPENROUTER
|
||||
|
||||
# Create mock Gemini provider for comparison
|
||||
self.mock_gemini = MagicMock(spec=ModelProvider)
|
||||
self.mock_gemini.provider_type = ProviderType.GOOGLE
|
||||
self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"]
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after tests."""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# Clean up environment variables
|
||||
for key in ["OPENROUTER_ALLOWED_MODELS", "OPENROUTER_API_KEY", "GEMINI_API_KEY"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"OPENROUTER_API_KEY": "test-key",
|
||||
"OPENROUTER_ALLOWED_MODELS": "opus,sonnet,deepseek/deepseek-r1-0528:free,qwen/qwen3-235b-a22b-04-28:free",
|
||||
"GEMINI_API_KEY": "gemini-test-key",
|
||||
},
|
||||
)
|
||||
@patch("utils.model_restrictions.get_restriction_service")
|
||||
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
|
||||
@patch.object(ModelProviderRegistry, "get_available_models")
|
||||
@patch.object(ModelProviderRegistry, "get_provider")
|
||||
def test_listmodels_respects_openrouter_restrictions(
|
||||
self, mock_get_provider, mock_get_models, mock_registry_class, mock_get_restriction
|
||||
):
|
||||
"""Test that listmodels only shows allowed OpenRouter models."""
|
||||
# Set up mock to return only allowed models when restrictions are respected
|
||||
# Include both aliased models and full model names without aliases
|
||||
self.mock_openrouter.list_models.return_value = [
|
||||
"anthropic/claude-3-opus-20240229", # Has alias "opus"
|
||||
"anthropic/claude-3-sonnet-20240229", # Has alias "sonnet"
|
||||
"deepseek/deepseek-r1-0528:free", # No alias, full name
|
||||
"qwen/qwen3-235b-a22b-04-28:free", # No alias, full name
|
||||
]
|
||||
|
||||
# Mock registry instance
|
||||
mock_registry = MagicMock()
|
||||
mock_registry_class.return_value = mock_registry
|
||||
|
||||
# Mock resolve method - return config for aliased models, None for others
|
||||
def resolve_side_effect(model_name):
|
||||
if "opus" in model_name.lower():
|
||||
config = MagicMock()
|
||||
config.model_name = "anthropic/claude-3-opus-20240229"
|
||||
config.context_window = 200000
|
||||
return config
|
||||
elif "sonnet" in model_name.lower():
|
||||
config = MagicMock()
|
||||
config.model_name = "anthropic/claude-3-sonnet-20240229"
|
||||
config.context_window = 200000
|
||||
return config
|
||||
return None # No config for models without aliases
|
||||
|
||||
mock_registry.resolve.side_effect = resolve_side_effect
|
||||
|
||||
# Mock provider registry
|
||||
def get_provider_side_effect(provider_type, force_new=False):
|
||||
if provider_type == ProviderType.OPENROUTER:
|
||||
return self.mock_openrouter
|
||||
elif provider_type == ProviderType.GOOGLE:
|
||||
return self.mock_gemini
|
||||
return None
|
||||
|
||||
mock_get_provider.side_effect = get_provider_side_effect
|
||||
|
||||
# Mock available models
|
||||
mock_get_models.return_value = {
|
||||
"gemini-2.5-flash": ProviderType.GOOGLE,
|
||||
"gemini-2.5-pro": ProviderType.GOOGLE,
|
||||
"anthropic/claude-3-opus-20240229": ProviderType.OPENROUTER,
|
||||
"anthropic/claude-3-sonnet-20240229": ProviderType.OPENROUTER,
|
||||
"deepseek/deepseek-r1-0528:free": ProviderType.OPENROUTER,
|
||||
"qwen/qwen3-235b-a22b-04-28:free": ProviderType.OPENROUTER,
|
||||
}
|
||||
|
||||
# Mock restriction service
|
||||
mock_restriction_service = MagicMock()
|
||||
mock_restriction_service.has_restrictions.return_value = True
|
||||
mock_restriction_service.get_allowed_models.return_value = {
|
||||
"opus",
|
||||
"sonnet",
|
||||
"deepseek/deepseek-r1-0528:free",
|
||||
"qwen/qwen3-235b-a22b-04-28:free",
|
||||
}
|
||||
mock_get_restriction.return_value = mock_restriction_service
|
||||
|
||||
# Create tool and execute
|
||||
tool = ListModelsTool()
|
||||
# Execute asynchronously
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result_contents = loop.run_until_complete(tool.execute({}))
|
||||
loop.close()
|
||||
|
||||
# Extract text content from result
|
||||
result_text = result_contents[0].text
|
||||
|
||||
# Parse JSON response
|
||||
import json
|
||||
|
||||
result_json = json.loads(result_text)
|
||||
result = result_json["content"]
|
||||
|
||||
# Parse the output
|
||||
lines = result.split("\n")
|
||||
|
||||
# Check that OpenRouter section exists
|
||||
openrouter_section_found = False
|
||||
openrouter_models = []
|
||||
in_openrouter_section = False
|
||||
|
||||
for line in lines:
|
||||
if "OpenRouter" in line and "✅" in line:
|
||||
openrouter_section_found = True
|
||||
elif "Available Models" in line and openrouter_section_found:
|
||||
in_openrouter_section = True
|
||||
elif in_openrouter_section and line.strip().startswith("- "):
|
||||
# Extract model name from various line formats:
|
||||
# - `model-name` → `full-name` (context)
|
||||
# - `model-name`
|
||||
line_content = line.strip()[2:] # Remove "- "
|
||||
if "`" in line_content:
|
||||
# Extract content between first pair of backticks
|
||||
model_name = line_content.split("`")[1]
|
||||
openrouter_models.append(model_name)
|
||||
|
||||
self.assertTrue(openrouter_section_found, "OpenRouter section not found")
|
||||
self.assertEqual(
|
||||
len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}"
|
||||
)
|
||||
|
||||
# Verify list_models was called with respect_restrictions=True
|
||||
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True)
|
||||
|
||||
# Check for restriction note
|
||||
self.assertIn("Restricted to models matching:", result)
|
||||
|
||||
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"})
|
||||
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
|
||||
@patch.object(ModelProviderRegistry, "get_provider")
|
||||
def test_listmodels_shows_all_models_without_restrictions(self, mock_get_provider, mock_registry_class):
|
||||
"""Test that listmodels shows all models when no restrictions are set."""
|
||||
# Set up mock to return many models when no restrictions
|
||||
all_models = [f"provider{i//10}/model-{i}" for i in range(50)] # Simulate 50 models from different providers
|
||||
self.mock_openrouter.list_models.return_value = all_models
|
||||
|
||||
# Mock registry instance
|
||||
mock_registry = MagicMock()
|
||||
mock_registry_class.return_value = mock_registry
|
||||
mock_registry.resolve.return_value = None # No configs for simplicity
|
||||
|
||||
# Mock provider registry
|
||||
def get_provider_side_effect(provider_type, force_new=False):
|
||||
if provider_type == ProviderType.OPENROUTER:
|
||||
return self.mock_openrouter
|
||||
elif provider_type == ProviderType.GOOGLE:
|
||||
return self.mock_gemini
|
||||
return None
|
||||
|
||||
mock_get_provider.side_effect = get_provider_side_effect
|
||||
|
||||
# Create tool and execute
|
||||
tool = ListModelsTool()
|
||||
# Execute asynchronously
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result_contents = loop.run_until_complete(tool.execute({}))
|
||||
loop.close()
|
||||
|
||||
# Extract text content from result
|
||||
result_text = result_contents[0].text
|
||||
|
||||
# Parse JSON response
|
||||
import json
|
||||
|
||||
result_json = json.loads(result_text)
|
||||
result = result_json["content"]
|
||||
|
||||
# Count OpenRouter models specifically
|
||||
lines = result.split("\n")
|
||||
openrouter_section_found = False
|
||||
openrouter_model_count = 0
|
||||
|
||||
for line in lines:
|
||||
if "OpenRouter" in line and "✅" in line:
|
||||
openrouter_section_found = True
|
||||
elif "Custom/Local API" in line:
|
||||
# End of OpenRouter section
|
||||
break
|
||||
elif openrouter_section_found and line.strip().startswith("- ") and "`" in line:
|
||||
openrouter_model_count += 1
|
||||
|
||||
# The tool shows models grouped by provider, max 5 per provider, total max 20
|
||||
# With 50 models from 5 providers, we expect around 5*5=25, but capped at 20
|
||||
self.assertGreaterEqual(
|
||||
openrouter_model_count, 5, f"Expected at least 5 OpenRouter models shown, found {openrouter_model_count}"
|
||||
)
|
||||
self.assertLessEqual(
|
||||
openrouter_model_count, 20, f"Expected at most 20 OpenRouter models shown, found {openrouter_model_count}"
|
||||
)
|
||||
|
||||
# Should show "and X more models available" message
|
||||
self.assertIn("more models available", result)
|
||||
|
||||
# Verify list_models was called with respect_restrictions=True
|
||||
# (even without restrictions, we always pass True)
|
||||
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True)
|
||||
|
||||
# Should NOT have restriction note when no restrictions are set
|
||||
self.assertNotIn("Restricted to models matching:", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -6,6 +6,7 @@ organized by their provider (Gemini, OpenAI, X.AI, OpenRouter, Custom).
|
||||
It shows which providers are configured and what models can be used.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -14,6 +15,8 @@ from mcp.types import TextContent
|
||||
from tools.base import BaseTool, ToolRequest
|
||||
from tools.models import ToolModelCategory, ToolOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ListModelsTool(BaseTool):
|
||||
"""
|
||||
@@ -156,29 +159,63 @@ class ListModelsTool(BaseTool):
|
||||
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
|
||||
|
||||
try:
|
||||
registry = OpenRouterModelRegistry()
|
||||
aliases = registry.list_aliases()
|
||||
# Get OpenRouter provider from registry to properly apply restrictions
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Group by provider for better organization
|
||||
providers_models = {}
|
||||
for alias in aliases[:20]: # Limit to first 20 to avoid overwhelming output
|
||||
config = registry.resolve(alias)
|
||||
if config and not (hasattr(config, "is_custom") and config.is_custom):
|
||||
# Extract provider from model_name
|
||||
provider = config.model_name.split("/")[0] if "/" in config.model_name else "other"
|
||||
if provider not in providers_models:
|
||||
providers_models[provider] = []
|
||||
providers_models[provider].append((alias, config))
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||
if provider:
|
||||
# Get models with restrictions applied
|
||||
available_models = provider.list_models(respect_restrictions=True)
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
output_lines.append("\n**Available Models** (showing top 20):")
|
||||
for provider, models in sorted(providers_models.items()):
|
||||
output_lines.append(f"\n*{provider.title()}:*")
|
||||
for alias, config in models[:5]: # Limit each provider to 5 models
|
||||
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
|
||||
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
|
||||
# Group by provider for better organization
|
||||
providers_models = {}
|
||||
for model_name in available_models[:20]: # Limit to first 20 to avoid overwhelming output
|
||||
# Try to resolve to get config details
|
||||
config = registry.resolve(model_name)
|
||||
if config:
|
||||
# Extract provider from model_name
|
||||
provider_name = config.model_name.split("/")[0] if "/" in config.model_name else "other"
|
||||
if provider_name not in providers_models:
|
||||
providers_models[provider_name] = []
|
||||
providers_models[provider_name].append((model_name, config))
|
||||
else:
|
||||
# Model without config - add with basic info
|
||||
provider_name = model_name.split("/")[0] if "/" in model_name else "other"
|
||||
if provider_name not in providers_models:
|
||||
providers_models[provider_name] = []
|
||||
providers_models[provider_name].append((model_name, None))
|
||||
|
||||
total_models = len(aliases)
|
||||
output_lines.append(f"\n...and {total_models - 20} more models available")
|
||||
output_lines.append("\n**Available Models** (showing top 20):")
|
||||
for provider_name, models in sorted(providers_models.items()):
|
||||
output_lines.append(f"\n*{provider_name.title()}:*")
|
||||
for alias, config in models[:5]: # Limit each provider to 5 models
|
||||
if config:
|
||||
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
|
||||
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
|
||||
else:
|
||||
output_lines.append(f"- `{alias}`")
|
||||
|
||||
total_models = len(available_models)
|
||||
if total_models > 20:
|
||||
output_lines.append(f"\n...and {total_models - 20} more models available")
|
||||
|
||||
# Check if restrictions are applied
|
||||
restriction_service = None
|
||||
try:
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if restriction_service.has_restrictions(ProviderType.OPENROUTER):
|
||||
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER)
|
||||
output_lines.append(
|
||||
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking OpenRouter restrictions: {e}")
|
||||
else:
|
||||
output_lines.append("**Error**: Could not load OpenRouter provider")
|
||||
|
||||
except Exception as e:
|
||||
output_lines.append(f"**Error loading models**: {str(e)}")
|
||||
@@ -244,13 +281,14 @@ class ListModelsTool(BaseTool):
|
||||
|
||||
# Get total available models
|
||||
try:
|
||||
from tools.analyze import AnalyzeTool
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
tool = AnalyzeTool()
|
||||
total_models = len(tool._get_available_models())
|
||||
# Get all available models respecting restrictions
|
||||
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
total_models = len(available_models)
|
||||
output_lines.append(f"**Total Available Models**: {total_models}")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting total available models: {e}")
|
||||
|
||||
# Add usage tips
|
||||
output_lines.append("\n**Usage Tips**:")
|
||||
|
||||
Reference in New Issue
Block a user