refactor: improved retry logic and moved core logic to base class
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
"""Base interfaces and common behaviour for model providers."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
@@ -168,6 +169,107 @@ class ModelProvider(ABC):
|
||||
|
||||
return
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Retry helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _is_error_retryable(self, error: Exception) -> bool:
|
||||
"""Return True when an error warrants another attempt.
|
||||
|
||||
Subclasses with structured provider errors should override this hook.
|
||||
The default implementation only retries obvious transient failures such
|
||||
as timeouts or 5xx responses detected via string inspection.
|
||||
"""
|
||||
|
||||
error_str = str(error).lower()
|
||||
retryable_indicators = [
|
||||
"timeout",
|
||||
"connection",
|
||||
"temporary",
|
||||
"unavailable",
|
||||
"retry",
|
||||
"reset",
|
||||
"refused",
|
||||
"broken pipe",
|
||||
"tls",
|
||||
"handshake",
|
||||
"network",
|
||||
"rate limit",
|
||||
"429",
|
||||
"500",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
]
|
||||
|
||||
return any(indicator in error_str for indicator in retryable_indicators)
|
||||
|
||||
def _run_with_retries(
|
||||
self,
|
||||
operation: Callable[[], Any],
|
||||
*,
|
||||
max_attempts: int,
|
||||
delays: Optional[list[float]] = None,
|
||||
log_prefix: str = "",
|
||||
):
|
||||
"""Execute ``operation`` with retry semantics.
|
||||
|
||||
Args:
|
||||
operation: Callable returning the provider result.
|
||||
max_attempts: Maximum number of attempts (>=1).
|
||||
delays: Optional list of sleep durations between attempts.
|
||||
log_prefix: Optional identifier for log clarity.
|
||||
|
||||
Returns:
|
||||
Whatever ``operation`` returns.
|
||||
|
||||
Raises:
|
||||
The last exception when all retries fail or the error is not retryable.
|
||||
"""
|
||||
|
||||
if max_attempts < 1:
|
||||
raise ValueError("max_attempts must be >= 1")
|
||||
|
||||
attempts = max_attempts
|
||||
delays = delays or []
|
||||
last_exc: Optional[Exception] = None
|
||||
|
||||
for attempt_index in range(attempts):
|
||||
try:
|
||||
return operation()
|
||||
except Exception as exc: # noqa: BLE001 - bubble exact provider errors
|
||||
last_exc = exc
|
||||
attempt_number = attempt_index + 1
|
||||
|
||||
# Decide whether to retry based on subclass hook
|
||||
retryable = self._is_error_retryable(exc)
|
||||
if not retryable or attempt_number >= attempts:
|
||||
raise
|
||||
|
||||
delay_idx = min(attempt_index, len(delays) - 1) if delays else -1
|
||||
delay = delays[delay_idx] if delay_idx >= 0 else 0.0
|
||||
|
||||
if delay > 0:
|
||||
logger.warning(
|
||||
"%s retryable error (attempt %s/%s): %s. Retrying in %ss...",
|
||||
log_prefix or self.__class__.__name__,
|
||||
attempt_number,
|
||||
attempts,
|
||||
exc,
|
||||
delay,
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s retryable error (attempt %s/%s): %s. Retrying...",
|
||||
log_prefix or self.__class__.__name__,
|
||||
attempt_number,
|
||||
attempts,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Should never reach here because loop either returns or raises
|
||||
raise last_exc if last_exc else RuntimeError("Retry loop exited without result")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation hooks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user