WIP major refactor and features

This commit is contained in:
Fahad
2025-06-12 07:14:59 +04:00
parent e06a6fd1fc
commit 2a067a7f4e
46 changed files with 2960 additions and 1011 deletions

185
providers/gemini.py Normal file
View File

@@ -0,0 +1,185 @@
"""Gemini model provider implementation."""
import os
from typing import Dict, Optional, List
from google import genai
from google.genai import types
from .base import ModelProvider, ModelResponse, ModelCapabilities, ProviderType
class GeminiModelProvider(ModelProvider):
"""Google Gemini model provider implementation."""
# Model configurations
SUPPORTED_MODELS = {
"gemini-2.0-flash-exp": {
"max_tokens": 1_048_576, # 1M tokens
"supports_extended_thinking": False,
},
"gemini-2.5-pro-preview-06-05": {
"max_tokens": 1_048_576, # 1M tokens
"supports_extended_thinking": True,
},
# Shorthands
"flash": "gemini-2.0-flash-exp",
"pro": "gemini-2.5-pro-preview-06-05",
}
# Thinking mode configurations for models that support it
THINKING_BUDGETS = {
"minimal": 128, # Minimum for 2.5 Pro - fast responses
"low": 2048, # Light reasoning tasks
"medium": 8192, # Balanced reasoning (default)
"high": 16384, # Complex analysis
"max": 32768, # Maximum reasoning depth
}
def __init__(self, api_key: str, **kwargs):
"""Initialize Gemini provider with API key."""
super().__init__(api_key, **kwargs)
self._client = None
self._token_counters = {} # Cache for token counting
@property
def client(self):
"""Lazy initialization of Gemini client."""
if self._client is None:
self._client = genai.Client(api_key=self.api_key)
return self._client
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific Gemini model."""
# Resolve shorthand
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported Gemini model: {model_name}")
config = self.SUPPORTED_MODELS[resolved_name]
return ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name=resolved_name,
friendly_name="Gemini",
max_tokens=config["max_tokens"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
temperature_range=(0.0, 2.0),
)
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
thinking_mode: str = "medium",
**kwargs
) -> ModelResponse:
"""Generate content using Gemini model."""
# Validate parameters
resolved_name = self._resolve_model_name(model_name)
self.validate_parameters(resolved_name, temperature)
# Combine system prompt with user prompt if provided
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
else:
full_prompt = prompt
# Prepare generation config
generation_config = types.GenerateContentConfig(
temperature=temperature,
candidate_count=1,
)
# Add max output tokens if specified
if max_output_tokens:
generation_config.max_output_tokens = max_output_tokens
# Add thinking configuration for models that support it
capabilities = self.get_capabilities(resolved_name)
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
generation_config.thinking_config = types.ThinkingConfig(
thinking_budget=self.THINKING_BUDGETS[thinking_mode]
)
try:
# Generate content
response = self.client.models.generate_content(
model=resolved_name,
contents=full_prompt,
config=generation_config,
)
# Extract usage information if available
usage = self._extract_usage(response)
return ModelResponse(
content=response.text,
usage=usage,
model_name=resolved_name,
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": getattr(response.candidates[0], "finish_reason", "STOP") if response.candidates else "STOP",
}
)
except Exception as e:
# Log error and re-raise with more context
error_msg = f"Gemini API error for model {resolved_name}: {str(e)}"
raise RuntimeError(error_msg) from e
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text using Gemini's tokenizer."""
resolved_name = self._resolve_model_name(model_name)
# For now, use a simple estimation
# TODO: Use actual Gemini tokenizer when available in SDK
# Rough estimation: ~4 characters per token for English text
return len(text) // 4
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.GOOGLE
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported."""
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
capabilities = self.get_capabilities(model_name)
return capabilities.supports_extended_thinking
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
# Check if it's a shorthand
shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower())
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
def _extract_usage(self, response) -> Dict[str, int]:
"""Extract token usage from Gemini response."""
usage = {}
# Try to extract usage metadata from response
# Note: The actual structure depends on the SDK version and response format
if hasattr(response, "usage_metadata"):
metadata = response.usage_metadata
if hasattr(metadata, "prompt_token_count"):
usage["input_tokens"] = metadata.prompt_token_count
if hasattr(metadata, "candidates_token_count"):
usage["output_tokens"] = metadata.candidates_token_count
if "input_tokens" in usage and "output_tokens" in usage:
usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
return usage