Merge branch 'BeehiveInnovations:main' into feat/comprehensive-project-improvements
This commit is contained in:
15
.env.example
15
.env.example
@@ -2,12 +2,27 @@
|
|||||||
# Copy this file to .env and fill in your values
|
# Copy this file to .env and fill in your values
|
||||||
|
|
||||||
# API Keys - At least one is required
|
# API Keys - At least one is required
|
||||||
|
#
|
||||||
|
# IMPORTANT: Use EITHER OpenRouter OR native APIs (Gemini/OpenAI), not both!
|
||||||
|
# Having both creates ambiguity about which provider serves each model.
|
||||||
|
#
|
||||||
|
# Option 1: Use native APIs (recommended for direct access)
|
||||||
# Get your Gemini API key from: https://makersuite.google.com/app/apikey
|
# Get your Gemini API key from: https://makersuite.google.com/app/apikey
|
||||||
GEMINI_API_KEY=your_gemini_api_key_here
|
GEMINI_API_KEY=your_gemini_api_key_here
|
||||||
|
|
||||||
# Get your OpenAI API key from: https://platform.openai.com/api-keys
|
# Get your OpenAI API key from: https://platform.openai.com/api-keys
|
||||||
OPENAI_API_KEY=your_openai_api_key_here
|
OPENAI_API_KEY=your_openai_api_key_here
|
||||||
|
|
||||||
|
# Option 2: Use OpenRouter for access to multiple models through one API
|
||||||
|
# Get your OpenRouter API key from: https://openrouter.ai/
|
||||||
|
# If using OpenRouter, comment out the native API keys above
|
||||||
|
OPENROUTER_API_KEY=your_openrouter_api_key_here
|
||||||
|
|
||||||
|
# Optional: Restrict which models can be used via OpenRouter (recommended for cost control)
|
||||||
|
# Example: OPENROUTER_ALLOWED_MODELS=gpt-4,claude-3-opus,mistral-large
|
||||||
|
# Leave empty to allow ANY model (not recommended - risk of high costs)
|
||||||
|
OPENROUTER_ALLOWED_MODELS=
|
||||||
|
|
||||||
# Optional: Default model to use
|
# Optional: Default model to use
|
||||||
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini'
|
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini'
|
||||||
# When set to 'auto', Claude will select the best model for each task
|
# When set to 'auto', Claude will select the best model for each task
|
||||||
|
|||||||
168
conf/openrouter_models.json
Normal file
168
conf/openrouter_models.json
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
{
|
||||||
|
"_README": {
|
||||||
|
"description": "OpenRouter model configuration for Zen MCP Server",
|
||||||
|
"documentation": "https://github.com/BeehiveInnovations/zen-mcp-server/blob/main/docs/openrouter.md",
|
||||||
|
"instructions": [
|
||||||
|
"Add new models by copying an existing entry and modifying it",
|
||||||
|
"Aliases are case-insensitive and should be unique across all models",
|
||||||
|
"context_window is the model's total context window size in tokens (input + output)",
|
||||||
|
"Set supports_* flags based on the model's actual capabilities",
|
||||||
|
"Models not listed here will use generic defaults (32K context window, basic features)"
|
||||||
|
],
|
||||||
|
"field_descriptions": {
|
||||||
|
"model_name": "The official OpenRouter model identifier (e.g., 'anthropic/claude-3-opus')",
|
||||||
|
"aliases": "Array of short names users can type instead of the full model name",
|
||||||
|
"context_window": "Total number of tokens the model can process (input + output combined)",
|
||||||
|
"supports_extended_thinking": "Whether the model supports extended reasoning tokens (currently none do via OpenRouter)",
|
||||||
|
"supports_json_mode": "Whether the model can guarantee valid JSON output",
|
||||||
|
"supports_function_calling": "Whether the model supports function/tool calling",
|
||||||
|
"description": "Human-readable description of the model"
|
||||||
|
},
|
||||||
|
"example_custom_model": {
|
||||||
|
"model_name": "vendor/model-name-version",
|
||||||
|
"aliases": ["shortname", "nickname", "abbrev"],
|
||||||
|
"context_window": 128000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"description": "Brief description of the model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"model_name": "openai/gpt-4o",
|
||||||
|
"aliases": ["gpt4o", "4o", "gpt-4o"],
|
||||||
|
"context_window": 128000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"description": "OpenAI's most capable model, GPT-4 Optimized"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "openai/gpt-4o-mini",
|
||||||
|
"aliases": ["gpt4o-mini", "4o-mini", "gpt-4o-mini"],
|
||||||
|
"context_window": 128000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"description": "Smaller, faster version of GPT-4o"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "anthropic/claude-3-opus",
|
||||||
|
"aliases": ["opus", "claude-opus", "claude3-opus", "claude-3-opus"],
|
||||||
|
"context_window": 200000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"description": "Claude 3 Opus - Most capable Claude model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "anthropic/claude-3-sonnet",
|
||||||
|
"aliases": ["sonnet", "claude-sonnet", "claude3-sonnet", "claude-3-sonnet", "claude"],
|
||||||
|
"context_window": 200000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"description": "Claude 3 Sonnet - Balanced performance"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "anthropic/claude-3-haiku",
|
||||||
|
"aliases": ["haiku", "claude-haiku", "claude3-haiku", "claude-3-haiku"],
|
||||||
|
"context_window": 200000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"description": "Claude 3 Haiku - Fast and efficient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "google/gemini-pro-1.5",
|
||||||
|
"aliases": ["pro","gemini-pro", "gemini", "pro-openrouter"],
|
||||||
|
"context_window": 1048576,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"description": "Google's Gemini Pro 1.5 via OpenRouter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "google/gemini-flash-1.5-8b",
|
||||||
|
"aliases": ["flash","gemini-flash", "flash-openrouter", "flash-8b"],
|
||||||
|
"context_window": 1048576,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"description": "Google's Gemini Flash 1.5 8B via OpenRouter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "mistral/mistral-large",
|
||||||
|
"aliases": ["mistral-large", "mistral"],
|
||||||
|
"context_window": 128000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"description": "Mistral's largest model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "meta-llama/llama-3-70b",
|
||||||
|
"aliases": ["llama","llama3-70b", "llama-70b", "llama3"],
|
||||||
|
"context_window": 8192,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"description": "Meta's Llama 3 70B model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "cohere/command-r-plus",
|
||||||
|
"aliases": ["command-r-plus", "command-r", "cohere"],
|
||||||
|
"context_window": 128000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": false,
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"description": "Cohere's Command R Plus model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "deepseek/deepseek-coder",
|
||||||
|
"aliases": ["deepseek-coder", "deepseek", "coder"],
|
||||||
|
"context_window": 16384,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"description": "DeepSeek's coding-focused model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "perplexity/llama-3-sonar-large-32k-online",
|
||||||
|
"aliases": ["perplexity", "sonar", "perplexity-online"],
|
||||||
|
"context_window": 32768,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": false,
|
||||||
|
"supports_function_calling": false,
|
||||||
|
"description": "Perplexity's online model with web search"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "openai/o3",
|
||||||
|
"aliases": ["o3"],
|
||||||
|
"context_window": 200000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"description": "OpenAI's o3 model - well-rounded and powerful across domains"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "openai/o3-mini",
|
||||||
|
"aliases": ["o3-mini", "o3mini"],
|
||||||
|
"context_window": 200000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"description": "OpenAI's o3-mini reasoning model - cost-efficient with STEM performance"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "openai/o3-mini-high",
|
||||||
|
"aliases": ["o3-mini-high", "o3mini-high"],
|
||||||
|
"context_window": 200000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"description": "OpenAI's o3-mini with high reasoning effort - optimized for complex problems"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
34
config.py
34
config.py
@@ -13,8 +13,8 @@ import os
|
|||||||
# Version and metadata
|
# Version and metadata
|
||||||
# These values are used in server responses and for tracking releases
|
# These values are used in server responses and for tracking releases
|
||||||
# IMPORTANT: This is the single source of truth for version and author info
|
# IMPORTANT: This is the single source of truth for version and author info
|
||||||
__version__ = "4.0.0" # Semantic versioning: MAJOR.MINOR.PATCH
|
__version__ = "4.1.0" # Semantic versioning: MAJOR.MINOR.PATCH
|
||||||
__updated__ = "2025-06-12" # Last update date in ISO format
|
__updated__ = "2025-06-13" # Last update date in ISO format
|
||||||
__author__ = "Fahad Gilani" # Primary maintainer
|
__author__ = "Fahad Gilani" # Primary maintainer
|
||||||
|
|
||||||
# Model configuration
|
# Model configuration
|
||||||
@@ -24,26 +24,6 @@ __author__ = "Fahad Gilani" # Primary maintainer
|
|||||||
# Special value "auto" means Claude should pick the best model for each task
|
# Special value "auto" means Claude should pick the best model for each task
|
||||||
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "auto")
|
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "auto")
|
||||||
|
|
||||||
# Validate DEFAULT_MODEL and set to "auto" if invalid
|
|
||||||
# Only include actually supported models from providers
|
|
||||||
VALID_MODELS = [
|
|
||||||
"auto",
|
|
||||||
"flash",
|
|
||||||
"pro",
|
|
||||||
"o3",
|
|
||||||
"o3-mini",
|
|
||||||
"gemini-2.5-flash-preview-05-20",
|
|
||||||
"gemini-2.5-pro-preview-06-05",
|
|
||||||
]
|
|
||||||
if DEFAULT_MODEL not in VALID_MODELS:
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid DEFAULT_MODEL '{DEFAULT_MODEL}'. Setting to 'auto'. Valid options: {', '.join(VALID_MODELS)}"
|
|
||||||
)
|
|
||||||
DEFAULT_MODEL = "auto"
|
|
||||||
|
|
||||||
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
|
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
|
||||||
IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto"
|
IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto"
|
||||||
|
|
||||||
@@ -56,9 +36,17 @@ MODEL_CAPABILITIES_DESC = {
|
|||||||
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||||
# Full model names also supported
|
# Full model names also supported
|
||||||
"gemini-2.5-flash-preview-05-20": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
"gemini-2.5-flash-preview-05-20": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||||
"gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
"gemini-2.5-pro-preview-06-05": (
|
||||||
|
"Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Note: When only OpenRouter is configured, these model aliases automatically map to equivalent models:
|
||||||
|
# - "flash" → "google/gemini-flash-1.5-8b"
|
||||||
|
# - "pro" → "google/gemini-pro-1.5"
|
||||||
|
# - "o3" → "openai/gpt-4o"
|
||||||
|
# - "o3-mini" → "openai/gpt-4o-mini"
|
||||||
|
|
||||||
# Token allocation for Gemini Pro (1M total capacity)
|
# Token allocation for Gemini Pro (1M total capacity)
|
||||||
# MAX_CONTEXT_TOKENS: Total model capacity
|
# MAX_CONTEXT_TOKENS: Total model capacity
|
||||||
# MAX_CONTENT_TOKENS: Available for prompts, conversation history, and files
|
# MAX_CONTENT_TOKENS: Available for prompts, conversation history, and files
|
||||||
|
|||||||
@@ -31,6 +31,9 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||||
|
# OpenRouter support
|
||||||
|
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
|
||||||
|
- OPENROUTER_MODELS_PATH=${OPENROUTER_MODELS_PATH:-}
|
||||||
- DEFAULT_MODEL=${DEFAULT_MODEL:-auto}
|
- DEFAULT_MODEL=${DEFAULT_MODEL:-auto}
|
||||||
- DEFAULT_THINKING_MODE_THINKDEEP=${DEFAULT_THINKING_MODE_THINKDEEP:-high}
|
- DEFAULT_THINKING_MODE_THINKDEEP=${DEFAULT_THINKING_MODE_THINKDEEP:-high}
|
||||||
- REDIS_URL=redis://redis:6379/0
|
- REDIS_URL=redis://redis:6379/0
|
||||||
|
|||||||
122
docs/openrouter.md
Normal file
122
docs/openrouter.md
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
# OpenRouter Setup
|
||||||
|
|
||||||
|
OpenRouter provides unified access to multiple AI models (GPT-4, Claude, Mistral, etc.) through a single API.
|
||||||
|
|
||||||
|
## When to Use OpenRouter
|
||||||
|
|
||||||
|
**Use OpenRouter when you want:**
|
||||||
|
- Access to models not available through native APIs (GPT-4, Claude, Mistral, etc.)
|
||||||
|
- Simplified billing across multiple model providers
|
||||||
|
- Experimentation with various models without separate API keys
|
||||||
|
|
||||||
|
**Use native APIs (Gemini/OpenAI) when you want:**
|
||||||
|
- Direct access to specific providers without intermediary
|
||||||
|
- Potentially lower latency and costs
|
||||||
|
- Access to the latest model features immediately upon release
|
||||||
|
|
||||||
|
**Important:** Don't use both OpenRouter and native APIs simultaneously - this creates ambiguity about which provider serves each model.
|
||||||
|
|
||||||
|
## Model Aliases
|
||||||
|
|
||||||
|
The server uses `conf/openrouter_models.json` to map convenient aliases to OpenRouter model names. Some popular aliases:
|
||||||
|
|
||||||
|
| Alias | Maps to OpenRouter Model |
|
||||||
|
|-------|-------------------------|
|
||||||
|
| `opus` | `anthropic/claude-3-opus` |
|
||||||
|
| `sonnet`, `claude` | `anthropic/claude-3-sonnet` |
|
||||||
|
| `haiku` | `anthropic/claude-3-haiku` |
|
||||||
|
| `gpt4o`, `4o` | `openai/gpt-4o` |
|
||||||
|
| `gpt4o-mini`, `4o-mini` | `openai/gpt-4o-mini` |
|
||||||
|
| `gemini`, `pro-openrouter` | `google/gemini-pro-1.5` |
|
||||||
|
| `flash-openrouter` | `google/gemini-flash-1.5-8b` |
|
||||||
|
| `mistral` | `mistral/mistral-large` |
|
||||||
|
| `deepseek`, `coder` | `deepseek/deepseek-coder` |
|
||||||
|
| `perplexity` | `perplexity/llama-3-sonar-large-32k-online` |
|
||||||
|
|
||||||
|
View the full list in [`conf/openrouter_models.json`](conf/openrouter_models.json).
|
||||||
|
|
||||||
|
**Note:** While you can use any OpenRouter model by its full name, models not in the config file will use generic capabilities (32K context window, no extended thinking, etc.) which may not match the model's actual capabilities. For best results, add new models to the config file with their proper specifications.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Get API Key
|
||||||
|
1. Sign up at [openrouter.ai](https://openrouter.ai/)
|
||||||
|
2. Create an API key from your dashboard
|
||||||
|
3. Add credits to your account
|
||||||
|
|
||||||
|
### 2. Set Environment Variable
|
||||||
|
```bash
|
||||||
|
# Add to your .env file
|
||||||
|
OPENROUTER_API_KEY=your-openrouter-api-key
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note:** Control which models can be used directly in your OpenRouter dashboard at [openrouter.ai](https://openrouter.ai/).
|
||||||
|
> This gives you centralized control over model access and spending limits.
|
||||||
|
|
||||||
|
That's it! Docker Compose already includes all necessary configuration.
|
||||||
|
|
||||||
|
### 3. Use Models
|
||||||
|
|
||||||
|
**Using model aliases (from conf/openrouter_models.json):**
|
||||||
|
```
|
||||||
|
# Use short aliases:
|
||||||
|
"Use opus for deep analysis" # → anthropic/claude-3-opus
|
||||||
|
"Use sonnet to review this code" # → anthropic/claude-3-sonnet
|
||||||
|
"Use gpt4o via zen to analyze this" # → openai/gpt-4o
|
||||||
|
"Use mistral via zen to optimize" # → mistral/mistral-large
|
||||||
|
```
|
||||||
|
|
||||||
|
**Using full model names:**
|
||||||
|
```
|
||||||
|
# Any model available on OpenRouter:
|
||||||
|
"Use anthropic/claude-3-opus via zen for deep analysis"
|
||||||
|
"Use openai/gpt-4o via zen to debug this"
|
||||||
|
"Use deepseek/deepseek-coder via zen to generate code"
|
||||||
|
```
|
||||||
|
|
||||||
|
Check current model pricing at [openrouter.ai/models](https://openrouter.ai/models).
|
||||||
|
|
||||||
|
## Model Configuration
|
||||||
|
|
||||||
|
The server uses `conf/openrouter_models.json` to define model aliases and capabilities. You can:
|
||||||
|
|
||||||
|
1. **Use the default configuration** - Includes popular models with convenient aliases
|
||||||
|
2. **Customize the configuration** - Add your own models and aliases
|
||||||
|
3. **Override the config path** - Set `OPENROUTER_MODELS_PATH` environment variable
|
||||||
|
|
||||||
|
### Adding Custom Models
|
||||||
|
|
||||||
|
Edit `conf/openrouter_models.json` to add new models:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model_name": "vendor/model-name",
|
||||||
|
"aliases": ["short-name", "nickname"],
|
||||||
|
"context_window": 128000,
|
||||||
|
"supports_extended_thinking": false,
|
||||||
|
"supports_json_mode": true,
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"description": "Model description"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Field explanations:**
|
||||||
|
- `context_window`: Total tokens the model can process (input + output combined)
|
||||||
|
- `supports_extended_thinking`: Whether the model has extended reasoning capabilities
|
||||||
|
- `supports_json_mode`: Whether the model can guarantee valid JSON output
|
||||||
|
- `supports_function_calling`: Whether the model supports function/tool calling
|
||||||
|
|
||||||
|
## Available Models
|
||||||
|
|
||||||
|
Popular models available through OpenRouter:
|
||||||
|
- **GPT-4** - OpenAI's most capable model
|
||||||
|
- **Claude 3** - Anthropic's models (Opus, Sonnet, Haiku)
|
||||||
|
- **Mistral** - Including Mistral Large
|
||||||
|
- **Llama 3** - Meta's open models
|
||||||
|
- Many more at [openrouter.ai/models](https://openrouter.ai/models)
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
- **"Model not found"**: Check exact model name at openrouter.ai/models
|
||||||
|
- **"Insufficient credits"**: Add credits to your OpenRouter account
|
||||||
|
- **"Model not available"**: Check your OpenRouter dashboard for model access permissions
|
||||||
@@ -141,7 +141,11 @@ trace issues to their root cause, and provide actionable solutions.
|
|||||||
IMPORTANT: If you lack critical information to proceed (e.g., missing files, ambiguous error details,
|
IMPORTANT: If you lack critical information to proceed (e.g., missing files, ambiguous error details,
|
||||||
insufficient context), OR if the provided diagnostics (log files, crash reports, stack traces) appear irrelevant,
|
insufficient context), OR if the provided diagnostics (log files, crash reports, stack traces) appear irrelevant,
|
||||||
incomplete, or insufficient for proper analysis, you MUST respond ONLY with this JSON format:
|
incomplete, or insufficient for proper analysis, you MUST respond ONLY with this JSON format:
|
||||||
{"status": "requires_clarification", "question": "What specific information you need from Claude or the user to proceed with debugging", "files_needed": ["file1.py", "file2.py"]}
|
{
|
||||||
|
"status": "requires_clarification",
|
||||||
|
"question": "What specific information you need from Claude or the user to proceed with debugging",
|
||||||
|
"files_needed": ["file1.py", "file2.py"]
|
||||||
|
}
|
||||||
|
|
||||||
CRITICAL: Your primary objective is to identify the root cause of the specific issue at hand and suggest the
|
CRITICAL: Your primary objective is to identify the root cause of the specific issue at hand and suggest the
|
||||||
minimal fix required to resolve it. Stay focused on the main problem - avoid suggesting extensive refactoring,
|
minimal fix required to resolve it. Stay focused on the main problem - avoid suggesting extensive refactoring,
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
from .base import ModelCapabilities, ModelProvider, ModelResponse
|
from .base import ModelCapabilities, ModelProvider, ModelResponse
|
||||||
from .gemini import GeminiModelProvider
|
from .gemini import GeminiModelProvider
|
||||||
from .openai import OpenAIModelProvider
|
from .openai import OpenAIModelProvider
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .openrouter import OpenRouterProvider
|
||||||
from .registry import ModelProviderRegistry
|
from .registry import ModelProviderRegistry
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -12,4 +14,6 @@ __all__ = [
|
|||||||
"ModelProviderRegistry",
|
"ModelProviderRegistry",
|
||||||
"GeminiModelProvider",
|
"GeminiModelProvider",
|
||||||
"OpenAIModelProvider",
|
"OpenAIModelProvider",
|
||||||
|
"OpenAICompatibleProvider",
|
||||||
|
"OpenRouterProvider",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ class ProviderType(Enum):
|
|||||||
|
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
OPENROUTER = "openrouter"
|
||||||
|
|
||||||
|
|
||||||
class TemperatureConstraint(ABC):
|
class TemperatureConstraint(ABC):
|
||||||
|
|||||||
@@ -1,22 +1,16 @@
|
|||||||
"""OpenAI model provider implementation."""
|
"""OpenAI model provider implementation."""
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
FixedTemperatureConstraint,
|
FixedTemperatureConstraint,
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelProvider,
|
|
||||||
ModelResponse,
|
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
RangeTemperatureConstraint,
|
||||||
)
|
)
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelProvider(ModelProvider):
|
class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||||
"""OpenAI model provider implementation."""
|
"""Official OpenAI API provider (api.openai.com)."""
|
||||||
|
|
||||||
# Model configurations
|
# Model configurations
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
@@ -32,23 +26,9 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize OpenAI provider with API key."""
|
"""Initialize OpenAI provider with API key."""
|
||||||
|
# Set default OpenAI base URL, allow override for regions/custom endpoints
|
||||||
|
kwargs.setdefault("base_url", "https://api.openai.com/v1")
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
self._client = None
|
|
||||||
self.base_url = kwargs.get("base_url") # Support custom endpoints
|
|
||||||
self.organization = kwargs.get("organization")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client(self):
|
|
||||||
"""Lazy initialization of OpenAI client."""
|
|
||||||
if self._client is None:
|
|
||||||
client_kwargs = {"api_key": self.api_key}
|
|
||||||
if self.base_url:
|
|
||||||
client_kwargs["base_url"] = self.base_url
|
|
||||||
if self.organization:
|
|
||||||
client_kwargs["organization"] = self.organization
|
|
||||||
|
|
||||||
self._client = OpenAI(**client_kwargs)
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
"""Get capabilities for a specific OpenAI model."""
|
"""Get capabilities for a specific OpenAI model."""
|
||||||
@@ -77,80 +57,6 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
temperature_constraint=temp_constraint,
|
temperature_constraint=temp_constraint,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_content(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
model_name: str,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
temperature: float = 0.7,
|
|
||||||
max_output_tokens: Optional[int] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> ModelResponse:
|
|
||||||
"""Generate content using OpenAI model."""
|
|
||||||
# Validate parameters
|
|
||||||
self.validate_parameters(model_name, temperature)
|
|
||||||
|
|
||||||
# Prepare messages
|
|
||||||
messages = []
|
|
||||||
if system_prompt:
|
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
|
||||||
messages.append({"role": "user", "content": prompt})
|
|
||||||
|
|
||||||
# Prepare completion parameters
|
|
||||||
completion_params = {
|
|
||||||
"model": model_name,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add max tokens if specified
|
|
||||||
if max_output_tokens:
|
|
||||||
completion_params["max_tokens"] = max_output_tokens
|
|
||||||
|
|
||||||
# Add any additional OpenAI-specific parameters
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop"]:
|
|
||||||
completion_params[key] = value
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Generate completion
|
|
||||||
response = self.client.chat.completions.create(**completion_params)
|
|
||||||
|
|
||||||
# Extract content and usage
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
usage = self._extract_usage(response)
|
|
||||||
|
|
||||||
return ModelResponse(
|
|
||||||
content=content,
|
|
||||||
usage=usage,
|
|
||||||
model_name=model_name,
|
|
||||||
friendly_name="OpenAI",
|
|
||||||
provider=ProviderType.OPENAI,
|
|
||||||
metadata={
|
|
||||||
"finish_reason": response.choices[0].finish_reason,
|
|
||||||
"model": response.model, # Actual model used (in case of fallbacks)
|
|
||||||
"id": response.id,
|
|
||||||
"created": response.created,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Log error and re-raise with more context
|
|
||||||
error_msg = f"OpenAI API error for model {model_name}: {str(e)}"
|
|
||||||
logging.error(error_msg)
|
|
||||||
raise RuntimeError(error_msg) from e
|
|
||||||
|
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
|
||||||
"""Count tokens for the given text.
|
|
||||||
|
|
||||||
Note: For accurate token counting, we should use tiktoken library.
|
|
||||||
This is a simplified estimation.
|
|
||||||
"""
|
|
||||||
# TODO: Implement proper token counting with tiktoken
|
|
||||||
# For now, use rough estimation
|
|
||||||
# O3 models ~4 chars per token
|
|
||||||
return len(text) // 4
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.OPENAI
|
return ProviderType.OPENAI
|
||||||
@@ -164,14 +70,3 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
# Currently no OpenAI models support extended thinking
|
# Currently no OpenAI models support extended thinking
|
||||||
# This may change with future O3 models
|
# This may change with future O3 models
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _extract_usage(self, response) -> dict[str, int]:
|
|
||||||
"""Extract token usage from OpenAI response."""
|
|
||||||
usage = {}
|
|
||||||
|
|
||||||
if hasattr(response, "usage") and response.usage:
|
|
||||||
usage["input_tokens"] = response.usage.prompt_tokens
|
|
||||||
usage["output_tokens"] = response.usage.completion_tokens
|
|
||||||
usage["total_tokens"] = response.usage.total_tokens
|
|
||||||
|
|
||||||
return usage
|
|
||||||
|
|||||||
414
providers/openai_compatible.py
Normal file
414
providers/openai_compatible.py
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
"""Base class for OpenAI-compatible API providers."""
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelProvider,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatibleProvider(ModelProvider):
|
||||||
|
"""Base class for any provider using an OpenAI-compatible API.
|
||||||
|
|
||||||
|
This includes:
|
||||||
|
- Direct OpenAI API
|
||||||
|
- OpenRouter
|
||||||
|
- Any other OpenAI-compatible endpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_HEADERS = {}
|
||||||
|
FRIENDLY_NAME = "OpenAI Compatible"
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str = None, **kwargs):
|
||||||
|
"""Initialize the provider with API key and optional base URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication
|
||||||
|
base_url: Base URL for the API endpoint
|
||||||
|
**kwargs: Additional configuration options
|
||||||
|
"""
|
||||||
|
super().__init__(api_key, **kwargs)
|
||||||
|
self._client = None
|
||||||
|
self.base_url = base_url
|
||||||
|
self.organization = kwargs.get("organization")
|
||||||
|
self.allowed_models = self._parse_allowed_models()
|
||||||
|
|
||||||
|
# Validate base URL for security
|
||||||
|
if self.base_url:
|
||||||
|
self._validate_base_url()
|
||||||
|
|
||||||
|
# Warn if using external URL without authentication
|
||||||
|
if self.base_url and not self._is_localhost_url() and not api_key:
|
||||||
|
logging.warning(
|
||||||
|
f"Using external URL '{self.base_url}' without API key. "
|
||||||
|
"This may be insecure. Consider setting an API key for authentication."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||||||
|
"""Parse allowed models from environment variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of allowed model names (lowercase) or None if not configured
|
||||||
|
"""
|
||||||
|
# Get provider-specific allowed models
|
||||||
|
provider_type = self.get_provider_type().value.upper()
|
||||||
|
env_var = f"{provider_type}_ALLOWED_MODELS"
|
||||||
|
models_str = os.getenv(env_var, "")
|
||||||
|
|
||||||
|
if models_str:
|
||||||
|
# Parse and normalize to lowercase for case-insensitive comparison
|
||||||
|
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
||||||
|
if models:
|
||||||
|
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
||||||
|
return models
|
||||||
|
|
||||||
|
# Log warning if no allow-list configured for proxy providers
|
||||||
|
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
|
||||||
|
logging.warning(
|
||||||
|
f"No model allow-list configured for {self.FRIENDLY_NAME}. "
|
||||||
|
f"Set {env_var} to restrict model access and control costs."
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_localhost_url(self) -> bool:
|
||||||
|
"""Check if the base URL points to localhost.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if URL is localhost, False otherwise
|
||||||
|
"""
|
||||||
|
if not self.base_url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(self.base_url)
|
||||||
|
hostname = parsed.hostname
|
||||||
|
|
||||||
|
# Check for common localhost patterns
|
||||||
|
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _validate_base_url(self) -> None:
|
||||||
|
"""Validate base URL for security (SSRF protection).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If URL is invalid or potentially unsafe
|
||||||
|
"""
|
||||||
|
if not self.base_url:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(self.base_url)
|
||||||
|
|
||||||
|
# Check URL scheme - only allow http/https
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
|
||||||
|
|
||||||
|
# Check hostname exists
|
||||||
|
if not parsed.hostname:
|
||||||
|
raise ValueError("URL must include a hostname")
|
||||||
|
|
||||||
|
# Check port - allow only standard HTTP/HTTPS ports
|
||||||
|
port = parsed.port
|
||||||
|
if port is None:
|
||||||
|
port = 443 if parsed.scheme == "https" else 80
|
||||||
|
|
||||||
|
# Allow common HTTP ports and some alternative ports
|
||||||
|
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
|
||||||
|
if port not in allowed_ports:
|
||||||
|
raise ValueError(f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}")
|
||||||
|
|
||||||
|
# Check against allowed domains if configured
|
||||||
|
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
|
||||||
|
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
|
||||||
|
|
||||||
|
if allowed_domains:
|
||||||
|
hostname_lower = parsed.hostname.lower()
|
||||||
|
if not any(
|
||||||
|
hostname_lower == domain or hostname_lower.endswith("." + domain) for domain in allowed_domains
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Domain not in allow-list: {parsed.hostname}. " f"Allowed domains: {allowed_domains}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to resolve hostname and check if it's a private IP
|
||||||
|
# Skip for localhost addresses which are commonly used for development
|
||||||
|
if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
|
||||||
|
try:
|
||||||
|
# Get all IP addresses for the hostname
|
||||||
|
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
|
||||||
|
|
||||||
|
for _family, _, _, _, sockaddr in addr_info:
|
||||||
|
ip_str = sockaddr[0]
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(ip_str)
|
||||||
|
|
||||||
|
# Check for dangerous IP ranges
|
||||||
|
if (
|
||||||
|
ip.is_private
|
||||||
|
or ip.is_loopback
|
||||||
|
or ip.is_link_local
|
||||||
|
or ip.is_multicast
|
||||||
|
or ip.is_reserved
|
||||||
|
or ip.is_unspecified
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"URL resolves to restricted IP address: {ip_str}. "
|
||||||
|
"This could be a security risk (SSRF)."
|
||||||
|
)
|
||||||
|
except ValueError as ve:
|
||||||
|
# Invalid IP address format or restricted IP - re-raise if it's our security error
|
||||||
|
if "restricted IP address" in str(ve):
|
||||||
|
raise
|
||||||
|
continue
|
||||||
|
|
||||||
|
except socket.gaierror as e:
|
||||||
|
# If we can't resolve the hostname, it's suspicious
|
||||||
|
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, ValueError):
|
||||||
|
raise
|
||||||
|
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self):
|
||||||
|
"""Lazy initialization of OpenAI client with security checks."""
|
||||||
|
if self._client is None:
|
||||||
|
client_kwargs = {
|
||||||
|
"api_key": self.api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.base_url:
|
||||||
|
client_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
|
if self.organization:
|
||||||
|
client_kwargs["organization"] = self.organization
|
||||||
|
|
||||||
|
# Add default headers if any
|
||||||
|
if self.DEFAULT_HEADERS:
|
||||||
|
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||||||
|
|
||||||
|
self._client = OpenAI(**client_kwargs)
|
||||||
|
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using the OpenAI-compatible API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send to the model
|
||||||
|
model_name: Name of the model to use
|
||||||
|
system_prompt: Optional system prompt for model behavior
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_output_tokens: Maximum tokens to generate
|
||||||
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelResponse with generated content and metadata
|
||||||
|
"""
|
||||||
|
# Validate model name against allow-list
|
||||||
|
if not self.validate_model_name(model_name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_name}' not in allowed models list. " f"Allowed models: {self.allowed_models}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate parameters
|
||||||
|
self.validate_parameters(model_name, temperature)
|
||||||
|
|
||||||
|
# Prepare messages
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
# Prepare completion parameters
|
||||||
|
completion_params = {
|
||||||
|
"model": model_name,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add max tokens if specified
|
||||||
|
if max_output_tokens:
|
||||||
|
completion_params["max_tokens"] = max_output_tokens
|
||||||
|
|
||||||
|
# Add any additional OpenAI-specific parameters
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||||||
|
completion_params[key] = value
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate completion
|
||||||
|
response = self.client.chat.completions.create(**completion_params)
|
||||||
|
|
||||||
|
# Extract content and usage
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
|
return ModelResponse(
|
||||||
|
content=content,
|
||||||
|
usage=usage,
|
||||||
|
model_name=model_name,
|
||||||
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
|
provider=self.get_provider_type(),
|
||||||
|
metadata={
|
||||||
|
"finish_reason": response.choices[0].finish_reason,
|
||||||
|
"model": response.model, # Actual model used
|
||||||
|
"id": response.id,
|
||||||
|
"created": response.created,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log error and re-raise with more context
|
||||||
|
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
|
"""Count tokens for the given text.
|
||||||
|
|
||||||
|
Uses a layered approach:
|
||||||
|
1. Try provider-specific token counting endpoint
|
||||||
|
2. Try tiktoken for known model families
|
||||||
|
3. Fall back to character-based estimation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count tokens for
|
||||||
|
model_name: Model name for tokenizer selection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
# 1. Check if provider has a remote token counting endpoint
|
||||||
|
if hasattr(self, "count_tokens_remote"):
|
||||||
|
try:
|
||||||
|
return self.count_tokens_remote(text, model_name)
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(f"Remote token counting failed: {e}")
|
||||||
|
|
||||||
|
# 2. Try tiktoken for known models
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
# Try to get encoding for the specific model
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model_name)
|
||||||
|
except KeyError:
|
||||||
|
# Try common encodings based on model patterns
|
||||||
|
if "gpt-4" in model_name or "gpt-3.5" in model_name:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
else:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base") # Default
|
||||||
|
|
||||||
|
return len(encoding.encode(text))
|
||||||
|
|
||||||
|
except (ImportError, Exception) as e:
|
||||||
|
logging.debug(f"Tiktoken not available or failed: {e}")
|
||||||
|
|
||||||
|
# 3. Fall back to character-based estimation
|
||||||
|
logging.warning(
|
||||||
|
f"No specific tokenizer available for '{model_name}'. "
|
||||||
|
"Using character-based estimation (~4 chars per token)."
|
||||||
|
)
|
||||||
|
return len(text) // 4
|
||||||
|
|
||||||
|
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||||
|
"""Validate model parameters.
|
||||||
|
|
||||||
|
For proxy providers, this may use generic capabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model to validate for
|
||||||
|
temperature: Temperature to validate
|
||||||
|
**kwargs: Additional parameters to validate
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
|
# Check if we're using generic capabilities
|
||||||
|
if hasattr(capabilities, "_is_generic"):
|
||||||
|
logging.debug(
|
||||||
|
f"Using generic parameter validation for {model_name}. " "Actual model constraints may differ."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate temperature using parent class method
|
||||||
|
super().validate_parameters(model_name, temperature, **kwargs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# For proxy providers, we might not have accurate capabilities
|
||||||
|
# Log warning but don't fail
|
||||||
|
logging.warning(f"Parameter validation limited for {model_name}: {e}")
|
||||||
|
|
||||||
|
def _extract_usage(self, response) -> dict[str, int]:
|
||||||
|
"""Extract token usage from OpenAI response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: OpenAI API response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with usage statistics
|
||||||
|
"""
|
||||||
|
usage = {}
|
||||||
|
|
||||||
|
if hasattr(response, "usage") and response.usage:
|
||||||
|
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
|
||||||
|
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
|
||||||
|
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
|
"""Get capabilities for a specific model.
|
||||||
|
|
||||||
|
Must be implemented by subclasses.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type.
|
||||||
|
|
||||||
|
Must be implemented by subclasses.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
|
"""Validate if the model name is supported.
|
||||||
|
|
||||||
|
Must be implemented by subclasses.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
|
"""Check if the model supports extended thinking mode.
|
||||||
|
|
||||||
|
Default is False for OpenAI-compatible providers.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
191
providers/openrouter.py
Normal file
191
providers/openrouter.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
"""OpenRouter provider implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderType,
|
||||||
|
RangeTemperatureConstraint,
|
||||||
|
)
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterProvider(OpenAICompatibleProvider):
|
||||||
|
"""OpenRouter unified API provider.
|
||||||
|
|
||||||
|
OpenRouter provides access to multiple AI models through a single API endpoint.
|
||||||
|
See https://openrouter.ai for available models and pricing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
FRIENDLY_NAME = "OpenRouter"
|
||||||
|
|
||||||
|
# Custom headers required by OpenRouter
|
||||||
|
DEFAULT_HEADERS = {
|
||||||
|
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
|
||||||
|
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Model registry for managing configurations and aliases
|
||||||
|
_registry: Optional[OpenRouterModelRegistry] = None
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize OpenRouter provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenRouter API key
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
"""
|
||||||
|
# Always use OpenRouter's base URL
|
||||||
|
super().__init__(api_key, base_url="https://openrouter.ai/api/v1", **kwargs)
|
||||||
|
|
||||||
|
# Initialize model registry
|
||||||
|
if OpenRouterProvider._registry is None:
|
||||||
|
OpenRouterProvider._registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
# Log loaded models and aliases
|
||||||
|
models = self._registry.list_models()
|
||||||
|
aliases = self._registry.list_aliases()
|
||||||
|
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
|
||||||
|
|
||||||
|
def _parse_allowed_models(self) -> None:
|
||||||
|
"""Override to disable environment-based allow-list.
|
||||||
|
|
||||||
|
OpenRouter model access is controlled via the OpenRouter dashboard,
|
||||||
|
not through environment variables.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
"""Resolve model aliases to OpenRouter model names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Input model name or alias
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved OpenRouter model name
|
||||||
|
"""
|
||||||
|
# Try to resolve through registry
|
||||||
|
config = self._registry.resolve(model_name)
|
||||||
|
|
||||||
|
if config:
|
||||||
|
if config.model_name != model_name:
|
||||||
|
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||||
|
return config.model_name
|
||||||
|
else:
|
||||||
|
# If not found in registry, return as-is
|
||||||
|
# This allows using models not in our config file
|
||||||
|
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
|
"""Get capabilities for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model (or alias)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelCapabilities from registry or generic defaults
|
||||||
|
"""
|
||||||
|
# Try to get from registry first
|
||||||
|
capabilities = self._registry.get_capabilities(model_name)
|
||||||
|
|
||||||
|
if capabilities:
|
||||||
|
return capabilities
|
||||||
|
else:
|
||||||
|
# Resolve any potential aliases and create generic capabilities
|
||||||
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
logging.debug(
|
||||||
|
f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
|
||||||
|
"Consider adding to openrouter_models.json for specific capabilities."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create generic capabilities with conservative defaults
|
||||||
|
capabilities = ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENROUTER,
|
||||||
|
model_name=resolved_name,
|
||||||
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
|
max_tokens=32_768, # Conservative default context window
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False,
|
||||||
|
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark as generic for validation purposes
|
||||||
|
capabilities._is_generic = True
|
||||||
|
|
||||||
|
return capabilities
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type."""
|
||||||
|
return ProviderType.OPENROUTER
|
||||||
|
|
||||||
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
|
"""Validate if the model name is allowed.
|
||||||
|
|
||||||
|
For OpenRouter, we accept any model name. OpenRouter will
|
||||||
|
validate based on the API key's permissions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model name to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Always True - OpenRouter handles validation
|
||||||
|
"""
|
||||||
|
# Accept any model name - OpenRouter will validate based on API key permissions
|
||||||
|
return True
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using the OpenRouter API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send to the model
|
||||||
|
model_name: Name of the model (or alias) to use
|
||||||
|
system_prompt: Optional system prompt for model behavior
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_output_tokens: Maximum tokens to generate
|
||||||
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelResponse with generated content and metadata
|
||||||
|
"""
|
||||||
|
# Resolve model alias to actual OpenRouter model name
|
||||||
|
resolved_model = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# Call parent method with resolved model name
|
||||||
|
return super().generate_content(
|
||||||
|
prompt=prompt,
|
||||||
|
model_name=resolved_model,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
|
"""Check if the model supports extended thinking mode.
|
||||||
|
|
||||||
|
Currently, no models via OpenRouter support extended thinking.
|
||||||
|
This may change as new models become available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
False (no OpenRouter models currently support thinking mode)
|
||||||
|
"""
|
||||||
|
return False
|
||||||
184
providers/openrouter_registry.py
Normal file
184
providers/openrouter_registry.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""OpenRouter model registry for managing model configurations and aliases."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OpenRouterModelConfig:
|
||||||
|
"""Configuration for an OpenRouter model."""
|
||||||
|
|
||||||
|
model_name: str
|
||||||
|
aliases: list[str] = field(default_factory=list)
|
||||||
|
context_window: int = 32768 # Total context window size in tokens
|
||||||
|
supports_extended_thinking: bool = False
|
||||||
|
supports_system_prompts: bool = True
|
||||||
|
supports_streaming: bool = True
|
||||||
|
supports_function_calling: bool = False
|
||||||
|
supports_json_mode: bool = False
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
def to_capabilities(self) -> ModelCapabilities:
|
||||||
|
"""Convert to ModelCapabilities object."""
|
||||||
|
return ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENROUTER,
|
||||||
|
model_name=self.model_name,
|
||||||
|
friendly_name="OpenRouter",
|
||||||
|
max_tokens=self.context_window, # ModelCapabilities still uses max_tokens
|
||||||
|
supports_extended_thinking=self.supports_extended_thinking,
|
||||||
|
supports_system_prompts=self.supports_system_prompts,
|
||||||
|
supports_streaming=self.supports_streaming,
|
||||||
|
supports_function_calling=self.supports_function_calling,
|
||||||
|
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterModelRegistry:
|
||||||
|
"""Registry for managing OpenRouter model configurations and aliases."""
|
||||||
|
|
||||||
|
def __init__(self, config_path: Optional[str] = None):
|
||||||
|
"""Initialize the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to config file. If None, uses default locations.
|
||||||
|
"""
|
||||||
|
self.alias_map: dict[str, str] = {} # alias -> model_name
|
||||||
|
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config
|
||||||
|
|
||||||
|
# Determine config path
|
||||||
|
if config_path:
|
||||||
|
self.config_path = Path(config_path)
|
||||||
|
else:
|
||||||
|
# Check environment variable first
|
||||||
|
env_path = os.getenv("OPENROUTER_MODELS_PATH")
|
||||||
|
if env_path:
|
||||||
|
self.config_path = Path(env_path)
|
||||||
|
else:
|
||||||
|
# Default to conf/openrouter_models.json
|
||||||
|
self.config_path = Path(__file__).parent.parent / "conf" / "openrouter_models.json"
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
self.reload()
|
||||||
|
|
||||||
|
def reload(self) -> None:
|
||||||
|
"""Reload configuration from disk."""
|
||||||
|
try:
|
||||||
|
configs = self._read_config()
|
||||||
|
self._build_maps(configs)
|
||||||
|
logging.info(f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases")
|
||||||
|
except ValueError as e:
|
||||||
|
# Re-raise ValueError only for duplicate aliases (critical config errors)
|
||||||
|
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||||
|
# Initialize with empty maps on failure
|
||||||
|
self.alias_map = {}
|
||||||
|
self.model_map = {}
|
||||||
|
if "Duplicate alias" in str(e):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||||
|
# Initialize with empty maps on failure
|
||||||
|
self.alias_map = {}
|
||||||
|
self.model_map = {}
|
||||||
|
|
||||||
|
def _read_config(self) -> list[OpenRouterModelConfig]:
|
||||||
|
"""Read configuration from file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model configurations
|
||||||
|
"""
|
||||||
|
if not self.config_path.exists():
|
||||||
|
logging.warning(f"OpenRouter model config not found at {self.config_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.config_path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Parse models
|
||||||
|
configs = []
|
||||||
|
for model_data in data.get("models", []):
|
||||||
|
# Handle backwards compatibility - rename max_tokens to context_window
|
||||||
|
if "max_tokens" in model_data and "context_window" not in model_data:
|
||||||
|
model_data["context_window"] = model_data.pop("max_tokens")
|
||||||
|
|
||||||
|
config = OpenRouterModelConfig(**model_data)
|
||||||
|
configs.append(config)
|
||||||
|
|
||||||
|
return configs
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Invalid JSON in {self.config_path}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error reading config from {self.config_path}: {e}")
|
||||||
|
|
||||||
|
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None:
|
||||||
|
"""Build alias and model maps from configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
configs: List of model configurations
|
||||||
|
"""
|
||||||
|
alias_map = {}
|
||||||
|
model_map = {}
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
# Add to model map
|
||||||
|
model_map[config.model_name] = config
|
||||||
|
|
||||||
|
# Add aliases
|
||||||
|
for alias in config.aliases:
|
||||||
|
alias_lower = alias.lower()
|
||||||
|
if alias_lower in alias_map:
|
||||||
|
existing_model = alias_map[alias_lower]
|
||||||
|
raise ValueError(
|
||||||
|
f"Duplicate alias '{alias}' found for models " f"'{existing_model}' and '{config.model_name}'"
|
||||||
|
)
|
||||||
|
alias_map[alias_lower] = config.model_name
|
||||||
|
|
||||||
|
# Atomic update
|
||||||
|
self.alias_map = alias_map
|
||||||
|
self.model_map = model_map
|
||||||
|
|
||||||
|
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]:
|
||||||
|
"""Resolve a model name or alias to configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_or_alias: Model name or alias to resolve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model configuration if found, None otherwise
|
||||||
|
"""
|
||||||
|
# Try alias first (case-insensitive)
|
||||||
|
alias_lower = name_or_alias.lower()
|
||||||
|
if alias_lower in self.alias_map:
|
||||||
|
model_name = self.alias_map[alias_lower]
|
||||||
|
return self.model_map.get(model_name)
|
||||||
|
|
||||||
|
# Try as direct model name
|
||||||
|
return self.model_map.get(name_or_alias)
|
||||||
|
|
||||||
|
def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||||
|
"""Get model capabilities for a name or alias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_or_alias: Model name or alias
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelCapabilities if found, None otherwise
|
||||||
|
"""
|
||||||
|
config = self.resolve(name_or_alias)
|
||||||
|
if config:
|
||||||
|
return config.to_capabilities()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
"""List all available model names."""
|
||||||
|
return list(self.model_map.keys())
|
||||||
|
|
||||||
|
def list_aliases(self) -> list[str]:
|
||||||
|
"""List all available aliases."""
|
||||||
|
return list(self.alias_map.keys())
|
||||||
@@ -117,6 +117,7 @@ class ModelProviderRegistry:
|
|||||||
key_mapping = {
|
key_mapping = {
|
||||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||||
|
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
||||||
}
|
}
|
||||||
|
|
||||||
env_var = key_mapping.get(provider_type)
|
env_var = key_mapping.get(provider_type)
|
||||||
|
|||||||
88
server.py
88
server.py
@@ -125,39 +125,83 @@ def configure_providers():
|
|||||||
At least one valid API key (Gemini or OpenAI) is required.
|
At least one valid API key (Gemini or OpenAI) is required.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no valid API keys are found
|
ValueError: If no valid API keys are found or conflicting configurations detected
|
||||||
"""
|
"""
|
||||||
from providers import ModelProviderRegistry
|
from providers import ModelProviderRegistry
|
||||||
from providers.base import ProviderType
|
from providers.base import ProviderType
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
|
||||||
valid_providers = []
|
valid_providers = []
|
||||||
|
has_native_apis = False
|
||||||
|
has_openrouter = False
|
||||||
|
|
||||||
# Check for Gemini API key
|
# Check for Gemini API key
|
||||||
gemini_key = os.getenv("GEMINI_API_KEY")
|
gemini_key = os.getenv("GEMINI_API_KEY")
|
||||||
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
|
||||||
valid_providers.append("Gemini")
|
valid_providers.append("Gemini")
|
||||||
|
has_native_apis = True
|
||||||
logger.info("Gemini API key found - Gemini models available")
|
logger.info("Gemini API key found - Gemini models available")
|
||||||
|
|
||||||
# Check for OpenAI API key
|
# Check for OpenAI API key
|
||||||
openai_key = os.getenv("OPENAI_API_KEY")
|
openai_key = os.getenv("OPENAI_API_KEY")
|
||||||
if openai_key and openai_key != "your_openai_api_key_here":
|
if openai_key and openai_key != "your_openai_api_key_here":
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
|
||||||
valid_providers.append("OpenAI (o3)")
|
valid_providers.append("OpenAI (o3)")
|
||||||
|
has_native_apis = True
|
||||||
logger.info("OpenAI API key found - o3 model available")
|
logger.info("OpenAI API key found - o3 model available")
|
||||||
|
|
||||||
|
# Check for OpenRouter API key
|
||||||
|
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
||||||
|
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
|
||||||
|
valid_providers.append("OpenRouter")
|
||||||
|
has_openrouter = True
|
||||||
|
logger.info("OpenRouter API key found - Multiple models available via OpenRouter")
|
||||||
|
|
||||||
|
# Check for conflicting configuration
|
||||||
|
if has_native_apis and has_openrouter:
|
||||||
|
logger.warning(
|
||||||
|
"\n" + "=" * 70 + "\n"
|
||||||
|
"WARNING: Both OpenRouter and native API keys detected!\n"
|
||||||
|
"\n"
|
||||||
|
"This creates ambiguity about which provider will be used for models\n"
|
||||||
|
"available through both APIs (e.g., 'o3' could come from OpenAI or OpenRouter).\n"
|
||||||
|
"\n"
|
||||||
|
"RECOMMENDATION: Use EITHER OpenRouter OR native APIs, not both.\n"
|
||||||
|
"\n"
|
||||||
|
"To fix this:\n"
|
||||||
|
"1. Use only OpenRouter: unset GEMINI_API_KEY and OPENAI_API_KEY\n"
|
||||||
|
"2. Use only native APIs: unset OPENROUTER_API_KEY\n"
|
||||||
|
"\n"
|
||||||
|
"Current configuration will prioritize native APIs over OpenRouter.\n" + "=" * 70 + "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register providers - native APIs first to ensure they take priority
|
||||||
|
if has_native_apis:
|
||||||
|
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
if openai_key and openai_key != "your_openai_api_key_here":
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
|
# Register OpenRouter last so native APIs take precedence
|
||||||
|
if has_openrouter:
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
# Require at least one valid provider
|
# Require at least one valid provider
|
||||||
if not valid_providers:
|
if not valid_providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one API key is required. Please set either:\n"
|
"At least one API key is required. Please set either:\n"
|
||||||
"- GEMINI_API_KEY for Gemini models\n"
|
"- GEMINI_API_KEY for Gemini models\n"
|
||||||
"- OPENAI_API_KEY for OpenAI o3 model"
|
"- OPENAI_API_KEY for OpenAI o3 model\n"
|
||||||
|
"- OPENROUTER_API_KEY for OpenRouter (multiple models)"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Available providers: {', '.join(valid_providers)}")
|
logger.info(f"Available providers: {', '.join(valid_providers)}")
|
||||||
|
|
||||||
|
# Log provider priority if both are configured
|
||||||
|
if has_native_apis and has_openrouter:
|
||||||
|
logger.info("Provider priority: Native APIs (Gemini, OpenAI) will be checked before OpenRouter")
|
||||||
|
|
||||||
|
|
||||||
@server.list_tools()
|
@server.list_tools()
|
||||||
async def handle_list_tools() -> list[Tool]:
|
async def handle_list_tools() -> list[Tool]:
|
||||||
@@ -318,18 +362,22 @@ If something needs clarification or you'd benefit from additional context, simpl
|
|||||||
IMPORTANT: When you suggest follow-ups or ask questions, you MUST explicitly instruct Claude to use the continuation_id
|
IMPORTANT: When you suggest follow-ups or ask questions, you MUST explicitly instruct Claude to use the continuation_id
|
||||||
to respond. Use clear, direct language based on urgency:
|
to respond. Use clear, direct language based on urgency:
|
||||||
|
|
||||||
For optional follow-ups: "Please continue this conversation using the continuation_id from this response if you'd like to explore this further."
|
For optional follow-ups: "Please continue this conversation using the continuation_id from this response if you'd "
|
||||||
|
"like to explore this further."
|
||||||
|
|
||||||
For needed responses: "Please respond using the continuation_id from this response - your input is needed to proceed."
|
For needed responses: "Please respond using the continuation_id from this response - your input is needed to proceed."
|
||||||
|
|
||||||
For essential/critical responses: "RESPONSE REQUIRED: Please immediately continue using the continuation_id from this response. Cannot proceed without your clarification/input."
|
For essential/critical responses: "RESPONSE REQUIRED: Please immediately continue using the continuation_id from "
|
||||||
|
"this response. Cannot proceed without your clarification/input."
|
||||||
|
|
||||||
This ensures Claude knows both HOW to maintain the conversation thread AND whether a response is optional, needed, or essential.
|
This ensures Claude knows both HOW to maintain the conversation thread AND whether a response is optional, "
|
||||||
|
"needed, or essential.
|
||||||
|
|
||||||
The tool will automatically provide a continuation_id in the structured response that Claude can use in subsequent
|
The tool will automatically provide a continuation_id in the structured response that Claude can use in subsequent
|
||||||
tool calls to maintain full conversation context across multiple exchanges.
|
tool calls to maintain full conversation context across multiple exchanges.
|
||||||
|
|
||||||
Remember: Only suggest follow-ups when they would genuinely add value to the discussion, and always instruct Claude to use the continuation_id when you do."""
|
Remember: Only suggest follow-ups when they would genuinely add value to the discussion, and always instruct "
|
||||||
|
"Claude to use the continuation_id when you do."""
|
||||||
|
|
||||||
|
|
||||||
async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any]:
|
async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any]:
|
||||||
@@ -366,8 +414,10 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
# Return error asking Claude to restart conversation with full context
|
# Return error asking Claude to restart conversation with full context
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Conversation thread '{continuation_id}' was not found or has expired. "
|
f"Conversation thread '{continuation_id}' was not found or has expired. "
|
||||||
f"This may happen if the conversation was created more than 1 hour ago or if there was an issue with Redis storage. "
|
f"This may happen if the conversation was created more than 1 hour ago or if there was an issue "
|
||||||
f"Please restart the conversation by providing your full question/prompt without the continuation_id parameter. "
|
f"with Redis storage. "
|
||||||
|
f"Please restart the conversation by providing your full question/prompt without the "
|
||||||
|
f"continuation_id parameter. "
|
||||||
f"This will create a new conversation thread that can continue with follow-up exchanges."
|
f"This will create a new conversation thread that can continue with follow-up exchanges."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -459,7 +509,8 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
try:
|
try:
|
||||||
mcp_activity_logger = logging.getLogger("mcp_activity")
|
mcp_activity_logger = logging.getLogger("mcp_activity")
|
||||||
mcp_activity_logger.info(
|
mcp_activity_logger.info(
|
||||||
f"CONVERSATION_CONTINUATION: Thread {continuation_id} turn {len(context.turns)} - {len(context.turns)} previous turns loaded"
|
f"CONVERSATION_CONTINUATION: Thread {continuation_id} turn {len(context.turns)} - "
|
||||||
|
f"{len(context.turns)} previous turns loaded"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -494,6 +545,18 @@ async def handle_get_version() -> list[TextContent]:
|
|||||||
"available_tools": list(TOOLS.keys()) + ["get_version"],
|
"available_tools": list(TOOLS.keys()) + ["get_version"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Check configured providers
|
||||||
|
from providers import ModelProviderRegistry
|
||||||
|
from providers.base import ProviderType
|
||||||
|
|
||||||
|
configured_providers = []
|
||||||
|
if ModelProviderRegistry.get_provider(ProviderType.GOOGLE):
|
||||||
|
configured_providers.append("Gemini (flash, pro)")
|
||||||
|
if ModelProviderRegistry.get_provider(ProviderType.OPENAI):
|
||||||
|
configured_providers.append("OpenAI (o3, o3-mini)")
|
||||||
|
if ModelProviderRegistry.get_provider(ProviderType.OPENROUTER):
|
||||||
|
configured_providers.append("OpenRouter (configured via conf/openrouter_models.json)")
|
||||||
|
|
||||||
# Format the information in a human-readable way
|
# Format the information in a human-readable way
|
||||||
text = f"""Zen MCP Server v{__version__}
|
text = f"""Zen MCP Server v{__version__}
|
||||||
Updated: {__updated__}
|
Updated: {__updated__}
|
||||||
@@ -506,6 +569,9 @@ Configuration:
|
|||||||
- Python: {version_info["python_version"]}
|
- Python: {version_info["python_version"]}
|
||||||
- Started: {version_info["server_started"]}
|
- Started: {version_info["server_started"]}
|
||||||
|
|
||||||
|
Configured Providers:
|
||||||
|
{chr(10).join(f" - {provider}" for provider in configured_providers)}
|
||||||
|
|
||||||
Available Tools:
|
Available Tools:
|
||||||
{chr(10).join(f" - {tool}" for tool in version_info["available_tools"])}
|
{chr(10).join(f" - {tool}" for tool in version_info["available_tools"])}
|
||||||
|
|
||||||
|
|||||||
@@ -36,8 +36,6 @@ else
|
|||||||
else
|
else
|
||||||
echo "⚠️ Found GEMINI_API_KEY in environment, but sed not available. Please update .env manually."
|
echo "⚠️ Found GEMINI_API_KEY in environment, but sed not available. Please update .env manually."
|
||||||
fi
|
fi
|
||||||
else
|
|
||||||
echo "⚠️ GEMINI_API_KEY not found in environment. Please edit .env and add your API key."
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "${OPENAI_API_KEY:-}" ]; then
|
if [ -n "${OPENAI_API_KEY:-}" ]; then
|
||||||
@@ -48,8 +46,16 @@ else
|
|||||||
else
|
else
|
||||||
echo "⚠️ Found OPENAI_API_KEY in environment, but sed not available. Please update .env manually."
|
echo "⚠️ Found OPENAI_API_KEY in environment, but sed not available. Please update .env manually."
|
||||||
fi
|
fi
|
||||||
else
|
fi
|
||||||
echo "⚠️ OPENAI_API_KEY not found in environment. Please edit .env and add your API key."
|
|
||||||
|
if [ -n "${OPENROUTER_API_KEY:-}" ]; then
|
||||||
|
# Replace the placeholder API key with the actual value
|
||||||
|
if command -v sed >/dev/null 2>&1; then
|
||||||
|
sed -i.bak "s/your_openrouter_api_key_here/$OPENROUTER_API_KEY/" .env && rm .env.bak
|
||||||
|
echo "✅ Updated .env with existing OPENROUTER_API_KEY from environment"
|
||||||
|
else
|
||||||
|
echo "⚠️ Found OPENROUTER_API_KEY in environment, but sed not available. Please update .env manually."
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Update WORKSPACE_ROOT to use current user's home directory
|
# Update WORKSPACE_ROOT to use current user's home directory
|
||||||
@@ -92,6 +98,7 @@ source .env 2>/dev/null || true
|
|||||||
|
|
||||||
VALID_GEMINI_KEY=false
|
VALID_GEMINI_KEY=false
|
||||||
VALID_OPENAI_KEY=false
|
VALID_OPENAI_KEY=false
|
||||||
|
VALID_OPENROUTER_KEY=false
|
||||||
|
|
||||||
# Check if GEMINI_API_KEY is set and not the placeholder
|
# Check if GEMINI_API_KEY is set and not the placeholder
|
||||||
if [ -n "${GEMINI_API_KEY:-}" ] && [ "$GEMINI_API_KEY" != "your_gemini_api_key_here" ]; then
|
if [ -n "${GEMINI_API_KEY:-}" ] && [ "$GEMINI_API_KEY" != "your_gemini_api_key_here" ]; then
|
||||||
@@ -105,18 +112,55 @@ if [ -n "${OPENAI_API_KEY:-}" ] && [ "$OPENAI_API_KEY" != "your_openai_api_key_h
|
|||||||
echo "✅ Valid OPENAI_API_KEY found"
|
echo "✅ Valid OPENAI_API_KEY found"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Check if OPENROUTER_API_KEY is set and not the placeholder
|
||||||
|
if [ -n "${OPENROUTER_API_KEY:-}" ] && [ "$OPENROUTER_API_KEY" != "your_openrouter_api_key_here" ]; then
|
||||||
|
VALID_OPENROUTER_KEY=true
|
||||||
|
echo "✅ Valid OPENROUTER_API_KEY found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check for conflicting configuration
|
||||||
|
if [ "$VALID_OPENROUTER_KEY" = true ] && ([ "$VALID_GEMINI_KEY" = true ] || [ "$VALID_OPENAI_KEY" = true ]); then
|
||||||
|
echo ""
|
||||||
|
echo "⚠️ WARNING: Conflicting API configuration detected!"
|
||||||
|
echo ""
|
||||||
|
echo "You have configured both:"
|
||||||
|
echo " - OpenRouter API key"
|
||||||
|
if [ "$VALID_GEMINI_KEY" = true ]; then
|
||||||
|
echo " - Native Gemini API key"
|
||||||
|
fi
|
||||||
|
if [ "$VALID_OPENAI_KEY" = true ]; then
|
||||||
|
echo " - Native OpenAI API key"
|
||||||
|
fi
|
||||||
|
echo ""
|
||||||
|
echo "This creates ambiguity about which provider to use for models available"
|
||||||
|
echo "through multiple APIs (e.g., 'o3' could come from OpenAI or OpenRouter)."
|
||||||
|
echo ""
|
||||||
|
echo "RECOMMENDATION: Use EITHER OpenRouter OR native APIs, not both."
|
||||||
|
echo ""
|
||||||
|
echo "To fix this, edit .env and:"
|
||||||
|
echo " Option 1: Use only OpenRouter - comment out GEMINI_API_KEY and OPENAI_API_KEY"
|
||||||
|
echo " Option 2: Use only native APIs - comment out OPENROUTER_API_KEY"
|
||||||
|
echo ""
|
||||||
|
echo "The server will start anyway, but native APIs will take priority over OpenRouter."
|
||||||
|
echo ""
|
||||||
|
# Give user time to read the warning
|
||||||
|
sleep 3
|
||||||
|
fi
|
||||||
|
|
||||||
# Require at least one valid API key
|
# Require at least one valid API key
|
||||||
if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ]; then
|
if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ] && [ "$VALID_OPENROUTER_KEY" = false ]; then
|
||||||
echo ""
|
echo ""
|
||||||
echo "❌ ERROR: At least one valid API key is required!"
|
echo "❌ ERROR: At least one valid API key is required!"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Please edit the .env file and set at least one of:"
|
echo "Please edit the .env file and set at least one of:"
|
||||||
echo " - GEMINI_API_KEY (get from https://makersuite.google.com/app/apikey)"
|
echo " - GEMINI_API_KEY (get from https://makersuite.google.com/app/apikey)"
|
||||||
echo " - OPENAI_API_KEY (get from https://platform.openai.com/api-keys)"
|
echo " - OPENAI_API_KEY (get from https://platform.openai.com/api-keys)"
|
||||||
|
echo " - OPENROUTER_API_KEY (get from https://openrouter.ai/)"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Example:"
|
echo "Example:"
|
||||||
echo " GEMINI_API_KEY=your-actual-api-key-here"
|
echo " GEMINI_API_KEY=your-actual-api-key-here"
|
||||||
echo " OPENAI_API_KEY=sk-your-actual-openai-key-here"
|
echo " OPENAI_API_KEY=sk-your-actual-openai-key-here"
|
||||||
|
echo " OPENROUTER_API_KEY=sk-or-your-actual-openrouter-key-here"
|
||||||
echo ""
|
echo ""
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
@@ -193,14 +237,14 @@ fi
|
|||||||
|
|
||||||
# Build and start services
|
# Build and start services
|
||||||
echo " - Building Zen MCP Server image..."
|
echo " - Building Zen MCP Server image..."
|
||||||
if $COMPOSE_CMD build --no-cache >/dev/null 2>&1; then
|
if $COMPOSE_CMD build >/dev/null 2>&1; then
|
||||||
echo "✅ Docker image built successfully!"
|
echo "✅ Docker image built successfully!"
|
||||||
else
|
else
|
||||||
echo "❌ Failed to build Docker image. Run '$COMPOSE_CMD build' manually to see errors."
|
echo "❌ Failed to build Docker image. Run '$COMPOSE_CMD build' manually to see errors."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo " - Starting Redis (needed for conversation memory)... please wait"
|
echo " - Starting all services (Redis + Zen MCP Server)..."
|
||||||
if $COMPOSE_CMD up -d >/dev/null 2>&1; then
|
if $COMPOSE_CMD up -d >/dev/null 2>&1; then
|
||||||
echo "✅ Services started successfully!"
|
echo "✅ Services started successfully!"
|
||||||
else
|
else
|
||||||
@@ -208,10 +252,6 @@ else
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Wait for services to be healthy
|
|
||||||
echo " - Waiting for Redis to be ready..."
|
|
||||||
sleep 3
|
|
||||||
|
|
||||||
# Check service status
|
# Check service status
|
||||||
if $COMPOSE_CMD ps --format table | grep -q "Up" 2>/dev/null || false; then
|
if $COMPOSE_CMD ps --format table | grep -q "Up" 2>/dev/null || false; then
|
||||||
echo "✅ All services are running!"
|
echo "✅ All services are running!"
|
||||||
@@ -228,7 +268,7 @@ show_configuration_steps() {
|
|||||||
echo ""
|
echo ""
|
||||||
echo "🔄 Next steps:"
|
echo "🔄 Next steps:"
|
||||||
NEEDS_KEY_UPDATE=false
|
NEEDS_KEY_UPDATE=false
|
||||||
if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null; then
|
if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null || grep -q "your_openrouter_api_key_here" .env 2>/dev/null; then
|
||||||
NEEDS_KEY_UPDATE=true
|
NEEDS_KEY_UPDATE=true
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@@ -236,6 +276,7 @@ show_configuration_steps() {
|
|||||||
echo "1. Edit .env and replace placeholder API keys with actual ones"
|
echo "1. Edit .env and replace placeholder API keys with actual ones"
|
||||||
echo " - GEMINI_API_KEY: your-gemini-api-key-here"
|
echo " - GEMINI_API_KEY: your-gemini-api-key-here"
|
||||||
echo " - OPENAI_API_KEY: your-openai-api-key-here"
|
echo " - OPENAI_API_KEY: your-openai-api-key-here"
|
||||||
|
echo " - OPENROUTER_API_KEY: your-openrouter-api-key-here (optional)"
|
||||||
echo "2. Restart services: $COMPOSE_CMD restart"
|
echo "2. Restart services: $COMPOSE_CMD restart"
|
||||||
echo "3. Copy the configuration below to your Claude Desktop config if required:"
|
echo "3. Copy the configuration below to your Claude Desktop config if required:"
|
||||||
else
|
else
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from .test_cross_tool_continuation import CrossToolContinuationTest
|
|||||||
from .test_logs_validation import LogsValidationTest
|
from .test_logs_validation import LogsValidationTest
|
||||||
from .test_model_thinking_config import TestModelThinkingConfig
|
from .test_model_thinking_config import TestModelThinkingConfig
|
||||||
from .test_o3_model_selection import O3ModelSelectionTest
|
from .test_o3_model_selection import O3ModelSelectionTest
|
||||||
|
from .test_openrouter_fallback import OpenRouterFallbackTest
|
||||||
|
from .test_openrouter_models import OpenRouterModelsTest
|
||||||
from .test_per_tool_deduplication import PerToolDeduplicationTest
|
from .test_per_tool_deduplication import PerToolDeduplicationTest
|
||||||
from .test_redis_validation import RedisValidationTest
|
from .test_redis_validation import RedisValidationTest
|
||||||
from .test_token_allocation_validation import TokenAllocationValidationTest
|
from .test_token_allocation_validation import TokenAllocationValidationTest
|
||||||
@@ -29,6 +31,8 @@ TEST_REGISTRY = {
|
|||||||
"redis_validation": RedisValidationTest,
|
"redis_validation": RedisValidationTest,
|
||||||
"model_thinking_config": TestModelThinkingConfig,
|
"model_thinking_config": TestModelThinkingConfig,
|
||||||
"o3_model_selection": O3ModelSelectionTest,
|
"o3_model_selection": O3ModelSelectionTest,
|
||||||
|
"openrouter_fallback": OpenRouterFallbackTest,
|
||||||
|
"openrouter_models": OpenRouterModelsTest,
|
||||||
"token_allocation_validation": TokenAllocationValidationTest,
|
"token_allocation_validation": TokenAllocationValidationTest,
|
||||||
"conversation_chain_validation": ConversationChainValidationTest,
|
"conversation_chain_validation": ConversationChainValidationTest,
|
||||||
}
|
}
|
||||||
@@ -44,6 +48,8 @@ __all__ = [
|
|||||||
"RedisValidationTest",
|
"RedisValidationTest",
|
||||||
"TestModelThinkingConfig",
|
"TestModelThinkingConfig",
|
||||||
"O3ModelSelectionTest",
|
"O3ModelSelectionTest",
|
||||||
|
"OpenRouterFallbackTest",
|
||||||
|
"OpenRouterModelsTest",
|
||||||
"TokenAllocationValidationTest",
|
"TokenAllocationValidationTest",
|
||||||
"ConversationChainValidationTest",
|
"ConversationChainValidationTest",
|
||||||
"TEST_REGISTRY",
|
"TEST_REGISTRY",
|
||||||
|
|||||||
@@ -45,6 +45,35 @@ class O3ModelSelectionTest(BaseSimulatorTest):
|
|||||||
try:
|
try:
|
||||||
self.logger.info(" Test: O3 model selection and usage validation")
|
self.logger.info(" Test: O3 model selection and usage validation")
|
||||||
|
|
||||||
|
# Check which API keys are configured
|
||||||
|
check_cmd = [
|
||||||
|
"docker",
|
||||||
|
"exec",
|
||||||
|
self.container_name,
|
||||||
|
"python",
|
||||||
|
"-c",
|
||||||
|
'import os; print(f\'OPENAI_KEY:{bool(os.environ.get("OPENAI_API_KEY"))}|OPENROUTER_KEY:{bool(os.environ.get("OPENROUTER_API_KEY"))}\')',
|
||||||
|
]
|
||||||
|
result = subprocess.run(check_cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
has_openai = False
|
||||||
|
has_openrouter = False
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
output = result.stdout.strip()
|
||||||
|
if "OPENAI_KEY:True" in output:
|
||||||
|
has_openai = True
|
||||||
|
if "OPENROUTER_KEY:True" in output:
|
||||||
|
has_openrouter = True
|
||||||
|
|
||||||
|
# If only OpenRouter is configured, adjust test expectations
|
||||||
|
if has_openrouter and not has_openai:
|
||||||
|
self.logger.info(" ℹ️ Only OpenRouter configured - O3 models will be routed through OpenRouter")
|
||||||
|
return self._run_openrouter_o3_test()
|
||||||
|
|
||||||
|
# Original test for when OpenAI is configured
|
||||||
|
self.logger.info(" ℹ️ OpenAI API configured - expecting direct OpenAI API calls")
|
||||||
|
|
||||||
# Setup test files for later use
|
# Setup test files for later use
|
||||||
self.setup_test_files()
|
self.setup_test_files()
|
||||||
|
|
||||||
@@ -192,6 +221,129 @@ def multiply(x, y):
|
|||||||
finally:
|
finally:
|
||||||
self.cleanup_test_files()
|
self.cleanup_test_files()
|
||||||
|
|
||||||
|
def _run_openrouter_o3_test(self) -> bool:
|
||||||
|
"""Test O3 model selection when using OpenRouter"""
|
||||||
|
try:
|
||||||
|
# Setup test files
|
||||||
|
self.setup_test_files()
|
||||||
|
|
||||||
|
# Test 1: O3 model via OpenRouter
|
||||||
|
self.logger.info(" 1: Testing O3 model via OpenRouter")
|
||||||
|
|
||||||
|
response1, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Simple test: What is 2 + 2? Just give a brief answer.",
|
||||||
|
"model": "o3",
|
||||||
|
"temperature": 1.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response1:
|
||||||
|
self.logger.error(" ❌ O3 model test via OpenRouter failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ O3 model call via OpenRouter completed")
|
||||||
|
|
||||||
|
# Test 2: O3-mini model via OpenRouter
|
||||||
|
self.logger.info(" 2: Testing O3-mini model via OpenRouter")
|
||||||
|
|
||||||
|
response2, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Simple test: What is 3 + 3? Just give a brief answer.",
|
||||||
|
"model": "o3-mini",
|
||||||
|
"temperature": 1.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response2:
|
||||||
|
self.logger.error(" ❌ O3-mini model test via OpenRouter failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ O3-mini model call via OpenRouter completed")
|
||||||
|
|
||||||
|
# Test 3: Codereview with O3 via OpenRouter
|
||||||
|
self.logger.info(" 3: Testing O3 with codereview tool via OpenRouter")
|
||||||
|
|
||||||
|
test_code = """def add(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
def multiply(x, y):
|
||||||
|
return x * y
|
||||||
|
"""
|
||||||
|
test_file = self.create_additional_test_file("simple_math.py", test_code)
|
||||||
|
|
||||||
|
response3, _ = self.call_mcp_tool(
|
||||||
|
"codereview",
|
||||||
|
{
|
||||||
|
"files": [test_file],
|
||||||
|
"prompt": "Quick review of this simple code",
|
||||||
|
"model": "o3",
|
||||||
|
"temperature": 1.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response3:
|
||||||
|
self.logger.error(" ❌ O3 with codereview tool via OpenRouter failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ O3 with codereview tool via OpenRouter completed")
|
||||||
|
|
||||||
|
# Validate OpenRouter usage in logs
|
||||||
|
self.logger.info(" 4: Validating OpenRouter usage in logs")
|
||||||
|
logs = self.get_recent_server_logs()
|
||||||
|
|
||||||
|
# Check for OpenRouter API calls
|
||||||
|
openrouter_api_logs = [
|
||||||
|
line
|
||||||
|
for line in logs.split("\n")
|
||||||
|
if "openrouter" in line.lower() and ("API" in line or "request" in line)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check for model resolution through OpenRouter
|
||||||
|
openrouter_model_logs = [
|
||||||
|
line for line in logs.split("\n") if "openrouter" in line.lower() and ("o3" in line or "model" in line)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check for successful responses
|
||||||
|
openrouter_response_logs = [
|
||||||
|
line for line in logs.split("\n") if "openrouter" in line.lower() and "response" in line
|
||||||
|
]
|
||||||
|
|
||||||
|
self.logger.info(f" OpenRouter API logs: {len(openrouter_api_logs)}")
|
||||||
|
self.logger.info(f" OpenRouter model logs: {len(openrouter_model_logs)}")
|
||||||
|
self.logger.info(f" OpenRouter response logs: {len(openrouter_response_logs)}")
|
||||||
|
|
||||||
|
# Success criteria for OpenRouter
|
||||||
|
openrouter_used = len(openrouter_api_logs) >= 3 or len(openrouter_model_logs) >= 3
|
||||||
|
all_calls_succeeded = response1 and response2 and response3
|
||||||
|
|
||||||
|
success_criteria = [
|
||||||
|
("All O3 model calls succeeded", all_calls_succeeded),
|
||||||
|
("OpenRouter provider was used", openrouter_used),
|
||||||
|
]
|
||||||
|
|
||||||
|
passed_criteria = sum(1 for _, passed in success_criteria if passed)
|
||||||
|
self.logger.info(f" Success criteria met: {passed_criteria}/{len(success_criteria)}")
|
||||||
|
|
||||||
|
for criterion, passed in success_criteria:
|
||||||
|
status = "✅" if passed else "❌"
|
||||||
|
self.logger.info(f" {status} {criterion}")
|
||||||
|
|
||||||
|
if passed_criteria == len(success_criteria):
|
||||||
|
self.logger.info(" ✅ O3 model selection via OpenRouter passed")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.logger.error(" ❌ O3 model selection via OpenRouter failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"OpenRouter O3 test failed: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
self.cleanup_test_files()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run the O3 model selection tests"""
|
"""Run the O3 model selection tests"""
|
||||||
|
|||||||
241
simulator_tests/test_openrouter_fallback.py
Normal file
241
simulator_tests/test_openrouter_fallback.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
OpenRouter Fallback Test
|
||||||
|
|
||||||
|
Tests that verify the system correctly falls back to OpenRouter when:
|
||||||
|
- Only OPENROUTER_API_KEY is configured
|
||||||
|
- Native models (flash, pro) are requested but map to OpenRouter equivalents
|
||||||
|
- Auto mode correctly selects OpenRouter models
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from .base_test import BaseSimulatorTest
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterFallbackTest(BaseSimulatorTest):
|
||||||
|
"""Test OpenRouter fallback behavior when it's the only provider"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_name(self) -> str:
|
||||||
|
return "openrouter_fallback"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_description(self) -> str:
|
||||||
|
return "OpenRouter fallback behavior when only provider"
|
||||||
|
|
||||||
|
def get_recent_server_logs(self) -> str:
|
||||||
|
"""Get recent server logs from the log file directly"""
|
||||||
|
try:
|
||||||
|
cmd = ["docker", "exec", self.container_name, "tail", "-n", "300", "/tmp/mcp_server.log"]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
return result.stdout
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"Failed to read server logs: {result.stderr}")
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to get server logs: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def run_test(self) -> bool:
|
||||||
|
"""Test OpenRouter fallback behavior"""
|
||||||
|
try:
|
||||||
|
self.logger.info("Test: OpenRouter fallback behavior when only provider available")
|
||||||
|
|
||||||
|
# Check if OpenRouter API key is configured
|
||||||
|
check_cmd = [
|
||||||
|
"docker",
|
||||||
|
"exec",
|
||||||
|
self.container_name,
|
||||||
|
"python",
|
||||||
|
"-c",
|
||||||
|
'import os; print("OPENROUTER_KEY:" + str(bool(os.environ.get("OPENROUTER_API_KEY"))))',
|
||||||
|
]
|
||||||
|
result = subprocess.run(check_cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0 and "OPENROUTER_KEY:False" in result.stdout:
|
||||||
|
self.logger.info(" ⚠️ OpenRouter API key not configured - skipping test")
|
||||||
|
self.logger.info(" ℹ️ This test requires OPENROUTER_API_KEY to be set in .env")
|
||||||
|
return True # Return True to indicate test is skipped, not failed
|
||||||
|
|
||||||
|
# Setup test files
|
||||||
|
self.setup_test_files()
|
||||||
|
|
||||||
|
# Test 1: Auto mode should work with OpenRouter
|
||||||
|
self.logger.info(" 1: Testing auto mode with OpenRouter as only provider")
|
||||||
|
|
||||||
|
response1, continuation_id = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "What is 2 + 2? Give a brief answer.",
|
||||||
|
# No model specified - should use auto mode
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response1:
|
||||||
|
self.logger.error(" ❌ Auto mode with OpenRouter failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ Auto mode call completed with OpenRouter")
|
||||||
|
|
||||||
|
# Test 2: Flash model should map to OpenRouter equivalent
|
||||||
|
self.logger.info(" 2: Testing flash model mapping to OpenRouter")
|
||||||
|
|
||||||
|
# Use codereview tool to test a different tool type
|
||||||
|
test_code = """def calculate_sum(numbers):
|
||||||
|
total = 0
|
||||||
|
for num in numbers:
|
||||||
|
total += num
|
||||||
|
return total"""
|
||||||
|
|
||||||
|
test_file = self.create_additional_test_file("sum_function.py", test_code)
|
||||||
|
|
||||||
|
response2, _ = self.call_mcp_tool(
|
||||||
|
"codereview",
|
||||||
|
{
|
||||||
|
"files": [test_file],
|
||||||
|
"prompt": "Quick review of this sum function",
|
||||||
|
"model": "flash",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response2:
|
||||||
|
self.logger.error(" ❌ Flash model mapping to OpenRouter failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ Flash model successfully mapped to OpenRouter")
|
||||||
|
|
||||||
|
# Test 3: Pro model should map to OpenRouter equivalent
|
||||||
|
self.logger.info(" 3: Testing pro model mapping to OpenRouter")
|
||||||
|
|
||||||
|
response3, _ = self.call_mcp_tool(
|
||||||
|
"analyze",
|
||||||
|
{
|
||||||
|
"files": [self.test_files["python"]],
|
||||||
|
"prompt": "Analyze the structure of this Python code",
|
||||||
|
"model": "pro",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response3:
|
||||||
|
self.logger.error(" ❌ Pro model mapping to OpenRouter failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ Pro model successfully mapped to OpenRouter")
|
||||||
|
|
||||||
|
# Test 4: Debug tool with OpenRouter
|
||||||
|
self.logger.info(" 4: Testing debug tool with OpenRouter")
|
||||||
|
|
||||||
|
response4, _ = self.call_mcp_tool(
|
||||||
|
"debug",
|
||||||
|
{
|
||||||
|
"prompt": "Why might a function return None instead of a value?",
|
||||||
|
"model": "flash", # Should map to OpenRouter
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response4:
|
||||||
|
self.logger.error(" ❌ Debug tool with OpenRouter failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ Debug tool working with OpenRouter")
|
||||||
|
|
||||||
|
# Test 5: Validate logs show OpenRouter is being used
|
||||||
|
self.logger.info(" 5: Validating OpenRouter is the active provider")
|
||||||
|
logs = self.get_recent_server_logs()
|
||||||
|
|
||||||
|
# Check for provider fallback logs
|
||||||
|
fallback_logs = [
|
||||||
|
line
|
||||||
|
for line in logs.split("\n")
|
||||||
|
if "No Gemini API key found" in line
|
||||||
|
or "No OpenAI API key found" in line
|
||||||
|
or "Only OpenRouter available" in line
|
||||||
|
or "Using OpenRouter" in line
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check for OpenRouter provider initialization
|
||||||
|
provider_logs = [
|
||||||
|
line
|
||||||
|
for line in logs.split("\n")
|
||||||
|
if "OpenRouter provider" in line or "OpenRouterProvider" in line or "openrouter.ai/api/v1" in line
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check for model resolution through OpenRouter
|
||||||
|
model_resolution_logs = [
|
||||||
|
line
|
||||||
|
for line in logs.split("\n")
|
||||||
|
if ("Resolved model" in line and "via OpenRouter" in line)
|
||||||
|
or ("Model alias" in line and "resolved to" in line)
|
||||||
|
or ("flash" in line and "gemini-flash" in line)
|
||||||
|
or ("pro" in line and "gemini-pro" in line)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Log findings
|
||||||
|
self.logger.info(f" Fallback indication logs: {len(fallback_logs)}")
|
||||||
|
self.logger.info(f" OpenRouter provider logs: {len(provider_logs)}")
|
||||||
|
self.logger.info(f" Model resolution logs: {len(model_resolution_logs)}")
|
||||||
|
|
||||||
|
# Sample logs for debugging
|
||||||
|
if self.verbose:
|
||||||
|
if fallback_logs:
|
||||||
|
self.logger.debug(" 📋 Sample fallback logs:")
|
||||||
|
for log in fallback_logs[:3]:
|
||||||
|
self.logger.debug(f" {log}")
|
||||||
|
|
||||||
|
if provider_logs:
|
||||||
|
self.logger.debug(" 📋 Sample provider logs:")
|
||||||
|
for log in provider_logs[:3]:
|
||||||
|
self.logger.debug(f" {log}")
|
||||||
|
|
||||||
|
# Success criteria
|
||||||
|
openrouter_active = len(provider_logs) > 0
|
||||||
|
models_resolved = len(model_resolution_logs) > 0
|
||||||
|
all_tools_worked = True # We checked this above
|
||||||
|
|
||||||
|
success_criteria = [
|
||||||
|
("OpenRouter provider active", openrouter_active),
|
||||||
|
("Models resolved through OpenRouter", models_resolved),
|
||||||
|
("All tools worked with OpenRouter", all_tools_worked),
|
||||||
|
]
|
||||||
|
|
||||||
|
passed_criteria = sum(1 for _, passed in success_criteria if passed)
|
||||||
|
self.logger.info(f" Success criteria met: {passed_criteria}/{len(success_criteria)}")
|
||||||
|
|
||||||
|
for criterion, passed in success_criteria:
|
||||||
|
status = "✅" if passed else "❌"
|
||||||
|
self.logger.info(f" {status} {criterion}")
|
||||||
|
|
||||||
|
if passed_criteria >= 2: # At least 2 out of 3 criteria
|
||||||
|
self.logger.info(" ✅ OpenRouter fallback test passed")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.logger.error(" ❌ OpenRouter fallback test failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"OpenRouter fallback test failed: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
self.cleanup_test_files()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run the OpenRouter fallback tests"""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
verbose = "--verbose" in sys.argv or "-v" in sys.argv
|
||||||
|
test = OpenRouterFallbackTest(verbose=verbose)
|
||||||
|
|
||||||
|
success = test.run_test()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
275
simulator_tests/test_openrouter_models.py
Normal file
275
simulator_tests/test_openrouter_models.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
OpenRouter Model Tests
|
||||||
|
|
||||||
|
Tests that verify OpenRouter functionality including:
|
||||||
|
- Model alias resolution (flash, pro, o3, etc. map to OpenRouter equivalents)
|
||||||
|
- Multiple OpenRouter models work correctly
|
||||||
|
- Conversation continuity works with OpenRouter models
|
||||||
|
- Error handling when models are not available
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from .base_test import BaseSimulatorTest
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterModelsTest(BaseSimulatorTest):
|
||||||
|
"""Test OpenRouter model functionality and alias mapping"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_name(self) -> str:
|
||||||
|
return "openrouter_models"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_description(self) -> str:
|
||||||
|
return "OpenRouter model functionality and alias mapping"
|
||||||
|
|
||||||
|
def get_recent_server_logs(self) -> str:
|
||||||
|
"""Get recent server logs from the log file directly"""
|
||||||
|
try:
|
||||||
|
# Read logs directly from the log file
|
||||||
|
cmd = ["docker", "exec", self.container_name, "tail", "-n", "500", "/tmp/mcp_server.log"]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
return result.stdout
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"Failed to read server logs: {result.stderr}")
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to get server logs: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def run_test(self) -> bool:
|
||||||
|
"""Test OpenRouter model functionality"""
|
||||||
|
try:
|
||||||
|
self.logger.info("Test: OpenRouter model functionality and alias mapping")
|
||||||
|
|
||||||
|
# Check if OpenRouter API key is configured
|
||||||
|
check_cmd = [
|
||||||
|
"docker",
|
||||||
|
"exec",
|
||||||
|
self.container_name,
|
||||||
|
"python",
|
||||||
|
"-c",
|
||||||
|
'import os; print("OPENROUTER_KEY:" + str(bool(os.environ.get("OPENROUTER_API_KEY"))))',
|
||||||
|
]
|
||||||
|
result = subprocess.run(check_cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0 and "OPENROUTER_KEY:False" in result.stdout:
|
||||||
|
self.logger.info(" ⚠️ OpenRouter API key not configured - skipping test")
|
||||||
|
self.logger.info(" ℹ️ This test requires OPENROUTER_API_KEY to be set in .env")
|
||||||
|
return True # Return True to indicate test is skipped, not failed
|
||||||
|
|
||||||
|
# Setup test files for later use
|
||||||
|
self.setup_test_files()
|
||||||
|
|
||||||
|
# Test 1: Flash alias mapping to OpenRouter
|
||||||
|
self.logger.info(" 1: Testing 'flash' alias (should map to google/gemini-flash-1.5-8b)")
|
||||||
|
|
||||||
|
response1, continuation_id = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Say 'Hello from Flash model!' and nothing else.",
|
||||||
|
"model": "flash",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response1:
|
||||||
|
self.logger.error(" ❌ Flash alias test failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ Flash alias call completed")
|
||||||
|
if continuation_id:
|
||||||
|
self.logger.info(f" ✅ Got continuation_id: {continuation_id}")
|
||||||
|
|
||||||
|
# Test 2: Pro alias mapping to OpenRouter
|
||||||
|
self.logger.info(" 2: Testing 'pro' alias (should map to google/gemini-pro-1.5)")
|
||||||
|
|
||||||
|
response2, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Say 'Hello from Pro model!' and nothing else.",
|
||||||
|
"model": "pro",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response2:
|
||||||
|
self.logger.error(" ❌ Pro alias test failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ Pro alias call completed")
|
||||||
|
|
||||||
|
# Test 3: O3 alias mapping to OpenRouter (should map to openai/gpt-4o)
|
||||||
|
self.logger.info(" 3: Testing 'o3' alias (should map to openai/gpt-4o)")
|
||||||
|
|
||||||
|
response3, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Say 'Hello from O3 model!' and nothing else.",
|
||||||
|
"model": "o3",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response3:
|
||||||
|
self.logger.error(" ❌ O3 alias test failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ O3 alias call completed")
|
||||||
|
|
||||||
|
# Test 4: Direct OpenRouter model name
|
||||||
|
self.logger.info(" 4: Testing direct OpenRouter model name (anthropic/claude-3-haiku)")
|
||||||
|
|
||||||
|
response4, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Say 'Hello from Claude Haiku!' and nothing else.",
|
||||||
|
"model": "anthropic/claude-3-haiku",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response4:
|
||||||
|
self.logger.error(" ❌ Direct OpenRouter model test failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ Direct OpenRouter model call completed")
|
||||||
|
|
||||||
|
# Test 5: OpenRouter alias from config
|
||||||
|
self.logger.info(" 5: Testing OpenRouter alias from config ('opus' -> anthropic/claude-3-opus)")
|
||||||
|
|
||||||
|
response5, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Say 'Hello from Opus!' and nothing else.",
|
||||||
|
"model": "opus",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response5:
|
||||||
|
self.logger.error(" ❌ OpenRouter alias test failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ OpenRouter alias call completed")
|
||||||
|
|
||||||
|
# Test 6: Conversation continuity with OpenRouter models
|
||||||
|
self.logger.info(" 6: Testing conversation continuity with OpenRouter")
|
||||||
|
|
||||||
|
response6, new_continuation_id = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Remember this number: 42. What number did I just tell you?",
|
||||||
|
"model": "sonnet", # Claude Sonnet via OpenRouter
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response6 or not new_continuation_id:
|
||||||
|
self.logger.error(" ❌ Failed to start conversation with continuation_id")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Continue the conversation
|
||||||
|
response7, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "What was the number I told you earlier?",
|
||||||
|
"model": "sonnet",
|
||||||
|
"continuation_id": new_continuation_id,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response7:
|
||||||
|
self.logger.error(" ❌ Failed to continue conversation")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the model remembered the number
|
||||||
|
if "42" in response7:
|
||||||
|
self.logger.info(" ✅ Conversation continuity working with OpenRouter")
|
||||||
|
else:
|
||||||
|
self.logger.warning(" ⚠️ Model may not have remembered the number")
|
||||||
|
|
||||||
|
# Test 7: Validate OpenRouter API usage from logs
|
||||||
|
self.logger.info(" 7: Validating OpenRouter API usage in logs")
|
||||||
|
logs = self.get_recent_server_logs()
|
||||||
|
|
||||||
|
# Check for OpenRouter API calls
|
||||||
|
openrouter_logs = [line for line in logs.split("\n") if "openrouter" in line.lower()]
|
||||||
|
openrouter_api_logs = [line for line in logs.split("\n") if "openrouter.ai/api/v1" in line]
|
||||||
|
|
||||||
|
# Check for specific model mappings
|
||||||
|
flash_mapping_logs = [
|
||||||
|
line
|
||||||
|
for line in logs.split("\n")
|
||||||
|
if ("flash" in line and "google/gemini-flash" in line)
|
||||||
|
or ("Resolved model" in line and "google/gemini-flash" in line)
|
||||||
|
]
|
||||||
|
|
||||||
|
pro_mapping_logs = [
|
||||||
|
line
|
||||||
|
for line in logs.split("\n")
|
||||||
|
if ("pro" in line and "google/gemini-pro" in line)
|
||||||
|
or ("Resolved model" in line and "google/gemini-pro" in line)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Log findings
|
||||||
|
self.logger.info(f" OpenRouter-related logs: {len(openrouter_logs)}")
|
||||||
|
self.logger.info(f" OpenRouter API logs: {len(openrouter_api_logs)}")
|
||||||
|
self.logger.info(f" Flash mapping logs: {len(flash_mapping_logs)}")
|
||||||
|
self.logger.info(f" Pro mapping logs: {len(pro_mapping_logs)}")
|
||||||
|
|
||||||
|
# Sample log output for debugging
|
||||||
|
if self.verbose and openrouter_logs:
|
||||||
|
self.logger.debug(" 📋 Sample OpenRouter logs:")
|
||||||
|
for log in openrouter_logs[:5]:
|
||||||
|
self.logger.debug(f" {log}")
|
||||||
|
|
||||||
|
# Success criteria
|
||||||
|
openrouter_api_used = len(openrouter_api_logs) > 0
|
||||||
|
models_mapped = len(flash_mapping_logs) > 0 or len(pro_mapping_logs) > 0
|
||||||
|
|
||||||
|
success_criteria = [
|
||||||
|
("OpenRouter API calls made", openrouter_api_used),
|
||||||
|
("Model aliases mapped correctly", models_mapped),
|
||||||
|
("All model calls succeeded", True), # We already checked this above
|
||||||
|
]
|
||||||
|
|
||||||
|
passed_criteria = sum(1 for _, passed in success_criteria if passed)
|
||||||
|
self.logger.info(f" Success criteria met: {passed_criteria}/{len(success_criteria)}")
|
||||||
|
|
||||||
|
for criterion, passed in success_criteria:
|
||||||
|
status = "✅" if passed else "❌"
|
||||||
|
self.logger.info(f" {status} {criterion}")
|
||||||
|
|
||||||
|
if passed_criteria >= 2: # At least 2 out of 3 criteria
|
||||||
|
self.logger.info(" ✅ OpenRouter model tests passed")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.logger.error(" ❌ OpenRouter model tests failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"OpenRouter model test failed: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
self.cleanup_test_files()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run the OpenRouter model tests"""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
verbose = "--verbose" in sys.argv or "-v" in sys.argv
|
||||||
|
test = OpenRouterModelsTest(verbose=verbose)
|
||||||
|
|
||||||
|
success = test.run_test()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
41
test_mapping.py
Normal file
41
test_mapping.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test OpenRouter model mapping
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append("/Users/fahad/Developer/gemini-mcp-server")
|
||||||
|
|
||||||
|
from simulator_tests.base_test import BaseSimulatorTest
|
||||||
|
|
||||||
|
|
||||||
|
class MappingTest(BaseSimulatorTest):
|
||||||
|
def test_mapping(self):
|
||||||
|
"""Test model alias mapping"""
|
||||||
|
|
||||||
|
# Test with 'flash' alias - should map to google/gemini-flash-1.5-8b
|
||||||
|
print("\nTesting 'flash' alias mapping...")
|
||||||
|
|
||||||
|
response, continuation_id = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Say 'Hello from Flash model!'",
|
||||||
|
"model": "flash", # Should be mapped to google/gemini-flash-1.5-8b
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response:
|
||||||
|
print("✅ Flash alias worked!")
|
||||||
|
print(f"Response: {response[:200]}...")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Flash alias failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test = MappingTest(verbose=False)
|
||||||
|
success = test.test_mapping()
|
||||||
|
print(f"\nTest result: {'Success' if success else 'Failed'}")
|
||||||
113
test_model_mapping.py
Executable file
113
test_model_mapping.py
Executable file
@@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple test script to demonstrate model mapping through the MCP server.
|
||||||
|
Tests how model aliases (flash, pro, o3) are mapped to OpenRouter models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def call_mcp_server(model: str, message: str = "Hello, which model are you?") -> dict[str, Any]:
|
||||||
|
"""Call the MCP server with a specific model and return the response."""
|
||||||
|
|
||||||
|
# Prepare the request
|
||||||
|
request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "completion",
|
||||||
|
"params": {"model": model, "messages": [{"role": "user", "content": message}], "max_tokens": 100},
|
||||||
|
"id": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the server
|
||||||
|
cmd = [sys.executable, "server.py"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send request to stdin and capture output
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
stdout, stderr = process.communicate(input=json.dumps(request))
|
||||||
|
|
||||||
|
if process.returncode != 0:
|
||||||
|
return {"error": f"Server returned non-zero exit code: {process.returncode}", "stderr": stderr}
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
try:
|
||||||
|
response = json.loads(stdout)
|
||||||
|
return response
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {"error": "Failed to parse JSON response", "stdout": stdout, "stderr": stderr}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Failed to call server: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_model_info(response: dict[str, Any]) -> dict[str, str]:
|
||||||
|
"""Extract model information from the response."""
|
||||||
|
|
||||||
|
if "error" in response:
|
||||||
|
return {"status": "error", "message": response.get("error", "Unknown error")}
|
||||||
|
|
||||||
|
# Look for result in the response
|
||||||
|
result = response.get("result", {})
|
||||||
|
|
||||||
|
# Extract relevant information
|
||||||
|
info = {"status": "success", "provider": "unknown", "model": "unknown"}
|
||||||
|
|
||||||
|
# Try to find provider and model info in the response
|
||||||
|
# This might be in metadata or debug info depending on server implementation
|
||||||
|
if "metadata" in result:
|
||||||
|
metadata = result["metadata"]
|
||||||
|
info["provider"] = metadata.get("provider", "unknown")
|
||||||
|
info["model"] = metadata.get("model", "unknown")
|
||||||
|
|
||||||
|
# Also check if the model info is in the response content itself
|
||||||
|
if "content" in result:
|
||||||
|
content = result["content"]
|
||||||
|
# Simple heuristic to detect OpenRouter models
|
||||||
|
if "openrouter" in content.lower() or any(x in content.lower() for x in ["claude", "gpt", "gemini"]):
|
||||||
|
info["provider"] = "openrouter"
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Test model mapping for different aliases."""
|
||||||
|
|
||||||
|
print("Model Mapping Test for MCP Server")
|
||||||
|
print("=" * 50)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test models
|
||||||
|
test_models = ["flash", "pro", "o3"]
|
||||||
|
|
||||||
|
for model in test_models:
|
||||||
|
print(f"Testing model: '{model}'")
|
||||||
|
print("-" * 30)
|
||||||
|
|
||||||
|
response = call_mcp_server(model)
|
||||||
|
model_info = extract_model_info(response)
|
||||||
|
|
||||||
|
if model_info["status"] == "error":
|
||||||
|
print(f" ❌ Error: {model_info['message']}")
|
||||||
|
else:
|
||||||
|
print(f" ✓ Provider: {model_info['provider']}")
|
||||||
|
print(f" ✓ Model: {model_info['model']}")
|
||||||
|
|
||||||
|
# Print raw response for debugging
|
||||||
|
if "--debug" in sys.argv:
|
||||||
|
print("\nDebug - Raw Response:")
|
||||||
|
print(json.dumps(response, indent=2))
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("\nNote: This test assumes the MCP server is configured with OpenRouter.")
|
||||||
|
print("The actual model mappings depend on the server configuration.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -97,7 +97,8 @@ class TestAutoMode:
|
|||||||
# Model field should have simpler description
|
# Model field should have simpler description
|
||||||
model_schema = schema["properties"]["model"]
|
model_schema = schema["properties"]["model"]
|
||||||
assert "enum" not in model_schema
|
assert "enum" not in model_schema
|
||||||
assert "Available:" in model_schema["description"]
|
assert "Native models:" in model_schema["description"]
|
||||||
|
assert "Defaults to" in model_schema["description"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_auto_mode_requires_model_parameter(self):
|
async def test_auto_mode_requires_model_parameter(self):
|
||||||
@@ -180,8 +181,9 @@ class TestAutoMode:
|
|||||||
|
|
||||||
schema = tool.get_model_field_schema()
|
schema = tool.get_model_field_schema()
|
||||||
assert "enum" not in schema
|
assert "enum" not in schema
|
||||||
assert "Available:" in schema["description"]
|
assert "Native models:" in schema["description"]
|
||||||
assert "'pro'" in schema["description"]
|
assert "'pro'" in schema["description"]
|
||||||
|
assert "Defaults to" in schema["description"]
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore
|
# Restore
|
||||||
|
|||||||
197
tests/test_openrouter_provider.py
Normal file
197
tests/test_openrouter_provider.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""Tests for OpenRouter provider."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenRouterProvider:
|
||||||
|
"""Test cases for OpenRouter provider."""
|
||||||
|
|
||||||
|
def test_provider_initialization(self):
|
||||||
|
"""Test OpenRouter provider initialization."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.base_url == "https://openrouter.ai/api/v1"
|
||||||
|
assert provider.FRIENDLY_NAME == "OpenRouter"
|
||||||
|
|
||||||
|
def test_custom_headers(self):
|
||||||
|
"""Test OpenRouter custom headers."""
|
||||||
|
# Test default headers
|
||||||
|
assert "HTTP-Referer" in OpenRouterProvider.DEFAULT_HEADERS
|
||||||
|
assert "X-Title" in OpenRouterProvider.DEFAULT_HEADERS
|
||||||
|
|
||||||
|
# Test with environment variables
|
||||||
|
with patch.dict(os.environ, {"OPENROUTER_REFERER": "https://myapp.com", "OPENROUTER_TITLE": "My App"}):
|
||||||
|
from importlib import reload
|
||||||
|
|
||||||
|
import providers.openrouter
|
||||||
|
|
||||||
|
reload(providers.openrouter)
|
||||||
|
|
||||||
|
provider = providers.openrouter.OpenRouterProvider(api_key="test-key")
|
||||||
|
assert provider.DEFAULT_HEADERS["HTTP-Referer"] == "https://myapp.com"
|
||||||
|
assert provider.DEFAULT_HEADERS["X-Title"] == "My App"
|
||||||
|
|
||||||
|
def test_model_validation(self):
|
||||||
|
"""Test model validation."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Should accept any model - OpenRouter handles validation
|
||||||
|
assert provider.validate_model_name("gpt-4") is True
|
||||||
|
assert provider.validate_model_name("claude-3-opus") is True
|
||||||
|
assert provider.validate_model_name("any-model-name") is True
|
||||||
|
assert provider.validate_model_name("GPT-4") is True
|
||||||
|
assert provider.validate_model_name("unknown-model") is True
|
||||||
|
|
||||||
|
def test_get_capabilities(self):
|
||||||
|
"""Test capability generation."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Test with a model in the registry (using alias)
|
||||||
|
caps = provider.get_capabilities("gpt4o")
|
||||||
|
assert caps.provider == ProviderType.OPENROUTER
|
||||||
|
assert caps.model_name == "openai/gpt-4o" # Resolved name
|
||||||
|
assert caps.friendly_name == "OpenRouter"
|
||||||
|
|
||||||
|
# Test with a model not in registry - should get generic capabilities
|
||||||
|
caps = provider.get_capabilities("unknown-model")
|
||||||
|
assert caps.provider == ProviderType.OPENROUTER
|
||||||
|
assert caps.model_name == "unknown-model"
|
||||||
|
assert caps.max_tokens == 32_768 # Safe default
|
||||||
|
assert hasattr(caps, "_is_generic") and caps._is_generic is True
|
||||||
|
|
||||||
|
def test_model_alias_resolution(self):
|
||||||
|
"""Test model alias resolution."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Test alias resolution
|
||||||
|
assert provider._resolve_model_name("opus") == "anthropic/claude-3-opus"
|
||||||
|
assert provider._resolve_model_name("sonnet") == "anthropic/claude-3-sonnet"
|
||||||
|
assert provider._resolve_model_name("gpt4o") == "openai/gpt-4o"
|
||||||
|
assert provider._resolve_model_name("4o") == "openai/gpt-4o"
|
||||||
|
assert provider._resolve_model_name("claude") == "anthropic/claude-3-sonnet"
|
||||||
|
assert provider._resolve_model_name("mistral") == "mistral/mistral-large"
|
||||||
|
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-coder"
|
||||||
|
assert provider._resolve_model_name("coder") == "deepseek/deepseek-coder"
|
||||||
|
|
||||||
|
# Test case-insensitive
|
||||||
|
assert provider._resolve_model_name("OPUS") == "anthropic/claude-3-opus"
|
||||||
|
assert provider._resolve_model_name("GPT4O") == "openai/gpt-4o"
|
||||||
|
assert provider._resolve_model_name("Mistral") == "mistral/mistral-large"
|
||||||
|
assert provider._resolve_model_name("CLAUDE") == "anthropic/claude-3-sonnet"
|
||||||
|
|
||||||
|
# Test direct model names (should pass through unchanged)
|
||||||
|
assert provider._resolve_model_name("anthropic/claude-3-opus") == "anthropic/claude-3-opus"
|
||||||
|
assert provider._resolve_model_name("openai/gpt-4o") == "openai/gpt-4o"
|
||||||
|
|
||||||
|
# Test unknown models pass through
|
||||||
|
assert provider._resolve_model_name("unknown-model") == "unknown-model"
|
||||||
|
assert provider._resolve_model_name("custom/model-v2") == "custom/model-v2"
|
||||||
|
|
||||||
|
def test_openrouter_registration(self):
|
||||||
|
"""Test OpenRouter can be registered and retrieved."""
|
||||||
|
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||||
|
# Clean up any existing registration
|
||||||
|
ModelProviderRegistry.unregister_provider(ProviderType.OPENROUTER)
|
||||||
|
|
||||||
|
# Register the provider
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
|
# Retrieve and verify
|
||||||
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||||
|
assert provider is not None
|
||||||
|
assert isinstance(provider, OpenRouterProvider)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenRouterRegistry:
|
||||||
|
"""Test cases for OpenRouter model registry."""
|
||||||
|
|
||||||
|
def test_registry_loading(self):
|
||||||
|
"""Test registry loads models from config."""
|
||||||
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
# Should have loaded models
|
||||||
|
models = registry.list_models()
|
||||||
|
assert len(models) > 0
|
||||||
|
assert "anthropic/claude-3-opus" in models
|
||||||
|
assert "openai/gpt-4o" in models
|
||||||
|
|
||||||
|
# Should have loaded aliases
|
||||||
|
aliases = registry.list_aliases()
|
||||||
|
assert len(aliases) > 0
|
||||||
|
assert "opus" in aliases
|
||||||
|
assert "gpt4o" in aliases
|
||||||
|
assert "claude" in aliases
|
||||||
|
|
||||||
|
def test_registry_capabilities(self):
|
||||||
|
"""Test registry provides correct capabilities."""
|
||||||
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
# Test known model
|
||||||
|
caps = registry.get_capabilities("opus")
|
||||||
|
assert caps is not None
|
||||||
|
assert caps.model_name == "anthropic/claude-3-opus"
|
||||||
|
assert caps.max_tokens == 200000 # Claude's context window
|
||||||
|
|
||||||
|
# Test using full model name
|
||||||
|
caps = registry.get_capabilities("anthropic/claude-3-opus")
|
||||||
|
assert caps is not None
|
||||||
|
assert caps.model_name == "anthropic/claude-3-opus"
|
||||||
|
|
||||||
|
# Test unknown model
|
||||||
|
caps = registry.get_capabilities("non-existent-model")
|
||||||
|
assert caps is None
|
||||||
|
|
||||||
|
def test_multiple_aliases_same_model(self):
|
||||||
|
"""Test multiple aliases pointing to same model."""
|
||||||
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
# All these should resolve to Claude Sonnet
|
||||||
|
sonnet_aliases = ["sonnet", "claude", "claude-sonnet", "claude3-sonnet"]
|
||||||
|
for alias in sonnet_aliases:
|
||||||
|
config = registry.resolve(alias)
|
||||||
|
assert config is not None
|
||||||
|
assert config.model_name == "anthropic/claude-3-sonnet"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenRouterFunctionality:
|
||||||
|
"""Test OpenRouter-specific functionality."""
|
||||||
|
|
||||||
|
def test_openrouter_always_uses_correct_url(self):
|
||||||
|
"""Test that OpenRouter always uses the correct base URL."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
assert provider.base_url == "https://openrouter.ai/api/v1"
|
||||||
|
|
||||||
|
# Even if we try to change it, it should remain the OpenRouter URL
|
||||||
|
# (This is a characteristic of the OpenRouter provider)
|
||||||
|
provider.base_url = "http://example.com" # Try to change it
|
||||||
|
# But new instances should always use the correct URL
|
||||||
|
provider2 = OpenRouterProvider(api_key="test-key")
|
||||||
|
assert provider2.base_url == "https://openrouter.ai/api/v1"
|
||||||
|
|
||||||
|
def test_openrouter_headers_set_correctly(self):
|
||||||
|
"""Test that OpenRouter specific headers are set."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Check default headers
|
||||||
|
assert "HTTP-Referer" in provider.DEFAULT_HEADERS
|
||||||
|
assert "X-Title" in provider.DEFAULT_HEADERS
|
||||||
|
assert provider.DEFAULT_HEADERS["X-Title"] == "Zen MCP Server"
|
||||||
|
|
||||||
|
def test_openrouter_model_registry_initialized(self):
|
||||||
|
"""Test that model registry is properly initialized."""
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
|
||||||
|
# Registry should be initialized
|
||||||
|
assert hasattr(provider, "_registry")
|
||||||
|
assert provider._registry is not None
|
||||||
223
tests/test_openrouter_registry.py
Normal file
223
tests/test_openrouter_registry.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""Tests for OpenRouter model registry functionality."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenRouterModelRegistry:
|
||||||
|
"""Test cases for OpenRouter model registry."""
|
||||||
|
|
||||||
|
def test_registry_initialization(self):
|
||||||
|
"""Test registry initializes with default config."""
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
# Should load models from default location
|
||||||
|
assert len(registry.list_models()) > 0
|
||||||
|
assert len(registry.list_aliases()) > 0
|
||||||
|
|
||||||
|
def test_custom_config_path(self):
|
||||||
|
"""Test registry with custom config path."""
|
||||||
|
# Create temporary config
|
||||||
|
config_data = {"models": [{"model_name": "test/model-1", "aliases": ["test1", "t1"], "context_window": 4096}]}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
|
json.dump(config_data, f)
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
registry = OpenRouterModelRegistry(config_path=temp_path)
|
||||||
|
assert len(registry.list_models()) == 1
|
||||||
|
assert "test/model-1" in registry.list_models()
|
||||||
|
assert "test1" in registry.list_aliases()
|
||||||
|
assert "t1" in registry.list_aliases()
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_environment_variable_override(self):
|
||||||
|
"""Test OPENROUTER_MODELS_PATH environment variable."""
|
||||||
|
# Create custom config
|
||||||
|
config_data = {"models": [{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192}]}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
|
json.dump(config_data, f)
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set environment variable
|
||||||
|
original_env = os.environ.get("OPENROUTER_MODELS_PATH")
|
||||||
|
os.environ["OPENROUTER_MODELS_PATH"] = temp_path
|
||||||
|
|
||||||
|
# Create registry without explicit path
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
# Should load from environment path
|
||||||
|
assert "env/model" in registry.list_models()
|
||||||
|
assert "envtest" in registry.list_aliases()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore environment
|
||||||
|
if original_env is not None:
|
||||||
|
os.environ["OPENROUTER_MODELS_PATH"] = original_env
|
||||||
|
else:
|
||||||
|
del os.environ["OPENROUTER_MODELS_PATH"]
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_alias_resolution(self):
|
||||||
|
"""Test alias resolution functionality."""
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
# Test various aliases
|
||||||
|
test_cases = [
|
||||||
|
("opus", "anthropic/claude-3-opus"),
|
||||||
|
("OPUS", "anthropic/claude-3-opus"), # Case insensitive
|
||||||
|
("claude", "anthropic/claude-3-sonnet"),
|
||||||
|
("gpt4o", "openai/gpt-4o"),
|
||||||
|
("4o", "openai/gpt-4o"),
|
||||||
|
("mistral", "mistral/mistral-large"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for alias, expected_model in test_cases:
|
||||||
|
config = registry.resolve(alias)
|
||||||
|
assert config is not None, f"Failed to resolve alias '{alias}'"
|
||||||
|
assert config.model_name == expected_model
|
||||||
|
|
||||||
|
def test_direct_model_name_lookup(self):
|
||||||
|
"""Test looking up models by their full name."""
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
# Should be able to look up by full model name
|
||||||
|
config = registry.resolve("anthropic/claude-3-opus")
|
||||||
|
assert config is not None
|
||||||
|
assert config.model_name == "anthropic/claude-3-opus"
|
||||||
|
|
||||||
|
config = registry.resolve("openai/gpt-4o")
|
||||||
|
assert config is not None
|
||||||
|
assert config.model_name == "openai/gpt-4o"
|
||||||
|
|
||||||
|
def test_unknown_model_resolution(self):
|
||||||
|
"""Test resolution of unknown models."""
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
# Unknown aliases should return None
|
||||||
|
assert registry.resolve("unknown-alias") is None
|
||||||
|
assert registry.resolve("") is None
|
||||||
|
assert registry.resolve("non-existent") is None
|
||||||
|
|
||||||
|
def test_model_capabilities_conversion(self):
|
||||||
|
"""Test conversion to ModelCapabilities."""
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
config = registry.resolve("opus")
|
||||||
|
assert config is not None
|
||||||
|
|
||||||
|
caps = config.to_capabilities()
|
||||||
|
assert caps.provider == ProviderType.OPENROUTER
|
||||||
|
assert caps.model_name == "anthropic/claude-3-opus"
|
||||||
|
assert caps.friendly_name == "OpenRouter"
|
||||||
|
assert caps.max_tokens == 200000
|
||||||
|
assert not caps.supports_extended_thinking
|
||||||
|
|
||||||
|
def test_duplicate_alias_detection(self):
|
||||||
|
"""Test that duplicate aliases are detected."""
|
||||||
|
config_data = {
|
||||||
|
"models": [
|
||||||
|
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096},
|
||||||
|
{
|
||||||
|
"model_name": "test/model-2",
|
||||||
|
"aliases": ["DUPE"], # Same alias, different case
|
||||||
|
"context_window": 8192,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
|
json.dump(config_data, f)
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(ValueError, match="Duplicate alias"):
|
||||||
|
OpenRouterModelRegistry(config_path=temp_path)
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_backwards_compatibility_max_tokens(self):
|
||||||
|
"""Test backwards compatibility with old max_tokens field."""
|
||||||
|
config_data = {
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"model_name": "test/old-model",
|
||||||
|
"aliases": ["old"],
|
||||||
|
"max_tokens": 16384, # Old field name
|
||||||
|
"supports_extended_thinking": False,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
|
json.dump(config_data, f)
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
registry = OpenRouterModelRegistry(config_path=temp_path)
|
||||||
|
config = registry.resolve("old")
|
||||||
|
|
||||||
|
assert config is not None
|
||||||
|
assert config.context_window == 16384 # Should be converted
|
||||||
|
|
||||||
|
# Check capabilities still work
|
||||||
|
caps = config.to_capabilities()
|
||||||
|
assert caps.max_tokens == 16384
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_missing_config_file(self):
|
||||||
|
"""Test behavior with missing config file."""
|
||||||
|
# Use a non-existent path
|
||||||
|
registry = OpenRouterModelRegistry(config_path="/non/existent/path.json")
|
||||||
|
|
||||||
|
# Should initialize with empty maps
|
||||||
|
assert len(registry.list_models()) == 0
|
||||||
|
assert len(registry.list_aliases()) == 0
|
||||||
|
assert registry.resolve("anything") is None
|
||||||
|
|
||||||
|
def test_invalid_json_config(self):
|
||||||
|
"""Test handling of invalid JSON."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
|
f.write("{ invalid json }")
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
registry = OpenRouterModelRegistry(config_path=temp_path)
|
||||||
|
# Should handle gracefully and initialize empty
|
||||||
|
assert len(registry.list_models()) == 0
|
||||||
|
assert len(registry.list_aliases()) == 0
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_model_with_all_capabilities(self):
|
||||||
|
"""Test model with all capability flags."""
|
||||||
|
config = OpenRouterModelConfig(
|
||||||
|
model_name="test/full-featured",
|
||||||
|
aliases=["full"],
|
||||||
|
context_window=128000,
|
||||||
|
supports_extended_thinking=True,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
description="Fully featured test model",
|
||||||
|
)
|
||||||
|
|
||||||
|
caps = config.to_capabilities()
|
||||||
|
assert caps.max_tokens == 128000
|
||||||
|
assert caps.supports_extended_thinking
|
||||||
|
assert caps.supports_system_prompts
|
||||||
|
assert caps.supports_streaming
|
||||||
|
assert caps.supports_function_calling
|
||||||
|
# Note: supports_json_mode is not in ModelCapabilities yet
|
||||||
111
tools/base.py
111
tools/base.py
@@ -57,15 +57,28 @@ class ToolRequest(BaseModel):
|
|||||||
# Higher values allow for more complex reasoning but increase latency and cost
|
# Higher values allow for more complex reasoning but increase latency and cost
|
||||||
thinking_mode: Optional[Literal["minimal", "low", "medium", "high", "max"]] = Field(
|
thinking_mode: Optional[Literal["minimal", "low", "medium", "high", "max"]] = Field(
|
||||||
None,
|
None,
|
||||||
description="Thinking depth: minimal (0.5% of model max), low (8%), medium (33%), high (67%), max (100% of model max)",
|
description=(
|
||||||
|
"Thinking depth: minimal (0.5% of model max), low (8%), medium (33%), high (67%), "
|
||||||
|
"max (100% of model max)"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
use_websearch: Optional[bool] = Field(
|
use_websearch: Optional[bool] = Field(
|
||||||
True,
|
True,
|
||||||
description="Enable web search for documentation, best practices, and current information. When enabled, the model can request Claude to perform web searches and share results back during conversations. Particularly useful for: brainstorming sessions, architectural design discussions, exploring industry best practices, working with specific frameworks/technologies, researching solutions to complex problems, or when current documentation and community insights would enhance the analysis.",
|
description=(
|
||||||
|
"Enable web search for documentation, best practices, and current information. "
|
||||||
|
"When enabled, the model can request Claude to perform web searches and share results back "
|
||||||
|
"during conversations. Particularly useful for: brainstorming sessions, architectural design "
|
||||||
|
"discussions, exploring industry best practices, working with specific frameworks/technologies, "
|
||||||
|
"researching solutions to complex problems, or when current documentation and community insights "
|
||||||
|
"would enhance the analysis."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
continuation_id: Optional[str] = Field(
|
continuation_id: Optional[str] = Field(
|
||||||
None,
|
None,
|
||||||
description="Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
description=(
|
||||||
|
"Thread continuation ID for multi-turn conversations. Can be used to continue conversations "
|
||||||
|
"across different tools. Only provide this if continuing a previous conversation thread."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -152,14 +165,77 @@ class BaseTool(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict containing the model field JSON schema
|
Dict containing the model field JSON schema
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
||||||
|
|
||||||
|
# Check if OpenRouter is configured
|
||||||
|
has_openrouter = bool(
|
||||||
|
os.getenv("OPENROUTER_API_KEY") and os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here"
|
||||||
|
)
|
||||||
|
|
||||||
if IS_AUTO_MODE:
|
if IS_AUTO_MODE:
|
||||||
# In auto mode, model is required and we provide detailed descriptions
|
# In auto mode, model is required and we provide detailed descriptions
|
||||||
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
|
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
|
||||||
for model, desc in MODEL_CAPABILITIES_DESC.items():
|
for model, desc in MODEL_CAPABILITIES_DESC.items():
|
||||||
model_desc_parts.append(f"- '{model}': {desc}")
|
model_desc_parts.append(f"- '{model}': {desc}")
|
||||||
|
|
||||||
|
if has_openrouter:
|
||||||
|
# Add OpenRouter models with descriptions
|
||||||
|
try:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
# Group models by their model_name to avoid duplicates
|
||||||
|
seen_models = set()
|
||||||
|
model_configs = []
|
||||||
|
|
||||||
|
for alias in registry.list_aliases():
|
||||||
|
config = registry.resolve(alias)
|
||||||
|
if config and config.model_name not in seen_models:
|
||||||
|
seen_models.add(config.model_name)
|
||||||
|
model_configs.append((alias, config))
|
||||||
|
|
||||||
|
# Sort by context window (descending) then by alias
|
||||||
|
model_configs.sort(key=lambda x: (-x[1].context_window, x[0]))
|
||||||
|
|
||||||
|
if model_configs:
|
||||||
|
model_desc_parts.append("\nOpenRouter models (use these aliases):")
|
||||||
|
for alias, config in model_configs[:10]: # Limit to top 10
|
||||||
|
# Format context window in human-readable form
|
||||||
|
context_tokens = config.context_window
|
||||||
|
if context_tokens >= 1_000_000:
|
||||||
|
context_str = f"{context_tokens // 1_000_000}M"
|
||||||
|
elif context_tokens >= 1_000:
|
||||||
|
context_str = f"{context_tokens // 1_000}K"
|
||||||
|
else:
|
||||||
|
context_str = str(context_tokens)
|
||||||
|
|
||||||
|
# Build description line
|
||||||
|
if config.description:
|
||||||
|
desc = f"- '{alias}' ({context_str} context): {config.description}"
|
||||||
|
else:
|
||||||
|
# Fallback to showing the model name if no description
|
||||||
|
desc = f"- '{alias}' ({context_str} context): {config.model_name}"
|
||||||
|
model_desc_parts.append(desc)
|
||||||
|
|
||||||
|
# Add note about additional models if any were cut off
|
||||||
|
total_models = len(model_configs)
|
||||||
|
if total_models > 10:
|
||||||
|
model_desc_parts.append(f"... and {total_models - 10} more models available")
|
||||||
|
except Exception as e:
|
||||||
|
# Log for debugging but don't fail
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.debug(f"Failed to load OpenRouter model descriptions: {e}")
|
||||||
|
# Fallback to simple message
|
||||||
|
model_desc_parts.append(
|
||||||
|
"\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter."
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "\n".join(model_desc_parts),
|
"description": "\n".join(model_desc_parts),
|
||||||
@@ -169,9 +245,36 @@ class BaseTool(ABC):
|
|||||||
# Normal mode - model is optional with default
|
# Normal mode - model is optional with default
|
||||||
available_models = list(MODEL_CAPABILITIES_DESC.keys())
|
available_models = list(MODEL_CAPABILITIES_DESC.keys())
|
||||||
models_str = ", ".join(f"'{m}'" for m in available_models)
|
models_str = ", ".join(f"'{m}'" for m in available_models)
|
||||||
|
|
||||||
|
description = f"Model to use. Native models: {models_str}."
|
||||||
|
if has_openrouter:
|
||||||
|
# Add OpenRouter aliases
|
||||||
|
try:
|
||||||
|
# Import registry directly to show available aliases
|
||||||
|
# This works even without an API key
|
||||||
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
aliases = registry.list_aliases()
|
||||||
|
|
||||||
|
# Show ALL aliases from the configuration
|
||||||
|
if aliases:
|
||||||
|
# Show all aliases so Claude knows every option available
|
||||||
|
all_aliases = sorted(aliases)
|
||||||
|
alias_list = ", ".join(f"'{a}'" for a in all_aliases)
|
||||||
|
description += f" OpenRouter aliases: {alias_list}."
|
||||||
|
else:
|
||||||
|
description += " OpenRouter: Any model available on openrouter.ai."
|
||||||
|
except Exception:
|
||||||
|
description += (
|
||||||
|
" OpenRouter: Any model available on openrouter.ai "
|
||||||
|
"(e.g., 'gpt-4', 'claude-3-opus', 'mistral-large')."
|
||||||
|
)
|
||||||
|
description += f" Defaults to '{DEFAULT_MODEL}' if not specified."
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": f"Model to use. Available: {models_str}. Defaults to '{DEFAULT_MODEL}' if not specified.",
|
"description": description,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_default_temperature(self) -> float:
|
def get_default_temperature(self) -> float:
|
||||||
|
|||||||
Reference in New Issue
Block a user