refactor: improved retry logic and moved core logic to base class

This commit is contained in:
Fahad
2025-10-03 23:48:55 +04:00
parent 828c4eed5b
commit f955100f3a
5 changed files with 378 additions and 267 deletions

View File

@@ -1,8 +1,9 @@
"""Base interfaces and common behaviour for model providers.""" """Base interfaces and common behaviour for model providers."""
import logging import logging
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Callable, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from tools.models import ToolModelCategory from tools.models import ToolModelCategory
@@ -168,6 +169,107 @@ class ModelProvider(ABC):
return return
# ------------------------------------------------------------------
# Retry helpers
# ------------------------------------------------------------------
def _is_error_retryable(self, error: Exception) -> bool:
"""Return True when an error warrants another attempt.
Subclasses with structured provider errors should override this hook.
The default implementation only retries obvious transient failures such
as timeouts or 5xx responses detected via string inspection.
"""
error_str = str(error).lower()
retryable_indicators = [
"timeout",
"connection",
"temporary",
"unavailable",
"retry",
"reset",
"refused",
"broken pipe",
"tls",
"handshake",
"network",
"rate limit",
"429",
"500",
"502",
"503",
"504",
]
return any(indicator in error_str for indicator in retryable_indicators)
def _run_with_retries(
self,
operation: Callable[[], Any],
*,
max_attempts: int,
delays: Optional[list[float]] = None,
log_prefix: str = "",
):
"""Execute ``operation`` with retry semantics.
Args:
operation: Callable returning the provider result.
max_attempts: Maximum number of attempts (>=1).
delays: Optional list of sleep durations between attempts.
log_prefix: Optional identifier for log clarity.
Returns:
Whatever ``operation`` returns.
Raises:
The last exception when all retries fail or the error is not retryable.
"""
if max_attempts < 1:
raise ValueError("max_attempts must be >= 1")
attempts = max_attempts
delays = delays or []
last_exc: Optional[Exception] = None
for attempt_index in range(attempts):
try:
return operation()
except Exception as exc: # noqa: BLE001 - bubble exact provider errors
last_exc = exc
attempt_number = attempt_index + 1
# Decide whether to retry based on subclass hook
retryable = self._is_error_retryable(exc)
if not retryable or attempt_number >= attempts:
raise
delay_idx = min(attempt_index, len(delays) - 1) if delays else -1
delay = delays[delay_idx] if delay_idx >= 0 else 0.0
if delay > 0:
logger.warning(
"%s retryable error (attempt %s/%s): %s. Retrying in %ss...",
log_prefix or self.__class__.__name__,
attempt_number,
attempts,
exc,
delay,
)
time.sleep(delay)
else:
logger.warning(
"%s retryable error (attempt %s/%s): %s. Retrying...",
log_prefix or self.__class__.__name__,
attempt_number,
attempts,
exc,
)
# Should never reach here because loop either returns or raises
raise last_exc if last_exc else RuntimeError("Retry loop exited without result")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Validation hooks # Validation hooks
# ------------------------------------------------------------------ # ------------------------------------------------------------------

View File

