From 52b45f2b03ceb674b63bbb104243801be6e1e70a Mon Sep 17 00:00:00 2001 From: Fahad Date: Thu, 12 Jun 2025 22:17:11 +0400 Subject: [PATCH] WIP - OpenRouter support and related refactoring --- .env.example | 8 + README.md | 7 +- docker-compose.yml | 3 + docs/openrouter.md | 52 ++++ providers/__init__.py | 4 + providers/base.py | 1 + providers/openai.py | 108 +------- providers/openai_compatible.py | 417 ++++++++++++++++++++++++++++++ providers/openrouter.py | 119 +++++++++ providers/registry.py | 1 + server.py | 12 +- setup-docker.sh | 28 +- tests/test_openrouter_provider.py | 138 ++++++++++ 13 files changed, 786 insertions(+), 112 deletions(-) create mode 100644 docs/openrouter.md create mode 100644 providers/openai_compatible.py create mode 100644 providers/openrouter.py create mode 100644 tests/test_openrouter_provider.py diff --git a/.env.example b/.env.example index c53d379..6962404 100644 --- a/.env.example +++ b/.env.example @@ -8,6 +8,14 @@ GEMINI_API_KEY=your_gemini_api_key_here # Get your OpenAI API key from: https://platform.openai.com/api-keys OPENAI_API_KEY=your_openai_api_key_here +# Optional: OpenRouter for access to multiple models +# Get your OpenRouter API key from: https://openrouter.ai/ +OPENROUTER_API_KEY=your_openrouter_api_key_here + +# Optional: Restrict which models can be used via OpenRouter (recommended for cost control) +# Example: OPENROUTER_ALLOWED_MODELS=gpt-4,claude-3-opus,mistral-large +OPENROUTER_ALLOWED_MODELS= + # Optional: Default model to use # Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini' # When set to 'auto', Claude will select the best model for each task diff --git a/README.md b/README.md index 0abd47e..2bef327 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan ### 1. Get API Keys (at least one required) - **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models. - **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) to get an API key for O3 model access. +- **OpenRouter**: Visit [OpenRouter](https://openrouter.ai/) for access to multiple models through one API. [Setup Guide](docs/openrouter.md) ### 2. Clone and Set Up @@ -125,12 +126,13 @@ cd zen-mcp-server # Edit .env to add your API keys (if not already set in environment) nano .env -# The file will contain: +# The file will contain, at least one should be set: # GEMINI_API_KEY=your-gemini-api-key-here # For Gemini models # OPENAI_API_KEY=your-openai-api-key-here # For O3 model +# OPENROUTER_API_KEY=your-openrouter-key # For OpenRouter (see docs/openrouter.md) # WORKSPACE_ROOT=/Users/your-username (automatically configured) -# Note: At least one API key is required (Gemini or OpenAI) +# Note: At least one API key is required ``` ### 4. Configure Claude @@ -742,6 +744,7 @@ OPENAI_API_KEY=your-openai-key # Enables O3, O3-mini | **`flash`** (Gemini 2.0 Flash) | Google | 1M tokens | Ultra-fast responses | Quick checks, formatting, simple analysis | | **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis | | **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks | +| **Any model** | OpenRouter | Varies | Access to GPT-4, Claude, Llama, etc. | User-specified or based on task requirements | **Manual Model Selection:** You can specify a default model instead of auto mode: diff --git a/docker-compose.yml b/docker-compose.yml index 812a492..0a4920f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -31,6 +31,9 @@ services: environment: - GEMINI_API_KEY=${GEMINI_API_KEY:-} - OPENAI_API_KEY=${OPENAI_API_KEY:-} + # OpenRouter support + - OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-} + - OPENROUTER_ALLOWED_MODELS=${OPENROUTER_ALLOWED_MODELS:-} - DEFAULT_MODEL=${DEFAULT_MODEL:-auto} - DEFAULT_THINKING_MODE_THINKDEEP=${DEFAULT_THINKING_MODE_THINKDEEP:-high} - REDIS_URL=redis://redis:6379/0 diff --git a/docs/openrouter.md b/docs/openrouter.md new file mode 100644 index 0000000..c081591 --- /dev/null +++ b/docs/openrouter.md @@ -0,0 +1,52 @@ +# OpenRouter Setup + +OpenRouter provides unified access to multiple AI models (GPT-4, Claude, Mistral, etc.) through a single API. + +## Quick Start + +### 1. Get API Key +1. Sign up at [openrouter.ai](https://openrouter.ai/) +2. Create an API key from your dashboard +3. Add credits to your account + +### 2. Set Environment Variable +```bash +# Add to your .env file +OPENROUTER_API_KEY=your-openrouter-api-key +``` + +That's it! Docker Compose already includes all necessary configuration. + +### 3. Use Any Model +``` +# Examples +"Use gpt-4 via zen to review this code" +"Use claude-3-opus via zen to debug this error" +"Use mistral-large via zen to optimize this algorithm" +``` + +## Cost Control (Recommended) + +Restrict which models can be used to prevent unexpected charges: + +```bash +# Add to .env file - only allow specific models +OPENROUTER_ALLOWED_MODELS=gpt-4,claude-3-sonnet,mistral-large +``` + +Check current model pricing at [openrouter.ai/models](https://openrouter.ai/models). + +## Available Models + +Popular models available through OpenRouter: +- **GPT-4** - OpenAI's most capable model +- **Claude 3** - Anthropic's models (Opus, Sonnet, Haiku) +- **Mistral** - Including Mistral Large +- **Llama 3** - Meta's open models +- Many more at [openrouter.ai/models](https://openrouter.ai/models) + +## Troubleshooting + +- **"Model not found"**: Check exact model name at openrouter.ai/models +- **"Insufficient credits"**: Add credits to your OpenRouter account +- **"Model not in allow-list"**: Update `OPENROUTER_ALLOWED_MODELS` in .env \ No newline at end of file diff --git a/providers/__init__.py b/providers/__init__.py index 2ca6162..b36b92e 100644 --- a/providers/__init__.py +++ b/providers/__init__.py @@ -3,6 +3,8 @@ from .base import ModelCapabilities, ModelProvider, ModelResponse from .gemini import GeminiModelProvider from .openai import OpenAIModelProvider +from .openai_compatible import OpenAICompatibleProvider +from .openrouter import OpenRouterProvider from .registry import ModelProviderRegistry __all__ = [ @@ -12,4 +14,6 @@ __all__ = [ "ModelProviderRegistry", "GeminiModelProvider", "OpenAIModelProvider", + "OpenAICompatibleProvider", + "OpenRouterProvider", ] diff --git a/providers/base.py b/providers/base.py index c61ab87..0908fd1 100644 --- a/providers/base.py +++ b/providers/base.py @@ -11,6 +11,7 @@ class ProviderType(Enum): GOOGLE = "google" OPENAI = "openai" + OPENROUTER = "openrouter" class TemperatureConstraint(ABC): diff --git a/providers/openai.py b/providers/openai.py index 6139ad6..e49e295 100644 --- a/providers/openai.py +++ b/providers/openai.py @@ -3,20 +3,18 @@ import logging from typing import Optional -from openai import OpenAI - from .base import ( FixedTemperatureConstraint, ModelCapabilities, - ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint, ) +from .openai_compatible import OpenAICompatibleProvider -class OpenAIModelProvider(ModelProvider): - """OpenAI model provider implementation.""" +class OpenAIModelProvider(OpenAICompatibleProvider): + """Official OpenAI API provider (api.openai.com).""" # Model configurations SUPPORTED_MODELS = { @@ -32,23 +30,10 @@ class OpenAIModelProvider(ModelProvider): def __init__(self, api_key: str, **kwargs): """Initialize OpenAI provider with API key.""" + # Set default OpenAI base URL, allow override for regions/custom endpoints + kwargs.setdefault("base_url", "https://api.openai.com/v1") super().__init__(api_key, **kwargs) - self._client = None - self.base_url = kwargs.get("base_url") # Support custom endpoints - self.organization = kwargs.get("organization") - @property - def client(self): - """Lazy initialization of OpenAI client.""" - if self._client is None: - client_kwargs = {"api_key": self.api_key} - if self.base_url: - client_kwargs["base_url"] = self.base_url - if self.organization: - client_kwargs["organization"] = self.organization - - self._client = OpenAI(**client_kwargs) - return self._client def get_capabilities(self, model_name: str) -> ModelCapabilities: """Get capabilities for a specific OpenAI model.""" @@ -77,79 +62,6 @@ class OpenAIModelProvider(ModelProvider): temperature_constraint=temp_constraint, ) - def generate_content( - self, - prompt: str, - model_name: str, - system_prompt: Optional[str] = None, - temperature: float = 0.7, - max_output_tokens: Optional[int] = None, - **kwargs, - ) -> ModelResponse: - """Generate content using OpenAI model.""" - # Validate parameters - self.validate_parameters(model_name, temperature) - - # Prepare messages - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.append({"role": "user", "content": prompt}) - - # Prepare completion parameters - completion_params = { - "model": model_name, - "messages": messages, - "temperature": temperature, - } - - # Add max tokens if specified - if max_output_tokens: - completion_params["max_tokens"] = max_output_tokens - - # Add any additional OpenAI-specific parameters - for key, value in kwargs.items(): - if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop"]: - completion_params[key] = value - - try: - # Generate completion - response = self.client.chat.completions.create(**completion_params) - - # Extract content and usage - content = response.choices[0].message.content - usage = self._extract_usage(response) - - return ModelResponse( - content=content, - usage=usage, - model_name=model_name, - friendly_name="OpenAI", - provider=ProviderType.OPENAI, - metadata={ - "finish_reason": response.choices[0].finish_reason, - "model": response.model, # Actual model used (in case of fallbacks) - "id": response.id, - "created": response.created, - }, - ) - - except Exception as e: - # Log error and re-raise with more context - error_msg = f"OpenAI API error for model {model_name}: {str(e)}" - logging.error(error_msg) - raise RuntimeError(error_msg) from e - - def count_tokens(self, text: str, model_name: str) -> int: - """Count tokens for the given text. - - Note: For accurate token counting, we should use tiktoken library. - This is a simplified estimation. - """ - # TODO: Implement proper token counting with tiktoken - # For now, use rough estimation - # O3 models ~4 chars per token - return len(text) // 4 def get_provider_type(self) -> ProviderType: """Get the provider type.""" @@ -165,13 +77,3 @@ class OpenAIModelProvider(ModelProvider): # This may change with future O3 models return False - def _extract_usage(self, response) -> dict[str, int]: - """Extract token usage from OpenAI response.""" - usage = {} - - if hasattr(response, "usage") and response.usage: - usage["input_tokens"] = response.usage.prompt_tokens - usage["output_tokens"] = response.usage.completion_tokens - usage["total_tokens"] = response.usage.total_tokens - - return usage diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py new file mode 100644 index 0000000..3008582 --- /dev/null +++ b/providers/openai_compatible.py @@ -0,0 +1,417 @@ +"""Base class for OpenAI-compatible API providers.""" + +import logging +import os +from abc import abstractmethod +from typing import Optional +from urllib.parse import urlparse +import ipaddress +import socket + +from openai import OpenAI + +from .base import ( + ModelCapabilities, + ModelProvider, + ModelResponse, + ProviderType, + RangeTemperatureConstraint, +) + + +class OpenAICompatibleProvider(ModelProvider): + """Base class for any provider using an OpenAI-compatible API. + + This includes: + - Direct OpenAI API + - OpenRouter + - Any other OpenAI-compatible endpoint + """ + + DEFAULT_HEADERS = {} + FRIENDLY_NAME = "OpenAI Compatible" + + def __init__(self, api_key: str, base_url: str = None, **kwargs): + """Initialize the provider with API key and optional base URL. + + Args: + api_key: API key for authentication + base_url: Base URL for the API endpoint + **kwargs: Additional configuration options + """ + super().__init__(api_key, **kwargs) + self._client = None + self.base_url = base_url + self.organization = kwargs.get("organization") + self.allowed_models = self._parse_allowed_models() + + # Validate base URL for security + if self.base_url: + self._validate_base_url() + + # Warn if using external URL without authentication + if self.base_url and not self._is_localhost_url() and not api_key: + logging.warning( + f"Using external URL '{self.base_url}' without API key. " + "This may be insecure. Consider setting an API key for authentication." + ) + + def _parse_allowed_models(self) -> Optional[set[str]]: + """Parse allowed models from environment variable. + + Returns: + Set of allowed model names (lowercase) or None if not configured + """ + # Get provider-specific allowed models + provider_type = self.get_provider_type().value.upper() + env_var = f"{provider_type}_ALLOWED_MODELS" + models_str = os.getenv(env_var, "") + + if models_str: + # Parse and normalize to lowercase for case-insensitive comparison + models = set(m.strip().lower() for m in models_str.split(",") if m.strip()) + if models: + logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}") + return models + + # Log warning if no allow-list configured for proxy providers + if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]: + logging.warning( + f"No model allow-list configured for {self.FRIENDLY_NAME}. " + f"Set {env_var} to restrict model access and control costs." + ) + + return None + + def _is_localhost_url(self) -> bool: + """Check if the base URL points to localhost. + + Returns: + True if URL is localhost, False otherwise + """ + if not self.base_url: + return False + + try: + parsed = urlparse(self.base_url) + hostname = parsed.hostname + + # Check for common localhost patterns + if hostname in ['localhost', '127.0.0.1', '::1']: + return True + + return False + except Exception: + return False + + def _validate_base_url(self) -> None: + """Validate base URL for security (SSRF protection). + + Raises: + ValueError: If URL is invalid or potentially unsafe + """ + if not self.base_url: + return + + try: + parsed = urlparse(self.base_url) + + + # Check URL scheme - only allow http/https + if parsed.scheme not in ('http', 'https'): + raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.") + + # Check hostname exists + if not parsed.hostname: + raise ValueError("URL must include a hostname") + + # Check port - allow only standard HTTP/HTTPS ports + port = parsed.port + if port is None: + port = 443 if parsed.scheme == 'https' else 80 + + # Allow common HTTP ports and some alternative ports + allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports + if port not in allowed_ports: + raise ValueError( + f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}" + ) + + # Check against allowed domains if configured + allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",") + allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()] + + if allowed_domains: + hostname_lower = parsed.hostname.lower() + if not any( + hostname_lower == domain or + hostname_lower.endswith('.' + domain) + for domain in allowed_domains + ): + raise ValueError( + f"Domain not in allow-list: {parsed.hostname}. " + f"Allowed domains: {allowed_domains}" + ) + + # Try to resolve hostname and check if it's a private IP + # Skip for localhost addresses which are commonly used for development + if parsed.hostname not in ['localhost', '127.0.0.1', '::1']: + try: + # Get all IP addresses for the hostname + addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP) + + for family, _, _, _, sockaddr in addr_info: + ip_str = sockaddr[0] + try: + ip = ipaddress.ip_address(ip_str) + + # Check for dangerous IP ranges + if (ip.is_private or ip.is_loopback or ip.is_link_local or + ip.is_multicast or ip.is_reserved or ip.is_unspecified): + raise ValueError( + f"URL resolves to restricted IP address: {ip_str}. " + "This could be a security risk (SSRF)." + ) + except ValueError as ve: + # Invalid IP address format or restricted IP - re-raise if it's our security error + if "restricted IP address" in str(ve): + raise + continue + + except socket.gaierror as e: + # If we can't resolve the hostname, it's suspicious + raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}") + + except Exception as e: + if isinstance(e, ValueError): + raise + raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}") + + @property + def client(self): + """Lazy initialization of OpenAI client with security checks.""" + if self._client is None: + client_kwargs = { + "api_key": self.api_key, + } + + if self.base_url: + client_kwargs["base_url"] = self.base_url + + if self.organization: + client_kwargs["organization"] = self.organization + + # Add default headers if any + if self.DEFAULT_HEADERS: + client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy() + + self._client = OpenAI(**client_kwargs) + + return self._client + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.7, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using the OpenAI-compatible API. + + Args: + prompt: User prompt to send to the model + model_name: Name of the model to use + system_prompt: Optional system prompt for model behavior + temperature: Sampling temperature + max_output_tokens: Maximum tokens to generate + **kwargs: Additional provider-specific parameters + + Returns: + ModelResponse with generated content and metadata + """ + # Validate model name against allow-list + if not self.validate_model_name(model_name): + raise ValueError( + f"Model '{model_name}' not in allowed models list. " + f"Allowed models: {self.allowed_models}" + ) + + # Validate parameters + self.validate_parameters(model_name, temperature) + + # Prepare messages + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + # Prepare completion parameters + completion_params = { + "model": model_name, + "messages": messages, + "temperature": temperature, + } + + # Add max tokens if specified + if max_output_tokens: + completion_params["max_tokens"] = max_output_tokens + + # Add any additional OpenAI-specific parameters + for key, value in kwargs.items(): + if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]: + completion_params[key] = value + + try: + # Generate completion + response = self.client.chat.completions.create(**completion_params) + + # Extract content and usage + content = response.choices[0].message.content + usage = self._extract_usage(response) + + return ModelResponse( + content=content, + usage=usage, + model_name=model_name, + friendly_name=self.FRIENDLY_NAME, + provider=self.get_provider_type(), + metadata={ + "finish_reason": response.choices[0].finish_reason, + "model": response.model, # Actual model used + "id": response.id, + "created": response.created, + }, + ) + + except Exception as e: + # Log error and re-raise with more context + error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}" + logging.error(error_msg) + raise RuntimeError(error_msg) from e + + def count_tokens(self, text: str, model_name: str) -> int: + """Count tokens for the given text. + + Uses a layered approach: + 1. Try provider-specific token counting endpoint + 2. Try tiktoken for known model families + 3. Fall back to character-based estimation + + Args: + text: Text to count tokens for + model_name: Model name for tokenizer selection + + Returns: + Estimated token count + """ + # 1. Check if provider has a remote token counting endpoint + if hasattr(self, 'count_tokens_remote'): + try: + return self.count_tokens_remote(text, model_name) + except Exception as e: + logging.debug(f"Remote token counting failed: {e}") + + # 2. Try tiktoken for known models + try: + import tiktoken + + # Try to get encoding for the specific model + try: + encoding = tiktoken.encoding_for_model(model_name) + except KeyError: + # Try common encodings based on model patterns + if "gpt-4" in model_name or "gpt-3.5" in model_name: + encoding = tiktoken.get_encoding("cl100k_base") + else: + encoding = tiktoken.get_encoding("cl100k_base") # Default + + return len(encoding.encode(text)) + + except (ImportError, Exception) as e: + logging.debug(f"Tiktoken not available or failed: {e}") + + # 3. Fall back to character-based estimation + logging.warning( + f"No specific tokenizer available for '{model_name}'. " + "Using character-based estimation (~4 chars per token)." + ) + return len(text) // 4 + + def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: + """Validate model parameters. + + For proxy providers, this may use generic capabilities. + + Args: + model_name: Model to validate for + temperature: Temperature to validate + **kwargs: Additional parameters to validate + """ + try: + capabilities = self.get_capabilities(model_name) + + # Check if we're using generic capabilities + if hasattr(capabilities, '_is_generic'): + logging.debug( + f"Using generic parameter validation for {model_name}. " + "Actual model constraints may differ." + ) + + # Validate temperature using parent class method + super().validate_parameters(model_name, temperature, **kwargs) + + except Exception as e: + # For proxy providers, we might not have accurate capabilities + # Log warning but don't fail + logging.warning(f"Parameter validation limited for {model_name}: {e}") + + def _extract_usage(self, response) -> dict[str, int]: + """Extract token usage from OpenAI response. + + Args: + response: OpenAI API response object + + Returns: + Dictionary with usage statistics + """ + usage = {} + + if hasattr(response, "usage") and response.usage: + usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) + usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) + usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) + + return usage + + @abstractmethod + def get_capabilities(self, model_name: str) -> ModelCapabilities: + """Get capabilities for a specific model. + + Must be implemented by subclasses. + """ + pass + + @abstractmethod + def get_provider_type(self) -> ProviderType: + """Get the provider type. + + Must be implemented by subclasses. + """ + pass + + @abstractmethod + def validate_model_name(self, model_name: str) -> bool: + """Validate if the model name is supported. + + Must be implemented by subclasses. + """ + pass + + def supports_thinking_mode(self, model_name: str) -> bool: + """Check if the model supports extended thinking mode. + + Default is False for OpenAI-compatible providers. + """ + return False \ No newline at end of file diff --git a/providers/openrouter.py b/providers/openrouter.py new file mode 100644 index 0000000..657e810 --- /dev/null +++ b/providers/openrouter.py @@ -0,0 +1,119 @@ +"""OpenRouter provider implementation.""" + +import logging +import os + +from .base import ( + ModelCapabilities, + ProviderType, + RangeTemperatureConstraint, +) +from .openai_compatible import OpenAICompatibleProvider + + +class OpenRouterProvider(OpenAICompatibleProvider): + """OpenRouter unified API provider. + + OpenRouter provides access to multiple AI models through a single API endpoint. + See https://openrouter.ai for available models and pricing. + """ + + FRIENDLY_NAME = "OpenRouter" + + # Custom headers required by OpenRouter + DEFAULT_HEADERS = { + "HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"), + "X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"), + } + + def __init__(self, api_key: str, **kwargs): + """Initialize OpenRouter provider. + + Args: + api_key: OpenRouter API key + **kwargs: Additional configuration + """ + # Always use OpenRouter's base URL + super().__init__(api_key, base_url="https://openrouter.ai/api/v1", **kwargs) + + # Log warning about model allow-list if not configured + if not self.allowed_models: + logging.warning( + "OpenRouter provider initialized without model allow-list. " + "Consider setting OPENROUTER_ALLOWED_MODELS environment variable " + "to restrict model access and control costs." + ) + + def get_capabilities(self, model_name: str) -> ModelCapabilities: + """Get capabilities for a model. + + Since OpenRouter supports many models dynamically, we return + generic capabilities with conservative defaults. + + Args: + model_name: Name of the model + + Returns: + Generic ModelCapabilities with warnings logged + """ + logging.warning( + f"Using generic capabilities for '{model_name}' via OpenRouter. " + "Actual model capabilities may differ. Consider querying OpenRouter's " + "/models endpoint for accurate information." + ) + + # Create generic capabilities with conservative defaults + capabilities = ModelCapabilities( + provider=ProviderType.OPENROUTER, + model_name=model_name, + friendly_name=self.FRIENDLY_NAME, + max_tokens=32_768, # Conservative default + supports_extended_thinking=False, # Most models don't support this + supports_system_prompts=True, # Most models support this + supports_streaming=True, + supports_function_calling=False, # Varies by model + temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0), + ) + + # Mark as generic for validation purposes + capabilities._is_generic = True + + return capabilities + + def get_provider_type(self) -> ProviderType: + """Get the provider type.""" + return ProviderType.OPENROUTER + + def validate_model_name(self, model_name: str) -> bool: + """Validate if the model name is allowed. + + For OpenRouter, we accept any model name unless an allow-list + is configured via OPENROUTER_ALLOWED_MODELS environment variable. + + Args: + model_name: Model name to validate + + Returns: + True if model is allowed, False otherwise + """ + if self.allowed_models: + # Case-insensitive validation against allow-list + return model_name.lower() in self.allowed_models + + # Accept any model if no allow-list configured + # The API will return an error if the model doesn't exist + return True + + def supports_thinking_mode(self, model_name: str) -> bool: + """Check if the model supports extended thinking mode. + + Currently, no models via OpenRouter support extended thinking. + This may change as new models become available. + + Args: + model_name: Model to check + + Returns: + False (no OpenRouter models currently support thinking mode) + """ + return False \ No newline at end of file diff --git a/providers/registry.py b/providers/registry.py index 8d126b2..c9fe184 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -117,6 +117,7 @@ class ModelProviderRegistry: key_mapping = { ProviderType.GOOGLE: "GEMINI_API_KEY", ProviderType.OPENAI: "OPENAI_API_KEY", + ProviderType.OPENROUTER: "OPENROUTER_API_KEY", } env_var = key_mapping.get(provider_type) diff --git a/server.py b/server.py index 49d376b..680774f 100644 --- a/server.py +++ b/server.py @@ -131,6 +131,7 @@ def configure_providers(): from providers.base import ProviderType from providers.gemini import GeminiModelProvider from providers.openai import OpenAIModelProvider + from providers.openrouter import OpenRouterProvider valid_providers = [] @@ -148,12 +149,21 @@ def configure_providers(): valid_providers.append("OpenAI (o3)") logger.info("OpenAI API key found - o3 model available") + # Check for OpenRouter API key + openrouter_key = os.getenv("OPENROUTER_API_KEY") + if openrouter_key and openrouter_key != "your_openrouter_api_key_here": + ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) + valid_providers.append("OpenRouter") + logger.info("OpenRouter API key found - Multiple models available via OpenRouter") + + # Require at least one valid provider if not valid_providers: raise ValueError( "At least one API key is required. Please set either:\n" "- GEMINI_API_KEY for Gemini models\n" - "- OPENAI_API_KEY for OpenAI o3 model" + "- OPENAI_API_KEY for OpenAI o3 model\n" + "- OPENROUTER_API_KEY for OpenRouter (multiple models)" ) logger.info(f"Available providers: {', '.join(valid_providers)}") diff --git a/setup-docker.sh b/setup-docker.sh index 4f489c4..0ac8cbc 100755 --- a/setup-docker.sh +++ b/setup-docker.sh @@ -36,8 +36,6 @@ else else echo "⚠️ Found GEMINI_API_KEY in environment, but sed not available. Please update .env manually." fi - else - echo "⚠️ GEMINI_API_KEY not found in environment. Please edit .env and add your API key." fi if [ -n "${OPENAI_API_KEY:-}" ]; then @@ -48,8 +46,16 @@ else else echo "⚠️ Found OPENAI_API_KEY in environment, but sed not available. Please update .env manually." fi - else - echo "⚠️ OPENAI_API_KEY not found in environment. Please edit .env and add your API key." + fi + + if [ -n "${OPENROUTER_API_KEY:-}" ]; then + # Replace the placeholder API key with the actual value + if command -v sed >/dev/null 2>&1; then + sed -i.bak "s/your_openrouter_api_key_here/$OPENROUTER_API_KEY/" .env && rm .env.bak + echo "✅ Updated .env with existing OPENROUTER_API_KEY from environment" + else + echo "⚠️ Found OPENROUTER_API_KEY in environment, but sed not available. Please update .env manually." + fi fi # Update WORKSPACE_ROOT to use current user's home directory @@ -92,6 +98,7 @@ source .env 2>/dev/null || true VALID_GEMINI_KEY=false VALID_OPENAI_KEY=false +VALID_OPENROUTER_KEY=false # Check if GEMINI_API_KEY is set and not the placeholder if [ -n "${GEMINI_API_KEY:-}" ] && [ "$GEMINI_API_KEY" != "your_gemini_api_key_here" ]; then @@ -105,18 +112,26 @@ if [ -n "${OPENAI_API_KEY:-}" ] && [ "$OPENAI_API_KEY" != "your_openai_api_key_h echo "✅ Valid OPENAI_API_KEY found" fi +# Check if OPENROUTER_API_KEY is set and not the placeholder +if [ -n "${OPENROUTER_API_KEY:-}" ] && [ "$OPENROUTER_API_KEY" != "your_openrouter_api_key_here" ]; then + VALID_OPENROUTER_KEY=true + echo "✅ Valid OPENROUTER_API_KEY found" +fi + # Require at least one valid API key -if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ]; then +if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ] && [ "$VALID_OPENROUTER_KEY" = false ]; then echo "" echo "❌ ERROR: At least one valid API key is required!" echo "" echo "Please edit the .env file and set at least one of:" echo " - GEMINI_API_KEY (get from https://makersuite.google.com/app/apikey)" echo " - OPENAI_API_KEY (get from https://platform.openai.com/api-keys)" + echo " - OPENROUTER_API_KEY (get from https://openrouter.ai/)" echo "" echo "Example:" echo " GEMINI_API_KEY=your-actual-api-key-here" echo " OPENAI_API_KEY=sk-your-actual-openai-key-here" + echo " OPENROUTER_API_KEY=sk-or-your-actual-openrouter-key-here" echo "" exit 1 fi @@ -228,7 +243,7 @@ show_configuration_steps() { echo "" echo "🔄 Next steps:" NEEDS_KEY_UPDATE=false - if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null; then + if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null || grep -q "your_openrouter_api_key_here" .env 2>/dev/null; then NEEDS_KEY_UPDATE=true fi @@ -236,6 +251,7 @@ show_configuration_steps() { echo "1. Edit .env and replace placeholder API keys with actual ones" echo " - GEMINI_API_KEY: your-gemini-api-key-here" echo " - OPENAI_API_KEY: your-openai-api-key-here" + echo " - OPENROUTER_API_KEY: your-openrouter-api-key-here (optional)" echo "2. Restart services: $COMPOSE_CMD restart" echo "3. Copy the configuration below to your Claude Desktop config if required:" else diff --git a/tests/test_openrouter_provider.py b/tests/test_openrouter_provider.py new file mode 100644 index 0000000..73c4787 --- /dev/null +++ b/tests/test_openrouter_provider.py @@ -0,0 +1,138 @@ +"""Tests for OpenRouter provider.""" + +import os +import pytest +from unittest.mock import patch, MagicMock + +from providers.base import ProviderType +from providers.openrouter import OpenRouterProvider +from providers.registry import ModelProviderRegistry + + +class TestOpenRouterProvider: + """Test cases for OpenRouter provider.""" + + def test_provider_initialization(self): + """Test OpenRouter provider initialization.""" + provider = OpenRouterProvider(api_key="test-key") + assert provider.api_key == "test-key" + assert provider.base_url == "https://openrouter.ai/api/v1" + assert provider.FRIENDLY_NAME == "OpenRouter" + + def test_custom_headers(self): + """Test OpenRouter custom headers.""" + # Test default headers + assert "HTTP-Referer" in OpenRouterProvider.DEFAULT_HEADERS + assert "X-Title" in OpenRouterProvider.DEFAULT_HEADERS + + # Test with environment variables + with patch.dict(os.environ, { + "OPENROUTER_REFERER": "https://myapp.com", + "OPENROUTER_TITLE": "My App" + }): + from importlib import reload + import providers.openrouter + reload(providers.openrouter) + + provider = providers.openrouter.OpenRouterProvider(api_key="test-key") + assert provider.DEFAULT_HEADERS["HTTP-Referer"] == "https://myapp.com" + assert provider.DEFAULT_HEADERS["X-Title"] == "My App" + + def test_model_validation_without_allowlist(self): + """Test model validation without allow-list.""" + provider = OpenRouterProvider(api_key="test-key") + + # Should accept any model when no allow-list + assert provider.validate_model_name("gpt-4") is True + assert provider.validate_model_name("claude-3-opus") is True + assert provider.validate_model_name("any-model-name") is True + + def test_model_validation_with_allowlist(self): + """Test model validation with allow-list.""" + with patch.dict(os.environ, { + "OPENROUTER_ALLOWED_MODELS": "gpt-4,claude-3-opus,mistral-large" + }): + provider = OpenRouterProvider(api_key="test-key") + + # Test allowed models (case-insensitive) + assert provider.validate_model_name("gpt-4") is True + assert provider.validate_model_name("GPT-4") is True + assert provider.validate_model_name("claude-3-opus") is True + assert provider.validate_model_name("MISTRAL-LARGE") is True + + # Test disallowed models + assert provider.validate_model_name("gpt-3.5-turbo") is False + assert provider.validate_model_name("unauthorized-model") is False + + def test_get_capabilities(self): + """Test capability generation returns generic capabilities.""" + provider = OpenRouterProvider(api_key="test-key") + + # Should return generic capabilities for any model + caps = provider.get_capabilities("gpt-4") + assert caps.provider == ProviderType.OPENROUTER + assert caps.model_name == "gpt-4" + assert caps.friendly_name == "OpenRouter" + assert caps.max_tokens == 32_768 # Safe default + assert hasattr(caps, '_is_generic') and caps._is_generic is True + + def test_openrouter_registration(self): + """Test OpenRouter can be registered and retrieved.""" + with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}): + # Clean up any existing registration + ModelProviderRegistry.unregister_provider(ProviderType.OPENROUTER) + + # Register the provider + ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) + + # Retrieve and verify + provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER) + assert provider is not None + assert isinstance(provider, OpenRouterProvider) + + +class TestOpenRouterSSRFProtection: + """Test SSRF protection for OpenRouter.""" + + def test_url_validation_rejects_private_ips(self): + """Test that private IPs are rejected.""" + provider = OpenRouterProvider(api_key="test-key") + + # List of private/dangerous IPs to test + dangerous_urls = [ + "http://192.168.1.1/api/v1", + "http://10.0.0.1/api/v1", + "http://172.16.0.1/api/v1", + "http://169.254.169.254/api/v1", # AWS metadata + "http://[::1]/api/v1", # IPv6 localhost + "http://0.0.0.0/api/v1", + ] + + for url in dangerous_urls: + with pytest.raises(ValueError, match="restricted IP|Invalid"): + provider.base_url = url + provider._validate_base_url() + + def test_url_validation_allows_public_domains(self): + """Test that legitimate public domains are allowed.""" + provider = OpenRouterProvider(api_key="test-key") + + # OpenRouter's actual domain should always be allowed + provider.base_url = "https://openrouter.ai/api/v1" + provider._validate_base_url() # Should not raise + + def test_invalid_url_schemes_rejected(self): + """Test that non-HTTP(S) schemes are rejected.""" + provider = OpenRouterProvider(api_key="test-key") + + invalid_urls = [ + "ftp://example.com/api", + "file:///etc/passwd", + "gopher://example.com", + "javascript:alert(1)", + ] + + for url in invalid_urls: + with pytest.raises(ValueError, match="Invalid URL scheme"): + provider.base_url = url + provider._validate_base_url() \ No newline at end of file