"""Base class for OpenAI-compatible API providers.""" import base64 import ipaddress import logging import os import time from abc import abstractmethod from typing import Optional from urllib.parse import urlparse from openai import OpenAI from .base import ( ModelCapabilities, ModelProvider, ModelResponse, ProviderType, ) class OpenAICompatibleProvider(ModelProvider): """Base class for any provider using an OpenAI-compatible API. This includes: - Direct OpenAI API - OpenRouter - Any other OpenAI-compatible endpoint """ DEFAULT_HEADERS = {} FRIENDLY_NAME = "OpenAI Compatible" def __init__(self, api_key: str, base_url: str = None, **kwargs): """Initialize the provider with API key and optional base URL. Args: api_key: API key for authentication base_url: Base URL for the API endpoint **kwargs: Additional configuration options including timeout """ super().__init__(api_key, **kwargs) self._client = None self.base_url = base_url self.organization = kwargs.get("organization") self.allowed_models = self._parse_allowed_models() # Configure timeouts - especially important for custom/local endpoints self.timeout_config = self._configure_timeouts(**kwargs) # Validate base URL for security if self.base_url: self._validate_base_url() # Warn if using external URL without authentication if self.base_url and not self._is_localhost_url() and not api_key: logging.warning( f"Using external URL '{self.base_url}' without API key. " "This may be insecure. Consider setting an API key for authentication." ) def _parse_allowed_models(self) -> Optional[set[str]]: """Parse allowed models from environment variable. Returns: Set of allowed model names (lowercase) or None if not configured """ # Get provider-specific allowed models provider_type = self.get_provider_type().value.upper() env_var = f"{provider_type}_ALLOWED_MODELS" models_str = os.getenv(env_var, "") if models_str: # Parse and normalize to lowercase for case-insensitive comparison models = {m.strip().lower() for m in models_str.split(",") if m.strip()} if models: logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}") return models # Log info if no allow-list configured for proxy providers if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]: logging.info( f"Model allow-list not configured for {self.FRIENDLY_NAME} - all models permitted. " f"To restrict access, set {env_var} with comma-separated model names." ) return None def _configure_timeouts(self, **kwargs): """Configure timeout settings based on provider type and custom settings. Custom URLs and local models often need longer timeouts due to: - Network latency on local networks - Extended thinking models taking longer to respond - Local inference being slower than cloud APIs Returns: httpx.Timeout object with appropriate timeout settings """ import httpx # Default timeouts - more generous for custom/local endpoints default_connect = 30.0 # 30 seconds for connection (vs OpenAI's 5s) default_read = 600.0 # 10 minutes for reading (same as OpenAI default) default_write = 600.0 # 10 minutes for writing default_pool = 600.0 # 10 minutes for pool # For custom/local URLs, use even longer timeouts if self.base_url and self._is_localhost_url(): default_connect = 60.0 # 1 minute for local connections default_read = 1800.0 # 30 minutes for local models (extended thinking) default_write = 1800.0 # 30 minutes for local models default_pool = 1800.0 # 30 minutes for local models logging.info(f"Using extended timeouts for local endpoint: {self.base_url}") elif self.base_url: default_connect = 45.0 # 45 seconds for custom remote endpoints default_read = 900.0 # 15 minutes for custom remote endpoints default_write = 900.0 # 15 minutes for custom remote endpoints default_pool = 900.0 # 15 minutes for custom remote endpoints logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}") # Allow override via kwargs or environment variables in future, for now... connect_timeout = kwargs.get("connect_timeout", float(os.getenv("CUSTOM_CONNECT_TIMEOUT", default_connect))) read_timeout = kwargs.get("read_timeout", float(os.getenv("CUSTOM_READ_TIMEOUT", default_read))) write_timeout = kwargs.get("write_timeout", float(os.getenv("CUSTOM_WRITE_TIMEOUT", default_write))) pool_timeout = kwargs.get("pool_timeout", float(os.getenv("CUSTOM_POOL_TIMEOUT", default_pool))) timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout) logging.debug( f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, " f"Write: {write_timeout}s, Pool: {pool_timeout}s" ) return timeout def _is_localhost_url(self) -> bool: """Check if the base URL points to localhost or local network. Returns: True if URL is localhost or local network, False otherwise """ if not self.base_url: return False try: parsed = urlparse(self.base_url) hostname = parsed.hostname # Check for common localhost patterns if hostname in ["localhost", "127.0.0.1", "::1"]: return True # Check for private network ranges (local network) if hostname: try: ip = ipaddress.ip_address(hostname) return ip.is_private or ip.is_loopback except ValueError: # Not an IP address, might be a hostname pass return False except Exception: return False def _validate_base_url(self) -> None: """Validate base URL for security (SSRF protection). Raises: ValueError: If URL is invalid or potentially unsafe """ if not self.base_url: return try: parsed = urlparse(self.base_url) # Check URL scheme - only allow http/https if parsed.scheme not in ("http", "https"): raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.") # Check hostname exists if not parsed.hostname: raise ValueError("URL must include a hostname") # Check port is valid (if specified) port = parsed.port if port is not None and (port < 1 or port > 65535): raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.") except Exception as e: if isinstance(e, ValueError): raise raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}") @property def client(self): """Lazy initialization of OpenAI client with security checks and timeout configuration.""" if self._client is None: import os import httpx # Temporarily disable proxy environment variables to prevent httpx from detecting them original_env = {} proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"] for var in proxy_env_vars: if var in os.environ: original_env[var] = os.environ[var] del os.environ[var] try: # Create a custom httpx client that explicitly avoids proxy parameters timeout_config = ( self.timeout_config if hasattr(self, "timeout_config") and self.timeout_config else httpx.Timeout(30.0) ) # Create httpx client with minimal config to avoid proxy conflicts # Note: proxies parameter was removed in httpx 0.28.0 http_client = httpx.Client( timeout=timeout_config, follow_redirects=True, ) # Keep client initialization minimal to avoid proxy parameter conflicts client_kwargs = { "api_key": self.api_key, "http_client": http_client, } if self.base_url: client_kwargs["base_url"] = self.base_url if self.organization: client_kwargs["organization"] = self.organization # Add default headers if any if self.DEFAULT_HEADERS: client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy() logging.debug(f"OpenAI client initialized with custom httpx client and timeout: {timeout_config}") # Create OpenAI client with custom httpx client self._client = OpenAI(**client_kwargs) except Exception as e: # If all else fails, try absolute minimal client without custom httpx logging.warning(f"Failed to create client with custom httpx, falling back to minimal config: {e}") try: minimal_kwargs = {"api_key": self.api_key} if self.base_url: minimal_kwargs["base_url"] = self.base_url self._client = OpenAI(**minimal_kwargs) except Exception as fallback_error: logging.error(f"Even minimal OpenAI client creation failed: {fallback_error}") raise finally: # Restore original proxy environment variables for var, value in original_env.items(): os.environ[var] = value return self._client def _generate_with_responses_endpoint( self, model_name: str, messages: list, temperature: float, max_output_tokens: Optional[int] = None, **kwargs, ) -> ModelResponse: """Generate content using the /v1/responses endpoint for o3-pro via OpenAI library.""" # Convert messages to the correct format for responses endpoint input_messages = [] for message in messages: role = message.get("role", "") content = message.get("content", "") if role == "system": # For o3-pro, system messages should be handled carefully to avoid policy violations # Instead of prefixing with "System:", we'll include the system content naturally input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]}) elif role == "user": input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]}) elif role == "assistant": input_messages.append({"role": "assistant", "content": [{"type": "output_text", "text": content}]}) # Prepare completion parameters for responses endpoint # Based on OpenAI documentation, use nested reasoning object for responses endpoint completion_params = { "model": model_name, "input": input_messages, "reasoning": {"effort": "medium"}, # Use nested object for responses endpoint "store": True, } # Add max tokens if specified (using max_completion_tokens for responses endpoint) if max_output_tokens: completion_params["max_completion_tokens"] = max_output_tokens # For responses endpoint, we only add parameters that are explicitly supported # Remove unsupported chat completion parameters that may cause API errors # Retry logic with progressive delays max_retries = 4 retry_delays = [1, 3, 5, 8] last_exception = None for attempt in range(max_retries): try: # Log the exact payload being sent for debugging import json logging.info(f"o3-pro API request payload: {json.dumps(completion_params, indent=2)}") # Use OpenAI client's responses endpoint response = self.client.responses.create(**completion_params) # Extract content and usage from responses endpoint format # The response format is different for responses endpoint content = "" if hasattr(response, "output") and response.output: if hasattr(response.output, "content") and response.output.content: # Look for output_text in content for content_item in response.output.content: if hasattr(content_item, "type") and content_item.type == "output_text": content = content_item.text break elif hasattr(response.output, "text"): content = response.output.text # Try to extract usage information usage = None if hasattr(response, "usage"): usage = self._extract_usage(response) elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"): # Safely extract token counts with None handling input_tokens = getattr(response, "input_tokens", 0) or 0 output_tokens = getattr(response, "output_tokens", 0) or 0 usage = { "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": input_tokens + output_tokens, } return ModelResponse( content=content, usage=usage, model_name=model_name, friendly_name=self.FRIENDLY_NAME, provider=self.get_provider_type(), metadata={ "model": getattr(response, "model", model_name), "id": getattr(response, "id", ""), "created": getattr(response, "created_at", 0), "endpoint": "responses", }, ) except Exception as e: last_exception = e # Check if this is a retryable error using structured error codes is_retryable = self._is_error_retryable(e) if is_retryable and attempt < max_retries - 1: delay = retry_delays[attempt] logging.warning( f"Retryable error for o3-pro responses endpoint, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..." ) time.sleep(delay) else: break # If we get here, all retries failed actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" logging.error(error_msg) raise RuntimeError(error_msg) from last_exception def generate_content( self, prompt: str, model_name: str, system_prompt: Optional[str] = None, temperature: float = 0.7, max_output_tokens: Optional[int] = None, images: Optional[list[str]] = None, **kwargs, ) -> ModelResponse: """Generate content using the OpenAI-compatible API. Args: prompt: User prompt to send to the model model_name: Name of the model to use system_prompt: Optional system prompt for model behavior temperature: Sampling temperature max_output_tokens: Maximum tokens to generate **kwargs: Additional provider-specific parameters Returns: ModelResponse with generated content and metadata """ # Validate model name against allow-list if not self.validate_model_name(model_name): raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}") # Get effective temperature for this model effective_temperature = self.get_effective_temperature(model_name, temperature) # Only validate if temperature is not None (meaning the model supports it) if effective_temperature is not None: # Validate parameters with the effective temperature self.validate_parameters(model_name, effective_temperature) # Prepare messages messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) # Prepare user message with text and potentially images user_content = [] user_content.append({"type": "text", "text": prompt}) # Add images if provided and model supports vision if images and self._supports_vision(model_name): for image_path in images: try: image_content = self._process_image(image_path) if image_content: user_content.append(image_content) except Exception as e: logging.warning(f"Failed to process image {image_path}: {e}") # Continue with other images and text continue elif images and not self._supports_vision(model_name): logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)") # Add user message if len(user_content) == 1: # Only text content, use simple string format for compatibility messages.append({"role": "user", "content": prompt}) else: # Text + images, use content array format messages.append({"role": "user", "content": user_content}) # Prepare completion parameters completion_params = { "model": model_name, "messages": messages, } # Check model capabilities once to determine parameter support resolved_model = self._resolve_model_name(model_name) # Use the effective temperature we calculated earlier if effective_temperature is not None: completion_params["temperature"] = effective_temperature supports_temperature = True else: # Model doesn't support temperature supports_temperature = False # Add max tokens if specified and model supports it # O3/O4 models that don't support temperature also don't support max_tokens if max_output_tokens and supports_temperature: completion_params["max_tokens"] = max_output_tokens # Add any additional OpenAI-specific parameters # Use capabilities to filter parameters for reasoning models for key, value in kwargs.items(): if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]: # Reasoning models (those that don't support temperature) also don't support these parameters if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]: continue # Skip unsupported parameters for reasoning models completion_params[key] = value # Check if this is o3-pro and needs the responses endpoint if resolved_model == "o3-pro-2025-06-10": # This model requires the /v1/responses endpoint # If it fails, we should not fall back to chat/completions return self._generate_with_responses_endpoint( model_name=resolved_model, messages=messages, temperature=temperature, max_output_tokens=max_output_tokens, **kwargs, ) # Retry logic with progressive delays max_retries = 4 # Total of 4 attempts retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s last_exception = None for attempt in range(max_retries): try: # Generate completion response = self.client.chat.completions.create(**completion_params) # Extract content and usage content = response.choices[0].message.content usage = self._extract_usage(response) return ModelResponse( content=content, usage=usage, model_name=model_name, friendly_name=self.FRIENDLY_NAME, provider=self.get_provider_type(), metadata={ "finish_reason": response.choices[0].finish_reason, "model": response.model, # Actual model used "id": response.id, "created": response.created, }, ) except Exception as e: last_exception = e # Check if this is a retryable error using structured error codes is_retryable = self._is_error_retryable(e) # If this is the last attempt or not retryable, give up if attempt == max_retries - 1 or not is_retryable: break # Get progressive delay delay = retry_delays[attempt] # Log retry attempt logging.warning( f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..." ) time.sleep(delay) # If we get here, all retries failed actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" logging.error(error_msg) raise RuntimeError(error_msg) from last_exception def count_tokens(self, text: str, model_name: str) -> int: """Count tokens for the given text. Uses a layered approach: 1. Try provider-specific token counting endpoint 2. Try tiktoken for known model families 3. Fall back to character-based estimation Args: text: Text to count tokens for model_name: Model name for tokenizer selection Returns: Estimated token count """ # 1. Check if provider has a remote token counting endpoint if hasattr(self, "count_tokens_remote"): try: return self.count_tokens_remote(text, model_name) except Exception as e: logging.debug(f"Remote token counting failed: {e}") # 2. Try tiktoken for known models try: import tiktoken # Try to get encoding for the specific model try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: # Try common encodings based on model patterns if "gpt-4" in model_name or "gpt-3.5" in model_name: encoding = tiktoken.get_encoding("cl100k_base") else: encoding = tiktoken.get_encoding("cl100k_base") # Default return len(encoding.encode(text)) except (ImportError, Exception) as e: logging.debug(f"Tiktoken not available or failed: {e}") # 3. Fall back to character-based estimation logging.warning( f"No specific tokenizer available for '{model_name}'. " "Using character-based estimation (~4 chars per token)." ) return len(text) // 4 def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: """Validate model parameters. For proxy providers, this may use generic capabilities. Args: model_name: Model to validate for temperature: Temperature to validate **kwargs: Additional parameters to validate """ try: capabilities = self.get_capabilities(model_name) # Check if we're using generic capabilities if hasattr(capabilities, "_is_generic"): logging.debug( f"Using generic parameter validation for {model_name}. Actual model constraints may differ." ) # Validate temperature using parent class method super().validate_parameters(model_name, temperature, **kwargs) except Exception as e: # For proxy providers, we might not have accurate capabilities # Log warning but don't fail logging.warning(f"Parameter validation limited for {model_name}: {e}") def _extract_usage(self, response) -> dict[str, int]: """Extract token usage from OpenAI response. Args: response: OpenAI API response object Returns: Dictionary with usage statistics """ usage = {} if hasattr(response, "usage") and response.usage: # Safely extract token counts with None handling usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) or 0 usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) or 0 usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) or 0 return usage @abstractmethod def get_capabilities(self, model_name: str) -> ModelCapabilities: """Get capabilities for a specific model. Must be implemented by subclasses. """ pass @abstractmethod def get_provider_type(self) -> ProviderType: """Get the provider type. Must be implemented by subclasses. """ pass @abstractmethod def validate_model_name(self, model_name: str) -> bool: """Validate if the model name is supported. Must be implemented by subclasses. """ pass def supports_thinking_mode(self, model_name: str) -> bool: """Check if the model supports extended thinking mode. Default is False for OpenAI-compatible providers. """ return False def _supports_vision(self, model_name: str) -> bool: """Check if the model supports vision (image processing). Default implementation for OpenAI-compatible providers. Subclasses should override with specific model support. """ # Common vision-capable models - only include models that actually support images vision_models = { "gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4-vision-preview", "gpt-4.1-2025-04-14", # GPT-4.1 supports vision "o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high", # Note: Claude models would be handled by a separate provider } supports = model_name.lower() in vision_models logging.debug(f"Model '{model_name}' vision support: {supports}") return supports def _is_error_retryable(self, error: Exception) -> bool: """Determine if an error should be retried based on structured error codes. Uses OpenAI API error structure instead of text pattern matching for reliability. Args: error: Exception from OpenAI API call Returns: True if error should be retried, False otherwise """ error_str = str(error).lower() # Check for 429 errors first - these need special handling if "429" in error_str: # Try to extract structured error information error_type = None error_code = None # Parse structured error from OpenAI API response # Format: "Error code: 429 - {'error': {'type': 'tokens', 'code': 'rate_limit_exceeded', ...}}" try: import ast import json import re # Extract JSON part from error string using regex # Look for pattern: {...} (from first { to last }) json_match = re.search(r"\{.*\}", str(error)) if json_match: json_like_str = json_match.group(0) # First try: parse as Python literal (handles single quotes safely) try: error_data = ast.literal_eval(json_like_str) except (ValueError, SyntaxError): # Fallback: try JSON parsing with simple quote replacement # (for cases where it's already valid JSON or simple replacements work) json_str = json_like_str.replace("'", '"') error_data = json.loads(json_str) if "error" in error_data: error_info = error_data["error"] error_type = error_info.get("type") error_code = error_info.get("code") except (json.JSONDecodeError, ValueError, SyntaxError, AttributeError): # Fall back to checking hasattr for OpenAI SDK exception objects if hasattr(error, "response") and hasattr(error.response, "json"): try: response_data = error.response.json() if "error" in response_data: error_info = response_data["error"] error_type = error_info.get("type") error_code = error_info.get("code") except Exception: pass # Determine if 429 is retryable based on structured error codes if error_type == "tokens": # Token-related 429s are typically non-retryable (request too large) logging.debug(f"Non-retryable 429: token-related error (type={error_type}, code={error_code})") return False elif error_code in ["invalid_request_error", "context_length_exceeded"]: # These are permanent failures logging.debug(f"Non-retryable 429: permanent failure (type={error_type}, code={error_code})") return False else: # Other 429s (like requests per minute) are retryable logging.debug(f"Retryable 429: rate limiting (type={error_type}, code={error_code})") return True # For non-429 errors, check if they're retryable retryable_indicators = [ "timeout", "connection", "network", "temporary", "unavailable", "retry", "408", # Request timeout "500", # Internal server error "502", # Bad gateway "503", # Service unavailable "504", # Gateway timeout "ssl", # SSL errors "handshake", # Handshake failures ] return any(indicator in error_str for indicator in retryable_indicators) def _process_image(self, image_path: str) -> Optional[dict]: """Process an image for OpenAI-compatible API.""" try: if image_path.startswith("... return {"type": "image_url", "image_url": {"url": image_path}} else: # Handle file path if not os.path.exists(image_path): logging.warning(f"Image file not found: {image_path}") return None # Detect MIME type from file extension using centralized mappings from utils.file_types import get_image_mime_type ext = os.path.splitext(image_path)[1].lower() mime_type = get_image_mime_type(ext) logging.debug(f"Processing image '{image_path}' with extension '{ext}' as MIME type '{mime_type}'") # Read and encode the image with open(image_path, "rb") as f: image_data = base64.b64encode(f.read()).decode() # Create data URL for OpenAI API data_url = f"data:{mime_type};base64,{image_data}" return {"type": "image_url", "image_url": {"url": data_url}} except Exception as e: logging.error(f"Error processing image {image_path}: {e}") return None