- OpenRouter model configuration registry
- Model definition file for users to be able to control
- Additional tests
- Update instructions
This commit is contained in:
Fahad
2025-06-13 06:33:12 +04:00
parent cd1105b741
commit 2cdb92460b
12 changed files with 417 additions and 381 deletions

View File

@@ -56,11 +56,13 @@ MODEL_CAPABILITIES_DESC = {
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
# Full model names also supported
"gemini-2.5-flash-preview-05-20": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
"gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
"gemini-2.5-pro-preview-06-05": (
"Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis"
),
}
# Note: When only OpenRouter is configured, these model aliases automatically map to equivalent models:
# - "flash" → "google/gemini-flash-1.5-8b"
# - "flash" → "google/gemini-flash-1.5-8b"
# - "pro" → "google/gemini-pro-1.5"
# - "o3" → "openai/gpt-4o"
# - "o3-mini" → "openai/gpt-4o-mini"

View File

@@ -141,7 +141,11 @@ trace issues to their root cause, and provide actionable solutions.
IMPORTANT: If you lack critical information to proceed (e.g., missing files, ambiguous error details,
insufficient context), OR if the provided diagnostics (log files, crash reports, stack traces) appear irrelevant,
incomplete, or insufficient for proper analysis, you MUST respond ONLY with this JSON format:
{"status": "requires_clarification", "question": "What specific information you need from Claude or the user to proceed with debugging", "files_needed": ["file1.py", "file2.py"]}
{
"status": "requires_clarification",
"question": "What specific information you need from Claude or the user to proceed with debugging",
"files_needed": ["file1.py", "file2.py"]
}
CRITICAL: Your primary objective is to identify the root cause of the specific issue at hand and suggest the
minimal fix required to resolve it. Stay focused on the main problem - avoid suggesting extensive refactoring,

View File

