refactor: improved retry logic and moved core logic to base class

This commit is contained in:
Fahad
2025-10-03 23:48:55 +04:00
parent 828c4eed5b
commit f955100f3a
5 changed files with 378 additions and 267 deletions

View File

@@ -1,8 +1,9 @@
"""Base interfaces and common behaviour for model providers."""
import logging
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
@@ -168,6 +169,107 @@ class ModelProvider(ABC):
return
# ------------------------------------------------------------------
# Retry helpers
# ------------------------------------------------------------------
def _is_error_retryable(self, error: Exception) -> bool:
"""Return True when an error warrants another attempt.
Subclasses with structured provider errors should override this hook.
The default implementation only retries obvious transient failures such
as timeouts or 5xx responses detected via string inspection.
"""
error_str = str(error).lower()
retryable_indicators = [
"timeout",
"connection",
"temporary",
"unavailable",
"retry",
"reset",
"refused",
"broken pipe",
"tls",
"handshake",
"network",
"rate limit",
"429",
"500",
"502",
"503",
"504",
]
return any(indicator in error_str for indicator in retryable_indicators)
def _run_with_retries(
self,
operation: Callable[[], Any],
*,
max_attempts: int,
delays: Optional[list[float]] = None,
log_prefix: str = "",
):
"""Execute ``operation`` with retry semantics.
Args:
operation: Callable returning the provider result.
max_attempts: Maximum number of attempts (>=1).
delays: Optional list of sleep durations between attempts.
log_prefix: Optional identifier for log clarity.
Returns:
Whatever ``operation`` returns.
Raises:
The last exception when all retries fail or the error is not retryable.
"""
if max_attempts < 1:
raise ValueError("max_attempts must be >= 1")
attempts = max_attempts
delays = delays or []
last_exc: Optional[Exception] = None
for attempt_index in range(attempts):
try:
return operation()
except Exception as exc: # noqa: BLE001 - bubble exact provider errors
last_exc = exc
attempt_number = attempt_index + 1
# Decide whether to retry based on subclass hook
retryable = self._is_error_retryable(exc)
if not retryable or attempt_number >= attempts:
raise
delay_idx = min(attempt_index, len(delays) - 1) if delays else -1
delay = delays[delay_idx] if delay_idx >= 0 else 0.0
if delay > 0:
logger.warning(
"%s retryable error (attempt %s/%s): %s. Retrying in %ss...",
log_prefix or self.__class__.__name__,
attempt_number,
attempts,
exc,
delay,
)
time.sleep(delay)
else:
logger.warning(
"%s retryable error (attempt %s/%s): %s. Retrying...",
log_prefix or self.__class__.__name__,
attempt_number,
attempts,
exc,
)
# Should never reach here because loop either returns or raises
raise last_exc if last_exc else RuntimeError("Retry loop exited without result")
# ------------------------------------------------------------------
# Validation hooks
# ------------------------------------------------------------------