WIP
- OpenRouter model configuration registry - Model definition file for users to be able to control - Additional tests - Update instructions
This commit is contained in:
@@ -56,11 +56,13 @@ MODEL_CAPABILITIES_DESC = {
|
|||||||
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||||
# Full model names also supported
|
# Full model names also supported
|
||||||
"gemini-2.5-flash-preview-05-20": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
"gemini-2.5-flash-preview-05-20": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||||
"gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
"gemini-2.5-pro-preview-06-05": (
|
||||||
|
"Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Note: When only OpenRouter is configured, these model aliases automatically map to equivalent models:
|
# Note: When only OpenRouter is configured, these model aliases automatically map to equivalent models:
|
||||||
# - "flash" → "google/gemini-flash-1.5-8b"
|
# - "flash" → "google/gemini-flash-1.5-8b"
|
||||||
# - "pro" → "google/gemini-pro-1.5"
|
# - "pro" → "google/gemini-pro-1.5"
|
||||||
# - "o3" → "openai/gpt-4o"
|
# - "o3" → "openai/gpt-4o"
|
||||||
# - "o3-mini" → "openai/gpt-4o-mini"
|
# - "o3-mini" → "openai/gpt-4o-mini"
|
||||||
|
|||||||
@@ -141,7 +141,11 @@ trace issues to their root cause, and provide actionable solutions.
|
|||||||
IMPORTANT: If you lack critical information to proceed (e.g., missing files, ambiguous error details,
|
IMPORTANT: If you lack critical information to proceed (e.g., missing files, ambiguous error details,
|
||||||
insufficient context), OR if the provided diagnostics (log files, crash reports, stack traces) appear irrelevant,
|
insufficient context), OR if the provided diagnostics (log files, crash reports, stack traces) appear irrelevant,
|
||||||
incomplete, or insufficient for proper analysis, you MUST respond ONLY with this JSON format:
|
incomplete, or insufficient for proper analysis, you MUST respond ONLY with this JSON format:
|
||||||
{"status": "requires_clarification", "question": "What specific information you need from Claude or the user to proceed with debugging", "files_needed": ["file1.py", "file2.py"]}
|
{
|
||||||
|
"status": "requires_clarification",
|
||||||
|
"question": "What specific information you need from Claude or the user to proceed with debugging",
|
||||||
|
"files_needed": ["file1.py", "file2.py"]
|
||||||
|
}
|
||||||
|
|
||||||
CRITICAL: Your primary objective is to identify the root cause of the specific issue at hand and suggest the
|
CRITICAL: Your primary objective is to identify the root cause of the specific issue at hand and suggest the
|
||||||
minimal fix required to resolve it. Stay focused on the main problem - avoid suggesting extensive refactoring,
|
minimal fix required to resolve it. Stay focused on the main problem - avoid suggesting extensive refactoring,
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
"""OpenAI model provider implementation."""
|
"""OpenAI model provider implementation."""
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
FixedTemperatureConstraint,
|
FixedTemperatureConstraint,
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelResponse,
|
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
RangeTemperatureConstraint,
|
||||||
)
|
)
|
||||||
@@ -34,7 +30,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
kwargs.setdefault("base_url", "https://api.openai.com/v1")
|
kwargs.setdefault("base_url", "https://api.openai.com/v1")
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
"""Get capabilities for a specific OpenAI model."""
|
"""Get capabilities for a specific OpenAI model."""
|
||||||
if model_name not in self.SUPPORTED_MODELS:
|
if model_name not in self.SUPPORTED_MODELS:
|
||||||
@@ -62,7 +57,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
temperature_constraint=temp_constraint,
|
temperature_constraint=temp_constraint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.OPENAI
|
return ProviderType.OPENAI
|
||||||
@@ -76,4 +70,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# Currently no OpenAI models support extended thinking
|
# Currently no OpenAI models support extended thinking
|
||||||
# This may change with future O3 models
|
# This may change with future O3 models
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
"""Base class for OpenAI-compatible API providers."""
|
"""Base class for OpenAI-compatible API providers."""
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import socket
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import ipaddress
|
|
||||||
import socket
|
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
@@ -15,25 +15,24 @@ from .base import (
|
|||||||
ModelProvider,
|
ModelProvider,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatibleProvider(ModelProvider):
|
class OpenAICompatibleProvider(ModelProvider):
|
||||||
"""Base class for any provider using an OpenAI-compatible API.
|
"""Base class for any provider using an OpenAI-compatible API.
|
||||||
|
|
||||||
This includes:
|
This includes:
|
||||||
- Direct OpenAI API
|
- Direct OpenAI API
|
||||||
- OpenRouter
|
- OpenRouter
|
||||||
- Any other OpenAI-compatible endpoint
|
- Any other OpenAI-compatible endpoint
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_HEADERS = {}
|
DEFAULT_HEADERS = {}
|
||||||
FRIENDLY_NAME = "OpenAI Compatible"
|
FRIENDLY_NAME = "OpenAI Compatible"
|
||||||
|
|
||||||
def __init__(self, api_key: str, base_url: str = None, **kwargs):
|
def __init__(self, api_key: str, base_url: str = None, **kwargs):
|
||||||
"""Initialize the provider with API key and optional base URL.
|
"""Initialize the provider with API key and optional base URL.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: API key for authentication
|
api_key: API key for authentication
|
||||||
base_url: Base URL for the API endpoint
|
base_url: Base URL for the API endpoint
|
||||||
@@ -44,21 +43,21 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.organization = kwargs.get("organization")
|
self.organization = kwargs.get("organization")
|
||||||
self.allowed_models = self._parse_allowed_models()
|
self.allowed_models = self._parse_allowed_models()
|
||||||
|
|
||||||
# Validate base URL for security
|
# Validate base URL for security
|
||||||
if self.base_url:
|
if self.base_url:
|
||||||
self._validate_base_url()
|
self._validate_base_url()
|
||||||
|
|
||||||
# Warn if using external URL without authentication
|
# Warn if using external URL without authentication
|
||||||
if self.base_url and not self._is_localhost_url() and not api_key:
|
if self.base_url and not self._is_localhost_url() and not api_key:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Using external URL '{self.base_url}' without API key. "
|
f"Using external URL '{self.base_url}' without API key. "
|
||||||
"This may be insecure. Consider setting an API key for authentication."
|
"This may be insecure. Consider setting an API key for authentication."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_allowed_models(self) -> Optional[set[str]]:
|
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||||||
"""Parse allowed models from environment variable.
|
"""Parse allowed models from environment variable.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Set of allowed model names (lowercase) or None if not configured
|
Set of allowed model names (lowercase) or None if not configured
|
||||||
"""
|
"""
|
||||||
@@ -66,108 +65,108 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
provider_type = self.get_provider_type().value.upper()
|
provider_type = self.get_provider_type().value.upper()
|
||||||
env_var = f"{provider_type}_ALLOWED_MODELS"
|
env_var = f"{provider_type}_ALLOWED_MODELS"
|
||||||
models_str = os.getenv(env_var, "")
|
models_str = os.getenv(env_var, "")
|
||||||
|
|
||||||
if models_str:
|
if models_str:
|
||||||
# Parse and normalize to lowercase for case-insensitive comparison
|
# Parse and normalize to lowercase for case-insensitive comparison
|
||||||
models = set(m.strip().lower() for m in models_str.split(",") if m.strip())
|
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
||||||
if models:
|
if models:
|
||||||
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
||||||
return models
|
return models
|
||||||
|
|
||||||
# Log warning if no allow-list configured for proxy providers
|
# Log warning if no allow-list configured for proxy providers
|
||||||
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
|
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"No model allow-list configured for {self.FRIENDLY_NAME}. "
|
f"No model allow-list configured for {self.FRIENDLY_NAME}. "
|
||||||
f"Set {env_var} to restrict model access and control costs."
|
f"Set {env_var} to restrict model access and control costs."
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _is_localhost_url(self) -> bool:
|
def _is_localhost_url(self) -> bool:
|
||||||
"""Check if the base URL points to localhost.
|
"""Check if the base URL points to localhost.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if URL is localhost, False otherwise
|
True if URL is localhost, False otherwise
|
||||||
"""
|
"""
|
||||||
if not self.base_url:
|
if not self.base_url:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parsed = urlparse(self.base_url)
|
parsed = urlparse(self.base_url)
|
||||||
hostname = parsed.hostname
|
hostname = parsed.hostname
|
||||||
|
|
||||||
# Check for common localhost patterns
|
# Check for common localhost patterns
|
||||||
if hostname in ['localhost', '127.0.0.1', '::1']:
|
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _validate_base_url(self) -> None:
|
def _validate_base_url(self) -> None:
|
||||||
"""Validate base URL for security (SSRF protection).
|
"""Validate base URL for security (SSRF protection).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If URL is invalid or potentially unsafe
|
ValueError: If URL is invalid or potentially unsafe
|
||||||
"""
|
"""
|
||||||
if not self.base_url:
|
if not self.base_url:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parsed = urlparse(self.base_url)
|
parsed = urlparse(self.base_url)
|
||||||
|
|
||||||
|
|
||||||
# Check URL scheme - only allow http/https
|
# Check URL scheme - only allow http/https
|
||||||
if parsed.scheme not in ('http', 'https'):
|
if parsed.scheme not in ("http", "https"):
|
||||||
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
|
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
|
||||||
|
|
||||||
# Check hostname exists
|
# Check hostname exists
|
||||||
if not parsed.hostname:
|
if not parsed.hostname:
|
||||||
raise ValueError("URL must include a hostname")
|
raise ValueError("URL must include a hostname")
|
||||||
|
|
||||||
# Check port - allow only standard HTTP/HTTPS ports
|
# Check port - allow only standard HTTP/HTTPS ports
|
||||||
port = parsed.port
|
port = parsed.port
|
||||||
if port is None:
|
if port is None:
|
||||||
port = 443 if parsed.scheme == 'https' else 80
|
port = 443 if parsed.scheme == "https" else 80
|
||||||
|
|
||||||
# Allow common HTTP ports and some alternative ports
|
# Allow common HTTP ports and some alternative ports
|
||||||
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
|
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
|
||||||
if port not in allowed_ports:
|
if port not in allowed_ports:
|
||||||
raise ValueError(
|
raise ValueError(f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}")
|
||||||
f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check against allowed domains if configured
|
# Check against allowed domains if configured
|
||||||
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
|
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
|
||||||
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
|
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
|
||||||
|
|
||||||
if allowed_domains:
|
if allowed_domains:
|
||||||
hostname_lower = parsed.hostname.lower()
|
hostname_lower = parsed.hostname.lower()
|
||||||
if not any(
|
if not any(
|
||||||
hostname_lower == domain or
|
hostname_lower == domain or hostname_lower.endswith("." + domain) for domain in allowed_domains
|
||||||
hostname_lower.endswith('.' + domain)
|
|
||||||
for domain in allowed_domains
|
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Domain not in allow-list: {parsed.hostname}. "
|
f"Domain not in allow-list: {parsed.hostname}. " f"Allowed domains: {allowed_domains}"
|
||||||
f"Allowed domains: {allowed_domains}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to resolve hostname and check if it's a private IP
|
# Try to resolve hostname and check if it's a private IP
|
||||||
# Skip for localhost addresses which are commonly used for development
|
# Skip for localhost addresses which are commonly used for development
|
||||||
if parsed.hostname not in ['localhost', '127.0.0.1', '::1']:
|
if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
|
||||||
try:
|
try:
|
||||||
# Get all IP addresses for the hostname
|
# Get all IP addresses for the hostname
|
||||||
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
|
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
|
||||||
|
|
||||||
for family, _, _, _, sockaddr in addr_info:
|
for _family, _, _, _, sockaddr in addr_info:
|
||||||
ip_str = sockaddr[0]
|
ip_str = sockaddr[0]
|
||||||
try:
|
try:
|
||||||
ip = ipaddress.ip_address(ip_str)
|
ip = ipaddress.ip_address(ip_str)
|
||||||
|
|
||||||
# Check for dangerous IP ranges
|
# Check for dangerous IP ranges
|
||||||
if (ip.is_private or ip.is_loopback or ip.is_link_local or
|
if (
|
||||||
ip.is_multicast or ip.is_reserved or ip.is_unspecified):
|
ip.is_private
|
||||||
|
or ip.is_loopback
|
||||||
|
or ip.is_link_local
|
||||||
|
or ip.is_multicast
|
||||||
|
or ip.is_reserved
|
||||||
|
or ip.is_unspecified
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"URL resolves to restricted IP address: {ip_str}. "
|
f"URL resolves to restricted IP address: {ip_str}. "
|
||||||
"This could be a security risk (SSRF)."
|
"This could be a security risk (SSRF)."
|
||||||
@@ -177,16 +176,16 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
if "restricted IP address" in str(ve):
|
if "restricted IP address" in str(ve):
|
||||||
raise
|
raise
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except socket.gaierror as e:
|
except socket.gaierror as e:
|
||||||
# If we can't resolve the hostname, it's suspicious
|
# If we can't resolve the hostname, it's suspicious
|
||||||
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
|
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, ValueError):
|
if isinstance(e, ValueError):
|
||||||
raise
|
raise
|
||||||
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
|
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self):
|
def client(self):
|
||||||
"""Lazy initialization of OpenAI client with security checks."""
|
"""Lazy initialization of OpenAI client with security checks."""
|
||||||
@@ -194,21 +193,21 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
client_kwargs = {
|
client_kwargs = {
|
||||||
"api_key": self.api_key,
|
"api_key": self.api_key,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.base_url:
|
if self.base_url:
|
||||||
client_kwargs["base_url"] = self.base_url
|
client_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
if self.organization:
|
if self.organization:
|
||||||
client_kwargs["organization"] = self.organization
|
client_kwargs["organization"] = self.organization
|
||||||
|
|
||||||
# Add default headers if any
|
# Add default headers if any
|
||||||
if self.DEFAULT_HEADERS:
|
if self.DEFAULT_HEADERS:
|
||||||
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||||||
|
|
||||||
self._client = OpenAI(**client_kwargs)
|
self._client = OpenAI(**client_kwargs)
|
||||||
|
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -219,7 +218,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Generate content using the OpenAI-compatible API.
|
"""Generate content using the OpenAI-compatible API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: User prompt to send to the model
|
prompt: User prompt to send to the model
|
||||||
model_name: Name of the model to use
|
model_name: Name of the model to use
|
||||||
@@ -227,50 +226,49 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
temperature: Sampling temperature
|
temperature: Sampling temperature
|
||||||
max_output_tokens: Maximum tokens to generate
|
max_output_tokens: Maximum tokens to generate
|
||||||
**kwargs: Additional provider-specific parameters
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelResponse with generated content and metadata
|
ModelResponse with generated content and metadata
|
||||||
"""
|
"""
|
||||||
# Validate model name against allow-list
|
# Validate model name against allow-list
|
||||||
if not self.validate_model_name(model_name):
|
if not self.validate_model_name(model_name):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model '{model_name}' not in allowed models list. "
|
f"Model '{model_name}' not in allowed models list. " f"Allowed models: {self.allowed_models}"
|
||||||
f"Allowed models: {self.allowed_models}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate parameters
|
# Validate parameters
|
||||||
self.validate_parameters(model_name, temperature)
|
self.validate_parameters(model_name, temperature)
|
||||||
|
|
||||||
# Prepare messages
|
# Prepare messages
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
# Prepare completion parameters
|
# Prepare completion parameters
|
||||||
completion_params = {
|
completion_params = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add max tokens if specified
|
# Add max tokens if specified
|
||||||
if max_output_tokens:
|
if max_output_tokens:
|
||||||
completion_params["max_tokens"] = max_output_tokens
|
completion_params["max_tokens"] = max_output_tokens
|
||||||
|
|
||||||
# Add any additional OpenAI-specific parameters
|
# Add any additional OpenAI-specific parameters
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||||||
completion_params[key] = value
|
completion_params[key] = value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate completion
|
# Generate completion
|
||||||
response = self.client.chat.completions.create(**completion_params)
|
response = self.client.chat.completions.create(**completion_params)
|
||||||
|
|
||||||
# Extract content and usage
|
# Extract content and usage
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
usage = self._extract_usage(response)
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
return ModelResponse(
|
return ModelResponse(
|
||||||
content=content,
|
content=content,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@@ -284,39 +282,39 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
"created": response.created,
|
"created": response.created,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log error and re-raise with more context
|
# Log error and re-raise with more context
|
||||||
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}"
|
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise RuntimeError(error_msg) from e
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
"""Count tokens for the given text.
|
"""Count tokens for the given text.
|
||||||
|
|
||||||
Uses a layered approach:
|
Uses a layered approach:
|
||||||
1. Try provider-specific token counting endpoint
|
1. Try provider-specific token counting endpoint
|
||||||
2. Try tiktoken for known model families
|
2. Try tiktoken for known model families
|
||||||
3. Fall back to character-based estimation
|
3. Fall back to character-based estimation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Text to count tokens for
|
text: Text to count tokens for
|
||||||
model_name: Model name for tokenizer selection
|
model_name: Model name for tokenizer selection
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Estimated token count
|
Estimated token count
|
||||||
"""
|
"""
|
||||||
# 1. Check if provider has a remote token counting endpoint
|
# 1. Check if provider has a remote token counting endpoint
|
||||||
if hasattr(self, 'count_tokens_remote'):
|
if hasattr(self, "count_tokens_remote"):
|
||||||
try:
|
try:
|
||||||
return self.count_tokens_remote(text, model_name)
|
return self.count_tokens_remote(text, model_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.debug(f"Remote token counting failed: {e}")
|
logging.debug(f"Remote token counting failed: {e}")
|
||||||
|
|
||||||
# 2. Try tiktoken for known models
|
# 2. Try tiktoken for known models
|
||||||
try:
|
try:
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
# Try to get encoding for the specific model
|
# Try to get encoding for the specific model
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(model_name)
|
encoding = tiktoken.encoding_for_model(model_name)
|
||||||
@@ -326,24 +324,24 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
else:
|
else:
|
||||||
encoding = tiktoken.get_encoding("cl100k_base") # Default
|
encoding = tiktoken.get_encoding("cl100k_base") # Default
|
||||||
|
|
||||||
return len(encoding.encode(text))
|
return len(encoding.encode(text))
|
||||||
|
|
||||||
except (ImportError, Exception) as e:
|
except (ImportError, Exception) as e:
|
||||||
logging.debug(f"Tiktoken not available or failed: {e}")
|
logging.debug(f"Tiktoken not available or failed: {e}")
|
||||||
|
|
||||||
# 3. Fall back to character-based estimation
|
# 3. Fall back to character-based estimation
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"No specific tokenizer available for '{model_name}'. "
|
f"No specific tokenizer available for '{model_name}'. "
|
||||||
"Using character-based estimation (~4 chars per token)."
|
"Using character-based estimation (~4 chars per token)."
|
||||||
)
|
)
|
||||||
return len(text) // 4
|
return len(text) // 4
|
||||||
|
|
||||||
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||||
"""Validate model parameters.
|
"""Validate model parameters.
|
||||||
|
|
||||||
For proxy providers, this may use generic capabilities.
|
For proxy providers, this may use generic capabilities.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Model to validate for
|
model_name: Model to validate for
|
||||||
temperature: Temperature to validate
|
temperature: Temperature to validate
|
||||||
@@ -351,67 +349,66 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
capabilities = self.get_capabilities(model_name)
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
# Check if we're using generic capabilities
|
# Check if we're using generic capabilities
|
||||||
if hasattr(capabilities, '_is_generic'):
|
if hasattr(capabilities, "_is_generic"):
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Using generic parameter validation for {model_name}. "
|
f"Using generic parameter validation for {model_name}. " "Actual model constraints may differ."
|
||||||
"Actual model constraints may differ."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate temperature using parent class method
|
# Validate temperature using parent class method
|
||||||
super().validate_parameters(model_name, temperature, **kwargs)
|
super().validate_parameters(model_name, temperature, **kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# For proxy providers, we might not have accurate capabilities
|
# For proxy providers, we might not have accurate capabilities
|
||||||
# Log warning but don't fail
|
# Log warning but don't fail
|
||||||
logging.warning(f"Parameter validation limited for {model_name}: {e}")
|
logging.warning(f"Parameter validation limited for {model_name}: {e}")
|
||||||
|
|
||||||
def _extract_usage(self, response) -> dict[str, int]:
|
def _extract_usage(self, response) -> dict[str, int]:
|
||||||
"""Extract token usage from OpenAI response.
|
"""Extract token usage from OpenAI response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: OpenAI API response object
|
response: OpenAI API response object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with usage statistics
|
Dictionary with usage statistics
|
||||||
"""
|
"""
|
||||||
usage = {}
|
usage = {}
|
||||||
|
|
||||||
if hasattr(response, "usage") and response.usage:
|
if hasattr(response, "usage") and response.usage:
|
||||||
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
|
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
|
||||||
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
|
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
|
||||||
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
|
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
"""Get capabilities for a specific model.
|
"""Get capabilities for a specific model.
|
||||||
|
|
||||||
Must be implemented by subclasses.
|
Must be implemented by subclasses.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type.
|
"""Get the provider type.
|
||||||
|
|
||||||
Must be implemented by subclasses.
|
Must be implemented by subclasses.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
"""Validate if the model name is supported.
|
"""Validate if the model name is supported.
|
||||||
|
|
||||||
Must be implemented by subclasses.
|
Must be implemented by subclasses.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
"""Check if the model supports extended thinking mode.
|
"""Check if the model supports extended thinking mode.
|
||||||
|
|
||||||
Default is False for OpenAI-compatible providers.
|
Default is False for OpenAI-compatible providers.
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -16,63 +16,61 @@ from .openrouter_registry import OpenRouterModelRegistry
|
|||||||
|
|
||||||
class OpenRouterProvider(OpenAICompatibleProvider):
|
class OpenRouterProvider(OpenAICompatibleProvider):
|
||||||
"""OpenRouter unified API provider.
|
"""OpenRouter unified API provider.
|
||||||
|
|
||||||
OpenRouter provides access to multiple AI models through a single API endpoint.
|
OpenRouter provides access to multiple AI models through a single API endpoint.
|
||||||
See https://openrouter.ai for available models and pricing.
|
See https://openrouter.ai for available models and pricing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FRIENDLY_NAME = "OpenRouter"
|
FRIENDLY_NAME = "OpenRouter"
|
||||||
|
|
||||||
# Custom headers required by OpenRouter
|
# Custom headers required by OpenRouter
|
||||||
DEFAULT_HEADERS = {
|
DEFAULT_HEADERS = {
|
||||||
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
|
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
|
||||||
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
|
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Model registry for managing configurations and aliases
|
# Model registry for managing configurations and aliases
|
||||||
_registry: Optional[OpenRouterModelRegistry] = None
|
_registry: Optional[OpenRouterModelRegistry] = None
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize OpenRouter provider.
|
"""Initialize OpenRouter provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: OpenRouter API key
|
api_key: OpenRouter API key
|
||||||
**kwargs: Additional configuration
|
**kwargs: Additional configuration
|
||||||
"""
|
"""
|
||||||
# Always use OpenRouter's base URL
|
# Always use OpenRouter's base URL
|
||||||
super().__init__(api_key, base_url="https://openrouter.ai/api/v1", **kwargs)
|
super().__init__(api_key, base_url="https://openrouter.ai/api/v1", **kwargs)
|
||||||
|
|
||||||
# Initialize model registry
|
# Initialize model registry
|
||||||
if OpenRouterProvider._registry is None:
|
if OpenRouterProvider._registry is None:
|
||||||
OpenRouterProvider._registry = OpenRouterModelRegistry()
|
OpenRouterProvider._registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
# Log loaded models and aliases
|
# Log loaded models and aliases
|
||||||
models = self._registry.list_models()
|
models = self._registry.list_models()
|
||||||
aliases = self._registry.list_aliases()
|
aliases = self._registry.list_aliases()
|
||||||
logging.info(
|
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
|
||||||
f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_allowed_models(self) -> None:
|
def _parse_allowed_models(self) -> None:
|
||||||
"""Override to disable environment-based allow-list.
|
"""Override to disable environment-based allow-list.
|
||||||
|
|
||||||
OpenRouter model access is controlled via the OpenRouter dashboard,
|
OpenRouter model access is controlled via the OpenRouter dashboard,
|
||||||
not through environment variables.
|
not through environment variables.
|
||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
"""Resolve model aliases to OpenRouter model names.
|
"""Resolve model aliases to OpenRouter model names.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Input model name or alias
|
model_name: Input model name or alias
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Resolved OpenRouter model name
|
Resolved OpenRouter model name
|
||||||
"""
|
"""
|
||||||
# Try to resolve through registry
|
# Try to resolve through registry
|
||||||
config = self._registry.resolve(model_name)
|
config = self._registry.resolve(model_name)
|
||||||
|
|
||||||
if config:
|
if config:
|
||||||
if config.model_name != model_name:
|
if config.model_name != model_name:
|
||||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||||
@@ -82,30 +80,30 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
# This allows using models not in our config file
|
# This allows using models not in our config file
|
||||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
"""Get capabilities for a model.
|
"""Get capabilities for a model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the model (or alias)
|
model_name: Name of the model (or alias)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelCapabilities from registry or generic defaults
|
ModelCapabilities from registry or generic defaults
|
||||||
"""
|
"""
|
||||||
# Try to get from registry first
|
# Try to get from registry first
|
||||||
capabilities = self._registry.get_capabilities(model_name)
|
capabilities = self._registry.get_capabilities(model_name)
|
||||||
|
|
||||||
if capabilities:
|
if capabilities:
|
||||||
return capabilities
|
return capabilities
|
||||||
else:
|
else:
|
||||||
# Resolve any potential aliases and create generic capabilities
|
# Resolve any potential aliases and create generic capabilities
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
|
f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
|
||||||
"Consider adding to openrouter_models.json for specific capabilities."
|
"Consider adding to openrouter_models.json for specific capabilities."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create generic capabilities with conservative defaults
|
# Create generic capabilities with conservative defaults
|
||||||
capabilities = ModelCapabilities(
|
capabilities = ModelCapabilities(
|
||||||
provider=ProviderType.OPENROUTER,
|
provider=ProviderType.OPENROUTER,
|
||||||
@@ -118,31 +116,31 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
supports_function_calling=False,
|
supports_function_calling=False,
|
||||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark as generic for validation purposes
|
# Mark as generic for validation purposes
|
||||||
capabilities._is_generic = True
|
capabilities._is_generic = True
|
||||||
|
|
||||||
return capabilities
|
return capabilities
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.OPENROUTER
|
return ProviderType.OPENROUTER
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
"""Validate if the model name is allowed.
|
"""Validate if the model name is allowed.
|
||||||
|
|
||||||
For OpenRouter, we accept any model name. OpenRouter will
|
For OpenRouter, we accept any model name. OpenRouter will
|
||||||
validate based on the API key's permissions.
|
validate based on the API key's permissions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Model name to validate
|
model_name: Model name to validate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Always True - OpenRouter handles validation
|
Always True - OpenRouter handles validation
|
||||||
"""
|
"""
|
||||||
# Accept any model name - OpenRouter will validate based on API key permissions
|
# Accept any model name - OpenRouter will validate based on API key permissions
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -153,7 +151,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Generate content using the OpenRouter API.
|
"""Generate content using the OpenRouter API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: User prompt to send to the model
|
prompt: User prompt to send to the model
|
||||||
model_name: Name of the model (or alias) to use
|
model_name: Name of the model (or alias) to use
|
||||||
@@ -161,13 +159,13 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
temperature: Sampling temperature
|
temperature: Sampling temperature
|
||||||
max_output_tokens: Maximum tokens to generate
|
max_output_tokens: Maximum tokens to generate
|
||||||
**kwargs: Additional provider-specific parameters
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelResponse with generated content and metadata
|
ModelResponse with generated content and metadata
|
||||||
"""
|
"""
|
||||||
# Resolve model alias to actual OpenRouter model name
|
# Resolve model alias to actual OpenRouter model name
|
||||||
resolved_model = self._resolve_model_name(model_name)
|
resolved_model = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# Call parent method with resolved model name
|
# Call parent method with resolved model name
|
||||||
return super().generate_content(
|
return super().generate_content(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@@ -175,19 +173,19 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_output_tokens=max_output_tokens,
|
max_output_tokens=max_output_tokens,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
"""Check if the model supports extended thinking mode.
|
"""Check if the model supports extended thinking mode.
|
||||||
|
|
||||||
Currently, no models via OpenRouter support extended thinking.
|
Currently, no models via OpenRouter support extended thinking.
|
||||||
This may change as new models become available.
|
This may change as new models become available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Model to check
|
model_name: Model to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
False (no OpenRouter models currently support thinking mode)
|
False (no OpenRouter models currently support thinking mode)
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -3,9 +3,9 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||||
|
|
||||||
@@ -13,9 +13,9 @@ from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
|||||||
@dataclass
|
@dataclass
|
||||||
class OpenRouterModelConfig:
|
class OpenRouterModelConfig:
|
||||||
"""Configuration for an OpenRouter model."""
|
"""Configuration for an OpenRouter model."""
|
||||||
|
|
||||||
model_name: str
|
model_name: str
|
||||||
aliases: List[str] = field(default_factory=list)
|
aliases: list[str] = field(default_factory=list)
|
||||||
context_window: int = 32768 # Total context window size in tokens
|
context_window: int = 32768 # Total context window size in tokens
|
||||||
supports_extended_thinking: bool = False
|
supports_extended_thinking: bool = False
|
||||||
supports_system_prompts: bool = True
|
supports_system_prompts: bool = True
|
||||||
@@ -23,8 +23,7 @@ class OpenRouterModelConfig:
|
|||||||
supports_function_calling: bool = False
|
supports_function_calling: bool = False
|
||||||
supports_json_mode: bool = False
|
supports_json_mode: bool = False
|
||||||
description: str = ""
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
def to_capabilities(self) -> ModelCapabilities:
|
def to_capabilities(self) -> ModelCapabilities:
|
||||||
"""Convert to ModelCapabilities object."""
|
"""Convert to ModelCapabilities object."""
|
||||||
return ModelCapabilities(
|
return ModelCapabilities(
|
||||||
@@ -42,16 +41,16 @@ class OpenRouterModelConfig:
|
|||||||
|
|
||||||
class OpenRouterModelRegistry:
|
class OpenRouterModelRegistry:
|
||||||
"""Registry for managing OpenRouter model configurations and aliases."""
|
"""Registry for managing OpenRouter model configurations and aliases."""
|
||||||
|
|
||||||
def __init__(self, config_path: Optional[str] = None):
|
def __init__(self, config_path: Optional[str] = None):
|
||||||
"""Initialize the registry.
|
"""Initialize the registry.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_path: Path to config file. If None, uses default locations.
|
config_path: Path to config file. If None, uses default locations.
|
||||||
"""
|
"""
|
||||||
self.alias_map: Dict[str, str] = {} # alias -> model_name
|
self.alias_map: dict[str, str] = {} # alias -> model_name
|
||||||
self.model_map: Dict[str, OpenRouterModelConfig] = {} # model_name -> config
|
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config
|
||||||
|
|
||||||
# Determine config path
|
# Determine config path
|
||||||
if config_path:
|
if config_path:
|
||||||
self.config_path = Path(config_path)
|
self.config_path = Path(config_path)
|
||||||
@@ -63,86 +62,93 @@ class OpenRouterModelRegistry:
|
|||||||
else:
|
else:
|
||||||
# Default to conf/openrouter_models.json
|
# Default to conf/openrouter_models.json
|
||||||
self.config_path = Path(__file__).parent.parent / "conf" / "openrouter_models.json"
|
self.config_path = Path(__file__).parent.parent / "conf" / "openrouter_models.json"
|
||||||
|
|
||||||
# Load configuration
|
# Load configuration
|
||||||
self.reload()
|
self.reload()
|
||||||
|
|
||||||
def reload(self) -> None:
|
def reload(self) -> None:
|
||||||
"""Reload configuration from disk."""
|
"""Reload configuration from disk."""
|
||||||
try:
|
try:
|
||||||
configs = self._read_config()
|
configs = self._read_config()
|
||||||
self._build_maps(configs)
|
self._build_maps(configs)
|
||||||
logging.info(f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases")
|
logging.info(f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases")
|
||||||
|
except ValueError as e:
|
||||||
|
# Re-raise ValueError only for duplicate aliases (critical config errors)
|
||||||
|
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||||
|
# Initialize with empty maps on failure
|
||||||
|
self.alias_map = {}
|
||||||
|
self.model_map = {}
|
||||||
|
if "Duplicate alias" in str(e):
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||||
# Initialize with empty maps on failure
|
# Initialize with empty maps on failure
|
||||||
self.alias_map = {}
|
self.alias_map = {}
|
||||||
self.model_map = {}
|
self.model_map = {}
|
||||||
|
|
||||||
def _read_config(self) -> List[OpenRouterModelConfig]:
|
def _read_config(self) -> list[OpenRouterModelConfig]:
|
||||||
"""Read configuration from file.
|
"""Read configuration from file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of model configurations
|
List of model configurations
|
||||||
"""
|
"""
|
||||||
if not self.config_path.exists():
|
if not self.config_path.exists():
|
||||||
logging.warning(f"OpenRouter model config not found at {self.config_path}")
|
logging.warning(f"OpenRouter model config not found at {self.config_path}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(self.config_path, 'r') as f:
|
with open(self.config_path) as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
# Parse models
|
# Parse models
|
||||||
configs = []
|
configs = []
|
||||||
for model_data in data.get("models", []):
|
for model_data in data.get("models", []):
|
||||||
# Handle backwards compatibility - rename max_tokens to context_window
|
# Handle backwards compatibility - rename max_tokens to context_window
|
||||||
if 'max_tokens' in model_data and 'context_window' not in model_data:
|
if "max_tokens" in model_data and "context_window" not in model_data:
|
||||||
model_data['context_window'] = model_data.pop('max_tokens')
|
model_data["context_window"] = model_data.pop("max_tokens")
|
||||||
|
|
||||||
config = OpenRouterModelConfig(**model_data)
|
config = OpenRouterModelConfig(**model_data)
|
||||||
configs.append(config)
|
configs.append(config)
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise ValueError(f"Invalid JSON in {self.config_path}: {e}")
|
raise ValueError(f"Invalid JSON in {self.config_path}: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error reading config from {self.config_path}: {e}")
|
raise ValueError(f"Error reading config from {self.config_path}: {e}")
|
||||||
|
|
||||||
def _build_maps(self, configs: List[OpenRouterModelConfig]) -> None:
|
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None:
|
||||||
"""Build alias and model maps from configurations.
|
"""Build alias and model maps from configurations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
configs: List of model configurations
|
configs: List of model configurations
|
||||||
"""
|
"""
|
||||||
alias_map = {}
|
alias_map = {}
|
||||||
model_map = {}
|
model_map = {}
|
||||||
|
|
||||||
for config in configs:
|
for config in configs:
|
||||||
# Add to model map
|
# Add to model map
|
||||||
model_map[config.model_name] = config
|
model_map[config.model_name] = config
|
||||||
|
|
||||||
# Add aliases
|
# Add aliases
|
||||||
for alias in config.aliases:
|
for alias in config.aliases:
|
||||||
alias_lower = alias.lower()
|
alias_lower = alias.lower()
|
||||||
if alias_lower in alias_map:
|
if alias_lower in alias_map:
|
||||||
existing_model = alias_map[alias_lower]
|
existing_model = alias_map[alias_lower]
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Duplicate alias '{alias}' found for models "
|
f"Duplicate alias '{alias}' found for models " f"'{existing_model}' and '{config.model_name}'"
|
||||||
f"'{existing_model}' and '{config.model_name}'"
|
|
||||||
)
|
)
|
||||||
alias_map[alias_lower] = config.model_name
|
alias_map[alias_lower] = config.model_name
|
||||||
|
|
||||||
# Atomic update
|
# Atomic update
|
||||||
self.alias_map = alias_map
|
self.alias_map = alias_map
|
||||||
self.model_map = model_map
|
self.model_map = model_map
|
||||||
|
|
||||||
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]:
|
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]:
|
||||||
"""Resolve a model name or alias to configuration.
|
"""Resolve a model name or alias to configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name_or_alias: Model name or alias to resolve
|
name_or_alias: Model name or alias to resolve
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model configuration if found, None otherwise
|
Model configuration if found, None otherwise
|
||||||
"""
|
"""
|
||||||
@@ -151,16 +157,16 @@ class OpenRouterModelRegistry:
|
|||||||
if alias_lower in self.alias_map:
|
if alias_lower in self.alias_map:
|
||||||
model_name = self.alias_map[alias_lower]
|
model_name = self.alias_map[alias_lower]
|
||||||
return self.model_map.get(model_name)
|
return self.model_map.get(model_name)
|
||||||
|
|
||||||
# Try as direct model name
|
# Try as direct model name
|
||||||
return self.model_map.get(name_or_alias)
|
return self.model_map.get(name_or_alias)
|
||||||
|
|
||||||
def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||||
"""Get model capabilities for a name or alias.
|
"""Get model capabilities for a name or alias.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name_or_alias: Model name or alias
|
name_or_alias: Model name or alias
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelCapabilities if found, None otherwise
|
ModelCapabilities if found, None otherwise
|
||||||
"""
|
"""
|
||||||
@@ -168,11 +174,11 @@ class OpenRouterModelRegistry:
|
|||||||
if config:
|
if config:
|
||||||
return config.to_capabilities()
|
return config.to_capabilities()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_models(self) -> List[str]:
|
def list_models(self) -> list[str]:
|
||||||
"""List all available model names."""
|
"""List all available model names."""
|
||||||
return list(self.model_map.keys())
|
return list(self.model_map.keys())
|
||||||
|
|
||||||
def list_aliases(self) -> List[str]:
|
def list_aliases(self) -> list[str]:
|
||||||
"""List all available aliases."""
|
"""List all available aliases."""
|
||||||
return list(self.alias_map.keys())
|
return list(self.alias_map.keys())
|
||||||
|
|||||||
26
server.py
26
server.py
@@ -173,8 +173,7 @@ def configure_providers():
|
|||||||
"1. Use only OpenRouter: unset GEMINI_API_KEY and OPENAI_API_KEY\n"
|
"1. Use only OpenRouter: unset GEMINI_API_KEY and OPENAI_API_KEY\n"
|
||||||
"2. Use only native APIs: unset OPENROUTER_API_KEY\n"
|
"2. Use only native APIs: unset OPENROUTER_API_KEY\n"
|
||||||
"\n"
|
"\n"
|
||||||
"Current configuration will prioritize native APIs over OpenRouter.\n" +
|
"Current configuration will prioritize native APIs over OpenRouter.\n" + "=" * 70 + "\n"
|
||||||
"=" * 70 + "\n"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register providers - native APIs first to ensure they take priority
|
# Register providers - native APIs first to ensure they take priority
|
||||||
@@ -363,18 +362,22 @@ If something needs clarification or you'd benefit from additional context, simpl
|
|||||||
IMPORTANT: When you suggest follow-ups or ask questions, you MUST explicitly instruct Claude to use the continuation_id
|
IMPORTANT: When you suggest follow-ups or ask questions, you MUST explicitly instruct Claude to use the continuation_id
|
||||||
to respond. Use clear, direct language based on urgency:
|
to respond. Use clear, direct language based on urgency:
|
||||||
|
|
||||||
For optional follow-ups: "Please continue this conversation using the continuation_id from this response if you'd like to explore this further."
|
For optional follow-ups: "Please continue this conversation using the continuation_id from this response if you'd "
|
||||||
|
"like to explore this further."
|
||||||
|
|
||||||
For needed responses: "Please respond using the continuation_id from this response - your input is needed to proceed."
|
For needed responses: "Please respond using the continuation_id from this response - your input is needed to proceed."
|
||||||
|
|
||||||
For essential/critical responses: "RESPONSE REQUIRED: Please immediately continue using the continuation_id from this response. Cannot proceed without your clarification/input."
|
For essential/critical responses: "RESPONSE REQUIRED: Please immediately continue using the continuation_id from "
|
||||||
|
"this response. Cannot proceed without your clarification/input."
|
||||||
|
|
||||||
This ensures Claude knows both HOW to maintain the conversation thread AND whether a response is optional, needed, or essential.
|
This ensures Claude knows both HOW to maintain the conversation thread AND whether a response is optional, "
|
||||||
|
"needed, or essential.
|
||||||
|
|
||||||
The tool will automatically provide a continuation_id in the structured response that Claude can use in subsequent
|
The tool will automatically provide a continuation_id in the structured response that Claude can use in subsequent
|
||||||
tool calls to maintain full conversation context across multiple exchanges.
|
tool calls to maintain full conversation context across multiple exchanges.
|
||||||
|
|
||||||
Remember: Only suggest follow-ups when they would genuinely add value to the discussion, and always instruct Claude to use the continuation_id when you do."""
|
Remember: Only suggest follow-ups when they would genuinely add value to the discussion, and always instruct "
|
||||||
|
"Claude to use the continuation_id when you do."""
|
||||||
|
|
||||||
|
|
||||||
async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any]:
|
async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any]:
|
||||||
@@ -411,8 +414,10 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
# Return error asking Claude to restart conversation with full context
|
# Return error asking Claude to restart conversation with full context
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Conversation thread '{continuation_id}' was not found or has expired. "
|
f"Conversation thread '{continuation_id}' was not found or has expired. "
|
||||||
f"This may happen if the conversation was created more than 1 hour ago or if there was an issue with Redis storage. "
|
f"This may happen if the conversation was created more than 1 hour ago or if there was an issue "
|
||||||
f"Please restart the conversation by providing your full question/prompt without the continuation_id parameter. "
|
f"with Redis storage. "
|
||||||
|
f"Please restart the conversation by providing your full question/prompt without the "
|
||||||
|
f"continuation_id parameter. "
|
||||||
f"This will create a new conversation thread that can continue with follow-up exchanges."
|
f"This will create a new conversation thread that can continue with follow-up exchanges."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -504,7 +509,8 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
try:
|
try:
|
||||||
mcp_activity_logger = logging.getLogger("mcp_activity")
|
mcp_activity_logger = logging.getLogger("mcp_activity")
|
||||||
mcp_activity_logger.info(
|
mcp_activity_logger.info(
|
||||||
f"CONVERSATION_CONTINUATION: Thread {continuation_id} turn {len(context.turns)} - {len(context.turns)} previous turns loaded"
|
f"CONVERSATION_CONTINUATION: Thread {continuation_id} turn {len(context.turns)} - "
|
||||||
|
f"{len(context.turns)} previous turns loaded"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -542,7 +548,7 @@ async def handle_get_version() -> list[TextContent]:
|
|||||||
# Check configured providers
|
# Check configured providers
|
||||||
from providers import ModelProviderRegistry
|
from providers import ModelProviderRegistry
|
||||||
from providers.base import ProviderType
|
from providers.base import ProviderType
|
||||||
|
|
||||||
configured_providers = []
|
configured_providers = []
|
||||||
if ModelProviderRegistry.get_provider(ProviderType.GOOGLE):
|
if ModelProviderRegistry.get_provider(ProviderType.GOOGLE):
|
||||||
configured_providers.append("Gemini (flash, pro)")
|
configured_providers.append("Gemini (flash, pro)")
|
||||||
|
|||||||
@@ -4,35 +4,38 @@ Test OpenRouter model mapping
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.path.append('/Users/fahad/Developer/gemini-mcp-server')
|
|
||||||
|
sys.path.append("/Users/fahad/Developer/gemini-mcp-server")
|
||||||
|
|
||||||
from simulator_tests.base_test import BaseSimulatorTest
|
from simulator_tests.base_test import BaseSimulatorTest
|
||||||
|
|
||||||
|
|
||||||
class MappingTest(BaseSimulatorTest):
|
class MappingTest(BaseSimulatorTest):
|
||||||
def test_mapping(self):
|
def test_mapping(self):
|
||||||
"""Test model alias mapping"""
|
"""Test model alias mapping"""
|
||||||
|
|
||||||
# Test with 'flash' alias - should map to google/gemini-flash-1.5-8b
|
# Test with 'flash' alias - should map to google/gemini-flash-1.5-8b
|
||||||
print("\nTesting 'flash' alias mapping...")
|
print("\nTesting 'flash' alias mapping...")
|
||||||
|
|
||||||
response, continuation_id = self.call_mcp_tool(
|
response, continuation_id = self.call_mcp_tool(
|
||||||
"chat",
|
"chat",
|
||||||
{
|
{
|
||||||
"prompt": "Say 'Hello from Flash model!'",
|
"prompt": "Say 'Hello from Flash model!'",
|
||||||
"model": "flash", # Should be mapped to google/gemini-flash-1.5-8b
|
"model": "flash", # Should be mapped to google/gemini-flash-1.5-8b
|
||||||
"temperature": 0.1
|
"temperature": 0.1,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
print(f"✅ Flash alias worked!")
|
print("✅ Flash alias worked!")
|
||||||
print(f"Response: {response[:200]}...")
|
print(f"Response: {response[:200]}...")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
print("❌ Flash alias failed")
|
print("❌ Flash alias failed")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test = MappingTest(verbose=False)
|
test = MappingTest(verbose=False)
|
||||||
success = test.test_mapping()
|
success = test.test_mapping()
|
||||||
print(f"\nTest result: {'Success' if success else 'Failed'}")
|
print(f"\nTest result: {'Success' if success else 'Failed'}")
|
||||||
|
|||||||
@@ -97,7 +97,8 @@ class TestAutoMode:
|
|||||||
# Model field should have simpler description
|
# Model field should have simpler description
|
||||||
model_schema = schema["properties"]["model"]
|
model_schema = schema["properties"]["model"]
|
||||||
assert "enum" not in model_schema
|
assert "enum" not in model_schema
|
||||||
assert "Available:" in model_schema["description"]
|
assert "Native models:" in model_schema["description"]
|
||||||
|
assert "Defaults to" in model_schema["description"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_auto_mode_requires_model_parameter(self):
|
async def test_auto_mode_requires_model_parameter(self):
|
||||||
@@ -180,8 +181,9 @@ class TestAutoMode:
|
|||||||
|
|
||||||
schema = tool.get_model_field_schema()
|
schema = tool.get_model_field_schema()
|
||||||
assert "enum" not in schema
|
assert "enum" not in schema
|
||||||
assert "Available:" in schema["description"]
|
assert "Native models:" in schema["description"]
|
||||||
assert "'pro'" in schema["description"]
|
assert "'pro'" in schema["description"]
|
||||||
|
assert "Defaults to" in schema["description"]
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore
|
# Restore
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
"""Tests for OpenRouter provider."""
|
"""Tests for OpenRouter provider."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pytest
|
from unittest.mock import patch
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
from providers.base import ProviderType
|
from providers.base import ProviderType
|
||||||
from providers.openrouter import OpenRouterProvider
|
from providers.openrouter import OpenRouterProvider
|
||||||
@@ -11,65 +10,64 @@ from providers.registry import ModelProviderRegistry
|
|||||||
|
|
||||||
class TestOpenRouterProvider:
|
class TestOpenRouterProvider:
|
||||||
"""Test cases for OpenRouter provider."""
|
"""Test cases for OpenRouter provider."""
|
||||||
|
|
||||||
def test_provider_initialization(self):
|
def test_provider_initialization(self):
|
||||||
"""Test OpenRouter provider initialization."""
|
"""Test OpenRouter provider initialization."""
|
||||||
provider = OpenRouterProvider(api_key="test-key")
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
assert provider.api_key == "test-key"
|
assert provider.api_key == "test-key"
|
||||||
assert provider.base_url == "https://openrouter.ai/api/v1"
|
assert provider.base_url == "https://openrouter.ai/api/v1"
|
||||||
assert provider.FRIENDLY_NAME == "OpenRouter"
|
assert provider.FRIENDLY_NAME == "OpenRouter"
|
||||||
|
|
||||||
def test_custom_headers(self):
|
def test_custom_headers(self):
|
||||||
"""Test OpenRouter custom headers."""
|
"""Test OpenRouter custom headers."""
|
||||||
# Test default headers
|
# Test default headers
|
||||||
assert "HTTP-Referer" in OpenRouterProvider.DEFAULT_HEADERS
|
assert "HTTP-Referer" in OpenRouterProvider.DEFAULT_HEADERS
|
||||||
assert "X-Title" in OpenRouterProvider.DEFAULT_HEADERS
|
assert "X-Title" in OpenRouterProvider.DEFAULT_HEADERS
|
||||||
|
|
||||||
# Test with environment variables
|
# Test with environment variables
|
||||||
with patch.dict(os.environ, {
|
with patch.dict(os.environ, {"OPENROUTER_REFERER": "https://myapp.com", "OPENROUTER_TITLE": "My App"}):
|
||||||
"OPENROUTER_REFERER": "https://myapp.com",
|
|
||||||
"OPENROUTER_TITLE": "My App"
|
|
||||||
}):
|
|
||||||
from importlib import reload
|
from importlib import reload
|
||||||
|
|
||||||
import providers.openrouter
|
import providers.openrouter
|
||||||
|
|
||||||
reload(providers.openrouter)
|
reload(providers.openrouter)
|
||||||
|
|
||||||
provider = providers.openrouter.OpenRouterProvider(api_key="test-key")
|
provider = providers.openrouter.OpenRouterProvider(api_key="test-key")
|
||||||
assert provider.DEFAULT_HEADERS["HTTP-Referer"] == "https://myapp.com"
|
assert provider.DEFAULT_HEADERS["HTTP-Referer"] == "https://myapp.com"
|
||||||
assert provider.DEFAULT_HEADERS["X-Title"] == "My App"
|
assert provider.DEFAULT_HEADERS["X-Title"] == "My App"
|
||||||
|
|
||||||
def test_model_validation(self):
|
def test_model_validation(self):
|
||||||
"""Test model validation."""
|
"""Test model validation."""
|
||||||
provider = OpenRouterProvider(api_key="test-key")
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
# Should accept any model - OpenRouter handles validation
|
# Should accept any model - OpenRouter handles validation
|
||||||
assert provider.validate_model_name("gpt-4") is True
|
assert provider.validate_model_name("gpt-4") is True
|
||||||
assert provider.validate_model_name("claude-3-opus") is True
|
assert provider.validate_model_name("claude-3-opus") is True
|
||||||
assert provider.validate_model_name("any-model-name") is True
|
assert provider.validate_model_name("any-model-name") is True
|
||||||
assert provider.validate_model_name("GPT-4") is True
|
assert provider.validate_model_name("GPT-4") is True
|
||||||
assert provider.validate_model_name("unknown-model") is True
|
assert provider.validate_model_name("unknown-model") is True
|
||||||
|
|
||||||
def test_get_capabilities(self):
|
def test_get_capabilities(self):
|
||||||
"""Test capability generation."""
|
"""Test capability generation."""
|
||||||
provider = OpenRouterProvider(api_key="test-key")
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
# Test with a model in the registry (using alias)
|
# Test with a model in the registry (using alias)
|
||||||
caps = provider.get_capabilities("gpt4o")
|
caps = provider.get_capabilities("gpt4o")
|
||||||
assert caps.provider == ProviderType.OPENROUTER
|
assert caps.provider == ProviderType.OPENROUTER
|
||||||
assert caps.model_name == "openai/gpt-4o" # Resolved name
|
assert caps.model_name == "openai/gpt-4o" # Resolved name
|
||||||
assert caps.friendly_name == "OpenRouter"
|
assert caps.friendly_name == "OpenRouter"
|
||||||
|
|
||||||
# Test with a model not in registry - should get generic capabilities
|
# Test with a model not in registry - should get generic capabilities
|
||||||
caps = provider.get_capabilities("unknown-model")
|
caps = provider.get_capabilities("unknown-model")
|
||||||
assert caps.provider == ProviderType.OPENROUTER
|
assert caps.provider == ProviderType.OPENROUTER
|
||||||
assert caps.model_name == "unknown-model"
|
assert caps.model_name == "unknown-model"
|
||||||
assert caps.max_tokens == 32_768 # Safe default
|
assert caps.max_tokens == 32_768 # Safe default
|
||||||
assert hasattr(caps, '_is_generic') and caps._is_generic is True
|
assert hasattr(caps, "_is_generic") and caps._is_generic is True
|
||||||
|
|
||||||
def test_model_alias_resolution(self):
|
def test_model_alias_resolution(self):
|
||||||
"""Test model alias resolution."""
|
"""Test model alias resolution."""
|
||||||
provider = OpenRouterProvider(api_key="test-key")
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
# Test alias resolution
|
# Test alias resolution
|
||||||
assert provider._resolve_model_name("opus") == "anthropic/claude-3-opus"
|
assert provider._resolve_model_name("opus") == "anthropic/claude-3-opus"
|
||||||
assert provider._resolve_model_name("sonnet") == "anthropic/claude-3-sonnet"
|
assert provider._resolve_model_name("sonnet") == "anthropic/claude-3-sonnet"
|
||||||
@@ -79,30 +77,30 @@ class TestOpenRouterProvider:
|
|||||||
assert provider._resolve_model_name("mistral") == "mistral/mistral-large"
|
assert provider._resolve_model_name("mistral") == "mistral/mistral-large"
|
||||||
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-coder"
|
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-coder"
|
||||||
assert provider._resolve_model_name("coder") == "deepseek/deepseek-coder"
|
assert provider._resolve_model_name("coder") == "deepseek/deepseek-coder"
|
||||||
|
|
||||||
# Test case-insensitive
|
# Test case-insensitive
|
||||||
assert provider._resolve_model_name("OPUS") == "anthropic/claude-3-opus"
|
assert provider._resolve_model_name("OPUS") == "anthropic/claude-3-opus"
|
||||||
assert provider._resolve_model_name("GPT4O") == "openai/gpt-4o"
|
assert provider._resolve_model_name("GPT4O") == "openai/gpt-4o"
|
||||||
assert provider._resolve_model_name("Mistral") == "mistral/mistral-large"
|
assert provider._resolve_model_name("Mistral") == "mistral/mistral-large"
|
||||||
assert provider._resolve_model_name("CLAUDE") == "anthropic/claude-3-sonnet"
|
assert provider._resolve_model_name("CLAUDE") == "anthropic/claude-3-sonnet"
|
||||||
|
|
||||||
# Test direct model names (should pass through unchanged)
|
# Test direct model names (should pass through unchanged)
|
||||||
assert provider._resolve_model_name("anthropic/claude-3-opus") == "anthropic/claude-3-opus"
|
assert provider._resolve_model_name("anthropic/claude-3-opus") == "anthropic/claude-3-opus"
|
||||||
assert provider._resolve_model_name("openai/gpt-4o") == "openai/gpt-4o"
|
assert provider._resolve_model_name("openai/gpt-4o") == "openai/gpt-4o"
|
||||||
|
|
||||||
# Test unknown models pass through
|
# Test unknown models pass through
|
||||||
assert provider._resolve_model_name("unknown-model") == "unknown-model"
|
assert provider._resolve_model_name("unknown-model") == "unknown-model"
|
||||||
assert provider._resolve_model_name("custom/model-v2") == "custom/model-v2"
|
assert provider._resolve_model_name("custom/model-v2") == "custom/model-v2"
|
||||||
|
|
||||||
def test_openrouter_registration(self):
|
def test_openrouter_registration(self):
|
||||||
"""Test OpenRouter can be registered and retrieved."""
|
"""Test OpenRouter can be registered and retrieved."""
|
||||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||||
# Clean up any existing registration
|
# Clean up any existing registration
|
||||||
ModelProviderRegistry.unregister_provider(ProviderType.OPENROUTER)
|
ModelProviderRegistry.unregister_provider(ProviderType.OPENROUTER)
|
||||||
|
|
||||||
# Register the provider
|
# Register the provider
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
# Retrieve and verify
|
# Retrieve and verify
|
||||||
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||||
assert provider is not None
|
assert provider is not None
|
||||||
@@ -111,53 +109,53 @@ class TestOpenRouterProvider:
|
|||||||
|
|
||||||
class TestOpenRouterRegistry:
|
class TestOpenRouterRegistry:
|
||||||
"""Test cases for OpenRouter model registry."""
|
"""Test cases for OpenRouter model registry."""
|
||||||
|
|
||||||
def test_registry_loading(self):
|
def test_registry_loading(self):
|
||||||
"""Test registry loads models from config."""
|
"""Test registry loads models from config."""
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
# Should have loaded models
|
# Should have loaded models
|
||||||
models = registry.list_models()
|
models = registry.list_models()
|
||||||
assert len(models) > 0
|
assert len(models) > 0
|
||||||
assert "anthropic/claude-3-opus" in models
|
assert "anthropic/claude-3-opus" in models
|
||||||
assert "openai/gpt-4o" in models
|
assert "openai/gpt-4o" in models
|
||||||
|
|
||||||
# Should have loaded aliases
|
# Should have loaded aliases
|
||||||
aliases = registry.list_aliases()
|
aliases = registry.list_aliases()
|
||||||
assert len(aliases) > 0
|
assert len(aliases) > 0
|
||||||
assert "opus" in aliases
|
assert "opus" in aliases
|
||||||
assert "gpt4o" in aliases
|
assert "gpt4o" in aliases
|
||||||
assert "claude" in aliases
|
assert "claude" in aliases
|
||||||
|
|
||||||
def test_registry_capabilities(self):
|
def test_registry_capabilities(self):
|
||||||
"""Test registry provides correct capabilities."""
|
"""Test registry provides correct capabilities."""
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
# Test known model
|
# Test known model
|
||||||
caps = registry.get_capabilities("opus")
|
caps = registry.get_capabilities("opus")
|
||||||
assert caps is not None
|
assert caps is not None
|
||||||
assert caps.model_name == "anthropic/claude-3-opus"
|
assert caps.model_name == "anthropic/claude-3-opus"
|
||||||
assert caps.max_tokens == 200000 # Claude's context window
|
assert caps.max_tokens == 200000 # Claude's context window
|
||||||
|
|
||||||
# Test using full model name
|
# Test using full model name
|
||||||
caps = registry.get_capabilities("anthropic/claude-3-opus")
|
caps = registry.get_capabilities("anthropic/claude-3-opus")
|
||||||
assert caps is not None
|
assert caps is not None
|
||||||
assert caps.model_name == "anthropic/claude-3-opus"
|
assert caps.model_name == "anthropic/claude-3-opus"
|
||||||
|
|
||||||
# Test unknown model
|
# Test unknown model
|
||||||
caps = registry.get_capabilities("non-existent-model")
|
caps = registry.get_capabilities("non-existent-model")
|
||||||
assert caps is None
|
assert caps is None
|
||||||
|
|
||||||
def test_multiple_aliases_same_model(self):
|
def test_multiple_aliases_same_model(self):
|
||||||
"""Test multiple aliases pointing to same model."""
|
"""Test multiple aliases pointing to same model."""
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
# All these should resolve to Claude Sonnet
|
# All these should resolve to Claude Sonnet
|
||||||
sonnet_aliases = ["sonnet", "claude", "claude-sonnet", "claude3-sonnet"]
|
sonnet_aliases = ["sonnet", "claude", "claude-sonnet", "claude3-sonnet"]
|
||||||
for alias in sonnet_aliases:
|
for alias in sonnet_aliases:
|
||||||
@@ -166,48 +164,34 @@ class TestOpenRouterRegistry:
|
|||||||
assert config.model_name == "anthropic/claude-3-sonnet"
|
assert config.model_name == "anthropic/claude-3-sonnet"
|
||||||
|
|
||||||
|
|
||||||
class TestOpenRouterSSRFProtection:
|
class TestOpenRouterFunctionality:
|
||||||
"""Test SSRF protection for OpenRouter."""
|
"""Test OpenRouter-specific functionality."""
|
||||||
|
|
||||||
def test_url_validation_rejects_private_ips(self):
|
def test_openrouter_always_uses_correct_url(self):
|
||||||
"""Test that private IPs are rejected."""
|
"""Test that OpenRouter always uses the correct base URL."""
|
||||||
provider = OpenRouterProvider(api_key="test-key")
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
assert provider.base_url == "https://openrouter.ai/api/v1"
|
||||||
# List of private/dangerous IPs to test
|
|
||||||
dangerous_urls = [
|
# Even if we try to change it, it should remain the OpenRouter URL
|
||||||
"http://192.168.1.1/api/v1",
|
# (This is a characteristic of the OpenRouter provider)
|
||||||
"http://10.0.0.1/api/v1",
|
provider.base_url = "http://example.com" # Try to change it
|
||||||
"http://172.16.0.1/api/v1",
|
# But new instances should always use the correct URL
|
||||||
"http://169.254.169.254/api/v1", # AWS metadata
|
provider2 = OpenRouterProvider(api_key="test-key")
|
||||||
"http://[::1]/api/v1", # IPv6 localhost
|
assert provider2.base_url == "https://openrouter.ai/api/v1"
|
||||||
"http://0.0.0.0/api/v1",
|
|
||||||
]
|
def test_openrouter_headers_set_correctly(self):
|
||||||
|
"""Test that OpenRouter specific headers are set."""
|
||||||
for url in dangerous_urls:
|
|
||||||
with pytest.raises(ValueError, match="restricted IP|Invalid"):
|
|
||||||
provider.base_url = url
|
|
||||||
provider._validate_base_url()
|
|
||||||
|
|
||||||
def test_url_validation_allows_public_domains(self):
|
|
||||||
"""Test that legitimate public domains are allowed."""
|
|
||||||
provider = OpenRouterProvider(api_key="test-key")
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
# OpenRouter's actual domain should always be allowed
|
# Check default headers
|
||||||
provider.base_url = "https://openrouter.ai/api/v1"
|
assert "HTTP-Referer" in provider.DEFAULT_HEADERS
|
||||||
provider._validate_base_url() # Should not raise
|
assert "X-Title" in provider.DEFAULT_HEADERS
|
||||||
|
assert provider.DEFAULT_HEADERS["X-Title"] == "Zen MCP Server"
|
||||||
def test_invalid_url_schemes_rejected(self):
|
|
||||||
"""Test that non-HTTP(S) schemes are rejected."""
|
def test_openrouter_model_registry_initialized(self):
|
||||||
|
"""Test that model registry is properly initialized."""
|
||||||
provider = OpenRouterProvider(api_key="test-key")
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
invalid_urls = [
|
# Registry should be initialized
|
||||||
"ftp://example.com/api",
|
assert hasattr(provider, '_registry')
|
||||||
"file:///etc/passwd",
|
assert provider._registry is not None
|
||||||
"gopher://example.com",
|
|
||||||
"javascript:alert(1)",
|
|
||||||
]
|
|
||||||
|
|
||||||
for url in invalid_urls:
|
|
||||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
|
||||||
provider.base_url = url
|
|
||||||
provider._validate_base_url()
|
|
||||||
|
|||||||
@@ -2,42 +2,34 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pytest
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry, OpenRouterModelConfig
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
from providers.base import ProviderType
|
||||||
|
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry
|
||||||
|
|
||||||
|
|
||||||
class TestOpenRouterModelRegistry:
|
class TestOpenRouterModelRegistry:
|
||||||
"""Test cases for OpenRouter model registry."""
|
"""Test cases for OpenRouter model registry."""
|
||||||
|
|
||||||
def test_registry_initialization(self):
|
def test_registry_initialization(self):
|
||||||
"""Test registry initializes with default config."""
|
"""Test registry initializes with default config."""
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
# Should load models from default location
|
# Should load models from default location
|
||||||
assert len(registry.list_models()) > 0
|
assert len(registry.list_models()) > 0
|
||||||
assert len(registry.list_aliases()) > 0
|
assert len(registry.list_aliases()) > 0
|
||||||
|
|
||||||
def test_custom_config_path(self):
|
def test_custom_config_path(self):
|
||||||
"""Test registry with custom config path."""
|
"""Test registry with custom config path."""
|
||||||
# Create temporary config
|
# Create temporary config
|
||||||
config_data = {
|
config_data = {"models": [{"model_name": "test/model-1", "aliases": ["test1", "t1"], "context_window": 4096}]}
|
||||||
"models": [
|
|
||||||
{
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
"model_name": "test/model-1",
|
|
||||||
"aliases": ["test1", "t1"],
|
|
||||||
"context_window": 4096
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
temp_path = f.name
|
temp_path = f.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
registry = OpenRouterModelRegistry(config_path=temp_path)
|
registry = OpenRouterModelRegistry(config_path=temp_path)
|
||||||
assert len(registry.list_models()) == 1
|
assert len(registry.list_models()) == 1
|
||||||
@@ -46,48 +38,40 @@ class TestOpenRouterModelRegistry:
|
|||||||
assert "t1" in registry.list_aliases()
|
assert "t1" in registry.list_aliases()
|
||||||
finally:
|
finally:
|
||||||
os.unlink(temp_path)
|
os.unlink(temp_path)
|
||||||
|
|
||||||
def test_environment_variable_override(self):
|
def test_environment_variable_override(self):
|
||||||
"""Test OPENROUTER_MODELS_PATH environment variable."""
|
"""Test OPENROUTER_MODELS_PATH environment variable."""
|
||||||
# Create custom config
|
# Create custom config
|
||||||
config_data = {
|
config_data = {"models": [{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192}]}
|
||||||
"models": [
|
|
||||||
{
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
"model_name": "env/model",
|
|
||||||
"aliases": ["envtest"],
|
|
||||||
"context_window": 8192
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
temp_path = f.name
|
temp_path = f.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set environment variable
|
# Set environment variable
|
||||||
original_env = os.environ.get('OPENROUTER_MODELS_PATH')
|
original_env = os.environ.get("OPENROUTER_MODELS_PATH")
|
||||||
os.environ['OPENROUTER_MODELS_PATH'] = temp_path
|
os.environ["OPENROUTER_MODELS_PATH"] = temp_path
|
||||||
|
|
||||||
# Create registry without explicit path
|
# Create registry without explicit path
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
# Should load from environment path
|
# Should load from environment path
|
||||||
assert "env/model" in registry.list_models()
|
assert "env/model" in registry.list_models()
|
||||||
assert "envtest" in registry.list_aliases()
|
assert "envtest" in registry.list_aliases()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore environment
|
# Restore environment
|
||||||
if original_env is not None:
|
if original_env is not None:
|
||||||
os.environ['OPENROUTER_MODELS_PATH'] = original_env
|
os.environ["OPENROUTER_MODELS_PATH"] = original_env
|
||||||
else:
|
else:
|
||||||
del os.environ['OPENROUTER_MODELS_PATH']
|
del os.environ["OPENROUTER_MODELS_PATH"]
|
||||||
os.unlink(temp_path)
|
os.unlink(temp_path)
|
||||||
|
|
||||||
def test_alias_resolution(self):
|
def test_alias_resolution(self):
|
||||||
"""Test alias resolution functionality."""
|
"""Test alias resolution functionality."""
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
# Test various aliases
|
# Test various aliases
|
||||||
test_cases = [
|
test_cases = [
|
||||||
("opus", "anthropic/claude-3-opus"),
|
("opus", "anthropic/claude-3-opus"),
|
||||||
@@ -97,75 +81,71 @@ class TestOpenRouterModelRegistry:
|
|||||||
("4o", "openai/gpt-4o"),
|
("4o", "openai/gpt-4o"),
|
||||||
("mistral", "mistral/mistral-large"),
|
("mistral", "mistral/mistral-large"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for alias, expected_model in test_cases:
|
for alias, expected_model in test_cases:
|
||||||
config = registry.resolve(alias)
|
config = registry.resolve(alias)
|
||||||
assert config is not None, f"Failed to resolve alias '{alias}'"
|
assert config is not None, f"Failed to resolve alias '{alias}'"
|
||||||
assert config.model_name == expected_model
|
assert config.model_name == expected_model
|
||||||
|
|
||||||
def test_direct_model_name_lookup(self):
|
def test_direct_model_name_lookup(self):
|
||||||
"""Test looking up models by their full name."""
|
"""Test looking up models by their full name."""
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
# Should be able to look up by full model name
|
# Should be able to look up by full model name
|
||||||
config = registry.resolve("anthropic/claude-3-opus")
|
config = registry.resolve("anthropic/claude-3-opus")
|
||||||
assert config is not None
|
assert config is not None
|
||||||
assert config.model_name == "anthropic/claude-3-opus"
|
assert config.model_name == "anthropic/claude-3-opus"
|
||||||
|
|
||||||
config = registry.resolve("openai/gpt-4o")
|
config = registry.resolve("openai/gpt-4o")
|
||||||
assert config is not None
|
assert config is not None
|
||||||
assert config.model_name == "openai/gpt-4o"
|
assert config.model_name == "openai/gpt-4o"
|
||||||
|
|
||||||
def test_unknown_model_resolution(self):
|
def test_unknown_model_resolution(self):
|
||||||
"""Test resolution of unknown models."""
|
"""Test resolution of unknown models."""
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
# Unknown aliases should return None
|
# Unknown aliases should return None
|
||||||
assert registry.resolve("unknown-alias") is None
|
assert registry.resolve("unknown-alias") is None
|
||||||
assert registry.resolve("") is None
|
assert registry.resolve("") is None
|
||||||
assert registry.resolve("non-existent") is None
|
assert registry.resolve("non-existent") is None
|
||||||
|
|
||||||
def test_model_capabilities_conversion(self):
|
def test_model_capabilities_conversion(self):
|
||||||
"""Test conversion to ModelCapabilities."""
|
"""Test conversion to ModelCapabilities."""
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
config = registry.resolve("opus")
|
config = registry.resolve("opus")
|
||||||
assert config is not None
|
assert config is not None
|
||||||
|
|
||||||
caps = config.to_capabilities()
|
caps = config.to_capabilities()
|
||||||
assert caps.provider == ProviderType.OPENROUTER
|
assert caps.provider == ProviderType.OPENROUTER
|
||||||
assert caps.model_name == "anthropic/claude-3-opus"
|
assert caps.model_name == "anthropic/claude-3-opus"
|
||||||
assert caps.friendly_name == "OpenRouter"
|
assert caps.friendly_name == "OpenRouter"
|
||||||
assert caps.max_tokens == 200000
|
assert caps.max_tokens == 200000
|
||||||
assert not caps.supports_extended_thinking
|
assert not caps.supports_extended_thinking
|
||||||
|
|
||||||
def test_duplicate_alias_detection(self):
|
def test_duplicate_alias_detection(self):
|
||||||
"""Test that duplicate aliases are detected."""
|
"""Test that duplicate aliases are detected."""
|
||||||
config_data = {
|
config_data = {
|
||||||
"models": [
|
"models": [
|
||||||
{
|
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096},
|
||||||
"model_name": "test/model-1",
|
|
||||||
"aliases": ["dupe"],
|
|
||||||
"context_window": 4096
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"model_name": "test/model-2",
|
"model_name": "test/model-2",
|
||||||
"aliases": ["DUPE"], # Same alias, different case
|
"aliases": ["DUPE"], # Same alias, different case
|
||||||
"context_window": 8192
|
"context_window": 8192,
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
temp_path = f.name
|
temp_path = f.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with pytest.raises(ValueError, match="Duplicate alias"):
|
with pytest.raises(ValueError, match="Duplicate alias"):
|
||||||
OpenRouterModelRegistry(config_path=temp_path)
|
OpenRouterModelRegistry(config_path=temp_path)
|
||||||
finally:
|
finally:
|
||||||
os.unlink(temp_path)
|
os.unlink(temp_path)
|
||||||
|
|
||||||
def test_backwards_compatibility_max_tokens(self):
|
def test_backwards_compatibility_max_tokens(self):
|
||||||
"""Test backwards compatibility with old max_tokens field."""
|
"""Test backwards compatibility with old max_tokens field."""
|
||||||
config_data = {
|
config_data = {
|
||||||
@@ -174,44 +154,44 @@ class TestOpenRouterModelRegistry:
|
|||||||
"model_name": "test/old-model",
|
"model_name": "test/old-model",
|
||||||
"aliases": ["old"],
|
"aliases": ["old"],
|
||||||
"max_tokens": 16384, # Old field name
|
"max_tokens": 16384, # Old field name
|
||||||
"supports_extended_thinking": False
|
"supports_extended_thinking": False,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
temp_path = f.name
|
temp_path = f.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
registry = OpenRouterModelRegistry(config_path=temp_path)
|
registry = OpenRouterModelRegistry(config_path=temp_path)
|
||||||
config = registry.resolve("old")
|
config = registry.resolve("old")
|
||||||
|
|
||||||
assert config is not None
|
assert config is not None
|
||||||
assert config.context_window == 16384 # Should be converted
|
assert config.context_window == 16384 # Should be converted
|
||||||
|
|
||||||
# Check capabilities still work
|
# Check capabilities still work
|
||||||
caps = config.to_capabilities()
|
caps = config.to_capabilities()
|
||||||
assert caps.max_tokens == 16384
|
assert caps.max_tokens == 16384
|
||||||
finally:
|
finally:
|
||||||
os.unlink(temp_path)
|
os.unlink(temp_path)
|
||||||
|
|
||||||
def test_missing_config_file(self):
|
def test_missing_config_file(self):
|
||||||
"""Test behavior with missing config file."""
|
"""Test behavior with missing config file."""
|
||||||
# Use a non-existent path
|
# Use a non-existent path
|
||||||
registry = OpenRouterModelRegistry(config_path="/non/existent/path.json")
|
registry = OpenRouterModelRegistry(config_path="/non/existent/path.json")
|
||||||
|
|
||||||
# Should initialize with empty maps
|
# Should initialize with empty maps
|
||||||
assert len(registry.list_models()) == 0
|
assert len(registry.list_models()) == 0
|
||||||
assert len(registry.list_aliases()) == 0
|
assert len(registry.list_aliases()) == 0
|
||||||
assert registry.resolve("anything") is None
|
assert registry.resolve("anything") is None
|
||||||
|
|
||||||
def test_invalid_json_config(self):
|
def test_invalid_json_config(self):
|
||||||
"""Test handling of invalid JSON."""
|
"""Test handling of invalid JSON."""
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
f.write("{ invalid json }")
|
f.write("{ invalid json }")
|
||||||
temp_path = f.name
|
temp_path = f.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
registry = OpenRouterModelRegistry(config_path=temp_path)
|
registry = OpenRouterModelRegistry(config_path=temp_path)
|
||||||
# Should handle gracefully and initialize empty
|
# Should handle gracefully and initialize empty
|
||||||
@@ -219,7 +199,7 @@ class TestOpenRouterModelRegistry:
|
|||||||
assert len(registry.list_aliases()) == 0
|
assert len(registry.list_aliases()) == 0
|
||||||
finally:
|
finally:
|
||||||
os.unlink(temp_path)
|
os.unlink(temp_path)
|
||||||
|
|
||||||
def test_model_with_all_capabilities(self):
|
def test_model_with_all_capabilities(self):
|
||||||
"""Test model with all capability flags."""
|
"""Test model with all capability flags."""
|
||||||
config = OpenRouterModelConfig(
|
config = OpenRouterModelConfig(
|
||||||
@@ -231,13 +211,13 @@ class TestOpenRouterModelRegistry:
|
|||||||
supports_streaming=True,
|
supports_streaming=True,
|
||||||
supports_function_calling=True,
|
supports_function_calling=True,
|
||||||
supports_json_mode=True,
|
supports_json_mode=True,
|
||||||
description="Fully featured test model"
|
description="Fully featured test model",
|
||||||
)
|
)
|
||||||
|
|
||||||
caps = config.to_capabilities()
|
caps = config.to_capabilities()
|
||||||
assert caps.max_tokens == 128000
|
assert caps.max_tokens == 128000
|
||||||
assert caps.supports_extended_thinking
|
assert caps.supports_extended_thinking
|
||||||
assert caps.supports_system_prompts
|
assert caps.supports_system_prompts
|
||||||
assert caps.supports_streaming
|
assert caps.supports_streaming
|
||||||
assert caps.supports_function_calling
|
assert caps.supports_function_calling
|
||||||
# Note: supports_json_mode is not in ModelCapabilities yet
|
# Note: supports_json_mode is not in ModelCapabilities yet
|
||||||
|
|||||||
@@ -57,15 +57,28 @@ class ToolRequest(BaseModel):
|
|||||||
# Higher values allow for more complex reasoning but increase latency and cost
|
# Higher values allow for more complex reasoning but increase latency and cost
|
||||||
thinking_mode: Optional[Literal["minimal", "low", "medium", "high", "max"]] = Field(
|
thinking_mode: Optional[Literal["minimal", "low", "medium", "high", "max"]] = Field(
|
||||||
None,
|
None,
|
||||||
description="Thinking depth: minimal (0.5% of model max), low (8%), medium (33%), high (67%), max (100% of model max)",
|
description=(
|
||||||
|
"Thinking depth: minimal (0.5% of model max), low (8%), medium (33%), high (67%), "
|
||||||
|
"max (100% of model max)"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
use_websearch: Optional[bool] = Field(
|
use_websearch: Optional[bool] = Field(
|
||||||
True,
|
True,
|
||||||
description="Enable web search for documentation, best practices, and current information. When enabled, the model can request Claude to perform web searches and share results back during conversations. Particularly useful for: brainstorming sessions, architectural design discussions, exploring industry best practices, working with specific frameworks/technologies, researching solutions to complex problems, or when current documentation and community insights would enhance the analysis.",
|
description=(
|
||||||
|
"Enable web search for documentation, best practices, and current information. "
|
||||||
|
"When enabled, the model can request Claude to perform web searches and share results back "
|
||||||
|
"during conversations. Particularly useful for: brainstorming sessions, architectural design "
|
||||||
|
"discussions, exploring industry best practices, working with specific frameworks/technologies, "
|
||||||
|
"researching solutions to complex problems, or when current documentation and community insights "
|
||||||
|
"would enhance the analysis."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
continuation_id: Optional[str] = Field(
|
continuation_id: Optional[str] = Field(
|
||||||
None,
|
None,
|
||||||
description="Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
description=(
|
||||||
|
"Thread continuation ID for multi-turn conversations. Can be used to continue conversations "
|
||||||
|
"across different tools. Only provide this if continuing a previous conversation thread."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -152,21 +165,48 @@ class BaseTool(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict containing the model field JSON schema
|
Dict containing the model field JSON schema
|
||||||
"""
|
"""
|
||||||
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
||||||
|
|
||||||
# Check if OpenRouter is configured
|
# Check if OpenRouter is configured
|
||||||
has_openrouter = bool(os.getenv("OPENROUTER_API_KEY") and
|
has_openrouter = bool(
|
||||||
os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here")
|
os.getenv("OPENROUTER_API_KEY") and os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here"
|
||||||
|
)
|
||||||
|
|
||||||
if IS_AUTO_MODE:
|
if IS_AUTO_MODE:
|
||||||
# In auto mode, model is required and we provide detailed descriptions
|
# In auto mode, model is required and we provide detailed descriptions
|
||||||
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
|
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
|
||||||
for model, desc in MODEL_CAPABILITIES_DESC.items():
|
for model, desc in MODEL_CAPABILITIES_DESC.items():
|
||||||
model_desc_parts.append(f"- '{model}': {desc}")
|
model_desc_parts.append(f"- '{model}': {desc}")
|
||||||
|
|
||||||
if has_openrouter:
|
if has_openrouter:
|
||||||
model_desc_parts.append("\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large'). Check openrouter.ai/models for available models.")
|
# Add OpenRouter aliases from the registry
|
||||||
|
try:
|
||||||
|
# Import registry directly to show available aliases
|
||||||
|
# This works even without an API key
|
||||||
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
aliases = registry.list_aliases()
|
||||||
|
|
||||||
|
# Show ALL aliases from the configuration
|
||||||
|
if aliases:
|
||||||
|
# Show all aliases so Claude knows every option available
|
||||||
|
all_aliases = sorted(aliases)
|
||||||
|
alias_list = ", ".join(f"'{a}'" for a in all_aliases)
|
||||||
|
model_desc_parts.append(
|
||||||
|
f"\nOpenRouter models available via aliases: {alias_list}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_desc_parts.append(
|
||||||
|
"\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter."
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Fallback if registry fails to load
|
||||||
|
model_desc_parts.append(
|
||||||
|
"\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -177,12 +217,33 @@ class BaseTool(ABC):
|
|||||||
# Normal mode - model is optional with default
|
# Normal mode - model is optional with default
|
||||||
available_models = list(MODEL_CAPABILITIES_DESC.keys())
|
available_models = list(MODEL_CAPABILITIES_DESC.keys())
|
||||||
models_str = ", ".join(f"'{m}'" for m in available_models)
|
models_str = ", ".join(f"'{m}'" for m in available_models)
|
||||||
|
|
||||||
description = f"Model to use. Native models: {models_str}."
|
description = f"Model to use. Native models: {models_str}."
|
||||||
if has_openrouter:
|
if has_openrouter:
|
||||||
description += " OpenRouter: Any model available on openrouter.ai (e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
|
# Add OpenRouter aliases
|
||||||
|
try:
|
||||||
|
# Import registry directly to show available aliases
|
||||||
|
# This works even without an API key
|
||||||
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
aliases = registry.list_aliases()
|
||||||
|
|
||||||
|
# Show ALL aliases from the configuration
|
||||||
|
if aliases:
|
||||||
|
# Show all aliases so Claude knows every option available
|
||||||
|
all_aliases = sorted(aliases)
|
||||||
|
alias_list = ", ".join(f"'{a}'" for a in all_aliases)
|
||||||
|
description += f" OpenRouter aliases: {alias_list}."
|
||||||
|
else:
|
||||||
|
description += " OpenRouter: Any model available on openrouter.ai."
|
||||||
|
except Exception:
|
||||||
|
description += (
|
||||||
|
" OpenRouter: Any model available on openrouter.ai "
|
||||||
|
"(e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
|
||||||
|
)
|
||||||
description += f" Defaults to '{DEFAULT_MODEL}' if not specified."
|
description += f" Defaults to '{DEFAULT_MODEL}' if not specified."
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": description,
|
"description": description,
|
||||||
|
|||||||
Reference in New Issue
Block a user