refactor: improved retry logic and moved core logic to base class

This commit is contained in:
Fahad
2025-10-03 23:48:55 +04:00
parent 828c4eed5b
commit f955100f3a
5 changed files with 378 additions and 267 deletions

View File

@@ -2,7 +2,6 @@
import base64
import logging
import time
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
@@ -229,133 +228,111 @@ class GeminiModelProvider(ModelProvider):
# Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
attempt_counter = {"value": 0}
last_exception = None
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
response = self.client.models.generate_content(
model=resolved_name,
contents=contents,
config=generation_config,
)
for attempt in range(max_retries):
try:
# Generate content
response = self.client.models.generate_content(
model=resolved_name,
contents=contents,
config=generation_config,
)
usage = self._extract_usage(response)
# Extract usage information if available
usage = self._extract_usage(response)
finish_reason_str = "UNKNOWN"
is_blocked_by_safety = False
safety_feedback_details = None
# Intelligently determine finish reason and safety blocks
finish_reason_str = "UNKNOWN"
is_blocked_by_safety = False
safety_feedback_details = None
if response.candidates:
candidate = response.candidates[0]
if response.candidates:
candidate = response.candidates[0]
# Safely get finish reason
try:
finish_reason_enum = candidate.finish_reason
if finish_reason_enum:
# Handle both enum objects and string values
try:
finish_reason_str = finish_reason_enum.name
except AttributeError:
finish_reason_str = str(finish_reason_enum)
else:
finish_reason_str = "STOP"
except AttributeError:
finish_reason_str = "STOP"
# If content is empty, check safety ratings for the definitive cause
if not response.text:
try:
finish_reason_enum = candidate.finish_reason
if finish_reason_enum:
try:
safety_ratings = candidate.safety_ratings
if safety_ratings: # Check it's not None or empty
for rating in safety_ratings:
try:
if rating.blocked:
is_blocked_by_safety = True
# Provide details for logging/debugging
category_name = "UNKNOWN"
probability_name = "UNKNOWN"
try:
category_name = rating.category.name
except (AttributeError, TypeError):
pass
try:
probability_name = rating.probability.name
except (AttributeError, TypeError):
pass
safety_feedback_details = (
f"Category: {category_name}, Probability: {probability_name}"
)
break
except (AttributeError, TypeError):
# Individual rating doesn't have expected attributes
continue
except (AttributeError, TypeError):
# candidate doesn't have safety_ratings or it's not iterable
pass
# Also check for prompt-level blocking (request rejected entirely)
elif response.candidates is not None and len(response.candidates) == 0:
# No candidates is the primary indicator of a prompt-level block
is_blocked_by_safety = True
finish_reason_str = "SAFETY"
safety_feedback_details = "Prompt blocked, reason unavailable" # Default message
finish_reason_str = finish_reason_enum.name
except AttributeError:
finish_reason_str = str(finish_reason_enum)
else:
finish_reason_str = "STOP"
except AttributeError:
finish_reason_str = "STOP"
if not response.text:
try:
prompt_feedback = response.prompt_feedback
if prompt_feedback and prompt_feedback.block_reason:
try:
block_reason_name = prompt_feedback.block_reason.name
except AttributeError:
block_reason_name = str(prompt_feedback.block_reason)
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
safety_ratings = candidate.safety_ratings
if safety_ratings:
for rating in safety_ratings:
try:
if rating.blocked:
is_blocked_by_safety = True
category_name = "UNKNOWN"
probability_name = "UNKNOWN"
try:
category_name = rating.category.name
except (AttributeError, TypeError):
pass
try:
probability_name = rating.probability.name
except (AttributeError, TypeError):
pass
safety_feedback_details = (
f"Category: {category_name}, Probability: {probability_name}"
)
break
except (AttributeError, TypeError):
continue
except (AttributeError, TypeError):
# prompt_feedback doesn't exist or has unexpected attributes; stick with the default message
pass
return ModelResponse(
content=response.text,
usage=usage,
model_name=resolved_name,
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": finish_reason_str,
"is_blocked_by_safety": is_blocked_by_safety,
"safety_feedback": safety_feedback_details,
},
)
elif response.candidates is not None and len(response.candidates) == 0:
is_blocked_by_safety = True
finish_reason_str = "SAFETY"
safety_feedback_details = "Prompt blocked, reason unavailable"
except Exception as e:
last_exception = e
try:
prompt_feedback = response.prompt_feedback
if prompt_feedback and prompt_feedback.block_reason:
try:
block_reason_name = prompt_feedback.block_reason.name
except AttributeError:
block_reason_name = str(prompt_feedback.block_reason)
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
except (AttributeError, TypeError):
pass
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
return ModelResponse(
content=response.text,
usage=usage,
model_name=resolved_name,
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": finish_reason_str,
"is_blocked_by_safety": is_blocked_by_safety,
"safety_feedback": safety_feedback_details,
},
)
# If this is the last attempt or not retryable, give up
if attempt == max_retries - 1 or not is_retryable:
break
# Get progressive delay
delay = retry_delays[attempt]
# Log retry attempt
logger.warning(
f"Gemini API error for model {resolved_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
# If we get here, all retries failed
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
raise RuntimeError(error_msg) from last_exception
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix=f"Gemini API ({resolved_name})",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = (
f"Gemini API error for model {resolved_name} after {attempts} attempt"
f"{'s' if attempts > 1 else ''}: {exc}"
)
raise RuntimeError(error_msg) from exc
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""