refactor: improved retry logic and moved core logic to base class

This commit is contained in:
Fahad
2025-10-03 23:48:55 +04:00
parent 828c4eed5b
commit f955100f3a
5 changed files with 378 additions and 267 deletions

View File

@@ -1,8 +1,9 @@
"""Base interfaces and common behaviour for model providers.""" """Base interfaces and common behaviour for model providers."""
import logging import logging
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Callable, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from tools.models import ToolModelCategory from tools.models import ToolModelCategory
@@ -168,6 +169,107 @@ class ModelProvider(ABC):
return return
# ------------------------------------------------------------------
# Retry helpers
# ------------------------------------------------------------------
def _is_error_retryable(self, error: Exception) -> bool:
"""Return True when an error warrants another attempt.
Subclasses with structured provider errors should override this hook.
The default implementation only retries obvious transient failures such
as timeouts or 5xx responses detected via string inspection.
"""
error_str = str(error).lower()
retryable_indicators = [
"timeout",
"connection",
"temporary",
"unavailable",
"retry",
"reset",
"refused",
"broken pipe",
"tls",
"handshake",
"network",
"rate limit",
"429",
"500",
"502",
"503",
"504",
]
return any(indicator in error_str for indicator in retryable_indicators)
def _run_with_retries(
self,
operation: Callable[[], Any],
*,
max_attempts: int,
delays: Optional[list[float]] = None,
log_prefix: str = "",
):
"""Execute ``operation`` with retry semantics.
Args:
operation: Callable returning the provider result.
max_attempts: Maximum number of attempts (>=1).
delays: Optional list of sleep durations between attempts.
log_prefix: Optional identifier for log clarity.
Returns:
Whatever ``operation`` returns.
Raises:
The last exception when all retries fail or the error is not retryable.
"""
if max_attempts < 1:
raise ValueError("max_attempts must be >= 1")
attempts = max_attempts
delays = delays or []
last_exc: Optional[Exception] = None
for attempt_index in range(attempts):
try:
return operation()
except Exception as exc: # noqa: BLE001 - bubble exact provider errors
last_exc = exc
attempt_number = attempt_index + 1
# Decide whether to retry based on subclass hook
retryable = self._is_error_retryable(exc)
if not retryable or attempt_number >= attempts:
raise
delay_idx = min(attempt_index, len(delays) - 1) if delays else -1
delay = delays[delay_idx] if delay_idx >= 0 else 0.0
if delay > 0:
logger.warning(
"%s retryable error (attempt %s/%s): %s. Retrying in %ss...",
log_prefix or self.__class__.__name__,
attempt_number,
attempts,
exc,
delay,
)
time.sleep(delay)
else:
logger.warning(
"%s retryable error (attempt %s/%s): %s. Retrying...",
log_prefix or self.__class__.__name__,
attempt_number,
attempts,
exc,
)
# Should never reach here because loop either returns or raises
raise last_exc if last_exc else RuntimeError("Retry loop exited without result")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Validation hooks # Validation hooks
# ------------------------------------------------------------------ # ------------------------------------------------------------------

View File

