- OpenRouter model configuration registry
- Model definition file for users to be able to control
- Additional tests
- Update instructions
This commit is contained in:
Fahad
2025-06-13 06:33:12 +04:00
parent cd1105b741
commit 2cdb92460b
12 changed files with 417 additions and 381 deletions

View File

@@ -56,11 +56,13 @@ MODEL_CAPABILITIES_DESC = {
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity", "o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
# Full model names also supported # Full model names also supported
"gemini-2.5-flash-preview-05-20": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations", "gemini-2.5-flash-preview-05-20": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
"gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis", "gemini-2.5-pro-preview-06-05": (
"Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis"
),
} }
# Note: When only OpenRouter is configured, these model aliases automatically map to equivalent models: # Note: When only OpenRouter is configured, these model aliases automatically map to equivalent models:
# - "flash" → "google/gemini-flash-1.5-8b" # - "flash" → "google/gemini-flash-1.5-8b"
# - "pro" → "google/gemini-pro-1.5" # - "pro" → "google/gemini-pro-1.5"
# - "o3" → "openai/gpt-4o" # - "o3" → "openai/gpt-4o"
# - "o3-mini" → "openai/gpt-4o-mini" # - "o3-mini" → "openai/gpt-4o-mini"

View File

@@ -141,7 +141,11 @@ trace issues to their root cause, and provide actionable solutions.
IMPORTANT: If you lack critical information to proceed (e.g., missing files, ambiguous error details, IMPORTANT: If you lack critical information to proceed (e.g., missing files, ambiguous error details,
insufficient context), OR if the provided diagnostics (log files, crash reports, stack traces) appear irrelevant, insufficient context), OR if the provided diagnostics (log files, crash reports, stack traces) appear irrelevant,
incomplete, or insufficient for proper analysis, you MUST respond ONLY with this JSON format: incomplete, or insufficient for proper analysis, you MUST respond ONLY with this JSON format:
{"status": "requires_clarification", "question": "What specific information you need from Claude or the user to proceed with debugging", "files_needed": ["file1.py", "file2.py"]} {
"status": "requires_clarification",
"question": "What specific information you need from Claude or the user to proceed with debugging",
"files_needed": ["file1.py", "file2.py"]
}
CRITICAL: Your primary objective is to identify the root cause of the specific issue at hand and suggest the CRITICAL: Your primary objective is to identify the root cause of the specific issue at hand and suggest the
minimal fix required to resolve it. Stay focused on the main problem - avoid suggesting extensive refactoring, minimal fix required to resolve it. Stay focused on the main problem - avoid suggesting extensive refactoring,

View File

