refactor: improved retry logic and moved core logic to base class

This commit is contained in:
Fahad
2025-10-03 23:48:55 +04:00
parent 828c4eed5b
commit f955100f3a
5 changed files with 378 additions and 267 deletions

View File

@@ -4,7 +4,6 @@ import copy
import ipaddress
import logging
import os
import time
from typing import Optional
from urllib.parse import urlparse
@@ -395,72 +394,59 @@ class OpenAICompatibleProvider(ModelProvider):
# Retry logic with progressive delays
max_retries = 4
retry_delays = [1, 3, 5, 8]
last_exception = None
actual_attempts = 0
attempt_counter = {"value": 0}
for attempt in range(max_retries):
try: # Log sanitized payload for debugging
import json
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
import json
sanitized_params = self._sanitize_for_logging(completion_params)
logging.info(
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
)
sanitized_params = self._sanitize_for_logging(completion_params)
logging.info(
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
)
# Use OpenAI client's responses endpoint
response = self.client.responses.create(**completion_params)
response = self.client.responses.create(**completion_params)
# Extract content from responses endpoint format
# Use validation helper to safely extract output_text
content = self._safe_extract_output_text(response)
content = self._safe_extract_output_text(response)
# Try to extract usage information
usage = None
if hasattr(response, "usage"):
usage = self._extract_usage(response)
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
# Safely extract token counts with None handling
input_tokens = getattr(response, "input_tokens", 0) or 0
output_tokens = getattr(response, "output_tokens", 0) or 0
usage = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
usage = None
if hasattr(response, "usage"):
usage = self._extract_usage(response)
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
input_tokens = getattr(response, "input_tokens", 0) or 0
output_tokens = getattr(response, "output_tokens", 0) or 0
usage = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"model": getattr(response, "model", model_name),
"id": getattr(response, "id", ""),
"created": getattr(response, "created_at", 0),
"endpoint": "responses",
},
)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"model": getattr(response, "model", model_name),
"id": getattr(response, "id", ""),
"created": getattr(response, "created_at", 0),
"endpoint": "responses",
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
if is_retryable and attempt < max_retries - 1:
delay = retry_delays[attempt]
logging.warning(
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
else:
break
# If we get here, all retries failed
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix="o3-pro responses endpoint",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = f"o3-pro responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}"
logging.error(error_msg)
raise RuntimeError(error_msg) from exc
def generate_content(
self,
@@ -587,57 +573,44 @@ class OpenAICompatibleProvider(ModelProvider):
# Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
attempt_counter = {"value": 0}
last_exception = None
actual_attempts = 0
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
response = self.client.chat.completions.create(**completion_params)
for attempt in range(max_retries):
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
try:
# Generate completion
response = self.client.chat.completions.create(**completion_params)
content = response.choices[0].message.content
usage = self._extract_usage(response)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model,
"id": response.id,
"created": response.created,
},
)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model, # Actual model used
"id": response.id,
"created": response.created,
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
# If this is the last attempt or not retryable, give up
if attempt == max_retries - 1 or not is_retryable:
break
# Get progressive delay
delay = retry_delays[attempt]
# Log retry attempt
logging.warning(
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
# If we get here, all retries failed
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix=f"{self.FRIENDLY_NAME} API ({model_name})",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = (
f"{self.FRIENDLY_NAME} API error for model {model_name} after {attempts} attempt"
f"{'s' if attempts > 1 else ''}: {exc}"
)
logging.error(error_msg)
raise RuntimeError(error_msg) from exc
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters.