- OpenRouter model configuration registry
- Model definition file for users to be able to control
- Additional tests
- Update instructions
This commit is contained in:
Fahad
2025-06-13 06:33:12 +04:00
parent cd1105b741
commit 2cdb92460b
12 changed files with 417 additions and 381 deletions

View File

@@ -1,12 +1,8 @@
"""OpenAI model provider implementation."""
import logging
from typing import Optional
from .base import (
FixedTemperatureConstraint,
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
@@ -34,7 +30,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
kwargs.setdefault("base_url", "https://api.openai.com/v1")
super().__init__(api_key, **kwargs)
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific OpenAI model."""
if model_name not in self.SUPPORTED_MODELS:
@@ -62,7 +57,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
temperature_constraint=temp_constraint,
)
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.OPENAI
@@ -76,4 +70,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
# Currently no OpenAI models support extended thinking
# This may change with future O3 models
return False

View File

@@ -1,12 +1,12 @@
"""Base class for OpenAI-compatible API providers."""
import ipaddress
import logging
import os
import socket
from abc import abstractmethod
from typing import Optional
from urllib.parse import urlparse
import ipaddress
import socket
from openai import OpenAI
@@ -15,25 +15,24 @@ from .base import (
ModelProvider,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
class OpenAICompatibleProvider(ModelProvider):
"""Base class for any provider using an OpenAI-compatible API.
This includes:
- Direct OpenAI API
- OpenRouter
- Any other OpenAI-compatible endpoint
"""
DEFAULT_HEADERS = {}
FRIENDLY_NAME = "OpenAI Compatible"
def __init__(self, api_key: str, base_url: str = None, **kwargs):
"""Initialize the provider with API key and optional base URL.
Args:
api_key: API key for authentication
base_url: Base URL for the API endpoint
@@ -44,21 +43,21 @@ class OpenAICompatibleProvider(ModelProvider):
self.base_url = base_url
self.organization = kwargs.get("organization")
self.allowed_models = self._parse_allowed_models()
# Validate base URL for security
if self.base_url:
self._validate_base_url()
# Warn if using external URL without authentication
if self.base_url and not self._is_localhost_url() and not api_key:
logging.warning(
f"Using external URL '{self.base_url}' without API key. "
"This may be insecure. Consider setting an API key for authentication."
)
def _parse_allowed_models(self) -> Optional[set[str]]:
"""Parse allowed models from environment variable.
Returns:
Set of allowed model names (lowercase) or None if not configured
"""
@@ -66,108 +65,108 @@ class OpenAICompatibleProvider(ModelProvider):
provider_type = self.get_provider_type().value.upper()
env_var = f"{provider_type}_ALLOWED_MODELS"
models_str = os.getenv(env_var, "")
if models_str:
# Parse and normalize to lowercase for case-insensitive comparison
models = set(m.strip().lower() for m in models_str.split(",") if m.strip())
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
if models:
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
return models
# Log warning if no allow-list configured for proxy providers
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
logging.warning(
f"No model allow-list configured for {self.FRIENDLY_NAME}. "
f"Set {env_var} to restrict model access and control costs."
)
return None
def _is_localhost_url(self) -> bool:
"""Check if the base URL points to localhost.
Returns:
True if URL is localhost, False otherwise
"""
if not self.base_url:
return False
try:
parsed = urlparse(self.base_url)
hostname = parsed.hostname
# Check for common localhost patterns
if hostname in ['localhost', '127.0.0.1', '::1']:
if hostname in ["localhost", "127.0.0.1", "::1"]:
return True
return False
except Exception:
return False
def _validate_base_url(self) -> None:
"""Validate base URL for security (SSRF protection).
Raises:
ValueError: If URL is invalid or potentially unsafe
"""
if not self.base_url:
return
try:
parsed = urlparse(self.base_url)
# Check URL scheme - only allow http/https
if parsed.scheme not in ('http', 'https'):
if parsed.scheme not in ("http", "https"):
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
# Check hostname exists
if not parsed.hostname:
raise ValueError("URL must include a hostname")
# Check port - allow only standard HTTP/HTTPS ports
port = parsed.port
if port is None:
port = 443 if parsed.scheme == 'https' else 80
port = 443 if parsed.scheme == "https" else 80
# Allow common HTTP ports and some alternative ports
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
if port not in allowed_ports:
raise ValueError(
f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}"
)
raise ValueError(f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}")
# Check against allowed domains if configured
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
if allowed_domains:
hostname_lower = parsed.hostname.lower()
if not any(
hostname_lower == domain or
hostname_lower.endswith('.' + domain)
for domain in allowed_domains
hostname_lower == domain or hostname_lower.endswith("." + domain) for domain in allowed_domains
):
raise ValueError(
f"Domain not in allow-list: {parsed.hostname}. "
f"Allowed domains: {allowed_domains}"
f"Domain not in allow-list: {parsed.hostname}. " f"Allowed domains: {allowed_domains}"
)
# Try to resolve hostname and check if it's a private IP
# Skip for localhost addresses which are commonly used for development
if parsed.hostname not in ['localhost', '127.0.0.1', '::1']:
if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
try:
# Get all IP addresses for the hostname
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
for family, _, _, _, sockaddr in addr_info:
for _family, _, _, _, sockaddr in addr_info:
ip_str = sockaddr[0]
try:
ip = ipaddress.ip_address(ip_str)
# Check for dangerous IP ranges
if (ip.is_private or ip.is_loopback or ip.is_link_local or
ip.is_multicast or ip.is_reserved or ip.is_unspecified):
if (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
):
raise ValueError(
f"URL resolves to restricted IP address: {ip_str}. "
"This could be a security risk (SSRF)."
@@ -177,16 +176,16 @@ class OpenAICompatibleProvider(ModelProvider):
if "restricted IP address" in str(ve):
raise
continue
except socket.gaierror as e:
# If we can't resolve the hostname, it's suspicious
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
@property
def client(self):
"""Lazy initialization of OpenAI client with security checks."""
@@ -194,21 +193,21 @@ class OpenAICompatibleProvider(ModelProvider):
client_kwargs = {
"api_key": self.api_key,
}
if self.base_url:
client_kwargs["base_url"] = self.base_url
if self.organization:
client_kwargs["organization"] = self.organization
# Add default headers if any
if self.DEFAULT_HEADERS:
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
self._client = OpenAI(**client_kwargs)
return self._client
def generate_content(
self,
prompt: str,
@@ -219,7 +218,7 @@ class OpenAICompatibleProvider(ModelProvider):
**kwargs,
) -> ModelResponse:
"""Generate content using the OpenAI-compatible API.
Args:
prompt: User prompt to send to the model
model_name: Name of the model to use
@@ -227,50 +226,49 @@ class OpenAICompatibleProvider(ModelProvider):
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Validate model name against allow-list
if not self.validate_model_name(model_name):
raise ValueError(
f"Model '{model_name}' not in allowed models list. "
f"Allowed models: {self.allowed_models}"
f"Model '{model_name}' not in allowed models list. " f"Allowed models: {self.allowed_models}"
)
# Validate parameters
self.validate_parameters(model_name, temperature)
# Prepare messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# Prepare completion parameters
completion_params = {
"model": model_name,
"messages": messages,
"temperature": temperature,
}
# Add max tokens if specified
if max_output_tokens:
completion_params["max_tokens"] = max_output_tokens
# Add any additional OpenAI-specific parameters
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
completion_params[key] = value
try:
# Generate completion
response = self.client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
@@ -284,39 +282,39 @@ class OpenAICompatibleProvider(ModelProvider):
"created": response.created,
},
)
except Exception as e:
# Log error and re-raise with more context
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from e
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text.
Uses a layered approach:
1. Try provider-specific token counting endpoint
2. Try tiktoken for known model families
3. Fall back to character-based estimation
Args:
text: Text to count tokens for
model_name: Model name for tokenizer selection
Returns:
Estimated token count
"""
# 1. Check if provider has a remote token counting endpoint
if hasattr(self, 'count_tokens_remote'):
if hasattr(self, "count_tokens_remote"):
try:
return self.count_tokens_remote(text, model_name)
except Exception as e:
logging.debug(f"Remote token counting failed: {e}")
# 2. Try tiktoken for known models
try:
import tiktoken
# Try to get encoding for the specific model
try:
encoding = tiktoken.encoding_for_model(model_name)
@@ -326,24 +324,24 @@ class OpenAICompatibleProvider(ModelProvider):
encoding = tiktoken.get_encoding("cl100k_base")
else:
encoding = tiktoken.get_encoding("cl100k_base") # Default
return len(encoding.encode(text))
except (ImportError, Exception) as e:
logging.debug(f"Tiktoken not available or failed: {e}")
# 3. Fall back to character-based estimation
logging.warning(
f"No specific tokenizer available for '{model_name}'. "
"Using character-based estimation (~4 chars per token)."
)
return len(text) // 4
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters.
For proxy providers, this may use generic capabilities.
Args:
model_name: Model to validate for
temperature: Temperature to validate
@@ -351,67 +349,66 @@ class OpenAICompatibleProvider(ModelProvider):
"""
try:
capabilities = self.get_capabilities(model_name)
# Check if we're using generic capabilities
if hasattr(capabilities, '_is_generic'):
if hasattr(capabilities, "_is_generic"):
logging.debug(
f"Using generic parameter validation for {model_name}. "
"Actual model constraints may differ."
f"Using generic parameter validation for {model_name}. " "Actual model constraints may differ."
)
# Validate temperature using parent class method
super().validate_parameters(model_name, temperature, **kwargs)
except Exception as e:
# For proxy providers, we might not have accurate capabilities
# Log warning but don't fail
logging.warning(f"Parameter validation limited for {model_name}: {e}")
def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from OpenAI response.
Args:
response: OpenAI API response object
Returns:
Dictionary with usage statistics
"""
usage = {}
if hasattr(response, "usage") and response.usage:
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
return usage
@abstractmethod
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def get_provider_type(self) -> ProviderType:
"""Get the provider type.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported.
Must be implemented by subclasses.
"""
pass
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode.
Default is False for OpenAI-compatible providers.
"""
return False
return False

View File

@@ -16,63 +16,61 @@ from .openrouter_registry import OpenRouterModelRegistry
class OpenRouterProvider(OpenAICompatibleProvider):
"""OpenRouter unified API provider.
OpenRouter provides access to multiple AI models through a single API endpoint.
See https://openrouter.ai for available models and pricing.
"""
FRIENDLY_NAME = "OpenRouter"
# Custom headers required by OpenRouter
DEFAULT_HEADERS = {
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
}
# Model registry for managing configurations and aliases
_registry: Optional[OpenRouterModelRegistry] = None
def __init__(self, api_key: str, **kwargs):
"""Initialize OpenRouter provider.
Args:
api_key: OpenRouter API key
**kwargs: Additional configuration
"""
# Always use OpenRouter's base URL
super().__init__(api_key, base_url="https://openrouter.ai/api/v1", **kwargs)
# Initialize model registry
if OpenRouterProvider._registry is None:
OpenRouterProvider._registry = OpenRouterModelRegistry()
# Log loaded models and aliases
models = self._registry.list_models()
aliases = self._registry.list_aliases()
logging.info(
f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases"
)
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
def _parse_allowed_models(self) -> None:
"""Override to disable environment-based allow-list.
OpenRouter model access is controlled via the OpenRouter dashboard,
not through environment variables.
"""
return None
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model aliases to OpenRouter model names.
Args:
model_name: Input model name or alias
Returns:
Resolved OpenRouter model name
"""
# Try to resolve through registry
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
@@ -82,30 +80,30 @@ class OpenRouterProvider(OpenAICompatibleProvider):
# This allows using models not in our config file
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a model.
Args:
model_name: Name of the model (or alias)
Returns:
ModelCapabilities from registry or generic defaults
"""
# Try to get from registry first
capabilities = self._registry.get_capabilities(model_name)
if capabilities:
return capabilities
else:
# Resolve any potential aliases and create generic capabilities
resolved_name = self._resolve_model_name(model_name)
logging.debug(
f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
"Consider adding to openrouter_models.json for specific capabilities."
)
# Create generic capabilities with conservative defaults
capabilities = ModelCapabilities(
provider=ProviderType.OPENROUTER,
@@ -118,31 +116,31 @@ class OpenRouterProvider(OpenAICompatibleProvider):
supports_function_calling=False,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
)
# Mark as generic for validation purposes
capabilities._is_generic = True
return capabilities
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.OPENROUTER
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is allowed.
For OpenRouter, we accept any model name. OpenRouter will
validate based on the API key's permissions.
Args:
model_name: Model name to validate
Returns:
Always True - OpenRouter handles validation
"""
# Accept any model name - OpenRouter will validate based on API key permissions
return True
def generate_content(
self,
prompt: str,
@@ -153,7 +151,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
**kwargs,
) -> ModelResponse:
"""Generate content using the OpenRouter API.
Args:
prompt: User prompt to send to the model
model_name: Name of the model (or alias) to use
@@ -161,13 +159,13 @@ class OpenRouterProvider(OpenAICompatibleProvider):
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Resolve model alias to actual OpenRouter model name
resolved_model = self._resolve_model_name(model_name)
# Call parent method with resolved model name
return super().generate_content(
prompt=prompt,
@@ -175,19 +173,19 @@ class OpenRouterProvider(OpenAICompatibleProvider):
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs
**kwargs,
)
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode.
Currently, no models via OpenRouter support extended thinking.
This may change as new models become available.
Args:
model_name: Model to check
Returns:
False (no OpenRouter models currently support thinking mode)
"""
return False
return False

View File

@@ -3,9 +3,9 @@
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
@@ -13,9 +13,9 @@ from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
@dataclass
class OpenRouterModelConfig:
"""Configuration for an OpenRouter model."""
model_name: str
aliases: List[str] = field(default_factory=list)
aliases: list[str] = field(default_factory=list)
context_window: int = 32768 # Total context window size in tokens
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
@@ -23,8 +23,7 @@ class OpenRouterModelConfig:
supports_function_calling: bool = False
supports_json_mode: bool = False
description: str = ""
def to_capabilities(self) -> ModelCapabilities:
"""Convert to ModelCapabilities object."""
return ModelCapabilities(
@@ -42,16 +41,16 @@ class OpenRouterModelConfig:
class OpenRouterModelRegistry:
"""Registry for managing OpenRouter model configurations and aliases."""
def __init__(self, config_path: Optional[str] = None):
"""Initialize the registry.
Args:
config_path: Path to config file. If None, uses default locations.
"""
self.alias_map: Dict[str, str] = {} # alias -> model_name
self.model_map: Dict[str, OpenRouterModelConfig] = {} # model_name -> config
self.alias_map: dict[str, str] = {} # alias -> model_name
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config
# Determine config path
if config_path:
self.config_path = Path(config_path)
@@ -63,86 +62,93 @@ class OpenRouterModelRegistry:
else:
# Default to conf/openrouter_models.json
self.config_path = Path(__file__).parent.parent / "conf" / "openrouter_models.json"
# Load configuration
self.reload()
def reload(self) -> None:
"""Reload configuration from disk."""
try:
configs = self._read_config()
self._build_maps(configs)
logging.info(f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases")
except ValueError as e:
# Re-raise ValueError only for duplicate aliases (critical config errors)
logging.error(f"Failed to load OpenRouter model configuration: {e}")
# Initialize with empty maps on failure
self.alias_map = {}
self.model_map = {}
if "Duplicate alias" in str(e):
raise
except Exception as e:
logging.error(f"Failed to load OpenRouter model configuration: {e}")
# Initialize with empty maps on failure
self.alias_map = {}
self.model_map = {}
def _read_config(self) -> List[OpenRouterModelConfig]:
def _read_config(self) -> list[OpenRouterModelConfig]:
"""Read configuration from file.
Returns:
List of model configurations
"""
if not self.config_path.exists():
logging.warning(f"OpenRouter model config not found at {self.config_path}")
return []
try:
with open(self.config_path, 'r') as f:
with open(self.config_path) as f:
data = json.load(f)
# Parse models
configs = []
for model_data in data.get("models", []):
# Handle backwards compatibility - rename max_tokens to context_window
if 'max_tokens' in model_data and 'context_window' not in model_data:
model_data['context_window'] = model_data.pop('max_tokens')
if "max_tokens" in model_data and "context_window" not in model_data:
model_data["context_window"] = model_data.pop("max_tokens")
config = OpenRouterModelConfig(**model_data)
configs.append(config)
return configs
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in {self.config_path}: {e}")
except Exception as e:
raise ValueError(f"Error reading config from {self.config_path}: {e}")
def _build_maps(self, configs: List[OpenRouterModelConfig]) -> None:
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None:
"""Build alias and model maps from configurations.
Args:
configs: List of model configurations
"""
alias_map = {}
model_map = {}
for config in configs:
# Add to model map
model_map[config.model_name] = config
# Add aliases
for alias in config.aliases:
alias_lower = alias.lower()
if alias_lower in alias_map:
existing_model = alias_map[alias_lower]
raise ValueError(
f"Duplicate alias '{alias}' found for models "
f"'{existing_model}' and '{config.model_name}'"
f"Duplicate alias '{alias}' found for models " f"'{existing_model}' and '{config.model_name}'"
)
alias_map[alias_lower] = config.model_name
# Atomic update
self.alias_map = alias_map
self.model_map = model_map
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]:
"""Resolve a model name or alias to configuration.
Args:
name_or_alias: Model name or alias to resolve
Returns:
Model configuration if found, None otherwise
"""
@@ -151,16 +157,16 @@ class OpenRouterModelRegistry:
if alias_lower in self.alias_map:
model_name = self.alias_map[alias_lower]
return self.model_map.get(model_name)
# Try as direct model name
return self.model_map.get(name_or_alias)
def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]:
"""Get model capabilities for a name or alias.
Args:
name_or_alias: Model name or alias
Returns:
ModelCapabilities if found, None otherwise
"""
@@ -168,11 +174,11 @@ class OpenRouterModelRegistry:
if config:
return config.to_capabilities()
return None
def list_models(self) -> List[str]:
def list_models(self) -> list[str]:
"""List all available model names."""
return list(self.model_map.keys())
def list_aliases(self) -> List[str]:
def list_aliases(self) -> list[str]:
"""List all available aliases."""
return list(self.alias_map.keys())
return list(self.alias_map.keys())