- OpenRouter model configuration registry
- Model definition file for users to be able to control
- Additional tests
- Update instructions
This commit is contained in:
Fahad
2025-06-13 06:33:12 +04:00
parent cd1105b741
commit 2cdb92460b
12 changed files with 417 additions and 381 deletions

View File

@@ -1,12 +1,12 @@
"""Base class for OpenAI-compatible API providers."""
import ipaddress
import logging
import os
import socket
from abc import abstractmethod
from typing import Optional
from urllib.parse import urlparse
import ipaddress
import socket
from openai import OpenAI
@@ -15,25 +15,24 @@ from .base import (
ModelProvider,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
class OpenAICompatibleProvider(ModelProvider):
"""Base class for any provider using an OpenAI-compatible API.
This includes:
- Direct OpenAI API
- OpenRouter
- Any other OpenAI-compatible endpoint
"""
DEFAULT_HEADERS = {}
FRIENDLY_NAME = "OpenAI Compatible"
def __init__(self, api_key: str, base_url: str = None, **kwargs):
"""Initialize the provider with API key and optional base URL.
Args:
api_key: API key for authentication
base_url: Base URL for the API endpoint
@@ -44,21 +43,21 @@ class OpenAICompatibleProvider(ModelProvider):
self.base_url = base_url
self.organization = kwargs.get("organization")
self.allowed_models = self._parse_allowed_models()
# Validate base URL for security
if self.base_url:
self._validate_base_url()
# Warn if using external URL without authentication
if self.base_url and not self._is_localhost_url() and not api_key:
logging.warning(
f"Using external URL '{self.base_url}' without API key. "
"This may be insecure. Consider setting an API key for authentication."
)
def _parse_allowed_models(self) -> Optional[set[str]]:
"""Parse allowed models from environment variable.
Returns:
Set of allowed model names (lowercase) or None if not configured
"""
@@ -66,108 +65,108 @@ class OpenAICompatibleProvider(ModelProvider):
provider_type = self.get_provider_type().value.upper()
env_var = f"{provider_type}_ALLOWED_MODELS"
models_str = os.getenv(env_var, "")
if models_str:
# Parse and normalize to lowercase for case-insensitive comparison
models = set(m.strip().lower() for m in models_str.split(",") if m.strip())
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
if models:
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
return models
# Log warning if no allow-list configured for proxy providers
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
logging.warning(
f"No model allow-list configured for {self.FRIENDLY_NAME}. "
f"Set {env_var} to restrict model access and control costs."
)
return None
def _is_localhost_url(self) -> bool:
"""Check if the base URL points to localhost.
Returns:
True if URL is localhost, False otherwise
"""
if not self.base_url:
return False
try:
parsed = urlparse(self.base_url)
hostname = parsed.hostname
# Check for common localhost patterns
if hostname in ['localhost', '127.0.0.1', '::1']:
if hostname in ["localhost", "127.0.0.1", "::1"]:
return True
return False
except Exception:
return False
def _validate_base_url(self) -> None:
"""Validate base URL for security (SSRF protection).
Raises:
ValueError: If URL is invalid or potentially unsafe
"""
if not self.base_url:
return
try:
parsed = urlparse(self.base_url)
# Check URL scheme - only allow http/https
if parsed.scheme not in ('http', 'https'):
if parsed.scheme not in ("http", "https"):
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
# Check hostname exists
if not parsed.hostname:
raise ValueError("URL must include a hostname")
# Check port - allow only standard HTTP/HTTPS ports
port = parsed.port
if port is None:
port = 443 if parsed.scheme == 'https' else 80
port = 443 if parsed.scheme == "https" else 80
# Allow common HTTP ports and some alternative ports
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
if port not in allowed_ports:
raise ValueError(
f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}"
)
raise ValueError(f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}")
# Check against allowed domains if configured
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
if allowed_domains:
hostname_lower = parsed.hostname.lower()
if not any(
hostname_lower == domain or
hostname_lower.endswith('.' + domain)
for domain in allowed_domains
hostname_lower == domain or hostname_lower.endswith("." + domain) for domain in allowed_domains
):
raise ValueError(
f"Domain not in allow-list: {parsed.hostname}. "
f"Allowed domains: {allowed_domains}"
f"Domain not in allow-list: {parsed.hostname}. " f"Allowed domains: {allowed_domains}"
)
# Try to resolve hostname and check if it's a private IP
# Skip for localhost addresses which are commonly used for development
if parsed.hostname not in ['localhost', '127.0.0.1', '::1']:
if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
try:
# Get all IP addresses for the hostname
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
for family, _, _, _, sockaddr in addr_info:
for _family, _, _, _, sockaddr in addr_info:
ip_str = sockaddr[0]
try:
ip = ipaddress.ip_address(ip_str)
# Check for dangerous IP ranges
if (ip.is_private or ip.is_loopback or ip.is_link_local or
ip.is_multicast or ip.is_reserved or ip.is_unspecified):
if (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
):
raise ValueError(
f"URL resolves to restricted IP address: {ip_str}. "
"This could be a security risk (SSRF)."
@@ -177,16 +176,16 @@ class OpenAICompatibleProvider(ModelProvider):
if "restricted IP address" in str(ve):
raise
continue
except socket.gaierror as e:
# If we can't resolve the hostname, it's suspicious
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
@property
def client(self):
"""Lazy initialization of OpenAI client with security checks."""
@@ -194,21 +193,21 @@ class OpenAICompatibleProvider(ModelProvider):
client_kwargs = {
"api_key": self.api_key,
}
if self.base_url:
client_kwargs["base_url"] = self.base_url
if self.organization:
client_kwargs["organization"] = self.organization
# Add default headers if any
if self.DEFAULT_HEADERS:
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
self._client = OpenAI(**client_kwargs)
return self._client
def generate_content(
self,
prompt: str,
@@ -219,7 +218,7 @@ class OpenAICompatibleProvider(ModelProvider):
**kwargs,
) -> ModelResponse:
"""Generate content using the OpenAI-compatible API.
Args:
prompt: User prompt to send to the model
model_name: Name of the model to use
@@ -227,50 +226,49 @@ class OpenAICompatibleProvider(ModelProvider):
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Validate model name against allow-list
if not self.validate_model_name(model_name):
raise ValueError(
f"Model '{model_name}' not in allowed models list. "
f"Allowed models: {self.allowed_models}"
f"Model '{model_name}' not in allowed models list. " f"Allowed models: {self.allowed_models}"
)
# Validate parameters
self.validate_parameters(model_name, temperature)
# Prepare messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# Prepare completion parameters
completion_params = {
"model": model_name,
"messages": messages,
"temperature": temperature,
}
# Add max tokens if specified
if max_output_tokens:
completion_params["max_tokens"] = max_output_tokens
# Add any additional OpenAI-specific parameters
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
completion_params[key] = value
try:
# Generate completion
response = self.client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
@@ -284,39 +282,39 @@ class OpenAICompatibleProvider(ModelProvider):
"created": response.created,
},
)
except Exception as e:
# Log error and re-raise with more context
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from e
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text.
Uses a layered approach:
1. Try provider-specific token counting endpoint
2. Try tiktoken for known model families
3. Fall back to character-based estimation
Args:
text: Text to count tokens for
model_name: Model name for tokenizer selection
Returns:
Estimated token count
"""
# 1. Check if provider has a remote token counting endpoint
if hasattr(self, 'count_tokens_remote'):
if hasattr(self, "count_tokens_remote"):
try:
return self.count_tokens_remote(text, model_name)
except Exception as e:
logging.debug(f"Remote token counting failed: {e}")
# 2. Try tiktoken for known models
try:
import tiktoken
# Try to get encoding for the specific model
try:
encoding = tiktoken.encoding_for_model(model_name)
@@ -326,24 +324,24 @@ class OpenAICompatibleProvider(ModelProvider):
encoding = tiktoken.get_encoding("cl100k_base")
else:
encoding = tiktoken.get_encoding("cl100k_base") # Default
return len(encoding.encode(text))
except (ImportError, Exception) as e:
logging.debug(f"Tiktoken not available or failed: {e}")
# 3. Fall back to character-based estimation
logging.warning(
f"No specific tokenizer available for '{model_name}'. "
"Using character-based estimation (~4 chars per token)."
)
return len(text) // 4
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters.
For proxy providers, this may use generic capabilities.
Args:
model_name: Model to validate for
temperature: Temperature to validate
@@ -351,67 +349,66 @@ class OpenAICompatibleProvider(ModelProvider):
"""
try:
capabilities = self.get_capabilities(model_name)
# Check if we're using generic capabilities
if hasattr(capabilities, '_is_generic'):
if hasattr(capabilities, "_is_generic"):
logging.debug(
f"Using generic parameter validation for {model_name}. "
"Actual model constraints may differ."
f"Using generic parameter validation for {model_name}. " "Actual model constraints may differ."
)
# Validate temperature using parent class method
super().validate_parameters(model_name, temperature, **kwargs)
except Exception as e:
# For proxy providers, we might not have accurate capabilities
# Log warning but don't fail
logging.warning(f"Parameter validation limited for {model_name}: {e}")
def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from OpenAI response.
Args:
response: OpenAI API response object
Returns:
Dictionary with usage statistics
"""
usage = {}
if hasattr(response, "usage") and response.usage:
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0)
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0)
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0)
return usage
@abstractmethod
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def get_provider_type(self) -> ProviderType:
"""Get the provider type.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported.
Must be implemented by subclasses.
"""
pass
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode.
Default is False for OpenAI-compatible providers.
"""
return False
return False