Breaking change: openrouter_models.json -> custom_models.json
* Support for Custom URLs and custom models, including locally hosted models such as ollama * Support for native + openrouter + local models (i.e. dozens of models) means you can start delegating sub-tasks to particular models or work to local models such as localizations or other boring work etc. * Several tests added * precommit to also include untracked (new) files * Logfile auto rollover * Improved logging
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
@@ -36,7 +35,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
base_url: Base URL for the API endpoint
|
||||
**kwargs: Additional configuration options
|
||||
**kwargs: Additional configuration options including timeout
|
||||
"""
|
||||
super().__init__(api_key, **kwargs)
|
||||
self._client = None
|
||||
@@ -44,6 +43,9 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
self.organization = kwargs.get("organization")
|
||||
self.allowed_models = self._parse_allowed_models()
|
||||
|
||||
# Configure timeouts - especially important for custom/local endpoints
|
||||
self.timeout_config = self._configure_timeouts(**kwargs)
|
||||
|
||||
# Validate base URL for security
|
||||
if self.base_url:
|
||||
self._validate_base_url()
|
||||
@@ -82,11 +84,59 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
return None
|
||||
|
||||
def _is_localhost_url(self) -> bool:
|
||||
"""Check if the base URL points to localhost.
|
||||
def _configure_timeouts(self, **kwargs):
|
||||
"""Configure timeout settings based on provider type and custom settings.
|
||||
|
||||
Custom URLs and local models often need longer timeouts due to:
|
||||
- Network latency on local networks
|
||||
- Extended thinking models taking longer to respond
|
||||
- Local inference being slower than cloud APIs
|
||||
|
||||
Returns:
|
||||
True if URL is localhost, False otherwise
|
||||
httpx.Timeout object with appropriate timeout settings
|
||||
"""
|
||||
import httpx
|
||||
|
||||
# Default timeouts - more generous for custom/local endpoints
|
||||
default_connect = 30.0 # 30 seconds for connection (vs OpenAI's 5s)
|
||||
default_read = 600.0 # 10 minutes for reading (same as OpenAI default)
|
||||
default_write = 600.0 # 10 minutes for writing
|
||||
default_pool = 600.0 # 10 minutes for pool
|
||||
|
||||
# For custom/local URLs, use even longer timeouts
|
||||
if self.base_url and self._is_localhost_url():
|
||||
default_connect = 60.0 # 1 minute for local connections
|
||||
default_read = 1800.0 # 30 minutes for local models (extended thinking)
|
||||
default_write = 1800.0 # 30 minutes for local models
|
||||
default_pool = 1800.0 # 30 minutes for local models
|
||||
logging.info(f"Using extended timeouts for local endpoint: {self.base_url}")
|
||||
elif self.base_url:
|
||||
default_connect = 45.0 # 45 seconds for custom remote endpoints
|
||||
default_read = 900.0 # 15 minutes for custom remote endpoints
|
||||
default_write = 900.0 # 15 minutes for custom remote endpoints
|
||||
default_pool = 900.0 # 15 minutes for custom remote endpoints
|
||||
logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}")
|
||||
|
||||
# Allow override via kwargs or environment variables in future, for now...
|
||||
connect_timeout = kwargs.get("connect_timeout", float(os.getenv("CUSTOM_CONNECT_TIMEOUT", default_connect)))
|
||||
read_timeout = kwargs.get("read_timeout", float(os.getenv("CUSTOM_READ_TIMEOUT", default_read)))
|
||||
write_timeout = kwargs.get("write_timeout", float(os.getenv("CUSTOM_WRITE_TIMEOUT", default_write)))
|
||||
pool_timeout = kwargs.get("pool_timeout", float(os.getenv("CUSTOM_POOL_TIMEOUT", default_pool)))
|
||||
|
||||
timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout)
|
||||
|
||||
logging.debug(
|
||||
f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, "
|
||||
f"Write: {write_timeout}s, Pool: {pool_timeout}s"
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
||||
def _is_localhost_url(self) -> bool:
|
||||
"""Check if the base URL points to localhost or local network.
|
||||
|
||||
Returns:
|
||||
True if URL is localhost or local network, False otherwise
|
||||
"""
|
||||
if not self.base_url:
|
||||
return False
|
||||
@@ -99,6 +149,19 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
||||
return True
|
||||
|
||||
# Check for Docker internal hostnames (like host.docker.internal)
|
||||
if hostname and ("docker.internal" in hostname or "host.docker.internal" in hostname):
|
||||
return True
|
||||
|
||||
# Check for private network ranges (local network)
|
||||
if hostname:
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
return ip.is_private or ip.is_loopback
|
||||
except ValueError:
|
||||
# Not an IP address, might be a hostname
|
||||
pass
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
@@ -123,64 +186,10 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL must include a hostname")
|
||||
|
||||
# Check port - allow only standard HTTP/HTTPS ports
|
||||
# Check port is valid (if specified)
|
||||
port = parsed.port
|
||||
if port is None:
|
||||
port = 443 if parsed.scheme == "https" else 80
|
||||
|
||||
# Allow common HTTP ports and some alternative ports
|
||||
allowed_ports = {80, 443, 8080, 8443, 4000, 3000} # Common API ports
|
||||
if port not in allowed_ports:
|
||||
raise ValueError(f"Port {port} not allowed. Allowed ports: {sorted(allowed_ports)}")
|
||||
|
||||
# Check against allowed domains if configured
|
||||
allowed_domains = os.getenv("ALLOWED_BASE_DOMAINS", "").split(",")
|
||||
allowed_domains = [d.strip().lower() for d in allowed_domains if d.strip()]
|
||||
|
||||
if allowed_domains:
|
||||
hostname_lower = parsed.hostname.lower()
|
||||
if not any(
|
||||
hostname_lower == domain or hostname_lower.endswith("." + domain) for domain in allowed_domains
|
||||
):
|
||||
raise ValueError(
|
||||
f"Domain not in allow-list: {parsed.hostname}. " f"Allowed domains: {allowed_domains}"
|
||||
)
|
||||
|
||||
# Try to resolve hostname and check if it's a private IP
|
||||
# Skip for localhost addresses which are commonly used for development
|
||||
if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
|
||||
try:
|
||||
# Get all IP addresses for the hostname
|
||||
addr_info = socket.getaddrinfo(parsed.hostname, port, proto=socket.IPPROTO_TCP)
|
||||
|
||||
for _family, _, _, _, sockaddr in addr_info:
|
||||
ip_str = sockaddr[0]
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
|
||||
# Check for dangerous IP ranges
|
||||
if (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
or ip.is_unspecified
|
||||
):
|
||||
raise ValueError(
|
||||
f"URL resolves to restricted IP address: {ip_str}. "
|
||||
"This could be a security risk (SSRF)."
|
||||
)
|
||||
except ValueError as ve:
|
||||
# Invalid IP address format or restricted IP - re-raise if it's our security error
|
||||
if "restricted IP address" in str(ve):
|
||||
raise
|
||||
continue
|
||||
|
||||
except socket.gaierror as e:
|
||||
# If we can't resolve the hostname, it's suspicious
|
||||
raise ValueError(f"Cannot resolve hostname '{parsed.hostname}': {e}")
|
||||
|
||||
if port is not None and (port < 1 or port > 65535):
|
||||
raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.")
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
@@ -188,7 +197,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy initialization of OpenAI client with security checks."""
|
||||
"""Lazy initialization of OpenAI client with security checks and timeout configuration."""
|
||||
if self._client is None:
|
||||
client_kwargs = {
|
||||
"api_key": self.api_key,
|
||||
@@ -204,6 +213,11 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
if self.DEFAULT_HEADERS:
|
||||
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||||
|
||||
# Add configured timeout settings
|
||||
if hasattr(self, "timeout_config") and self.timeout_config:
|
||||
client_kwargs["timeout"] = self.timeout_config
|
||||
logging.debug(f"OpenAI client initialized with custom timeout: {self.timeout_config}")
|
||||
|
||||
self._client = OpenAI(**client_kwargs)
|
||||
|
||||
return self._client
|
||||
|
||||
Reference in New Issue
Block a user