refactor: improved retry logic and moved core logic to base class

This commit is contained in:
Fahad
2025-10-03 23:48:55 +04:00
parent 828c4eed5b
commit f955100f3a
5 changed files with 378 additions and 267 deletions

View File

@@ -1,8 +1,9 @@
"""Base interfaces and common behaviour for model providers."""
import logging
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
@@ -168,6 +169,107 @@ class ModelProvider(ABC):
return
# ------------------------------------------------------------------
# Retry helpers
# ------------------------------------------------------------------
def _is_error_retryable(self, error: Exception) -> bool:
"""Return True when an error warrants another attempt.
Subclasses with structured provider errors should override this hook.
The default implementation only retries obvious transient failures such
as timeouts or 5xx responses detected via string inspection.
"""
error_str = str(error).lower()
retryable_indicators = [
"timeout",
"connection",
"temporary",
"unavailable",
"retry",
"reset",
"refused",
"broken pipe",
"tls",
"handshake",
"network",
"rate limit",
"429",
"500",
"502",
"503",
"504",
]
return any(indicator in error_str for indicator in retryable_indicators)
def _run_with_retries(
self,
operation: Callable[[], Any],
*,
max_attempts: int,
delays: Optional[list[float]] = None,
log_prefix: str = "",
):
"""Execute ``operation`` with retry semantics.
Args:
operation: Callable returning the provider result.
max_attempts: Maximum number of attempts (>=1).
delays: Optional list of sleep durations between attempts.
log_prefix: Optional identifier for log clarity.
Returns:
Whatever ``operation`` returns.
Raises:
The last exception when all retries fail or the error is not retryable.
"""
if max_attempts < 1:
raise ValueError("max_attempts must be >= 1")
attempts = max_attempts
delays = delays or []
last_exc: Optional[Exception] = None
for attempt_index in range(attempts):
try:
return operation()
except Exception as exc: # noqa: BLE001 - bubble exact provider errors
last_exc = exc
attempt_number = attempt_index + 1
# Decide whether to retry based on subclass hook
retryable = self._is_error_retryable(exc)
if not retryable or attempt_number >= attempts:
raise
delay_idx = min(attempt_index, len(delays) - 1) if delays else -1
delay = delays[delay_idx] if delay_idx >= 0 else 0.0
if delay > 0:
logger.warning(
"%s retryable error (attempt %s/%s): %s. Retrying in %ss...",
log_prefix or self.__class__.__name__,
attempt_number,
attempts,
exc,
delay,
)
time.sleep(delay)
else:
logger.warning(
"%s retryable error (attempt %s/%s): %s. Retrying...",
log_prefix or self.__class__.__name__,
attempt_number,
attempts,
exc,
)
# Should never reach here because loop either returns or raises
raise last_exc if last_exc else RuntimeError("Retry loop exited without result")
# ------------------------------------------------------------------
# Validation hooks
# ------------------------------------------------------------------

View File

@@ -3,7 +3,6 @@
import logging
import os
import threading
import time
from typing import Optional
from .openai_compatible import OpenAICompatibleProvider
@@ -405,55 +404,42 @@ class DIALModelProvider(OpenAICompatibleProvider):
# DIAL-specific: Get cached client for deployment endpoint
deployment_client = self._get_deployment_client(resolved_model)
# Retry logic with progressive delays
last_exception = None
attempt_counter = {"value": 0}
for attempt in range(self.MAX_RETRIES):
try:
# Generate completion using deployment-specific client
response = deployment_client.chat.completions.create(**completion_params)
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
response = deployment_client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model,
"id": response.id,
"created": response.created,
},
)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model,
"id": response.id,
"created": response.created,
},
)
except Exception as e:
last_exception = e
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=self.MAX_RETRIES,
delays=self.RETRY_DELAYS,
log_prefix=f"DIAL API ({resolved_model})",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
if attempts == 1:
raise ValueError(f"DIAL API error for model {model_name}: {exc}") from exc
# Check if this is a retryable error
is_retryable = self._is_error_retryable(e)
if not is_retryable:
# Non-retryable error, raise immediately
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
# If this isn't the last attempt and error is retryable, wait and retry
if attempt < self.MAX_RETRIES - 1:
delay = self.RETRY_DELAYS[attempt]
logger.info(
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
)
time.sleep(delay)
continue
# All retries exhausted
raise ValueError(
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
)
raise ValueError(f"DIAL API error for model {model_name} after {attempts} attempts: {exc}") from exc
def close(self) -> None:
"""Clean up HTTP clients when provider is closed."""

View File

