WIP major refactor and features
This commit is contained in:
570
tools/base.py
570
tools/base.py
@@ -20,13 +20,12 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from mcp.types import TextContent
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from config import DEFAULT_MODEL, MAX_CONTEXT_TOKENS, MCP_PROMPT_SIZE_LIMIT
|
||||
from utils import check_token_limit
|
||||
from providers import ModelProviderRegistry, ModelProvider, ModelResponse
|
||||
from utils.conversation_memory import (
|
||||
MAX_CONVERSATION_TURNS,
|
||||
add_turn,
|
||||
@@ -52,7 +51,7 @@ class ToolRequest(BaseModel):
|
||||
|
||||
model: Optional[str] = Field(
|
||||
None,
|
||||
description=f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
|
||||
description="Model to use. See tool's input schema for available models and their capabilities.",
|
||||
)
|
||||
temperature: Optional[float] = Field(None, description="Temperature for response (tool-specific defaults)")
|
||||
# Thinking mode controls how much computational budget the model uses for reasoning
|
||||
@@ -144,6 +143,38 @@ class BaseTool(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_model_field_schema(self) -> dict[str, Any]:
|
||||
"""
|
||||
Generate the model field schema based on auto mode configuration.
|
||||
|
||||
When auto mode is enabled, the model parameter becomes required
|
||||
and includes detailed descriptions of each model's capabilities.
|
||||
|
||||
Returns:
|
||||
Dict containing the model field JSON schema
|
||||
"""
|
||||
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
||||
|
||||
if IS_AUTO_MODE:
|
||||
# In auto mode, model is required and we provide detailed descriptions
|
||||
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
|
||||
for model, desc in MODEL_CAPABILITIES_DESC.items():
|
||||
model_desc_parts.append(f"- '{model}': {desc}")
|
||||
|
||||
return {
|
||||
"type": "string",
|
||||
"description": "\n".join(model_desc_parts),
|
||||
"enum": list(MODEL_CAPABILITIES_DESC.keys()),
|
||||
}
|
||||
else:
|
||||
# Normal mode - model is optional with default
|
||||
available_models = list(MODEL_CAPABILITIES_DESC.keys())
|
||||
models_str = ', '.join(f"'{m}'" for m in available_models)
|
||||
return {
|
||||
"type": "string",
|
||||
"description": f"Model to use. Available: {models_str}. Defaults to '{DEFAULT_MODEL}' if not specified.",
|
||||
}
|
||||
|
||||
def get_default_temperature(self) -> float:
|
||||
"""
|
||||
Return the default temperature setting for this tool.
|
||||
@@ -293,6 +324,11 @@ class BaseTool(ABC):
|
||||
"""
|
||||
if not request_files:
|
||||
return ""
|
||||
|
||||
# If conversation history is already embedded, skip file processing
|
||||
if hasattr(self, '_has_embedded_history') and self._has_embedded_history:
|
||||
logger.debug(f"[FILES] {self.name}: Skipping file processing - conversation history already embedded")
|
||||
return ""
|
||||
|
||||
# Extract remaining budget from arguments if available
|
||||
if remaining_budget is None:
|
||||
@@ -300,15 +336,59 @@ class BaseTool(ABC):
|
||||
args_to_use = arguments or getattr(self, "_current_arguments", {})
|
||||
remaining_budget = args_to_use.get("_remaining_tokens")
|
||||
|
||||
# Use remaining budget if provided, otherwise fall back to max_tokens or default
|
||||
# Use remaining budget if provided, otherwise fall back to max_tokens or model-specific default
|
||||
if remaining_budget is not None:
|
||||
effective_max_tokens = remaining_budget - reserve_tokens
|
||||
elif max_tokens is not None:
|
||||
effective_max_tokens = max_tokens - reserve_tokens
|
||||
else:
|
||||
from config import MAX_CONTENT_TOKENS
|
||||
|
||||
effective_max_tokens = MAX_CONTENT_TOKENS - reserve_tokens
|
||||
# Get model-specific limits
|
||||
# First check if model_context was passed from server.py
|
||||
model_context = None
|
||||
if arguments:
|
||||
model_context = arguments.get("_model_context") or getattr(self, "_current_arguments", {}).get("_model_context")
|
||||
|
||||
if model_context:
|
||||
# Use the passed model context
|
||||
try:
|
||||
token_allocation = model_context.calculate_token_allocation()
|
||||
effective_max_tokens = token_allocation.file_tokens - reserve_tokens
|
||||
logger.debug(f"[FILES] {self.name}: Using passed model context for {model_context.model_name}: "
|
||||
f"{token_allocation.file_tokens:,} file tokens from {token_allocation.total_tokens:,} total")
|
||||
except Exception as e:
|
||||
logger.warning(f"[FILES] {self.name}: Error using passed model context: {e}")
|
||||
# Fall through to manual calculation
|
||||
model_context = None
|
||||
|
||||
if not model_context:
|
||||
# Manual calculation as fallback
|
||||
model_name = getattr(self, "_current_model_name", None) or DEFAULT_MODEL
|
||||
try:
|
||||
provider = self.get_model_provider(model_name)
|
||||
capabilities = provider.get_capabilities(model_name)
|
||||
|
||||
# Calculate content allocation based on model capacity
|
||||
if capabilities.max_tokens < 300_000:
|
||||
# Smaller context models: 60% content, 40% response
|
||||
model_content_tokens = int(capabilities.max_tokens * 0.6)
|
||||
else:
|
||||
# Larger context models: 80% content, 20% response
|
||||
model_content_tokens = int(capabilities.max_tokens * 0.8)
|
||||
|
||||
effective_max_tokens = model_content_tokens - reserve_tokens
|
||||
logger.debug(f"[FILES] {self.name}: Using model-specific limit for {model_name}: "
|
||||
f"{model_content_tokens:,} content tokens from {capabilities.max_tokens:,} total")
|
||||
except (ValueError, AttributeError) as e:
|
||||
# Handle specific errors: provider not found, model not supported, missing attributes
|
||||
logger.warning(f"[FILES] {self.name}: Could not get model capabilities for {model_name}: {type(e).__name__}: {e}")
|
||||
# Fall back to conservative default for safety
|
||||
from config import MAX_CONTENT_TOKENS
|
||||
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
|
||||
except Exception as e:
|
||||
# Catch any other unexpected errors
|
||||
logger.error(f"[FILES] {self.name}: Unexpected error getting model capabilities: {type(e).__name__}: {e}")
|
||||
from config import MAX_CONTENT_TOKENS
|
||||
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
|
||||
|
||||
# Ensure we have a reasonable minimum budget
|
||||
effective_max_tokens = max(1000, effective_max_tokens)
|
||||
@@ -601,34 +681,59 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
)
|
||||
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||
|
||||
# Prepare the full prompt by combining system prompt with user request
|
||||
# This is delegated to the tool implementation for customization
|
||||
prompt = await self.prepare_prompt(request)
|
||||
|
||||
# Add follow-up instructions for new conversations (not threaded)
|
||||
# Check if we have continuation_id - if so, conversation history is already embedded
|
||||
continuation_id = getattr(request, "continuation_id", None)
|
||||
if not continuation_id:
|
||||
# Import here to avoid circular imports
|
||||
|
||||
if continuation_id:
|
||||
# When continuation_id is present, server.py has already injected the
|
||||
# conversation history into the appropriate field. We need to check if
|
||||
# the prompt already contains conversation history marker.
|
||||
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
|
||||
|
||||
# Store the original arguments to detect enhanced prompts
|
||||
self._has_embedded_history = False
|
||||
|
||||
# Check if conversation history is already embedded in the prompt field
|
||||
field_value = getattr(request, "prompt", "")
|
||||
field_name = "prompt"
|
||||
|
||||
if "=== CONVERSATION HISTORY ===" in field_value:
|
||||
# Conversation history is already embedded, use it directly
|
||||
prompt = field_value
|
||||
self._has_embedded_history = True
|
||||
logger.debug(f"{self.name}: Using pre-embedded conversation history from {field_name}")
|
||||
else:
|
||||
# No embedded history, prepare prompt normally
|
||||
prompt = await self.prepare_prompt(request)
|
||||
logger.debug(f"{self.name}: No embedded history found, prepared prompt normally")
|
||||
else:
|
||||
# New conversation, prepare prompt normally
|
||||
prompt = await self.prepare_prompt(request)
|
||||
|
||||
# Add follow-up instructions for new conversations
|
||||
from server import get_follow_up_instructions
|
||||
|
||||
follow_up_instructions = get_follow_up_instructions(0) # New conversation, turn 0
|
||||
prompt = f"{prompt}\n\n{follow_up_instructions}"
|
||||
|
||||
logger.debug(f"Added follow-up instructions for new {self.name} conversation")
|
||||
|
||||
# Also log to file for debugging MCP issues
|
||||
try:
|
||||
with open("/tmp/gemini_debug.log", "a") as f:
|
||||
f.write(f"[{self.name}] Added follow-up instructions for new conversation\n")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
|
||||
# History reconstruction is handled by server.py:reconstruct_thread_context
|
||||
# No need to rebuild it here - prompt already contains conversation history
|
||||
|
||||
# Extract model configuration from request or use defaults
|
||||
model_name = getattr(request, "model", None) or DEFAULT_MODEL
|
||||
model_name = getattr(request, "model", None)
|
||||
if not model_name:
|
||||
model_name = DEFAULT_MODEL
|
||||
|
||||
# In auto mode, model parameter is required
|
||||
from config import IS_AUTO_MODE
|
||||
if IS_AUTO_MODE and model_name.lower() == "auto":
|
||||
error_output = ToolOutput(
|
||||
status="error",
|
||||
content="Model parameter is required. Please specify which model to use for this task.",
|
||||
content_type="text",
|
||||
)
|
||||
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||
|
||||
# Store model name for use by helper methods like _prepare_file_content_for_prompt
|
||||
self._current_model_name = model_name
|
||||
|
||||
temperature = getattr(request, "temperature", None)
|
||||
if temperature is None:
|
||||
temperature = self.get_default_temperature()
|
||||
@@ -636,28 +741,45 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
if thinking_mode is None:
|
||||
thinking_mode = self.get_default_thinking_mode()
|
||||
|
||||
# Create model instance with appropriate configuration
|
||||
# This handles both regular models and thinking-enabled models
|
||||
model = self.create_model(model_name, temperature, thinking_mode)
|
||||
# Get the appropriate model provider
|
||||
provider = self.get_model_provider(model_name)
|
||||
|
||||
# Get system prompt for this tool
|
||||
system_prompt = self.get_system_prompt()
|
||||
|
||||
# Generate AI response using the configured model
|
||||
logger.info(f"Sending request to Gemini API for {self.name}")
|
||||
# Generate AI response using the provider
|
||||
logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.name}")
|
||||
logger.debug(f"Prompt length: {len(prompt)} characters")
|
||||
response = model.generate_content(prompt)
|
||||
logger.info(f"Received response from Gemini API for {self.name}")
|
||||
|
||||
# Generate content with provider abstraction
|
||||
model_response = provider.generate_content(
|
||||
prompt=prompt,
|
||||
model_name=model_name,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
|
||||
)
|
||||
|
||||
logger.info(f"Received response from {provider.get_provider_type().value} API for {self.name}")
|
||||
|
||||
# Process the model's response
|
||||
if response.candidates and response.candidates[0].content.parts:
|
||||
raw_text = response.candidates[0].content.parts[0].text
|
||||
if model_response.content:
|
||||
raw_text = model_response.content
|
||||
|
||||
# Parse response to check for clarification requests or format output
|
||||
tool_output = self._parse_response(raw_text, request)
|
||||
# Pass model info for conversation tracking
|
||||
model_info = {
|
||||
"provider": provider,
|
||||
"model_name": model_name,
|
||||
"model_response": model_response
|
||||
}
|
||||
tool_output = self._parse_response(raw_text, request, model_info)
|
||||
logger.info(f"Successfully completed {self.name} tool execution")
|
||||
|
||||
else:
|
||||
# Handle cases where the model couldn't generate a response
|
||||
# This might happen due to safety filters or other constraints
|
||||
finish_reason = response.candidates[0].finish_reason if response.candidates else "Unknown"
|
||||
finish_reason = model_response.metadata.get("finish_reason", "Unknown")
|
||||
logger.warning(f"Response blocked or incomplete for {self.name}. Finish reason: {finish_reason}")
|
||||
tool_output = ToolOutput(
|
||||
status="error",
|
||||
@@ -678,13 +800,24 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
if "500 INTERNAL" in error_msg and "Please retry" in error_msg:
|
||||
logger.warning(f"500 INTERNAL error in {self.name} - attempting retry")
|
||||
try:
|
||||
# Single retry attempt
|
||||
model = self._get_model_wrapper(request)
|
||||
raw_response = await model.generate_content(prompt)
|
||||
response = raw_response.text
|
||||
|
||||
# If successful, process normally
|
||||
return [TextContent(type="text", text=self._process_response(response, request).model_dump_json())]
|
||||
# Single retry attempt using provider
|
||||
retry_response = provider.generate_content(
|
||||
prompt=prompt,
|
||||
model_name=model_name,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
|
||||
)
|
||||
|
||||
if retry_response.content:
|
||||
# If successful, process normally
|
||||
retry_model_info = {
|
||||
"provider": provider,
|
||||
"model_name": model_name,
|
||||
"model_response": retry_response
|
||||
}
|
||||
tool_output = self._parse_response(retry_response.content, request, retry_model_info)
|
||||
return [TextContent(type="text", text=tool_output.model_dump_json())]
|
||||
|
||||
except Exception as retry_e:
|
||||
logger.error(f"Retry failed for {self.name} tool: {str(retry_e)}")
|
||||
@@ -699,7 +832,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
)
|
||||
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||
|
||||
def _parse_response(self, raw_text: str, request) -> ToolOutput:
|
||||
def _parse_response(self, raw_text: str, request, model_info: Optional[dict] = None) -> ToolOutput:
|
||||
"""
|
||||
Parse the raw response and determine if it's a clarification request or follow-up.
|
||||
|
||||
@@ -745,11 +878,11 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
pass
|
||||
|
||||
# Normal text response - format using tool-specific formatting
|
||||
formatted_content = self.format_response(raw_text, request)
|
||||
formatted_content = self.format_response(raw_text, request, model_info)
|
||||
|
||||
# If we found a follow-up question, prepare the threading response
|
||||
if follow_up_question:
|
||||
return self._create_follow_up_response(formatted_content, follow_up_question, request)
|
||||
return self._create_follow_up_response(formatted_content, follow_up_question, request, model_info)
|
||||
|
||||
# Check if we should offer Claude a continuation opportunity
|
||||
continuation_offer = self._check_continuation_opportunity(request)
|
||||
@@ -758,7 +891,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
logger.debug(
|
||||
f"Creating continuation offer for {self.name} with {continuation_offer['remaining_turns']} turns remaining"
|
||||
)
|
||||
return self._create_continuation_offer_response(formatted_content, continuation_offer, request)
|
||||
return self._create_continuation_offer_response(formatted_content, continuation_offer, request, model_info)
|
||||
else:
|
||||
logger.debug(f"No continuation offer created for {self.name}")
|
||||
|
||||
@@ -766,12 +899,32 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
continuation_id = getattr(request, "continuation_id", None)
|
||||
if continuation_id:
|
||||
request_files = getattr(request, "files", []) or []
|
||||
# Extract model metadata for conversation tracking
|
||||
model_provider = None
|
||||
model_name = None
|
||||
model_metadata = None
|
||||
|
||||
if model_info:
|
||||
provider = model_info.get("provider")
|
||||
if provider:
|
||||
model_provider = provider.get_provider_type().value
|
||||
model_name = model_info.get("model_name")
|
||||
model_response = model_info.get("model_response")
|
||||
if model_response:
|
||||
model_metadata = {
|
||||
"usage": model_response.usage,
|
||||
"metadata": model_response.metadata
|
||||
}
|
||||
|
||||
success = add_turn(
|
||||
continuation_id,
|
||||
"assistant",
|
||||
formatted_content,
|
||||
files=request_files,
|
||||
tool_name=self.name,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
model_metadata=model_metadata,
|
||||
)
|
||||
if not success:
|
||||
logging.warning(f"Failed to add turn to thread {continuation_id} for {self.name}")
|
||||
@@ -820,7 +973,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
|
||||
return None
|
||||
|
||||
def _create_follow_up_response(self, content: str, follow_up_data: dict, request) -> ToolOutput:
|
||||
def _create_follow_up_response(self, content: str, follow_up_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
|
||||
"""
|
||||
Create a response with follow-up question for conversation threading.
|
||||
|
||||
@@ -832,56 +985,57 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
Returns:
|
||||
ToolOutput configured for conversation continuation
|
||||
"""
|
||||
# Create or get thread ID
|
||||
# Always create a new thread (with parent linkage if continuation)
|
||||
continuation_id = getattr(request, "continuation_id", None)
|
||||
request_files = getattr(request, "files", []) or []
|
||||
|
||||
try:
|
||||
# Create new thread with parent linkage if continuing
|
||||
thread_id = create_thread(
|
||||
tool_name=self.name,
|
||||
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
|
||||
parent_thread_id=continuation_id # Link to parent thread if continuing
|
||||
)
|
||||
|
||||
if continuation_id:
|
||||
# This is a continuation - add this turn to existing thread
|
||||
request_files = getattr(request, "files", []) or []
|
||||
success = add_turn(
|
||||
continuation_id,
|
||||
# Add the assistant's response with follow-up
|
||||
# Extract model metadata
|
||||
model_provider = None
|
||||
model_name = None
|
||||
model_metadata = None
|
||||
|
||||
if model_info:
|
||||
provider = model_info.get("provider")
|
||||
if provider:
|
||||
model_provider = provider.get_provider_type().value
|
||||
model_name = model_info.get("model_name")
|
||||
model_response = model_info.get("model_response")
|
||||
if model_response:
|
||||
model_metadata = {
|
||||
"usage": model_response.usage,
|
||||
"metadata": model_response.metadata
|
||||
}
|
||||
|
||||
add_turn(
|
||||
thread_id, # Add to the new thread
|
||||
"assistant",
|
||||
content,
|
||||
follow_up_question=follow_up_data.get("follow_up_question"),
|
||||
files=request_files,
|
||||
tool_name=self.name,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
model_metadata=model_metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
# Threading failed, return normal response
|
||||
logger = logging.getLogger(f"tools.{self.name}")
|
||||
logger.warning(f"Follow-up threading failed in {self.name}: {str(e)}")
|
||||
return ToolOutput(
|
||||
status="success",
|
||||
content=content,
|
||||
content_type="markdown",
|
||||
metadata={"tool_name": self.name, "follow_up_error": str(e)},
|
||||
)
|
||||
if not success:
|
||||
# Thread not found or at limit, return normal response
|
||||
return ToolOutput(
|
||||
status="success",
|
||||
content=content,
|
||||
content_type="markdown",
|
||||
metadata={"tool_name": self.name},
|
||||
)
|
||||
thread_id = continuation_id
|
||||
else:
|
||||
# Create new thread
|
||||
try:
|
||||
thread_id = create_thread(
|
||||
tool_name=self.name, initial_request=request.model_dump() if hasattr(request, "model_dump") else {}
|
||||
)
|
||||
|
||||
# Add the assistant's response with follow-up
|
||||
request_files = getattr(request, "files", []) or []
|
||||
add_turn(
|
||||
thread_id,
|
||||
"assistant",
|
||||
content,
|
||||
follow_up_question=follow_up_data.get("follow_up_question"),
|
||||
files=request_files,
|
||||
tool_name=self.name,
|
||||
)
|
||||
except Exception as e:
|
||||
# Threading failed, return normal response
|
||||
logger = logging.getLogger(f"tools.{self.name}")
|
||||
logger.warning(f"Follow-up threading failed in {self.name}: {str(e)}")
|
||||
return ToolOutput(
|
||||
status="success",
|
||||
content=content,
|
||||
content_type="markdown",
|
||||
metadata={"tool_name": self.name, "follow_up_error": str(e)},
|
||||
)
|
||||
|
||||
# Create follow-up request
|
||||
follow_up_request = FollowUpRequest(
|
||||
@@ -925,13 +1079,14 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
|
||||
try:
|
||||
if continuation_id:
|
||||
# Check remaining turns in existing thread
|
||||
from utils.conversation_memory import get_thread
|
||||
# Check remaining turns in thread chain
|
||||
from utils.conversation_memory import get_thread_chain
|
||||
|
||||
context = get_thread(continuation_id)
|
||||
if context:
|
||||
current_turns = len(context.turns)
|
||||
remaining_turns = MAX_CONVERSATION_TURNS - current_turns - 1 # -1 for this response
|
||||
chain = get_thread_chain(continuation_id)
|
||||
if chain:
|
||||
# Count total turns across all threads in chain
|
||||
total_turns = sum(len(thread.turns) for thread in chain)
|
||||
remaining_turns = MAX_CONVERSATION_TURNS - total_turns - 1 # -1 for this response
|
||||
else:
|
||||
# Thread not found, don't offer continuation
|
||||
return None
|
||||
@@ -949,7 +1104,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
# If anything fails, don't offer continuation
|
||||
return None
|
||||
|
||||
def _create_continuation_offer_response(self, content: str, continuation_data: dict, request) -> ToolOutput:
|
||||
def _create_continuation_offer_response(self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
|
||||
"""
|
||||
Create a response offering Claude the opportunity to continue conversation.
|
||||
|
||||
@@ -962,14 +1117,43 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
ToolOutput configured with continuation offer
|
||||
"""
|
||||
try:
|
||||
# Create new thread for potential continuation
|
||||
# Create new thread for potential continuation (with parent link if continuing)
|
||||
continuation_id = getattr(request, "continuation_id", None)
|
||||
thread_id = create_thread(
|
||||
tool_name=self.name, initial_request=request.model_dump() if hasattr(request, "model_dump") else {}
|
||||
tool_name=self.name,
|
||||
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
|
||||
parent_thread_id=continuation_id # Link to parent if this is a continuation
|
||||
)
|
||||
|
||||
# Add this response as the first turn (assistant turn)
|
||||
request_files = getattr(request, "files", []) or []
|
||||
add_turn(thread_id, "assistant", content, files=request_files, tool_name=self.name)
|
||||
# Extract model metadata
|
||||
model_provider = None
|
||||
model_name = None
|
||||
model_metadata = None
|
||||
|
||||
if model_info:
|
||||
provider = model_info.get("provider")
|
||||
if provider:
|
||||
model_provider = provider.get_provider_type().value
|
||||
model_name = model_info.get("model_name")
|
||||
model_response = model_info.get("model_response")
|
||||
if model_response:
|
||||
model_metadata = {
|
||||
"usage": model_response.usage,
|
||||
"metadata": model_response.metadata
|
||||
}
|
||||
|
||||
add_turn(
|
||||
thread_id,
|
||||
"assistant",
|
||||
content,
|
||||
files=request_files,
|
||||
tool_name=self.name,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
model_metadata=model_metadata,
|
||||
)
|
||||
|
||||
# Create continuation offer
|
||||
remaining_turns = continuation_data["remaining_turns"]
|
||||
@@ -1022,7 +1206,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
"""
|
||||
pass
|
||||
|
||||
def format_response(self, response: str, request) -> str:
|
||||
def format_response(self, response: str, request, model_info: Optional[dict] = None) -> str:
|
||||
"""
|
||||
Format the model's response for display.
|
||||
|
||||
@@ -1033,6 +1217,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
Args:
|
||||
response: The raw response from the model
|
||||
request: The original request for context
|
||||
model_info: Optional dict with model metadata (provider, model_name, model_response)
|
||||
|
||||
Returns:
|
||||
str: Formatted response
|
||||
@@ -1059,154 +1244,41 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
||||
f"{context_type} too large (~{estimated_tokens:,} tokens). Maximum is {MAX_CONTEXT_TOKENS:,} tokens."
|
||||
)
|
||||
|
||||
def create_model(self, model_name: str, temperature: float, thinking_mode: str = "medium"):
|
||||
def get_model_provider(self, model_name: str) -> ModelProvider:
|
||||
"""
|
||||
Create a configured Gemini model instance.
|
||||
|
||||
This method handles model creation with appropriate settings including
|
||||
temperature and thinking budget configuration for models that support it.
|
||||
Get a model provider for the specified model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the Gemini model to use (or shorthand like 'flash', 'pro')
|
||||
temperature: Temperature setting for response generation
|
||||
thinking_mode: Thinking depth mode (affects computational budget)
|
||||
model_name: Name of the model to use (can be provider-specific or generic)
|
||||
|
||||
Returns:
|
||||
Model instance configured and ready for generation
|
||||
ModelProvider instance configured for the model
|
||||
|
||||
Raises:
|
||||
ValueError: If no provider supports the requested model
|
||||
"""
|
||||
# Define model shorthands for user convenience
|
||||
model_shorthands = {
|
||||
"pro": "gemini-2.5-pro-preview-06-05",
|
||||
"flash": "gemini-2.0-flash-exp",
|
||||
}
|
||||
|
||||
# Resolve shorthand to full model name
|
||||
resolved_model_name = model_shorthands.get(model_name.lower(), model_name)
|
||||
|
||||
# Map thinking modes to computational budget values
|
||||
# Higher budgets allow for more complex reasoning but increase latency
|
||||
thinking_budgets = {
|
||||
"minimal": 128, # Minimum for 2.5 Pro - fast responses
|
||||
"low": 2048, # Light reasoning tasks
|
||||
"medium": 8192, # Balanced reasoning (default)
|
||||
"high": 16384, # Complex analysis
|
||||
"max": 32768, # Maximum reasoning depth
|
||||
}
|
||||
|
||||
thinking_budget = thinking_budgets.get(thinking_mode, 8192)
|
||||
|
||||
# Gemini 2.5 models support thinking configuration for enhanced reasoning
|
||||
# Skip special handling in test environment to allow mocking
|
||||
if "2.5" in resolved_model_name and not os.environ.get("PYTEST_CURRENT_TEST"):
|
||||
try:
|
||||
# Retrieve API key for Gemini client creation
|
||||
api_key = os.environ.get("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("GEMINI_API_KEY environment variable is required")
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Create a wrapper class to provide a consistent interface
|
||||
# This abstracts the differences between API versions
|
||||
class ModelWrapper:
|
||||
def __init__(self, client, model_name, temperature, thinking_budget):
|
||||
self.client = client
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.thinking_budget = thinking_budget
|
||||
|
||||
def generate_content(self, prompt):
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model_name,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=self.temperature,
|
||||
candidate_count=1,
|
||||
thinking_config=types.ThinkingConfig(thinking_budget=self.thinking_budget),
|
||||
),
|
||||
)
|
||||
|
||||
# Wrap the response to match the expected format
|
||||
# This ensures compatibility across different API versions
|
||||
class ResponseWrapper:
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
self.candidates = [
|
||||
type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"content": type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"parts": [
|
||||
type(
|
||||
"obj",
|
||||
(object,),
|
||||
{"text": text},
|
||||
)
|
||||
]
|
||||
},
|
||||
)(),
|
||||
"finish_reason": "STOP",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
return ResponseWrapper(response.text)
|
||||
|
||||
return ModelWrapper(client, resolved_model_name, temperature, thinking_budget)
|
||||
|
||||
except Exception:
|
||||
# Fall back to regular API if thinking configuration fails
|
||||
# This ensures the tool remains functional even with API changes
|
||||
pass
|
||||
|
||||
# For models that don't support thinking configuration, use standard API
|
||||
api_key = os.environ.get("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("GEMINI_API_KEY environment variable is required")
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Create a simple wrapper for models without thinking configuration
|
||||
# This provides the same interface as the thinking-enabled wrapper
|
||||
class SimpleModelWrapper:
|
||||
def __init__(self, client, model_name, temperature):
|
||||
self.client = client
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
|
||||
def generate_content(self, prompt):
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model_name,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=self.temperature,
|
||||
candidate_count=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Convert to match expected format
|
||||
class ResponseWrapper:
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
self.candidates = [
|
||||
type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"content": type(
|
||||
"obj",
|
||||
(object,),
|
||||
{"parts": [type("obj", (object,), {"text": text})]},
|
||||
)(),
|
||||
"finish_reason": "STOP",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
return ResponseWrapper(response.text)
|
||||
|
||||
return SimpleModelWrapper(client, resolved_model_name, temperature)
|
||||
# Get provider from registry
|
||||
provider = ModelProviderRegistry.get_provider_for_model(model_name)
|
||||
|
||||
if not provider:
|
||||
# Try to determine provider from model name patterns
|
||||
if "gemini" in model_name.lower() or model_name.lower() in ["flash", "pro"]:
|
||||
# Register Gemini provider if not already registered
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.base import ProviderType
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
||||
elif "gpt" in model_name.lower() or "o3" in model_name.lower():
|
||||
# Register OpenAI provider if not already registered
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.base import ProviderType
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.OPENAI)
|
||||
|
||||
if not provider:
|
||||
raise ValueError(
|
||||
f"No provider found for model '{model_name}'. "
|
||||
f"Ensure the appropriate API key is set and the model name is correct."
|
||||
)
|
||||
|
||||
return provider
|
||||
|
||||
Reference in New Issue
Block a user