@@ -1,12 +1,8 @@
"""OpenAI model provider implementation."""
import logging
from typing import Optional
from .base import (
FixedTemperatureConstraint,
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
@@ -34,7 +30,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
kwargs.setdefault("base_url", "https://api.openai.com/v1")
super().__init__(api_key, **kwargs)
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific OpenAI model."""
if model_name not in self.SUPPORTED_MODELS:
@@ -62,7 +57,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
temperature_constraint=temp_constraint,
)
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.OPENAI
@@ -76,4 +70,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
# Currently no OpenAI models support extended thinking
# This may change with future O3 models
return False

View File

@@ -1,12 +1,12 @@
"""Base class for OpenAI-compatible API providers."""
import ipaddress
import logging
import os
import socket
from abc import abstractmethod
from typing import Optional
from urllib.parse import urlparse
import ipaddress
import socket
from openai import OpenAI
@@ -15,25 +15,24 @@ from .base import (
ModelProvider,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
class OpenAICompatibleProvider(ModelProvider):
"""Base class for any provider using an OpenAI-compatible API.
This includes:
- Direct OpenAI API
- OpenRouter
- Any other OpenAI-compatible endpoint
"""
DEFAULT_HEADERS = {}
FRIENDLY_NAME = "OpenAI Compatible"
def __init__(self, api_key: str, base_url: str = None, **kwargs):
"""Initialize the provider with API key and optional base URL.
Args:
api_key: API key for authentication
base_url: Base URL for the API endpoint
@@ -44,21 +43,21 @@ class OpenAICompatibleProvider(ModelProvider):
self.base_url = base_url
self.organization = kwargs.get("organization")
self.allowed_models = self._parse_allowed_models()
# Validate base URL for security
if self.base_url:
self._validate_base_url()
# Warn if using external URL without authentication
if self.base_url and not self._is_localhost_url() and not api_key:
logging.warning(
f"Using external URL '{self.base_url}' without API key. "
"This may be insecure. Consider setting an API key for authentication."
)
def _parse_allowed_models(self) -> Optional[set[str]]:
"""Parse allowed models from environment variable.
Returns:
Set of allowed model names (lowercase) or None if not configured
"""
@@ -66,108 +65,108 @@ class OpenAICompatibleProvider(ModelProvider):
provider_type = self.get_provider_type().value.upper()
env_var = f"{provider_type}_ALLOWED_MODELS"
models_str = os.getenv(env_var, "")
if models_str:
# Parse and normalize to lowercase for case-insensitive comparison
models = set(m.strip().lower() for m in models_str.split(",") if m.strip())
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
if models:
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
return models
# Log warning if no allow-list configured for proxy providers
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
logging.warning(
f"No model allow-list configured for {self.FRIENDLY_NAME}. "
f"Set {env_var} to restrict model access and control costs."
)
return None
def _is_localhost_url(self) -> bool:
"""Check if the base URL points to localhost.
Returns:
True if URL is localhost, False otherwise
"""
if not self.base_url:
return False
try:
parsed = urlparse(self.base_url)
hostname = parsed.hostname
# Check for common localhost patterns
if hostname in ['localhost', '127.0.0.1', '::1']:
if hostname in ["localhost", "127.0.0.1", "::1"]:
return True
return False
except Exception:
return False
def _validate_base_url(self) -> None:
"""Validate base URL for security (SSRF protection).
Raises:
ValueError: If URL is invalid or potentially unsafe
"""
if not self.base_url:
return
try:
parsed = urlparse(self.base_url)
# Check URL scheme - only allow http/https
if parsed.scheme not in ('http', 'https'):
if parsed.scheme not in ("http", "https"):
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
# Check hostname exists
if not parsed.hostname:
raise ValueError("URL must include a hostname")
# Check port - allow only standard HTTP/HTTPS ports
port = parsed.port
if port is None:
port = 443 if parsed.scheme == 'https' else 80
port = 443 if parsed.scheme == "https" else 80
# Allow common HTTP ports and some alternative ports
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
if port not in allowed_ports:
raise ValueError(
f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}"
)
raise ValueError(f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}")
# Check against allowed domains if configured
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
if allowed_domains:
hostname_lower = parsed.hostname.lower()
if not any(
hostname_lower == domain or
hostname_lower.endswith('.' + domain)
for domain in allowed_domains
hostname_lower == domain or hostname_lower.endswith("." + domain) for domain in allowed_domains
):
raise ValueError(
f"Domain not in allow-list: {parsed.hostname}. "
f"Allowed domains: {allowed_domains}"
f"Domain not in allow-list: {parsed.hostname}. " f"Allowed domains: {allowed_domains}"
)
# Try to resolve hostname and check if it's a private IP
# Skip for localhost addresses which are commonly used for development
if parsed.hostname not in ['localhost', '127.0.0.1', '::1']:
if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
try:
# Get all IP addresses for the hostname
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
for family, _, _, _, sockaddr in addr_info:
for _family, _, _, _, sockaddr in addr_info:
ip_str = sockaddr[0]
try:
ip = ipaddress.ip_address(ip_str)
# Check for dangerous IP ranges
if (ip.is_private or ip.is_loopback or ip.is_link_local or
ip.is_multicast or ip.is_reserved or ip.is_unspecified):
if (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
):
raise ValueError(
f"URL resolves to restricted IP address: {ip_str}. "
"This could be a security risk (SSRF)."
@@ -177,16 +176,16 @@ class OpenAICompatibleProvider(ModelProvider):
if "restricted IP address" in str(ve):
raise
continue
except socket.gaierror as e:
# If we can't resolve the hostname, it's suspicious
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
@property
def client(self):
"""Lazy initialization of OpenAI client with security checks."""
@@ -194,21 +193,21 @@ class OpenAICompatibleProvider(ModelProvider):
client_kwargs = {
"api_key": self.api_key,
}
if self.base_url:
client_kwargs["base_url"] = self.base_url
if self.organization:
client_kwargs["organization"] = self.organization
# Add default headers if any
if self.DEFAULT_HEADERS:
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
self._client = OpenAI(**client_kwargs)
return self._client
def generate_content(
self,
prompt: str,
@@ -219,7 +218,7 @@ class OpenAICompatibleProvider(ModelProvider):
**kwargs,
) -> ModelResponse:
"""Generate content using the OpenAI-compatible API.
Args:
prompt: User prompt to send to the model
model_name: Name of the model to use
@@ -227,50 +226,49 @@ class OpenAICompatibleProvider(ModelProvider):
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Validate model name against allow-list
if not self.validate_model_name(model_name):
raise ValueError(
f"Model '{model_name}' not in allowed models list. "
f"Allowed models: {self.allowed_models}"
f"Model '{model_name}' not in allowed models list. " f"Allowed models: {self.allowed_models}"
)
# Validate parameters
self.validate_parameters(model_name, temperature)
# Prepare messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# Prepare completion parameters
completion_params = {
"model": model_name,
"messages": messages,
"temperature": temperature,
}
# Add max tokens if specified
if max_output_tokens:
completion_params["max_tokens"] = max_output_tokens
# Add any additional OpenAI-specific parameters
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
completion_params[key] = value
try:
# Generate completion
response = self.client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
@@ -284,39 +282,39 @@ class OpenAICompatibleProvider(ModelProvider):
"created": response.created,
},
)
except Exception as e:
# Log error and re-raise with more context
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from e
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text.
Uses a layered approach:
1. Try provider-specific token counting endpoint
2. Try tiktoken for known model families
3. Fall back to character-based estimation
Args:
text: Text to count tokens for
model_name: Model name for tokenizer selection
Returns:
Estimated token count
"""
# 1. Check if provider has a remote token counting endpoint
if hasattr(self, 'count_tokens_remote'):
if hasattr(self, "count_tokens_remote"):
try:
return self.count_tokens_remote(text, model_name)
except Exception as e:
logging.debug(f"Remote token counting failed: {e}")
# 2. Try tiktoken for known models
try:
import tiktoken
# Try to get encoding for the specific model
try:
encoding = tiktoken.encoding_for_model(model_name)
@@ -326,24 +324,24 @@ class OpenAICompatibleProvider(ModelProvider):
encoding = tiktoken.get_encoding("cl100k_base")
else:
encoding = tiktoken.get_encoding("cl100k_base") # Default
return len(encoding.encode(text))
except (ImportError, Exception) as e:
logging.debug(f"Tiktoken not available or failed: {e}")
# 3. Fall back to character-based estimation
logging.warning(
f"No specific tokenizer available for '{model_name}'. "
"Using character-based estimation (~4 chars per token)."
)
return len(text) // 4
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters.
For proxy providers, this may use generic capabilities.
Args:
model_name: Model to validate for
temperature: Temperature to validate
@@ -351,67 +349,66 @@ class OpenAICompatibleProvider(ModelProvider):
"""
try:
capabilities = self.get_capabilities(model_name)
# Check if we're using generic capabilities
if hasattr(capabilities, '_is_generic'):
if hasattr(capabilities, "_is_generic"):
logging.debug(
f"Using generic parameter validation for {model_name}. "
"Actual model constraints may differ."
f"Using generic parameter validation for {model_name}. " "Actual model constraints may differ."
)
# Validate temperature using parent class method
super().validate_parameters(model_name, temperature, **kwargs)
except Exception as e:
# For proxy providers, we might not have accurate capabilities
# Log warning but don't fail
logging.warning(f"Parameter validation limited for {model_name}: {e}")
def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from OpenAI response.
Args:
response: OpenAI API response object
Returns:
Dictionary with usage statistics
"""
usage = {}
if hasattr(response, "usage") and response.usage:
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
return usage
@abstractmethod
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def get_provider_type(self) -> ProviderType:
"""Get the provider type.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported.
Must be implemented by subclasses.
"""
pass
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode.
Default is False for OpenAI-compatible providers.
"""
return False
return False

View File

@@ -16,63 +16,61 @@ from .openrouter_registry import OpenRouterModelRegistry
class OpenRouterProvider(OpenAICompatibleProvider):
"""OpenRouter unified API provider.
OpenRouter provides access to multiple AI models through a single API endpoint.
See https://openrouter.ai for available models and pricing.
"""
FRIENDLY_NAME = "OpenRouter"
# Custom headers required by OpenRouter
DEFAULT_HEADERS = {
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
}
# Model registry for managing configurations and aliases
_registry: Optional[OpenRouterModelRegistry] = None
def __init__(self, api_key: str, **kwargs):
"""Initialize OpenRouter provider.
Args:
api_key: OpenRouter API key
**kwargs: Additional configuration
"""
# Always use OpenRouter's base URL
super().__init__(api_key, base_url="https://openrouter.ai/api/v1", **kwargs)
# Initialize model registry
if OpenRouterProvider._registry is None:
OpenRouterProvider._registry = OpenRouterModelRegistry()
# Log loaded models and aliases
models = self._registry.list_models()
aliases = self._registry.list_aliases()
logging.info(
f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases"
)
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
def _parse_allowed_models(self) -> None:
"""Override to disable environment-based allow-list.
OpenRouter model access is controlled via the OpenRouter dashboard,
not through environment variables.
"""
return None
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model aliases to OpenRouter model names.
Args:
model_name: Input model name or alias
Returns:
Resolved OpenRouter model name
"""
# Try to resolve through registry
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
@@ -82,30 +80,30 @@ class OpenRouterProvider(OpenAICompatibleProvider):
# This allows using models not in our config file
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a model.
Args:
model_name: Name of the model (or alias)
Returns:
ModelCapabilities from registry or generic defaults
"""
# Try to get from registry first
capabilities = self._registry.get_capabilities(model_name)
if capabilities:
return capabilities
else:
# Resolve any potential aliases and create generic capabilities
resolved_name = self._resolve_model_name(model_name)
logging.debug(
f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
"Consider adding to openrouter_models.json for specific capabilities."
)
# Create generic capabilities with conservative defaults
capabilities = ModelCapabilities(
provider=ProviderType.OPENROUTER,
@@ -118,31 +116,31 @@ class OpenRouterProvider(OpenAICompatibleProvider):
supports_function_calling=False,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
)
# Mark as generic for validation purposes
capabilities._is_generic = True
return capabilities
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.OPENROUTER
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is allowed.
For OpenRouter, we accept any model name. OpenRouter will
validate based on the API key's permissions.
Args:
model_name: Model name to validate
Returns:
Always True - OpenRouter handles validation
"""
# Accept any model name - OpenRouter will validate based on API key permissions
return True
def generate_content(
self,
prompt: str,
@@ -153,7 +151,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
**kwargs,
) -> ModelResponse:
"""Generate content using the OpenRouter API.
Args:
prompt: User prompt to send to the model
model_name: Name of the model (or alias) to use
@@ -161,13 +159,13 @@ class OpenRouterProvider(OpenAICompatibleProvider):
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Resolve model alias to actual OpenRouter model name
resolved_model = self._resolve_model_name(model_name)
# Call parent method with resolved model name
return super().generate_content(
prompt=prompt,
@@ -175,19 +173,19 @@ class OpenRouterProvider(OpenAICompatibleProvider):
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs
**kwargs,
)
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode.
Currently, no models via OpenRouter support extended thinking.
This may change as new models become available.
Args:
model_name: Model to check
Returns:
False (no OpenRouter models currently support thinking mode)
"""
return False
return False

View File

@@ -3,9 +3,9 @@
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
@@ -13,9 +13,9 @@ from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
@dataclass
class OpenRouterModelConfig:
"""Configuration for an OpenRouter model."""
model_name: str
aliases: List[str] = field(default_factory=list)
aliases: list[str] = field(default_factory=list)
context_window: int = 32768 # Total context window size in tokens
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
@@ -23,8 +23,7 @@ class OpenRouterModelConfig:
supports_function_calling: bool = False
supports_json_mode: bool = False
description: str = ""
def to_capabilities(self) -> ModelCapabilities:
"""Convert to ModelCapabilities object."""
return ModelCapabilities(
@@ -42,16 +41,16 @@ class OpenRouterModelConfig:
class OpenRouterModelRegistry:
"""Registry for managing OpenRouter model configurations and aliases."""
def __init__(self, config_path: Optional[str] = None):
"""Initialize the registry.
Args:
config_path: Path to config file. If None, uses default locations.
"""
self.alias_map: Dict[str, str] = {} # alias -> model_name
self.model_map: Dict[str, OpenRouterModelConfig] = {} # model_name -> config
self.alias_map: dict[str, str] = {} # alias -> model_name
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config
# Determine config path
if config_path:
self.config_path = Path(config_path)
@@ -63,86 +62,93 @@ class OpenRouterModelRegistry:
else:
# Default to conf/openrouter_models.json
self.config_path = Path(__file__).parent.parent / "conf" / "openrouter_models.json"
# Load configuration
self.reload()
def reload(self) -> None:
"""Reload configuration from disk."""
try:
configs = self._read_config()
self._build_maps(configs)
logging.info(f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases")
except ValueError as e:
# Re-raise ValueError only for duplicate aliases (critical config errors)
logging.error(f"Failed to load OpenRouter model configuration: {e}")
# Initialize with empty maps on failure
self.alias_map = {}
self.model_map = {}
if "Duplicate alias" in str(e):
raise
except Exception as e:
logging.error(f"Failed to load OpenRouter model configuration: {e}")
# Initialize with empty maps on failure
self.alias_map = {}
self.model_map = {}
def _read_config(self) -> List[OpenRouterModelConfig]:
def _read_config(self) -> list[OpenRouterModelConfig]:
"""Read configuration from file.
Returns:
List of model configurations
"""
if not self.config_path.exists():
logging.warning(f"OpenRouter model config not found at {self.config_path}")
return []
try:
with open(self.config_path, 'r') as f:
with open(self.config_path) as f:
data = json.load(f)
# Parse models
configs = []
for model_data in data.get("models", []):
# Handle backwards compatibility - rename max_tokens to context_window
if 'max_tokens' in model_data and 'context_window' not in model_data:
model_data['context_window'] = model_data.pop('max_tokens')
if "max_tokens" in model_data and "context_window" not in model_data:
model_data["context_window"] = model_data.pop("max_tokens")
config = OpenRouterModelConfig(**model_data)
configs.append(config)
return configs
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in {self.config_path}: {e}")
except Exception as e:
raise ValueError(f"Error reading config from {self.config_path}: {e}")
def _build_maps(self, configs: List[OpenRouterModelConfig]) -> None:
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None:
"""Build alias and model maps from configurations.
Args:
configs: List of model configurations
"""
alias_map = {}
model_map = {}
for config in configs:
# Add to model map
model_map[config.model_name] = config
# Add aliases
for alias in config.aliases:
alias_lower = alias.lower()
if alias_lower in alias_map:
existing_model = alias_map[alias_lower]
raise ValueError(
f"Duplicate alias '{alias}' found for models "
f"'{existing_model}' and '{config.model_name}'"
f"Duplicate alias '{alias}' found for models " f"'{existing_model}' and '{config.model_name}'"
)
alias_map[alias_lower] = config.model_name
# Atomic update
self.alias_map = alias_map
self.model_map = model_map
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]:
"""Resolve a model name or alias to configuration.
Args:
name_or_alias: Model name or alias to resolve
Returns:
Model configuration if found, None otherwise
"""
@@ -151,16 +157,16 @@ class OpenRouterModelRegistry:
if alias_lower in self.alias_map:
model_name = self.alias_map[alias_lower]
return self.model_map.get(model_name)
# Try as direct model name
return self.model_map.get(name_or_alias)
def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]:
"""Get model capabilities for a name or alias.
Args:
name_or_alias: Model name or alias
Returns:
ModelCapabilities if found, None otherwise
"""
@@ -168,11 +174,11 @@ class OpenRouterModelRegistry:
if config:
return config.to_capabilities()
return None
def list_models(self) -> List[str]:
def list_models(self) -> list[str]:
"""List all available model names."""
return list(self.model_map.keys())
def list_aliases(self) -> List[str]:
def list_aliases(self) -> list[str]:
"""List all available aliases."""
return list(self.alias_map.keys())
return list(self.alias_map.keys())

View File

@@ -173,8 +173,7 @@ def configure_providers():
"1. Use only OpenRouter: unset GEMINI_API_KEY and OPENAI_API_KEY\n"
"2. Use only native APIs: unset OPENROUTER_API_KEY\n"
"\n"
"Current configuration will prioritize native APIs over OpenRouter.\n" +
"=" * 70 + "\n"
"Current configuration will prioritize native APIs over OpenRouter.\n" + "=" * 70 + "\n"
)
# Register providers - native APIs first to ensure they take priority
@@ -363,18 +362,22 @@ If something needs clarification or you'd benefit from additional context, simpl
IMPORTANT: When you suggest follow-ups or ask questions, you MUST explicitly instruct Claude to use the continuation_id
to respond. Use clear, direct language based on urgency:
For optional follow-ups: "Please continue this conversation using the continuation_id from this response if you'd like to explore this further."
For optional follow-ups: "Please continue this conversation using the continuation_id from this response if you'd "
"like to explore this further."
For needed responses: "Please respond using the continuation_id from this response - your input is needed to proceed."
For essential/critical responses: "RESPONSE REQUIRED: Please immediately continue using the continuation_id from this response. Cannot proceed without your clarification/input."
For essential/critical responses: "RESPONSE REQUIRED: Please immediately continue using the continuation_id from "
"this response. Cannot proceed without your clarification/input."
This ensures Claude knows both HOW to maintain the conversation thread AND whether a response is optional, needed, or essential.
This ensures Claude knows both HOW to maintain the conversation thread AND whether a response is optional, "
"needed, or essential.
The tool will automatically provide a continuation_id in the structured response that Claude can use in subsequent
tool calls to maintain full conversation context across multiple exchanges.
Remember: Only suggest follow-ups when they would genuinely add value to the discussion, and always instruct Claude to use the continuation_id when you do."""
Remember: Only suggest follow-ups when they would genuinely add value to the discussion, and always instruct "
"Claude to use the continuation_id when you do."""
async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any]:
@@ -411,8 +414,10 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
# Return error asking Claude to restart conversation with full context
raise ValueError(
f"Conversation thread '{continuation_id}' was not found or has expired. "
f"This may happen if the conversation was created more than 1 hour ago or if there was an issue with Redis storage. "
f"Please restart the conversation by providing your full question/prompt without the continuation_id parameter. "
f"This may happen if the conversation was created more than 1 hour ago or if there was an issue "
f"with Redis storage. "
f"Please restart the conversation by providing your full question/prompt without the "
f"continuation_id parameter. "
f"This will create a new conversation thread that can continue with follow-up exchanges."
)
@@ -504,7 +509,8 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
try:
mcp_activity_logger = logging.getLogger("mcp_activity")
mcp_activity_logger.info(
f"CONVERSATION_CONTINUATION: Thread {continuation_id} turn {len(context.turns)} - {len(context.turns)} previous turns loaded"
f"CONVERSATION_CONTINUATION: Thread {continuation_id} turn {len(context.turns)} - "
f"{len(context.turns)} previous turns loaded"
)
except Exception:
pass
@@ -542,7 +548,7 @@ async def handle_get_version() -> list[TextContent]:
# Check configured providers
from providers import ModelProviderRegistry
from providers.base import ProviderType
configured_providers = []
if ModelProviderRegistry.get_provider(ProviderType.GOOGLE):
configured_providers.append("Gemini (flash, pro)")

