WIP - OpenRouter support and related refactoring
This commit is contained in:
@@ -8,6 +8,14 @@ GEMINI_API_KEY=your_gemini_api_key_here
|
|||||||
# Get your OpenAI API key from: https://platform.openai.com/api-keys
|
# Get your OpenAI API key from: https://platform.openai.com/api-keys
|
||||||
OPENAI_API_KEY=your_openai_api_key_here
|
OPENAI_API_KEY=your_openai_api_key_here
|
||||||
|
|
||||||
|
# Optional: OpenRouter for access to multiple models
|
||||||
|
# Get your OpenRouter API key from: https://openrouter.ai/
|
||||||
|
OPENROUTER_API_KEY=your_openrouter_api_key_here
|
||||||
|
|
||||||
|
# Optional: Restrict which models can be used via OpenRouter (recommended for cost control)
|
||||||
|
# Example: OPENROUTER_ALLOWED_MODELS=gpt-4,claude-3-opus,mistral-large
|
||||||
|
OPENROUTER_ALLOWED_MODELS=
|
||||||
|
|
||||||
# Optional: Default model to use
|
# Optional: Default model to use
|
||||||
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini'
|
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini'
|
||||||
# When set to 'auto', Claude will select the best model for each task
|
# When set to 'auto', Claude will select the best model for each task
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan
|
|||||||
### 1. Get API Keys (at least one required)
|
### 1. Get API Keys (at least one required)
|
||||||
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models.
|
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models.
|
||||||
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) to get an API key for O3 model access.
|
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) to get an API key for O3 model access.
|
||||||
|
- **OpenRouter**: Visit [OpenRouter](https://openrouter.ai/) for access to multiple models through one API. [Setup Guide](docs/openrouter.md)
|
||||||
|
|
||||||
### 2. Clone and Set Up
|
### 2. Clone and Set Up
|
||||||
|
|
||||||
@@ -125,12 +126,13 @@ cd zen-mcp-server
|
|||||||
# Edit .env to add your API keys (if not already set in environment)
|
# Edit .env to add your API keys (if not already set in environment)
|
||||||
nano .env
|
nano .env
|
||||||
|
|
||||||
# The file will contain:
|
# The file will contain, at least one should be set:
|
||||||
# GEMINI_API_KEY=your-gemini-api-key-here # For Gemini models
|
# GEMINI_API_KEY=your-gemini-api-key-here # For Gemini models
|
||||||
# OPENAI_API_KEY=your-openai-api-key-here # For O3 model
|
# OPENAI_API_KEY=your-openai-api-key-here # For O3 model
|
||||||
|
# OPENROUTER_API_KEY=your-openrouter-key # For OpenRouter (see docs/openrouter.md)
|
||||||
# WORKSPACE_ROOT=/Users/your-username (automatically configured)
|
# WORKSPACE_ROOT=/Users/your-username (automatically configured)
|
||||||
|
|
||||||
# Note: At least one API key is required (Gemini or OpenAI)
|
# Note: At least one API key is required
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. Configure Claude
|
### 4. Configure Claude
|
||||||
@@ -742,6 +744,7 @@ OPENAI_API_KEY=your-openai-key # Enables O3, O3-mini
|
|||||||
| **`flash`** (Gemini 2.0 Flash) | Google | 1M tokens | Ultra-fast responses | Quick checks, formatting, simple analysis |
|
| **`flash`** (Gemini 2.0 Flash) | Google | 1M tokens | Ultra-fast responses | Quick checks, formatting, simple analysis |
|
||||||
| **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis |
|
| **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis |
|
||||||
| **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks |
|
| **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks |
|
||||||
|
| **Any model** | OpenRouter | Varies | Access to GPT-4, Claude, Llama, etc. | User-specified or based on task requirements |
|
||||||
|
|
||||||
**Manual Model Selection:**
|
**Manual Model Selection:**
|
||||||
You can specify a default model instead of auto mode:
|
You can specify a default model instead of auto mode:
|
||||||
|
|||||||
@@ -31,6 +31,9 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||||
|
# OpenRouter support
|
||||||
|
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
|
||||||
|
- OPENROUTER_ALLOWED_MODELS=${OPENROUTER_ALLOWED_MODELS:-}
|
||||||
- DEFAULT_MODEL=${DEFAULT_MODEL:-auto}
|
- DEFAULT_MODEL=${DEFAULT_MODEL:-auto}
|
||||||
- DEFAULT_THINKING_MODE_THINKDEEP=${DEFAULT_THINKING_MODE_THINKDEEP:-high}
|
- DEFAULT_THINKING_MODE_THINKDEEP=${DEFAULT_THINKING_MODE_THINKDEEP:-high}
|
||||||
- REDIS_URL=redis://redis:6379/0
|
- REDIS_URL=redis://redis:6379/0
|
||||||
|
|||||||
52
docs/openrouter.md
Normal file
52
docs/openrouter.md
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# OpenRouter Setup
|
||||||
|
|
||||||
|
OpenRouter provides unified access to multiple AI models (GPT-4, Claude, Mistral, etc.) through a single API.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Get API Key
|
||||||
|
1. Sign up at [openrouter.ai](https://openrouter.ai/)
|
||||||
|
2. Create an API key from your dashboard
|
||||||
|
3. Add credits to your account
|
||||||
|
|
||||||
|
### 2. Set Environment Variable
|
||||||
|
```bash
|
||||||
|
# Add to your .env file
|
||||||
|
OPENROUTER_API_KEY=your-openrouter-api-key
|
||||||
|
```
|
||||||
|
|
||||||
|
That's it! Docker Compose already includes all necessary configuration.
|
||||||
|
|
||||||
|
### 3. Use Any Model
|
||||||
|
```
|
||||||
|
# Examples
|
||||||
|
"Use gpt-4 via zen to review this code"
|
||||||
|
"Use claude-3-opus via zen to debug this error"
|
||||||
|
"Use mistral-large via zen to optimize this algorithm"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Cost Control (Recommended)
|
||||||
|
|
||||||
|
Restrict which models can be used to prevent unexpected charges:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Add to .env file - only allow specific models
|
||||||
|
OPENROUTER_ALLOWED_MODELS=gpt-4,claude-3-sonnet,mistral-large
|
||||||
|
```
|
||||||
|
|
||||||
|
Check current model pricing at [openrouter.ai/models](https://openrouter.ai/models).
|
||||||
|
|
||||||
|
## Available Models
|
||||||
|
|
||||||
|
Popular models available through OpenRouter:
|
||||||
|
- **GPT-4** - OpenAI's most capable model
|
||||||
|
- **Claude 3** - Anthropic's models (Opus, Sonnet, Haiku)
|
||||||
|
- **Mistral** - Including Mistral Large
|
||||||
|
- **Llama 3** - Meta's open models
|
||||||
|
- Many more at [openrouter.ai/models](https://openrouter.ai/models)
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
- **"Model not found"**: Check exact model name at openrouter.ai/models
|
||||||
|
- **"Insufficient credits"**: Add credits to your OpenRouter account
|
||||||
|
- **"Model not in allow-list"**: Update `OPENROUTER_ALLOWED_MODELS` in .env
|
||||||
@@ -3,6 +3,8 @@
|
|||||||
from .base import ModelCapabilities, ModelProvider, ModelResponse
|
from .base import ModelCapabilities, ModelProvider, ModelResponse
|
||||||
from .gemini import GeminiModelProvider
|
from .gemini import GeminiModelProvider
|
||||||
from .openai import OpenAIModelProvider
|
from .openai import OpenAIModelProvider
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .openrouter import OpenRouterProvider
|
||||||
from .registry import ModelProviderRegistry
|
from .registry import ModelProviderRegistry
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -12,4 +14,6 @@ __all__ = [
|
|||||||
"ModelProviderRegistry",
|
"ModelProviderRegistry",
|
||||||
"GeminiModelProvider",
|
"GeminiModelProvider",
|
||||||
"OpenAIModelProvider",
|
"OpenAIModelProvider",
|
||||||
|
"OpenAICompatibleProvider",
|
||||||
|
"OpenRouterProvider",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ class ProviderType(Enum):
|
|||||||
|
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
OPENROUTER = "openrouter"
|
||||||
|
|
||||||
|
|
||||||
class TemperatureConstraint(ABC):
|
class TemperatureConstraint(ABC):
|
||||||
|
|||||||
@@ -3,20 +3,18 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
FixedTemperatureConstraint,
|
FixedTemperatureConstraint,
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelProvider,
|
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
RangeTemperatureConstraint,
|
||||||
)
|
)
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelProvider(ModelProvider):
|
class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||||
"""OpenAI model provider implementation."""
|
"""Official OpenAI API provider (api.openai.com)."""
|
||||||
|
|
||||||
# Model configurations
|
# Model configurations
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
@@ -32,23 +30,10 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize OpenAI provider with API key."""
|
"""Initialize OpenAI provider with API key."""
|
||||||
|
# Set default OpenAI base URL, allow override for regions/custom endpoints
|
||||||
|
kwargs.setdefault("base_url", "https://api.openai.com/v1")
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
self._client = None
|
|
||||||
self.base_url = kwargs.get("base_url") # Support custom endpoints
|
|
||||||
self.organization = kwargs.get("organization")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client(self):
|
|
||||||
"""Lazy initialization of OpenAI client."""
|
|
||||||
if self._client is None:
|
|
||||||
client_kwargs = {"api_key": self.api_key}
|
|
||||||
if self.base_url:
|
|
||||||
client_kwargs["base_url"] = self.base_url
|
|
||||||
if self.organization:
|
|
||||||
client_kwargs["organization"] = self.organization
|
|
||||||
|
|
||||||
self._client = OpenAI(**client_kwargs)
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
"""Get capabilities for a specific OpenAI model."""
|
"""Get capabilities for a specific OpenAI model."""
|
||||||
@@ -77,79 +62,6 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
temperature_constraint=temp_constraint,
|
temperature_constraint=temp_constraint,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_content(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
model_name: str,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
temperature: float = 0.7,
|
|
||||||
max_output_tokens: Optional[int] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> ModelResponse:
|
|
||||||
"""Generate content using OpenAI model."""
|
|
||||||
# Validate parameters
|
|
||||||
self.validate_parameters(model_name, temperature)
|
|
||||||
|
|
||||||
# Prepare messages
|
|
||||||
messages = []
|
|
||||||
if system_prompt:
|
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
|
||||||
messages.append({"role": "user", "content": prompt})
|
|
||||||
|
|
||||||
# Prepare completion parameters
|
|
||||||
completion_params = {
|
|
||||||
"model": model_name,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add max tokens if specified
|
|
||||||
if max_output_tokens:
|
|
||||||
completion_params["max_tokens"] = max_output_tokens
|
|
||||||
|
|
||||||
# Add any additional OpenAI-specific parameters
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop"]:
|
|
||||||
completion_params[key] = value
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Generate completion
|
|
||||||
response = self.client.chat.completions.create(**completion_params)
|
|
||||||
|
|
||||||
# Extract content and usage
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
usage = self._extract_usage(response)
|
|
||||||
|
|
||||||
return ModelResponse(
|
|
||||||
content=content,
|
|
||||||
usage=usage,
|
|
||||||
model_name=model_name,
|
|
||||||
friendly_name="OpenAI",
|
|
||||||
provider=ProviderType.OPENAI,
|
|
||||||
metadata={
|
|
||||||
"finish_reason": response.choices[0].finish_reason,
|
|
||||||
"model": response.model, # Actual model used (in case of fallbacks)
|
|
||||||
"id": response.id,
|
|
||||||
"created": response.created,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Log error and re-raise with more context
|
|
||||||
error_msg = f"OpenAI API error for model {model_name}: {str(e)}"
|
|
||||||
logging.error(error_msg)
|
|
||||||
raise RuntimeError(error_msg) from e
|
|
||||||
|
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
|
||||||
"""Count tokens for the given text.
|
|
||||||
|
|
||||||
Note: For accurate token counting, we should use tiktoken library.
|
|
||||||
This is a simplified estimation.
|
|
||||||
"""
|
|
||||||
# TODO: Implement proper token counting with tiktoken
|
|
||||||
# For now, use rough estimation
|
|
||||||
# O3 models ~4 chars per token
|
|
||||||
return len(text) // 4
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
@@ -165,13 +77,3 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
# This may change with future O3 models
|
# This may change with future O3 models
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _extract_usage(self, response) -> dict[str, int]:
|
|
||||||
"""Extract token usage from OpenAI response."""
|
|
||||||
usage = {}
|
|
||||||
|
|
||||||
if hasattr(response, "usage") and response.usage:
|
|
||||||
usage["input_tokens"] = response.usage.prompt_tokens
|
|
||||||
usage["output_tokens"] = response.usage.completion_tokens
|
|
||||||
usage["total_tokens"] = response.usage.total_tokens
|
|
||||||
|
|
||||||
return usage
|
|
||||||
|
|||||||
417
providers/openai_compatible.py
Normal file
417
providers/openai_compatible.py
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
"""Base class for OpenAI-compatible API providers."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import ipaddress
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelProvider,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderType,
|
||||||
|
RangeTemperatureConstraint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatibleProvider(ModelProvider):
|
||||||
|
"""Base class for any provider using an OpenAI-compatible API.
|
||||||
|
|
||||||
|
This includes:
|
||||||
|
- Direct OpenAI API
|
||||||
|
- OpenRouter
|
||||||
|
- Any other OpenAI-compatible endpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_HEADERS = {}
|
||||||
|
FRIENDLY_NAME = "OpenAI Compatible"
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str = None, **kwargs):
|
||||||
|
"""Initialize the provider with API key and optional base URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication
|
||||||
|
base_url: Base URL for the API endpoint
|
||||||
|
**kwargs: Additional configuration options
|
||||||
|
"""
|
||||||
|
super().__init__(api_key, **kwargs)
|
||||||
|
self._client = None
|
||||||
|
self.base_url = base_url
|
||||||
|
self.organization = kwargs.get("organization")
|
||||||
|
self.allowed_models = self._parse_allowed_models()
|
||||||
|
|
||||||
|
# Validate base URL for security
|
||||||
|
if self.base_url:
|
||||||
|
self._validate_base_url()
|
||||||
|
|
||||||
|
# Warn if using external URL without authentication
|
||||||
|
if self.base_url and not self._is_localhost_url() and not api_key:
|
||||||
|
logging.warning(
|
||||||
|
f"Using external URL '{self.base_url}' without API key. "
|
||||||
|
"This may be insecure. Consider setting an API key for authentication."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||||||
|
"""Parse allowed models from environment variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of allowed model names (lowercase) or None if not configured
|
||||||
|
"""
|
||||||
|
# Get provider-specific allowed models
|
||||||
|
provider_type = self.get_provider_type().value.upper()
|
||||||
|
env_var = f"{provider_type}_ALLOWED_MODELS"
|
||||||
|
models_str = os.getenv(env_var, "")
|
||||||
|
|
||||||
|
if models_str:
|
||||||
|
# Parse and normalize to lowercase for case-insensitive comparison
|
||||||
|
models = set(m.strip().lower() for m in models_str.split(",") if m.strip())
|
||||||
|
if models:
|
||||||
|
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
||||||
|
return models
|
||||||
|
|
||||||
|
# Log warning if no allow-list configured for proxy providers
|
||||||
|
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
|
||||||
|
logging.warning(
|
||||||
|
f"No model allow-list configured for {self.FRIENDLY_NAME}. "
|
||||||
|
f"Set {env_var} to restrict model access and control costs."
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_localhost_url(self) -> bool:
|
||||||
|
"""Check if the base URL points to localhost.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if URL is localhost, False otherwise
|
||||||
|
"""
|
||||||
|
if not self.base_url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(self.base_url)
|
||||||
|
hostname = parsed.hostname
|
||||||
|
|
||||||
|
# Check for common localhost patterns
|
||||||
|
if hostname in ['localhost', '127.0.0.1', '::1']:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _validate_base_url(self) -> None:
|
||||||
|
"""Validate base URL for security (SSRF protection).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If URL is invalid or potentially unsafe
|
||||||
|
"""
|
||||||
|
if not self.base_url:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(self.base_url)
|
||||||
|
|
||||||
|
|
||||||
|
# Check URL scheme - only allow http/https
|
||||||
|
if parsed.scheme not in ('http', 'https'):
|
||||||
|
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
|
||||||
|
|
||||||
|
# Check hostname exists
|
||||||
|
if not parsed.hostname:
|
||||||
|
raise ValueError("URL must include a hostname")
|
||||||
|
|
||||||
|
# Check port - allow only standard HTTP/HTTPS ports
|
||||||
|
port = parsed.port
|
||||||
|
if port is None:
|
||||||
|
port = 443 if parsed.scheme == 'https' else 80
|
||||||
|
|
||||||
|
# Allow common HTTP ports and some alternative ports
|
||||||
|
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
|
||||||
|
if port not in allowed_ports:
|
||||||
|
raise ValueError(
|
||||||
|
f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check against allowed domains if configured
|
||||||
|
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
|
||||||
|
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
|
||||||
|
|
||||||
|
if allowed_domains:
|
||||||
|
hostname_lower = parsed.hostname.lower()
|
||||||
|
if not any(
|
||||||
|
hostname_lower == domain or
|
||||||
|
hostname_lower.endswith('.' + domain)
|
||||||
|
for domain in allowed_domains
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Domain not in allow-list: {parsed.hostname}. "
|
||||||
|
f"Allowed domains: {allowed_domains}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to resolve hostname and check if it's a private IP
|
||||||
|
# Skip for localhost addresses which are commonly used for development
|
||||||
|
if parsed.hostname not in ['localhost', '127.0.0.1', '::1']:
|
||||||
|
try:
|
||||||
|
# Get all IP addresses for the hostname
|
||||||
|
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
|
||||||
|
|
||||||
|
for family, _, _, _, sockaddr in addr_info:
|
||||||
|
ip_str = sockaddr[0]
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(ip_str)
|
||||||
|
|
||||||
|
# Check for dangerous IP ranges
|
||||||
|
if (ip.is_private or ip.is_loopback or ip.is_link_local or
|
||||||
|
ip.is_multicast or ip.is_reserved or ip.is_unspecified):
|
||||||
|
raise ValueError(
|
||||||
|
f"URL resolves to restricted IP address: {ip_str}. "
|
||||||
|
"This could be a security risk (SSRF)."
|
||||||
|
)
|
||||||
|
except ValueError as ve:
|
||||||
|
# Invalid IP address format or restricted IP - re-raise if it's our security error
|
||||||
|
if "restricted IP address" in str(ve):
|
||||||
|
raise
|
||||||
|
continue
|
||||||
|
|
||||||
|
except socket.gaierror as e:
|
||||||
|
# If we can't resolve the hostname, it's suspicious
|
||||||
|
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, ValueError):
|
||||||
|
raise
|
||||||
|
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self):
|
||||||
|
"""Lazy initialization of OpenAI client with security checks."""
|
||||||
|
if self._client is None:
|
||||||
|
client_kwargs = {
|
||||||
|
"api_key": self.api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.base_url:
|
||||||
|
client_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
|
if self.organization:
|
||||||
|
client_kwargs["organization"] = self.organization
|
||||||
|
|
||||||
|
# Add default headers if any
|
||||||
|
if self.DEFAULT_HEADERS:
|
||||||
|
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||||||
|
|
||||||
|
self._client = OpenAI(**client_kwargs)
|
||||||
|
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using the OpenAI-compatible API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send to the model
|
||||||
|
model_name: Name of the model to use
|
||||||
|
system_prompt: Optional system prompt for model behavior
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_output_tokens: Maximum tokens to generate
|
||||||
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelResponse with generated content and metadata
|
||||||
|
"""
|
||||||
|
# Validate model name against allow-list
|
||||||
|
if not self.validate_model_name(model_name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_name}' not in allowed models list. "
|
||||||
|
f"Allowed models: {self.allowed_models}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate parameters
|
||||||
|
self.validate_parameters(model_name, temperature)
|
||||||
|
|
||||||
|
# Prepare messages
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
# Prepare completion parameters
|
||||||
|
completion_params = {
|
||||||
|
"model": model_name,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add max tokens if specified
|
||||||
|
if max_output_tokens:
|
||||||
|
completion_params["max_tokens"] = max_output_tokens
|
||||||
|
|
||||||
|
# Add any additional OpenAI-specific parameters
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||||||
|
completion_params[key] = value
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate completion
|
||||||
|
response = self.client.chat.completions.create(**completion_params)
|
||||||
|
|
||||||
|
# Extract content and usage
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
|
return ModelResponse(
|
||||||
|
content=content,
|
||||||
|
usage=usage,
|
||||||
|
model_name=model_name,
|
||||||
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
|
provider=self.get_provider_type(),
|
||||||
|
metadata={
|
||||||
|
"finish_reason": response.choices[0].finish_reason,
|
||||||
|
"model": response.model, # Actual model used
|
||||||
|
"id": response.id,
|
||||||
|
"created": response.created,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log error and re-raise with more context
|
||||||
|
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
|
"""Count tokens for the given text.
|
||||||
|
|
||||||
|
Uses a layered approach:
|
||||||
|
1. Try provider-specific token counting endpoint
|
||||||
|
2. Try tiktoken for known model families
|
||||||
|
3. Fall back to character-based estimation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count tokens for
|
||||||
|
model_name: Model name for tokenizer selection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
# 1. Check if provider has a remote token counting endpoint
|
||||||
|
if hasattr(self, 'count_tokens_remote'):
|
||||||
|
try:
|
||||||
|
return self.count_tokens_remote(text, model_name)
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(f"Remote token counting failed: {e}")
|
||||||
|
|
||||||
|
# 2. Try tiktoken for known models
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
# Try to get encoding for the specific model
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model_name)
|
||||||
|
except KeyError:
|
||||||
|
# Try common encodings based on model patterns
|
||||||
|
if "gpt-4" in model_name or "gpt-3.5" in model_name:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
else:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base") # Default
|
||||||
|
|
||||||
|
return len(encoding.encode(text))
|
||||||
|
|
||||||
|
except (ImportError, Exception) as e:
|
||||||
|
logging.debug(f"Tiktoken not available or failed: {e}")
|
||||||
|
|
||||||
|
# 3. Fall back to character-based estimation
|
||||||
|
logging.warning(
|
||||||
|
f"No specific tokenizer available for '{model_name}'. "
|
||||||
|
"Using character-based estimation (~4 chars per token)."
|
||||||
|
)
|
||||||
|
return len(text) // 4
|
||||||
|
|
||||||
|
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||||
|
"""Validate model parameters.
|
||||||
|
|
||||||
|
For proxy providers, this may use generic capabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model to validate for
|
||||||
|
temperature: Temperature to validate
|
||||||
|
**kwargs: Additional parameters to validate
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
|
# Check if we're using generic capabilities
|
||||||
|
if hasattr(capabilities, '_is_generic'):
|
||||||
|
logging.debug(
|
||||||
|
f"Using generic parameter validation for {model_name}. "
|
||||||
|
"Actual model constraints may differ."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate temperature using parent class method
|
||||||
|
super().validate_parameters(model_name, temperature, **kwargs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# For proxy providers, we might not have accurate capabilities
|
||||||
|
# Log warning but don't fail
|
||||||
|
logging.warning(f"Parameter validation limited for {model_name}: {e}")
|
||||||
|
|
||||||
|
def _extract_usage(self, response) -> dict[str, int]:
|
||||||
|
"""Extract token usage from OpenAI response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: OpenAI API response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with usage statistics
|
||||||
|
"""
|
||||||
|
usage = {}
|
||||||
|
|
||||||
|
if hasattr(response, "usage") and response.usage:
|
||||||
|
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
|
||||||
|
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
|
||||||
|
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
|
"""Get capabilities for a specific model.
|
||||||
|
|
||||||
|
Must be implemented by subclasses.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type.
|
||||||
|
|
||||||
|
Must be implemented by subclasses.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
|
"""Validate if the model name is supported.
|
||||||
|
|
||||||
|
Must be implemented by subclasses.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
|
"""Check if the model supports extended thinking mode.
|
||||||
|
|
||||||
|
Default is False for OpenAI-compatible providers.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
119
providers/openrouter.py
Normal file
119
providers/openrouter.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""OpenRouter provider implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ProviderType,
|
||||||
|
RangeTemperatureConstraint,
|
||||||
|
)
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterProvider(OpenAICompatibleProvider):
|
||||||
|
"""OpenRouter unified API provider.
|
||||||
|
|
||||||
|
OpenRouter provides access to multiple AI models through a single API endpoint.
|
||||||
|
See https://openrouter.ai for available models and pricing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
FRIENDLY_NAME = "OpenRouter"
|
||||||
|
|
||||||
|
# Custom headers required by OpenRouter
|
||||||
|
DEFAULT_HEADERS = {
|
||||||
|
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
|
||||||
|
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize OpenRouter provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenRouter API key
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
"""
|
||||||
|
# Always use OpenRouter's base URL
|
||||||
|
super().__init__(api_key, base_url="https://openrouter.ai/api/v1", **kwargs)
|
||||||
|
|
||||||
|
# Log warning about model allow-list if not configured
|
||||||
|
if not self.allowed_models:
|
||||||
|
logging.warning(
|
||||||
|
"OpenRouter provider initialized without model allow-list. "
|
||||||
|
"Consider setting OPENROUTER_ALLOWED_MODELS environment variable "
|
||||||
|
"to restrict model access and control costs."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
|
"""Get capabilities for a model.
|
||||||
|
|
||||||
|
Since OpenRouter supports many models dynamically, we return
|
||||||
|
generic capabilities with conservative defaults.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generic ModelCapabilities with warnings logged
|
||||||
|
"""
|
||||||
|
logging.warning(
|
||||||
|
f"Using generic capabilities for '{model_name}' via OpenRouter. "
|
||||||
|
"Actual model capabilities may differ. Consider querying OpenRouter's "
|
||||||
|
"/models endpoint for accurate information."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create generic capabilities with conservative defaults
|
||||||
|
capabilities = ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENROUTER,
|
||||||
|
model_name=model_name,
|
||||||
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
|
max_tokens=32_768, # Conservative default
|
||||||
|
supports_extended_thinking=False, # Most models don't support this
|
||||||
|
supports_system_prompts=True, # Most models support this
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # Varies by model
|
||||||
|
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark as generic for validation purposes
|
||||||
|
capabilities._is_generic = True
|
||||||
|
|
||||||
|
return capabilities
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type."""
|
||||||
|
return ProviderType.OPENROUTER
|
||||||
|
|
||||||
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
|
"""Validate if the model name is allowed.
|
||||||
|
|
||||||
|
For OpenRouter, we accept any model name unless an allow-list
|
||||||
|
is configured via OPENROUTER_ALLOWED_MODELS environment variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model name to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if model is allowed, False otherwise
|
||||||
|
"""
|
||||||
|
if self.allowed_models:
|
||||||
|
# Case-insensitive validation against allow-list
|
||||||
|
return model_name.lower() in self.allowed_models
|
||||||
|
|
||||||
|
# Accept any model if no allow-list configured
|
||||||
|
# The API will return an error if the model doesn't exist
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
|
"""Check if the model supports extended thinking mode.
|
||||||
|
|
||||||
|
Currently, no models via OpenRouter support extended thinking.
|
||||||
|
This may change as new models become available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
False (no OpenRouter models currently support thinking mode)
|
||||||
|
"""
|
||||||
|
return False
|
||||||
@@ -117,6 +117,7 @@ class ModelProviderRegistry:
|
|||||||
key_mapping = {
|
key_mapping = {
|
||||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||||
|
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
||||||
}
|
}
|
||||||
|
|
||||||
env_var = key_mapping.get(provider_type)
|
env_var = key_mapping.get(provider_type)
|
||||||
|
|||||||
12
server.py
12
server.py
@@ -131,6 +131,7 @@ def configure_providers():
|
|||||||
from providers.base import ProviderType
|
from providers.base import ProviderType
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
|
||||||
valid_providers = []
|
valid_providers = []
|
||||||
|
|
||||||
@@ -148,12 +149,21 @@ def configure_providers():
|
|||||||
valid_providers.append("OpenAI (o3)")
|
valid_providers.append("OpenAI (o3)")
|
||||||
logger.info("OpenAI API key found - o3 model available")
|
logger.info("OpenAI API key found - o3 model available")
|
||||||
|
|
||||||
|
# Check for OpenRouter API key
|
||||||
|
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
||||||
|
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
valid_providers.append("OpenRouter")
|
||||||
|
logger.info("OpenRouter API key found - Multiple models available via OpenRouter")
|
||||||
|
|
||||||
|
|
||||||
# Require at least one valid provider
|
# Require at least one valid provider
|
||||||
if not valid_providers:
|
if not valid_providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one API key is required. Please set either:\n"
|
"At least one API key is required. Please set either:\n"
|
||||||
"- GEMINI_API_KEY for Gemini models\n"
|
"- GEMINI_API_KEY for Gemini models\n"
|
||||||
"- OPENAI_API_KEY for OpenAI o3 model"
|
"- OPENAI_API_KEY for OpenAI o3 model\n"
|
||||||
|
"- OPENROUTER_API_KEY for OpenRouter (multiple models)"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Available providers: {', '.join(valid_providers)}")
|
logger.info(f"Available providers: {', '.join(valid_providers)}")
|
||||||
|
|||||||
@@ -36,8 +36,6 @@ else
|
|||||||
else
|
else
|
||||||
echo "⚠️ Found GEMINI_API_KEY in environment, but sed not available. Please update .env manually."
|
echo "⚠️ Found GEMINI_API_KEY in environment, but sed not available. Please update .env manually."
|
||||||
fi
|
fi
|
||||||
else
|
|
||||||
echo "⚠️ GEMINI_API_KEY not found in environment. Please edit .env and add your API key."
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "${OPENAI_API_KEY:-}" ]; then
|
if [ -n "${OPENAI_API_KEY:-}" ]; then
|
||||||
@@ -48,8 +46,16 @@ else
|
|||||||
else
|
else
|
||||||
echo "⚠️ Found OPENAI_API_KEY in environment, but sed not available. Please update .env manually."
|
echo "⚠️ Found OPENAI_API_KEY in environment, but sed not available. Please update .env manually."
|
||||||
fi
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "${OPENROUTER_API_KEY:-}" ]; then
|
||||||
|
# Replace the placeholder API key with the actual value
|
||||||
|
if command -v sed >/dev/null 2>&1; then
|
||||||
|
sed -i.bak "s/your_openrouter_api_key_here/$OPENROUTER_API_KEY/" .env && rm .env.bak
|
||||||
|
echo "✅ Updated .env with existing OPENROUTER_API_KEY from environment"
|
||||||
else
|
else
|
||||||
echo "⚠️ OPENAI_API_KEY not found in environment. Please edit .env and add your API key."
|
echo "⚠️ Found OPENROUTER_API_KEY in environment, but sed not available. Please update .env manually."
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Update WORKSPACE_ROOT to use current user's home directory
|
# Update WORKSPACE_ROOT to use current user's home directory
|
||||||
@@ -92,6 +98,7 @@ source .env 2>/dev/null || true
|
|||||||
|
|
||||||
VALID_GEMINI_KEY=false
|
VALID_GEMINI_KEY=false
|
||||||
VALID_OPENAI_KEY=false
|
VALID_OPENAI_KEY=false
|
||||||
|
VALID_OPENROUTER_KEY=false
|
||||||
|
|
||||||
# Check if GEMINI_API_KEY is set and not the placeholder
|
# Check if GEMINI_API_KEY is set and not the placeholder
|
||||||
if [ -n "${GEMINI_API_KEY:-}" ] && [ "$GEMINI_API_KEY" != "your_gemini_api_key_here" ]; then
|
if [ -n "${GEMINI_API_KEY:-}" ] && [ "$GEMINI_API_KEY" != "your_gemini_api_key_here" ]; then
|
||||||
@@ -105,18 +112,26 @@ if [ -n "${OPENAI_API_KEY:-}" ] && [ "$OPENAI_API_KEY" != "your_openai_api_key_h
|
|||||||
echo "✅ Valid OPENAI_API_KEY found"
|
echo "✅ Valid OPENAI_API_KEY found"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Check if OPENROUTER_API_KEY is set and not the placeholder
|
||||||
|
if [ -n "${OPENROUTER_API_KEY:-}" ] && [ "$OPENROUTER_API_KEY" != "your_openrouter_api_key_here" ]; then
|
||||||
|
VALID_OPENROUTER_KEY=true
|
||||||
|
echo "✅ Valid OPENROUTER_API_KEY found"
|
||||||
|
fi
|
||||||
|
|
||||||
# Require at least one valid API key
|
# Require at least one valid API key
|
||||||
if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ]; then
|
if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ] && [ "$VALID_OPENROUTER_KEY" = false ]; then
|
||||||
echo ""
|
echo ""
|
||||||
echo "❌ ERROR: At least one valid API key is required!"
|
echo "❌ ERROR: At least one valid API key is required!"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Please edit the .env file and set at least one of:"
|
echo "Please edit the .env file and set at least one of:"
|
||||||
echo " - GEMINI_API_KEY (get from https://makersuite.google.com/app/apikey)"
|
echo " - GEMINI_API_KEY (get from https://makersuite.google.com/app/apikey)"
|
||||||
echo " - OPENAI_API_KEY (get from https://platform.openai.com/api-keys)"
|
echo " - OPENAI_API_KEY (get from https://platform.openai.com/api-keys)"
|
||||||
|
echo " - OPENROUTER_API_KEY (get from https://openrouter.ai/)"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Example:"
|
echo "Example:"
|
||||||
echo " GEMINI_API_KEY=your-actual-api-key-here"
|
echo " GEMINI_API_KEY=your-actual-api-key-here"
|
||||||
echo " OPENAI_API_KEY=sk-your-actual-openai-key-here"
|
echo " OPENAI_API_KEY=sk-your-actual-openai-key-here"
|
||||||
|
echo " OPENROUTER_API_KEY=sk-or-your-actual-openrouter-key-here"
|
||||||
echo ""
|
echo ""
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
@@ -228,7 +243,7 @@ show_configuration_steps() {
|
|||||||
echo ""
|
echo ""
|
||||||
echo "🔄 Next steps:"
|
echo "🔄 Next steps:"
|
||||||
NEEDS_KEY_UPDATE=false
|
NEEDS_KEY_UPDATE=false
|
||||||
if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null; then
|
if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null || grep -q "your_openrouter_api_key_here" .env 2>/dev/null; then
|
||||||
NEEDS_KEY_UPDATE=true
|
NEEDS_KEY_UPDATE=true
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@@ -236,6 +251,7 @@ show_configuration_steps() {
|
|||||||
echo "1. Edit .env and replace placeholder API keys with actual ones"
|
echo "1. Edit .env and replace placeholder API keys with actual ones"
|
||||||
echo " - GEMINI_API_KEY: your-gemini-api-key-here"
|
echo " - GEMINI_API_KEY: your-gemini-api-key-here"
|
||||||
echo " - OPENAI_API_KEY: your-openai-api-key-here"
|
echo " - OPENAI_API_KEY: your-openai-api-key-here"
|
||||||
|
echo " - OPENROUTER_API_KEY: your-openrouter-api-key-here (optional)"
|
||||||
echo "2. Restart services: $COMPOSE_CMD restart"
|
echo "2. Restart services: $COMPOSE_CMD restart"
|
||||||
echo "3. Copy the configuration below to your Claude Desktop config if required:"
|
echo "3. Copy the configuration below to your Claude Desktop config if required:"
|
||||||
else
|
else
|
||||||
|
|||||||
138
tests/test_openrouter_provider.py
Normal file
138
tests/test_openrouter_provider.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""Tests for OpenRouter provider."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenRouterProvider:
|
||||||
|
"""Test cases for OpenRouter provider."""
|
||||||
|
|
||||||
|
def test_provider_initialization(self):
|
||||||
|
"""Test OpenRouter provider initialization."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.base_url == "https://openrouter.ai/api/v1"
|
||||||
|
assert provider.FRIENDLY_NAME == "OpenRouter"
|
||||||
|
|
||||||
|
def test_custom_headers(self):
|
||||||
|
"""Test OpenRouter custom headers."""
|
||||||
|
# Test default headers
|
||||||
|
assert "HTTP-Referer" in OpenRouterProvider.DEFAULT_HEADERS
|
||||||
|
assert "X-Title" in OpenRouterProvider.DEFAULT_HEADERS
|
||||||
|
|
||||||
|
# Test with environment variables
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"OPENROUTER_REFERER": "https://myapp.com",
|
||||||
|
"OPENROUTER_TITLE": "My App"
|
||||||
|
}):
|
||||||
|
from importlib import reload
|
||||||
|
import providers.openrouter
|
||||||
|
reload(providers.openrouter)
|
||||||
|
|
||||||
|
provider = providers.openrouter.OpenRouterProvider(api_key="test-key")
|
||||||
|
assert provider.DEFAULT_HEADERS["HTTP-Referer"] == "https://myapp.com"
|
||||||
|
assert provider.DEFAULT_HEADERS["X-Title"] == "My App"
|
||||||
|
|
||||||
|
def test_model_validation_without_allowlist(self):
|
||||||
|
"""Test model validation without allow-list."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Should accept any model when no allow-list
|
||||||
|
assert provider.validate_model_name("gpt-4") is True
|
||||||
|
assert provider.validate_model_name("claude-3-opus") is True
|
||||||
|
assert provider.validate_model_name("any-model-name") is True
|
||||||
|
|
||||||
|
def test_model_validation_with_allowlist(self):
|
||||||
|
"""Test model validation with allow-list."""
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"OPENROUTER_ALLOWED_MODELS": "gpt-4,claude-3-opus,mistral-large"
|
||||||
|
}):
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Test allowed models (case-insensitive)
|
||||||
|
assert provider.validate_model_name("gpt-4") is True
|
||||||
|
assert provider.validate_model_name("GPT-4") is True
|
||||||
|
assert provider.validate_model_name("claude-3-opus") is True
|
||||||
|
assert provider.validate_model_name("MISTRAL-LARGE") is True
|
||||||
|
|
||||||
|
# Test disallowed models
|
||||||
|
assert provider.validate_model_name("gpt-3.5-turbo") is False
|
||||||
|
assert provider.validate_model_name("unauthorized-model") is False
|
||||||
|
|
||||||
|
def test_get_capabilities(self):
|
||||||
|
"""Test capability generation returns generic capabilities."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Should return generic capabilities for any model
|
||||||
|
caps = provider.get_capabilities("gpt-4")
|
||||||
|
assert caps.provider == ProviderType.OPENROUTER
|
||||||
|
assert caps.model_name == "gpt-4"
|
||||||
|
assert caps.friendly_name == "OpenRouter"
|
||||||
|
assert caps.max_tokens == 32_768 # Safe default
|
||||||
|
assert hasattr(caps, '_is_generic') and caps._is_generic is True
|
||||||
|
|
||||||
|
def test_openrouter_registration(self):
|
||||||
|
"""Test OpenRouter can be registered and retrieved."""
|
||||||
|
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||||
|
# Clean up any existing registration
|
||||||
|
ModelProviderRegistry.unregister_provider(ProviderType.OPENROUTER)
|
||||||
|
|
||||||
|
# Register the provider
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
|
# Retrieve and verify
|
||||||
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||||
|
assert provider is not None
|
||||||
|
assert isinstance(provider, OpenRouterProvider)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenRouterSSRFProtection:
|
||||||
|
"""Test SSRF protection for OpenRouter."""
|
||||||
|
|
||||||
|
def test_url_validation_rejects_private_ips(self):
|
||||||
|
"""Test that private IPs are rejected."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# List of private/dangerous IPs to test
|
||||||
|
dangerous_urls = [
|
||||||
|
"http://192.168.1.1/api/v1",
|
||||||
|
"http://10.0.0.1/api/v1",
|
||||||
|
"http://172.16.0.1/api/v1",
|
||||||
|
"http://169.254.169.254/api/v1", # AWS metadata
|
||||||
|
"http://[::1]/api/v1", # IPv6 localhost
|
||||||
|
"http://0.0.0.0/api/v1",
|
||||||
|
]
|
||||||
|
|
||||||
|
for url in dangerous_urls:
|
||||||
|
with pytest.raises(ValueError, match="restricted IP|Invalid"):
|
||||||
|
provider.base_url = url
|
||||||
|
provider._validate_base_url()
|
||||||
|
|
||||||
|
def test_url_validation_allows_public_domains(self):
|
||||||
|
"""Test that legitimate public domains are allowed."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# OpenRouter's actual domain should always be allowed
|
||||||
|
provider.base_url = "https://openrouter.ai/api/v1"
|
||||||
|
provider._validate_base_url() # Should not raise
|
||||||
|
|
||||||
|
def test_invalid_url_schemes_rejected(self):
|
||||||
|
"""Test that non-HTTP(S) schemes are rejected."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
|
invalid_urls = [
|
||||||
|
"ftp://example.com/api",
|
||||||
|
"file:///etc/passwd",
|
||||||
|
"gopher://example.com",
|
||||||
|
"javascript:alert(1)",
|
||||||
|
]
|
||||||
|
|
||||||
|
for url in invalid_urls:
|
||||||
|
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||||
|
provider.base_url = url
|
||||||
|
provider._validate_base_url()
|
||||||
Reference in New Issue
Block a user