WIP
- OpenRouter model configuration registry - Model definition file for users to be able to control - Additional tests - Update instructions
This commit is contained in:
@@ -56,7 +56,9 @@ MODEL_CAPABILITIES_DESC = {
|
||||
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||
# Full model names also supported
|
||||
"gemini-2.5-flash-preview-05-20": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||
"gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||
"gemini-2.5-pro-preview-06-05": (
|
||||
"Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis"
|
||||
),
|
||||
}
|
||||
|
||||
# Note: When only OpenRouter is configured, these model aliases automatically map to equivalent models:
|
||||
|
||||
@@ -141,7 +141,11 @@ trace issues to their root cause, and provide actionable solutions.
|
||||
IMPORTANT: If you lack critical information to proceed (e.g., missing files, ambiguous error details,
|
||||
insufficient context), OR if the provided diagnostics (log files, crash reports, stack traces) appear irrelevant,
|
||||
incomplete, or insufficient for proper analysis, you MUST respond ONLY with this JSON format:
|
||||
{"status": "requires_clarification", "question": "What specific information you need from Claude or the user to proceed with debugging", "files_needed": ["file1.py", "file2.py"]}
|
||||
{
|
||||
"status": "requires_clarification",
|
||||
"question": "What specific information you need from Claude or the user to proceed with debugging",
|
||||
"files_needed": ["file1.py", "file2.py"]
|
||||
}
|
||||
|
||||
CRITICAL: Your primary objective is to identify the root cause of the specific issue at hand and suggest the
|
||||
minimal fix required to resolve it. Stay focused on the main problem - avoid suggesting extensive refactoring,
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
"""OpenAI model provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
FixedTemperatureConstraint,
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
@@ -34,7 +30,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
kwargs.setdefault("base_url", "https://api.openai.com/v1")
|
||||
super().__init__(api_key, **kwargs)
|
||||
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific OpenAI model."""
|
||||
if model_name not in self.SUPPORTED_MODELS:
|
||||
@@ -62,7 +57,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
temperature_constraint=temp_constraint,
|
||||
)
|
||||
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.OPENAI
|
||||
@@ -76,4 +70,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
# Currently no OpenAI models support extended thinking
|
||||
# This may change with future O3 models
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""Base class for OpenAI-compatible API providers."""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
import ipaddress
|
||||
import socket
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -15,7 +15,6 @@ from .base import (
|
||||
ModelProvider,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
|
||||
|
||||
@@ -69,7 +68,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
if models_str:
|
||||
# Parse and normalize to lowercase for case-insensitive comparison
|
||||
models = set(m.strip().lower() for m in models_str.split(",") if m.strip())
|
||||
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
||||
if models:
|
||||
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
||||
return models
|
||||
@@ -97,7 +96,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
hostname = parsed.hostname
|
||||
|
||||
# Check for common localhost patterns
|
||||
if hostname in ['localhost', '127.0.0.1', '::1']:
|
||||
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -116,9 +115,8 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
try:
|
||||
parsed = urlparse(self.base_url)
|
||||
|
||||
|
||||
# Check URL scheme - only allow http/https
|
||||
if parsed.scheme not in ('http', 'https'):
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
|
||||
|
||||
# Check hostname exists
|
||||
@@ -128,14 +126,12 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
# Check port - allow only standard HTTP/HTTPS ports
|
||||
port = parsed.port
|
||||
if port is None:
|
||||
port = 443 if parsed.scheme == 'https' else 80
|
||||
port = 443 if parsed.scheme == "https" else 80
|
||||
|
||||
# Allow common HTTP ports and some alternative ports
|
||||
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
|
||||
if port not in allowed_ports:
|
||||
raise ValueError(
|
||||
f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}"
|
||||
)
|
||||
raise ValueError(f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}")
|
||||
|
||||
# Check against allowed domains if configured
|
||||
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
|
||||
@@ -144,30 +140,33 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
if allowed_domains:
|
||||
hostname_lower = parsed.hostname.lower()
|
||||
if not any(
|
||||
hostname_lower == domain or
|
||||
hostname_lower.endswith('.' + domain)
|
||||
for domain in allowed_domains
|
||||
hostname_lower == domain or hostname_lower.endswith("." + domain) for domain in allowed_domains
|
||||
):
|
||||
raise ValueError(
|
||||
f"Domain not in allow-list: {parsed.hostname}. "
|
||||
f"Allowed domains: {allowed_domains}"
|
||||
f"Domain not in allow-list: {parsed.hostname}. " f"Allowed domains: {allowed_domains}"
|
||||
)
|
||||
|
||||
# Try to resolve hostname and check if it's a private IP
|
||||
# Skip for localhost addresses which are commonly used for development
|
||||
if parsed.hostname not in ['localhost', '127.0.0.1', '::1']:
|
||||
if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
|
||||
try:
|
||||
# Get all IP addresses for the hostname
|
||||
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
|
||||
|
||||
for family, _, _, _, sockaddr in addr_info:
|
||||
for _family, _, _, _, sockaddr in addr_info:
|
||||
ip_str = sockaddr[0]
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
|
||||
# Check for dangerous IP ranges
|
||||
if (ip.is_private or ip.is_loopback or ip.is_link_local or
|
||||
ip.is_multicast or ip.is_reserved or ip.is_unspecified):
|
||||
if (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
or ip.is_unspecified
|
||||
):
|
||||
raise ValueError(
|
||||
f"URL resolves to restricted IP address: {ip_str}. "
|
||||
"This could be a security risk (SSRF)."
|
||||
@@ -234,8 +233,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
# Validate model name against allow-list
|
||||
if not self.validate_model_name(model_name):
|
||||
raise ValueError(
|
||||
f"Model '{model_name}' not in allowed models list. "
|
||||
f"Allowed models: {self.allowed_models}"
|
||||
f"Model '{model_name}' not in allowed models list. " f"Allowed models: {self.allowed_models}"
|
||||
)
|
||||
|
||||
# Validate parameters
|
||||
@@ -307,7 +305,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
Estimated token count
|
||||
"""
|
||||
# 1. Check if provider has a remote token counting endpoint
|
||||
if hasattr(self, 'count_tokens_remote'):
|
||||
if hasattr(self, "count_tokens_remote"):
|
||||
try:
|
||||
return self.count_tokens_remote(text, model_name)
|
||||
except Exception as e:
|
||||
@@ -353,10 +351,9 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
# Check if we're using generic capabilities
|
||||
if hasattr(capabilities, '_is_generic'):
|
||||
if hasattr(capabilities, "_is_generic"):
|
||||
logging.debug(
|
||||
f"Using generic parameter validation for {model_name}. "
|
||||
"Actual model constraints may differ."
|
||||
f"Using generic parameter validation for {model_name}. " "Actual model constraints may differ."
|
||||
)
|
||||
|
||||
# Validate temperature using parent class method
|
||||
|
||||
@@ -49,9 +49,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
# Log loaded models and aliases
|
||||
models = self._registry.list_models()
|
||||
aliases = self._registry.list_aliases()
|
||||
logging.info(
|
||||
f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases"
|
||||
)
|
||||
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
|
||||
|
||||
def _parse_allowed_models(self) -> None:
|
||||
"""Override to disable environment-based allow-list.
|
||||
@@ -175,7 +173,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||
|
||||
@@ -15,7 +15,7 @@ class OpenRouterModelConfig:
|
||||
"""Configuration for an OpenRouter model."""
|
||||
|
||||
model_name: str
|
||||
aliases: List[str] = field(default_factory=list)
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
context_window: int = 32768 # Total context window size in tokens
|
||||
supports_extended_thinking: bool = False
|
||||
supports_system_prompts: bool = True
|
||||
@@ -24,7 +24,6 @@ class OpenRouterModelConfig:
|
||||
supports_json_mode: bool = False
|
||||
description: str = ""
|
||||
|
||||
|
||||
def to_capabilities(self) -> ModelCapabilities:
|
||||
"""Convert to ModelCapabilities object."""
|
||||
return ModelCapabilities(
|
||||
@@ -49,8 +48,8 @@ class OpenRouterModelRegistry:
|
||||
Args:
|
||||
config_path: Path to config file. If None, uses default locations.
|
||||
"""
|
||||
self.alias_map: Dict[str, str] = {} # alias -> model_name
|
||||
self.model_map: Dict[str, OpenRouterModelConfig] = {} # model_name -> config
|
||||
self.alias_map: dict[str, str] = {} # alias -> model_name
|
||||
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config
|
||||
|
||||
# Determine config path
|
||||
if config_path:
|
||||
@@ -73,13 +72,21 @@ class OpenRouterModelRegistry:
|
||||
configs = self._read_config()
|
||||
self._build_maps(configs)
|
||||
logging.info(f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases")
|
||||
except ValueError as e:
|
||||
# Re-raise ValueError only for duplicate aliases (critical config errors)
|
||||
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||
# Initialize with empty maps on failure
|
||||
self.alias_map = {}
|
||||
self.model_map = {}
|
||||
if "Duplicate alias" in str(e):
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||
# Initialize with empty maps on failure
|
||||
self.alias_map = {}
|
||||
self.model_map = {}
|
||||
|
||||
def _read_config(self) -> List[OpenRouterModelConfig]:
|
||||
def _read_config(self) -> list[OpenRouterModelConfig]:
|
||||
"""Read configuration from file.
|
||||
|
||||
Returns:
|
||||
@@ -90,15 +97,15 @@ class OpenRouterModelRegistry:
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(self.config_path, 'r') as f:
|
||||
with open(self.config_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Parse models
|
||||
configs = []
|
||||
for model_data in data.get("models", []):
|
||||
# Handle backwards compatibility - rename max_tokens to context_window
|
||||
if 'max_tokens' in model_data and 'context_window' not in model_data:
|
||||
model_data['context_window'] = model_data.pop('max_tokens')
|
||||
if "max_tokens" in model_data and "context_window" not in model_data:
|
||||
model_data["context_window"] = model_data.pop("max_tokens")
|
||||
|
||||
config = OpenRouterModelConfig(**model_data)
|
||||
configs.append(config)
|
||||
@@ -109,7 +116,7 @@ class OpenRouterModelRegistry:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading config from {self.config_path}: {e}")
|
||||
|
||||
def _build_maps(self, configs: List[OpenRouterModelConfig]) -> None:
|
||||
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None:
|
||||
"""Build alias and model maps from configurations.
|
||||
|
||||
Args:
|
||||
@@ -128,8 +135,7 @@ class OpenRouterModelRegistry:
|
||||
if alias_lower in alias_map:
|
||||
existing_model = alias_map[alias_lower]
|
||||
raise ValueError(
|
||||
f"Duplicate alias '{alias}' found for models "
|
||||
f"'{existing_model}' and '{config.model_name}'"
|
||||
f"Duplicate alias '{alias}' found for models " f"'{existing_model}' and '{config.model_name}'"
|
||||
)
|
||||
alias_map[alias_lower] = config.model_name
|
||||
|
||||
@@ -169,10 +175,10 @@ class OpenRouterModelRegistry:
|
||||
return config.to_capabilities()
|
||||
return None
|
||||
|
||||
def list_models(self) -> List[str]:
|
||||
def list_models(self) -> list[str]:
|
||||
"""List all available model names."""
|
||||
return list(self.model_map.keys())
|
||||
|
||||
def list_aliases(self) -> List[str]:
|
||||
def list_aliases(self) -> list[str]:
|
||||
"""List all available aliases."""
|
||||
return list(self.alias_map.keys())
|
||||
24
server.py
24
server.py
@@ -173,8 +173,7 @@ def configure_providers():
|
||||
"1. Use only OpenRouter: unset GEMINI_API_KEY and OPENAI_API_KEY\n"
|
||||
"2. Use only native APIs: unset OPENROUTER_API_KEY\n"
|
||||
"\n"
|
||||
"Current configuration will prioritize native APIs over OpenRouter.\n" +
|
||||
"=" * 70 + "\n"
|
||||
"Current configuration will prioritize native APIs over OpenRouter.\n" + "=" * 70 + "\n"
|
||||
)
|
||||
|
||||
# Register providers - native APIs first to ensure they take priority
|
||||
@@ -363,18 +362,22 @@ If something needs clarification or you'd benefit from additional context, simpl
|
||||
IMPORTANT: When you suggest follow-ups or ask questions, you MUST explicitly instruct Claude to use the continuation_id
|
||||
to respond. Use clear, direct language based on urgency:
|
||||
|
||||
For optional follow-ups: "Please continue this conversation using the continuation_id from this response if you'd like to explore this further."
|
||||
For optional follow-ups: "Please continue this conversation using the continuation_id from this response if you'd "
|
||||
"like to explore this further."
|
||||
|
||||
For needed responses: "Please respond using the continuation_id from this response - your input is needed to proceed."
|
||||
|
||||
For essential/critical responses: "RESPONSE REQUIRED: Please immediately continue using the continuation_id from this response. Cannot proceed without your clarification/input."
|
||||
For essential/critical responses: "RESPONSE REQUIRED: Please immediately continue using the continuation_id from "
|
||||
"this response. Cannot proceed without your clarification/input."
|
||||
|
||||
This ensures Claude knows both HOW to maintain the conversation thread AND whether a response is optional, needed, or essential.
|
||||
This ensures Claude knows both HOW to maintain the conversation thread AND whether a response is optional, "
|
||||
"needed, or essential.
|
||||
|
||||
The tool will automatically provide a continuation_id in the structured response that Claude can use in subsequent
|
||||
tool calls to maintain full conversation context across multiple exchanges.
|
||||
|
||||
Remember: Only suggest follow-ups when they would genuinely add value to the discussion, and always instruct Claude to use the continuation_id when you do."""
|
||||
Remember: Only suggest follow-ups when they would genuinely add value to the discussion, and always instruct "
|
||||
"Claude to use the continuation_id when you do."""
|
||||
|
||||
|
||||
async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -411,8 +414,10 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
||||
# Return error asking Claude to restart conversation with full context
|
||||
raise ValueError(
|
||||
f"Conversation thread '{continuation_id}' was not found or has expired. "
|
||||
f"This may happen if the conversation was created more than 1 hour ago or if there was an issue with Redis storage. "
|
||||
f"Please restart the conversation by providing your full question/prompt without the continuation_id parameter. "
|
||||
f"This may happen if the conversation was created more than 1 hour ago or if there was an issue "
|
||||
f"with Redis storage. "
|
||||
f"Please restart the conversation by providing your full question/prompt without the "
|
||||
f"continuation_id parameter. "
|
||||
f"This will create a new conversation thread that can continue with follow-up exchanges."
|
||||
)
|
||||
|
||||
@@ -504,7 +509,8 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
||||
try:
|
||||
mcp_activity_logger = logging.getLogger("mcp_activity")
|
||||
mcp_activity_logger.info(
|
||||
f"CONVERSATION_CONTINUATION: Thread {continuation_id} turn {len(context.turns)} - {len(context.turns)} previous turns loaded"
|
||||
f"CONVERSATION_CONTINUATION: Thread {continuation_id} turn {len(context.turns)} - "
|
||||
f"{len(context.turns)} previous turns loaded"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -4,10 +4,12 @@ Test OpenRouter model mapping
|
||||
"""
|
||||
|
||||
import sys
|
||||
sys.path.append('/Users/fahad/Developer/gemini-mcp-server')
|
||||
|
||||
sys.path.append("/Users/fahad/Developer/gemini-mcp-server")
|
||||
|
||||
from simulator_tests.base_test import BaseSimulatorTest
|
||||
|
||||
|
||||
class MappingTest(BaseSimulatorTest):
|
||||
def test_mapping(self):
|
||||
"""Test model alias mapping"""
|
||||
@@ -20,18 +22,19 @@ class MappingTest(BaseSimulatorTest):
|
||||
{
|
||||
"prompt": "Say 'Hello from Flash model!'",
|
||||
"model": "flash", # Should be mapped to google/gemini-flash-1.5-8b
|
||||
"temperature": 0.1
|
||||
}
|
||||
"temperature": 0.1,
|
||||
},
|
||||
)
|
||||
|
||||
if response:
|
||||
print(f"✅ Flash alias worked!")
|
||||
print("✅ Flash alias worked!")
|
||||
print(f"Response: {response[:200]}...")
|
||||
return True
|
||||
else:
|
||||
print("❌ Flash alias failed")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = MappingTest(verbose=False)
|
||||
success = test.test_mapping()
|
||||
|
||||
@@ -97,7 +97,8 @@ class TestAutoMode:
|
||||
# Model field should have simpler description
|
||||
model_schema = schema["properties"]["model"]
|
||||
assert "enum" not in model_schema
|
||||
assert "Available:" in model_schema["description"]
|
||||
assert "Native models:" in model_schema["description"]
|
||||
assert "Defaults to" in model_schema["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_requires_model_parameter(self):
|
||||
@@ -180,8 +181,9 @@ class TestAutoMode:
|
||||
|
||||
schema = tool.get_model_field_schema()
|
||||
assert "enum" not in schema
|
||||
assert "Available:" in schema["description"]
|
||||
assert "Native models:" in schema["description"]
|
||||
assert "'pro'" in schema["description"]
|
||||
assert "Defaults to" in schema["description"]
|
||||
|
||||
finally:
|
||||
# Restore
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Tests for OpenRouter provider."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
@@ -26,12 +25,11 @@ class TestOpenRouterProvider:
|
||||
assert "X-Title" in OpenRouterProvider.DEFAULT_HEADERS
|
||||
|
||||
# Test with environment variables
|
||||
with patch.dict(os.environ, {
|
||||
"OPENROUTER_REFERER": "https://myapp.com",
|
||||
"OPENROUTER_TITLE": "My App"
|
||||
}):
|
||||
with patch.dict(os.environ, {"OPENROUTER_REFERER": "https://myapp.com", "OPENROUTER_TITLE": "My App"}):
|
||||
from importlib import reload
|
||||
|
||||
import providers.openrouter
|
||||
|
||||
reload(providers.openrouter)
|
||||
|
||||
provider = providers.openrouter.OpenRouterProvider(api_key="test-key")
|
||||
@@ -64,7 +62,7 @@ class TestOpenRouterProvider:
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "unknown-model"
|
||||
assert caps.max_tokens == 32_768 # Safe default
|
||||
assert hasattr(caps, '_is_generic') and caps._is_generic is True
|
||||
assert hasattr(caps, "_is_generic") and caps._is_generic is True
|
||||
|
||||
def test_model_alias_resolution(self):
|
||||
"""Test model alias resolution."""
|
||||
@@ -166,48 +164,34 @@ class TestOpenRouterRegistry:
|
||||
assert config.model_name == "anthropic/claude-3-sonnet"
|
||||
|
||||
|
||||
class TestOpenRouterSSRFProtection:
|
||||
"""Test SSRF protection for OpenRouter."""
|
||||
class TestOpenRouterFunctionality:
|
||||
"""Test OpenRouter-specific functionality."""
|
||||
|
||||
def test_url_validation_rejects_private_ips(self):
|
||||
"""Test that private IPs are rejected."""
|
||||
def test_openrouter_always_uses_correct_url(self):
|
||||
"""Test that OpenRouter always uses the correct base URL."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
assert provider.base_url == "https://openrouter.ai/api/v1"
|
||||
|
||||
# Even if we try to change it, it should remain the OpenRouter URL
|
||||
# (This is a characteristic of the OpenRouter provider)
|
||||
provider.base_url = "http://example.com" # Try to change it
|
||||
# But new instances should always use the correct URL
|
||||
provider2 = OpenRouterProvider(api_key="test-key")
|
||||
assert provider2.base_url == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_openrouter_headers_set_correctly(self):
|
||||
"""Test that OpenRouter specific headers are set."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
|
||||
# List of private/dangerous IPs to test
|
||||
dangerous_urls = [
|
||||
"http://192.168.1.1/api/v1",
|
||||
"http://10.0.0.1/api/v1",
|
||||
"http://172.16.0.1/api/v1",
|
||||
"http://169.254.169.254/api/v1", # AWS metadata
|
||||
"http://[::1]/api/v1", # IPv6 localhost
|
||||
"http://0.0.0.0/api/v1",
|
||||
]
|
||||
# Check default headers
|
||||
assert "HTTP-Referer" in provider.DEFAULT_HEADERS
|
||||
assert "X-Title" in provider.DEFAULT_HEADERS
|
||||
assert provider.DEFAULT_HEADERS["X-Title"] == "Zen MCP Server"
|
||||
|
||||
for url in dangerous_urls:
|
||||
with pytest.raises(ValueError, match="restricted IP|Invalid"):
|
||||
provider.base_url = url
|
||||
provider._validate_base_url()
|
||||
|
||||
def test_url_validation_allows_public_domains(self):
|
||||
"""Test that legitimate public domains are allowed."""
|
||||
def test_openrouter_model_registry_initialized(self):
|
||||
"""Test that model registry is properly initialized."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
|
||||
# OpenRouter's actual domain should always be allowed
|
||||
provider.base_url = "https://openrouter.ai/api/v1"
|
||||
provider._validate_base_url() # Should not raise
|
||||
|
||||
def test_invalid_url_schemes_rejected(self):
|
||||
"""Test that non-HTTP(S) schemes are rejected."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
|
||||
invalid_urls = [
|
||||
"ftp://example.com/api",
|
||||
"file:///etc/passwd",
|
||||
"gopher://example.com",
|
||||
"javascript:alert(1)",
|
||||
]
|
||||
|
||||
for url in invalid_urls:
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
provider.base_url = url
|
||||
provider._validate_base_url()
|
||||
# Registry should be initialized
|
||||
assert hasattr(provider, '_registry')
|
||||
assert provider._registry is not None
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry, OpenRouterModelConfig
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry
|
||||
|
||||
|
||||
class TestOpenRouterModelRegistry:
|
||||
@@ -24,17 +24,9 @@ class TestOpenRouterModelRegistry:
|
||||
def test_custom_config_path(self):
|
||||
"""Test registry with custom config path."""
|
||||
# Create temporary config
|
||||
config_data = {
|
||||
"models": [
|
||||
{
|
||||
"model_name": "test/model-1",
|
||||
"aliases": ["test1", "t1"],
|
||||
"context_window": 4096
|
||||
}
|
||||
]
|
||||
}
|
||||
config_data = {"models": [{"model_name": "test/model-1", "aliases": ["test1", "t1"], "context_window": 4096}]}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
temp_path = f.name
|
||||
|
||||
@@ -50,24 +42,16 @@ class TestOpenRouterModelRegistry:
|
||||
def test_environment_variable_override(self):
|
||||
"""Test OPENROUTER_MODELS_PATH environment variable."""
|
||||
# Create custom config
|
||||
config_data = {
|
||||
"models": [
|
||||
{
|
||||
"model_name": "env/model",
|
||||
"aliases": ["envtest"],
|
||||
"context_window": 8192
|
||||
}
|
||||
]
|
||||
}
|
||||
config_data = {"models": [{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192}]}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
# Set environment variable
|
||||
original_env = os.environ.get('OPENROUTER_MODELS_PATH')
|
||||
os.environ['OPENROUTER_MODELS_PATH'] = temp_path
|
||||
original_env = os.environ.get("OPENROUTER_MODELS_PATH")
|
||||
os.environ["OPENROUTER_MODELS_PATH"] = temp_path
|
||||
|
||||
# Create registry without explicit path
|
||||
registry = OpenRouterModelRegistry()
|
||||
@@ -79,9 +63,9 @@ class TestOpenRouterModelRegistry:
|
||||
finally:
|
||||
# Restore environment
|
||||
if original_env is not None:
|
||||
os.environ['OPENROUTER_MODELS_PATH'] = original_env
|
||||
os.environ["OPENROUTER_MODELS_PATH"] = original_env
|
||||
else:
|
||||
del os.environ['OPENROUTER_MODELS_PATH']
|
||||
del os.environ["OPENROUTER_MODELS_PATH"]
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_alias_resolution(self):
|
||||
@@ -143,20 +127,16 @@ class TestOpenRouterModelRegistry:
|
||||
"""Test that duplicate aliases are detected."""
|
||||
config_data = {
|
||||
"models": [
|
||||
{
|
||||
"model_name": "test/model-1",
|
||||
"aliases": ["dupe"],
|
||||
"context_window": 4096
|
||||
},
|
||||
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096},
|
||||
{
|
||||
"model_name": "test/model-2",
|
||||
"aliases": ["DUPE"], # Same alias, different case
|
||||
"context_window": 8192
|
||||
}
|
||||
"context_window": 8192,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
temp_path = f.name
|
||||
|
||||
@@ -174,12 +154,12 @@ class TestOpenRouterModelRegistry:
|
||||
"model_name": "test/old-model",
|
||||
"aliases": ["old"],
|
||||
"max_tokens": 16384, # Old field name
|
||||
"supports_extended_thinking": False
|
||||
"supports_extended_thinking": False,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
temp_path = f.name
|
||||
|
||||
@@ -208,7 +188,7 @@ class TestOpenRouterModelRegistry:
|
||||
|
||||
def test_invalid_json_config(self):
|
||||
"""Test handling of invalid JSON."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
f.write("{ invalid json }")
|
||||
temp_path = f.name
|
||||
|
||||
@@ -231,7 +211,7 @@ class TestOpenRouterModelRegistry:
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
description="Fully featured test model"
|
||||
description="Fully featured test model",
|
||||
)
|
||||
|
||||
caps = config.to_capabilities()
|
||||
|
||||
@@ -57,15 +57,28 @@ class ToolRequest(BaseModel):
|
||||
# Higher values allow for more complex reasoning but increase latency and cost
|
||||
thinking_mode: Optional[Literal["minimal", "low", "medium", "high", "max"]] = Field(
|
||||
None,
|
||||
description="Thinking depth: minimal (0.5% of model max), low (8%), medium (33%), high (67%), max (100% of model max)",
|
||||
description=(
|
||||
"Thinking depth: minimal (0.5% of model max), low (8%), medium (33%), high (67%), "
|
||||
"max (100% of model max)"
|
||||
),
|
||||
)
|
||||
use_websearch: Optional[bool] = Field(
|
||||
True,
|
||||
description="Enable web search for documentation, best practices, and current information. When enabled, the model can request Claude to perform web searches and share results back during conversations. Particularly useful for: brainstorming sessions, architectural design discussions, exploring industry best practices, working with specific frameworks/technologies, researching solutions to complex problems, or when current documentation and community insights would enhance the analysis.",
|
||||
description=(
|
||||
"Enable web search for documentation, best practices, and current information. "
|
||||
"When enabled, the model can request Claude to perform web searches and share results back "
|
||||
"during conversations. Particularly useful for: brainstorming sessions, architectural design "
|
||||
"discussions, exploring industry best practices, working with specific frameworks/technologies, "
|
||||
"researching solutions to complex problems, or when current documentation and community insights "
|
||||
"would enhance the analysis."
|
||||
),
|
||||
)
|
||||
continuation_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||
description=(
|
||||
"Thread continuation ID for multi-turn conversations. Can be used to continue conversations "
|
||||
"across different tools. Only provide this if continuing a previous conversation thread."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -152,12 +165,14 @@ class BaseTool(ABC):
|
||||
Returns:
|
||||
Dict containing the model field JSON schema
|
||||
"""
|
||||
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
||||
import os
|
||||
|
||||
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
||||
|
||||
# Check if OpenRouter is configured
|
||||
has_openrouter = bool(os.getenv("OPENROUTER_API_KEY") and
|
||||
os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here")
|
||||
has_openrouter = bool(
|
||||
os.getenv("OPENROUTER_API_KEY") and os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here"
|
||||
)
|
||||
|
||||
if IS_AUTO_MODE:
|
||||
# In auto mode, model is required and we provide detailed descriptions
|
||||
@@ -166,7 +181,32 @@ class BaseTool(ABC):
|
||||
model_desc_parts.append(f"- '{model}': {desc}")
|
||||
|
||||
if has_openrouter:
|
||||
model_desc_parts.append("\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large'). Check openrouter.ai/models for available models.")
|
||||
# Add OpenRouter aliases from the registry
|
||||
try:
|
||||
# Import registry directly to show available aliases
|
||||
# This works even without an API key
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
registry = OpenRouterModelRegistry()
|
||||
aliases = registry.list_aliases()
|
||||
|
||||
# Show ALL aliases from the configuration
|
||||
if aliases:
|
||||
# Show all aliases so Claude knows every option available
|
||||
all_aliases = sorted(aliases)
|
||||
alias_list = ", ".join(f"'{a}'" for a in all_aliases)
|
||||
model_desc_parts.append(
|
||||
f"\nOpenRouter models available via aliases: {alias_list}"
|
||||
)
|
||||
else:
|
||||
model_desc_parts.append(
|
||||
"\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter."
|
||||
)
|
||||
except Exception:
|
||||
# Fallback if registry fails to load
|
||||
model_desc_parts.append(
|
||||
"\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "string",
|
||||
@@ -180,7 +220,28 @@ class BaseTool(ABC):
|
||||
|
||||
description = f"Model to use. Native models: {models_str}."
|
||||
if has_openrouter:
|
||||
description += " OpenRouter: Any model available on openrouter.ai (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
|
||||
# Add OpenRouter aliases
|
||||
try:
|
||||
# Import registry directly to show available aliases
|
||||
# This works even without an API key
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
registry = OpenRouterModelRegistry()
|
||||
aliases = registry.list_aliases()
|
||||
|
||||
# Show ALL aliases from the configuration
|
||||
if aliases:
|
||||
# Show all aliases so Claude knows every option available
|
||||
all_aliases = sorted(aliases)
|
||||
alias_list = ", ".join(f"'{a}'" for a in all_aliases)
|
||||
description += f" OpenRouter aliases: {alias_list}."
|
||||
else:
|
||||
description += " OpenRouter: Any model available on openrouter.ai."
|
||||
except Exception:
|
||||
description += (
|
||||
" OpenRouter: Any model available on openrouter.ai "
|
||||
"(e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
|
||||
)
|
||||
description += f" Defaults to '{DEFAULT_MODEL}' if not specified."
|
||||
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user