refactor: improved retry logic and moved core logic to base class

This commit is contained in:
Fahad
2025-10-03 23:48:55 +04:00
parent 828c4eed5b
commit f955100f3a
5 changed files with 378 additions and 267 deletions

View File

@@ -1,8 +1,9 @@
"""Base interfaces and common behaviour for model providers."""
import logging
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
@@ -168,6 +169,107 @@ class ModelProvider(ABC):
return
# ------------------------------------------------------------------
# Retry helpers
# ------------------------------------------------------------------
def _is_error_retryable(self, error: Exception) -> bool:
"""Return True when an error warrants another attempt.
Subclasses with structured provider errors should override this hook.
The default implementation only retries obvious transient failures such
as timeouts or 5xx responses detected via string inspection.
"""
error_str = str(error).lower()
retryable_indicators = [
"timeout",
"connection",
"temporary",
"unavailable",
"retry",
"reset",
"refused",
"broken pipe",
"tls",
"handshake",
"network",
"rate limit",
"429",
"500",
"502",
"503",
"504",
]
return any(indicator in error_str for indicator in retryable_indicators)
def _run_with_retries(
self,
operation: Callable[[], Any],
*,
max_attempts: int,
delays: Optional[list[float]] = None,
log_prefix: str = "",
):
"""Execute ``operation`` with retry semantics.
Args:
operation: Callable returning the provider result.
max_attempts: Maximum number of attempts (>=1).
delays: Optional list of sleep durations between attempts.
log_prefix: Optional identifier for log clarity.
Returns:
Whatever ``operation`` returns.
Raises:
The last exception when all retries fail or the error is not retryable.
"""
if max_attempts < 1:
raise ValueError("max_attempts must be >= 1")
attempts = max_attempts
delays = delays or []
last_exc: Optional[Exception] = None
for attempt_index in range(attempts):
try:
return operation()
except Exception as exc: # noqa: BLE001 - bubble exact provider errors
last_exc = exc
attempt_number = attempt_index + 1
# Decide whether to retry based on subclass hook
retryable = self._is_error_retryable(exc)
if not retryable or attempt_number >= attempts:
raise
delay_idx = min(attempt_index, len(delays) - 1) if delays else -1
delay = delays[delay_idx] if delay_idx >= 0 else 0.0
if delay > 0:
logger.warning(
"%s retryable error (attempt %s/%s): %s. Retrying in %ss...",
log_prefix or self.__class__.__name__,
attempt_number,
attempts,
exc,
delay,
)
time.sleep(delay)
else:
logger.warning(
"%s retryable error (attempt %s/%s): %s. Retrying...",
log_prefix or self.__class__.__name__,
attempt_number,
attempts,
exc,
)
# Should never reach here because loop either returns or raises
raise last_exc if last_exc else RuntimeError("Retry loop exited without result")
# ------------------------------------------------------------------
# Validation hooks
# ------------------------------------------------------------------

View File

@@ -3,7 +3,6 @@
import logging
import os
import threading
import time
from typing import Optional
from .openai_compatible import OpenAICompatibleProvider
@@ -405,15 +404,12 @@ class DIALModelProvider(OpenAICompatibleProvider):
# DIAL-specific: Get cached client for deployment endpoint
deployment_client = self._get_deployment_client(resolved_model)
# Retry logic with progressive delays
last_exception = None
attempt_counter = {"value": 0}
for attempt in range(self.MAX_RETRIES):
try:
# Generate completion using deployment-specific client
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
response = deployment_client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
@@ -431,29 +427,19 @@ class DIALModelProvider(OpenAICompatibleProvider):
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error
is_retryable = self._is_error_retryable(e)
if not is_retryable:
# Non-retryable error, raise immediately
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
# If this isn't the last attempt and error is retryable, wait and retry
if attempt < self.MAX_RETRIES - 1:
delay = self.RETRY_DELAYS[attempt]
logger.info(
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=self.MAX_RETRIES,
delays=self.RETRY_DELAYS,
log_prefix=f"DIAL API ({resolved_model})",
)
time.sleep(delay)
continue
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
if attempts == 1:
raise ValueError(f"DIAL API error for model {model_name}: {exc}") from exc
# All retries exhausted
raise ValueError(
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
)
raise ValueError(f"DIAL API error for model {model_name} after {attempts} attempts: {exc}") from exc
def close(self) -> None:
"""Clean up HTTP clients when provider is closed."""