@@ -1,12 +1,8 @@
"""OpenAI model provider implementation.""" """OpenAI model provider implementation."""
import logging
from typing import Optional
from .base import ( from .base import (
FixedTemperatureConstraint, FixedTemperatureConstraint,
ModelCapabilities, ModelCapabilities,
ModelResponse,
ProviderType, ProviderType,
RangeTemperatureConstraint, RangeTemperatureConstraint,
) )
@@ -34,7 +30,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
kwargs.setdefault("base_url", "https://api.openai.com/v1") kwargs.setdefault("base_url", "https://api.openai.com/v1")
super().__init__(api_key, **kwargs) super().__init__(api_key, **kwargs)
def get_capabilities(self, model_name: str) -> ModelCapabilities: def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific OpenAI model.""" """Get capabilities for a specific OpenAI model."""
if model_name not in self.SUPPORTED_MODELS: if model_name not in self.SUPPORTED_MODELS:
@@ -62,7 +57,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
temperature_constraint=temp_constraint, temperature_constraint=temp_constraint,
) )
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""
return ProviderType.OPENAI return ProviderType.OPENAI
@@ -76,4 +70,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
# Currently no OpenAI models support extended thinking # Currently no OpenAI models support extended thinking
# This may change with future O3 models # This may change with future O3 models
return False return False

View File

@@ -1,12 +1,12 @@
"""Base class for OpenAI-compatible API providers.""" """Base class for OpenAI-compatible API providers."""
import ipaddress
import logging import logging
import os import os
import socket
from abc import abstractmethod from abc import abstractmethod
from typing import Optional from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import ipaddress
import socket
from openai import OpenAI from openai import OpenAI
@@ -15,25 +15,24 @@ from .base import (
ModelProvider, ModelProvider,
ModelResponse, ModelResponse,
ProviderType, ProviderType,
RangeTemperatureConstraint,
) )
class OpenAICompatibleProvider(ModelProvider): class OpenAICompatibleProvider(ModelProvider):
"""Base class for any provider using an OpenAI-compatible API. """Base class for any provider using an OpenAI-compatible API.
This includes: This includes:
- Direct OpenAI API - Direct OpenAI API
- OpenRouter - OpenRouter
- Any other OpenAI-compatible endpoint - Any other OpenAI-compatible endpoint
""" """
DEFAULT_HEADERS = {} DEFAULT_HEADERS = {}
FRIENDLY_NAME = "OpenAI Compatible" FRIENDLY_NAME = "OpenAI Compatible"
def __init__(self, api_key: str, base_url: str = None, **kwargs): def __init__(self, api_key: str, base_url: str = None, **kwargs):
"""Initialize the provider with API key and optional base URL. """Initialize the provider with API key and optional base URL.
Args: Args:
api_key: API key for authentication api_key: API key for authentication
base_url: Base URL for the API endpoint base_url: Base URL for the API endpoint
@@ -44,21 +43,21 @@ class OpenAICompatibleProvider(ModelProvider):
self.base_url = base_url self.base_url = base_url
self.organization = kwargs.get("organization") self.organization = kwargs.get("organization")
self.allowed_models = self._parse_allowed_models() self.allowed_models = self._parse_allowed_models()
# Validate base URL for security # Validate base URL for security
if self.base_url: if self.base_url:
self._validate_base_url() self._validate_base_url()
# Warn if using external URL without authentication # Warn if using external URL without authentication
if self.base_url and not self._is_localhost_url() and not api_key: if self.base_url and not self._is_localhost_url() and not api_key:
logging.warning( logging.warning(
f"Using external URL '{self.base_url}' without API key. " f"Using external URL '{self.base_url}' without API key. "
"This may be insecure. Consider setting an API key for authentication." "This may be insecure. Consider setting an API key for authentication."
) )
def _parse_allowed_models(self) -> Optional[set[str]]: def _parse_allowed_models(self) -> Optional[set[str]]:
"""Parse allowed models from environment variable. """Parse allowed models from environment variable.
Returns: Returns:
Set of allowed model names (lowercase) or None if not configured Set of allowed model names (lowercase) or None if not configured
""" """
@@ -66,108 +65,108 @@ class OpenAICompatibleProvider(ModelProvider):
provider_type = self.get_provider_type().value.upper() provider_type = self.get_provider_type().value.upper()
env_var = f"{provider_type}_ALLOWED_MODELS" env_var = f"{provider_type}_ALLOWED_MODELS"
models_str = os.getenv(env_var, "") models_str = os.getenv(env_var, "")
if models_str: if models_str:
# Parse and normalize to lowercase for case-insensitive comparison # Parse and normalize to lowercase for case-insensitive comparison
models = set(m.strip().lower() for m in models_str.split(",") if m.strip()) models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
if models: if models:
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}") logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
return models return models
# Log warning if no allow-list configured for proxy providers # Log warning if no allow-list configured for proxy providers
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]: if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
logging.warning( logging.warning(
f"No model allow-list configured for {self.FRIENDLY_NAME}. " f"No model allow-list configured for {self.FRIENDLY_NAME}. "
f"Set {env_var} to restrict model access and control costs." f"Set {env_var} to restrict model access and control costs."
) )
return None return None
def _is_localhost_url(self) -> bool: def _is_localhost_url(self) -> bool:
"""Check if the base URL points to localhost. """Check if the base URL points to localhost.
Returns: Returns:
True if URL is localhost, False otherwise True if URL is localhost, False otherwise
""" """
if not self.base_url: if not self.base_url:
return False return False
try: try:
parsed = urlparse(self.base_url) parsed = urlparse(self.base_url)
hostname = parsed.hostname hostname = parsed.hostname
# Check for common localhost patterns # Check for common localhost patterns
if hostname in ['localhost', '127.0.0.1', '::1']: if hostname in ["localhost", "127.0.0.1", "::1"]:
return True return True
return False return False
except Exception: except Exception:
return False return False
def _validate_base_url(self) -> None: def _validate_base_url(self) -> None:
"""Validate base URL for security (SSRF protection). """Validate base URL for security (SSRF protection).
Raises: Raises:
ValueError: If URL is invalid or potentially unsafe ValueError: If URL is invalid or potentially unsafe
""" """
if not self.base_url: if not self.base_url:
return return
try: try:
parsed = urlparse(self.base_url) parsed = urlparse(self.base_url)
# Check URL scheme - only allow http/https # Check URL scheme - only allow http/https
if parsed.scheme not in ('http', 'https'): if parsed.scheme not in ("http", "https"):
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.") raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
# Check hostname exists # Check hostname exists
if not parsed.hostname: if not parsed.hostname:
raise ValueError("URL must include a hostname") raise ValueError("URL must include a hostname")
# Check port - allow only standard HTTP/HTTPS ports # Check port - allow only standard HTTP/HTTPS ports
port = parsed.port port = parsed.port
if port is None: if port is None:
port = 443 if parsed.scheme == 'https' else 80 port = 443 if parsed.scheme == "https" else 80
# Allow common HTTP ports and some alternative ports # Allow common HTTP ports and some alternative ports
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
if port not in allowed_ports: if port not in allowed_ports:
raise ValueError( raise ValueError(f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}")
f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}"
)
# Check against allowed domains if configured # Check against allowed domains if configured
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",") allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()] allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
if allowed_domains: if allowed_domains:
hostname_lower = parsed.hostname.lower() hostname_lower = parsed.hostname.lower()
if not any( if not any(
hostname_lower == domain or hostname_lower == domain or hostname_lower.endswith("." + domain) for domain in allowed_domains
hostname_lower.endswith('.' + domain)
for domain in allowed_domains
): ):
raise ValueError( raise ValueError(
f"Domain not in allow-list: {parsed.hostname}. " f"Domain not in allow-list: {parsed.hostname}. " f"Allowed domains: {allowed_domains}"
f"Allowed domains: {allowed_domains}"
) )
# Try to resolve hostname and check if it's a private IP # Try to resolve hostname and check if it's a private IP
# Skip for localhost addresses which are commonly used for development # Skip for localhost addresses which are commonly used for development
if parsed.hostname not in ['localhost', '127.0.0.1', '::1']: if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
try: try:
# Get all IP addresses for the hostname # Get all IP addresses for the hostname
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP) addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
for family, _, _, _, sockaddr in addr_info: for _family, _, _, _, sockaddr in addr_info:
ip_str = sockaddr[0] ip_str = sockaddr[0]
try: try:
ip = ipaddress.ip_address(ip_str) ip = ipaddress.ip_address(ip_str)
# Check for dangerous IP ranges # Check for dangerous IP ranges
if (ip.is_private or ip.is_loopback or ip.is_link_local or if (
ip.is_multicast or ip.is_reserved or ip.is_unspecified): ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
):
raise ValueError( raise ValueError(
f"URL resolves to restricted IP address: {ip_str}. " f"URL resolves to restricted IP address: {ip_str}. "
"This could be a security risk (SSRF)." "This could be a security risk (SSRF)."
@@ -177,16 +176,16 @@ class OpenAICompatibleProvider(ModelProvider):
if "restricted IP address" in str(ve): if "restricted IP address" in str(ve):
raise raise
continue continue
except socket.gaierror as e: except socket.gaierror as e:
# If we can't resolve the hostname, it's suspicious # If we can't resolve the hostname, it's suspicious
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}") raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
except Exception as e: except Exception as e:
if isinstance(e, ValueError): if isinstance(e, ValueError):
raise raise
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}") raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
@property @property
def client(self): def client(self):
"""Lazy initialization of OpenAI client with security checks.""" """Lazy initialization of OpenAI client with security checks."""
@@ -194,21 +193,21 @@ class OpenAICompatibleProvider(ModelProvider):
client_kwargs = { client_kwargs = {
"api_key": self.api_key, "api_key": self.api_key,
} }
if self.base_url: if self.base_url:
client_kwargs["base_url"] = self.base_url client_kwargs["base_url"] = self.base_url
if self.organization: if self.organization:
client_kwargs["organization"] = self.organization client_kwargs["organization"] = self.organization
# Add default headers if any # Add default headers if any
if self.DEFAULT_HEADERS: if self.DEFAULT_HEADERS:
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy() client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
self._client = OpenAI(**client_kwargs) self._client = OpenAI(**client_kwargs)
return self._client return self._client
def generate_content( def generate_content(
self, self,
prompt: str, prompt: str,
@@ -219,7 +218,7 @@ class OpenAICompatibleProvider(ModelProvider):
**kwargs, **kwargs,
) -> ModelResponse: ) -> ModelResponse:
"""Generate content using the OpenAI-compatible API. """Generate content using the OpenAI-compatible API.
Args: Args:
prompt: User prompt to send to the model prompt: User prompt to send to the model
model_name: Name of the model to use model_name: Name of the model to use
@@ -227,50 +226,49 @@ class OpenAICompatibleProvider(ModelProvider):
temperature: Sampling temperature temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters **kwargs: Additional provider-specific parameters
Returns: Returns:
ModelResponse with generated content and metadata ModelResponse with generated content and metadata
""" """
# Validate model name against allow-list # Validate model name against allow-list
if not self.validate_model_name(model_name): if not self.validate_model_name(model_name):
raise ValueError( raise ValueError(
f"Model '{model_name}' not in allowed models list. " f"Model '{model_name}' not in allowed models list. " f"Allowed models: {self.allowed_models}"
f"Allowed models: {self.allowed_models}"
) )
# Validate parameters # Validate parameters
self.validate_parameters(model_name, temperature) self.validate_parameters(model_name, temperature)
# Prepare messages # Prepare messages
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
# Prepare completion parameters # Prepare completion parameters
completion_params = { completion_params = {
"model": model_name, "model": model_name,
"messages": messages, "messages": messages,
"temperature": temperature, "temperature": temperature,
} }
# Add max tokens if specified # Add max tokens if specified
if max_output_tokens: if max_output_tokens:
completion_params["max_tokens"] = max_output_tokens completion_params["max_tokens"] = max_output_tokens
# Add any additional OpenAI-specific parameters # Add any additional OpenAI-specific parameters
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]: if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
completion_params[key] = value completion_params[key] = value
try: try:
# Generate completion # Generate completion
response = self.client.chat.completions.create(**completion_params) response = self.client.chat.completions.create(**completion_params)
# Extract content and usage # Extract content and usage
content = response.choices[0].message.content content = response.choices[0].message.content
usage = self._extract_usage(response) usage = self._extract_usage(response)
return ModelResponse( return ModelResponse(
content=content, content=content,
usage=usage, usage=usage,
@@ -284,39 +282,39 @@ class OpenAICompatibleProvider(ModelProvider):
"created": response.created, "created": response.created,
}, },
) )
except Exception as e: except Exception as e:
# Log error and re-raise with more context # Log error and re-raise with more context
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}" error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}"
logging.error(error_msg) logging.error(error_msg)
raise RuntimeError(error_msg) from e raise RuntimeError(error_msg) from e
def count_tokens(self, text: str, model_name: str) -> int: def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text. """Count tokens for the given text.
Uses a layered approach: Uses a layered approach:
1. Try provider-specific token counting endpoint 1. Try provider-specific token counting endpoint
2. Try tiktoken for known model families 2. Try tiktoken for known model families
3. Fall back to character-based estimation 3. Fall back to character-based estimation
Args: Args:
text: Text to count tokens for text: Text to count tokens for
model_name: Model name for tokenizer selection model_name: Model name for tokenizer selection
Returns: Returns:
Estimated token count Estimated token count
""" """
# 1. Check if provider has a remote token counting endpoint # 1. Check if provider has a remote token counting endpoint
if hasattr(self, 'count_tokens_remote'): if hasattr(self, "count_tokens_remote"):
try: try:
return self.count_tokens_remote(text, model_name) return self.count_tokens_remote(text, model_name)
except Exception as e: except Exception as e:
logging.debug(f"Remote token counting failed: {e}") logging.debug(f"Remote token counting failed: {e}")
# 2. Try tiktoken for known models # 2. Try tiktoken for known models
try: try:
import tiktoken import tiktoken
# Try to get encoding for the specific model # Try to get encoding for the specific model
try: try:
encoding = tiktoken.encoding_for_model(model_name) encoding = tiktoken.encoding_for_model(model_name)
@@ -326,24 +324,24 @@ class OpenAICompatibleProvider(ModelProvider):
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
else: else:
encoding = tiktoken.get_encoding("cl100k_base") # Default encoding = tiktoken.get_encoding("cl100k_base") # Default
return len(encoding.encode(text)) return len(encoding.encode(text))
except (ImportError, Exception) as e: except (ImportError, Exception) as e:
logging.debug(f"Tiktoken not available or failed: {e}") logging.debug(f"Tiktoken not available or failed: {e}")
# 3. Fall back to character-based estimation # 3. Fall back to character-based estimation
logging.warning( logging.warning(
f"No specific tokenizer available for '{model_name}'. " f"No specific tokenizer available for '{model_name}'. "
"Using character-based estimation (~4 chars per token)." "Using character-based estimation (~4 chars per token)."
) )
return len(text) // 4 return len(text) // 4
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters. """Validate model parameters.
For proxy providers, this may use generic capabilities. For proxy providers, this may use generic capabilities.
Args: Args:
model_name: Model to validate for model_name: Model to validate for
temperature: Temperature to validate temperature: Temperature to validate
@@ -351,67 +349,66 @@ class OpenAICompatibleProvider(ModelProvider):
""" """
try: try:
capabilities = self.get_capabilities(model_name) capabilities = self.get_capabilities(model_name)
# Check if we're using generic capabilities # Check if we're using generic capabilities
if hasattr(capabilities, '_is_generic'): if hasattr(capabilities, "_is_generic"):
logging.debug( logging.debug(
f"Using generic parameter validation for {model_name}. " f"Using generic parameter validation for {model_name}. " "Actual model constraints may differ."
"Actual model constraints may differ."
) )
# Validate temperature using parent class method # Validate temperature using parent class method
super().validate_parameters(model_name, temperature, **kwargs) super().validate_parameters(model_name, temperature, **kwargs)
except Exception as e: except Exception as e:
# For proxy providers, we might not have accurate capabilities # For proxy providers, we might not have accurate capabilities
# Log warning but don't fail # Log warning but don't fail
logging.warning(f"Parameter validation limited for {model_name}: {e}") logging.warning(f"Parameter validation limited for {model_name}: {e}")
def _extract_usage(self, response) -> dict[str, int]: def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from OpenAI response. """Extract token usage from OpenAI response.
Args: Args:
response: OpenAI API response object response: OpenAI API response object
Returns: Returns:
Dictionary with usage statistics Dictionary with usage statistics
""" """
usage = {} usage = {}
if hasattr(response, "usage") and response.usage: if hasattr(response, "usage") and response.usage:
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
return usage return usage
@abstractmethod @abstractmethod
def get_capabilities(self, model_name: str) -> ModelCapabilities: def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model. """Get capabilities for a specific model.
Must be implemented by subclasses. Must be implemented by subclasses.
""" """
pass pass
@abstractmethod @abstractmethod
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type. """Get the provider type.
Must be implemented by subclasses. Must be implemented by subclasses.
""" """
pass pass
@abstractmethod @abstractmethod
def validate_model_name(self, model_name: str) -> bool: def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported. """Validate if the model name is supported.
Must be implemented by subclasses. Must be implemented by subclasses.
""" """
pass pass
def supports_thinking_mode(self, model_name: str) -> bool: def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode. """Check if the model supports extended thinking mode.
Default is False for OpenAI-compatible providers. Default is False for OpenAI-compatible providers.
""" """
return False return False