View File

@@ -4,35 +4,38 @@ Test OpenRouter model mapping
"""
import sys
sys.path.append('/Users/fahad/Developer/gemini-mcp-server')
sys.path.append("/Users/fahad/Developer/gemini-mcp-server")
from simulator_tests.base_test import BaseSimulatorTest
class MappingTest(BaseSimulatorTest):
def test_mapping(self):
"""Test model alias mapping"""
# Test with 'flash' alias - should map to google/gemini-flash-1.5-8b
print("\nTesting 'flash' alias mapping...")
response, continuation_id = self.call_mcp_tool(
"chat",
{
"prompt": "Say 'Hello from Flash model!'",
"model": "flash", # Should be mapped to google/gemini-flash-1.5-8b
"temperature": 0.1
}
"temperature": 0.1,
},
)
if response:
print(f"✅ Flash alias worked!")
print("✅ Flash alias worked!")
print(f"Response: {response[:200]}...")
return True
else:
print("❌ Flash alias failed")
return False
if __name__ == "__main__":
test = MappingTest(verbose=False)
success = test.test_mapping()
print(f"\nTest result: {'Success' if success else 'Failed'}")
print(f"\nTest result: {'Success' if success else 'Failed'}")

View File

@@ -97,7 +97,8 @@ class TestAutoMode:
# Model field should have simpler description
model_schema = schema["properties"]["model"]
assert "enum" not in model_schema
assert "Available:" in model_schema["description"]
assert "Native models:" in model_schema["description"]
assert "Defaults to" in model_schema["description"]
@pytest.mark.asyncio
async def test_auto_mode_requires_model_parameter(self):
@@ -180,8 +181,9 @@ class TestAutoMode:
schema = tool.get_model_field_schema()
assert "enum" not in schema
assert "Available:" in schema["description"]
assert "Native models:" in schema["description"]
assert "'pro'" in schema["description"]
assert "Defaults to" in schema["description"]
finally:
# Restore

View File

@@ -1,8 +1,7 @@
"""Tests for OpenRouter provider."""
import os
import pytest
from unittest.mock import patch, MagicMock
from unittest.mock import patch
from providers.base import ProviderType
from providers.openrouter import OpenRouterProvider
@@ -11,65 +10,64 @@ from providers.registry import ModelProviderRegistry
class TestOpenRouterProvider:
"""Test cases for OpenRouter provider."""
def test_provider_initialization(self):
"""Test OpenRouter provider initialization."""
provider = OpenRouterProvider(api_key="test-key")
assert provider.api_key == "test-key"
assert provider.base_url == "https://openrouter.ai/api/v1"
assert provider.FRIENDLY_NAME == "OpenRouter"
def test_custom_headers(self):
"""Test OpenRouter custom headers."""
# Test default headers
assert "HTTP-Referer" in OpenRouterProvider.DEFAULT_HEADERS
assert "X-Title" in OpenRouterProvider.DEFAULT_HEADERS
# Test with environment variables
with patch.dict(os.environ, {
"OPENROUTER_REFERER": "https://myapp.com",
"OPENROUTER_TITLE": "My App"
}):
with patch.dict(os.environ, {"OPENROUTER_REFERER": "https://myapp.com", "OPENROUTER_TITLE": "My App"}):
from importlib import reload
import providers.openrouter
reload(providers.openrouter)
provider = providers.openrouter.OpenRouterProvider(api_key="test-key")
assert provider.DEFAULT_HEADERS["HTTP-Referer"] == "https://myapp.com"
assert provider.DEFAULT_HEADERS["X-Title"] == "My App"
def test_model_validation(self):
"""Test model validation."""
provider = OpenRouterProvider(api_key="test-key")
# Should accept any model - OpenRouter handles validation
assert provider.validate_model_name("gpt-4") is True
assert provider.validate_model_name("claude-3-opus") is True
assert provider.validate_model_name("any-model-name") is True
assert provider.validate_model_name("GPT-4") is True
assert provider.validate_model_name("unknown-model") is True
def test_get_capabilities(self):
"""Test capability generation."""
provider = OpenRouterProvider(api_key="test-key")
# Test with a model in the registry (using alias)
caps = provider.get_capabilities("gpt4o")
assert caps.provider == ProviderType.OPENROUTER
assert caps.model_name == "openai/gpt-4o" # Resolved name
assert caps.friendly_name == "OpenRouter"
# Test with a model not in registry - should get generic capabilities
caps = provider.get_capabilities("unknown-model")
assert caps.provider == ProviderType.OPENROUTER
assert caps.model_name == "unknown-model"
assert caps.max_tokens == 32_768 # Safe default
assert hasattr(caps, '_is_generic') and caps._is_generic is True
assert hasattr(caps, "_is_generic") and caps._is_generic is True
def test_model_alias_resolution(self):
"""Test model alias resolution."""
provider = OpenRouterProvider(api_key="test-key")
# Test alias resolution
assert provider._resolve_model_name("opus") == "anthropic/claude-3-opus"
assert provider._resolve_model_name("sonnet") == "anthropic/claude-3-sonnet"
@@ -79,30 +77,30 @@ class TestOpenRouterProvider:
assert provider._resolve_model_name("mistral") == "mistral/mistral-large"
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-coder"
assert provider._resolve_model_name("coder") == "deepseek/deepseek-coder"
# Test case-insensitive
assert provider._resolve_model_name("OPUS") == "anthropic/claude-3-opus"
assert provider._resolve_model_name("GPT4O") == "openai/gpt-4o"
assert provider._resolve_model_name("Mistral") == "mistral/mistral-large"
assert provider._resolve_model_name("CLAUDE") == "anthropic/claude-3-sonnet"
# Test direct model names (should pass through unchanged)
assert provider._resolve_model_name("anthropic/claude-3-opus") == "anthropic/claude-3-opus"
assert provider._resolve_model_name("openai/gpt-4o") == "openai/gpt-4o"
# Test unknown models pass through
assert provider._resolve_model_name("unknown-model") == "unknown-model"
assert provider._resolve_model_name("custom/model-v2") == "custom/model-v2"
def test_openrouter_registration(self):
"""Test OpenRouter can be registered and retrieved."""
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
# Clean up any existing registration
ModelProviderRegistry.unregister_provider(ProviderType.OPENROUTER)
# Register the provider
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
# Retrieve and verify
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
assert provider is not None
@@ -111,53 +109,53 @@ class TestOpenRouterProvider:
class TestOpenRouterRegistry:
"""Test cases for OpenRouter model registry."""
def test_registry_loading(self):
"""Test registry loads models from config."""
from providers.openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
# Should have loaded models
models = registry.list_models()
assert len(models) > 0
assert "anthropic/claude-3-opus" in models
assert "openai/gpt-4o" in models
# Should have loaded aliases
aliases = registry.list_aliases()
assert len(aliases) > 0
assert "opus" in aliases
assert "gpt4o" in aliases
assert "claude" in aliases
def test_registry_capabilities(self):
"""Test registry provides correct capabilities."""
from providers.openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
# Test known model
caps = registry.get_capabilities("opus")
assert caps is not None
assert caps.model_name == "anthropic/claude-3-opus"
assert caps.max_tokens == 200000 # Claude's context window
# Test using full model name
caps = registry.get_capabilities("anthropic/claude-3-opus")
assert caps is not None
assert caps.model_name == "anthropic/claude-3-opus"
# Test unknown model
caps = registry.get_capabilities("non-existent-model")
assert caps is None
def test_multiple_aliases_same_model(self):
"""Test multiple aliases pointing to same model."""
from providers.openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
# All these should resolve to Claude Sonnet
sonnet_aliases = ["sonnet", "claude", "claude-sonnet", "claude3-sonnet"]
for alias in sonnet_aliases:
@@ -166,48 +164,34 @@ class TestOpenRouterRegistry:
assert config.model_name == "anthropic/claude-3-sonnet"
class TestOpenRouterSSRFProtection:
"""Test SSRF protection for OpenRouter."""
def test_url_validation_rejects_private_ips(self):
"""Test that private IPs are rejected."""
class TestOpenRouterFunctionality:
"""Test OpenRouter-specific functionality."""
def test_openrouter_always_uses_correct_url(self):
"""Test that OpenRouter always uses the correct base URL."""
provider = OpenRouterProvider(api_key="test-key")
# List of private/dangerous IPs to test
dangerous_urls = [
"http://192.168.1.1/api/v1",
"http://10.0.0.1/api/v1",
"http://172.16.0.1/api/v1",
"http://169.254.169.254/api/v1", # AWS metadata
"http://[::1]/api/v1", # IPv6 localhost
"http://0.0.0.0/api/v1",
]
for url in dangerous_urls:
with pytest.raises(ValueError, match="restricted IP|Invalid"):
provider.base_url = url
provider._validate_base_url()
def test_url_validation_allows_public_domains(self):
"""Test that legitimate public domains are allowed."""
assert provider.base_url == "https://openrouter.ai/api/v1"
# Even if we try to change it, it should remain the OpenRouter URL
# (This is a characteristic of the OpenRouter provider)
provider.base_url = "http://example.com" # Try to change it
# But new instances should always use the correct URL
provider2 = OpenRouterProvider(api_key="test-key")
assert provider2.base_url == "https://openrouter.ai/api/v1"
def test_openrouter_headers_set_correctly(self):
"""Test that OpenRouter specific headers are set."""
provider = OpenRouterProvider(api_key="test-key")
# OpenRouter's actual domain should always be allowed
provider.base_url = "https://openrouter.ai/api/v1"
provider._validate_base_url() # Should not raise
def test_invalid_url_schemes_rejected(self):
"""Test that non-HTTP(S) schemes are rejected."""
# Check default headers
assert "HTTP-Referer" in provider.DEFAULT_HEADERS
assert "X-Title" in provider.DEFAULT_HEADERS
assert provider.DEFAULT_HEADERS["X-Title"] == "Zen MCP Server"
def test_openrouter_model_registry_initialized(self):
"""Test that model registry is properly initialized."""
provider = OpenRouterProvider(api_key="test-key")
invalid_urls = [
"ftp://example.com/api",
"file:///etc/passwd",
"gopher://example.com",
"javascript:alert(1)",
]
for url in invalid_urls:
with pytest.raises(ValueError, match="Invalid URL scheme"):
provider.base_url = url
provider._validate_base_url()
# Registry should be initialized
assert hasattr(provider, '_registry')
assert provider._registry is not None

View File

@@ -2,42 +2,34 @@
import json
import os
import pytest
import tempfile
from pathlib import Path
from providers.openrouter_registry import OpenRouterModelRegistry, OpenRouterModelConfig
import pytest
from providers.base import ProviderType
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry
class TestOpenRouterModelRegistry:
"""Test cases for OpenRouter model registry."""
def test_registry_initialization(self):
"""Test registry initializes with default config."""
registry = OpenRouterModelRegistry()
# Should load models from default location
assert len(registry.list_models()) > 0
assert len(registry.list_aliases()) > 0
def test_custom_config_path(self):
"""Test registry with custom config path."""
# Create temporary config
config_data = {
"models": [
{
"model_name": "test/model-1",
"aliases": ["test1", "t1"],
"context_window": 4096
}
]
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
config_data = {"models": [{"model_name": "test/model-1", "aliases": ["test1", "t1"], "context_window": 4096}]}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(config_data, f)
temp_path = f.name
try:
registry = OpenRouterModelRegistry(config_path=temp_path)
assert len(registry.list_models()) == 1
@@ -46,48 +38,40 @@ class TestOpenRouterModelRegistry:
assert "t1" in registry.list_aliases()
finally:
os.unlink(temp_path)
def test_environment_variable_override(self):
"""Test OPENROUTER_MODELS_PATH environment variable."""
# Create custom config
config_data = {
"models": [
{
"model_name": "env/model",
"aliases": ["envtest"],
"context_window": 8192
}
]
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
config_data = {"models": [{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192}]}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(config_data, f)
temp_path = f.name
try:
# Set environment variable
original_env = os.environ.get('OPENROUTER_MODELS_PATH')
os.environ['OPENROUTER_MODELS_PATH'] = temp_path
original_env = os.environ.get("OPENROUTER_MODELS_PATH")
os.environ["OPENROUTER_MODELS_PATH"] = temp_path
# Create registry without explicit path
registry = OpenRouterModelRegistry()
# Should load from environment path
assert "env/model" in registry.list_models()
assert "envtest" in registry.list_aliases()
finally:
# Restore environment
if original_env is not None:
os.environ['OPENROUTER_MODELS_PATH'] = original_env
os.environ["OPENROUTER_MODELS_PATH"] = original_env
else:
del os.environ['OPENROUTER_MODELS_PATH']
del os.environ["OPENROUTER_MODELS_PATH"]
os.unlink(temp_path)
def test_alias_resolution(self):
"""Test alias resolution functionality."""
registry = OpenRouterModelRegistry()
# Test various aliases
test_cases = [
("opus", "anthropic/claude-3-opus"),
@@ -97,75 +81,71 @@ class TestOpenRouterModelRegistry:
("4o", "openai/gpt-4o"),
("mistral", "mistral/mistral-large"),
]
for alias, expected_model in test_cases:
config = registry.resolve(alias)
assert config is not None, f"Failed to resolve alias '{alias}'"
assert config.model_name == expected_model
def test_direct_model_name_lookup(self):
"""Test looking up models by their full name."""
registry = OpenRouterModelRegistry()
# Should be able to look up by full model name
config = registry.resolve("anthropic/claude-3-opus")
assert config is not None
assert config.model_name == "anthropic/claude-3-opus"
config = registry.resolve("openai/gpt-4o")
assert config is not None
assert config.model_name == "openai/gpt-4o"
def test_unknown_model_resolution(self):
"""Test resolution of unknown models."""
registry = OpenRouterModelRegistry()
# Unknown aliases should return None
assert registry.resolve("unknown-alias") is None
assert registry.resolve("") is None
assert registry.resolve("non-existent") is None
def test_model_capabilities_conversion(self):
"""Test conversion to ModelCapabilities."""
registry = OpenRouterModelRegistry()
config = registry.resolve("opus")
assert config is not None
caps = config.to_capabilities()
assert caps.provider == ProviderType.OPENROUTER
assert caps.model_name == "anthropic/claude-3-opus"
assert caps.friendly_name == "OpenRouter"
assert caps.max_tokens == 200000
assert not caps.supports_extended_thinking
def test_duplicate_alias_detection(self):
"""Test that duplicate aliases are detected."""
config_data = {
"models": [
{
"model_name": "test/model-1",
"aliases": ["dupe"],
"context_window": 4096
},
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096},
{
"model_name": "test/model-2",
"aliases": ["DUPE"], # Same alias, different case
"context_window": 8192
}
"context_window": 8192,
},
]
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(config_data, f)
temp_path = f.name
try:
with pytest.raises(ValueError, match="Duplicate alias"):
OpenRouterModelRegistry(config_path=temp_path)
finally:
os.unlink(temp_path)
def test_backwards_compatibility_max_tokens(self):
"""Test backwards compatibility with old max_tokens field."""
config_data = {
@@ -174,44 +154,44 @@ class TestOpenRouterModelRegistry:
"model_name": "test/old-model",
"aliases": ["old"],
"max_tokens": 16384, # Old field name
"supports_extended_thinking": False
"supports_extended_thinking": False,
}
]
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(config_data, f)
temp_path = f.name
try:
registry = OpenRouterModelRegistry(config_path=temp_path)
config = registry.resolve("old")
assert config is not None
assert config.context_window == 16384 # Should be converted
# Check capabilities still work
caps = config.to_capabilities()
assert caps.max_tokens == 16384
finally:
os.unlink(temp_path)
def test_missing_config_file(self):
"""Test behavior with missing config file."""
# Use a non-existent path
registry = OpenRouterModelRegistry(config_path="/non/existent/path.json")
# Should initialize with empty maps
assert len(registry.list_models()) == 0
assert len(registry.list_aliases()) == 0
assert registry.resolve("anything") is None
def test_invalid_json_config(self):
"""Test handling of invalid JSON."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
f.write("{ invalid json }")
temp_path = f.name
try:
registry = OpenRouterModelRegistry(config_path=temp_path)
# Should handle gracefully and initialize empty
@@ -219,7 +199,7 @@ class TestOpenRouterModelRegistry:
assert len(registry.list_aliases()) == 0
finally:
os.unlink(temp_path)
def test_model_with_all_capabilities(self):
"""Test model with all capability flags."""
config = OpenRouterModelConfig(
@@ -231,13 +211,13 @@ class TestOpenRouterModelRegistry:
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
description="Fully featured test model"
description="Fully featured test model",
)
caps = config.to_capabilities()
assert caps.max_tokens == 128000
assert caps.supports_extended_thinking
assert caps.supports_system_prompts
assert caps.supports_streaming
assert caps.supports_function_calling
# Note: supports_json_mode is not in ModelCapabilities yet
# Note: supports_json_mode is not in ModelCapabilities yet

