fix: updated tests to override env variables they need instead of relying on the current values from .env
846 lines
34 KiB
Python
846 lines
34 KiB
Python
"""Base class for OpenAI-compatible API providers."""
|
||
|
||
import copy
|
||
import ipaddress
|
||
import logging
|
||
from typing import Optional
|
||
from urllib.parse import urlparse
|
||
|
||
from openai import OpenAI
|
||
|
||
from utils.env import get_env
|
||
from utils.image_utils import validate_image
|
||
|
||
from .base import ModelProvider
|
||
from .shared import (
|
||
ModelCapabilities,
|
||
ModelResponse,
|
||
ProviderType,
|
||
)
|
||
|
||
|
||
class OpenAICompatibleProvider(ModelProvider):
|
||
"""Shared implementation for OpenAI API lookalikes.
|
||
|
||
The class owns HTTP client configuration (timeouts, proxy hardening,
|
||
custom headers) and normalises the OpenAI SDK responses into
|
||
:class:`~providers.shared.ModelResponse`. Concrete subclasses only need to
|
||
provide capability metadata and any provider-specific request tweaks.
|
||
"""
|
||
|
||
DEFAULT_HEADERS = {}
|
||
FRIENDLY_NAME = "OpenAI Compatible"
|
||
|
||
def __init__(self, api_key: str, base_url: str = None, **kwargs):
|
||
"""Initialize the provider with API key and optional base URL.
|
||
|
||
Args:
|
||
api_key: API key for authentication
|
||
base_url: Base URL for the API endpoint
|
||
**kwargs: Additional configuration options including timeout
|
||
"""
|
||
self._allowed_alias_cache: dict[str, str] = {}
|
||
super().__init__(api_key, **kwargs)
|
||
self._client = None
|
||
self.base_url = base_url
|
||
self.organization = kwargs.get("organization")
|
||
self.allowed_models = self._parse_allowed_models()
|
||
|
||
# Configure timeouts - especially important for custom/local endpoints
|
||
self.timeout_config = self._configure_timeouts(**kwargs)
|
||
|
||
# Validate base URL for security
|
||
if self.base_url:
|
||
self._validate_base_url()
|
||
|
||
# Warn if using external URL without authentication
|
||
if self.base_url and not self._is_localhost_url() and not api_key:
|
||
logging.warning(
|
||
f"Using external URL '{self.base_url}' without API key. "
|
||
"This may be insecure. Consider setting an API key for authentication."
|
||
)
|
||
|
||
def _ensure_model_allowed(
|
||
self,
|
||
capabilities: ModelCapabilities,
|
||
canonical_name: str,
|
||
requested_name: str,
|
||
) -> None:
|
||
"""Respect provider-specific allowlists before default restriction checks."""
|
||
|
||
super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
|
||
|
||
if self.allowed_models is not None:
|
||
requested = requested_name.lower()
|
||
canonical = canonical_name.lower()
|
||
|
||
if requested not in self.allowed_models and canonical not in self.allowed_models:
|
||
allowed = False
|
||
for allowed_entry in list(self.allowed_models):
|
||
normalized_resolved = self._allowed_alias_cache.get(allowed_entry)
|
||
if normalized_resolved is None:
|
||
try:
|
||
resolved_name = self._resolve_model_name(allowed_entry)
|
||
except Exception:
|
||
continue
|
||
|
||
if not resolved_name:
|
||
continue
|
||
|
||
normalized_resolved = resolved_name.lower()
|
||
self._allowed_alias_cache[allowed_entry] = normalized_resolved
|
||
|
||
if normalized_resolved == canonical:
|
||
# Canonical match discovered via alias resolution – mark as allowed and
|
||
# memoise the canonical entry for future lookups.
|
||
allowed = True
|
||
self._allowed_alias_cache[canonical] = canonical
|
||
self.allowed_models.add(canonical)
|
||
break
|
||
|
||
if not allowed:
|
||
raise ValueError(
|
||
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
||
)
|
||
|
||
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||
"""Parse allowed models from environment variable.
|
||
|
||
Returns:
|
||
Set of allowed model names (lowercase) or None if not configured
|
||
"""
|
||
# Get provider-specific allowed models
|
||
provider_type = self.get_provider_type().value.upper()
|
||
env_var = f"{provider_type}_ALLOWED_MODELS"
|
||
models_str = get_env(env_var, "") or ""
|
||
|
||
if models_str:
|
||
# Parse and normalize to lowercase for case-insensitive comparison
|
||
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
||
if models:
|
||
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
||
self._allowed_alias_cache = {}
|
||
return models
|
||
|
||
# Log info if no allow-list configured for proxy providers
|
||
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
|
||
logging.info(
|
||
f"Model allow-list not configured for {self.FRIENDLY_NAME} - all models permitted. "
|
||
f"To restrict access, set {env_var} with comma-separated model names."
|
||
)
|
||
|
||
return None
|
||
|
||
def _configure_timeouts(self, **kwargs):
|
||
"""Configure timeout settings based on provider type and custom settings.
|
||
|
||
Custom URLs and local models often need longer timeouts due to:
|
||
- Network latency on local networks
|
||
- Extended thinking models taking longer to respond
|
||
- Local inference being slower than cloud APIs
|
||
|
||
Returns:
|
||
httpx.Timeout object with appropriate timeout settings
|
||
"""
|
||
import httpx
|
||
|
||
# Default timeouts - more generous for custom/local endpoints
|
||
default_connect = 30.0 # 30 seconds for connection (vs OpenAI's 5s)
|
||
default_read = 600.0 # 10 minutes for reading (same as OpenAI default)
|
||
default_write = 600.0 # 10 minutes for writing
|
||
default_pool = 600.0 # 10 minutes for pool
|
||
|
||
# For custom/local URLs, use even longer timeouts
|
||
if self.base_url and self._is_localhost_url():
|
||
default_connect = 60.0 # 1 minute for local connections
|
||
default_read = 1800.0 # 30 minutes for local models (extended thinking)
|
||
default_write = 1800.0 # 30 minutes for local models
|
||
default_pool = 1800.0 # 30 minutes for local models
|
||
logging.info(f"Using extended timeouts for local endpoint: {self.base_url}")
|
||
elif self.base_url:
|
||
default_connect = 45.0 # 45 seconds for custom remote endpoints
|
||
default_read = 900.0 # 15 minutes for custom remote endpoints
|
||
default_write = 900.0 # 15 minutes for custom remote endpoints
|
||
default_pool = 900.0 # 15 minutes for custom remote endpoints
|
||
logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}")
|
||
|
||
# Allow override via kwargs or environment variables in future, for now...
|
||
connect_timeout = kwargs.get("connect_timeout")
|
||
if connect_timeout is None:
|
||
connect_timeout_raw = get_env("CUSTOM_CONNECT_TIMEOUT")
|
||
connect_timeout = float(connect_timeout_raw) if connect_timeout_raw is not None else float(default_connect)
|
||
|
||
read_timeout = kwargs.get("read_timeout")
|
||
if read_timeout is None:
|
||
read_timeout_raw = get_env("CUSTOM_READ_TIMEOUT")
|
||
read_timeout = float(read_timeout_raw) if read_timeout_raw is not None else float(default_read)
|
||
|
||
write_timeout = kwargs.get("write_timeout")
|
||
if write_timeout is None:
|
||
write_timeout_raw = get_env("CUSTOM_WRITE_TIMEOUT")
|
||
write_timeout = float(write_timeout_raw) if write_timeout_raw is not None else float(default_write)
|
||
|
||
pool_timeout = kwargs.get("pool_timeout")
|
||
if pool_timeout is None:
|
||
pool_timeout_raw = get_env("CUSTOM_POOL_TIMEOUT")
|
||
pool_timeout = float(pool_timeout_raw) if pool_timeout_raw is not None else float(default_pool)
|
||
|
||
timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout)
|
||
|
||
logging.debug(
|
||
f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, "
|
||
f"Write: {write_timeout}s, Pool: {pool_timeout}s"
|
||
)
|
||
|
||
return timeout
|
||
|
||
def _is_localhost_url(self) -> bool:
|
||
"""Check if the base URL points to localhost or local network.
|
||
|
||
Returns:
|
||
True if URL is localhost or local network, False otherwise
|
||
"""
|
||
if not self.base_url:
|
||
return False
|
||
|
||
try:
|
||
parsed = urlparse(self.base_url)
|
||
hostname = parsed.hostname
|
||
|
||
# Check for common localhost patterns
|
||
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
||
return True
|
||
|
||
# Check for private network ranges (local network)
|
||
if hostname:
|
||
try:
|
||
ip = ipaddress.ip_address(hostname)
|
||
return ip.is_private or ip.is_loopback
|
||
except ValueError:
|
||
# Not an IP address, might be a hostname
|
||
pass
|
||
|
||
return False
|
||
except Exception:
|
||
return False
|
||
|
||
def _validate_base_url(self) -> None:
|
||
"""Validate base URL for security (SSRF protection).
|
||
|
||
Raises:
|
||
ValueError: If URL is invalid or potentially unsafe
|
||
"""
|
||
if not self.base_url:
|
||
return
|
||
|
||
try:
|
||
parsed = urlparse(self.base_url)
|
||
|
||
# Check URL scheme - only allow http/https
|
||
if parsed.scheme not in ("http", "https"):
|
||
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
|
||
|
||
# Check hostname exists
|
||
if not parsed.hostname:
|
||
raise ValueError("URL must include a hostname")
|
||
|
||
# Check port is valid (if specified)
|
||
port = parsed.port
|
||
if port is not None and (port < 1 or port > 65535):
|
||
raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.")
|
||
except Exception as e:
|
||
if isinstance(e, ValueError):
|
||
raise
|
||
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
|
||
|
||
@property
|
||
def client(self):
|
||
"""Lazy initialization of OpenAI client with security checks and timeout configuration."""
|
||
if self._client is None:
|
||
import os
|
||
|
||
import httpx
|
||
|
||
# Temporarily disable proxy environment variables to prevent httpx from detecting them
|
||
original_env = {}
|
||
proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]
|
||
|
||
for var in proxy_env_vars:
|
||
if var in os.environ:
|
||
original_env[var] = os.environ[var]
|
||
del os.environ[var]
|
||
|
||
try:
|
||
# Create a custom httpx client that explicitly avoids proxy parameters
|
||
timeout_config = (
|
||
self.timeout_config
|
||
if hasattr(self, "timeout_config") and self.timeout_config
|
||
else httpx.Timeout(30.0)
|
||
)
|
||
|
||
# Create httpx client with minimal config to avoid proxy conflicts
|
||
# Note: proxies parameter was removed in httpx 0.28.0
|
||
# Check for test transport injection
|
||
if hasattr(self, "_test_transport"):
|
||
# Use custom transport for testing (HTTP recording/replay)
|
||
http_client = httpx.Client(
|
||
transport=self._test_transport,
|
||
timeout=timeout_config,
|
||
follow_redirects=True,
|
||
)
|
||
else:
|
||
# Normal production client
|
||
http_client = httpx.Client(
|
||
timeout=timeout_config,
|
||
follow_redirects=True,
|
||
)
|
||
|
||
# Keep client initialization minimal to avoid proxy parameter conflicts
|
||
client_kwargs = {
|
||
"api_key": self.api_key,
|
||
"http_client": http_client,
|
||
}
|
||
|
||
if self.base_url:
|
||
client_kwargs["base_url"] = self.base_url
|
||
|
||
if self.organization:
|
||
client_kwargs["organization"] = self.organization
|
||
|
||
# Add default headers if any
|
||
if self.DEFAULT_HEADERS:
|
||
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||
|
||
logging.debug(f"OpenAI client initialized with custom httpx client and timeout: {timeout_config}")
|
||
|
||
# Create OpenAI client with custom httpx client
|
||
self._client = OpenAI(**client_kwargs)
|
||
|
||
except Exception as e:
|
||
# If all else fails, try absolute minimal client without custom httpx
|
||
logging.warning(f"Failed to create client with custom httpx, falling back to minimal config: {e}")
|
||
try:
|
||
minimal_kwargs = {"api_key": self.api_key}
|
||
if self.base_url:
|
||
minimal_kwargs["base_url"] = self.base_url
|
||
self._client = OpenAI(**minimal_kwargs)
|
||
except Exception as fallback_error:
|
||
logging.error(f"Even minimal OpenAI client creation failed: {fallback_error}")
|
||
raise
|
||
finally:
|
||
# Restore original proxy environment variables
|
||
for var, value in original_env.items():
|
||
os.environ[var] = value
|
||
|
||
return self._client
|
||
|
||
def _sanitize_for_logging(self, params: dict) -> dict:
|
||
"""Sanitize sensitive data from parameters before logging.
|
||
|
||
Args:
|
||
params: Dictionary of API parameters
|
||
|
||
Returns:
|
||
dict: Sanitized copy of parameters safe for logging
|
||
"""
|
||
sanitized = copy.deepcopy(params)
|
||
|
||
# Sanitize messages content
|
||
if "input" in sanitized:
|
||
for msg in sanitized.get("input", []):
|
||
if isinstance(msg, dict) and "content" in msg:
|
||
for content_item in msg.get("content", []):
|
||
if isinstance(content_item, dict) and "text" in content_item:
|
||
# Truncate long text and add ellipsis
|
||
text = content_item["text"]
|
||
if len(text) > 100:
|
||
content_item["text"] = text[:100] + "... [truncated]"
|
||
|
||
# Remove any API keys that might be in headers/auth
|
||
sanitized.pop("api_key", None)
|
||
sanitized.pop("authorization", None)
|
||
|
||
return sanitized
|
||
|
||
def _safe_extract_output_text(self, response) -> str:
|
||
"""Safely extract output_text from o3-pro response with validation.
|
||
|
||
Args:
|
||
response: Response object from OpenAI SDK
|
||
|
||
Returns:
|
||
str: The output text content
|
||
|
||
Raises:
|
||
ValueError: If output_text is missing, None, or not a string
|
||
"""
|
||
logging.debug(f"Response object type: {type(response)}")
|
||
logging.debug(f"Response attributes: {dir(response)}")
|
||
|
||
if not hasattr(response, "output_text"):
|
||
raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
|
||
|
||
content = response.output_text
|
||
logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
|
||
|
||
if content is None:
|
||
raise ValueError("o3-pro returned None for output_text")
|
||
|
||
if not isinstance(content, str):
|
||
raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}")
|
||
|
||
return content
|
||
|
||
def _generate_with_responses_endpoint(
|
||
self,
|
||
model_name: str,
|
||
messages: list,
|
||
temperature: float,
|
||
max_output_tokens: Optional[int] = None,
|
||
**kwargs,
|
||
) -> ModelResponse:
|
||
"""Generate content using the /v1/responses endpoint for o3-pro via OpenAI library."""
|
||
# Convert messages to the correct format for responses endpoint
|
||
input_messages = []
|
||
|
||
for message in messages:
|
||
role = message.get("role", "")
|
||
content = message.get("content", "")
|
||
|
||
if role == "system":
|
||
# For o3-pro, system messages should be handled carefully to avoid policy violations
|
||
# Instead of prefixing with "System:", we'll include the system content naturally
|
||
input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
|
||
elif role == "user":
|
||
input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
|
||
elif role == "assistant":
|
||
input_messages.append({"role": "assistant", "content": [{"type": "output_text", "text": content}]})
|
||
|
||
# Prepare completion parameters for responses endpoint
|
||
# Based on OpenAI documentation, use nested reasoning object for responses endpoint
|
||
completion_params = {
|
||
"model": model_name,
|
||
"input": input_messages,
|
||
"reasoning": {"effort": "medium"}, # Use nested object for responses endpoint
|
||
"store": True,
|
||
}
|
||
|
||
# Add max tokens if specified (using max_completion_tokens for responses endpoint)
|
||
if max_output_tokens:
|
||
completion_params["max_completion_tokens"] = max_output_tokens
|
||
|
||
# For responses endpoint, we only add parameters that are explicitly supported
|
||
# Remove unsupported chat completion parameters that may cause API errors
|
||
|
||
# Retry logic with progressive delays
|
||
max_retries = 4
|
||
retry_delays = [1, 3, 5, 8]
|
||
attempt_counter = {"value": 0}
|
||
|
||
def _attempt() -> ModelResponse:
|
||
attempt_counter["value"] += 1
|
||
import json
|
||
|
||
sanitized_params = self._sanitize_for_logging(completion_params)
|
||
logging.info(
|
||
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
|
||
)
|
||
|
||
response = self.client.responses.create(**completion_params)
|
||
|
||
content = self._safe_extract_output_text(response)
|
||
|
||
usage = None
|
||
if hasattr(response, "usage"):
|
||
usage = self._extract_usage(response)
|
||
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
|
||
input_tokens = getattr(response, "input_tokens", 0) or 0
|
||
output_tokens = getattr(response, "output_tokens", 0) or 0
|
||
usage = {
|
||
"input_tokens": input_tokens,
|
||
"output_tokens": output_tokens,
|
||
"total_tokens": input_tokens + output_tokens,
|
||
}
|
||
|
||
return ModelResponse(
|
||
content=content,
|
||
usage=usage,
|
||
model_name=model_name,
|
||
friendly_name=self.FRIENDLY_NAME,
|
||
provider=self.get_provider_type(),
|
||
metadata={
|
||
"model": getattr(response, "model", model_name),
|
||
"id": getattr(response, "id", ""),
|
||
"created": getattr(response, "created_at", 0),
|
||
"endpoint": "responses",
|
||
},
|
||
)
|
||
|
||
try:
|
||
return self._run_with_retries(
|
||
operation=_attempt,
|
||
max_attempts=max_retries,
|
||
delays=retry_delays,
|
||
log_prefix="o3-pro responses endpoint",
|
||
)
|
||
except Exception as exc:
|
||
attempts = max(attempt_counter["value"], 1)
|
||
error_msg = f"o3-pro responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}"
|
||
logging.error(error_msg)
|
||
raise RuntimeError(error_msg) from exc
|
||
|
||
def generate_content(
|
||
self,
|
||
prompt: str,
|
||
model_name: str,
|
||
system_prompt: Optional[str] = None,
|
||
temperature: float = 0.3,
|
||
max_output_tokens: Optional[int] = None,
|
||
images: Optional[list[str]] = None,
|
||
**kwargs,
|
||
) -> ModelResponse:
|
||
"""Generate content using the OpenAI-compatible API.
|
||
|
||
Args:
|
||
prompt: User prompt to send to the model
|
||
model_name: Canonical model name or its alias
|
||
system_prompt: Optional system prompt for model behavior
|
||
temperature: Sampling temperature
|
||
max_output_tokens: Maximum tokens to generate
|
||
images: Optional list of image paths or data URLs to include with the prompt (for vision models)
|
||
**kwargs: Additional provider-specific parameters
|
||
|
||
Returns:
|
||
ModelResponse with generated content and metadata
|
||
"""
|
||
# Validate model name against allow-list
|
||
if not self.validate_model_name(model_name):
|
||
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
|
||
|
||
capabilities: Optional[ModelCapabilities]
|
||
try:
|
||
capabilities = self.get_capabilities(model_name)
|
||
except Exception as exc:
|
||
logging.debug(f"Falling back to generic capabilities for {model_name}: {exc}")
|
||
capabilities = None
|
||
|
||
# Get effective temperature for this model from capabilities when available
|
||
if capabilities:
|
||
effective_temperature = capabilities.get_effective_temperature(temperature)
|
||
if effective_temperature is not None and effective_temperature != temperature:
|
||
logging.debug(
|
||
f"Adjusting temperature from {temperature} to {effective_temperature} for model {model_name}"
|
||
)
|
||
else:
|
||
effective_temperature = temperature
|
||
|
||
# Only validate if temperature is not None (meaning the model supports it)
|
||
if effective_temperature is not None:
|
||
# Validate parameters with the effective temperature
|
||
self.validate_parameters(model_name, effective_temperature)
|
||
|
||
# Resolve to canonical model name
|
||
resolved_model = self._resolve_model_name(model_name)
|
||
|
||
# Prepare messages
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
|
||
# Prepare user message with text and potentially images
|
||
user_content = []
|
||
user_content.append({"type": "text", "text": prompt})
|
||
|
||
# Add images if provided and model supports vision
|
||
if images and capabilities and capabilities.supports_images:
|
||
for image_path in images:
|
||
try:
|
||
image_content = self._process_image(image_path)
|
||
if image_content:
|
||
user_content.append(image_content)
|
||
except Exception as e:
|
||
logging.warning(f"Failed to process image {image_path}: {e}")
|
||
# Continue with other images and text
|
||
continue
|
||
elif images and (not capabilities or not capabilities.supports_images):
|
||
logging.warning(f"Model {resolved_model} does not support images, ignoring {len(images)} image(s)")
|
||
|
||
# Add user message
|
||
if len(user_content) == 1:
|
||
# Only text content, use simple string format for compatibility
|
||
messages.append({"role": "user", "content": prompt})
|
||
else:
|
||
# Text + images, use content array format
|
||
messages.append({"role": "user", "content": user_content})
|
||
|
||
# Prepare completion parameters
|
||
# Always disable streaming for OpenRouter
|
||
# MCP doesn't use streaming, and this avoids issues with O3 model access
|
||
completion_params = {
|
||
"model": resolved_model,
|
||
"messages": messages,
|
||
"stream": False,
|
||
}
|
||
|
||
# Use the effective temperature we calculated earlier
|
||
supports_sampling = effective_temperature is not None
|
||
|
||
if supports_sampling:
|
||
completion_params["temperature"] = effective_temperature
|
||
|
||
# Add max tokens if specified and model supports it
|
||
# O3/O4 models that don't support temperature also don't support max_tokens
|
||
if max_output_tokens and supports_sampling:
|
||
completion_params["max_tokens"] = max_output_tokens
|
||
|
||
# Add any additional OpenAI-specific parameters
|
||
# Use capabilities to filter parameters for reasoning models
|
||
for key, value in kwargs.items():
|
||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||
# Reasoning models (those that don't support temperature) also don't support these parameters
|
||
if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty", "stream"]:
|
||
continue # Skip unsupported parameters for reasoning models
|
||
completion_params[key] = value
|
||
|
||
# Check if this model needs the Responses API endpoint
|
||
# Both o3-pro and gpt-5-codex use the new Responses API
|
||
if resolved_model in ["o3-pro", "gpt-5-codex"]:
|
||
# These models require the /v1/responses endpoint for stateful context
|
||
# If it fails, we should not fall back to chat/completions
|
||
return self._generate_with_responses_endpoint(
|
||
model_name=resolved_model,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_output_tokens=max_output_tokens,
|
||
**kwargs,
|
||
)
|
||
|
||
# Retry logic with progressive delays
|
||
max_retries = 4 # Total of 4 attempts
|
||
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
||
attempt_counter = {"value": 0}
|
||
|
||
def _attempt() -> ModelResponse:
|
||
attempt_counter["value"] += 1
|
||
response = self.client.chat.completions.create(**completion_params)
|
||
|
||
content = response.choices[0].message.content
|
||
usage = self._extract_usage(response)
|
||
|
||
return ModelResponse(
|
||
content=content,
|
||
usage=usage,
|
||
model_name=resolved_model,
|
||
friendly_name=self.FRIENDLY_NAME,
|
||
provider=self.get_provider_type(),
|
||
metadata={
|
||
"finish_reason": response.choices[0].finish_reason,
|
||
"model": response.model,
|
||
"id": response.id,
|
||
"created": response.created,
|
||
},
|
||
)
|
||
|
||
try:
|
||
return self._run_with_retries(
|
||
operation=_attempt,
|
||
max_attempts=max_retries,
|
||
delays=retry_delays,
|
||
log_prefix=f"{self.FRIENDLY_NAME} API ({resolved_model})",
|
||
)
|
||
except Exception as exc:
|
||
attempts = max(attempt_counter["value"], 1)
|
||
error_msg = (
|
||
f"{self.FRIENDLY_NAME} API error for model {resolved_model} after {attempts} attempt"
|
||
f"{'s' if attempts > 1 else ''}: {exc}"
|
||
)
|
||
logging.error(error_msg)
|
||
raise RuntimeError(error_msg) from exc
|
||
|
||
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||
"""Validate model parameters.
|
||
|
||
For proxy providers, this may use generic capabilities.
|
||
|
||
Args:
|
||
model_name: Canonical model name or its alias
|
||
temperature: Temperature to validate
|
||
**kwargs: Additional parameters to validate
|
||
"""
|
||
try:
|
||
capabilities = self.get_capabilities(model_name)
|
||
|
||
# Check if we're using generic capabilities
|
||
if hasattr(capabilities, "_is_generic"):
|
||
logging.debug(
|
||
f"Using generic parameter validation for {model_name}. Actual model constraints may differ."
|
||
)
|
||
|
||
# Validate temperature using parent class method
|
||
super().validate_parameters(model_name, temperature, **kwargs)
|
||
|
||
except Exception as e:
|
||
# For proxy providers, we might not have accurate capabilities
|
||
# Log warning but don't fail
|
||
logging.warning(f"Parameter validation limited for {model_name}: {e}")
|
||
|
||
def _extract_usage(self, response) -> dict[str, int]:
|
||
"""Extract token usage from OpenAI response.
|
||
|
||
Args:
|
||
response: OpenAI API response object
|
||
|
||
Returns:
|
||
Dictionary with usage statistics
|
||
"""
|
||
usage = {}
|
||
|
||
if hasattr(response, "usage") and response.usage:
|
||
# Safely extract token counts with None handling
|
||
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) or 0
|
||
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) or 0
|
||
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) or 0
|
||
|
||
return usage
|
||
|
||
def count_tokens(self, text: str, model_name: str) -> int:
|
||
"""Count tokens using OpenAI-compatible tokenizer tables when available."""
|
||
|
||
resolved_model = self._resolve_model_name(model_name)
|
||
|
||
try:
|
||
import tiktoken
|
||
|
||
try:
|
||
encoding = tiktoken.encoding_for_model(resolved_model)
|
||
except KeyError:
|
||
encoding = tiktoken.get_encoding("cl100k_base")
|
||
|
||
return len(encoding.encode(text))
|
||
|
||
except (ImportError, Exception) as exc:
|
||
logging.debug("tiktoken unavailable for %s: %s", resolved_model, exc)
|
||
|
||
return super().count_tokens(text, model_name)
|
||
|
||
def _is_error_retryable(self, error: Exception) -> bool:
|
||
"""Determine if an error should be retried based on structured error codes.
|
||
|
||
Uses OpenAI API error structure instead of text pattern matching for reliability.
|
||
|
||
Args:
|
||
error: Exception from OpenAI API call
|
||
|
||
Returns:
|
||
True if error should be retried, False otherwise
|
||
"""
|
||
error_str = str(error).lower()
|
||
|
||
# Check for 429 errors first - these need special handling
|
||
if "429" in error_str:
|
||
# Try to extract structured error information
|
||
error_type = None
|
||
error_code = None
|
||
|
||
# Parse structured error from OpenAI API response
|
||
# Format: "Error code: 429 - {'error': {'type': 'tokens', 'code': 'rate_limit_exceeded', ...}}"
|
||
try:
|
||
import ast
|
||
import json
|
||
import re
|
||
|
||
# Extract JSON part from error string using regex
|
||
# Look for pattern: {...} (from first { to last })
|
||
json_match = re.search(r"\{.*\}", str(error))
|
||
if json_match:
|
||
json_like_str = json_match.group(0)
|
||
|
||
# First try: parse as Python literal (handles single quotes safely)
|
||
try:
|
||
error_data = ast.literal_eval(json_like_str)
|
||
except (ValueError, SyntaxError):
|
||
# Fallback: try JSON parsing with simple quote replacement
|
||
# (for cases where it's already valid JSON or simple replacements work)
|
||
json_str = json_like_str.replace("'", '"')
|
||
error_data = json.loads(json_str)
|
||
|
||
if "error" in error_data:
|
||
error_info = error_data["error"]
|
||
error_type = error_info.get("type")
|
||
error_code = error_info.get("code")
|
||
|
||
except (json.JSONDecodeError, ValueError, SyntaxError, AttributeError):
|
||
# Fall back to checking hasattr for OpenAI SDK exception objects
|
||
if hasattr(error, "response") and hasattr(error.response, "json"):
|
||
try:
|
||
response_data = error.response.json()
|
||
if "error" in response_data:
|
||
error_info = response_data["error"]
|
||
error_type = error_info.get("type")
|
||
error_code = error_info.get("code")
|
||
except Exception:
|
||
pass
|
||
|
||
# Determine if 429 is retryable based on structured error codes
|
||
if error_type == "tokens":
|
||
# Token-related 429s are typically non-retryable (request too large)
|
||
logging.debug(f"Non-retryable 429: token-related error (type={error_type}, code={error_code})")
|
||
return False
|
||
elif error_code in ["invalid_request_error", "context_length_exceeded"]:
|
||
# These are permanent failures
|
||
logging.debug(f"Non-retryable 429: permanent failure (type={error_type}, code={error_code})")
|
||
return False
|
||
else:
|
||
# Other 429s (like requests per minute) are retryable
|
||
logging.debug(f"Retryable 429: rate limiting (type={error_type}, code={error_code})")
|
||
return True
|
||
|
||
# For non-429 errors, check if they're retryable
|
||
retryable_indicators = [
|
||
"timeout",
|
||
"connection",
|
||
"network",
|
||
"temporary",
|
||
"unavailable",
|
||
"retry",
|
||
"408", # Request timeout
|
||
"500", # Internal server error
|
||
"502", # Bad gateway
|
||
"503", # Service unavailable
|
||
"504", # Gateway timeout
|
||
"ssl", # SSL errors
|
||
"handshake", # Handshake failures
|
||
]
|
||
|
||
return any(indicator in error_str for indicator in retryable_indicators)
|
||
|
||
def _process_image(self, image_path: str) -> Optional[dict]:
|
||
"""Process an image for OpenAI-compatible API."""
|
||
try:
|
||
if image_path.startswith("data:"):
|
||
# Validate the data URL
|
||
validate_image(image_path)
|
||
# Handle data URL: data:image/png;base64,iVBORw0...
|
||
return {"type": "image_url", "image_url": {"url": image_path}}
|
||
else:
|
||
# Use base class validation
|
||
image_bytes, mime_type = validate_image(image_path)
|
||
|
||
# Read and encode the image
|
||
import base64
|
||
|
||
image_data = base64.b64encode(image_bytes).decode()
|
||
logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'")
|
||
|
||
# Create data URL for OpenAI API
|
||
data_url = f"data:{mime_type};base64,{image_data}"
|
||
|
||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||
|
||
except ValueError as e:
|
||
logging.warning(str(e))
|
||
return None
|
||
except Exception as e:
|
||
logging.error(f"Error processing image {image_path}: {e}")
|
||
return None
|