View File

@@ -16,63 +16,61 @@ from .openrouter_registry import OpenRouterModelRegistry
class OpenRouterProvider(OpenAICompatibleProvider): class OpenRouterProvider(OpenAICompatibleProvider):
"""OpenRouter unified API provider. """OpenRouter unified API provider.
OpenRouter provides access to multiple AI models through a single API endpoint. OpenRouter provides access to multiple AI models through a single API endpoint.
See https://openrouter.ai for available models and pricing. See https://openrouter.ai for available models and pricing.
""" """
FRIENDLY_NAME = "OpenRouter" FRIENDLY_NAME = "OpenRouter"
# Custom headers required by OpenRouter # Custom headers required by OpenRouter
DEFAULT_HEADERS = { DEFAULT_HEADERS = {
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"), "HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"), "X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
} }
# Model registry for managing configurations and aliases # Model registry for managing configurations and aliases
_registry: Optional[OpenRouterModelRegistry] = None _registry: Optional[OpenRouterModelRegistry] = None
def __init__(self, api_key: str, **kwargs): def __init__(self, api_key: str, **kwargs):
"""Initialize OpenRouter provider. """Initialize OpenRouter provider.
Args: Args:
api_key: OpenRouter API key api_key: OpenRouter API key
**kwargs: Additional configuration **kwargs: Additional configuration
""" """
# Always use OpenRouter's base URL # Always use OpenRouter's base URL
super().__init__(api_key, base_url="https://openrouter.ai/api/v1", **kwargs) super().__init__(api_key, base_url="https://openrouter.ai/api/v1", **kwargs)
# Initialize model registry # Initialize model registry
if OpenRouterProvider._registry is None: if OpenRouterProvider._registry is None:
OpenRouterProvider._registry = OpenRouterModelRegistry() OpenRouterProvider._registry = OpenRouterModelRegistry()
# Log loaded models and aliases # Log loaded models and aliases
models = self._registry.list_models() models = self._registry.list_models()
aliases = self._registry.list_aliases() aliases = self._registry.list_aliases()
logging.info( logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases"
)
def _parse_allowed_models(self) -> None: def _parse_allowed_models(self) -> None:
"""Override to disable environment-based allow-list. """Override to disable environment-based allow-list.
OpenRouter model access is controlled via the OpenRouter dashboard, OpenRouter model access is controlled via the OpenRouter dashboard,
not through environment variables. not through environment variables.
""" """
return None return None
def _resolve_model_name(self, model_name: str) -> str: def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model aliases to OpenRouter model names. """Resolve model aliases to OpenRouter model names.
Args: Args:
model_name: Input model name or alias model_name: Input model name or alias
Returns: Returns:
Resolved OpenRouter model name Resolved OpenRouter model name
""" """
# Try to resolve through registry # Try to resolve through registry
config = self._registry.resolve(model_name) config = self._registry.resolve(model_name)
if config: if config:
if config.model_name != model_name: if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
@@ -82,30 +80,30 @@ class OpenRouterProvider(OpenAICompatibleProvider):
# This allows using models not in our config file # This allows using models not in our config file
logging.debug(f"Model '{model_name}' not found in registry, using as-is") logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name return model_name
def get_capabilities(self, model_name: str) -> ModelCapabilities: def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a model. """Get capabilities for a model.
Args: Args:
model_name: Name of the model (or alias) model_name: Name of the model (or alias)
Returns: Returns:
ModelCapabilities from registry or generic defaults ModelCapabilities from registry or generic defaults
""" """
# Try to get from registry first # Try to get from registry first
capabilities = self._registry.get_capabilities(model_name) capabilities = self._registry.get_capabilities(model_name)
if capabilities: if capabilities:
return capabilities return capabilities
else: else:
# Resolve any potential aliases and create generic capabilities # Resolve any potential aliases and create generic capabilities
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
logging.debug( logging.debug(
f"Using generic capabilities for '{resolved_name}' via OpenRouter. " f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
"Consider adding to openrouter_models.json for specific capabilities." "Consider adding to openrouter_models.json for specific capabilities."
) )
# Create generic capabilities with conservative defaults # Create generic capabilities with conservative defaults
capabilities = ModelCapabilities( capabilities = ModelCapabilities(
provider=ProviderType.OPENROUTER, provider=ProviderType.OPENROUTER,
@@ -118,31 +116,31 @@ class OpenRouterProvider(OpenAICompatibleProvider):
supports_function_calling=False, supports_function_calling=False,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0), temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
) )
# Mark as generic for validation purposes # Mark as generic for validation purposes
capabilities._is_generic = True capabilities._is_generic = True
return capabilities return capabilities
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""
return ProviderType.OPENROUTER return ProviderType.OPENROUTER
def validate_model_name(self, model_name: str) -> bool: def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is allowed. """Validate if the model name is allowed.
For OpenRouter, we accept any model name. OpenRouter will For OpenRouter, we accept any model name. OpenRouter will
validate based on the API key's permissions. validate based on the API key's permissions.
Args: Args:
model_name: Model name to validate model_name: Model name to validate
Returns: Returns:
Always True - OpenRouter handles validation Always True - OpenRouter handles validation
""" """
# Accept any model name - OpenRouter will validate based on API key permissions # Accept any model name - OpenRouter will validate based on API key permissions
return True return True
def generate_content( def generate_content(
self, self,
prompt: str, prompt: str,
@@ -153,7 +151,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
**kwargs, **kwargs,
) -> ModelResponse: ) -> ModelResponse:
"""Generate content using the OpenRouter API. """Generate content using the OpenRouter API.
Args: Args:
prompt: User prompt to send to the model prompt: User prompt to send to the model
model_name: Name of the model (or alias) to use model_name: Name of the model (or alias) to use
@@ -161,13 +159,13 @@ class OpenRouterProvider(OpenAICompatibleProvider):
temperature: Sampling temperature temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters **kwargs: Additional provider-specific parameters
Returns: Returns:
ModelResponse with generated content and metadata ModelResponse with generated content and metadata
""" """
# Resolve model alias to actual OpenRouter model name # Resolve model alias to actual OpenRouter model name
resolved_model = self._resolve_model_name(model_name) resolved_model = self._resolve_model_name(model_name)
# Call parent method with resolved model name # Call parent method with resolved model name
return super().generate_content( return super().generate_content(
prompt=prompt, prompt=prompt,
@@ -175,19 +173,19 @@ class OpenRouterProvider(OpenAICompatibleProvider):
system_prompt=system_prompt, system_prompt=system_prompt,
temperature=temperature, temperature=temperature,
max_output_tokens=max_output_tokens, max_output_tokens=max_output_tokens,
**kwargs **kwargs,
) )
def supports_thinking_mode(self, model_name: str) -> bool: def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode. """Check if the model supports extended thinking mode.
Currently, no models via OpenRouter support extended thinking. Currently, no models via OpenRouter support extended thinking.
This may change as new models become available. This may change as new models become available.
Args: Args:
model_name: Model to check model_name: Model to check
Returns: Returns:
False (no OpenRouter models currently support thinking mode) False (no OpenRouter models currently support thinking mode)
""" """
return False return False

