WIP
- OpenRouter model configuration registry - Model definition file for users to be able to control - Additional tests - Update instructions
This commit is contained in:
@@ -1,12 +1,8 @@
|
||||
"""OpenAI model provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
FixedTemperatureConstraint,
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
@@ -34,7 +30,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
kwargs.setdefault("base_url", "https://api.openai.com/v1")
|
||||
super().__init__(api_key, **kwargs)
|
||||
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific OpenAI model."""
|
||||
if model_name not in self.SUPPORTED_MODELS:
|
||||
@@ -62,7 +57,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
temperature_constraint=temp_constraint,
|
||||
)
|
||||
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.OPENAI
|
||||
@@ -76,4 +70,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
# Currently no OpenAI models support extended thinking
|
||||
# This may change with future O3 models
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""Base class for OpenAI-compatible API providers."""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
import ipaddress
|
||||
import socket
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -15,25 +15,24 @@ from .base import (
|
||||
ModelProvider,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
|
||||
|
||||
class OpenAICompatibleProvider(ModelProvider):
|
||||
"""Base class for any provider using an OpenAI-compatible API.
|
||||
|
||||
|
||||
This includes:
|
||||
- Direct OpenAI API
|
||||
- OpenRouter
|
||||
- Any other OpenAI-compatible endpoint
|
||||
"""
|
||||
|
||||
|
||||
DEFAULT_HEADERS = {}
|
||||
FRIENDLY_NAME = "OpenAI Compatible"
|
||||
|
||||
|
||||
def __init__(self, api_key: str, base_url: str = None, **kwargs):
|
||||
"""Initialize the provider with API key and optional base URL.
|
||||
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
base_url: Base URL for the API endpoint
|
||||
@@ -44,21 +43,21 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
self.base_url = base_url
|
||||
self.organization = kwargs.get("organization")
|
||||
self.allowed_models = self._parse_allowed_models()
|
||||
|
||||
|
||||
# Validate base URL for security
|
||||
if self.base_url:
|
||||
self._validate_base_url()
|
||||
|
||||
|
||||
# Warn if using external URL without authentication
|
||||
if self.base_url and not self._is_localhost_url() and not api_key:
|
||||
logging.warning(
|
||||
f"Using external URL '{self.base_url}' without API key. "
|
||||
"This may be insecure. Consider setting an API key for authentication."
|
||||
)
|
||||
|
||||
|
||||
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||||
"""Parse allowed models from environment variable.
|
||||
|
||||
|
||||
Returns:
|
||||
Set of allowed model names (lowercase) or None if not configured
|
||||
"""
|
||||
@@ -66,108 +65,108 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
provider_type = self.get_provider_type().value.upper()
|
||||
env_var = f"{provider_type}_ALLOWED_MODELS"
|
||||
models_str = os.getenv(env_var, "")
|
||||
|
||||
|
||||
if models_str:
|
||||
# Parse and normalize to lowercase for case-insensitive comparison
|
||||
models = set(m.strip().lower() for m in models_str.split(",") if m.strip())
|
||||
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
||||
if models:
|
||||
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
||||
return models
|
||||
|
||||
|
||||
# Log warning if no allow-list configured for proxy providers
|
||||
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
|
||||
logging.warning(
|
||||
f"No model allow-list configured for {self.FRIENDLY_NAME}. "
|
||||
f"Set {env_var} to restrict model access and control costs."
|
||||
)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_localhost_url(self) -> bool:
|
||||
"""Check if the base URL points to localhost.
|
||||
|
||||
|
||||
Returns:
|
||||
True if URL is localhost, False otherwise
|
||||
"""
|
||||
if not self.base_url:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
parsed = urlparse(self.base_url)
|
||||
hostname = parsed.hostname
|
||||
|
||||
|
||||
# Check for common localhost patterns
|
||||
if hostname in ['localhost', '127.0.0.1', '::1']:
|
||||
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _validate_base_url(self) -> None:
|
||||
"""Validate base URL for security (SSRF protection).
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is invalid or potentially unsafe
|
||||
"""
|
||||
if not self.base_url:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
parsed = urlparse(self.base_url)
|
||||
|
||||
|
||||
|
||||
# Check URL scheme - only allow http/https
|
||||
if parsed.scheme not in ('http', 'https'):
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
|
||||
|
||||
|
||||
# Check hostname exists
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL must include a hostname")
|
||||
|
||||
|
||||
# Check port - allow only standard HTTP/HTTPS ports
|
||||
port = parsed.port
|
||||
if port is None:
|
||||
port = 443 if parsed.scheme == 'https' else 80
|
||||
|
||||
port = 443 if parsed.scheme == "https" else 80
|
||||
|
||||
# Allow common HTTP ports and some alternative ports
|
||||
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
|
||||
if port not in allowed_ports:
|
||||
raise ValueError(
|
||||
f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}"
|
||||
)
|
||||
|
||||
raise ValueError(f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}")
|
||||
|
||||
# Check against allowed domains if configured
|
||||
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
|
||||
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
|
||||
|
||||
|
||||
if allowed_domains:
|
||||
hostname_lower = parsed.hostname.lower()
|
||||
if not any(
|
||||
hostname_lower == domain or
|
||||
hostname_lower.endswith('.' + domain)
|
||||
for domain in allowed_domains
|
||||
hostname_lower == domain or hostname_lower.endswith("." + domain) for domain in allowed_domains
|
||||
):
|
||||
raise ValueError(
|
||||
f"Domain not in allow-list: {parsed.hostname}. "
|
||||
f"Allowed domains: {allowed_domains}"
|
||||
f"Domain not in allow-list: {parsed.hostname}. " f"Allowed domains: {allowed_domains}"
|
||||
)
|
||||
|
||||
|
||||
# Try to resolve hostname and check if it's a private IP
|
||||
# Skip for localhost addresses which are commonly used for development
|
||||
if parsed.hostname not in ['localhost', '127.0.0.1', '::1']:
|
||||
if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
|
||||
try:
|
||||
# Get all IP addresses for the hostname
|
||||
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
|
||||
|
||||
for family, _, _, _, sockaddr in addr_info:
|
||||
|
||||
for _family, _, _, _, sockaddr in addr_info:
|
||||
ip_str = sockaddr[0]
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
|
||||
|
||||
# Check for dangerous IP ranges
|
||||
if (ip.is_private or ip.is_loopback or ip.is_link_local or
|
||||
ip.is_multicast or ip.is_reserved or ip.is_unspecified):
|
||||
if (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
or ip.is_unspecified
|
||||
):
|
||||
raise ValueError(
|
||||
f"URL resolves to restricted IP address: {ip_str}. "
|
||||
"This could be a security risk (SSRF)."
|
||||
@@ -177,16 +176,16 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
if "restricted IP address" in str(ve):
|
||||
raise
|
||||
continue
|
||||
|
||||
|
||||
except socket.gaierror as e:
|
||||
# If we can't resolve the hostname, it's suspicious
|
||||
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
|
||||
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy initialization of OpenAI client with security checks."""
|
||||
@@ -194,21 +193,21 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
client_kwargs = {
|
||||
"api_key": self.api_key,
|
||||
}
|
||||
|
||||
|
||||
if self.base_url:
|
||||
client_kwargs["base_url"] = self.base_url
|
||||
|
||||
|
||||
if self.organization:
|
||||
client_kwargs["organization"] = self.organization
|
||||
|
||||
|
||||
# Add default headers if any
|
||||
if self.DEFAULT_HEADERS:
|
||||
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||||
|
||||
|
||||
self._client = OpenAI(**client_kwargs)
|
||||
|
||||
|
||||
return self._client
|
||||
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -219,7 +218,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using the OpenAI-compatible API.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send to the model
|
||||
model_name: Name of the model to use
|
||||
@@ -227,50 +226,49 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
temperature: Sampling temperature
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
"""
|
||||
# Validate model name against allow-list
|
||||
if not self.validate_model_name(model_name):
|
||||
raise ValueError(
|
||||
f"Model '{model_name}' not in allowed models list. "
|
||||
f"Allowed models: {self.allowed_models}"
|
||||
f"Model '{model_name}' not in allowed models list. " f"Allowed models: {self.allowed_models}"
|
||||
)
|
||||
|
||||
|
||||
# Validate parameters
|
||||
self.validate_parameters(model_name, temperature)
|
||||
|
||||
|
||||
# Prepare messages
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
|
||||
# Add max tokens if specified
|
||||
if max_output_tokens:
|
||||
completion_params["max_tokens"] = max_output_tokens
|
||||
|
||||
|
||||
# Add any additional OpenAI-specific parameters
|
||||
for key, value in kwargs.items():
|
||||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||||
completion_params[key] = value
|
||||
|
||||
|
||||
try:
|
||||
# Generate completion
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
|
||||
|
||||
# Extract content and usage
|
||||
content = response.choices[0].message.content
|
||||
usage = self._extract_usage(response)
|
||||
|
||||
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
usage=usage,
|
||||
@@ -284,39 +282,39 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
"created": response.created,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Log error and re-raise with more context
|
||||
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
"""Count tokens for the given text.
|
||||
|
||||
|
||||
Uses a layered approach:
|
||||
1. Try provider-specific token counting endpoint
|
||||
2. Try tiktoken for known model families
|
||||
3. Fall back to character-based estimation
|
||||
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
model_name: Model name for tokenizer selection
|
||||
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
# 1. Check if provider has a remote token counting endpoint
|
||||
if hasattr(self, 'count_tokens_remote'):
|
||||
if hasattr(self, "count_tokens_remote"):
|
||||
try:
|
||||
return self.count_tokens_remote(text, model_name)
|
||||
except Exception as e:
|
||||
logging.debug(f"Remote token counting failed: {e}")
|
||||
|
||||
|
||||
# 2. Try tiktoken for known models
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
|
||||
# Try to get encoding for the specific model
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
@@ -326,24 +324,24 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
else:
|
||||
encoding = tiktoken.get_encoding("cl100k_base") # Default
|
||||
|
||||
|
||||
return len(encoding.encode(text))
|
||||
|
||||
|
||||
except (ImportError, Exception) as e:
|
||||
logging.debug(f"Tiktoken not available or failed: {e}")
|
||||
|
||||
|
||||
# 3. Fall back to character-based estimation
|
||||
logging.warning(
|
||||
f"No specific tokenizer available for '{model_name}'. "
|
||||
"Using character-based estimation (~4 chars per token)."
|
||||
)
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||
"""Validate model parameters.
|
||||
|
||||
|
||||
For proxy providers, this may use generic capabilities.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Model to validate for
|
||||
temperature: Temperature to validate
|
||||
@@ -351,67 +349,66 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
"""
|
||||
try:
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
|
||||
# Check if we're using generic capabilities
|
||||
if hasattr(capabilities, '_is_generic'):
|
||||
if hasattr(capabilities, "_is_generic"):
|
||||
logging.debug(
|
||||
f"Using generic parameter validation for {model_name}. "
|
||||
"Actual model constraints may differ."
|
||||
f"Using generic parameter validation for {model_name}. " "Actual model constraints may differ."
|
||||
)
|
||||
|
||||
|
||||
# Validate temperature using parent class method
|
||||
super().validate_parameters(model_name, temperature, **kwargs)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# For proxy providers, we might not have accurate capabilities
|
||||
# Log warning but don't fail
|
||||
logging.warning(f"Parameter validation limited for {model_name}: {e}")
|
||||
|
||||
|
||||
def _extract_usage(self, response) -> dict[str, int]:
|
||||
"""Extract token usage from OpenAI response.
|
||||
|
||||
|
||||
Args:
|
||||
response: OpenAI API response object
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with usage statistics
|
||||
"""
|
||||
usage = {}
|
||||
|
||||
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
|
||||
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
|
||||
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
|
||||
|
||||
|
||||
return usage
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model.
|
||||
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type.
|
||||
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported.
|
||||
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode.
|
||||
|
||||
|
||||
Default is False for OpenAI-compatible providers.
|
||||
"""
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -16,63 +16,61 @@ from .openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
"""OpenRouter unified API provider.
|
||||
|
||||
|
||||
OpenRouter provides access to multiple AI models through a single API endpoint.
|
||||
See https://openrouter.ai for available models and pricing.
|
||||
"""
|
||||
|
||||
|
||||
FRIENDLY_NAME = "OpenRouter"
|
||||
|
||||
|
||||
# Custom headers required by OpenRouter
|
||||
DEFAULT_HEADERS = {
|
||||
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
|
||||
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
|
||||
}
|
||||
|
||||
|
||||
# Model registry for managing configurations and aliases
|
||||
_registry: Optional[OpenRouterModelRegistry] = None
|
||||
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize OpenRouter provider.
|
||||
|
||||
|
||||
Args:
|
||||
api_key: OpenRouter API key
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
# Always use OpenRouter's base URL
|
||||
super().__init__(api_key, base_url="https://openrouter.ai/api/v1", **kwargs)
|
||||
|
||||
|
||||
# Initialize model registry
|
||||
if OpenRouterProvider._registry is None:
|
||||
OpenRouterProvider._registry = OpenRouterModelRegistry()
|
||||
|
||||
|
||||
# Log loaded models and aliases
|
||||
models = self._registry.list_models()
|
||||
aliases = self._registry.list_aliases()
|
||||
logging.info(
|
||||
f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases"
|
||||
)
|
||||
|
||||
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
|
||||
|
||||
def _parse_allowed_models(self) -> None:
|
||||
"""Override to disable environment-based allow-list.
|
||||
|
||||
|
||||
OpenRouter model access is controlled via the OpenRouter dashboard,
|
||||
not through environment variables.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model aliases to OpenRouter model names.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Input model name or alias
|
||||
|
||||
|
||||
Returns:
|
||||
Resolved OpenRouter model name
|
||||
"""
|
||||
# Try to resolve through registry
|
||||
config = self._registry.resolve(model_name)
|
||||
|
||||
|
||||
if config:
|
||||
if config.model_name != model_name:
|
||||
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||
@@ -82,30 +80,30 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
# This allows using models not in our config file
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
return model_name
|
||||
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a model.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (or alias)
|
||||
|
||||
|
||||
Returns:
|
||||
ModelCapabilities from registry or generic defaults
|
||||
"""
|
||||
# Try to get from registry first
|
||||
capabilities = self._registry.get_capabilities(model_name)
|
||||
|
||||
|
||||
if capabilities:
|
||||
return capabilities
|
||||
else:
|
||||
# Resolve any potential aliases and create generic capabilities
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
|
||||
logging.debug(
|
||||
f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
|
||||
"Consider adding to openrouter_models.json for specific capabilities."
|
||||
)
|
||||
|
||||
|
||||
# Create generic capabilities with conservative defaults
|
||||
capabilities = ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
@@ -118,31 +116,31 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
supports_function_calling=False,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||
)
|
||||
|
||||
|
||||
# Mark as generic for validation purposes
|
||||
capabilities._is_generic = True
|
||||
|
||||
|
||||
return capabilities
|
||||
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.OPENROUTER
|
||||
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is allowed.
|
||||
|
||||
|
||||
For OpenRouter, we accept any model name. OpenRouter will
|
||||
validate based on the API key's permissions.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
|
||||
Returns:
|
||||
Always True - OpenRouter handles validation
|
||||
"""
|
||||
# Accept any model name - OpenRouter will validate based on API key permissions
|
||||
return True
|
||||
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -153,7 +151,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using the OpenRouter API.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: User prompt to send to the model
|
||||
model_name: Name of the model (or alias) to use
|
||||
@@ -161,13 +159,13 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
temperature: Sampling temperature
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
"""
|
||||
# Resolve model alias to actual OpenRouter model name
|
||||
resolved_model = self._resolve_model_name(model_name)
|
||||
|
||||
|
||||
# Call parent method with resolved model name
|
||||
return super().generate_content(
|
||||
prompt=prompt,
|
||||
@@ -175,19 +173,19 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode.
|
||||
|
||||
|
||||
Currently, no models via OpenRouter support extended thinking.
|
||||
This may change as new models become available.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Model to check
|
||||
|
||||
|
||||
Returns:
|
||||
False (no OpenRouter models currently support thinking mode)
|
||||
"""
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||
|
||||
@@ -13,9 +13,9 @@ from .base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||
@dataclass
|
||||
class OpenRouterModelConfig:
|
||||
"""Configuration for an OpenRouter model."""
|
||||
|
||||
|
||||
model_name: str
|
||||
aliases: List[str] = field(default_factory=list)
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
context_window: int = 32768 # Total context window size in tokens
|
||||
supports_extended_thinking: bool = False
|
||||
supports_system_prompts: bool = True
|
||||
@@ -23,8 +23,7 @@ class OpenRouterModelConfig:
|
||||
supports_function_calling: bool = False
|
||||
supports_json_mode: bool = False
|
||||
description: str = ""
|
||||
|
||||
|
||||
|
||||
def to_capabilities(self) -> ModelCapabilities:
|
||||
"""Convert to ModelCapabilities object."""
|
||||
return ModelCapabilities(
|
||||
@@ -42,16 +41,16 @@ class OpenRouterModelConfig:
|
||||
|
||||
class OpenRouterModelRegistry:
|
||||
"""Registry for managing OpenRouter model configurations and aliases."""
|
||||
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""Initialize the registry.
|
||||
|
||||
|
||||
Args:
|
||||
config_path: Path to config file. If None, uses default locations.
|
||||
"""
|
||||
self.alias_map: Dict[str, str] = {} # alias -> model_name
|
||||
self.model_map: Dict[str, OpenRouterModelConfig] = {} # model_name -> config
|
||||
|
||||
self.alias_map: dict[str, str] = {} # alias -> model_name
|
||||
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config
|
||||
|
||||
# Determine config path
|
||||
if config_path:
|
||||
self.config_path = Path(config_path)
|
||||
@@ -63,86 +62,93 @@ class OpenRouterModelRegistry:
|
||||
else:
|
||||
# Default to conf/openrouter_models.json
|
||||
self.config_path = Path(__file__).parent.parent / "conf" / "openrouter_models.json"
|
||||
|
||||
|
||||
# Load configuration
|
||||
self.reload()
|
||||
|
||||
|
||||
def reload(self) -> None:
|
||||
"""Reload configuration from disk."""
|
||||
try:
|
||||
configs = self._read_config()
|
||||
self._build_maps(configs)
|
||||
logging.info(f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases")
|
||||
except ValueError as e:
|
||||
# Re-raise ValueError only for duplicate aliases (critical config errors)
|
||||
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||
# Initialize with empty maps on failure
|
||||
self.alias_map = {}
|
||||
self.model_map = {}
|
||||
if "Duplicate alias" in str(e):
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||
# Initialize with empty maps on failure
|
||||
self.alias_map = {}
|
||||
self.model_map = {}
|
||||
|
||||
def _read_config(self) -> List[OpenRouterModelConfig]:
|
||||
|
||||
def _read_config(self) -> list[OpenRouterModelConfig]:
|
||||
"""Read configuration from file.
|
||||
|
||||
|
||||
Returns:
|
||||
List of model configurations
|
||||
"""
|
||||
if not self.config_path.exists():
|
||||
logging.warning(f"OpenRouter model config not found at {self.config_path}")
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
with open(self.config_path, 'r') as f:
|
||||
with open(self.config_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
# Parse models
|
||||
configs = []
|
||||
for model_data in data.get("models", []):
|
||||
# Handle backwards compatibility - rename max_tokens to context_window
|
||||
if 'max_tokens' in model_data and 'context_window' not in model_data:
|
||||
model_data['context_window'] = model_data.pop('max_tokens')
|
||||
|
||||
if "max_tokens" in model_data and "context_window" not in model_data:
|
||||
model_data["context_window"] = model_data.pop("max_tokens")
|
||||
|
||||
config = OpenRouterModelConfig(**model_data)
|
||||
configs.append(config)
|
||||
|
||||
|
||||
return configs
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in {self.config_path}: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading config from {self.config_path}: {e}")
|
||||
|
||||
def _build_maps(self, configs: List[OpenRouterModelConfig]) -> None:
|
||||
|
||||
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None:
|
||||
"""Build alias and model maps from configurations.
|
||||
|
||||
|
||||
Args:
|
||||
configs: List of model configurations
|
||||
"""
|
||||
alias_map = {}
|
||||
model_map = {}
|
||||
|
||||
|
||||
for config in configs:
|
||||
# Add to model map
|
||||
model_map[config.model_name] = config
|
||||
|
||||
|
||||
# Add aliases
|
||||
for alias in config.aliases:
|
||||
alias_lower = alias.lower()
|
||||
if alias_lower in alias_map:
|
||||
existing_model = alias_map[alias_lower]
|
||||
raise ValueError(
|
||||
f"Duplicate alias '{alias}' found for models "
|
||||
f"'{existing_model}' and '{config.model_name}'"
|
||||
f"Duplicate alias '{alias}' found for models " f"'{existing_model}' and '{config.model_name}'"
|
||||
)
|
||||
alias_map[alias_lower] = config.model_name
|
||||
|
||||
|
||||
# Atomic update
|
||||
self.alias_map = alias_map
|
||||
self.model_map = model_map
|
||||
|
||||
|
||||
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]:
|
||||
"""Resolve a model name or alias to configuration.
|
||||
|
||||
|
||||
Args:
|
||||
name_or_alias: Model name or alias to resolve
|
||||
|
||||
|
||||
Returns:
|
||||
Model configuration if found, None otherwise
|
||||
"""
|
||||
@@ -151,16 +157,16 @@ class OpenRouterModelRegistry:
|
||||
if alias_lower in self.alias_map:
|
||||
model_name = self.alias_map[alias_lower]
|
||||
return self.model_map.get(model_name)
|
||||
|
||||
|
||||
# Try as direct model name
|
||||
return self.model_map.get(name_or_alias)
|
||||
|
||||
|
||||
def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||
"""Get model capabilities for a name or alias.
|
||||
|
||||
|
||||
Args:
|
||||
name_or_alias: Model name or alias
|
||||
|
||||
|
||||
Returns:
|
||||
ModelCapabilities if found, None otherwise
|
||||
"""
|
||||
@@ -168,11 +174,11 @@ class OpenRouterModelRegistry:
|
||||
if config:
|
||||
return config.to_capabilities()
|
||||
return None
|
||||
|
||||
def list_models(self) -> List[str]:
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
"""List all available model names."""
|
||||
return list(self.model_map.keys())
|
||||
|
||||
def list_aliases(self) -> List[str]:
|
||||
|
||||
def list_aliases(self) -> list[str]:
|
||||
"""List all available aliases."""
|
||||
return list(self.alias_map.keys())
|
||||
return list(self.alias_map.keys())
|
||||
|
||||
Reference in New Issue
Block a user