@@ -3,7 +3,6 @@
import logging import logging
import os import os
import threading import threading
import time
from typing import Optional from typing import Optional
from .openai_compatible import OpenAICompatibleProvider from .openai_compatible import OpenAICompatibleProvider
@@ -405,55 +404,42 @@ class DIALModelProvider(OpenAICompatibleProvider):
# DIAL-specific: Get cached client for deployment endpoint # DIAL-specific: Get cached client for deployment endpoint
deployment_client = self._get_deployment_client(resolved_model) deployment_client = self._get_deployment_client(resolved_model)
# Retry logic with progressive delays attempt_counter = {"value": 0}
last_exception = None
for attempt in range(self.MAX_RETRIES): def _attempt() -> ModelResponse:
try: attempt_counter["value"] += 1
# Generate completion using deployment-specific client response = deployment_client.chat.completions.create(**completion_params)
response = deployment_client.chat.completions.create(**completion_params)
# Extract content and usage content = response.choices[0].message.content
content = response.choices[0].message.content usage = self._extract_usage(response)
usage = self._extract_usage(response)
return ModelResponse( return ModelResponse(
content=content, content=content,
usage=usage, usage=usage,
model_name=model_name, model_name=model_name,
friendly_name=self.FRIENDLY_NAME, friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(), provider=self.get_provider_type(),
metadata={ metadata={
"finish_reason": response.choices[0].finish_reason, "finish_reason": response.choices[0].finish_reason,
"model": response.model, "model": response.model,
"id": response.id, "id": response.id,
"created": response.created, "created": response.created,
}, },
) )
except Exception as e: try:
last_exception = e return self._run_with_retries(
operation=_attempt,
max_attempts=self.MAX_RETRIES,
delays=self.RETRY_DELAYS,
log_prefix=f"DIAL API ({resolved_model})",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
if attempts == 1:
raise ValueError(f"DIAL API error for model {model_name}: {exc}") from exc
# Check if this is a retryable error raise ValueError(f"DIAL API error for model {model_name} after {attempts} attempts: {exc}") from exc
is_retryable = self._is_error_retryable(e)
if not is_retryable:
# Non-retryable error, raise immediately
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
# If this isn't the last attempt and error is retryable, wait and retry
if attempt < self.MAX_RETRIES - 1:
delay = self.RETRY_DELAYS[attempt]
logger.info(
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
)
time.sleep(delay)
continue
# All retries exhausted
raise ValueError(
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
)
def close(self) -> None: def close(self) -> None:
"""Clean up HTTP clients when provider is closed.""" """Clean up HTTP clients when provider is closed."""

View File

@@ -2,7 +2,6 @@
import base64 import base64
import logging import logging
import time
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -229,133 +228,111 @@ class GeminiModelProvider(ModelProvider):
# Retry logic with progressive delays # Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
attempt_counter = {"value": 0}
last_exception = None def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
response = self.client.models.generate_content(
model=resolved_name,
contents=contents,
config=generation_config,
)
for attempt in range(max_retries): usage = self._extract_usage(response)
try:
# Generate content
response = self.client.models.generate_content(
model=resolved_name,
contents=contents,
config=generation_config,
)
# Extract usage information if available finish_reason_str = "UNKNOWN"
usage = self._extract_usage(response) is_blocked_by_safety = False
safety_feedback_details = None
# Intelligently determine finish reason and safety blocks if response.candidates:
finish_reason_str = "UNKNOWN" candidate = response.candidates[0]
is_blocked_by_safety = False
safety_feedback_details = None
if response.candidates: try:
candidate = response.candidates[0] finish_reason_enum = candidate.finish_reason
if finish_reason_enum:
# Safely get finish reason
try:
finish_reason_enum = candidate.finish_reason
if finish_reason_enum:
# Handle both enum objects and string values
try:
finish_reason_str = finish_reason_enum.name
except AttributeError:
finish_reason_str = str(finish_reason_enum)
else:
finish_reason_str = "STOP"
except AttributeError:
finish_reason_str = "STOP"
# If content is empty, check safety ratings for the definitive cause
if not response.text:
try: try:
safety_ratings = candidate.safety_ratings finish_reason_str = finish_reason_enum.name
if safety_ratings: # Check it's not None or empty except AttributeError:
for rating in safety_ratings: finish_reason_str = str(finish_reason_enum)
try: else:
if rating.blocked: finish_reason_str = "STOP"
is_blocked_by_safety = True except AttributeError:
# Provide details for logging/debugging finish_reason_str = "STOP"
category_name = "UNKNOWN"
probability_name = "UNKNOWN"
try:
category_name = rating.category.name
except (AttributeError, TypeError):
pass
try:
probability_name = rating.probability.name
except (AttributeError, TypeError):
pass
safety_feedback_details = (
f"Category: {category_name}, Probability: {probability_name}"
)
break
except (AttributeError, TypeError):
# Individual rating doesn't have expected attributes
continue
except (AttributeError, TypeError):
# candidate doesn't have safety_ratings or it's not iterable
pass
# Also check for prompt-level blocking (request rejected entirely)
elif response.candidates is not None and len(response.candidates) == 0:
# No candidates is the primary indicator of a prompt-level block
is_blocked_by_safety = True
finish_reason_str = "SAFETY"
safety_feedback_details = "Prompt blocked, reason unavailable" # Default message
if not response.text:
try: try:
prompt_feedback = response.prompt_feedback safety_ratings = candidate.safety_ratings
if prompt_feedback and prompt_feedback.block_reason: if safety_ratings:
try: for rating in safety_ratings:
block_reason_name = prompt_feedback.block_reason.name try:
except AttributeError: if rating.blocked:
block_reason_name = str(prompt_feedback.block_reason) is_blocked_by_safety = True
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}" category_name = "UNKNOWN"
probability_name = "UNKNOWN"
try:
category_name = rating.category.name
except (AttributeError, TypeError):
pass
try:
probability_name = rating.probability.name
except (AttributeError, TypeError):
pass
safety_feedback_details = (
f"Category: {category_name}, Probability: {probability_name}"
)
break
except (AttributeError, TypeError):
continue
except (AttributeError, TypeError): except (AttributeError, TypeError):
# prompt_feedback doesn't exist or has unexpected attributes; stick with the default message
pass pass
return ModelResponse( elif response.candidates is not None and len(response.candidates) == 0:
content=response.text, is_blocked_by_safety = True
usage=usage, finish_reason_str = "SAFETY"
model_name=resolved_name, safety_feedback_details = "Prompt blocked, reason unavailable"
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": finish_reason_str,
"is_blocked_by_safety": is_blocked_by_safety,
"safety_feedback": safety_feedback_details,
},
)
except Exception as e: try:
last_exception = e prompt_feedback = response.prompt_feedback
if prompt_feedback and prompt_feedback.block_reason:
try:
block_reason_name = prompt_feedback.block_reason.name
except AttributeError:
block_reason_name = str(prompt_feedback.block_reason)
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
except (AttributeError, TypeError):
pass
# Check if this is a retryable error using structured error codes return ModelResponse(
is_retryable = self._is_error_retryable(e) content=response.text,
usage=usage,
model_name=resolved_name,
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": finish_reason_str,
"is_blocked_by_safety": is_blocked_by_safety,
"safety_feedback": safety_feedback_details,
},
)
# If this is the last attempt or not retryable, give up try:
if attempt == max_retries - 1 or not is_retryable: return self._run_with_retries(
break operation=_attempt,
max_attempts=max_retries,
# Get progressive delay delays=retry_delays,
delay = retry_delays[attempt] log_prefix=f"Gemini API ({resolved_name})",
)
# Log retry attempt except Exception as exc:
logger.warning( attempts = max(attempt_counter["value"], 1)
f"Gemini API error for model {resolved_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..." error_msg = (
) f"Gemini API error for model {resolved_name} after {attempts} attempt"
time.sleep(delay) f"{'s' if attempts > 1 else ''}: {exc}"
)
# If we get here, all retries failed raise RuntimeError(error_msg) from exc
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
raise RuntimeError(error_msg) from last_exception
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""