View File

@@ -3,9 +3,9 @@
import json import json
import logging import logging
import os import os
from pathlib import Path
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
@@ -13,9 +13,9 @@ from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
@dataclass @dataclass
class OpenRouterModelConfig: class OpenRouterModelConfig:
"""Configuration for an OpenRouter model.""" """Configuration for an OpenRouter model."""
model_name: str model_name: str
aliases: List[str] = field(default_factory=list) aliases: list[str] = field(default_factory=list)
context_window: int = 32768 # Total context window size in tokens context_window: int = 32768 # Total context window size in tokens
supports_extended_thinking: bool = False supports_extended_thinking: bool = False
supports_system_prompts: bool = True supports_system_prompts: bool = True
@@ -23,8 +23,7 @@ class OpenRouterModelConfig:
supports_function_calling: bool = False supports_function_calling: bool = False
supports_json_mode: bool = False supports_json_mode: bool = False
description: str = "" description: str = ""
def to_capabilities(self) -> ModelCapabilities: def to_capabilities(self) -> ModelCapabilities:
"""Convert to ModelCapabilities object.""" """Convert to ModelCapabilities object."""
return ModelCapabilities( return ModelCapabilities(
@@ -42,16 +41,16 @@ class OpenRouterModelConfig:
class OpenRouterModelRegistry: class OpenRouterModelRegistry:
"""Registry for managing OpenRouter model configurations and aliases.""" """Registry for managing OpenRouter model configurations and aliases."""
def __init__(self, config_path: Optional[str] = None): def __init__(self, config_path: Optional[str] = None):
"""Initialize the registry. """Initialize the registry.
Args: Args:
config_path: Path to config file. If None, uses default locations. config_path: Path to config file. If None, uses default locations.
""" """
self.alias_map: Dict[str, str] = {} # alias -> model_name self.alias_map: dict[str, str] = {} # alias -> model_name
self.model_map: Dict[str, OpenRouterModelConfig] = {} # model_name -> config self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config
# Determine config path # Determine config path
if config_path: if config_path:
self.config_path = Path(config_path) self.config_path = Path(config_path)
@@ -63,86 +62,93 @@ class OpenRouterModelRegistry:
else: else:
# Default to conf/openrouter_models.json # Default to conf/openrouter_models.json
self.config_path = Path(__file__).parent.parent / "conf" / "openrouter_models.json" self.config_path = Path(__file__).parent.parent / "conf" / "openrouter_models.json"
# Load configuration # Load configuration
self.reload() self.reload()
def reload(self) -> None: def reload(self) -> None:
"""Reload configuration from disk.""" """Reload configuration from disk."""
try: try:
configs = self._read_config() configs = self._read_config()
self._build_maps(configs) self._build_maps(configs)
logging.info(f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases") logging.info(f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases")
except ValueError as e:
# Re-raise ValueError only for duplicate aliases (critical config errors)
logging.error(f"Failed to load OpenRouter model configuration: {e}")
# Initialize with empty maps on failure
self.alias_map = {}
self.model_map = {}
if "Duplicate alias" in str(e):
raise
except Exception as e: except Exception as e:
logging.error(f"Failed to load OpenRouter model configuration: {e}") logging.error(f"Failed to load OpenRouter model configuration: {e}")
# Initialize with empty maps on failure # Initialize with empty maps on failure
self.alias_map = {} self.alias_map = {}
self.model_map = {} self.model_map = {}
def _read_config(self) -> List[OpenRouterModelConfig]: def _read_config(self) -> list[OpenRouterModelConfig]:
"""Read configuration from file. """Read configuration from file.
Returns: Returns:
List of model configurations List of model configurations
""" """
if not self.config_path.exists(): if not self.config_path.exists():
logging.warning(f"OpenRouter model config not found at {self.config_path}") logging.warning(f"OpenRouter model config not found at {self.config_path}")
return [] return []
try: try:
with open(self.config_path, 'r') as f: with open(self.config_path) as f:
data = json.load(f) data = json.load(f)
# Parse models # Parse models
configs = [] configs = []
for model_data in data.get("models", []): for model_data in data.get("models", []):
# Handle backwards compatibility - rename max_tokens to context_window # Handle backwards compatibility - rename max_tokens to context_window
if 'max_tokens' in model_data and 'context_window' not in model_data: if "max_tokens" in model_data and "context_window" not in model_data:
model_data['context_window'] = model_data.pop('max_tokens') model_data["context_window"] = model_data.pop("max_tokens")
config = OpenRouterModelConfig(**model_data) config = OpenRouterModelConfig(**model_data)
configs.append(config) configs.append(config)
return configs return configs
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in {self.config_path}: {e}") raise ValueError(f"Invalid JSON in {self.config_path}: {e}")
except Exception as e: except Exception as e:
raise ValueError(f"Error reading config from {self.config_path}: {e}") raise ValueError(f"Error reading config from {self.config_path}: {e}")
def _build_maps(self, configs: List[OpenRouterModelConfig]) -> None: def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None:
"""Build alias and model maps from configurations. """Build alias and model maps from configurations.
Args: Args:
configs: List of model configurations configs: List of model configurations
""" """
alias_map = {} alias_map = {}
model_map = {} model_map = {}
for config in configs: for config in configs:
# Add to model map # Add to model map
model_map[config.model_name] = config model_map[config.model_name] = config
# Add aliases # Add aliases
for alias in config.aliases: for alias in config.aliases:
alias_lower = alias.lower() alias_lower = alias.lower()
if alias_lower in alias_map: if alias_lower in alias_map:
existing_model = alias_map[alias_lower] existing_model = alias_map[alias_lower]
raise ValueError( raise ValueError(
f"Duplicate alias '{alias}' found for models " f"Duplicate alias '{alias}' found for models " f"'{existing_model}' and '{config.model_name}'"
f"'{existing_model}' and '{config.model_name}'"
) )
alias_map[alias_lower] = config.model_name alias_map[alias_lower] = config.model_name
# Atomic update # Atomic update
self.alias_map = alias_map self.alias_map = alias_map
self.model_map = model_map self.model_map = model_map
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]: def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]:
"""Resolve a model name or alias to configuration. """Resolve a model name or alias to configuration.
Args: Args:
name_or_alias: Model name or alias to resolve name_or_alias: Model name or alias to resolve
Returns: Returns:
Model configuration if found, None otherwise Model configuration if found, None otherwise
""" """
@@ -151,16 +157,16 @@ class OpenRouterModelRegistry:
if alias_lower in self.alias_map: if alias_lower in self.alias_map:
model_name = self.alias_map[alias_lower] model_name = self.alias_map[alias_lower]
return self.model_map.get(model_name) return self.model_map.get(model_name)
# Try as direct model name # Try as direct model name
return self.model_map.get(name_or_alias) return self.model_map.get(name_or_alias)
def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]: def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]:
"""Get model capabilities for a name or alias. """Get model capabilities for a name or alias.
Args: Args:
name_or_alias: Model name or alias name_or_alias: Model name or alias
Returns: Returns:
ModelCapabilities if found, None otherwise ModelCapabilities if found, None otherwise
""" """
@@ -168,11 +174,11 @@ class OpenRouterModelRegistry:
if config: if config:
return config.to_capabilities() return config.to_capabilities()
return None return None
def list_models(self) -> List[str]: def list_models(self) -> list[str]:
"""List all available model names.""" """List all available model names."""
return list(self.model_map.keys()) return list(self.model_map.keys())
def list_aliases(self) -> List[str]: def list_aliases(self) -> list[str]:
"""List all available aliases.""" """List all available aliases."""
return list(self.alias_map.keys()) return list(self.alias_map.keys())

View File