View File

@@ -2,7 +2,6 @@
import base64
import logging
import time
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
@@ -229,22 +228,18 @@ class GeminiModelProvider(ModelProvider):
# Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
attempt_counter = {"value": 0}
last_exception = None
for attempt in range(max_retries):
try:
# Generate content
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
response = self.client.models.generate_content(
model=resolved_name,
contents=contents,
config=generation_config,
)
# Extract usage information if available
usage = self._extract_usage(response)
# Intelligently determine finish reason and safety blocks
finish_reason_str = "UNKNOWN"
is_blocked_by_safety = False
safety_feedback_details = None
@@ -252,11 +247,9 @@ class GeminiModelProvider(ModelProvider):
if response.candidates:
candidate = response.candidates[0]
# Safely get finish reason
try:
finish_reason_enum = candidate.finish_reason
if finish_reason_enum:
# Handle both enum objects and string values
try:
finish_reason_str = finish_reason_enum.name
except AttributeError:
@@ -266,16 +259,14 @@ class GeminiModelProvider(ModelProvider):
except AttributeError:
finish_reason_str = "STOP"
# If content is empty, check safety ratings for the definitive cause
if not response.text:
try:
safety_ratings = candidate.safety_ratings
if safety_ratings: # Check it's not None or empty
if safety_ratings:
for rating in safety_ratings:
try:
if rating.blocked:
is_blocked_by_safety = True
# Provide details for logging/debugging
category_name = "UNKNOWN"
probability_name = "UNKNOWN"
@@ -294,18 +285,14 @@ class GeminiModelProvider(ModelProvider):
)
break
except (AttributeError, TypeError):
# Individual rating doesn't have expected attributes
continue
except (AttributeError, TypeError):
# candidate doesn't have safety_ratings or it's not iterable
pass
# Also check for prompt-level blocking (request rejected entirely)
elif response.candidates is not None and len(response.candidates) == 0:
# No candidates is the primary indicator of a prompt-level block
is_blocked_by_safety = True
finish_reason_str = "SAFETY"
safety_feedback_details = "Prompt blocked, reason unavailable" # Default message
safety_feedback_details = "Prompt blocked, reason unavailable"
try:
prompt_feedback = response.prompt_feedback
@@ -316,7 +303,6 @@ class GeminiModelProvider(ModelProvider):
block_reason_name = str(prompt_feedback.block_reason)
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
except (AttributeError, TypeError):
# prompt_feedback doesn't exist or has unexpected attributes; stick with the default message
pass
return ModelResponse(
@@ -333,29 +319,20 @@ class GeminiModelProvider(ModelProvider):
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
# If this is the last attempt or not retryable, give up
if attempt == max_retries - 1 or not is_retryable:
break
# Get progressive delay
delay = retry_delays[attempt]
# Log retry attempt
logger.warning(
f"Gemini API error for model {resolved_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix=f"Gemini API ({resolved_name})",
)
time.sleep(delay)
# If we get here, all retries failed
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
raise RuntimeError(error_msg) from last_exception
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = (
f"Gemini API error for model {resolved_name} after {attempts} attempt"
f"{'s' if attempts > 1 else ''}: {exc}"
)
raise RuntimeError(error_msg) from exc
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""

View File

@@ -4,7 +4,6 @@ import copy
import ipaddress
import logging
import os
import time
from typing import Optional
from urllib.parse import urlparse
@@ -395,11 +394,10 @@ class OpenAICompatibleProvider(ModelProvider):
# Retry logic with progressive delays
max_retries = 4
retry_delays = [1, 3, 5, 8]
last_exception = None
actual_attempts = 0
attempt_counter = {"value": 0}
for attempt in range(max_retries):
try: # Log sanitized payload for debugging
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
import json
sanitized_params = self._sanitize_for_logging(completion_params)
@@ -407,19 +405,14 @@ class OpenAICompatibleProvider(ModelProvider):
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
)
# Use OpenAI client's responses endpoint
response = self.client.responses.create(**completion_params)
# Extract content from responses endpoint format
# Use validation helper to safely extract output_text
content = self._safe_extract_output_text(response)
# Try to extract usage information
usage = None
if hasattr(response, "usage"):
usage = self._extract_usage(response)
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
# Safely extract token counts with None handling
input_tokens = getattr(response, "input_tokens", 0) or 0
output_tokens = getattr(response, "output_tokens", 0) or 0
usage = {
@@ -442,25 +435,18 @@ class OpenAICompatibleProvider(ModelProvider):
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
if is_retryable and attempt < max_retries - 1:
delay = retry_delays[attempt]
logging.warning(
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix="o3-pro responses endpoint",
)
time.sleep(delay)
else:
break
# If we get here, all retries failed
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = f"o3-pro responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
raise RuntimeError(error_msg) from exc
def generate_content(
self,
@@ -587,17 +573,12 @@ class OpenAICompatibleProvider(ModelProvider):
# Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
attempt_counter = {"value": 0}
last_exception = None
actual_attempts = 0
for attempt in range(max_retries):
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
try:
# Generate completion
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
response = self.client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
@@ -609,35 +590,27 @@ class OpenAICompatibleProvider(ModelProvider):
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model, # Actual model used
"model": response.model,
"id": response.id,
"created": response.created,
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
# If this is the last attempt or not retryable, give up
if attempt == max_retries - 1 or not is_retryable:
break
# Get progressive delay
delay = retry_delays[attempt]
# Log retry attempt
logging.warning(
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix=f"{self.FRIENDLY_NAME} API ({model_name})",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = (
f"{self.FRIENDLY_NAME} API error for model {model_name} after {attempts} attempt"
f"{'s' if attempts > 1 else ''}: {exc}"
)
time.sleep(delay)
# If we get here, all retries failed
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
raise RuntimeError(error_msg) from exc
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters.

View File

@@ -0,0 +1,73 @@
"""Tests covering shared retry behaviour for providers."""
from types import SimpleNamespace
import pytest
from providers.openai_provider import OpenAIModelProvider
def _mock_chat_response(content: str = "retry success") -> SimpleNamespace:
"""Create a minimal chat completion response for tests."""
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
message = SimpleNamespace(content=content)
choice = SimpleNamespace(message=message, finish_reason="stop")
return SimpleNamespace(choices=[choice], model="gpt-4.1", id="resp-1", created=123, usage=usage)
def test_openai_provider_retries_on_transient_error(monkeypatch):
"""Provider should retry once for retryable errors and eventually succeed."""
monkeypatch.setattr("providers.base.time.sleep", lambda _: None)
provider = OpenAIModelProvider(api_key="test-key")
attempts = {"count": 0}
def create_completion(**kwargs):
attempts["count"] += 1
if attempts["count"] == 1:
raise RuntimeError("temporary network interruption")
return _mock_chat_response("second attempt response")
provider._client = SimpleNamespace(
chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)),
responses=SimpleNamespace(create=lambda **_: None),
)
result = provider.generate_content("hello", "gpt-4.1")
assert attempts["count"] == 2, "Expected a retry before succeeding"
assert result.content == "second attempt response"
def test_openai_provider_bails_on_non_retryable_error(monkeypatch):
"""Provider should stop immediately when the error is marked non-retryable."""
monkeypatch.setattr("providers.base.time.sleep", lambda _: None)
provider = OpenAIModelProvider(api_key="test-key")
attempts = {"count": 0}
def create_completion(**kwargs):
attempts["count"] += 1
raise RuntimeError("context length exceeded 429")
provider._client = SimpleNamespace(
chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)),
responses=SimpleNamespace(create=lambda **_: None),
)
monkeypatch.setattr(
OpenAIModelProvider,
"_is_error_retryable",
lambda self, error: False,
)
with pytest.raises(RuntimeError) as excinfo:
provider.generate_content("hello", "gpt-4.1")
assert "after 1 attempt" in str(excinfo.value)
assert attempts["count"] == 1