@@ -2,7 +2,6 @@
import base64
import logging
import time
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
@@ -229,133 +228,111 @@ class GeminiModelProvider(ModelProvider):
# Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
attempt_counter = {"value": 0}
last_exception = None
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
response = self.client.models.generate_content(
model=resolved_name,
contents=contents,
config=generation_config,
)
for attempt in range(max_retries):
try:
# Generate content
response = self.client.models.generate_content(
model=resolved_name,
contents=contents,
config=generation_config,
)
usage = self._extract_usage(response)
# Extract usage information if available
usage = self._extract_usage(response)
finish_reason_str = "UNKNOWN"
is_blocked_by_safety = False
safety_feedback_details = None
# Intelligently determine finish reason and safety blocks
finish_reason_str = "UNKNOWN"
is_blocked_by_safety = False
safety_feedback_details = None
if response.candidates:
candidate = response.candidates[0]
if response.candidates:
candidate = response.candidates[0]
# Safely get finish reason
try:
finish_reason_enum = candidate.finish_reason
if finish_reason_enum:
# Handle both enum objects and string values
try:
finish_reason_str = finish_reason_enum.name
except AttributeError:
finish_reason_str = str(finish_reason_enum)
else:
finish_reason_str = "STOP"
except AttributeError:
finish_reason_str = "STOP"
# If content is empty, check safety ratings for the definitive cause
if not response.text:
try:
finish_reason_enum = candidate.finish_reason
if finish_reason_enum:
try:
safety_ratings = candidate.safety_ratings
if safety_ratings: # Check it's not None or empty
for rating in safety_ratings:
try:
if rating.blocked:
is_blocked_by_safety = True
# Provide details for logging/debugging
category_name = "UNKNOWN"
probability_name = "UNKNOWN"
try:
category_name = rating.category.name
except (AttributeError, TypeError):
pass
try:
probability_name = rating.probability.name
except (AttributeError, TypeError):
pass
safety_feedback_details = (
f"Category: {category_name}, Probability: {probability_name}"
)
break
except (AttributeError, TypeError):
# Individual rating doesn't have expected attributes
continue
except (AttributeError, TypeError):
# candidate doesn't have safety_ratings or it's not iterable
pass
# Also check for prompt-level blocking (request rejected entirely)
elif response.candidates is not None and len(response.candidates) == 0:
# No candidates is the primary indicator of a prompt-level block
is_blocked_by_safety = True
finish_reason_str = "SAFETY"
safety_feedback_details = "Prompt blocked, reason unavailable" # Default message
finish_reason_str = finish_reason_enum.name
except AttributeError:
finish_reason_str = str(finish_reason_enum)
else:
finish_reason_str = "STOP"
except AttributeError:
finish_reason_str = "STOP"
if not response.text:
try:
prompt_feedback = response.prompt_feedback
if prompt_feedback and prompt_feedback.block_reason:
try:
block_reason_name = prompt_feedback.block_reason.name
except AttributeError:
block_reason_name = str(prompt_feedback.block_reason)
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
safety_ratings = candidate.safety_ratings
if safety_ratings:
for rating in safety_ratings:
try:
if rating.blocked:
is_blocked_by_safety = True
category_name = "UNKNOWN"
probability_name = "UNKNOWN"
try:
category_name = rating.category.name
except (AttributeError, TypeError):
pass
try:
probability_name = rating.probability.name
except (AttributeError, TypeError):
pass
safety_feedback_details = (
f"Category: {category_name}, Probability: {probability_name}"
)
break
except (AttributeError, TypeError):
continue
except (AttributeError, TypeError):
# prompt_feedback doesn't exist or has unexpected attributes; stick with the default message
pass
return ModelResponse(
content=response.text,
usage=usage,
model_name=resolved_name,
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": finish_reason_str,
"is_blocked_by_safety": is_blocked_by_safety,
"safety_feedback": safety_feedback_details,
},
)
elif response.candidates is not None and len(response.candidates) == 0:
is_blocked_by_safety = True
finish_reason_str = "SAFETY"
safety_feedback_details = "Prompt blocked, reason unavailable"
except Exception as e:
last_exception = e
try:
prompt_feedback = response.prompt_feedback
if prompt_feedback and prompt_feedback.block_reason:
try:
block_reason_name = prompt_feedback.block_reason.name
except AttributeError:
block_reason_name = str(prompt_feedback.block_reason)
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
except (AttributeError, TypeError):
pass
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
return ModelResponse(
content=response.text,
usage=usage,
model_name=resolved_name,
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": finish_reason_str,
"is_blocked_by_safety": is_blocked_by_safety,
"safety_feedback": safety_feedback_details,
},
)
# If this is the last attempt or not retryable, give up
if attempt == max_retries - 1 or not is_retryable:
break
# Get progressive delay
delay = retry_delays[attempt]
# Log retry attempt
logger.warning(
f"Gemini API error for model {resolved_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
# If we get here, all retries failed
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
raise RuntimeError(error_msg) from last_exception
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix=f"Gemini API ({resolved_name})",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = (
f"Gemini API error for model {resolved_name} after {attempts} attempt"
f"{'s' if attempts > 1 else ''}: {exc}"
)
raise RuntimeError(error_msg) from exc
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""

View File

@@ -4,7 +4,6 @@ import copy
import ipaddress
import logging
import os
import time
from typing import Optional
from urllib.parse import urlparse
@@ -395,72 +394,59 @@ class OpenAICompatibleProvider(ModelProvider):
# Retry logic with progressive delays
max_retries = 4
retry_delays = [1, 3, 5, 8]
last_exception = None
actual_attempts = 0
attempt_counter = {"value": 0}
for attempt in range(max_retries):
try: # Log sanitized payload for debugging
import json
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
import json
sanitized_params = self._sanitize_for_logging(completion_params)
logging.info(
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
)
sanitized_params = self._sanitize_for_logging(completion_params)
logging.info(
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
)
# Use OpenAI client's responses endpoint
response = self.client.responses.create(**completion_params)
response = self.client.responses.create(**completion_params)
# Extract content from responses endpoint format
# Use validation helper to safely extract output_text
content = self._safe_extract_output_text(response)
content = self._safe_extract_output_text(response)
# Try to extract usage information
usage = None
if hasattr(response, "usage"):
usage = self._extract_usage(response)
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
# Safely extract token counts with None handling
input_tokens = getattr(response, "input_tokens", 0) or 0
output_tokens = getattr(response, "output_tokens", 0) or 0
usage = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
usage = None
if hasattr(response, "usage"):
usage = self._extract_usage(response)
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
input_tokens = getattr(response, "input_tokens", 0) or 0
output_tokens = getattr(response, "output_tokens", 0) or 0
usage = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"model": getattr(response, "model", model_name),
"id": getattr(response, "id", ""),
"created": getattr(response, "created_at", 0),
"endpoint": "responses",
},
)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"model": getattr(response, "model", model_name),
"id": getattr(response, "id", ""),
"created": getattr(response, "created_at", 0),
"endpoint": "responses",
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
if is_retryable and attempt < max_retries - 1:
delay = retry_delays[attempt]
logging.warning(
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
else:
break
# If we get here, all retries failed
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix="o3-pro responses endpoint",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = f"o3-pro responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}"
logging.error(error_msg)
raise RuntimeError(error_msg) from exc
def generate_content(
self,
@@ -587,57 +573,44 @@ class OpenAICompatibleProvider(ModelProvider):
# Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
attempt_counter = {"value": 0}
last_exception = None
actual_attempts = 0
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
response = self.client.chat.completions.create(**completion_params)
for attempt in range(max_retries):
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
try:
# Generate completion
response = self.client.chat.completions.create(**completion_params)
content = response.choices[0].message.content
usage = self._extract_usage(response)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model,
"id": response.id,
"created": response.created,
},
)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model, # Actual model used
"id": response.id,
"created": response.created,
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
# If this is the last attempt or not retryable, give up
if attempt == max_retries - 1 or not is_retryable:
break
# Get progressive delay
delay = retry_delays[attempt]
# Log retry attempt
logging.warning(
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
# If we get here, all retries failed
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix=f"{self.FRIENDLY_NAME} API ({model_name})",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = (
f"{self.FRIENDLY_NAME} API error for model {model_name} after {attempts} attempt"
f"{'s' if attempts > 1 else ''}: {exc}"
)
logging.error(error_msg)
raise RuntimeError(error_msg) from exc
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters.

View File

@@ -0,0 +1,73 @@
"""Tests covering shared retry behaviour for providers."""
from types import SimpleNamespace
import pytest
from providers.openai_provider import OpenAIModelProvider
def _mock_chat_response(content: str = "retry success") -> SimpleNamespace:
"""Create a minimal chat completion response for tests."""
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
message = SimpleNamespace(content=content)
choice = SimpleNamespace(message=message, finish_reason="stop")
return SimpleNamespace(choices=[choice], model="gpt-4.1", id="resp-1", created=123, usage=usage)
def test_openai_provider_retries_on_transient_error(monkeypatch):
"""Provider should retry once for retryable errors and eventually succeed."""
monkeypatch.setattr("providers.base.time.sleep", lambda _: None)
provider = OpenAIModelProvider(api_key="test-key")
attempts = {"count": 0}
def create_completion(**kwargs):
attempts["count"] += 1
if attempts["count"] == 1:
raise RuntimeError("temporary network interruption")
return _mock_chat_response("second attempt response")
provider._client = SimpleNamespace(
chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)),
responses=SimpleNamespace(create=lambda **_: None),
)
result = provider.generate_content("hello", "gpt-4.1")
assert attempts["count"] == 2, "Expected a retry before succeeding"
assert result.content == "second attempt response"
def test_openai_provider_bails_on_non_retryable_error(monkeypatch):
"""Provider should stop immediately when the error is marked non-retryable."""
monkeypatch.setattr("providers.base.time.sleep", lambda _: None)
provider = OpenAIModelProvider(api_key="test-key")
attempts = {"count": 0}
def create_completion(**kwargs):
attempts["count"] += 1
raise RuntimeError("context length exceeded 429")
provider._client = SimpleNamespace(
chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)),
responses=SimpleNamespace(create=lambda **_: None),
)
monkeypatch.setattr(
OpenAIModelProvider,
"_is_error_retryable",
lambda self, error: False,
)
with pytest.raises(RuntimeError) as excinfo:
provider.generate_content("hello", "gpt-4.1")
assert "after 1 attempt" in str(excinfo.value)
assert attempts["count"] == 1