@@ -173,8 +173,7 @@ def configure_providers():
"1. Use only OpenRouter: unset GEMINI_API_KEY and OPENAI_API_KEY\n" "1. Use only OpenRouter: unset GEMINI_API_KEY and OPENAI_API_KEY\n"
"2. Use only native APIs: unset OPENROUTER_API_KEY\n" "2. Use only native APIs: unset OPENROUTER_API_KEY\n"
"\n" "\n"
"Current configuration will prioritize native APIs over OpenRouter.\n" + "Current configuration will prioritize native APIs over OpenRouter.\n" + "=" * 70 + "\n"
"=" * 70 + "\n"
) )
# Register providers - native APIs first to ensure they take priority # Register providers - native APIs first to ensure they take priority
@@ -363,18 +362,22 @@ If something needs clarification or you'd benefit from additional context, simpl
IMPORTANT: When you suggest follow-ups or ask questions, you MUST explicitly instruct Claude to use the continuation_id IMPORTANT: When you suggest follow-ups or ask questions, you MUST explicitly instruct Claude to use the continuation_id
to respond. Use clear, direct language based on urgency: to respond. Use clear, direct language based on urgency:
For optional follow-ups: "Please continue this conversation using the continuation_id from this response if you'd like to explore this further." For optional follow-ups: "Please continue this conversation using the continuation_id from this response if you'd "
"like to explore this further."
For needed responses: "Please respond using the continuation_id from this response - your input is needed to proceed." For needed responses: "Please respond using the continuation_id from this response - your input is needed to proceed."
For essential/critical responses: "RESPONSE REQUIRED: Please immediately continue using the continuation_id from this response. Cannot proceed without your clarification/input." For essential/critical responses: "RESPONSE REQUIRED: Please immediately continue using the continuation_id from "
"this response. Cannot proceed without your clarification/input."
This ensures Claude knows both HOW to maintain the conversation thread AND whether a response is optional, needed, or essential. This ensures Claude knows both HOW to maintain the conversation thread AND whether a response is optional, "
"needed, or essential.
The tool will automatically provide a continuation_id in the structured response that Claude can use in subsequent The tool will automatically provide a continuation_id in the structured response that Claude can use in subsequent
tool calls to maintain full conversation context across multiple exchanges. tool calls to maintain full conversation context across multiple exchanges.
Remember: Only suggest follow-ups when they would genuinely add value to the discussion, and always instruct Claude to use the continuation_id when you do.""" Remember: Only suggest follow-ups when they would genuinely add value to the discussion, and always instruct "
"Claude to use the continuation_id when you do."""
async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any]: async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any]:
@@ -411,8 +414,10 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
# Return error asking Claude to restart conversation with full context # Return error asking Claude to restart conversation with full context
raise ValueError( raise ValueError(
f"Conversation thread '{continuation_id}' was not found or has expired. " f"Conversation thread '{continuation_id}' was not found or has expired. "
f"This may happen if the conversation was created more than 1 hour ago or if there was an issue with Redis storage. " f"This may happen if the conversation was created more than 1 hour ago or if there was an issue "
f"Please restart the conversation by providing your full question/prompt without the continuation_id parameter. " f"with Redis storage. "
f"Please restart the conversation by providing your full question/prompt without the "
f"continuation_id parameter. "
f"This will create a new conversation thread that can continue with follow-up exchanges." f"This will create a new conversation thread that can continue with follow-up exchanges."
) )
@@ -504,7 +509,8 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
try: try:
mcp_activity_logger = logging.getLogger("mcp_activity") mcp_activity_logger = logging.getLogger("mcp_activity")
mcp_activity_logger.info( mcp_activity_logger.info(
f"CONVERSATION_CONTINUATION: Thread {continuation_id} turn {len(context.turns)} - {len(context.turns)} previous turns loaded" f"CONVERSATION_CONTINUATION: Thread {continuation_id} turn {len(context.turns)} - "
f"{len(context.turns)} previous turns loaded"
) )
except Exception: except Exception:
pass pass
@@ -542,7 +548,7 @@ async def handle_get_version() -> list[TextContent]:
# Check configured providers # Check configured providers
from providers import ModelProviderRegistry from providers import ModelProviderRegistry
from providers.base import ProviderType from providers.base import ProviderType
configured_providers = [] configured_providers = []
if ModelProviderRegistry.get_provider(ProviderType.GOOGLE): if ModelProviderRegistry.get_provider(ProviderType.GOOGLE):
configured_providers.append("Gemini (flash, pro)") configured_providers.append("Gemini (flash, pro)")

View File

@@ -4,35 +4,38 @@ Test OpenRouter model mapping
""" """
import sys import sys
sys.path.append('/Users/fahad/Developer/gemini-mcp-server')
sys.path.append("/Users/fahad/Developer/gemini-mcp-server")
from simulator_tests.base_test import BaseSimulatorTest from simulator_tests.base_test import BaseSimulatorTest
class MappingTest(BaseSimulatorTest): class MappingTest(BaseSimulatorTest):
def test_mapping(self): def test_mapping(self):
"""Test model alias mapping""" """Test model alias mapping"""
# Test with 'flash' alias - should map to google/gemini-flash-1.5-8b # Test with 'flash' alias - should map to google/gemini-flash-1.5-8b
print("\nTesting 'flash' alias mapping...") print("\nTesting 'flash' alias mapping...")
response, continuation_id = self.call_mcp_tool( response, continuation_id = self.call_mcp_tool(
"chat", "chat",
{ {
"prompt": "Say 'Hello from Flash model!'", "prompt": "Say 'Hello from Flash model!'",
"model": "flash", # Should be mapped to google/gemini-flash-1.5-8b "model": "flash", # Should be mapped to google/gemini-flash-1.5-8b
"temperature": 0.1 "temperature": 0.1,
} },
) )
if response: if response:
print(f"✅ Flash alias worked!") print("✅ Flash alias worked!")
print(f"Response: {response[:200]}...") print(f"Response: {response[:200]}...")
return True return True
else: else:
print("❌ Flash alias failed") print("❌ Flash alias failed")
return False return False
if __name__ == "__main__": if __name__ == "__main__":
test = MappingTest(verbose=False) test = MappingTest(verbose=False)
success = test.test_mapping() success = test.test_mapping()
print(f"\nTest result: {'Success' if success else 'Failed'}") print(f"\nTest result: {'Success' if success else 'Failed'}")

View File

