Files
my-pal-mcp-server/providers/gemini.py
Fahad 9a55ca8898 WIP lots of new tests and validation scenarios
Simulation tests to confirm threading and history traversal
Chain of communication and branching validation tests from live simulation
Temperature enforcement per model
2025-06-12 09:35:05 +04:00

194 lines
7.3 KiB
Python

"""Gemini model provider implementation."""
import os
from typing import Dict, Optional, List
from google import genai
from google.genai import types
from .base import (
ModelProvider,
ModelResponse,
ModelCapabilities,
ProviderType,
RangeTemperatureConstraint
)
class GeminiModelProvider(ModelProvider):
"""Google Gemini model provider implementation."""
# Model configurations
SUPPORTED_MODELS = {
"gemini-2.0-flash-exp": {
"max_tokens": 1_048_576, # 1M tokens
"supports_extended_thinking": False,
},
"gemini-2.5-pro-preview-06-05": {
"max_tokens": 1_048_576, # 1M tokens
"supports_extended_thinking": True,
},
# Shorthands
"flash": "gemini-2.0-flash-exp",
"pro": "gemini-2.5-pro-preview-06-05",
}
# Thinking mode configurations for models that support it
THINKING_BUDGETS = {
"minimal": 128, # Minimum for 2.5 Pro - fast responses
"low": 2048, # Light reasoning tasks
"medium": 8192, # Balanced reasoning (default)
"high": 16384, # Complex analysis
"max": 32768, # Maximum reasoning depth
}
def __init__(self, api_key: str, **kwargs):
"""Initialize Gemini provider with API key."""
super().__init__(api_key, **kwargs)
self._client = None
self._token_counters = {} # Cache for token counting
@property
def client(self):
"""Lazy initialization of Gemini client."""
if self._client is None:
self._client = genai.Client(api_key=self.api_key)
return self._client
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific Gemini model."""
# Resolve shorthand
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported Gemini model: {model_name}")
config = self.SUPPORTED_MODELS[resolved_name]
# Gemini models support 0.0-2.0 temperature range
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
return ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name=resolved_name,
friendly_name="Gemini",
max_tokens=config["max_tokens"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
temperature_constraint=temp_constraint,
)
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
thinking_mode: str = "medium",
**kwargs
) -> ModelResponse:
"""Generate content using Gemini model."""
# Validate parameters
resolved_name = self._resolve_model_name(model_name)
self.validate_parameters(resolved_name, temperature)
# Combine system prompt with user prompt if provided
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
else:
full_prompt = prompt
# Prepare generation config
generation_config = types.GenerateContentConfig(
temperature=temperature,
candidate_count=1,
)
# Add max output tokens if specified
if max_output_tokens:
generation_config.max_output_tokens = max_output_tokens
# Add thinking configuration for models that support it
capabilities = self.get_capabilities(resolved_name)
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
generation_config.thinking_config = types.ThinkingConfig(
thinking_budget=self.THINKING_BUDGETS[thinking_mode]
)
try:
# Generate content
response = self.client.models.generate_content(
model=resolved_name,
contents=full_prompt,
config=generation_config,
)
# Extract usage information if available
usage = self._extract_usage(response)
return ModelResponse(
content=response.text,
usage=usage,
model_name=resolved_name,
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": getattr(response.candidates[0], "finish_reason", "STOP") if response.candidates else "STOP",
}
)
except Exception as e:
# Log error and re-raise with more context
error_msg = f"Gemini API error for model {resolved_name}: {str(e)}"
raise RuntimeError(error_msg) from e
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text using Gemini's tokenizer."""
resolved_name = self._resolve_model_name(model_name)
# For now, use a simple estimation
# TODO: Use actual Gemini tokenizer when available in SDK
# Rough estimation: ~4 characters per token for English text
return len(text) // 4
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.GOOGLE
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported."""
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
capabilities = self.get_capabilities(model_name)
return capabilities.supports_extended_thinking
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
# Check if it's a shorthand
shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower())
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
def _extract_usage(self, response) -> Dict[str, int]:
"""Extract token usage from Gemini response."""
usage = {}
# Try to extract usage metadata from response
# Note: The actual structure depends on the SDK version and response format
if hasattr(response, "usage_metadata"):
metadata = response.usage_metadata
if hasattr(metadata, "prompt_token_count"):
usage["input_tokens"] = metadata.prompt_token_count
if hasattr(metadata, "candidates_token_count"):
usage["output_tokens"] = metadata.candidates_token_count
if "input_tokens" in usage and "output_tokens" in usage:
usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
return usage