fix: respect OPENROUTER_ALLOWED_MODELS in listmodels tool (#89)

* fix: respect OPENROUTER_ALLOWED_MODELS in listmodels tool

- Modified listmodels tool to use provider's list_models() method with respect_restrictions=True
- This ensures only models allowed by OPENROUTER_ALLOWED_MODELS are shown
- Added note indicating when model restrictions are active
- Fixed total model count to also respect restrictions

Previously, the tool was directly accessing the OpenRouter registry and showing all ~200 models regardless of the OPENROUTER_ALLOWED_MODELS setting.

* test: add tests for listmodels OpenRouter restrictions

- Test that listmodels respects OPENROUTER_ALLOWED_MODELS setting
- Test shows only allowed models when restrictions are set
- Test shows all models when no restrictions are set
- Verify proper use of respect_restrictions parameter

* correcting test

* test: fix test expectations for listmodels

- Update tests to parse JSON response format
- Fix model counting logic to handle provider grouping
- Adjust expectations based on actual tool behavior (max 5 models per provider)
- Tests now properly validate both restricted and unrestricted scenarios

* style: fix code formatting issues

- Applied ruff, black, and isort formatting
- Fixed import order and removed trailing whitespace
- All code quality checks now pass

* fix: improve exception handling based on code review feedback

- Added proper logging for exceptions instead of silent pass
- Import logging module and create logger instance
- Log warnings when error checking OpenRouter restrictions
- Log warnings when error getting total available models
- Maintains backward compatibility while improving debuggability

---------

Co-authored-by: Patryk Ciechanski <patryk.ciechanski@inetum.com>
This commit is contained in:
PCITI
2025-06-20 22:14:21 +02:00
committed by GitHub
parent 69a3121452
commit 76edd30e9a
2 changed files with 303 additions and 25 deletions

View File

@@ -0,0 +1,240 @@
"""Test listmodels tool respects model restrictions."""
import asyncio
import os
import unittest
from unittest.mock import MagicMock, patch
from providers.base import ModelProvider, ProviderType
from providers.registry import ModelProviderRegistry
from tools.listmodels import ListModelsTool
class TestListModelsRestrictions(unittest.TestCase):
"""Test that listmodels respects OPENROUTER_ALLOWED_MODELS."""
def setUp(self):
"""Set up test environment."""
# Clear any existing registry state
ModelProviderRegistry.clear_cache()
# Create mock OpenRouter provider
self.mock_openrouter = MagicMock(spec=ModelProvider)
self.mock_openrouter.provider_type = ProviderType.OPENROUTER
# Create mock Gemini provider for comparison
self.mock_gemini = MagicMock(spec=ModelProvider)
self.mock_gemini.provider_type = ProviderType.GOOGLE
self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"]
def tearDown(self):
"""Clean up after tests."""
ModelProviderRegistry.clear_cache()
# Clean up environment variables
for key in ["OPENROUTER_ALLOWED_MODELS", "OPENROUTER_API_KEY", "GEMINI_API_KEY"]:
os.environ.pop(key, None)
@patch.dict(
os.environ,
{
"OPENROUTER_API_KEY": "test-key",
"OPENROUTER_ALLOWED_MODELS": "opus,sonnet,deepseek/deepseek-r1-0528:free,qwen/qwen3-235b-a22b-04-28:free",
"GEMINI_API_KEY": "gemini-test-key",
},
)
@patch("utils.model_restrictions.get_restriction_service")
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
@patch.object(ModelProviderRegistry, "get_available_models")
@patch.object(ModelProviderRegistry, "get_provider")
def test_listmodels_respects_openrouter_restrictions(
self, mock_get_provider, mock_get_models, mock_registry_class, mock_get_restriction
):
"""Test that listmodels only shows allowed OpenRouter models."""
# Set up mock to return only allowed models when restrictions are respected
# Include both aliased models and full model names without aliases
self.mock_openrouter.list_models.return_value = [
"anthropic/claude-3-opus-20240229", # Has alias "opus"
"anthropic/claude-3-sonnet-20240229", # Has alias "sonnet"
"deepseek/deepseek-r1-0528:free", # No alias, full name
"qwen/qwen3-235b-a22b-04-28:free", # No alias, full name
]
# Mock registry instance
mock_registry = MagicMock()
mock_registry_class.return_value = mock_registry
# Mock resolve method - return config for aliased models, None for others
def resolve_side_effect(model_name):
if "opus" in model_name.lower():
config = MagicMock()
config.model_name = "anthropic/claude-3-opus-20240229"
config.context_window = 200000
return config
elif "sonnet" in model_name.lower():
config = MagicMock()
config.model_name = "anthropic/claude-3-sonnet-20240229"
config.context_window = 200000
return config
return None # No config for models without aliases
mock_registry.resolve.side_effect = resolve_side_effect
# Mock provider registry
def get_provider_side_effect(provider_type, force_new=False):
if provider_type == ProviderType.OPENROUTER:
return self.mock_openrouter
elif provider_type == ProviderType.GOOGLE:
return self.mock_gemini
return None
mock_get_provider.side_effect = get_provider_side_effect
# Mock available models
mock_get_models.return_value = {
"gemini-2.5-flash": ProviderType.GOOGLE,
"gemini-2.5-pro": ProviderType.GOOGLE,
"anthropic/claude-3-opus-20240229": ProviderType.OPENROUTER,
"anthropic/claude-3-sonnet-20240229": ProviderType.OPENROUTER,
"deepseek/deepseek-r1-0528:free": ProviderType.OPENROUTER,
"qwen/qwen3-235b-a22b-04-28:free": ProviderType.OPENROUTER,
}
# Mock restriction service
mock_restriction_service = MagicMock()
mock_restriction_service.has_restrictions.return_value = True
mock_restriction_service.get_allowed_models.return_value = {
"opus",
"sonnet",
"deepseek/deepseek-r1-0528:free",
"qwen/qwen3-235b-a22b-04-28:free",
}
mock_get_restriction.return_value = mock_restriction_service
# Create tool and execute
tool = ListModelsTool()
# Execute asynchronously
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result_contents = loop.run_until_complete(tool.execute({}))
loop.close()
# Extract text content from result
result_text = result_contents[0].text
# Parse JSON response
import json
result_json = json.loads(result_text)
result = result_json["content"]
# Parse the output
lines = result.split("\n")
# Check that OpenRouter section exists
openrouter_section_found = False
openrouter_models = []
in_openrouter_section = False
for line in lines:
if "OpenRouter" in line and "" in line:
openrouter_section_found = True
elif "Available Models" in line and openrouter_section_found:
in_openrouter_section = True
elif in_openrouter_section and line.strip().startswith("- "):
# Extract model name from various line formats:
# - `model-name` → `full-name` (context)
# - `model-name`
line_content = line.strip()[2:] # Remove "- "
if "`" in line_content:
# Extract content between first pair of backticks
model_name = line_content.split("`")[1]
openrouter_models.append(model_name)
self.assertTrue(openrouter_section_found, "OpenRouter section not found")
self.assertEqual(
len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}"
)
# Verify list_models was called with respect_restrictions=True
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True)
# Check for restriction note
self.assertIn("Restricted to models matching:", result)
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"})
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
@patch.object(ModelProviderRegistry, "get_provider")
def test_listmodels_shows_all_models_without_restrictions(self, mock_get_provider, mock_registry_class):
"""Test that listmodels shows all models when no restrictions are set."""
# Set up mock to return many models when no restrictions
all_models = [f"provider{i//10}/model-{i}" for i in range(50)] # Simulate 50 models from different providers
self.mock_openrouter.list_models.return_value = all_models
# Mock registry instance
mock_registry = MagicMock()
mock_registry_class.return_value = mock_registry
mock_registry.resolve.return_value = None # No configs for simplicity
# Mock provider registry
def get_provider_side_effect(provider_type, force_new=False):
if provider_type == ProviderType.OPENROUTER:
return self.mock_openrouter
elif provider_type == ProviderType.GOOGLE:
return self.mock_gemini
return None
mock_get_provider.side_effect = get_provider_side_effect
# Create tool and execute
tool = ListModelsTool()
# Execute asynchronously
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result_contents = loop.run_until_complete(tool.execute({}))
loop.close()
# Extract text content from result
result_text = result_contents[0].text
# Parse JSON response
import json
result_json = json.loads(result_text)
result = result_json["content"]
# Count OpenRouter models specifically
lines = result.split("\n")
openrouter_section_found = False
openrouter_model_count = 0
for line in lines:
if "OpenRouter" in line and "" in line:
openrouter_section_found = True
elif "Custom/Local API" in line:
# End of OpenRouter section
break
elif openrouter_section_found and line.strip().startswith("- ") and "`" in line:
openrouter_model_count += 1
# The tool shows models grouped by provider, max 5 per provider, total max 20
# With 50 models from 5 providers, we expect around 5*5=25, but capped at 20
self.assertGreaterEqual(
openrouter_model_count, 5, f"Expected at least 5 OpenRouter models shown, found {openrouter_model_count}"
)
self.assertLessEqual(
openrouter_model_count, 20, f"Expected at most 20 OpenRouter models shown, found {openrouter_model_count}"
)
# Should show "and X more models available" message
self.assertIn("more models available", result)
# Verify list_models was called with respect_restrictions=True
# (even without restrictions, we always pass True)
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True)
# Should NOT have restriction note when no restrictions are set
self.assertNotIn("Restricted to models matching:", result)
if __name__ == "__main__":
unittest.main()