@@ -97,7 +97,8 @@ class TestAutoMode:
# Model field should have simpler description # Model field should have simpler description
model_schema = schema["properties"]["model"] model_schema = schema["properties"]["model"]
assert "enum" not in model_schema assert "enum" not in model_schema
assert "Available:" in model_schema["description"] assert "Native models:" in model_schema["description"]
assert "Defaults to" in model_schema["description"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_auto_mode_requires_model_parameter(self): async def test_auto_mode_requires_model_parameter(self):
@@ -180,8 +181,9 @@ class TestAutoMode:
schema = tool.get_model_field_schema() schema = tool.get_model_field_schema()
assert "enum" not in schema assert "enum" not in schema
assert "Available:" in schema["description"] assert "Native models:" in schema["description"]
assert "'pro'" in schema["description"] assert "'pro'" in schema["description"]
assert "Defaults to" in schema["description"]
finally: finally:
# Restore # Restore

View File

@@ -1,8 +1,7 @@
"""Tests for OpenRouter provider.""" """Tests for OpenRouter provider."""
import os import os
import pytest from unittest.mock import patch
from unittest.mock import patch, MagicMock
from providers.base import ProviderType from providers.base import ProviderType
from providers.openrouter import OpenRouterProvider from providers.openrouter import OpenRouterProvider
@@ -11,65 +10,64 @@ from providers.registry import ModelProviderRegistry
class TestOpenRouterProvider: class TestOpenRouterProvider:
"""Test cases for OpenRouter provider.""" """Test cases for OpenRouter provider."""
def test_provider_initialization(self): def test_provider_initialization(self):
"""Test OpenRouter provider initialization.""" """Test OpenRouter provider initialization."""
provider = OpenRouterProvider(api_key="test-key") provider = OpenRouterProvider(api_key="test-key")
assert provider.api_key == "test-key" assert provider.api_key == "test-key"
assert provider.base_url == "https://openrouter.ai/api/v1" assert provider.base_url == "https://openrouter.ai/api/v1"
assert provider.FRIENDLY_NAME == "OpenRouter" assert provider.FRIENDLY_NAME == "OpenRouter"
def test_custom_headers(self): def test_custom_headers(self):
"""Test OpenRouter custom headers.""" """Test OpenRouter custom headers."""
# Test default headers # Test default headers
assert "HTTP-Referer" in OpenRouterProvider.DEFAULT_HEADERS assert "HTTP-Referer" in OpenRouterProvider.DEFAULT_HEADERS
assert "X-Title" in OpenRouterProvider.DEFAULT_HEADERS assert "X-Title" in OpenRouterProvider.DEFAULT_HEADERS
# Test with environment variables # Test with environment variables
with patch.dict(os.environ, { with patch.dict(os.environ, {"OPENROUTER_REFERER": "https://myapp.com", "OPENROUTER_TITLE": "My App"}):
"OPENROUTER_REFERER": "https://myapp.com",
"OPENROUTER_TITLE": "My App"
}):
from importlib import reload from importlib import reload
import providers.openrouter import providers.openrouter
reload(providers.openrouter) reload(providers.openrouter)
provider = providers.openrouter.OpenRouterProvider(api_key="test-key") provider = providers.openrouter.OpenRouterProvider(api_key="test-key")
assert provider.DEFAULT_HEADERS["HTTP-Referer"] == "https://myapp.com" assert provider.DEFAULT_HEADERS["HTTP-Referer"] == "https://myapp.com"
assert provider.DEFAULT_HEADERS["X-Title"] == "My App" assert provider.DEFAULT_HEADERS["X-Title"] == "My App"
def test_model_validation(self): def test_model_validation(self):
"""Test model validation.""" """Test model validation."""
provider = OpenRouterProvider(api_key="test-key") provider = OpenRouterProvider(api_key="test-key")
# Should accept any model - OpenRouter handles validation # Should accept any model - OpenRouter handles validation
assert provider.validate_model_name("gpt-4") is True assert provider.validate_model_name("gpt-4") is True
assert provider.validate_model_name("claude-3-opus") is True assert provider.validate_model_name("claude-3-opus") is True
assert provider.validate_model_name("any-model-name") is True assert provider.validate_model_name("any-model-name") is True
assert provider.validate_model_name("GPT-4") is True assert provider.validate_model_name("GPT-4") is True
assert provider.validate_model_name("unknown-model") is True assert provider.validate_model_name("unknown-model") is True
def test_get_capabilities(self): def test_get_capabilities(self):
"""Test capability generation.""" """Test capability generation."""
provider = OpenRouterProvider(api_key="test-key") provider = OpenRouterProvider(api_key="test-key")
# Test with a model in the registry (using alias) # Test with a model in the registry (using alias)
caps = provider.get_capabilities("gpt4o") caps = provider.get_capabilities("gpt4o")
assert caps.provider == ProviderType.OPENROUTER assert caps.provider == ProviderType.OPENROUTER
assert caps.model_name == "openai/gpt-4o" # Resolved name assert caps.model_name == "openai/gpt-4o" # Resolved name
assert caps.friendly_name == "OpenRouter" assert caps.friendly_name == "OpenRouter"
# Test with a model not in registry - should get generic capabilities # Test with a model not in registry - should get generic capabilities
caps = provider.get_capabilities("unknown-model") caps = provider.get_capabilities("unknown-model")
assert caps.provider == ProviderType.OPENROUTER assert caps.provider == ProviderType.OPENROUTER
assert caps.model_name == "unknown-model" assert caps.model_name == "unknown-model"
assert caps.max_tokens == 32_768 # Safe default assert caps.max_tokens == 32_768 # Safe default
assert hasattr(caps, '_is_generic') and caps._is_generic is True assert hasattr(caps, "_is_generic") and caps._is_generic is True
def test_model_alias_resolution(self): def test_model_alias_resolution(self):
"""Test model alias resolution.""" """Test model alias resolution."""
provider = OpenRouterProvider(api_key="test-key") provider = OpenRouterProvider(api_key="test-key")
# Test alias resolution # Test alias resolution
assert provider._resolve_model_name("opus") == "anthropic/claude-3-opus" assert provider._resolve_model_name("opus") == "anthropic/claude-3-opus"
assert provider._resolve_model_name("sonnet") == "anthropic/claude-3-sonnet" assert provider._resolve_model_name("sonnet") == "anthropic/claude-3-sonnet"
@@ -79,30 +77,30 @@ class TestOpenRouterProvider:
assert provider._resolve_model_name("mistral") == "mistral/mistral-large" assert provider._resolve_model_name("mistral") == "mistral/mistral-large"
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-coder" assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-coder"
assert provider._resolve_model_name("coder") == "deepseek/deepseek-coder" assert provider._resolve_model_name("coder") == "deepseek/deepseek-coder"
# Test case-insensitive # Test case-insensitive
assert provider._resolve_model_name("OPUS") == "anthropic/claude-3-opus" assert provider._resolve_model_name("OPUS") == "anthropic/claude-3-opus"
assert provider._resolve_model_name("GPT4O") == "openai/gpt-4o" assert provider._resolve_model_name("GPT4O") == "openai/gpt-4o"
assert provider._resolve_model_name("Mistral") == "mistral/mistral-large" assert provider._resolve_model_name("Mistral") == "mistral/mistral-large"
assert provider._resolve_model_name("CLAUDE") == "anthropic/claude-3-sonnet" assert provider._resolve_model_name("CLAUDE") == "anthropic/claude-3-sonnet"
# Test direct model names (should pass through unchanged) # Test direct model names (should pass through unchanged)
assert provider._resolve_model_name("anthropic/claude-3-opus") == "anthropic/claude-3-opus" assert provider._resolve_model_name("anthropic/claude-3-opus") == "anthropic/claude-3-opus"
assert provider._resolve_model_name("openai/gpt-4o") == "openai/gpt-4o" assert provider._resolve_model_name("openai/gpt-4o") == "openai/gpt-4o"
# Test unknown models pass through # Test unknown models pass through
assert provider._resolve_model_name("unknown-model") == "unknown-model" assert provider._resolve_model_name("unknown-model") == "unknown-model"
assert provider._resolve_model_name("custom/model-v2") == "custom/model-v2" assert provider._resolve_model_name("custom/model-v2") == "custom/model-v2"
def test_openrouter_registration(self): def test_openrouter_registration(self):
"""Test OpenRouter can be registered and retrieved.""" """Test OpenRouter can be registered and retrieved."""
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}): with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
# Clean up any existing registration # Clean up any existing registration
ModelProviderRegistry.unregister_provider(ProviderType.OPENROUTER) ModelProviderRegistry.unregister_provider(ProviderType.OPENROUTER)
# Register the provider # Register the provider
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
# Retrieve and verify # Retrieve and verify
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER) provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
assert provider is not None assert provider is not None
@@ -111,53 +109,53 @@ class TestOpenRouterProvider:
class TestOpenRouterRegistry: class TestOpenRouterRegistry:
"""Test cases for OpenRouter model registry.""" """Test cases for OpenRouter model registry."""
def test_registry_loading(self): def test_registry_loading(self):
"""Test registry loads models from config.""" """Test registry loads models from config."""
from providers.openrouter_registry import OpenRouterModelRegistry from providers.openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
# Should have loaded models # Should have loaded models
models = registry.list_models() models = registry.list_models()
assert len(models) > 0 assert len(models) > 0
assert "anthropic/claude-3-opus" in models assert "anthropic/claude-3-opus" in models
assert "openai/gpt-4o" in models assert "openai/gpt-4o" in models
# Should have loaded aliases # Should have loaded aliases
aliases = registry.list_aliases() aliases = registry.list_aliases()
assert len(aliases) > 0 assert len(aliases) > 0
assert "opus" in aliases assert "opus" in aliases
assert "gpt4o" in aliases assert "gpt4o" in aliases
assert "claude" in aliases assert "claude" in aliases
def test_registry_capabilities(self): def test_registry_capabilities(self):
"""Test registry provides correct capabilities.""" """Test registry provides correct capabilities."""
from providers.openrouter_registry import OpenRouterModelRegistry from providers.openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
# Test known model # Test known model
caps = registry.get_capabilities("opus") caps = registry.get_capabilities("opus")
assert caps is not None assert caps is not None
assert caps.model_name == "anthropic/claude-3-opus" assert caps.model_name == "anthropic/claude-3-opus"
assert caps.max_tokens == 200000 # Claude's context window assert caps.max_tokens == 200000 # Claude's context window
# Test using full model name # Test using full model name
caps = registry.get_capabilities("anthropic/claude-3-opus") caps = registry.get_capabilities("anthropic/claude-3-opus")
assert caps is not None assert caps is not None
assert caps.model_name == "anthropic/claude-3-opus" assert caps.model_name == "anthropic/claude-3-opus"
# Test unknown model # Test unknown model
caps = registry.get_capabilities("non-existent-model") caps = registry.get_capabilities("non-existent-model")
assert caps is None assert caps is None
def test_multiple_aliases_same_model(self): def test_multiple_aliases_same_model(self):
"""Test multiple aliases pointing to same model.""" """Test multiple aliases pointing to same model."""
from providers.openrouter_registry import OpenRouterModelRegistry from providers.openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
# All these should resolve to Claude Sonnet # All these should resolve to Claude Sonnet
sonnet_aliases = ["sonnet", "claude", "claude-sonnet", "claude3-sonnet"] sonnet_aliases = ["sonnet", "claude", "claude-sonnet", "claude3-sonnet"]
for alias in sonnet_aliases: for alias in sonnet_aliases:
@@ -166,48 +164,34 @@ class TestOpenRouterRegistry:
assert config.model_name == "anthropic/claude-3-sonnet" assert config.model_name == "anthropic/claude-3-sonnet"
class TestOpenRouterSSRFProtection: class TestOpenRouterFunctionality:
"""Test SSRF protection for OpenRouter.""" """Test OpenRouter-specific functionality."""
def test_url_validation_rejects_private_ips(self): def test_openrouter_always_uses_correct_url(self):
"""Test that private IPs are rejected.""" """Test that OpenRouter always uses the correct base URL."""
provider = OpenRouterProvider(api_key="test-key") provider = OpenRouterProvider(api_key="test-key")
assert provider.base_url == "https://openrouter.ai/api/v1"
# List of private/dangerous IPs to test
dangerous_urls = [ # Even if we try to change it, it should remain the OpenRouter URL
"http://192.168.1.1/api/v1", # (This is a characteristic of the OpenRouter provider)
"http://10.0.0.1/api/v1", provider.base_url = "http://example.com" # Try to change it
"http://172.16.0.1/api/v1", # But new instances should always use the correct URL
"http://169.254.169.254/api/v1", # AWS metadata provider2 = OpenRouterProvider(api_key="test-key")
"http://[::1]/api/v1", # IPv6 localhost assert provider2.base_url == "https://openrouter.ai/api/v1"
"http://0.0.0.0/api/v1",
] def test_openrouter_headers_set_correctly(self):
"""Test that OpenRouter specific headers are set."""
for url in dangerous_urls:
with pytest.raises(ValueError, match="restricted IP|Invalid"):
provider.base_url = url
provider._validate_base_url()
def test_url_validation_allows_public_domains(self):
"""Test that legitimate public domains are allowed."""
provider = OpenRouterProvider(api_key="test-key") provider = OpenRouterProvider(api_key="test-key")
# OpenRouter's actual domain should always be allowed # Check default headers
provider.base_url = "https://openrouter.ai/api/v1" assert "HTTP-Referer" in provider.DEFAULT_HEADERS
provider._validate_base_url() # Should not raise assert "X-Title" in provider.DEFAULT_HEADERS
assert provider.DEFAULT_HEADERS["X-Title"] == "Zen MCP Server"
def test_invalid_url_schemes_rejected(self):
"""Test that non-HTTP(S) schemes are rejected.""" def test_openrouter_model_registry_initialized(self):
"""Test that model registry is properly initialized."""
provider = OpenRouterProvider(api_key="test-key") provider = OpenRouterProvider(api_key="test-key")
invalid_urls = [ # Registry should be initialized
"ftp://example.com/api", assert hasattr(provider, '_registry')
"file:///etc/passwd", assert provider._registry is not None
"gopher://example.com",
"javascript:alert(1)",
]
for url in invalid_urls:
with pytest.raises(ValueError, match="Invalid URL scheme"):
provider.base_url = url
provider._validate_base_url()