View File

@@ -57,15 +57,28 @@ class ToolRequest(BaseModel):
# Higher values allow for more complex reasoning but increase latency and cost
thinking_mode: Optional[Literal["minimal", "low", "medium", "high", "max"]] = Field(
None,
description="Thinking depth: minimal (0.5% of model max), low (8%), medium (33%), high (67%), max (100% of model max)",
description=(
"Thinking depth: minimal (0.5% of model max), low (8%), medium (33%), high (67%), "
"max (100% of model max)"
),
)
use_websearch: Optional[bool] = Field(
True,
description="Enable web search for documentation, best practices, and current information. When enabled, the model can request Claude to perform web searches and share results back during conversations. Particularly useful for: brainstorming sessions, architectural design discussions, exploring industry best practices, working with specific frameworks/technologies, researching solutions to complex problems, or when current documentation and community insights would enhance the analysis.",
description=(
"Enable web search for documentation, best practices, and current information. "
"When enabled, the model can request Claude to perform web searches and share results back "
"during conversations. Particularly useful for: brainstorming sessions, architectural design "
"discussions, exploring industry best practices, working with specific frameworks/technologies, "
"researching solutions to complex problems, or when current documentation and community insights "
"would enhance the analysis."
),
)
continuation_id: Optional[str] = Field(
None,
description="Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
description=(
"Thread continuation ID for multi-turn conversations. Can be used to continue conversations "
"across different tools. Only provide this if continuing a previous conversation thread."
),
)
@@ -152,21 +165,48 @@ class BaseTool(ABC):
Returns:
Dict containing the model field JSON schema
"""
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
import os
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
# Check if OpenRouter is configured
has_openrouter = bool(os.getenv("OPENROUTER_API_KEY") and
os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here")
has_openrouter = bool(
os.getenv("OPENROUTER_API_KEY") and os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here"
)
if IS_AUTO_MODE:
# In auto mode, model is required and we provide detailed descriptions
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
for model, desc in MODEL_CAPABILITIES_DESC.items():
model_desc_parts.append(f"- '{model}': {desc}")
if has_openrouter:
model_desc_parts.append("\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large'). Check openrouter.ai/models for available models.")
# Add OpenRouter aliases from the registry
try:
# Import registry directly to show available aliases
# This works even without an API key
from providers.openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
aliases = registry.list_aliases()
# Show ALL aliases from the configuration
if aliases:
# Show all aliases so Claude knows every option available
all_aliases = sorted(aliases)
alias_list = ", ".join(f"'{a}'" for a in all_aliases)
model_desc_parts.append(
f"\nOpenRouter models available via aliases: {alias_list}"
)
else:
model_desc_parts.append(
"\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter."
)
except Exception:
# Fallback if registry fails to load
model_desc_parts.append(
"\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
)
return {
"type": "string",
@@ -177,12 +217,33 @@ class BaseTool(ABC):
# Normal mode - model is optional with default
available_models = list(MODEL_CAPABILITIES_DESC.keys())
models_str = ", ".join(f"'{m}'" for m in available_models)
description = f"Model to use. Native models: {models_str}."
if has_openrouter:
description += " OpenRouter: Any model available on openrouter.ai (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
# Add OpenRouter aliases
try:
# Import registry directly to show available aliases
# This works even without an API key
from providers.openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
aliases = registry.list_aliases()
# Show ALL aliases from the configuration
if aliases:
# Show all aliases so Claude knows every option available
all_aliases = sorted(aliases)
alias_list = ", ".join(f"'{a}'" for a in all_aliases)
description += f" OpenRouter aliases: {alias_list}."
else:
description += " OpenRouter: Any model available on openrouter.ai."
except Exception:
description += (
" OpenRouter: Any model available on openrouter.ai "
"(e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
)
description += f" Defaults to '{DEFAULT_MODEL}' if not specified."
return {
"type": "string",
"description": description,