View File

@@ -6,6 +6,7 @@ organized by their provider (Gemini, OpenAI, X.AI, OpenRouter, Custom).
It shows which providers are configured and what models can be used.
"""
import logging
import os
from typing import Any, Optional
@@ -14,6 +15,8 @@ from mcp.types import TextContent
from tools.base import BaseTool, ToolRequest
from tools.models import ToolModelCategory, ToolOutput
logger = logging.getLogger(__name__)
class ListModelsTool(BaseTool):
"""
@@ -156,29 +159,63 @@ class ListModelsTool(BaseTool):
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
try:
registry = OpenRouterModelRegistry()
aliases = registry.list_aliases()
# Get OpenRouter provider from registry to properly apply restrictions
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
# Group by provider for better organization
providers_models = {}
for alias in aliases[:20]: # Limit to first 20 to avoid overwhelming output
config = registry.resolve(alias)
if config and not (hasattr(config, "is_custom") and config.is_custom):
# Extract provider from model_name
provider = config.model_name.split("/")[0] if "/" in config.model_name else "other"
if provider not in providers_models:
providers_models[provider] = []
providers_models[provider].append((alias, config))
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
if provider:
# Get models with restrictions applied
available_models = provider.list_models(respect_restrictions=True)
registry = OpenRouterModelRegistry()
output_lines.append("\n**Available Models** (showing top 20):")
for provider, models in sorted(providers_models.items()):
output_lines.append(f"\n*{provider.title()}:*")
for alias, config in models[:5]: # Limit each provider to 5 models
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
# Group by provider for better organization
providers_models = {}
for model_name in available_models[:20]: # Limit to first 20 to avoid overwhelming output
# Try to resolve to get config details
config = registry.resolve(model_name)
if config:
# Extract provider from model_name
provider_name = config.model_name.split("/")[0] if "/" in config.model_name else "other"
if provider_name not in providers_models:
providers_models[provider_name] = []
providers_models[provider_name].append((model_name, config))
else:
# Model without config - add with basic info
provider_name = model_name.split("/")[0] if "/" in model_name else "other"
if provider_name not in providers_models:
providers_models[provider_name] = []
providers_models[provider_name].append((model_name, None))
total_models = len(aliases)
output_lines.append(f"\n...and {total_models - 20} more models available")
output_lines.append("\n**Available Models** (showing top 20):")
for provider_name, models in sorted(providers_models.items()):
output_lines.append(f"\n*{provider_name.title()}:*")
for alias, config in models[:5]: # Limit each provider to 5 models
if config:
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
else:
output_lines.append(f"- `{alias}`")
total_models = len(available_models)
if total_models > 20:
output_lines.append(f"\n...and {total_models - 20} more models available")
# Check if restrictions are applied
restriction_service = None
try:
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if restriction_service.has_restrictions(ProviderType.OPENROUTER):
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER)
output_lines.append(
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}"
)
except Exception as e:
logger.warning(f"Error checking OpenRouter restrictions: {e}")
else:
output_lines.append("**Error**: Could not load OpenRouter provider")
except Exception as e:
output_lines.append(f"**Error loading models**: {str(e)}")
@@ -244,13 +281,14 @@ class ListModelsTool(BaseTool):
# Get total available models
try:
from tools.analyze import AnalyzeTool
from providers.registry import ModelProviderRegistry
tool = AnalyzeTool()
total_models = len(tool._get_available_models())
# Get all available models respecting restrictions
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
total_models = len(available_models)
output_lines.append(f"**Total Available Models**: {total_models}")
except Exception:
pass
except Exception as e:
logger.warning(f"Error getting total available models: {e}")
# Add usage tips
output_lines.append("\n**Usage Tips**:")