View File

@@ -2,42 +2,34 @@
import json import json
import os import os
import pytest
import tempfile import tempfile
from pathlib import Path
from providers.openrouter_registry import OpenRouterModelRegistry, OpenRouterModelConfig import pytest
from providers.base import ProviderType from providers.base import ProviderType
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry
class TestOpenRouterModelRegistry: class TestOpenRouterModelRegistry:
"""Test cases for OpenRouter model registry.""" """Test cases for OpenRouter model registry."""
def test_registry_initialization(self): def test_registry_initialization(self):
"""Test registry initializes with default config.""" """Test registry initializes with default config."""
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
# Should load models from default location # Should load models from default location
assert len(registry.list_models()) > 0 assert len(registry.list_models()) > 0
assert len(registry.list_aliases()) > 0 assert len(registry.list_aliases()) > 0
def test_custom_config_path(self): def test_custom_config_path(self):
"""Test registry with custom config path.""" """Test registry with custom config path."""
# Create temporary config # Create temporary config
config_data = { config_data = {"models": [{"model_name": "test/model-1", "aliases": ["test1", "t1"], "context_window": 4096}]}
"models": [
{ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
"model_name": "test/model-1",
"aliases": ["test1", "t1"],
"context_window": 4096
}
]
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f) json.dump(config_data, f)
temp_path = f.name temp_path = f.name
try: try:
registry = OpenRouterModelRegistry(config_path=temp_path) registry = OpenRouterModelRegistry(config_path=temp_path)
assert len(registry.list_models()) == 1 assert len(registry.list_models()) == 1
@@ -46,48 +38,40 @@ class TestOpenRouterModelRegistry:
assert "t1" in registry.list_aliases() assert "t1" in registry.list_aliases()
finally: finally:
os.unlink(temp_path) os.unlink(temp_path)
def test_environment_variable_override(self): def test_environment_variable_override(self):
"""Test OPENROUTER_MODELS_PATH environment variable.""" """Test OPENROUTER_MODELS_PATH environment variable."""
# Create custom config # Create custom config
config_data = { config_data = {"models": [{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192}]}
"models": [
{ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
"model_name": "env/model",
"aliases": ["envtest"],
"context_window": 8192
}
]
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f) json.dump(config_data, f)
temp_path = f.name temp_path = f.name
try: try:
# Set environment variable # Set environment variable
original_env = os.environ.get('OPENROUTER_MODELS_PATH') original_env = os.environ.get("OPENROUTER_MODELS_PATH")
os.environ['OPENROUTER_MODELS_PATH'] = temp_path os.environ["OPENROUTER_MODELS_PATH"] = temp_path
# Create registry without explicit path # Create registry without explicit path
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
# Should load from environment path # Should load from environment path
assert "env/model" in registry.list_models() assert "env/model" in registry.list_models()
assert "envtest" in registry.list_aliases() assert "envtest" in registry.list_aliases()
finally: finally:
# Restore environment # Restore environment
if original_env is not None: if original_env is not None:
os.environ['OPENROUTER_MODELS_PATH'] = original_env os.environ["OPENROUTER_MODELS_PATH"] = original_env
else: else:
del os.environ['OPENROUTER_MODELS_PATH'] del os.environ["OPENROUTER_MODELS_PATH"]
os.unlink(temp_path) os.unlink(temp_path)
def test_alias_resolution(self): def test_alias_resolution(self):
"""Test alias resolution functionality.""" """Test alias resolution functionality."""
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
# Test various aliases # Test various aliases
test_cases = [ test_cases = [
("opus", "anthropic/claude-3-opus"), ("opus", "anthropic/claude-3-opus"),
@@ -97,75 +81,71 @@ class TestOpenRouterModelRegistry:
("4o", "openai/gpt-4o"), ("4o", "openai/gpt-4o"),
("mistral", "mistral/mistral-large"), ("mistral", "mistral/mistral-large"),
] ]
for alias, expected_model in test_cases: for alias, expected_model in test_cases:
config = registry.resolve(alias) config = registry.resolve(alias)
assert config is not None, f"Failed to resolve alias '{alias}'" assert config is not None, f"Failed to resolve alias '{alias}'"
assert config.model_name == expected_model assert config.model_name == expected_model
def test_direct_model_name_lookup(self): def test_direct_model_name_lookup(self):
"""Test looking up models by their full name.""" """Test looking up models by their full name."""
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
# Should be able to look up by full model name # Should be able to look up by full model name
config = registry.resolve("anthropic/claude-3-opus") config = registry.resolve("anthropic/claude-3-opus")
assert config is not None assert config is not None
assert config.model_name == "anthropic/claude-3-opus" assert config.model_name == "anthropic/claude-3-opus"
config = registry.resolve("openai/gpt-4o") config = registry.resolve("openai/gpt-4o")
assert config is not None assert config is not None
assert config.model_name == "openai/gpt-4o" assert config.model_name == "openai/gpt-4o"
def test_unknown_model_resolution(self): def test_unknown_model_resolution(self):
"""Test resolution of unknown models.""" """Test resolution of unknown models."""
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
# Unknown aliases should return None # Unknown aliases should return None
assert registry.resolve("unknown-alias") is None assert registry.resolve("unknown-alias") is None
assert registry.resolve("") is None assert registry.resolve("") is None
assert registry.resolve("non-existent") is None assert registry.resolve("non-existent") is None
def test_model_capabilities_conversion(self): def test_model_capabilities_conversion(self):
"""Test conversion to ModelCapabilities.""" """Test conversion to ModelCapabilities."""
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
config = registry.resolve("opus") config = registry.resolve("opus")
assert config is not None assert config is not None
caps = config.to_capabilities() caps = config.to_capabilities()
assert caps.provider == ProviderType.OPENROUTER assert caps.provider == ProviderType.OPENROUTER
assert caps.model_name == "anthropic/claude-3-opus" assert caps.model_name == "anthropic/claude-3-opus"
assert caps.friendly_name == "OpenRouter" assert caps.friendly_name == "OpenRouter"
assert caps.max_tokens == 200000 assert caps.max_tokens == 200000
assert not caps.supports_extended_thinking assert not caps.supports_extended_thinking
def test_duplicate_alias_detection(self): def test_duplicate_alias_detection(self):
"""Test that duplicate aliases are detected.""" """Test that duplicate aliases are detected."""
config_data = { config_data = {
"models": [ "models": [
{ {"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096},
"model_name": "test/model-1",
"aliases": ["dupe"],
"context_window": 4096
},
{ {
"model_name": "test/model-2", "model_name": "test/model-2",
"aliases": ["DUPE"], # Same alias, different case "aliases": ["DUPE"], # Same alias, different case
"context_window": 8192 "context_window": 8192,
} },
] ]
} }
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(config_data, f) json.dump(config_data, f)
temp_path = f.name temp_path = f.name
try: try:
with pytest.raises(ValueError, match="Duplicate alias"): with pytest.raises(ValueError, match="Duplicate alias"):
OpenRouterModelRegistry(config_path=temp_path) OpenRouterModelRegistry(config_path=temp_path)
finally: finally:
os.unlink(temp_path) os.unlink(temp_path)
def test_backwards_compatibility_max_tokens(self): def test_backwards_compatibility_max_tokens(self):
"""Test backwards compatibility with old max_tokens field.""" """Test backwards compatibility with old max_tokens field."""
config_data = { config_data = {
@@ -174,44 +154,44 @@ class TestOpenRouterModelRegistry:
"model_name": "test/old-model", "model_name": "test/old-model",
"aliases": ["old"], "aliases": ["old"],
"max_tokens": 16384, # Old field name "max_tokens": 16384, # Old field name
"supports_extended_thinking": False "supports_extended_thinking": False,
} }
] ]
} }
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(config_data, f) json.dump(config_data, f)
temp_path = f.name temp_path = f.name
try: try:
registry = OpenRouterModelRegistry(config_path=temp_path) registry = OpenRouterModelRegistry(config_path=temp_path)
config = registry.resolve("old") config = registry.resolve("old")
assert config is not None assert config is not None
assert config.context_window == 16384 # Should be converted assert config.context_window == 16384 # Should be converted
# Check capabilities still work # Check capabilities still work
caps = config.to_capabilities() caps = config.to_capabilities()
assert caps.max_tokens == 16384 assert caps.max_tokens == 16384
finally: finally:
os.unlink(temp_path) os.unlink(temp_path)
def test_missing_config_file(self): def test_missing_config_file(self):
"""Test behavior with missing config file.""" """Test behavior with missing config file."""
# Use a non-existent path # Use a non-existent path
registry = OpenRouterModelRegistry(config_path="/non/existent/path.json") registry = OpenRouterModelRegistry(config_path="/non/existent/path.json")
# Should initialize with empty maps # Should initialize with empty maps
assert len(registry.list_models()) == 0 assert len(registry.list_models()) == 0
assert len(registry.list_aliases()) == 0 assert len(registry.list_aliases()) == 0
assert registry.resolve("anything") is None assert registry.resolve("anything") is None
def test_invalid_json_config(self): def test_invalid_json_config(self):
"""Test handling of invalid JSON.""" """Test handling of invalid JSON."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
f.write("{ invalid json }") f.write("{ invalid json }")
temp_path = f.name temp_path = f.name
try: try:
registry = OpenRouterModelRegistry(config_path=temp_path) registry = OpenRouterModelRegistry(config_path=temp_path)
# Should handle gracefully and initialize empty # Should handle gracefully and initialize empty
@@ -219,7 +199,7 @@ class TestOpenRouterModelRegistry:
assert len(registry.list_aliases()) == 0 assert len(registry.list_aliases()) == 0
finally: finally:
os.unlink(temp_path) os.unlink(temp_path)
def test_model_with_all_capabilities(self): def test_model_with_all_capabilities(self):
"""Test model with all capability flags.""" """Test model with all capability flags."""
config = OpenRouterModelConfig( config = OpenRouterModelConfig(
@@ -231,13 +211,13 @@ class TestOpenRouterModelRegistry:
supports_streaming=True, supports_streaming=True,
supports_function_calling=True, supports_function_calling=True,
supports_json_mode=True, supports_json_mode=True,
description="Fully featured test model" description="Fully featured test model",
) )
caps = config.to_capabilities() caps = config.to_capabilities()
assert caps.max_tokens == 128000 assert caps.max_tokens == 128000
assert caps.supports_extended_thinking assert caps.supports_extended_thinking
assert caps.supports_system_prompts assert caps.supports_system_prompts
assert caps.supports_streaming assert caps.supports_streaming
assert caps.supports_function_calling assert caps.supports_function_calling
# Note: supports_json_mode is not in ModelCapabilities yet # Note: supports_json_mode is not in ModelCapabilities yet

View File

@@ -57,15 +57,28 @@ class ToolRequest(BaseModel):
# Higher values allow for more complex reasoning but increase latency and cost # Higher values allow for more complex reasoning but increase latency and cost
thinking_mode: Optional[Literal["minimal", "low", "medium", "high", "max"]] = Field( thinking_mode: Optional[Literal["minimal", "low", "medium", "high", "max"]] = Field(
None, None,
description="Thinking depth: minimal (0.5% of model max), low (8%), medium (33%), high (67%), max (100% of model max)", description=(
"Thinking depth: minimal (0.5% of model max), low (8%), medium (33%), high (67%), "
"max (100% of model max)"
),
) )
use_websearch: Optional[bool] = Field( use_websearch: Optional[bool] = Field(
True, True,
description="Enable web search for documentation, best practices, and current information. When enabled, the model can request Claude to perform web searches and share results back during conversations. Particularly useful for: brainstorming sessions, architectural design discussions, exploring industry best practices, working with specific frameworks/technologies, researching solutions to complex problems, or when current documentation and community insights would enhance the analysis.", description=(
"Enable web search for documentation, best practices, and current information. "
"When enabled, the model can request Claude to perform web searches and share results back "
"during conversations. Particularly useful for: brainstorming sessions, architectural design "
"discussions, exploring industry best practices, working with specific frameworks/technologies, "
"researching solutions to complex problems, or when current documentation and community insights "
"would enhance the analysis."
),
) )
continuation_id: Optional[str] = Field( continuation_id: Optional[str] = Field(
None, None,
description="Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.", description=(
"Thread continuation ID for multi-turn conversations. Can be used to continue conversations "
"across different tools. Only provide this if continuing a previous conversation thread."
),
) )
@@ -152,21 +165,48 @@ class BaseTool(ABC):
Returns: Returns:
Dict containing the model field JSON schema Dict containing the model field JSON schema
""" """
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
import os import os
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
# Check if OpenRouter is configured # Check if OpenRouter is configured
has_openrouter = bool(os.getenv("OPENROUTER_API_KEY") and has_openrouter = bool(
os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here") os.getenv("OPENROUTER_API_KEY") and os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here"
)
if IS_AUTO_MODE: if IS_AUTO_MODE:
# In auto mode, model is required and we provide detailed descriptions # In auto mode, model is required and we provide detailed descriptions
model_desc_parts = ["Choose the best model for this task based on these capabilities:"] model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
for model, desc in MODEL_CAPABILITIES_DESC.items(): for model, desc in MODEL_CAPABILITIES_DESC.items():
model_desc_parts.append(f"- '{model}': {desc}") model_desc_parts.append(f"- '{model}': {desc}")
if has_openrouter: if has_openrouter:
model_desc_parts.append("\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large'). Check openrouter.ai/models for available models.") # Add OpenRouter aliases from the registry
try:
# Import registry directly to show available aliases
# This works even without an API key
from providers.openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
aliases = registry.list_aliases()
# Show ALL aliases from the configuration
if aliases:
# Show all aliases so Claude knows every option available
all_aliases = sorted(aliases)
alias_list = ", ".join(f"'{a}'" for a in all_aliases)
model_desc_parts.append(
f"\nOpenRouter models available via aliases: {alias_list}"
)
else:
model_desc_parts.append(
"\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter."
)
except Exception:
# Fallback if registry fails to load
model_desc_parts.append(
"\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
)
return { return {
"type": "string", "type": "string",
@@ -177,12 +217,33 @@ class BaseTool(ABC):
# Normal mode - model is optional with default # Normal mode - model is optional with default
available_models = list(MODEL_CAPABILITIES_DESC.keys()) available_models = list(MODEL_CAPABILITIES_DESC.keys())
models_str = ", ".join(f"'{m}'" for m in available_models) models_str = ", ".join(f"'{m}'" for m in available_models)
description = f"Model to use. Native models: {models_str}." description = f"Model to use. Native models: {models_str}."
if has_openrouter: if has_openrouter:
description += " OpenRouter: Any model available on openrouter.ai (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')." # Add OpenRouter aliases
try:
# Import registry directly to show available aliases
# This works even without an API key
from providers.openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
aliases = registry.list_aliases()
# Show ALL aliases from the configuration
if aliases:
# Show all aliases so Claude knows every option available
all_aliases = sorted(aliases)
alias_list = ", ".join(f"'{a}'" for a in all_aliases)
description += f" OpenRouter aliases: {alias_list}."
else:
description += " OpenRouter: Any model available on openrouter.ai."
except Exception:
description += (
" OpenRouter: Any model available on openrouter.ai "
"(e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
)
description += f" Defaults to '{DEFAULT_MODEL}' if not specified." description += f" Defaults to '{DEFAULT_MODEL}' if not specified."
return { return {
"type": "string", "type": "string",
"description": description, "description": description,