@@ -3,7 +3,6 @@
import logging import logging
import os import os
import threading import threading
import time
from typing import Optional from typing import Optional
from .openai_compatible import OpenAICompatibleProvider from .openai_compatible import OpenAICompatibleProvider
@@ -405,15 +404,12 @@ class DIALModelProvider(OpenAICompatibleProvider):
# DIAL-specific: Get cached client for deployment endpoint # DIAL-specific: Get cached client for deployment endpoint
deployment_client = self._get_deployment_client(resolved_model) deployment_client = self._get_deployment_client(resolved_model)
# Retry logic with progressive delays attempt_counter = {"value": 0}
last_exception = None
for attempt in range(self.MAX_RETRIES): def _attempt() -> ModelResponse:
try: attempt_counter["value"] += 1
# Generate completion using deployment-specific client
response = deployment_client.chat.completions.create(**completion_params) response = deployment_client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content content = response.choices[0].message.content
usage = self._extract_usage(response) usage = self._extract_usage(response)
@@ -431,29 +427,19 @@ class DIALModelProvider(OpenAICompatibleProvider):
}, },
) )
except Exception as e: try:
last_exception = e return self._run_with_retries(
operation=_attempt,
# Check if this is a retryable error max_attempts=self.MAX_RETRIES,
is_retryable = self._is_error_retryable(e) delays=self.RETRY_DELAYS,
log_prefix=f"DIAL API ({resolved_model})",
if not is_retryable:
# Non-retryable error, raise immediately
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
# If this isn't the last attempt and error is retryable, wait and retry
if attempt < self.MAX_RETRIES - 1:
delay = self.RETRY_DELAYS[attempt]
logger.info(
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
) )
time.sleep(delay) except Exception as exc:
continue attempts = max(attempt_counter["value"], 1)
if attempts == 1:
raise ValueError(f"DIAL API error for model {model_name}: {exc}") from exc
# All retries exhausted raise ValueError(f"DIAL API error for model {model_name} after {attempts} attempts: {exc}") from exc
raise ValueError(
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
)
def close(self) -> None: def close(self) -> None:
"""Clean up HTTP clients when provider is closed.""" """Clean up HTTP clients when provider is closed."""

View File

@@ -2,7 +2,6 @@
import base64 import base64
import logging import logging
import time
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -229,22 +228,18 @@ class GeminiModelProvider(ModelProvider):
# Retry logic with progressive delays # Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
attempt_counter = {"value": 0}
last_exception = None def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
for attempt in range(max_retries):
try:
# Generate content
response = self.client.models.generate_content( response = self.client.models.generate_content(
model=resolved_name, model=resolved_name,
contents=contents, contents=contents,
config=generation_config, config=generation_config,
) )
# Extract usage information if available
usage = self._extract_usage(response) usage = self._extract_usage(response)
# Intelligently determine finish reason and safety blocks
finish_reason_str = "UNKNOWN" finish_reason_str = "UNKNOWN"
is_blocked_by_safety = False is_blocked_by_safety = False
safety_feedback_details = None safety_feedback_details = None
@@ -252,11 +247,9 @@ class GeminiModelProvider(ModelProvider):
if response.candidates: if response.candidates:
candidate = response.candidates[0] candidate = response.candidates[0]
# Safely get finish reason
try: try:
finish_reason_enum = candidate.finish_reason finish_reason_enum = candidate.finish_reason
if finish_reason_enum: if finish_reason_enum:
# Handle both enum objects and string values
try: try:
finish_reason_str = finish_reason_enum.name finish_reason_str = finish_reason_enum.name
except AttributeError: except AttributeError:
@@ -266,16 +259,14 @@ class GeminiModelProvider(ModelProvider):
except AttributeError: except AttributeError:
finish_reason_str = "STOP" finish_reason_str = "STOP"
# If content is empty, check safety ratings for the definitive cause
if not response.text: if not response.text:
try: try:
safety_ratings = candidate.safety_ratings safety_ratings = candidate.safety_ratings
if safety_ratings: # Check it's not None or empty if safety_ratings:
for rating in safety_ratings: for rating in safety_ratings:
try: try:
if rating.blocked: if rating.blocked:
is_blocked_by_safety = True is_blocked_by_safety = True
# Provide details for logging/debugging
category_name = "UNKNOWN" category_name = "UNKNOWN"
probability_name = "UNKNOWN" probability_name = "UNKNOWN"
@@ -294,18 +285,14 @@ class GeminiModelProvider(ModelProvider):
) )
break break
except (AttributeError, TypeError): except (AttributeError, TypeError):
# Individual rating doesn't have expected attributes
continue continue
except (AttributeError, TypeError): except (AttributeError, TypeError):
# candidate doesn't have safety_ratings or it's not iterable
pass pass
# Also check for prompt-level blocking (request rejected entirely)
elif response.candidates is not None and len(response.candidates) == 0: elif response.candidates is not None and len(response.candidates) == 0:
# No candidates is the primary indicator of a prompt-level block
is_blocked_by_safety = True is_blocked_by_safety = True
finish_reason_str = "SAFETY" finish_reason_str = "SAFETY"
safety_feedback_details = "Prompt blocked, reason unavailable" # Default message safety_feedback_details = "Prompt blocked, reason unavailable"
try: try:
prompt_feedback = response.prompt_feedback prompt_feedback = response.prompt_feedback
@@ -316,7 +303,6 @@ class GeminiModelProvider(ModelProvider):
block_reason_name = str(prompt_feedback.block_reason) block_reason_name = str(prompt_feedback.block_reason)
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}" safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
except (AttributeError, TypeError): except (AttributeError, TypeError):
# prompt_feedback doesn't exist or has unexpected attributes; stick with the default message
pass pass
return ModelResponse( return ModelResponse(
@@ -333,29 +319,20 @@ class GeminiModelProvider(ModelProvider):
}, },
) )
except Exception as e: try:
last_exception = e return self._run_with_retries(
operation=_attempt,
# Check if this is a retryable error using structured error codes max_attempts=max_retries,
is_retryable = self._is_error_retryable(e) delays=retry_delays,
log_prefix=f"Gemini API ({resolved_name})",
# If this is the last attempt or not retryable, give up
if attempt == max_retries - 1 or not is_retryable:
break
# Get progressive delay
delay = retry_delays[attempt]
# Log retry attempt
logger.warning(
f"Gemini API error for model {resolved_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
) )
time.sleep(delay) except Exception as exc:
attempts = max(attempt_counter["value"], 1)
# If we get here, all retries failed error_msg = (
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count f"Gemini API error for model {resolved_name} after {attempts} attempt"
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" f"{'s' if attempts > 1 else ''}: {exc}"
raise RuntimeError(error_msg) from last_exception )
raise RuntimeError(error_msg) from exc
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""

