refactor: improved retry logic and moved core logic to base class
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -229,133 +228,111 @@ class GeminiModelProvider(ModelProvider):
|
||||
# Retry logic with progressive delays
|
||||
max_retries = 4 # Total of 4 attempts
|
||||
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
||||
attempt_counter = {"value": 0}
|
||||
|
||||
last_exception = None
|
||||
def _attempt() -> ModelResponse:
|
||||
attempt_counter["value"] += 1
|
||||
response = self.client.models.generate_content(
|
||||
model=resolved_name,
|
||||
contents=contents,
|
||||
config=generation_config,
|
||||
)
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Generate content
|
||||
response = self.client.models.generate_content(
|
||||
model=resolved_name,
|
||||
contents=contents,
|
||||
config=generation_config,
|
||||
)
|
||||
usage = self._extract_usage(response)
|
||||
|
||||
# Extract usage information if available
|
||||
usage = self._extract_usage(response)
|
||||
finish_reason_str = "UNKNOWN"
|
||||
is_blocked_by_safety = False
|
||||
safety_feedback_details = None
|
||||
|
||||
# Intelligently determine finish reason and safety blocks
|
||||
finish_reason_str = "UNKNOWN"
|
||||
is_blocked_by_safety = False
|
||||
safety_feedback_details = None
|
||||
if response.candidates:
|
||||
candidate = response.candidates[0]
|
||||
|
||||
if response.candidates:
|
||||
candidate = response.candidates[0]
|
||||
|
||||
# Safely get finish reason
|
||||
try:
|
||||
finish_reason_enum = candidate.finish_reason
|
||||
if finish_reason_enum:
|
||||
# Handle both enum objects and string values
|
||||
try:
|
||||
finish_reason_str = finish_reason_enum.name
|
||||
except AttributeError:
|
||||
finish_reason_str = str(finish_reason_enum)
|
||||
else:
|
||||
finish_reason_str = "STOP"
|
||||
except AttributeError:
|
||||
finish_reason_str = "STOP"
|
||||
|
||||
# If content is empty, check safety ratings for the definitive cause
|
||||
if not response.text:
|
||||
try:
|
||||
finish_reason_enum = candidate.finish_reason
|
||||
if finish_reason_enum:
|
||||
try:
|
||||
safety_ratings = candidate.safety_ratings
|
||||
if safety_ratings: # Check it's not None or empty
|
||||
for rating in safety_ratings:
|
||||
try:
|
||||
if rating.blocked:
|
||||
is_blocked_by_safety = True
|
||||
# Provide details for logging/debugging
|
||||
category_name = "UNKNOWN"
|
||||
probability_name = "UNKNOWN"
|
||||
|
||||
try:
|
||||
category_name = rating.category.name
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
try:
|
||||
probability_name = rating.probability.name
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
safety_feedback_details = (
|
||||
f"Category: {category_name}, Probability: {probability_name}"
|
||||
)
|
||||
break
|
||||
except (AttributeError, TypeError):
|
||||
# Individual rating doesn't have expected attributes
|
||||
continue
|
||||
except (AttributeError, TypeError):
|
||||
# candidate doesn't have safety_ratings or it's not iterable
|
||||
pass
|
||||
|
||||
# Also check for prompt-level blocking (request rejected entirely)
|
||||
elif response.candidates is not None and len(response.candidates) == 0:
|
||||
# No candidates is the primary indicator of a prompt-level block
|
||||
is_blocked_by_safety = True
|
||||
finish_reason_str = "SAFETY"
|
||||
safety_feedback_details = "Prompt blocked, reason unavailable" # Default message
|
||||
finish_reason_str = finish_reason_enum.name
|
||||
except AttributeError:
|
||||
finish_reason_str = str(finish_reason_enum)
|
||||
else:
|
||||
finish_reason_str = "STOP"
|
||||
except AttributeError:
|
||||
finish_reason_str = "STOP"
|
||||
|
||||
if not response.text:
|
||||
try:
|
||||
prompt_feedback = response.prompt_feedback
|
||||
if prompt_feedback and prompt_feedback.block_reason:
|
||||
try:
|
||||
block_reason_name = prompt_feedback.block_reason.name
|
||||
except AttributeError:
|
||||
block_reason_name = str(prompt_feedback.block_reason)
|
||||
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
|
||||
safety_ratings = candidate.safety_ratings
|
||||
if safety_ratings:
|
||||
for rating in safety_ratings:
|
||||
try:
|
||||
if rating.blocked:
|
||||
is_blocked_by_safety = True
|
||||
category_name = "UNKNOWN"
|
||||
probability_name = "UNKNOWN"
|
||||
|
||||
try:
|
||||
category_name = rating.category.name
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
try:
|
||||
probability_name = rating.probability.name
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
safety_feedback_details = (
|
||||
f"Category: {category_name}, Probability: {probability_name}"
|
||||
)
|
||||
break
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
except (AttributeError, TypeError):
|
||||
# prompt_feedback doesn't exist or has unexpected attributes; stick with the default message
|
||||
pass
|
||||
|
||||
return ModelResponse(
|
||||
content=response.text,
|
||||
usage=usage,
|
||||
model_name=resolved_name,
|
||||
friendly_name="Gemini",
|
||||
provider=ProviderType.GOOGLE,
|
||||
metadata={
|
||||
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
|
||||
"finish_reason": finish_reason_str,
|
||||
"is_blocked_by_safety": is_blocked_by_safety,
|
||||
"safety_feedback": safety_feedback_details,
|
||||
},
|
||||
)
|
||||
elif response.candidates is not None and len(response.candidates) == 0:
|
||||
is_blocked_by_safety = True
|
||||
finish_reason_str = "SAFETY"
|
||||
safety_feedback_details = "Prompt blocked, reason unavailable"
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
try:
|
||||
prompt_feedback = response.prompt_feedback
|
||||
if prompt_feedback and prompt_feedback.block_reason:
|
||||
try:
|
||||
block_reason_name = prompt_feedback.block_reason.name
|
||||
except AttributeError:
|
||||
block_reason_name = str(prompt_feedback.block_reason)
|
||||
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# Check if this is a retryable error using structured error codes
|
||||
is_retryable = self._is_error_retryable(e)
|
||||
return ModelResponse(
|
||||
content=response.text,
|
||||
usage=usage,
|
||||
model_name=resolved_name,
|
||||
friendly_name="Gemini",
|
||||
provider=ProviderType.GOOGLE,
|
||||
metadata={
|
||||
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
|
||||
"finish_reason": finish_reason_str,
|
||||
"is_blocked_by_safety": is_blocked_by_safety,
|
||||
"safety_feedback": safety_feedback_details,
|
||||
},
|
||||
)
|
||||
|
||||
# If this is the last attempt or not retryable, give up
|
||||
if attempt == max_retries - 1 or not is_retryable:
|
||||
break
|
||||
|
||||
# Get progressive delay
|
||||
delay = retry_delays[attempt]
|
||||
|
||||
# Log retry attempt
|
||||
logger.warning(
|
||||
f"Gemini API error for model {resolved_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
# If we get here, all retries failed
|
||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
||||
raise RuntimeError(error_msg) from last_exception
|
||||
try:
|
||||
return self._run_with_retries(
|
||||
operation=_attempt,
|
||||
max_attempts=max_retries,
|
||||
delays=retry_delays,
|
||||
log_prefix=f"Gemini API ({resolved_name})",
|
||||
)
|
||||
except Exception as exc:
|
||||
attempts = max(attempt_counter["value"], 1)
|
||||
error_msg = (
|
||||
f"Gemini API error for model {resolved_name} after {attempts} attempt"
|
||||
f"{'s' if attempts > 1 else ''}: {exc}"
|
||||
)
|
||||
raise RuntimeError(error_msg) from exc
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
|
||||
Reference in New Issue
Block a user