View File

@@ -4,7 +4,6 @@ import copy
import ipaddress import ipaddress
import logging import logging
import os import os
import time
from typing import Optional from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -395,72 +394,59 @@ class OpenAICompatibleProvider(ModelProvider):
# Retry logic with progressive delays # Retry logic with progressive delays
max_retries = 4 max_retries = 4
retry_delays = [1, 3, 5, 8] retry_delays = [1, 3, 5, 8]
last_exception = None attempt_counter = {"value": 0}
actual_attempts = 0
for attempt in range(max_retries): def _attempt() -> ModelResponse:
try: # Log sanitized payload for debugging attempt_counter["value"] += 1
import json import json
sanitized_params = self._sanitize_for_logging(completion_params) sanitized_params = self._sanitize_for_logging(completion_params)
logging.info( logging.info(
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}" f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
) )
# Use OpenAI client's responses endpoint response = self.client.responses.create(**completion_params)
response = self.client.responses.create(**completion_params)
# Extract content from responses endpoint format content = self._safe_extract_output_text(response)
# Use validation helper to safely extract output_text
content = self._safe_extract_output_text(response)
# Try to extract usage information usage = None
usage = None if hasattr(response, "usage"):
if hasattr(response, "usage"): usage = self._extract_usage(response)
usage = self._extract_usage(response) elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"): input_tokens = getattr(response, "input_tokens", 0) or 0
# Safely extract token counts with None handling output_tokens = getattr(response, "output_tokens", 0) or 0
input_tokens = getattr(response, "input_tokens", 0) or 0 usage = {
output_tokens = getattr(response, "output_tokens", 0) or 0 "input_tokens": input_tokens,
usage = { "output_tokens": output_tokens,
"input_tokens": input_tokens, "total_tokens": input_tokens + output_tokens,
"output_tokens": output_tokens, }
"total_tokens": input_tokens + output_tokens,
}
return ModelResponse( return ModelResponse(
content=content, content=content,
usage=usage, usage=usage,
model_name=model_name, model_name=model_name,
friendly_name=self.FRIENDLY_NAME, friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(), provider=self.get_provider_type(),
metadata={ metadata={
"model": getattr(response, "model", model_name), "model": getattr(response, "model", model_name),
"id": getattr(response, "id", ""), "id": getattr(response, "id", ""),
"created": getattr(response, "created_at", 0), "created": getattr(response, "created_at", 0),
"endpoint": "responses", "endpoint": "responses",
}, },
) )
except Exception as e: try:
last_exception = e return self._run_with_retries(
operation=_attempt,
# Check if this is a retryable error using structured error codes max_attempts=max_retries,
is_retryable = self._is_error_retryable(e) delays=retry_delays,
log_prefix="o3-pro responses endpoint",
if is_retryable and attempt < max_retries - 1: )
delay = retry_delays[attempt] except Exception as exc:
logging.warning( attempts = max(attempt_counter["value"], 1)
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..." error_msg = f"o3-pro responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}"
) logging.error(error_msg)
time.sleep(delay) raise RuntimeError(error_msg) from exc
else:
break
# If we get here, all retries failed
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
def generate_content( def generate_content(
self, self,
@@ -587,57 +573,44 @@ class OpenAICompatibleProvider(ModelProvider):
# Retry logic with progressive delays # Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
attempt_counter = {"value": 0}
last_exception = None def _attempt() -> ModelResponse:
actual_attempts = 0 attempt_counter["value"] += 1
response = self.client.chat.completions.create(**completion_params)
for attempt in range(max_retries): content = response.choices[0].message.content
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count usage = self._extract_usage(response)
try:
# Generate completion
response = self.client.chat.completions.create(**completion_params)
# Extract content and usage return ModelResponse(
content = response.choices[0].message.content content=content,
usage = self._extract_usage(response) usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model,
"id": response.id,
"created": response.created,
},
)
return ModelResponse( try:
content=content, return self._run_with_retries(
usage=usage, operation=_attempt,
model_name=model_name, max_attempts=max_retries,
friendly_name=self.FRIENDLY_NAME, delays=retry_delays,
provider=self.get_provider_type(), log_prefix=f"{self.FRIENDLY_NAME} API ({model_name})",
metadata={ )
"finish_reason": response.choices[0].finish_reason, except Exception as exc:
"model": response.model, # Actual model used attempts = max(attempt_counter["value"], 1)
"id": response.id, error_msg = (
"created": response.created, f"{self.FRIENDLY_NAME} API error for model {model_name} after {attempts} attempt"
}, f"{'s' if attempts > 1 else ''}: {exc}"
) )
logging.error(error_msg)
except Exception as e: raise RuntimeError(error_msg) from exc
last_exception = e
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
# If this is the last attempt or not retryable, give up
if attempt == max_retries - 1 or not is_retryable:
break
# Get progressive delay
delay = retry_delays[attempt]
# Log retry attempt
logging.warning(
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
# If we get here, all retries failed
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters. """Validate model parameters.

View File

@@ -0,0 +1,73 @@
"""Tests covering shared retry behaviour for providers."""
from types import SimpleNamespace
import pytest
from providers.openai_provider import OpenAIModelProvider
def _mock_chat_response(content: str = "retry success") -> SimpleNamespace:
"""Create a minimal chat completion response for tests."""
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
message = SimpleNamespace(content=content)
choice = SimpleNamespace(message=message, finish_reason="stop")
return SimpleNamespace(choices=[choice], model="gpt-4.1", id="resp-1", created=123, usage=usage)
def test_openai_provider_retries_on_transient_error(monkeypatch):
"""Provider should retry once for retryable errors and eventually succeed."""
monkeypatch.setattr("providers.base.time.sleep", lambda _: None)
provider = OpenAIModelProvider(api_key="test-key")
attempts = {"count": 0}
def create_completion(**kwargs):
attempts["count"] += 1
if attempts["count"] == 1:
raise RuntimeError("temporary network interruption")
return _mock_chat_response("second attempt response")
provider._client = SimpleNamespace(
chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)),
responses=SimpleNamespace(create=lambda **_: None),
)
result = provider.generate_content("hello", "gpt-4.1")
assert attempts["count"] == 2, "Expected a retry before succeeding"
assert result.content == "second attempt response"
def test_openai_provider_bails_on_non_retryable_error(monkeypatch):
"""Provider should stop immediately when the error is marked non-retryable."""
monkeypatch.setattr("providers.base.time.sleep", lambda _: None)
provider = OpenAIModelProvider(api_key="test-key")
attempts = {"count": 0}
def create_completion(**kwargs):
attempts["count"] += 1
raise RuntimeError("context length exceeded 429")
provider._client = SimpleNamespace(
chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)),
responses=SimpleNamespace(create=lambda **_: None),
)
monkeypatch.setattr(
OpenAIModelProvider,
"_is_error_retryable",
lambda self, error: False,
)
with pytest.raises(RuntimeError) as excinfo:
provider.generate_content("hello", "gpt-4.1")
assert "after 1 attempt" in str(excinfo.value)
assert attempts["count"] == 1