View File

@@ -4,7 +4,6 @@ import copy
import ipaddress import ipaddress
import logging import logging
import os import os
import time
from typing import Optional from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -395,11 +394,10 @@ class OpenAICompatibleProvider(ModelProvider):
# Retry logic with progressive delays # Retry logic with progressive delays
max_retries = 4 max_retries = 4
retry_delays = [1, 3, 5, 8] retry_delays = [1, 3, 5, 8]
last_exception = None attempt_counter = {"value": 0}
actual_attempts = 0
for attempt in range(max_retries): def _attempt() -> ModelResponse:
try: # Log sanitized payload for debugging attempt_counter["value"] += 1
import json import json
sanitized_params = self._sanitize_for_logging(completion_params) sanitized_params = self._sanitize_for_logging(completion_params)
@@ -407,19 +405,14 @@ class OpenAICompatibleProvider(ModelProvider):
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}" f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
) )
# Use OpenAI client's responses endpoint
response = self.client.responses.create(**completion_params) response = self.client.responses.create(**completion_params)
# Extract content from responses endpoint format
# Use validation helper to safely extract output_text
content = self._safe_extract_output_text(response) content = self._safe_extract_output_text(response)
# Try to extract usage information
usage = None usage = None
if hasattr(response, "usage"): if hasattr(response, "usage"):
usage = self._extract_usage(response) usage = self._extract_usage(response)
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"): elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
# Safely extract token counts with None handling
input_tokens = getattr(response, "input_tokens", 0) or 0 input_tokens = getattr(response, "input_tokens", 0) or 0
output_tokens = getattr(response, "output_tokens", 0) or 0 output_tokens = getattr(response, "output_tokens", 0) or 0
usage = { usage = {
@@ -442,25 +435,18 @@ class OpenAICompatibleProvider(ModelProvider):
}, },
) )
except Exception as e: try:
last_exception = e return self._run_with_retries(
operation=_attempt,
# Check if this is a retryable error using structured error codes max_attempts=max_retries,
is_retryable = self._is_error_retryable(e) delays=retry_delays,
log_prefix="o3-pro responses endpoint",
if is_retryable and attempt < max_retries - 1:
delay = retry_delays[attempt]
logging.warning(
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
) )
time.sleep(delay) except Exception as exc:
else: attempts = max(attempt_counter["value"], 1)
break error_msg = f"o3-pro responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}"
# If we get here, all retries failed
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg) logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception raise RuntimeError(error_msg) from exc
def generate_content( def generate_content(
self, self,
@@ -587,17 +573,12 @@ class OpenAICompatibleProvider(ModelProvider):
# Retry logic with progressive delays # Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
attempt_counter = {"value": 0}
last_exception = None def _attempt() -> ModelResponse:
actual_attempts = 0 attempt_counter["value"] += 1
for attempt in range(max_retries):
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
try:
# Generate completion
response = self.client.chat.completions.create(**completion_params) response = self.client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content content = response.choices[0].message.content
usage = self._extract_usage(response) usage = self._extract_usage(response)
@@ -609,35 +590,27 @@ class OpenAICompatibleProvider(ModelProvider):
provider=self.get_provider_type(), provider=self.get_provider_type(),
metadata={ metadata={
"finish_reason": response.choices[0].finish_reason, "finish_reason": response.choices[0].finish_reason,
"model": response.model, # Actual model used "model": response.model,
"id": response.id, "id": response.id,
"created": response.created, "created": response.created,
}, },
) )
except Exception as e: try:
last_exception = e return self._run_with_retries(
operation=_attempt,
# Check if this is a retryable error using structured error codes max_attempts=max_retries,
is_retryable = self._is_error_retryable(e) delays=retry_delays,
log_prefix=f"{self.FRIENDLY_NAME} API ({model_name})",
# If this is the last attempt or not retryable, give up )
if attempt == max_retries - 1 or not is_retryable: except Exception as exc:
break attempts = max(attempt_counter["value"], 1)
error_msg = (
# Get progressive delay f"{self.FRIENDLY_NAME} API error for model {model_name} after {attempts} attempt"
delay = retry_delays[attempt] f"{'s' if attempts > 1 else ''}: {exc}"
# Log retry attempt
logging.warning(
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
) )
time.sleep(delay)
# If we get here, all retries failed
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg) logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception raise RuntimeError(error_msg) from exc
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters. """Validate model parameters.

View File

@@ -0,0 +1,73 @@
"""Tests covering shared retry behaviour for providers."""
from types import SimpleNamespace
import pytest
from providers.openai_provider import OpenAIModelProvider
def _mock_chat_response(content: str = "retry success") -> SimpleNamespace:
"""Create a minimal chat completion response for tests."""
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
message = SimpleNamespace(content=content)
choice = SimpleNamespace(message=message, finish_reason="stop")
return SimpleNamespace(choices=[choice], model="gpt-4.1", id="resp-1", created=123, usage=usage)
def test_openai_provider_retries_on_transient_error(monkeypatch):
"""Provider should retry once for retryable errors and eventually succeed."""
monkeypatch.setattr("providers.base.time.sleep", lambda _: None)
provider = OpenAIModelProvider(api_key="test-key")
attempts = {"count": 0}
def create_completion(**kwargs):
attempts["count"] += 1
if attempts["count"] == 1:
raise RuntimeError("temporary network interruption")
return _mock_chat_response("second attempt response")
provider._client = SimpleNamespace(
chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)),
responses=SimpleNamespace(create=lambda **_: None),
)
result = provider.generate_content("hello", "gpt-4.1")
assert attempts["count"] == 2, "Expected a retry before succeeding"
assert result.content == "second attempt response"
def test_openai_provider_bails_on_non_retryable_error(monkeypatch):
"""Provider should stop immediately when the error is marked non-retryable."""
monkeypatch.setattr("providers.base.time.sleep", lambda _: None)
provider = OpenAIModelProvider(api_key="test-key")
attempts = {"count": 0}
def create_completion(**kwargs):
attempts["count"] += 1
raise RuntimeError("context length exceeded 429")
provider._client = SimpleNamespace(
chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)),
responses=SimpleNamespace(create=lambda **_: None),
)
monkeypatch.setattr(
OpenAIModelProvider,
"_is_error_retryable",
lambda self, error: False,
)
with pytest.raises(RuntimeError) as excinfo:
provider.generate_content("hello", "gpt-4.1")
assert "after 1 attempt" in str(excinfo.value)
assert attempts["count"] == 1