fix: respect OPENROUTER_ALLOWED_MODELS in listmodels tool (#89)
* fix: respect OPENROUTER_ALLOWED_MODELS in listmodels tool - Modified listmodels tool to use provider's list_models() method with respect_restrictions=True - This ensures only models allowed by OPENROUTER_ALLOWED_MODELS are shown - Added note indicating when model restrictions are active - Fixed total model count to also respect restrictions Previously, the tool was directly accessing the OpenRouter registry and showing all ~200 models regardless of the OPENROUTER_ALLOWED_MODELS setting. * test: add tests for listmodels OpenRouter restrictions - Test that listmodels respects OPENROUTER_ALLOWED_MODELS setting - Test shows only allowed models when restrictions are set - Test shows all models when no restrictions are set - Verify proper use of respect_restrictions parameter * correcting test * test: fix test expectations for listmodels - Update tests to parse JSON response format - Fix model counting logic to handle provider grouping - Adjust expectations based on actual tool behavior (max 5 models per provider) - Tests now properly validate both restricted and unrestricted scenarios * style: fix code formatting issues - Applied ruff, black, and isort formatting - Fixed import order and removed trailing whitespace - All code quality checks now pass * fix: improve exception handling based on code review feedback - Added proper logging for exceptions instead of silent pass - Import logging module and create logger instance - Log warnings when error checking OpenRouter restrictions - Log warnings when error getting total available models - Maintains backward compatibility while improving debuggability --------- Co-authored-by: Patryk Ciechanski <patryk.ciechanski@inetum.com>
This commit is contained in:
@@ -6,6 +6,7 @@ organized by their provider (Gemini, OpenAI, X.AI, OpenRouter, Custom).
|
||||
It shows which providers are configured and what models can be used.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -14,6 +15,8 @@ from mcp.types import TextContent
|
||||
from tools.base import BaseTool, ToolRequest
|
||||
from tools.models import ToolModelCategory, ToolOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ListModelsTool(BaseTool):
|
||||
"""
|
||||
@@ -156,29 +159,63 @@ class ListModelsTool(BaseTool):
|
||||
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
|
||||
|
||||
try:
|
||||
registry = OpenRouterModelRegistry()
|
||||
aliases = registry.list_aliases()
|
||||
# Get OpenRouter provider from registry to properly apply restrictions
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Group by provider for better organization
|
||||
providers_models = {}
|
||||
for alias in aliases[:20]: # Limit to first 20 to avoid overwhelming output
|
||||
config = registry.resolve(alias)
|
||||
if config and not (hasattr(config, "is_custom") and config.is_custom):
|
||||
# Extract provider from model_name
|
||||
provider = config.model_name.split("/")[0] if "/" in config.model_name else "other"
|
||||
if provider not in providers_models:
|
||||
providers_models[provider] = []
|
||||
providers_models[provider].append((alias, config))
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||
if provider:
|
||||
# Get models with restrictions applied
|
||||
available_models = provider.list_models(respect_restrictions=True)
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
output_lines.append("\n**Available Models** (showing top 20):")
|
||||
for provider, models in sorted(providers_models.items()):
|
||||
output_lines.append(f"\n*{provider.title()}:*")
|
||||
for alias, config in models[:5]: # Limit each provider to 5 models
|
||||
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
|
||||
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
|
||||
# Group by provider for better organization
|
||||
providers_models = {}
|
||||
for model_name in available_models[:20]: # Limit to first 20 to avoid overwhelming output
|
||||
# Try to resolve to get config details
|
||||
config = registry.resolve(model_name)
|
||||
if config:
|
||||
# Extract provider from model_name
|
||||
provider_name = config.model_name.split("/")[0] if "/" in config.model_name else "other"
|
||||
if provider_name not in providers_models:
|
||||
providers_models[provider_name] = []
|
||||
providers_models[provider_name].append((model_name, config))
|
||||
else:
|
||||
# Model without config - add with basic info
|
||||
provider_name = model_name.split("/")[0] if "/" in model_name else "other"
|
||||
if provider_name not in providers_models:
|
||||
providers_models[provider_name] = []
|
||||
providers_models[provider_name].append((model_name, None))
|
||||
|
||||
total_models = len(aliases)
|
||||
output_lines.append(f"\n...and {total_models - 20} more models available")
|
||||
output_lines.append("\n**Available Models** (showing top 20):")
|
||||
for provider_name, models in sorted(providers_models.items()):
|
||||
output_lines.append(f"\n*{provider_name.title()}:*")
|
||||
for alias, config in models[:5]: # Limit each provider to 5 models
|
||||
if config:
|
||||
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
|
||||
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
|
||||
else:
|
||||
output_lines.append(f"- `{alias}`")
|
||||
|
||||
total_models = len(available_models)
|
||||
if total_models > 20:
|
||||
output_lines.append(f"\n...and {total_models - 20} more models available")
|
||||
|
||||
# Check if restrictions are applied
|
||||
restriction_service = None
|
||||
try:
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if restriction_service.has_restrictions(ProviderType.OPENROUTER):
|
||||
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER)
|
||||
output_lines.append(
|
||||
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking OpenRouter restrictions: {e}")
|
||||
else:
|
||||
output_lines.append("**Error**: Could not load OpenRouter provider")
|
||||
|
||||
except Exception as e:
|
||||
output_lines.append(f"**Error loading models**: {str(e)}")
|
||||
@@ -244,13 +281,14 @@ class ListModelsTool(BaseTool):
|
||||
|
||||
# Get total available models
|
||||
try:
|
||||
from tools.analyze import AnalyzeTool
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
tool = AnalyzeTool()
|
||||
total_models = len(tool._get_available_models())
|
||||
# Get all available models respecting restrictions
|
||||
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
total_models = len(available_models)
|
||||
output_lines.append(f"**Total Available Models**: {total_models}")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting total available models: {e}")
|
||||
|
||||
# Add usage tips
|
||||
output_lines.append("\n**Usage Tips**:")
|
||||
|
||||
Reference in New Issue
Block a user