From a07036e6805042895109c00f921c58a09caaa319 Mon Sep 17 00:00:00 2001 From: Fahad Date: Sun, 24 Aug 2025 21:25:01 +0400 Subject: [PATCH] fix: another fix for https://github.com/BeehiveInnovations/zen-mcp-server/issues/251 --- providers/gemini.py | 144 +++++++++++++++++++++++++++++--------------- 1 file changed, 96 insertions(+), 48 deletions(-) diff --git a/providers/gemini.py b/providers/gemini.py index 1bfea75..5c587ad 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -238,38 +238,73 @@ class GeminiModelProvider(ModelProvider): if response.candidates: candidate = response.candidates[0] - finish_reason_enum = getattr(candidate, "finish_reason", None) - if finish_reason_enum: - # Handle both enum objects and string values - finish_reason_str = getattr(finish_reason_enum, "name", str(finish_reason_enum)) - else: + + # Safely get finish reason + try: + finish_reason_enum = candidate.finish_reason + if finish_reason_enum: + # Handle both enum objects and string values + try: + finish_reason_str = finish_reason_enum.name + except AttributeError: + finish_reason_str = str(finish_reason_enum) + else: + finish_reason_str = "STOP" + except AttributeError: finish_reason_str = "STOP" # If content is empty, check safety ratings for the definitive cause - if not response.text and hasattr(candidate, "safety_ratings"): - for rating in candidate.safety_ratings: - if getattr(rating, "blocked", False): - is_blocked_by_safety = True - # Provide details for logging/debugging - category_name = ( - getattr(rating.category, "name", "UNKNOWN") - if hasattr(rating, "category") - else "UNKNOWN" - ) - probability_name = ( - getattr(rating.probability, "name", "UNKNOWN") - if hasattr(rating, "probability") - else "UNKNOWN" - ) - safety_feedback_details = f"Category: {category_name}, Probability: {probability_name}" - break + if not response.text: + try: + safety_ratings = candidate.safety_ratings + if safety_ratings: # Check it's not None or empty + for rating in safety_ratings: + try: + if rating.blocked: + is_blocked_by_safety = True + # Provide details for logging/debugging + category_name = "UNKNOWN" + probability_name = "UNKNOWN" + + try: + category_name = rating.category.name + except (AttributeError, TypeError): + pass + + try: + probability_name = rating.probability.name + except (AttributeError, TypeError): + pass + + safety_feedback_details = ( + f"Category: {category_name}, Probability: {probability_name}" + ) + break + except (AttributeError, TypeError): + # Individual rating doesn't have expected attributes + continue + except (AttributeError, TypeError): + # candidate doesn't have safety_ratings or it's not iterable + pass # Also check for prompt-level blocking (request rejected entirely) - elif hasattr(response, "prompt_feedback") and getattr(response.prompt_feedback, "block_reason", None): + elif response.candidates is not None and len(response.candidates) == 0: + # No candidates is the primary indicator of a prompt-level block is_blocked_by_safety = True - finish_reason_str = "SAFETY" # This is a clear safety block - block_reason_name = getattr(response.prompt_feedback.block_reason, "name", "UNKNOWN") - safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}" + finish_reason_str = "SAFETY" + safety_feedback_details = "Prompt blocked, reason unavailable" # Default message + + try: + prompt_feedback = response.prompt_feedback + if prompt_feedback and prompt_feedback.block_reason: + try: + block_reason_name = prompt_feedback.block_reason.name + except AttributeError: + block_reason_name = str(prompt_feedback.block_reason) + safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}" + except (AttributeError, TypeError): + # prompt_feedback doesn't exist or has unexpected attributes; stick with the default message + pass return ModelResponse( content=response.text, @@ -370,28 +405,35 @@ class GeminiModelProvider(ModelProvider): # Try to extract usage metadata from response # Note: The actual structure depends on the SDK version and response format - if hasattr(response, "usage_metadata"): + try: metadata = response.usage_metadata + if metadata: + # Extract token counts with explicit None checks + input_tokens = None + output_tokens = None - # Extract token counts with explicit None checks - input_tokens = None - output_tokens = None + try: + value = metadata.prompt_token_count + if value is not None: + input_tokens = value + usage["input_tokens"] = value + except (AttributeError, TypeError): + pass - if hasattr(metadata, "prompt_token_count"): - value = metadata.prompt_token_count - if value is not None: - input_tokens = value - usage["input_tokens"] = value + try: + value = metadata.candidates_token_count + if value is not None: + output_tokens = value + usage["output_tokens"] = value + except (AttributeError, TypeError): + pass - if hasattr(metadata, "candidates_token_count"): - value = metadata.candidates_token_count - if value is not None: - output_tokens = value - usage["output_tokens"] = value - - # Calculate total only if both values are available and valid - if input_tokens is not None and output_tokens is not None: - usage["total_tokens"] = input_tokens + output_tokens + # Calculate total only if both values are available and valid + if input_tokens is not None and output_tokens is not None: + usage["total_tokens"] = input_tokens + output_tokens + except (AttributeError, TypeError): + # response doesn't have usage_metadata + pass return usage @@ -438,11 +480,17 @@ class GeminiModelProvider(ModelProvider): # Also check if this is a structured error from Gemini SDK try: # Try to access error details if available - if hasattr(error, "details") or hasattr(error, "reason"): - # Gemini API errors may have structured details - error_details = getattr(error, "details", "") or getattr(error, "reason", "") - error_details_str = str(error_details).lower() + error_details = None + try: + error_details = error.details + except AttributeError: + try: + error_details = error.reason + except AttributeError: + pass + if error_details: + error_details_str = str(error_details).lower() # Check for non-retryable error codes/reasons if any(indicator in error_details_str for indicator in non_retryable_indicators): logger.debug(f"Non-retryable Gemini error: {error_details}")