- OpenRouter model configuration registry - Model definition file for users to be able to control - Additional tests - Update instructions
415 lines
15 KiB
Python
415 lines
15 KiB
Python
"""Base class for OpenAI-compatible API providers."""
|
|
|
|
import ipaddress
|
|
import logging
|
|
import os
|
|
import socket
|
|
from abc import abstractmethod
|
|
from typing import Optional
|
|
from urllib.parse import urlparse
|
|
|
|
from openai import OpenAI
|
|
|
|
from .base import (
|
|
ModelCapabilities,
|
|
ModelProvider,
|
|
ModelResponse,
|
|
ProviderType,
|
|
)
|
|
|
|
|
|
class OpenAICompatibleProvider(ModelProvider):
|
|
"""Base class for any provider using an OpenAI-compatible API.
|
|
|
|
This includes:
|
|
- Direct OpenAI API
|
|
- OpenRouter
|
|
- Any other OpenAI-compatible endpoint
|
|
"""
|
|
|
|
DEFAULT_HEADERS = {}
|
|
FRIENDLY_NAME = "OpenAI Compatible"
|
|
|
|
def __init__(self, api_key: str, base_url: str = None, **kwargs):
|
|
"""Initialize the provider with API key and optional base URL.
|
|
|
|
Args:
|
|
api_key: API key for authentication
|
|
base_url: Base URL for the API endpoint
|
|
**kwargs: Additional configuration options
|
|
"""
|
|
super().__init__(api_key, **kwargs)
|
|
self._client = None
|
|
self.base_url = base_url
|
|
self.organization = kwargs.get("organization")
|
|
self.allowed_models = self._parse_allowed_models()
|
|
|
|
# Validate base URL for security
|
|
if self.base_url:
|
|
self._validate_base_url()
|
|
|
|
# Warn if using external URL without authentication
|
|
if self.base_url and not self._is_localhost_url() and not api_key:
|
|
logging.warning(
|
|
f"Using external URL '{self.base_url}' without API key. "
|
|
"This may be insecure. Consider setting an API key for authentication."
|
|
)
|
|
|
|
def _parse_allowed_models(self) -> Optional[set[str]]:
|
|
"""Parse allowed models from environment variable.
|
|
|
|
Returns:
|
|
Set of allowed model names (lowercase) or None if not configured
|
|
"""
|
|
# Get provider-specific allowed models
|
|
provider_type = self.get_provider_type().value.upper()
|
|
env_var = f"{provider_type}_ALLOWED_MODELS"
|
|
models_str = os.getenv(env_var, "")
|
|
|
|
if models_str:
|
|
# Parse and normalize to lowercase for case-insensitive comparison
|
|
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
|
if models:
|
|
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
|
return models
|
|
|
|
# Log warning if no allow-list configured for proxy providers
|
|
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
|
|
logging.warning(
|
|
f"No model allow-list configured for {self.FRIENDLY_NAME}. "
|
|
f"Set {env_var} to restrict model access and control costs."
|
|
)
|
|
|
|
return None
|
|
|
|
def _is_localhost_url(self) -> bool:
|
|
"""Check if the base URL points to localhost.
|
|
|
|
Returns:
|
|
True if URL is localhost, False otherwise
|
|
"""
|
|
if not self.base_url:
|
|
return False
|
|
|
|
try:
|
|
parsed = urlparse(self.base_url)
|
|
hostname = parsed.hostname
|
|
|
|
# Check for common localhost patterns
|
|
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
|
return True
|
|
|
|
return False
|
|
except Exception:
|
|
return False
|
|
|
|
def _validate_base_url(self) -> None:
|
|
"""Validate base URL for security (SSRF protection).
|
|
|
|
Raises:
|
|
ValueError: If URL is invalid or potentially unsafe
|
|
"""
|
|
if not self.base_url:
|
|
return
|
|
|
|
try:
|
|
parsed = urlparse(self.base_url)
|
|
|
|
# Check URL scheme - only allow http/https
|
|
if parsed.scheme not in ("http", "https"):
|
|
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
|
|
|
|
# Check hostname exists
|
|
if not parsed.hostname:
|
|
raise ValueError("URL must include a hostname")
|
|
|
|
# Check port - allow only standard HTTP/HTTPS ports
|
|
port = parsed.port
|
|
if port is None:
|
|
port = 443 if parsed.scheme == "https" else 80
|
|
|
|
# Allow common HTTP ports and some alternative ports
|
|
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
|
|
if port not in allowed_ports:
|
|
raise ValueError(f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}")
|
|
|
|
# Check against allowed domains if configured
|
|
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
|
|
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
|
|
|
|
if allowed_domains:
|
|
hostname_lower = parsed.hostname.lower()
|
|
if not any(
|
|
hostname_lower == domain or hostname_lower.endswith("." + domain) for domain in allowed_domains
|
|
):
|
|
raise ValueError(
|
|
f"Domain not in allow-list: {parsed.hostname}. " f"Allowed domains: {allowed_domains}"
|
|
)
|
|
|
|
# Try to resolve hostname and check if it's a private IP
|
|
# Skip for localhost addresses which are commonly used for development
|
|
if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
|
|
try:
|
|
# Get all IP addresses for the hostname
|
|
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
|
|
|
|
for _family, _, _, _, sockaddr in addr_info:
|
|
ip_str = sockaddr[0]
|
|
try:
|
|
ip = ipaddress.ip_address(ip_str)
|
|
|
|
# Check for dangerous IP ranges
|
|
if (
|
|
ip.is_private
|
|
or ip.is_loopback
|
|
or ip.is_link_local
|
|
or ip.is_multicast
|
|
or ip.is_reserved
|
|
or ip.is_unspecified
|
|
):
|
|
raise ValueError(
|
|
f"URL resolves to restricted IP address: {ip_str}. "
|
|
"This could be a security risk (SSRF)."
|
|
)
|
|
except ValueError as ve:
|
|
# Invalid IP address format or restricted IP - re-raise if it's our security error
|
|
if "restricted IP address" in str(ve):
|
|
raise
|
|
continue
|
|
|
|
except socket.gaierror as e:
|
|
# If we can't resolve the hostname, it's suspicious
|
|
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
|
|
|
|
except Exception as e:
|
|
if isinstance(e, ValueError):
|
|
raise
|
|
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
|
|
|
|
@property
|
|
def client(self):
|
|
"""Lazy initialization of OpenAI client with security checks."""
|
|
if self._client is None:
|
|
client_kwargs = {
|
|
"api_key": self.api_key,
|
|
}
|
|
|
|
if self.base_url:
|
|
client_kwargs["base_url"] = self.base_url
|
|
|
|
if self.organization:
|
|
client_kwargs["organization"] = self.organization
|
|
|
|
# Add default headers if any
|
|
if self.DEFAULT_HEADERS:
|
|
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
|
|
|
self._client = OpenAI(**client_kwargs)
|
|
|
|
return self._client
|
|
|
|
def generate_content(
|
|
self,
|
|
prompt: str,
|
|
model_name: str,
|
|
system_prompt: Optional[str] = None,
|
|
temperature: float = 0.7,
|
|
max_output_tokens: Optional[int] = None,
|
|
**kwargs,
|
|
) -> ModelResponse:
|
|
"""Generate content using the OpenAI-compatible API.
|
|
|
|
Args:
|
|
prompt: User prompt to send to the model
|
|
model_name: Name of the model to use
|
|
system_prompt: Optional system prompt for model behavior
|
|
temperature: Sampling temperature
|
|
max_output_tokens: Maximum tokens to generate
|
|
**kwargs: Additional provider-specific parameters
|
|
|
|
Returns:
|
|
ModelResponse with generated content and metadata
|
|
"""
|
|
# Validate model name against allow-list
|
|
if not self.validate_model_name(model_name):
|
|
raise ValueError(
|
|
f"Model '{model_name}' not in allowed models list. " f"Allowed models: {self.allowed_models}"
|
|
)
|
|
|
|
# Validate parameters
|
|
self.validate_parameters(model_name, temperature)
|
|
|
|
# Prepare messages
|
|
messages = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
# Prepare completion parameters
|
|
completion_params = {
|
|
"model": model_name,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
}
|
|
|
|
# Add max tokens if specified
|
|
if max_output_tokens:
|
|
completion_params["max_tokens"] = max_output_tokens
|
|
|
|
# Add any additional OpenAI-specific parameters
|
|
for key, value in kwargs.items():
|
|
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
|
completion_params[key] = value
|
|
|
|
try:
|
|
# Generate completion
|
|
response = self.client.chat.completions.create(**completion_params)
|
|
|
|
# Extract content and usage
|
|
content = response.choices[0].message.content
|
|
usage = self._extract_usage(response)
|
|
|
|
return ModelResponse(
|
|
content=content,
|
|
usage=usage,
|
|
model_name=model_name,
|
|
friendly_name=self.FRIENDLY_NAME,
|
|
provider=self.get_provider_type(),
|
|
metadata={
|
|
"finish_reason": response.choices[0].finish_reason,
|
|
"model": response.model, # Actual model used
|
|
"id": response.id,
|
|
"created": response.created,
|
|
},
|
|
)
|
|
|
|
except Exception as e:
|
|
# Log error and re-raise with more context
|
|
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}"
|
|
logging.error(error_msg)
|
|
raise RuntimeError(error_msg) from e
|
|
|
|
def count_tokens(self, text: str, model_name: str) -> int:
|
|
"""Count tokens for the given text.
|
|
|
|
Uses a layered approach:
|
|
1. Try provider-specific token counting endpoint
|
|
2. Try tiktoken for known model families
|
|
3. Fall back to character-based estimation
|
|
|
|
Args:
|
|
text: Text to count tokens for
|
|
model_name: Model name for tokenizer selection
|
|
|
|
Returns:
|
|
Estimated token count
|
|
"""
|
|
# 1. Check if provider has a remote token counting endpoint
|
|
if hasattr(self, "count_tokens_remote"):
|
|
try:
|
|
return self.count_tokens_remote(text, model_name)
|
|
except Exception as e:
|
|
logging.debug(f"Remote token counting failed: {e}")
|
|
|
|
# 2. Try tiktoken for known models
|
|
try:
|
|
import tiktoken
|
|
|
|
# Try to get encoding for the specific model
|
|
try:
|
|
encoding = tiktoken.encoding_for_model(model_name)
|
|
except KeyError:
|
|
# Try common encodings based on model patterns
|
|
if "gpt-4" in model_name or "gpt-3.5" in model_name:
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
else:
|
|
encoding = tiktoken.get_encoding("cl100k_base") # Default
|
|
|
|
return len(encoding.encode(text))
|
|
|
|
except (ImportError, Exception) as e:
|
|
logging.debug(f"Tiktoken not available or failed: {e}")
|
|
|
|
# 3. Fall back to character-based estimation
|
|
logging.warning(
|
|
f"No specific tokenizer available for '{model_name}'. "
|
|
"Using character-based estimation (~4 chars per token)."
|
|
)
|
|
return len(text) // 4
|
|
|
|
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
|
"""Validate model parameters.
|
|
|
|
For proxy providers, this may use generic capabilities.
|
|
|
|
Args:
|
|
model_name: Model to validate for
|
|
temperature: Temperature to validate
|
|
**kwargs: Additional parameters to validate
|
|
"""
|
|
try:
|
|
capabilities = self.get_capabilities(model_name)
|
|
|
|
# Check if we're using generic capabilities
|
|
if hasattr(capabilities, "_is_generic"):
|
|
logging.debug(
|
|
f"Using generic parameter validation for {model_name}. " "Actual model constraints may differ."
|
|
)
|
|
|
|
# Validate temperature using parent class method
|
|
super().validate_parameters(model_name, temperature, **kwargs)
|
|
|
|
except Exception as e:
|
|
# For proxy providers, we might not have accurate capabilities
|
|
# Log warning but don't fail
|
|
logging.warning(f"Parameter validation limited for {model_name}: {e}")
|
|
|
|
def _extract_usage(self, response) -> dict[str, int]:
|
|
"""Extract token usage from OpenAI response.
|
|
|
|
Args:
|
|
response: OpenAI API response object
|
|
|
|
Returns:
|
|
Dictionary with usage statistics
|
|
"""
|
|
usage = {}
|
|
|
|
if hasattr(response, "usage") and response.usage:
|
|
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
|
|
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
|
|
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
|
|
|
|
return usage
|
|
|
|
@abstractmethod
|
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
"""Get capabilities for a specific model.
|
|
|
|
Must be implemented by subclasses.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_provider_type(self) -> ProviderType:
|
|
"""Get the provider type.
|
|
|
|
Must be implemented by subclasses.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def validate_model_name(self, model_name: str) -> bool:
|
|
"""Validate if the model name is supported.
|
|
|
|
Must be implemented by subclasses.
|
|
"""
|
|
pass
|
|
|
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
|
"""Check if the model supports extended thinking mode.
|
|
|
|
Default is False for OpenAI-compatible providers.
|
|
"""
|
|
return False
|