refactor: improved retry logic and moved core logic to base class
This commit is contained in:
@@ -1,8 +1,9 @@
|
|||||||
"""Base interfaces and common behaviour for model providers."""
|
"""Base interfaces and common behaviour for model providers."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tools.models import ToolModelCategory
|
from tools.models import ToolModelCategory
|
||||||
@@ -168,6 +169,107 @@ class ModelProvider(ABC):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Retry helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _is_error_retryable(self, error: Exception) -> bool:
|
||||||
|
"""Return True when an error warrants another attempt.
|
||||||
|
|
||||||
|
Subclasses with structured provider errors should override this hook.
|
||||||
|
The default implementation only retries obvious transient failures such
|
||||||
|
as timeouts or 5xx responses detected via string inspection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
error_str = str(error).lower()
|
||||||
|
retryable_indicators = [
|
||||||
|
"timeout",
|
||||||
|
"connection",
|
||||||
|
"temporary",
|
||||||
|
"unavailable",
|
||||||
|
"retry",
|
||||||
|
"reset",
|
||||||
|
"refused",
|
||||||
|
"broken pipe",
|
||||||
|
"tls",
|
||||||
|
"handshake",
|
||||||
|
"network",
|
||||||
|
"rate limit",
|
||||||
|
"429",
|
||||||
|
"500",
|
||||||
|
"502",
|
||||||
|
"503",
|
||||||
|
"504",
|
||||||
|
]
|
||||||
|
|
||||||
|
return any(indicator in error_str for indicator in retryable_indicators)
|
||||||
|
|
||||||
|
def _run_with_retries(
|
||||||
|
self,
|
||||||
|
operation: Callable[[], Any],
|
||||||
|
*,
|
||||||
|
max_attempts: int,
|
||||||
|
delays: Optional[list[float]] = None,
|
||||||
|
log_prefix: str = "",
|
||||||
|
):
|
||||||
|
"""Execute ``operation`` with retry semantics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation: Callable returning the provider result.
|
||||||
|
max_attempts: Maximum number of attempts (>=1).
|
||||||
|
delays: Optional list of sleep durations between attempts.
|
||||||
|
log_prefix: Optional identifier for log clarity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whatever ``operation`` returns.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
The last exception when all retries fail or the error is not retryable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if max_attempts < 1:
|
||||||
|
raise ValueError("max_attempts must be >= 1")
|
||||||
|
|
||||||
|
attempts = max_attempts
|
||||||
|
delays = delays or []
|
||||||
|
last_exc: Optional[Exception] = None
|
||||||
|
|
||||||
|
for attempt_index in range(attempts):
|
||||||
|
try:
|
||||||
|
return operation()
|
||||||
|
except Exception as exc: # noqa: BLE001 - bubble exact provider errors
|
||||||
|
last_exc = exc
|
||||||
|
attempt_number = attempt_index + 1
|
||||||
|
|
||||||
|
# Decide whether to retry based on subclass hook
|
||||||
|
retryable = self._is_error_retryable(exc)
|
||||||
|
if not retryable or attempt_number >= attempts:
|
||||||
|
raise
|
||||||
|
|
||||||
|
delay_idx = min(attempt_index, len(delays) - 1) if delays else -1
|
||||||
|
delay = delays[delay_idx] if delay_idx >= 0 else 0.0
|
||||||
|
|
||||||
|
if delay > 0:
|
||||||
|
logger.warning(
|
||||||
|
"%s retryable error (attempt %s/%s): %s. Retrying in %ss...",
|
||||||
|
log_prefix or self.__class__.__name__,
|
||||||
|
attempt_number,
|
||||||
|
attempts,
|
||||||
|
exc,
|
||||||
|
delay,
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"%s retryable error (attempt %s/%s): %s. Retrying...",
|
||||||
|
log_prefix or self.__class__.__name__,
|
||||||
|
attempt_number,
|
||||||
|
attempts,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should never reach here because loop either returns or raises
|
||||||
|
raise last_exc if last_exc else RuntimeError("Retry loop exited without result")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Validation hooks
|
# Validation hooks
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
@@ -405,15 +404,12 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
# DIAL-specific: Get cached client for deployment endpoint
|
# DIAL-specific: Get cached client for deployment endpoint
|
||||||
deployment_client = self._get_deployment_client(resolved_model)
|
deployment_client = self._get_deployment_client(resolved_model)
|
||||||
|
|
||||||
# Retry logic with progressive delays
|
attempt_counter = {"value": 0}
|
||||||
last_exception = None
|
|
||||||
|
|
||||||
for attempt in range(self.MAX_RETRIES):
|
def _attempt() -> ModelResponse:
|
||||||
try:
|
attempt_counter["value"] += 1
|
||||||
# Generate completion using deployment-specific client
|
|
||||||
response = deployment_client.chat.completions.create(**completion_params)
|
response = deployment_client.chat.completions.create(**completion_params)
|
||||||
|
|
||||||
# Extract content and usage
|
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
usage = self._extract_usage(response)
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
@@ -431,29 +427,19 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
try:
|
||||||
last_exception = e
|
return self._run_with_retries(
|
||||||
|
operation=_attempt,
|
||||||
# Check if this is a retryable error
|
max_attempts=self.MAX_RETRIES,
|
||||||
is_retryable = self._is_error_retryable(e)
|
delays=self.RETRY_DELAYS,
|
||||||
|
log_prefix=f"DIAL API ({resolved_model})",
|
||||||
if not is_retryable:
|
|
||||||
# Non-retryable error, raise immediately
|
|
||||||
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
|
|
||||||
|
|
||||||
# If this isn't the last attempt and error is retryable, wait and retry
|
|
||||||
if attempt < self.MAX_RETRIES - 1:
|
|
||||||
delay = self.RETRY_DELAYS[attempt]
|
|
||||||
logger.info(
|
|
||||||
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
|
|
||||||
)
|
)
|
||||||
time.sleep(delay)
|
except Exception as exc:
|
||||||
continue
|
attempts = max(attempt_counter["value"], 1)
|
||||||
|
if attempts == 1:
|
||||||
|
raise ValueError(f"DIAL API error for model {model_name}: {exc}") from exc
|
||||||
|
|
||||||
# All retries exhausted
|
raise ValueError(f"DIAL API error for model {model_name} after {attempts} attempts: {exc}") from exc
|
||||||
raise ValueError(
|
|
||||||
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Clean up HTTP clients when provider is closed."""
|
"""Clean up HTTP clients when provider is closed."""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -229,22 +228,18 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
# Retry logic with progressive delays
|
# Retry logic with progressive delays
|
||||||
max_retries = 4 # Total of 4 attempts
|
max_retries = 4 # Total of 4 attempts
|
||||||
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
||||||
|
attempt_counter = {"value": 0}
|
||||||
|
|
||||||
last_exception = None
|
def _attempt() -> ModelResponse:
|
||||||
|
attempt_counter["value"] += 1
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
# Generate content
|
|
||||||
response = self.client.models.generate_content(
|
response = self.client.models.generate_content(
|
||||||
model=resolved_name,
|
model=resolved_name,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
config=generation_config,
|
config=generation_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract usage information if available
|
|
||||||
usage = self._extract_usage(response)
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
# Intelligently determine finish reason and safety blocks
|
|
||||||
finish_reason_str = "UNKNOWN"
|
finish_reason_str = "UNKNOWN"
|
||||||
is_blocked_by_safety = False
|
is_blocked_by_safety = False
|
||||||
safety_feedback_details = None
|
safety_feedback_details = None
|
||||||
@@ -252,11 +247,9 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
if response.candidates:
|
if response.candidates:
|
||||||
candidate = response.candidates[0]
|
candidate = response.candidates[0]
|
||||||
|
|
||||||
# Safely get finish reason
|
|
||||||
try:
|
try:
|
||||||
finish_reason_enum = candidate.finish_reason
|
finish_reason_enum = candidate.finish_reason
|
||||||
if finish_reason_enum:
|
if finish_reason_enum:
|
||||||
# Handle both enum objects and string values
|
|
||||||
try:
|
try:
|
||||||
finish_reason_str = finish_reason_enum.name
|
finish_reason_str = finish_reason_enum.name
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@@ -266,16 +259,14 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
finish_reason_str = "STOP"
|
finish_reason_str = "STOP"
|
||||||
|
|
||||||
# If content is empty, check safety ratings for the definitive cause
|
|
||||||
if not response.text:
|
if not response.text:
|
||||||
try:
|
try:
|
||||||
safety_ratings = candidate.safety_ratings
|
safety_ratings = candidate.safety_ratings
|
||||||
if safety_ratings: # Check it's not None or empty
|
if safety_ratings:
|
||||||
for rating in safety_ratings:
|
for rating in safety_ratings:
|
||||||
try:
|
try:
|
||||||
if rating.blocked:
|
if rating.blocked:
|
||||||
is_blocked_by_safety = True
|
is_blocked_by_safety = True
|
||||||
# Provide details for logging/debugging
|
|
||||||
category_name = "UNKNOWN"
|
category_name = "UNKNOWN"
|
||||||
probability_name = "UNKNOWN"
|
probability_name = "UNKNOWN"
|
||||||
|
|
||||||
@@ -294,18 +285,14 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
except (AttributeError, TypeError):
|
except (AttributeError, TypeError):
|
||||||
# Individual rating doesn't have expected attributes
|
|
||||||
continue
|
continue
|
||||||
except (AttributeError, TypeError):
|
except (AttributeError, TypeError):
|
||||||
# candidate doesn't have safety_ratings or it's not iterable
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Also check for prompt-level blocking (request rejected entirely)
|
|
||||||
elif response.candidates is not None and len(response.candidates) == 0:
|
elif response.candidates is not None and len(response.candidates) == 0:
|
||||||
# No candidates is the primary indicator of a prompt-level block
|
|
||||||
is_blocked_by_safety = True
|
is_blocked_by_safety = True
|
||||||
finish_reason_str = "SAFETY"
|
finish_reason_str = "SAFETY"
|
||||||
safety_feedback_details = "Prompt blocked, reason unavailable" # Default message
|
safety_feedback_details = "Prompt blocked, reason unavailable"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt_feedback = response.prompt_feedback
|
prompt_feedback = response.prompt_feedback
|
||||||
@@ -316,7 +303,6 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
block_reason_name = str(prompt_feedback.block_reason)
|
block_reason_name = str(prompt_feedback.block_reason)
|
||||||
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
|
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
|
||||||
except (AttributeError, TypeError):
|
except (AttributeError, TypeError):
|
||||||
# prompt_feedback doesn't exist or has unexpected attributes; stick with the default message
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return ModelResponse(
|
return ModelResponse(
|
||||||
@@ -333,29 +319,20 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
try:
|
||||||
last_exception = e
|
return self._run_with_retries(
|
||||||
|
operation=_attempt,
|
||||||
# Check if this is a retryable error using structured error codes
|
max_attempts=max_retries,
|
||||||
is_retryable = self._is_error_retryable(e)
|
delays=retry_delays,
|
||||||
|
log_prefix=f"Gemini API ({resolved_name})",
|
||||||
# If this is the last attempt or not retryable, give up
|
|
||||||
if attempt == max_retries - 1 or not is_retryable:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Get progressive delay
|
|
||||||
delay = retry_delays[attempt]
|
|
||||||
|
|
||||||
# Log retry attempt
|
|
||||||
logger.warning(
|
|
||||||
f"Gemini API error for model {resolved_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
|
||||||
)
|
)
|
||||||
time.sleep(delay)
|
except Exception as exc:
|
||||||
|
attempts = max(attempt_counter["value"], 1)
|
||||||
# If we get here, all retries failed
|
error_msg = (
|
||||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
f"Gemini API error for model {resolved_name} after {attempts} attempt"
|
||||||
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
f"{'s' if attempts > 1 else ''}: {exc}"
|
||||||
raise RuntimeError(error_msg) from last_exception
|
)
|
||||||
|
raise RuntimeError(error_msg) from exc
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import copy
|
|||||||
import ipaddress
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -395,11 +394,10 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
# Retry logic with progressive delays
|
# Retry logic with progressive delays
|
||||||
max_retries = 4
|
max_retries = 4
|
||||||
retry_delays = [1, 3, 5, 8]
|
retry_delays = [1, 3, 5, 8]
|
||||||
last_exception = None
|
attempt_counter = {"value": 0}
|
||||||
actual_attempts = 0
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
def _attempt() -> ModelResponse:
|
||||||
try: # Log sanitized payload for debugging
|
attempt_counter["value"] += 1
|
||||||
import json
|
import json
|
||||||
|
|
||||||
sanitized_params = self._sanitize_for_logging(completion_params)
|
sanitized_params = self._sanitize_for_logging(completion_params)
|
||||||
@@ -407,19 +405,14 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
|
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use OpenAI client's responses endpoint
|
|
||||||
response = self.client.responses.create(**completion_params)
|
response = self.client.responses.create(**completion_params)
|
||||||
|
|
||||||
# Extract content from responses endpoint format
|
|
||||||
# Use validation helper to safely extract output_text
|
|
||||||
content = self._safe_extract_output_text(response)
|
content = self._safe_extract_output_text(response)
|
||||||
|
|
||||||
# Try to extract usage information
|
|
||||||
usage = None
|
usage = None
|
||||||
if hasattr(response, "usage"):
|
if hasattr(response, "usage"):
|
||||||
usage = self._extract_usage(response)
|
usage = self._extract_usage(response)
|
||||||
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
|
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
|
||||||
# Safely extract token counts with None handling
|
|
||||||
input_tokens = getattr(response, "input_tokens", 0) or 0
|
input_tokens = getattr(response, "input_tokens", 0) or 0
|
||||||
output_tokens = getattr(response, "output_tokens", 0) or 0
|
output_tokens = getattr(response, "output_tokens", 0) or 0
|
||||||
usage = {
|
usage = {
|
||||||
@@ -442,25 +435,18 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
try:
|
||||||
last_exception = e
|
return self._run_with_retries(
|
||||||
|
operation=_attempt,
|
||||||
# Check if this is a retryable error using structured error codes
|
max_attempts=max_retries,
|
||||||
is_retryable = self._is_error_retryable(e)
|
delays=retry_delays,
|
||||||
|
log_prefix="o3-pro responses endpoint",
|
||||||
if is_retryable and attempt < max_retries - 1:
|
|
||||||
delay = retry_delays[attempt]
|
|
||||||
logging.warning(
|
|
||||||
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
|
||||||
)
|
)
|
||||||
time.sleep(delay)
|
except Exception as exc:
|
||||||
else:
|
attempts = max(attempt_counter["value"], 1)
|
||||||
break
|
error_msg = f"o3-pro responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}"
|
||||||
|
|
||||||
# If we get here, all retries failed
|
|
||||||
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise RuntimeError(error_msg) from last_exception
|
raise RuntimeError(error_msg) from exc
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
@@ -587,17 +573,12 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
# Retry logic with progressive delays
|
# Retry logic with progressive delays
|
||||||
max_retries = 4 # Total of 4 attempts
|
max_retries = 4 # Total of 4 attempts
|
||||||
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
||||||
|
attempt_counter = {"value": 0}
|
||||||
|
|
||||||
last_exception = None
|
def _attempt() -> ModelResponse:
|
||||||
actual_attempts = 0
|
attempt_counter["value"] += 1
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
|
||||||
try:
|
|
||||||
# Generate completion
|
|
||||||
response = self.client.chat.completions.create(**completion_params)
|
response = self.client.chat.completions.create(**completion_params)
|
||||||
|
|
||||||
# Extract content and usage
|
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
usage = self._extract_usage(response)
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
@@ -609,35 +590,27 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
provider=self.get_provider_type(),
|
provider=self.get_provider_type(),
|
||||||
metadata={
|
metadata={
|
||||||
"finish_reason": response.choices[0].finish_reason,
|
"finish_reason": response.choices[0].finish_reason,
|
||||||
"model": response.model, # Actual model used
|
"model": response.model,
|
||||||
"id": response.id,
|
"id": response.id,
|
||||||
"created": response.created,
|
"created": response.created,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
try:
|
||||||
last_exception = e
|
return self._run_with_retries(
|
||||||
|
operation=_attempt,
|
||||||
# Check if this is a retryable error using structured error codes
|
max_attempts=max_retries,
|
||||||
is_retryable = self._is_error_retryable(e)
|
delays=retry_delays,
|
||||||
|
log_prefix=f"{self.FRIENDLY_NAME} API ({model_name})",
|
||||||
# If this is the last attempt or not retryable, give up
|
)
|
||||||
if attempt == max_retries - 1 or not is_retryable:
|
except Exception as exc:
|
||||||
break
|
attempts = max(attempt_counter["value"], 1)
|
||||||
|
error_msg = (
|
||||||
# Get progressive delay
|
f"{self.FRIENDLY_NAME} API error for model {model_name} after {attempts} attempt"
|
||||||
delay = retry_delays[attempt]
|
f"{'s' if attempts > 1 else ''}: {exc}"
|
||||||
|
|
||||||
# Log retry attempt
|
|
||||||
logging.warning(
|
|
||||||
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
|
||||||
)
|
)
|
||||||
time.sleep(delay)
|
|
||||||
|
|
||||||
# If we get here, all retries failed
|
|
||||||
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise RuntimeError(error_msg) from last_exception
|
raise RuntimeError(error_msg) from exc
|
||||||
|
|
||||||
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||||
"""Validate model parameters.
|
"""Validate model parameters.
|
||||||
|
|||||||
73
tests/test_provider_retry_logic.py
Normal file
73
tests/test_provider_retry_logic.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Tests covering shared retry behaviour for providers."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_chat_response(content: str = "retry success") -> SimpleNamespace:
|
||||||
|
"""Create a minimal chat completion response for tests."""
|
||||||
|
|
||||||
|
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||||
|
message = SimpleNamespace(content=content)
|
||||||
|
choice = SimpleNamespace(message=message, finish_reason="stop")
|
||||||
|
return SimpleNamespace(choices=[choice], model="gpt-4.1", id="resp-1", created=123, usage=usage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_provider_retries_on_transient_error(monkeypatch):
|
||||||
|
"""Provider should retry once for retryable errors and eventually succeed."""
|
||||||
|
|
||||||
|
monkeypatch.setattr("providers.base.time.sleep", lambda _: None)
|
||||||
|
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
attempts = {"count": 0}
|
||||||
|
|
||||||
|
def create_completion(**kwargs):
|
||||||
|
attempts["count"] += 1
|
||||||
|
if attempts["count"] == 1:
|
||||||
|
raise RuntimeError("temporary network interruption")
|
||||||
|
return _mock_chat_response("second attempt response")
|
||||||
|
|
||||||
|
provider._client = SimpleNamespace(
|
||||||
|
chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)),
|
||||||
|
responses=SimpleNamespace(create=lambda **_: None),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = provider.generate_content("hello", "gpt-4.1")
|
||||||
|
|
||||||
|
assert attempts["count"] == 2, "Expected a retry before succeeding"
|
||||||
|
assert result.content == "second attempt response"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_provider_bails_on_non_retryable_error(monkeypatch):
|
||||||
|
"""Provider should stop immediately when the error is marked non-retryable."""
|
||||||
|
|
||||||
|
monkeypatch.setattr("providers.base.time.sleep", lambda _: None)
|
||||||
|
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
attempts = {"count": 0}
|
||||||
|
|
||||||
|
def create_completion(**kwargs):
|
||||||
|
attempts["count"] += 1
|
||||||
|
raise RuntimeError("context length exceeded 429")
|
||||||
|
|
||||||
|
provider._client = SimpleNamespace(
|
||||||
|
chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)),
|
||||||
|
responses=SimpleNamespace(create=lambda **_: None),
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
OpenAIModelProvider,
|
||||||
|
"_is_error_retryable",
|
||||||
|
lambda self, error: False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
provider.generate_content("hello", "gpt-4.1")
|
||||||
|
|
||||||
|
assert "after 1 attempt" in str(excinfo.value)
|
||||||
|
assert attempts["count"] == 1
|
||||||
Reference in New Issue
Block a user