Merge branch 'main' into refactor-image-validation
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""Base class for OpenAI-compatible API providers."""
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
@@ -219,10 +221,20 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
# Create httpx client with minimal config to avoid proxy conflicts
|
||||
# Note: proxies parameter was removed in httpx 0.28.0
|
||||
http_client = httpx.Client(
|
||||
timeout=timeout_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
# Check for test transport injection
|
||||
if hasattr(self, "_test_transport"):
|
||||
# Use custom transport for testing (HTTP recording/replay)
|
||||
http_client = httpx.Client(
|
||||
transport=self._test_transport,
|
||||
timeout=timeout_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
else:
|
||||
# Normal production client
|
||||
http_client = httpx.Client(
|
||||
timeout=timeout_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
# Keep client initialization minimal to avoid proxy parameter conflicts
|
||||
client_kwargs = {
|
||||
@@ -263,6 +275,63 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
return self._client
|
||||
|
||||
def _sanitize_for_logging(self, params: dict) -> dict:
|
||||
"""Sanitize sensitive data from parameters before logging.
|
||||
|
||||
Args:
|
||||
params: Dictionary of API parameters
|
||||
|
||||
Returns:
|
||||
dict: Sanitized copy of parameters safe for logging
|
||||
"""
|
||||
sanitized = copy.deepcopy(params)
|
||||
|
||||
# Sanitize messages content
|
||||
if "input" in sanitized:
|
||||
for msg in sanitized.get("input", []):
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
for content_item in msg.get("content", []):
|
||||
if isinstance(content_item, dict) and "text" in content_item:
|
||||
# Truncate long text and add ellipsis
|
||||
text = content_item["text"]
|
||||
if len(text) > 100:
|
||||
content_item["text"] = text[:100] + "... [truncated]"
|
||||
|
||||
# Remove any API keys that might be in headers/auth
|
||||
sanitized.pop("api_key", None)
|
||||
sanitized.pop("authorization", None)
|
||||
|
||||
return sanitized
|
||||
|
||||
def _safe_extract_output_text(self, response) -> str:
|
||||
"""Safely extract output_text from o3-pro response with validation.
|
||||
|
||||
Args:
|
||||
response: Response object from OpenAI SDK
|
||||
|
||||
Returns:
|
||||
str: The output text content
|
||||
|
||||
Raises:
|
||||
ValueError: If output_text is missing, None, or not a string
|
||||
"""
|
||||
logging.debug(f"Response object type: {type(response)}")
|
||||
logging.debug(f"Response attributes: {dir(response)}")
|
||||
|
||||
if not hasattr(response, "output_text"):
|
||||
raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
|
||||
|
||||
content = response.output_text
|
||||
logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
|
||||
|
||||
if content is None:
|
||||
raise ValueError("o3-pro returned None for output_text")
|
||||
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}")
|
||||
|
||||
return content
|
||||
|
||||
def _generate_with_responses_endpoint(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -308,30 +377,23 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
max_retries = 4
|
||||
retry_delays = [1, 3, 5, 8]
|
||||
last_exception = None
|
||||
actual_attempts = 0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try: # Log the exact payload being sent for debugging
|
||||
try: # Log sanitized payload for debugging
|
||||
import json
|
||||
|
||||
sanitized_params = self._sanitize_for_logging(completion_params)
|
||||
logging.info(
|
||||
f"o3-pro API request payload: {json.dumps(completion_params, indent=2, ensure_ascii=False)}"
|
||||
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
# Use OpenAI client's responses endpoint
|
||||
response = self.client.responses.create(**completion_params)
|
||||
|
||||
# Extract content and usage from responses endpoint format
|
||||
# The response format is different for responses endpoint
|
||||
content = ""
|
||||
if hasattr(response, "output") and response.output:
|
||||
if hasattr(response.output, "content") and response.output.content:
|
||||
# Look for output_text in content
|
||||
for content_item in response.output.content:
|
||||
if hasattr(content_item, "type") and content_item.type == "output_text":
|
||||
content = content_item.text
|
||||
break
|
||||
elif hasattr(response.output, "text"):
|
||||
content = response.output.text
|
||||
# Extract content from responses endpoint format
|
||||
# Use validation helper to safely extract output_text
|
||||
content = self._safe_extract_output_text(response)
|
||||
|
||||
# Try to extract usage information
|
||||
usage = None
|
||||
@@ -370,14 +432,13 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
if is_retryable and attempt < max_retries - 1:
|
||||
delay = retry_delays[attempt]
|
||||
logging.warning(
|
||||
f"Retryable error for o3-pro responses endpoint, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break
|
||||
|
||||
# If we get here, all retries failed
|
||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from last_exception
|
||||
@@ -387,7 +448,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
images: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
@@ -480,7 +541,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
completion_params[key] = value
|
||||
|
||||
# Check if this is o3-pro and needs the responses endpoint
|
||||
if resolved_model == "o3-pro-2025-06-10":
|
||||
if resolved_model == "o3-pro":
|
||||
# This model requires the /v1/responses endpoint
|
||||
# If it fails, we should not fall back to chat/completions
|
||||
return self._generate_with_responses_endpoint(
|
||||
@@ -496,8 +557,10 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
||||
|
||||
last_exception = None
|
||||
actual_attempts = 0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||
try:
|
||||
# Generate completion
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
@@ -535,12 +598,11 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
# Log retry attempt
|
||||
logging.warning(
|
||||
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
# If we get here, all retries failed
|
||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from last_exception
|
||||
@@ -575,11 +637,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
# Try common encodings based on model patterns
|
||||
if "gpt-4" in model_name or "gpt-3.5" in model_name:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
else:
|
||||
encoding = tiktoken.get_encoding("cl100k_base") # Default
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
return len(encoding.encode(text))
|
||||
|
||||
@@ -678,11 +736,13 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
"""
|
||||
# Common vision-capable models - only include models that actually support images
|
||||
vision_models = {
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4.1-2025-04-14", # GPT-4.1 supports vision
|
||||
"gpt-4.1-2025-04-14",
|
||||
"o3",
|
||||
"o3-mini",
|
||||
"o3-pro",
|
||||
|
||||
Reference in New Issue
Block a user