Files
my-pal-mcp-server/providers/gemini.py
Fahad 97fa6781cf Vision support via images / pdfs etc that can be passed on to other models as part of analysis, additional context etc.
Image processing pipeline added
OpenAI GPT-4.1 support
Chat tool prompt enhancement
Lint and code quality improvements
2025-06-16 13:14:53 +04:00

356 lines
14 KiB
Python

"""Gemini model provider implementation."""
import base64
import logging
import os
import time
from typing import Optional
from google import genai
from google.genai import types
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
logger = logging.getLogger(__name__)
class GeminiModelProvider(ModelProvider):
"""Google Gemini model provider implementation."""
# Model configurations
SUPPORTED_MODELS = {
"gemini-2.5-flash-preview-05-20": {
"context_window": 1_048_576, # 1M tokens
"supports_extended_thinking": True,
"max_thinking_tokens": 24576, # Flash 2.5 thinking budget limit
"supports_images": True, # Vision capability
"max_image_size_mb": 20.0, # Conservative 20MB limit for reliability
},
"gemini-2.5-pro-preview-06-05": {
"context_window": 1_048_576, # 1M tokens
"supports_extended_thinking": True,
"max_thinking_tokens": 32768, # Pro 2.5 thinking budget limit
"supports_images": True, # Vision capability
"max_image_size_mb": 32.0, # Higher limit for Pro model
},
# Shorthands
"flash": "gemini-2.5-flash-preview-05-20",
"pro": "gemini-2.5-pro-preview-06-05",
}
# Thinking mode configurations - percentages of model's max_thinking_tokens
# These percentages work across all models that support thinking
THINKING_BUDGETS = {
"minimal": 0.005, # 0.5% of max - minimal thinking for fast responses
"low": 0.08, # 8% of max - light reasoning tasks
"medium": 0.33, # 33% of max - balanced reasoning (default)
"high": 0.67, # 67% of max - complex analysis
"max": 1.0, # 100% of max - full thinking budget
}
def __init__(self, api_key: str, **kwargs):
"""Initialize Gemini provider with API key."""
super().__init__(api_key, **kwargs)
self._client = None
self._token_counters = {} # Cache for token counting
@property
def client(self):
"""Lazy initialization of Gemini client."""
if self._client is None:
self._client = genai.Client(api_key=self.api_key)
return self._client
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific Gemini model."""
# Resolve shorthand
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported Gemini model: {model_name}")
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
raise ValueError(f"Gemini model '{model_name}' is not allowed by restriction policy.")
config = self.SUPPORTED_MODELS[resolved_name]
# Gemini models support 0.0-2.0 temperature range
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
return ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name=resolved_name,
friendly_name="Gemini",
context_window=config["context_window"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_images=config.get("supports_images", False),
max_image_size_mb=config.get("max_image_size_mb", 0.0),
temperature_constraint=temp_constraint,
)
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
thinking_mode: str = "medium",
images: Optional[list[str]] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using Gemini model."""
# Validate parameters
resolved_name = self._resolve_model_name(model_name)
self.validate_parameters(resolved_name, temperature)
# Prepare content parts (text and potentially images)
parts = []
# Add system and user prompts as text
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
else:
full_prompt = prompt
parts.append({"text": full_prompt})
# Add images if provided and model supports vision
if images and self._supports_vision(resolved_name):
for image_path in images:
try:
image_part = self._process_image(image_path)
if image_part:
parts.append(image_part)
except Exception as e:
logger.warning(f"Failed to process image {image_path}: {e}")
# Continue with other images and text
continue
elif images and not self._supports_vision(resolved_name):
logger.warning(f"Model {resolved_name} does not support images, ignoring {len(images)} image(s)")
# Create contents structure
contents = [{"parts": parts}]
# Prepare generation config
generation_config = types.GenerateContentConfig(
temperature=temperature,
candidate_count=1,
)
# Add max output tokens if specified
if max_output_tokens:
generation_config.max_output_tokens = max_output_tokens
# Add thinking configuration for models that support it
capabilities = self.get_capabilities(resolved_name)
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
# Get model's max thinking tokens and calculate actual budget
model_config = self.SUPPORTED_MODELS.get(resolved_name)
if model_config and "max_thinking_tokens" in model_config:
max_thinking_tokens = model_config["max_thinking_tokens"]
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget)
# Retry logic with exponential backoff
max_retries = 2 # Total of 2 attempts (1 initial + 1 retry)
base_delay = 1.0 # Start with 1 second delay
last_exception = None
for attempt in range(max_retries):
try:
# Generate content
response = self.client.models.generate_content(
model=resolved_name,
contents=contents,
config=generation_config,
)
# Extract usage information if available
usage = self._extract_usage(response)
return ModelResponse(
content=response.text,
usage=usage,
model_name=resolved_name,
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": (
getattr(response.candidates[0], "finish_reason", "STOP") if response.candidates else "STOP"
),
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error
error_str = str(e).lower()
is_retryable = any(
term in error_str
for term in [
"timeout",
"connection",
"network",
"temporary",
"unavailable",
"retry",
"429",
"500",
"502",
"503",
"504",
]
)
# If this is the last attempt or not retryable, give up
if attempt == max_retries - 1 or not is_retryable:
break
# Calculate delay with exponential backoff
delay = base_delay * (2**attempt)
# Log retry attempt (could add logging here if needed)
# For now, just sleep and retry
time.sleep(delay)
# If we get here, all retries failed
error_msg = f"Gemini API error for model {resolved_name} after {max_retries} attempts: {str(last_exception)}"
raise RuntimeError(error_msg) from last_exception
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text using Gemini's tokenizer."""
self._resolve_model_name(model_name)
# For now, use a simple estimation
# TODO: Use actual Gemini tokenizer when available in SDK
# Rough estimation: ~4 characters per token for English text
return len(text) // 4
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.GOOGLE
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported and allowed."""
resolved_name = self._resolve_model_name(model_name)
# First check if model is supported
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
return False
# Then check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
logger.debug(f"Gemini model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
capabilities = self.get_capabilities(model_name)
return capabilities.supports_extended_thinking
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
"""Get actual thinking token budget for a model and thinking mode."""
resolved_name = self._resolve_model_name(model_name)
model_config = self.SUPPORTED_MODELS.get(resolved_name, {})
if not model_config.get("supports_extended_thinking", False):
return 0
if thinking_mode not in self.THINKING_BUDGETS:
return 0
max_thinking_tokens = model_config.get("max_thinking_tokens", 0)
if max_thinking_tokens == 0:
return 0
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
# Check if it's a shorthand
shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower())
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from Gemini response."""
usage = {}
# Try to extract usage metadata from response
# Note: The actual structure depends on the SDK version and response format
if hasattr(response, "usage_metadata"):
metadata = response.usage_metadata
if hasattr(metadata, "prompt_token_count"):
usage["input_tokens"] = metadata.prompt_token_count
if hasattr(metadata, "candidates_token_count"):
usage["output_tokens"] = metadata.candidates_token_count
if "input_tokens" in usage and "output_tokens" in usage:
usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
return usage
def _supports_vision(self, model_name: str) -> bool:
"""Check if the model supports vision (image processing)."""
# Gemini 2.5 models support vision
vision_models = {
"gemini-2.5-flash-preview-05-20",
"gemini-2.5-pro-preview-06-05",
"gemini-2.0-flash",
"gemini-1.5-pro",
"gemini-1.5-flash",
}
return model_name in vision_models
def _process_image(self, image_path: str) -> Optional[dict]:
"""Process an image for Gemini API."""
try:
if image_path.startswith("data:image/"):
# Handle data URL: data:image/png;base64,iVBORw0...
header, data = image_path.split(",", 1)
mime_type = header.split(";")[0].split(":")[1]
return {"inline_data": {"mime_type": mime_type, "data": data}}
else:
# Handle file path - translate for Docker environment
from utils.file_types import get_image_mime_type
from utils.file_utils import translate_path_for_environment
translated_path = translate_path_for_environment(image_path)
logger.debug(f"Translated image path from '{image_path}' to '{translated_path}'")
if not os.path.exists(translated_path):
logger.warning(f"Image file not found: {translated_path} (original: {image_path})")
return None
# Use translated path for all subsequent operations
image_path = translated_path
# Detect MIME type from file extension using centralized mappings
ext = os.path.splitext(image_path)[1].lower()
mime_type = get_image_mime_type(ext)
# Read and encode the image
with open(image_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
return {"inline_data": {"mime_type": mime_type, "data": image_data}}
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
return None