848 lines
35 KiB
Python
848 lines
35 KiB
Python
"""Base class for OpenAI-compatible API providers."""
|
||
|
||
import copy
|
||
import ipaddress
|
||
import logging
|
||
from typing import Optional
|
||
from urllib.parse import urlparse
|
||
|
||
from openai import OpenAI
|
||
|
||
from utils.env import get_env, suppress_env_vars
|
||
from utils.image_utils import validate_image
|
||
|
||
from .base import ModelProvider
|
||
from .shared import (
|
||
ModelCapabilities,
|
||
ModelResponse,
|
||
ProviderType,
|
||
)
|
||
|
||
|
||
class OpenAICompatibleProvider(ModelProvider):
|
||
"""Shared implementation for OpenAI API lookalikes.
|
||
|
||
The class owns HTTP client configuration (timeouts, proxy hardening,
|
||
custom headers) and normalises the OpenAI SDK responses into
|
||
:class:`~providers.shared.ModelResponse`. Concrete subclasses only need to
|
||
provide capability metadata and any provider-specific request tweaks.
|
||
"""
|
||
|
||
DEFAULT_HEADERS = {}
|
||
FRIENDLY_NAME = "OpenAI Compatible"
|
||
|
||
def __init__(self, api_key: str, base_url: str = None, **kwargs):
|
||
"""Initialize the provider with API key and optional base URL.
|
||
|
||
Args:
|
||
api_key: API key for authentication
|
||
base_url: Base URL for the API endpoint
|
||
**kwargs: Additional configuration options including timeout
|
||
"""
|
||
self._allowed_alias_cache: dict[str, str] = {}
|
||
super().__init__(api_key, **kwargs)
|
||
self._client = None
|
||
self.base_url = base_url
|
||
self.organization = kwargs.get("organization")
|
||
self.allowed_models = self._parse_allowed_models()
|
||
|
||
# Configure timeouts - especially important for custom/local endpoints
|
||
self.timeout_config = self._configure_timeouts(**kwargs)
|
||
|
||
# Validate base URL for security
|
||
if self.base_url:
|
||
self._validate_base_url()
|
||
|
||
# Warn if using external URL without authentication
|
||
if self.base_url and not self._is_localhost_url() and not api_key:
|
||
logging.warning(
|
||
f"Using external URL '{self.base_url}' without API key. "
|
||
"This may be insecure. Consider setting an API key for authentication."
|
||
)
|
||
|
||
def _ensure_model_allowed(
|
||
self,
|
||
capabilities: ModelCapabilities,
|
||
canonical_name: str,
|
||
requested_name: str,
|
||
) -> None:
|
||
"""Respect provider-specific allowlists before default restriction checks."""
|
||
|
||
super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
|
||
|
||
if self.allowed_models is not None:
|
||
requested = requested_name.lower()
|
||
canonical = canonical_name.lower()
|
||
|
||
if requested not in self.allowed_models and canonical not in self.allowed_models:
|
||
allowed = False
|
||
for allowed_entry in list(self.allowed_models):
|
||
normalized_resolved = self._allowed_alias_cache.get(allowed_entry)
|
||
if normalized_resolved is None:
|
||
try:
|
||
resolved_name = self._resolve_model_name(allowed_entry)
|
||
except Exception:
|
||
continue
|
||
|
||
if not resolved_name:
|
||
continue
|
||
|
||
normalized_resolved = resolved_name.lower()
|
||
self._allowed_alias_cache[allowed_entry] = normalized_resolved
|
||
|
||
if normalized_resolved == canonical:
|
||
# Canonical match discovered via alias resolution – mark as allowed and
|
||
# memoise the canonical entry for future lookups.
|
||
allowed = True
|
||
self._allowed_alias_cache[canonical] = canonical
|
||
self.allowed_models.add(canonical)
|
||
break
|
||
|
||
if not allowed:
|
||
raise ValueError(
|
||
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
||
)
|
||
|
||
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||
"""Parse allowed models from environment variable.
|
||
|
||
Returns:
|
||
Set of allowed model names (lowercase) or None if not configured
|
||
"""
|
||
# Get provider-specific allowed models
|
||
provider_type = self.get_provider_type().value.upper()
|
||
env_var = f"{provider_type}_ALLOWED_MODELS"
|
||
models_str = get_env(env_var, "") or ""
|
||
|
||
if models_str:
|
||
# Parse and normalize to lowercase for case-insensitive comparison
|
||
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
||
if models:
|
||
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
||
self._allowed_alias_cache = {}
|
||
return models
|
||
|
||
# Log info if no allow-list configured for proxy providers
|
||
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
|
||
logging.info(
|
||
f"Model allow-list not configured for {self.FRIENDLY_NAME} - all models permitted. "
|
||
f"To restrict access, set {env_var} with comma-separated model names."
|
||
)
|
||
|
||
return None
|
||
|
||
def _configure_timeouts(self, **kwargs):
|
||
"""Configure timeout settings based on provider type and custom settings.
|
||
|
||
Custom URLs and local models often need longer timeouts due to:
|
||
- Network latency on local networks
|
||
- Extended thinking models taking longer to respond
|
||
- Local inference being slower than cloud APIs
|
||
|
||
Returns:
|
||
httpx.Timeout object with appropriate timeout settings
|
||
"""
|
||
import httpx
|
||
|
||
# Default timeouts - more generous for custom/local endpoints
|
||
default_connect = 30.0 # 30 seconds for connection (vs OpenAI's 5s)
|
||
default_read = 600.0 # 10 minutes for reading (same as OpenAI default)
|
||
default_write = 600.0 # 10 minutes for writing
|
||
default_pool = 600.0 # 10 minutes for pool
|
||
|
||
# For custom/local URLs, use even longer timeouts
|
||
if self.base_url and self._is_localhost_url():
|
||
default_connect = 60.0 # 1 minute for local connections
|
||
default_read = 1800.0 # 30 minutes for local models (extended thinking)
|
||
default_write = 1800.0 # 30 minutes for local models
|
||
default_pool = 1800.0 # 30 minutes for local models
|
||
logging.info(f"Using extended timeouts for local endpoint: {self.base_url}")
|
||
elif self.base_url:
|
||
default_connect = 45.0 # 45 seconds for custom remote endpoints
|
||
default_read = 900.0 # 15 minutes for custom remote endpoints
|
||
default_write = 900.0 # 15 minutes for custom remote endpoints
|
||
default_pool = 900.0 # 15 minutes for custom remote endpoints
|
||
logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}")
|
||
|
||
# Allow override via kwargs or environment variables in future, for now...
|
||
connect_timeout = kwargs.get("connect_timeout")
|
||
if connect_timeout is None:
|
||
connect_timeout_raw = get_env("CUSTOM_CONNECT_TIMEOUT")
|
||
connect_timeout = float(connect_timeout_raw) if connect_timeout_raw is not None else float(default_connect)
|
||
|
||
read_timeout = kwargs.get("read_timeout")
|
||
if read_timeout is None:
|
||
read_timeout_raw = get_env("CUSTOM_READ_TIMEOUT")
|
||
read_timeout = float(read_timeout_raw) if read_timeout_raw is not None else float(default_read)
|
||
|
||
write_timeout = kwargs.get("write_timeout")
|
||
if write_timeout is None:
|
||
write_timeout_raw = get_env("CUSTOM_WRITE_TIMEOUT")
|
||
write_timeout = float(write_timeout_raw) if write_timeout_raw is not None else float(default_write)
|
||
|
||
pool_timeout = kwargs.get("pool_timeout")
|
||
if pool_timeout is None:
|
||
pool_timeout_raw = get_env("CUSTOM_POOL_TIMEOUT")
|
||
pool_timeout = float(pool_timeout_raw) if pool_timeout_raw is not None else float(default_pool)
|
||
|
||
timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout)
|
||
|
||
logging.debug(
|
||
f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, "
|
||
f"Write: {write_timeout}s, Pool: {pool_timeout}s"
|
||
)
|
||
|
||
return timeout
|
||
|
||
def _is_localhost_url(self) -> bool:
|
||
"""Check if the base URL points to localhost or local network.
|
||
|
||
Returns:
|
||
True if URL is localhost or local network, False otherwise
|
||
"""
|
||
if not self.base_url:
|
||
return False
|
||
|
||
try:
|
||
parsed = urlparse(self.base_url)
|
||
hostname = parsed.hostname
|
||
|
||
# Check for common localhost patterns
|
||
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
||
return True
|
||
|
||
# Check for private network ranges (local network)
|
||
if hostname:
|
||
try:
|
||
ip = ipaddress.ip_address(hostname)
|
||
return ip.is_private or ip.is_loopback
|
||
except ValueError:
|
||
# Not an IP address, might be a hostname
|
||
pass
|
||
|
||
return False
|
||
except Exception:
|
||
return False
|
||
|
||
def _validate_base_url(self) -> None:
|
||
"""Validate base URL for security (SSRF protection).
|
||
|
||
Raises:
|
||
ValueError: If URL is invalid or potentially unsafe
|
||
"""
|
||
if not self.base_url:
|
||
return
|
||
|
||
try:
|
||
parsed = urlparse(self.base_url)
|
||
|
||
# Check URL scheme - only allow http/https
|
||
if parsed.scheme not in ("http", "https"):
|
||
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
|
||
|
||
# Check hostname exists
|
||
if not parsed.hostname:
|
||
raise ValueError("URL must include a hostname")
|
||
|
||
# Check port is valid (if specified)
|
||
port = parsed.port
|
||
if port is not None and (port < 1 or port > 65535):
|
||
raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.")
|
||
except Exception as e:
|
||
if isinstance(e, ValueError):
|
||
raise
|
||
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
|
||
|
||
@property
|
||
def client(self):
|
||
"""Lazy initialization of OpenAI client with security checks and timeout configuration."""
|
||
if self._client is None:
|
||
import httpx
|
||
|
||
proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]
|
||
|
||
with suppress_env_vars(*proxy_env_vars):
|
||
try:
|
||
# Create a custom httpx client that explicitly avoids proxy parameters
|
||
timeout_config = (
|
||
self.timeout_config
|
||
if hasattr(self, "timeout_config") and self.timeout_config
|
||
else httpx.Timeout(30.0)
|
||
)
|
||
|
||
# Create httpx client with minimal config to avoid proxy conflicts
|
||
# Note: proxies parameter was removed in httpx 0.28.0
|
||
# Check for test transport injection
|
||
if hasattr(self, "_test_transport"):
|
||
# Use custom transport for testing (HTTP recording/replay)
|
||
http_client = httpx.Client(
|
||
transport=self._test_transport,
|
||
timeout=timeout_config,
|
||
follow_redirects=True,
|
||
)
|
||
else:
|
||
# Normal production client
|
||
http_client = httpx.Client(
|
||
timeout=timeout_config,
|
||
follow_redirects=True,
|
||
)
|
||
|
||
# Keep client initialization minimal to avoid proxy parameter conflicts
|
||
client_kwargs = {
|
||
"api_key": self.api_key,
|
||
"http_client": http_client,
|
||
}
|
||
|
||
if self.base_url:
|
||
client_kwargs["base_url"] = self.base_url
|
||
|
||
if self.organization:
|
||
client_kwargs["organization"] = self.organization
|
||
|
||
# Add default headers if any
|
||
if self.DEFAULT_HEADERS:
|
||
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||
|
||
logging.debug(
|
||
"OpenAI client initialized with custom httpx client and timeout: %s",
|
||
timeout_config,
|
||
)
|
||
|
||
# Create OpenAI client with custom httpx client
|
||
self._client = OpenAI(**client_kwargs)
|
||
|
||
except Exception as e:
|
||
# If all else fails, try absolute minimal client without custom httpx
|
||
logging.warning(
|
||
"Failed to create client with custom httpx, falling back to minimal config: %s",
|
||
e,
|
||
)
|
||
try:
|
||
minimal_kwargs = {"api_key": self.api_key}
|
||
if self.base_url:
|
||
minimal_kwargs["base_url"] = self.base_url
|
||
self._client = OpenAI(**minimal_kwargs)
|
||
except Exception as fallback_error:
|
||
logging.error("Even minimal OpenAI client creation failed: %s", fallback_error)
|
||
raise
|
||
|
||
return self._client
|
||
|
||
def _sanitize_for_logging(self, params: dict) -> dict:
|
||
"""Sanitize sensitive data from parameters before logging.
|
||
|
||
Args:
|
||
params: Dictionary of API parameters
|
||
|
||
Returns:
|
||
dict: Sanitized copy of parameters safe for logging
|
||
"""
|
||
sanitized = copy.deepcopy(params)
|
||
|
||
# Sanitize messages content
|
||
if "input" in sanitized:
|
||
for msg in sanitized.get("input", []):
|
||
if isinstance(msg, dict) and "content" in msg:
|
||
for content_item in msg.get("content", []):
|
||
if isinstance(content_item, dict) and "text" in content_item:
|
||
# Truncate long text and add ellipsis
|
||
text = content_item["text"]
|
||
if len(text) > 100:
|
||
content_item["text"] = text[:100] + "... [truncated]"
|
||
|
||
# Remove any API keys that might be in headers/auth
|
||
sanitized.pop("api_key", None)
|
||
sanitized.pop("authorization", None)
|
||
|
||
return sanitized
|
||
|
||
def _safe_extract_output_text(self, response) -> str:
|
||
"""Safely extract output_text from o3-pro response with validation.
|
||
|
||
Args:
|
||
response: Response object from OpenAI SDK
|
||
|
||
Returns:
|
||
str: The output text content
|
||
|
||
Raises:
|
||
ValueError: If output_text is missing, None, or not a string
|
||
"""
|
||
logging.debug(f"Response object type: {type(response)}")
|
||
logging.debug(f"Response attributes: {dir(response)}")
|
||
|
||
if not hasattr(response, "output_text"):
|
||
raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
|
||
|
||
content = response.output_text
|
||
logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
|
||
|
||
if content is None:
|
||
raise ValueError("o3-pro returned None for output_text")
|
||
|
||
if not isinstance(content, str):
|
||
raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}")
|
||
|
||
return content
|
||
|
||
def _generate_with_responses_endpoint(
|
||
self,
|
||
model_name: str,
|
||
messages: list,
|
||
temperature: float,
|
||
max_output_tokens: Optional[int] = None,
|
||
**kwargs,
|
||
) -> ModelResponse:
|
||
"""Generate content using the /v1/responses endpoint for o3-pro via OpenAI library."""
|
||
# Convert messages to the correct format for responses endpoint
|
||
input_messages = []
|
||
|
||
for message in messages:
|
||
role = message.get("role", "")
|
||
content = message.get("content", "")
|
||
|
||
if role == "system":
|
||
# For o3-pro, system messages should be handled carefully to avoid policy violations
|
||
# Instead of prefixing with "System:", we'll include the system content naturally
|
||
input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
|
||
elif role == "user":
|
||
input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
|
||
elif role == "assistant":
|
||
input_messages.append({"role": "assistant", "content": [{"type": "output_text", "text": content}]})
|
||
|
||
# Prepare completion parameters for responses endpoint
|
||
# Based on OpenAI documentation, use nested reasoning object for responses endpoint
|
||
completion_params = {
|
||
"model": model_name,
|
||
"input": input_messages,
|
||
"reasoning": {"effort": "medium"}, # Use nested object for responses endpoint
|
||
"store": True,
|
||
}
|
||
|
||
# Add max tokens if specified (using max_completion_tokens for responses endpoint)
|
||
if max_output_tokens:
|
||
completion_params["max_completion_tokens"] = max_output_tokens
|
||
|
||
# For responses endpoint, we only add parameters that are explicitly supported
|
||
# Remove unsupported chat completion parameters that may cause API errors
|
||
|
||
# Retry logic with progressive delays
|
||
max_retries = 4
|
||
retry_delays = [1, 3, 5, 8]
|
||
attempt_counter = {"value": 0}
|
||
|
||
def _attempt() -> ModelResponse:
|
||
attempt_counter["value"] += 1
|
||
import json
|
||
|
||
sanitized_params = self._sanitize_for_logging(completion_params)
|
||
logging.info(
|
||
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
|
||
)
|
||
|
||
response = self.client.responses.create(**completion_params)
|
||
|
||
content = self._safe_extract_output_text(response)
|
||
|
||
usage = None
|
||
if hasattr(response, "usage"):
|
||
usage = self._extract_usage(response)
|
||
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
|
||
input_tokens = getattr(response, "input_tokens", 0) or 0
|
||
output_tokens = getattr(response, "output_tokens", 0) or 0
|
||
usage = {
|
||
"input_tokens": input_tokens,
|
||
"output_tokens": output_tokens,
|
||
"total_tokens": input_tokens + output_tokens,
|
||
}
|
||
|
||
return ModelResponse(
|
||
content=content,
|
||
usage=usage,
|
||
model_name=model_name,
|
||
friendly_name=self.FRIENDLY_NAME,
|
||
provider=self.get_provider_type(),
|
||
metadata={
|
||
"model": getattr(response, "model", model_name),
|
||
"id": getattr(response, "id", ""),
|
||
"created": getattr(response, "created_at", 0),
|
||
"endpoint": "responses",
|
||
},
|
||
)
|
||
|
||
try:
|
||
return self._run_with_retries(
|
||
operation=_attempt,
|
||
max_attempts=max_retries,
|
||
delays=retry_delays,
|
||
log_prefix="o3-pro responses endpoint",
|
||
)
|
||
except Exception as exc:
|
||
attempts = max(attempt_counter["value"], 1)
|
||
error_msg = f"o3-pro responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}"
|
||
logging.error(error_msg)
|
||
raise RuntimeError(error_msg) from exc
|
||
|
||
def generate_content(
|
||
self,
|
||
prompt: str,
|
||
model_name: str,
|
||
system_prompt: Optional[str] = None,
|
||
temperature: float = 0.3,
|
||
max_output_tokens: Optional[int] = None,
|
||
images: Optional[list[str]] = None,
|
||
**kwargs,
|
||
) -> ModelResponse:
|
||
"""Generate content using the OpenAI-compatible API.
|
||
|
||
Args:
|
||
prompt: User prompt to send to the model
|
||
model_name: Canonical model name or its alias
|
||
system_prompt: Optional system prompt for model behavior
|
||
temperature: Sampling temperature
|
||
max_output_tokens: Maximum tokens to generate
|
||
images: Optional list of image paths or data URLs to include with the prompt (for vision models)
|
||
**kwargs: Additional provider-specific parameters
|
||
|
||
Returns:
|
||
ModelResponse with generated content and metadata
|
||
"""
|
||
# Validate model name against allow-list
|
||
if not self.validate_model_name(model_name):
|
||
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
|
||
|
||
capabilities: Optional[ModelCapabilities]
|
||
try:
|
||
capabilities = self.get_capabilities(model_name)
|
||
except Exception as exc:
|
||
logging.debug(f"Falling back to generic capabilities for {model_name}: {exc}")
|
||
capabilities = None
|
||
|
||
# Get effective temperature for this model from capabilities when available
|
||
if capabilities:
|
||
effective_temperature = capabilities.get_effective_temperature(temperature)
|
||
if effective_temperature is not None and effective_temperature != temperature:
|
||
logging.debug(
|
||
f"Adjusting temperature from {temperature} to {effective_temperature} for model {model_name}"
|
||
)
|
||
else:
|
||
effective_temperature = temperature
|
||
|
||
# Only validate if temperature is not None (meaning the model supports it)
|
||
if effective_temperature is not None:
|
||
# Validate parameters with the effective temperature
|
||
self.validate_parameters(model_name, effective_temperature)
|
||
|
||
# Resolve to canonical model name
|
||
resolved_model = self._resolve_model_name(model_name)
|
||
|
||
# Prepare messages
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
|
||
# Prepare user message with text and potentially images
|
||
user_content = []
|
||
user_content.append({"type": "text", "text": prompt})
|
||
|
||
# Add images if provided and model supports vision
|
||
if images and capabilities and capabilities.supports_images:
|
||
for image_path in images:
|
||
try:
|
||
image_content = self._process_image(image_path)
|
||
if image_content:
|
||
user_content.append(image_content)
|
||
except Exception as e:
|
||
logging.warning(f"Failed to process image {image_path}: {e}")
|
||
# Continue with other images and text
|
||
continue
|
||
elif images and (not capabilities or not capabilities.supports_images):
|
||
logging.warning(f"Model {resolved_model} does not support images, ignoring {len(images)} image(s)")
|
||
|
||
# Add user message
|
||
if len(user_content) == 1:
|
||
# Only text content, use simple string format for compatibility
|
||
messages.append({"role": "user", "content": prompt})
|
||
else:
|
||
# Text + images, use content array format
|
||
messages.append({"role": "user", "content": user_content})
|
||
|
||
# Prepare completion parameters
|
||
# Always disable streaming for OpenRouter
|
||
# MCP doesn't use streaming, and this avoids issues with O3 model access
|
||
completion_params = {
|
||
"model": resolved_model,
|
||
"messages": messages,
|
||
"stream": False,
|
||
}
|
||
|
||
# Use the effective temperature we calculated earlier
|
||
supports_sampling = effective_temperature is not None
|
||
|
||
if supports_sampling:
|
||
completion_params["temperature"] = effective_temperature
|
||
|
||
# Add max tokens if specified and model supports it
|
||
# O3/O4 models that don't support temperature also don't support max_tokens
|
||
if max_output_tokens and supports_sampling:
|
||
completion_params["max_tokens"] = max_output_tokens
|
||
|
||
# Add any additional OpenAI-specific parameters
|
||
# Use capabilities to filter parameters for reasoning models
|
||
for key, value in kwargs.items():
|
||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||
# Reasoning models (those that don't support temperature) also don't support these parameters
|
||
if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty", "stream"]:
|
||
continue # Skip unsupported parameters for reasoning models
|
||
completion_params[key] = value
|
||
|
||
# Check if this model needs the Responses API endpoint
|
||
# Prefer capability metadata; fall back to static map when capabilities unavailable
|
||
use_responses_api = False
|
||
if capabilities is not None:
|
||
use_responses_api = getattr(capabilities, "use_openai_response_api", False)
|
||
else:
|
||
static_capabilities = self.get_all_model_capabilities().get(resolved_model)
|
||
if static_capabilities is not None:
|
||
use_responses_api = getattr(static_capabilities, "use_openai_response_api", False)
|
||
|
||
if use_responses_api:
|
||
# These models require the /v1/responses endpoint for stateful context
|
||
# If it fails, we should not fall back to chat/completions
|
||
return self._generate_with_responses_endpoint(
|
||
model_name=resolved_model,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_output_tokens=max_output_tokens,
|
||
**kwargs,
|
||
)
|
||
|
||
# Retry logic with progressive delays
|
||
max_retries = 4 # Total of 4 attempts
|
||
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
||
attempt_counter = {"value": 0}
|
||
|
||
def _attempt() -> ModelResponse:
|
||
attempt_counter["value"] += 1
|
||
response = self.client.chat.completions.create(**completion_params)
|
||
|
||
content = response.choices[0].message.content
|
||
usage = self._extract_usage(response)
|
||
|
||
return ModelResponse(
|
||
content=content,
|
||
usage=usage,
|
||
model_name=resolved_model,
|
||
friendly_name=self.FRIENDLY_NAME,
|
||
provider=self.get_provider_type(),
|
||
metadata={
|
||
"finish_reason": response.choices[0].finish_reason,
|
||
"model": response.model,
|
||
"id": response.id,
|
||
"created": response.created,
|
||
},
|
||
)
|
||
|
||
try:
|
||
return self._run_with_retries(
|
||
operation=_attempt,
|
||
max_attempts=max_retries,
|
||
delays=retry_delays,
|
||
log_prefix=f"{self.FRIENDLY_NAME} API ({resolved_model})",
|
||
)
|
||
except Exception as exc:
|
||
attempts = max(attempt_counter["value"], 1)
|
||
error_msg = (
|
||
f"{self.FRIENDLY_NAME} API error for model {resolved_model} after {attempts} attempt"
|
||
f"{'s' if attempts > 1 else ''}: {exc}"
|
||
)
|
||
logging.error(error_msg)
|
||
raise RuntimeError(error_msg) from exc
|
||
|
||
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||
"""Validate model parameters.
|
||
|
||
For proxy providers, this may use generic capabilities.
|
||
|
||
Args:
|
||
model_name: Canonical model name or its alias
|
||
temperature: Temperature to validate
|
||
**kwargs: Additional parameters to validate
|
||
"""
|
||
try:
|
||
capabilities = self.get_capabilities(model_name)
|
||
|
||
# Check if we're using generic capabilities
|
||
if hasattr(capabilities, "_is_generic"):
|
||
logging.debug(
|
||
f"Using generic parameter validation for {model_name}. Actual model constraints may differ."
|
||
)
|
||
|
||
# Validate temperature using parent class method
|
||
super().validate_parameters(model_name, temperature, **kwargs)
|
||
|
||
except Exception as e:
|
||
# For proxy providers, we might not have accurate capabilities
|
||
# Log warning but don't fail
|
||
logging.warning(f"Parameter validation limited for {model_name}: {e}")
|
||
|
||
def _extract_usage(self, response) -> dict[str, int]:
|
||
"""Extract token usage from OpenAI response.
|
||
|
||
Args:
|
||
response: OpenAI API response object
|
||
|
||
Returns:
|
||
Dictionary with usage statistics
|
||
"""
|
||
usage = {}
|
||
|
||
if hasattr(response, "usage") and response.usage:
|
||
# Safely extract token counts with None handling
|
||
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) or 0
|
||
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) or 0
|
||
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) or 0
|
||
|
||
return usage
|
||
|
||
def count_tokens(self, text: str, model_name: str) -> int:
|
||
"""Count tokens using OpenAI-compatible tokenizer tables when available."""
|
||
|
||
resolved_model = self._resolve_model_name(model_name)
|
||
|
||
try:
|
||
import tiktoken
|
||
|
||
try:
|
||
encoding = tiktoken.encoding_for_model(resolved_model)
|
||
except KeyError:
|
||
encoding = tiktoken.get_encoding("cl100k_base")
|
||
|
||
return len(encoding.encode(text))
|
||
|
||
except (ImportError, Exception) as exc:
|
||
logging.debug("tiktoken unavailable for %s: %s", resolved_model, exc)
|
||
|
||
return super().count_tokens(text, model_name)
|
||
|
||
def _is_error_retryable(self, error: Exception) -> bool:
|
||
"""Determine if an error should be retried based on structured error codes.
|
||
|
||
Uses OpenAI API error structure instead of text pattern matching for reliability.
|
||
|
||
Args:
|
||
error: Exception from OpenAI API call
|
||
|
||
Returns:
|
||
True if error should be retried, False otherwise
|
||
"""
|
||
error_str = str(error).lower()
|
||
|
||
# Check for 429 errors first - these need special handling
|
||
if "429" in error_str:
|
||
# Try to extract structured error information
|
||
error_type = None
|
||
error_code = None
|
||
|
||
# Parse structured error from OpenAI API response
|
||
# Format: "Error code: 429 - {'error': {'type': 'tokens', 'code': 'rate_limit_exceeded', ...}}"
|
||
try:
|
||
import ast
|
||
import json
|
||
import re
|
||
|
||
# Extract JSON part from error string using regex
|
||
# Look for pattern: {...} (from first { to last })
|
||
json_match = re.search(r"\{.*\}", str(error))
|
||
if json_match:
|
||
json_like_str = json_match.group(0)
|
||
|
||
# First try: parse as Python literal (handles single quotes safely)
|
||
try:
|
||
error_data = ast.literal_eval(json_like_str)
|
||
except (ValueError, SyntaxError):
|
||
# Fallback: try JSON parsing with simple quote replacement
|
||
# (for cases where it's already valid JSON or simple replacements work)
|
||
json_str = json_like_str.replace("'", '"')
|
||
error_data = json.loads(json_str)
|
||
|
||
if "error" in error_data:
|
||
error_info = error_data["error"]
|
||
error_type = error_info.get("type")
|
||
error_code = error_info.get("code")
|
||
|
||
except (json.JSONDecodeError, ValueError, SyntaxError, AttributeError):
|
||
# Fall back to checking hasattr for OpenAI SDK exception objects
|
||
if hasattr(error, "response") and hasattr(error.response, "json"):
|
||
try:
|
||
response_data = error.response.json()
|
||
if "error" in response_data:
|
||
error_info = response_data["error"]
|
||
error_type = error_info.get("type")
|
||
error_code = error_info.get("code")
|
||
except Exception:
|
||
pass
|
||
|
||
# Determine if 429 is retryable based on structured error codes
|
||
if error_type == "tokens":
|
||
# Token-related 429s are typically non-retryable (request too large)
|
||
logging.debug(f"Non-retryable 429: token-related error (type={error_type}, code={error_code})")
|
||
return False
|
||
elif error_code in ["invalid_request_error", "context_length_exceeded"]:
|
||
# These are permanent failures
|
||
logging.debug(f"Non-retryable 429: permanent failure (type={error_type}, code={error_code})")
|
||
return False
|
||
else:
|
||
# Other 429s (like requests per minute) are retryable
|
||
logging.debug(f"Retryable 429: rate limiting (type={error_type}, code={error_code})")
|
||
return True
|
||
|
||
# For non-429 errors, check if they're retryable
|
||
retryable_indicators = [
|
||
"timeout",
|
||
"connection",
|
||
"network",
|
||
"temporary",
|
||
"unavailable",
|
||
"retry",
|
||
"408", # Request timeout
|
||
"500", # Internal server error
|
||
"502", # Bad gateway
|
||
"503", # Service unavailable
|
||
"504", # Gateway timeout
|
||
"ssl", # SSL errors
|
||
"handshake", # Handshake failures
|
||
]
|
||
|
||
return any(indicator in error_str for indicator in retryable_indicators)
|
||
|
||
def _process_image(self, image_path: str) -> Optional[dict]:
|
||
"""Process an image for OpenAI-compatible API."""
|
||
try:
|
||
if image_path.startswith("data:"):
|
||
# Validate the data URL
|
||
validate_image(image_path)
|
||
# Handle data URL: data:image/png;base64,iVBORw0...
|
||
return {"type": "image_url", "image_url": {"url": image_path}}
|
||
else:
|
||
# Use base class validation
|
||
image_bytes, mime_type = validate_image(image_path)
|
||
|
||
# Read and encode the image
|
||
import base64
|
||
|
||
image_data = base64.b64encode(image_bytes).decode()
|
||
logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'")
|
||
|
||
# Create data URL for OpenAI API
|
||
data_url = f"data:{mime_type};base64,{image_data}"
|
||
|
||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||
|
||
except ValueError as e:
|
||
logging.warning(str(e))
|
||
return None
|
||
except Exception as e:
|
||
logging.error(f"Error processing image {image_path}: {e}")
|
||
return None
|