"""Base class for OpenAI-compatible API providers.""" import base64 import ipaddress import logging import os from abc import abstractmethod from typing import Optional from urllib.parse import urlparse from openai import OpenAI from .base import ( ModelCapabilities, ModelProvider, ModelResponse, ProviderType, ) class OpenAICompatibleProvider(ModelProvider): """Base class for any provider using an OpenAI-compatible API. This includes: - Direct OpenAI API - OpenRouter - Any other OpenAI-compatible endpoint """ DEFAULT_HEADERS = {} FRIENDLY_NAME = "OpenAI Compatible" def __init__(self, api_key: str, base_url: str = None, **kwargs): """Initialize the provider with API key and optional base URL. Args: api_key: API key for authentication base_url: Base URL for the API endpoint **kwargs: Additional configuration options including timeout """ super().__init__(api_key, **kwargs) self._client = None self.base_url = base_url self.organization = kwargs.get("organization") self.allowed_models = self._parse_allowed_models() # Configure timeouts - especially important for custom/local endpoints self.timeout_config = self._configure_timeouts(**kwargs) # Validate base URL for security if self.base_url: self._validate_base_url() # Warn if using external URL without authentication if self.base_url and not self._is_localhost_url() and not api_key: logging.warning( f"Using external URL '{self.base_url}' without API key. " "This may be insecure. Consider setting an API key for authentication." ) def _parse_allowed_models(self) -> Optional[set[str]]: """Parse allowed models from environment variable. Returns: Set of allowed model names (lowercase) or None if not configured """ # Get provider-specific allowed models provider_type = self.get_provider_type().value.upper() env_var = f"{provider_type}_ALLOWED_MODELS" models_str = os.getenv(env_var, "") if models_str: # Parse and normalize to lowercase for case-insensitive comparison models = {m.strip().lower() for m in models_str.split(",") if m.strip()} if models: logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}") return models # Log info if no allow-list configured for proxy providers if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]: logging.info( f"Model allow-list not configured for {self.FRIENDLY_NAME} - all models permitted. " f"To restrict access, set {env_var} with comma-separated model names." ) return None def _configure_timeouts(self, **kwargs): """Configure timeout settings based on provider type and custom settings. Custom URLs and local models often need longer timeouts due to: - Network latency on local networks - Extended thinking models taking longer to respond - Local inference being slower than cloud APIs Returns: httpx.Timeout object with appropriate timeout settings """ import httpx # Default timeouts - more generous for custom/local endpoints default_connect = 30.0 # 30 seconds for connection (vs OpenAI's 5s) default_read = 600.0 # 10 minutes for reading (same as OpenAI default) default_write = 600.0 # 10 minutes for writing default_pool = 600.0 # 10 minutes for pool # For custom/local URLs, use even longer timeouts if self.base_url and self._is_localhost_url(): default_connect = 60.0 # 1 minute for local connections default_read = 1800.0 # 30 minutes for local models (extended thinking) default_write = 1800.0 # 30 minutes for local models default_pool = 1800.0 # 30 minutes for local models logging.info(f"Using extended timeouts for local endpoint: {self.base_url}") elif self.base_url: default_connect = 45.0 # 45 seconds for custom remote endpoints default_read = 900.0 # 15 minutes for custom remote endpoints default_write = 900.0 # 15 minutes for custom remote endpoints default_pool = 900.0 # 15 minutes for custom remote endpoints logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}") # Allow override via kwargs or environment variables in future, for now... connect_timeout = kwargs.get("connect_timeout", float(os.getenv("CUSTOM_CONNECT_TIMEOUT", default_connect))) read_timeout = kwargs.get("read_timeout", float(os.getenv("CUSTOM_READ_TIMEOUT", default_read))) write_timeout = kwargs.get("write_timeout", float(os.getenv("CUSTOM_WRITE_TIMEOUT", default_write))) pool_timeout = kwargs.get("pool_timeout", float(os.getenv("CUSTOM_POOL_TIMEOUT", default_pool))) timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout) logging.debug( f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, " f"Write: {write_timeout}s, Pool: {pool_timeout}s" ) return timeout def _is_localhost_url(self) -> bool: """Check if the base URL points to localhost or local network. Returns: True if URL is localhost or local network, False otherwise """ if not self.base_url: return False try: parsed = urlparse(self.base_url) hostname = parsed.hostname # Check for common localhost patterns if hostname in ["localhost", "127.0.0.1", "::1"]: return True # Check for Docker internal hostnames (like host.docker.internal) if hostname and ("docker.internal" in hostname or "host.docker.internal" in hostname): return True # Check for private network ranges (local network) if hostname: try: ip = ipaddress.ip_address(hostname) return ip.is_private or ip.is_loopback except ValueError: # Not an IP address, might be a hostname pass return False except Exception: return False def _validate_base_url(self) -> None: """Validate base URL for security (SSRF protection). Raises: ValueError: If URL is invalid or potentially unsafe """ if not self.base_url: return try: parsed = urlparse(self.base_url) # Check URL scheme - only allow http/https if parsed.scheme not in ("http", "https"): raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.") # Check hostname exists if not parsed.hostname: raise ValueError("URL must include a hostname") # Check port is valid (if specified) port = parsed.port if port is not None and (port < 1 or port > 65535): raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.") except Exception as e: if isinstance(e, ValueError): raise raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}") @property def client(self): """Lazy initialization of OpenAI client with security checks and timeout configuration.""" if self._client is None: client_kwargs = { "api_key": self.api_key, } if self.base_url: client_kwargs["base_url"] = self.base_url if self.organization: client_kwargs["organization"] = self.organization # Add default headers if any if self.DEFAULT_HEADERS: client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy() # Add configured timeout settings if hasattr(self, "timeout_config") and self.timeout_config: client_kwargs["timeout"] = self.timeout_config logging.debug(f"OpenAI client initialized with custom timeout: {self.timeout_config}") self._client = OpenAI(**client_kwargs) return self._client def generate_content( self, prompt: str, model_name: str, system_prompt: Optional[str] = None, temperature: float = 0.7, max_output_tokens: Optional[int] = None, images: Optional[list[str]] = None, **kwargs, ) -> ModelResponse: """Generate content using the OpenAI-compatible API. Args: prompt: User prompt to send to the model model_name: Name of the model to use system_prompt: Optional system prompt for model behavior temperature: Sampling temperature max_output_tokens: Maximum tokens to generate **kwargs: Additional provider-specific parameters Returns: ModelResponse with generated content and metadata """ # Validate model name against allow-list if not self.validate_model_name(model_name): raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}") # Validate parameters self.validate_parameters(model_name, temperature) # Prepare messages messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) # Prepare user message with text and potentially images user_content = [] user_content.append({"type": "text", "text": prompt}) # Add images if provided and model supports vision if images and self._supports_vision(model_name): for image_path in images: try: image_content = self._process_image(image_path) if image_content: user_content.append(image_content) except Exception as e: logging.warning(f"Failed to process image {image_path}: {e}") # Continue with other images and text continue elif images and not self._supports_vision(model_name): logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)") # Add user message if len(user_content) == 1: # Only text content, use simple string format for compatibility messages.append({"role": "user", "content": prompt}) else: # Text + images, use content array format messages.append({"role": "user", "content": user_content}) # Prepare completion parameters completion_params = { "model": model_name, "messages": messages, "temperature": temperature, } # Add max tokens if specified if max_output_tokens: completion_params["max_tokens"] = max_output_tokens # Add any additional OpenAI-specific parameters for key, value in kwargs.items(): if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]: completion_params[key] = value try: # Generate completion response = self.client.chat.completions.create(**completion_params) # Extract content and usage content = response.choices[0].message.content usage = self._extract_usage(response) return ModelResponse( content=content, usage=usage, model_name=model_name, friendly_name=self.FRIENDLY_NAME, provider=self.get_provider_type(), metadata={ "finish_reason": response.choices[0].finish_reason, "model": response.model, # Actual model used "id": response.id, "created": response.created, }, ) except Exception as e: # Log error and re-raise with more context error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name}: {str(e)}" logging.error(error_msg) raise RuntimeError(error_msg) from e def count_tokens(self, text: str, model_name: str) -> int: """Count tokens for the given text. Uses a layered approach: 1. Try provider-specific token counting endpoint 2. Try tiktoken for known model families 3. Fall back to character-based estimation Args: text: Text to count tokens for model_name: Model name for tokenizer selection Returns: Estimated token count """ # 1. Check if provider has a remote token counting endpoint if hasattr(self, "count_tokens_remote"): try: return self.count_tokens_remote(text, model_name) except Exception as e: logging.debug(f"Remote token counting failed: {e}") # 2. Try tiktoken for known models try: import tiktoken # Try to get encoding for the specific model try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: # Try common encodings based on model patterns if "gpt-4" in model_name or "gpt-3.5" in model_name: encoding = tiktoken.get_encoding("cl100k_base") else: encoding = tiktoken.get_encoding("cl100k_base") # Default return len(encoding.encode(text)) except (ImportError, Exception) as e: logging.debug(f"Tiktoken not available or failed: {e}") # 3. Fall back to character-based estimation logging.warning( f"No specific tokenizer available for '{model_name}'. " "Using character-based estimation (~4 chars per token)." ) return len(text) // 4 def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: """Validate model parameters. For proxy providers, this may use generic capabilities. Args: model_name: Model to validate for temperature: Temperature to validate **kwargs: Additional parameters to validate """ try: capabilities = self.get_capabilities(model_name) # Check if we're using generic capabilities if hasattr(capabilities, "_is_generic"): logging.debug( f"Using generic parameter validation for {model_name}. Actual model constraints may differ." ) # Validate temperature using parent class method super().validate_parameters(model_name, temperature, **kwargs) except Exception as e: # For proxy providers, we might not have accurate capabilities # Log warning but don't fail logging.warning(f"Parameter validation limited for {model_name}: {e}") def _extract_usage(self, response) -> dict[str, int]: """Extract token usage from OpenAI response. Args: response: OpenAI API response object Returns: Dictionary with usage statistics """ usage = {} if hasattr(response, "usage") and response.usage: usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) return usage @abstractmethod def get_capabilities(self, model_name: str) -> ModelCapabilities: """Get capabilities for a specific model. Must be implemented by subclasses. """ pass @abstractmethod def get_provider_type(self) -> ProviderType: """Get the provider type. Must be implemented by subclasses. """ pass @abstractmethod def validate_model_name(self, model_name: str) -> bool: """Validate if the model name is supported. Must be implemented by subclasses. """ pass def supports_thinking_mode(self, model_name: str) -> bool: """Check if the model supports extended thinking mode. Default is False for OpenAI-compatible providers. """ return False def _supports_vision(self, model_name: str) -> bool: """Check if the model supports vision (image processing). Default implementation for OpenAI-compatible providers. Subclasses should override with specific model support. """ # Common vision-capable models - only include models that actually support images vision_models = { "gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4-vision-preview", "gpt-4.1-2025-04-14", # GPT-4.1 supports vision "o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high", # Note: Claude models would be handled by a separate provider } supports = model_name.lower() in vision_models logging.debug(f"Model '{model_name}' vision support: {supports}") return supports def _process_image(self, image_path: str) -> Optional[dict]: """Process an image for OpenAI-compatible API.""" try: if image_path.startswith("data:image/"): # Handle data URL: data:image/png;base64,iVBORw0... return {"type": "image_url", "image_url": {"url": image_path}} else: # Handle file path - translate for Docker environment from utils.file_utils import translate_path_for_environment translated_path = translate_path_for_environment(image_path) logging.debug(f"Translated image path from '{image_path}' to '{translated_path}'") if not os.path.exists(translated_path): logging.warning(f"Image file not found: {translated_path} (original: {image_path})") return None # Use translated path for all subsequent operations image_path = translated_path # Detect MIME type from file extension using centralized mappings from utils.file_types import get_image_mime_type ext = os.path.splitext(image_path)[1].lower() mime_type = get_image_mime_type(ext) logging.debug(f"Processing image '{image_path}' with extension '{ext}' as MIME type '{mime_type}'") # Read and encode the image with open(image_path, "rb") as f: image_data = base64.b64encode(f.read()).decode() # Create data URL for OpenAI API data_url = f"data:{mime_type};base64,{image_data}" return {"type": "image_url", "image_url": {"url": data_url}} except Exception as e: logging.error(f"Error processing image {image_path}: {e}") return None