From f955100f3a82973ccd987607e1d8a1bbe07828c8 Mon Sep 17 00:00:00 2001 From: Fahad Date: Fri, 3 Oct 2025 23:48:55 +0400 Subject: [PATCH] refactor: improved retry logic and moved core logic to base class --- providers/base.py | 104 ++++++++++++++- providers/dial.py | 76 +++++------ providers/gemini.py | 205 +++++++++++++---------------- providers/openai_compatible.py | 187 +++++++++++--------------- tests/test_provider_retry_logic.py | 73 ++++++++++ 5 files changed, 378 insertions(+), 267 deletions(-) create mode 100644 tests/test_provider_retry_logic.py diff --git a/providers/base.py b/providers/base.py index 8d42738..c5e9cb0 100644 --- a/providers/base.py +++ b/providers/base.py @@ -1,8 +1,9 @@ """Base interfaces and common behaviour for model providers.""" import logging +import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional if TYPE_CHECKING: from tools.models import ToolModelCategory @@ -168,6 +169,107 @@ class ModelProvider(ABC): return + # ------------------------------------------------------------------ + # Retry helpers + # ------------------------------------------------------------------ + def _is_error_retryable(self, error: Exception) -> bool: + """Return True when an error warrants another attempt. + + Subclasses with structured provider errors should override this hook. + The default implementation only retries obvious transient failures such + as timeouts or 5xx responses detected via string inspection. + """ + + error_str = str(error).lower() + retryable_indicators = [ + "timeout", + "connection", + "temporary", + "unavailable", + "retry", + "reset", + "refused", + "broken pipe", + "tls", + "handshake", + "network", + "rate limit", + "429", + "500", + "502", + "503", + "504", + ] + + return any(indicator in error_str for indicator in retryable_indicators) + + def _run_with_retries( + self, + operation: Callable[[], Any], + *, + max_attempts: int, + delays: Optional[list[float]] = None, + log_prefix: str = "", + ): + """Execute ``operation`` with retry semantics. + + Args: + operation: Callable returning the provider result. + max_attempts: Maximum number of attempts (>=1). + delays: Optional list of sleep durations between attempts. + log_prefix: Optional identifier for log clarity. + + Returns: + Whatever ``operation`` returns. + + Raises: + The last exception when all retries fail or the error is not retryable. + """ + + if max_attempts < 1: + raise ValueError("max_attempts must be >= 1") + + attempts = max_attempts + delays = delays or [] + last_exc: Optional[Exception] = None + + for attempt_index in range(attempts): + try: + return operation() + except Exception as exc: # noqa: BLE001 - bubble exact provider errors + last_exc = exc + attempt_number = attempt_index + 1 + + # Decide whether to retry based on subclass hook + retryable = self._is_error_retryable(exc) + if not retryable or attempt_number >= attempts: + raise + + delay_idx = min(attempt_index, len(delays) - 1) if delays else -1 + delay = delays[delay_idx] if delay_idx >= 0 else 0.0 + + if delay > 0: + logger.warning( + "%s retryable error (attempt %s/%s): %s. Retrying in %ss...", + log_prefix or self.__class__.__name__, + attempt_number, + attempts, + exc, + delay, + ) + time.sleep(delay) + else: + logger.warning( + "%s retryable error (attempt %s/%s): %s. Retrying...", + log_prefix or self.__class__.__name__, + attempt_number, + attempts, + exc, + ) + + # Should never reach here because loop either returns or raises + raise last_exc if last_exc else RuntimeError("Retry loop exited without result") + # ------------------------------------------------------------------ # Validation hooks # ------------------------------------------------------------------ diff --git a/providers/dial.py b/providers/dial.py index caf33f6..7c23331 100644 --- a/providers/dial.py +++ b/providers/dial.py @@ -3,7 +3,6 @@ import logging import os import threading -import time from typing import Optional from .openai_compatible import OpenAICompatibleProvider @@ -405,55 +404,42 @@ class DIALModelProvider(OpenAICompatibleProvider): # DIAL-specific: Get cached client for deployment endpoint deployment_client = self._get_deployment_client(resolved_model) - # Retry logic with progressive delays - last_exception = None + attempt_counter = {"value": 0} - for attempt in range(self.MAX_RETRIES): - try: - # Generate completion using deployment-specific client - response = deployment_client.chat.completions.create(**completion_params) + def _attempt() -> ModelResponse: + attempt_counter["value"] += 1 + response = deployment_client.chat.completions.create(**completion_params) - # Extract content and usage - content = response.choices[0].message.content - usage = self._extract_usage(response) + content = response.choices[0].message.content + usage = self._extract_usage(response) - return ModelResponse( - content=content, - usage=usage, - model_name=model_name, - friendly_name=self.FRIENDLY_NAME, - provider=self.get_provider_type(), - metadata={ - "finish_reason": response.choices[0].finish_reason, - "model": response.model, - "id": response.id, - "created": response.created, - }, - ) + return ModelResponse( + content=content, + usage=usage, + model_name=model_name, + friendly_name=self.FRIENDLY_NAME, + provider=self.get_provider_type(), + metadata={ + "finish_reason": response.choices[0].finish_reason, + "model": response.model, + "id": response.id, + "created": response.created, + }, + ) - except Exception as e: - last_exception = e + try: + return self._run_with_retries( + operation=_attempt, + max_attempts=self.MAX_RETRIES, + delays=self.RETRY_DELAYS, + log_prefix=f"DIAL API ({resolved_model})", + ) + except Exception as exc: + attempts = max(attempt_counter["value"], 1) + if attempts == 1: + raise ValueError(f"DIAL API error for model {model_name}: {exc}") from exc - # Check if this is a retryable error - is_retryable = self._is_error_retryable(e) - - if not is_retryable: - # Non-retryable error, raise immediately - raise ValueError(f"DIAL API error for model {model_name}: {str(e)}") - - # If this isn't the last attempt and error is retryable, wait and retry - if attempt < self.MAX_RETRIES - 1: - delay = self.RETRY_DELAYS[attempt] - logger.info( - f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}" - ) - time.sleep(delay) - continue - - # All retries exhausted - raise ValueError( - f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}" - ) + raise ValueError(f"DIAL API error for model {model_name} after {attempts} attempts: {exc}") from exc def close(self) -> None: """Clean up HTTP clients when provider is closed.""" diff --git a/providers/gemini.py b/providers/gemini.py index 5788da0..0f0068b 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -2,7 +2,6 @@ import base64 import logging -import time from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: @@ -229,133 +228,111 @@ class GeminiModelProvider(ModelProvider): # Retry logic with progressive delays max_retries = 4 # Total of 4 attempts retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s + attempt_counter = {"value": 0} - last_exception = None + def _attempt() -> ModelResponse: + attempt_counter["value"] += 1 + response = self.client.models.generate_content( + model=resolved_name, + contents=contents, + config=generation_config, + ) - for attempt in range(max_retries): - try: - # Generate content - response = self.client.models.generate_content( - model=resolved_name, - contents=contents, - config=generation_config, - ) + usage = self._extract_usage(response) - # Extract usage information if available - usage = self._extract_usage(response) + finish_reason_str = "UNKNOWN" + is_blocked_by_safety = False + safety_feedback_details = None - # Intelligently determine finish reason and safety blocks - finish_reason_str = "UNKNOWN" - is_blocked_by_safety = False - safety_feedback_details = None + if response.candidates: + candidate = response.candidates[0] - if response.candidates: - candidate = response.candidates[0] - - # Safely get finish reason - try: - finish_reason_enum = candidate.finish_reason - if finish_reason_enum: - # Handle both enum objects and string values - try: - finish_reason_str = finish_reason_enum.name - except AttributeError: - finish_reason_str = str(finish_reason_enum) - else: - finish_reason_str = "STOP" - except AttributeError: - finish_reason_str = "STOP" - - # If content is empty, check safety ratings for the definitive cause - if not response.text: + try: + finish_reason_enum = candidate.finish_reason + if finish_reason_enum: try: - safety_ratings = candidate.safety_ratings - if safety_ratings: # Check it's not None or empty - for rating in safety_ratings: - try: - if rating.blocked: - is_blocked_by_safety = True - # Provide details for logging/debugging - category_name = "UNKNOWN" - probability_name = "UNKNOWN" - - try: - category_name = rating.category.name - except (AttributeError, TypeError): - pass - - try: - probability_name = rating.probability.name - except (AttributeError, TypeError): - pass - - safety_feedback_details = ( - f"Category: {category_name}, Probability: {probability_name}" - ) - break - except (AttributeError, TypeError): - # Individual rating doesn't have expected attributes - continue - except (AttributeError, TypeError): - # candidate doesn't have safety_ratings or it's not iterable - pass - - # Also check for prompt-level blocking (request rejected entirely) - elif response.candidates is not None and len(response.candidates) == 0: - # No candidates is the primary indicator of a prompt-level block - is_blocked_by_safety = True - finish_reason_str = "SAFETY" - safety_feedback_details = "Prompt blocked, reason unavailable" # Default message + finish_reason_str = finish_reason_enum.name + except AttributeError: + finish_reason_str = str(finish_reason_enum) + else: + finish_reason_str = "STOP" + except AttributeError: + finish_reason_str = "STOP" + if not response.text: try: - prompt_feedback = response.prompt_feedback - if prompt_feedback and prompt_feedback.block_reason: - try: - block_reason_name = prompt_feedback.block_reason.name - except AttributeError: - block_reason_name = str(prompt_feedback.block_reason) - safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}" + safety_ratings = candidate.safety_ratings + if safety_ratings: + for rating in safety_ratings: + try: + if rating.blocked: + is_blocked_by_safety = True + category_name = "UNKNOWN" + probability_name = "UNKNOWN" + + try: + category_name = rating.category.name + except (AttributeError, TypeError): + pass + + try: + probability_name = rating.probability.name + except (AttributeError, TypeError): + pass + + safety_feedback_details = ( + f"Category: {category_name}, Probability: {probability_name}" + ) + break + except (AttributeError, TypeError): + continue except (AttributeError, TypeError): - # prompt_feedback doesn't exist or has unexpected attributes; stick with the default message pass - return ModelResponse( - content=response.text, - usage=usage, - model_name=resolved_name, - friendly_name="Gemini", - provider=ProviderType.GOOGLE, - metadata={ - "thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None, - "finish_reason": finish_reason_str, - "is_blocked_by_safety": is_blocked_by_safety, - "safety_feedback": safety_feedback_details, - }, - ) + elif response.candidates is not None and len(response.candidates) == 0: + is_blocked_by_safety = True + finish_reason_str = "SAFETY" + safety_feedback_details = "Prompt blocked, reason unavailable" - except Exception as e: - last_exception = e + try: + prompt_feedback = response.prompt_feedback + if prompt_feedback and prompt_feedback.block_reason: + try: + block_reason_name = prompt_feedback.block_reason.name + except AttributeError: + block_reason_name = str(prompt_feedback.block_reason) + safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}" + except (AttributeError, TypeError): + pass - # Check if this is a retryable error using structured error codes - is_retryable = self._is_error_retryable(e) + return ModelResponse( + content=response.text, + usage=usage, + model_name=resolved_name, + friendly_name="Gemini", + provider=ProviderType.GOOGLE, + metadata={ + "thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None, + "finish_reason": finish_reason_str, + "is_blocked_by_safety": is_blocked_by_safety, + "safety_feedback": safety_feedback_details, + }, + ) - # If this is the last attempt or not retryable, give up - if attempt == max_retries - 1 or not is_retryable: - break - - # Get progressive delay - delay = retry_delays[attempt] - - # Log retry attempt - logger.warning( - f"Gemini API error for model {resolved_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..." - ) - time.sleep(delay) - - # If we get here, all retries failed - actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count - error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" - raise RuntimeError(error_msg) from last_exception + try: + return self._run_with_retries( + operation=_attempt, + max_attempts=max_retries, + delays=retry_delays, + log_prefix=f"Gemini API ({resolved_name})", + ) + except Exception as exc: + attempts = max(attempt_counter["value"], 1) + error_msg = ( + f"Gemini API error for model {resolved_name} after {attempts} attempt" + f"{'s' if attempts > 1 else ''}: {exc}" + ) + raise RuntimeError(error_msg) from exc def get_provider_type(self) -> ProviderType: """Get the provider type.""" diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py index 5c93a3c..9eb126b 100644 --- a/providers/openai_compatible.py +++ b/providers/openai_compatible.py @@ -4,7 +4,6 @@ import copy import ipaddress import logging import os -import time from typing import Optional from urllib.parse import urlparse @@ -395,72 +394,59 @@ class OpenAICompatibleProvider(ModelProvider): # Retry logic with progressive delays max_retries = 4 retry_delays = [1, 3, 5, 8] - last_exception = None - actual_attempts = 0 + attempt_counter = {"value": 0} - for attempt in range(max_retries): - try: # Log sanitized payload for debugging - import json + def _attempt() -> ModelResponse: + attempt_counter["value"] += 1 + import json - sanitized_params = self._sanitize_for_logging(completion_params) - logging.info( - f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}" - ) + sanitized_params = self._sanitize_for_logging(completion_params) + logging.info( + f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}" + ) - # Use OpenAI client's responses endpoint - response = self.client.responses.create(**completion_params) + response = self.client.responses.create(**completion_params) - # Extract content from responses endpoint format - # Use validation helper to safely extract output_text - content = self._safe_extract_output_text(response) + content = self._safe_extract_output_text(response) - # Try to extract usage information - usage = None - if hasattr(response, "usage"): - usage = self._extract_usage(response) - elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"): - # Safely extract token counts with None handling - input_tokens = getattr(response, "input_tokens", 0) or 0 - output_tokens = getattr(response, "output_tokens", 0) or 0 - usage = { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": input_tokens + output_tokens, - } + usage = None + if hasattr(response, "usage"): + usage = self._extract_usage(response) + elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"): + input_tokens = getattr(response, "input_tokens", 0) or 0 + output_tokens = getattr(response, "output_tokens", 0) or 0 + usage = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + } - return ModelResponse( - content=content, - usage=usage, - model_name=model_name, - friendly_name=self.FRIENDLY_NAME, - provider=self.get_provider_type(), - metadata={ - "model": getattr(response, "model", model_name), - "id": getattr(response, "id", ""), - "created": getattr(response, "created_at", 0), - "endpoint": "responses", - }, - ) + return ModelResponse( + content=content, + usage=usage, + model_name=model_name, + friendly_name=self.FRIENDLY_NAME, + provider=self.get_provider_type(), + metadata={ + "model": getattr(response, "model", model_name), + "id": getattr(response, "id", ""), + "created": getattr(response, "created_at", 0), + "endpoint": "responses", + }, + ) - except Exception as e: - last_exception = e - - # Check if this is a retryable error using structured error codes - is_retryable = self._is_error_retryable(e) - - if is_retryable and attempt < max_retries - 1: - delay = retry_delays[attempt] - logging.warning( - f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..." - ) - time.sleep(delay) - else: - break - - # If we get here, all retries failed - error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" - logging.error(error_msg) - raise RuntimeError(error_msg) from last_exception + try: + return self._run_with_retries( + operation=_attempt, + max_attempts=max_retries, + delays=retry_delays, + log_prefix="o3-pro responses endpoint", + ) + except Exception as exc: + attempts = max(attempt_counter["value"], 1) + error_msg = f"o3-pro responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}" + logging.error(error_msg) + raise RuntimeError(error_msg) from exc def generate_content( self, @@ -587,57 +573,44 @@ class OpenAICompatibleProvider(ModelProvider): # Retry logic with progressive delays max_retries = 4 # Total of 4 attempts retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s + attempt_counter = {"value": 0} - last_exception = None - actual_attempts = 0 + def _attempt() -> ModelResponse: + attempt_counter["value"] += 1 + response = self.client.chat.completions.create(**completion_params) - for attempt in range(max_retries): - actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count - try: - # Generate completion - response = self.client.chat.completions.create(**completion_params) + content = response.choices[0].message.content + usage = self._extract_usage(response) - # Extract content and usage - content = response.choices[0].message.content - usage = self._extract_usage(response) + return ModelResponse( + content=content, + usage=usage, + model_name=model_name, + friendly_name=self.FRIENDLY_NAME, + provider=self.get_provider_type(), + metadata={ + "finish_reason": response.choices[0].finish_reason, + "model": response.model, + "id": response.id, + "created": response.created, + }, + ) - return ModelResponse( - content=content, - usage=usage, - model_name=model_name, - friendly_name=self.FRIENDLY_NAME, - provider=self.get_provider_type(), - metadata={ - "finish_reason": response.choices[0].finish_reason, - "model": response.model, # Actual model used - "id": response.id, - "created": response.created, - }, - ) - - except Exception as e: - last_exception = e - - # Check if this is a retryable error using structured error codes - is_retryable = self._is_error_retryable(e) - - # If this is the last attempt or not retryable, give up - if attempt == max_retries - 1 or not is_retryable: - break - - # Get progressive delay - delay = retry_delays[attempt] - - # Log retry attempt - logging.warning( - f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..." - ) - time.sleep(delay) - - # If we get here, all retries failed - error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" - logging.error(error_msg) - raise RuntimeError(error_msg) from last_exception + try: + return self._run_with_retries( + operation=_attempt, + max_attempts=max_retries, + delays=retry_delays, + log_prefix=f"{self.FRIENDLY_NAME} API ({model_name})", + ) + except Exception as exc: + attempts = max(attempt_counter["value"], 1) + error_msg = ( + f"{self.FRIENDLY_NAME} API error for model {model_name} after {attempts} attempt" + f"{'s' if attempts > 1 else ''}: {exc}" + ) + logging.error(error_msg) + raise RuntimeError(error_msg) from exc def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: """Validate model parameters. diff --git a/tests/test_provider_retry_logic.py b/tests/test_provider_retry_logic.py new file mode 100644 index 0000000..4ff92ed --- /dev/null +++ b/tests/test_provider_retry_logic.py @@ -0,0 +1,73 @@ +"""Tests covering shared retry behaviour for providers.""" + +from types import SimpleNamespace + +import pytest + +from providers.openai_provider import OpenAIModelProvider + + +def _mock_chat_response(content: str = "retry success") -> SimpleNamespace: + """Create a minimal chat completion response for tests.""" + + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) + message = SimpleNamespace(content=content) + choice = SimpleNamespace(message=message, finish_reason="stop") + return SimpleNamespace(choices=[choice], model="gpt-4.1", id="resp-1", created=123, usage=usage) + + +def test_openai_provider_retries_on_transient_error(monkeypatch): + """Provider should retry once for retryable errors and eventually succeed.""" + + monkeypatch.setattr("providers.base.time.sleep", lambda _: None) + + provider = OpenAIModelProvider(api_key="test-key") + + attempts = {"count": 0} + + def create_completion(**kwargs): + attempts["count"] += 1 + if attempts["count"] == 1: + raise RuntimeError("temporary network interruption") + return _mock_chat_response("second attempt response") + + provider._client = SimpleNamespace( + chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)), + responses=SimpleNamespace(create=lambda **_: None), + ) + + result = provider.generate_content("hello", "gpt-4.1") + + assert attempts["count"] == 2, "Expected a retry before succeeding" + assert result.content == "second attempt response" + + +def test_openai_provider_bails_on_non_retryable_error(monkeypatch): + """Provider should stop immediately when the error is marked non-retryable.""" + + monkeypatch.setattr("providers.base.time.sleep", lambda _: None) + + provider = OpenAIModelProvider(api_key="test-key") + + attempts = {"count": 0} + + def create_completion(**kwargs): + attempts["count"] += 1 + raise RuntimeError("context length exceeded 429") + + provider._client = SimpleNamespace( + chat=SimpleNamespace(completions=SimpleNamespace(create=create_completion)), + responses=SimpleNamespace(create=lambda **_: None), + ) + + monkeypatch.setattr( + OpenAIModelProvider, + "_is_error_retryable", + lambda self, error: False, + ) + + with pytest.raises(RuntimeError) as excinfo: + provider.generate_content("hello", "gpt-4.1") + + assert "after 1 attempt" in str(excinfo.value) + assert attempts["count"] == 1