|
|
|
|
@@ -1,5 +1,5 @@
|
|
|
|
|
"""
|
|
|
|
|
Base class for all Gemini MCP tools
|
|
|
|
|
Base class for all Zen MCP tools
|
|
|
|
|
|
|
|
|
|
This module provides the abstract base class that all tools must inherit from.
|
|
|
|
|
It defines the contract that tools must implement and provides common functionality
|
|
|
|
|
@@ -24,8 +24,8 @@ from mcp.types import TextContent
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
from config import DEFAULT_MODEL, MAX_CONTEXT_TOKENS, MCP_PROMPT_SIZE_LIMIT
|
|
|
|
|
from providers import ModelProvider, ModelProviderRegistry
|
|
|
|
|
from utils import check_token_limit
|
|
|
|
|
from providers import ModelProviderRegistry, ModelProvider, ModelResponse
|
|
|
|
|
from utils.conversation_memory import (
|
|
|
|
|
MAX_CONVERSATION_TURNS,
|
|
|
|
|
add_turn,
|
|
|
|
|
@@ -146,21 +146,21 @@ class BaseTool(ABC):
|
|
|
|
|
def get_model_field_schema(self) -> dict[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Generate the model field schema based on auto mode configuration.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
When auto mode is enabled, the model parameter becomes required
|
|
|
|
|
and includes detailed descriptions of each model's capabilities.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dict containing the model field JSON schema
|
|
|
|
|
"""
|
|
|
|
|
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if IS_AUTO_MODE:
|
|
|
|
|
# In auto mode, model is required and we provide detailed descriptions
|
|
|
|
|
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
|
|
|
|
|
for model, desc in MODEL_CAPABILITIES_DESC.items():
|
|
|
|
|
model_desc_parts.append(f"- '{model}': {desc}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"type": "string",
|
|
|
|
|
"description": "\n".join(model_desc_parts),
|
|
|
|
|
@@ -169,12 +169,12 @@ class BaseTool(ABC):
|
|
|
|
|
else:
|
|
|
|
|
# Normal mode - model is optional with default
|
|
|
|
|
available_models = list(MODEL_CAPABILITIES_DESC.keys())
|
|
|
|
|
models_str = ', '.join(f"'{m}'" for m in available_models)
|
|
|
|
|
models_str = ", ".join(f"'{m}'" for m in available_models)
|
|
|
|
|
return {
|
|
|
|
|
"type": "string",
|
|
|
|
|
"type": "string",
|
|
|
|
|
"description": f"Model to use. Available: {models_str}. Defaults to '{DEFAULT_MODEL}' if not specified.",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_default_temperature(self) -> float:
|
|
|
|
|
"""
|
|
|
|
|
Return the default temperature setting for this tool.
|
|
|
|
|
@@ -257,9 +257,7 @@ class BaseTool(ABC):
|
|
|
|
|
# Safety check: If no files are marked as embedded but we have a continuation_id,
|
|
|
|
|
# this might indicate an issue with conversation history. Be conservative.
|
|
|
|
|
if not embedded_files:
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"{self.name} tool: No files found in conversation history for thread {continuation_id}"
|
|
|
|
|
)
|
|
|
|
|
logger.debug(f"{self.name} tool: No files found in conversation history for thread {continuation_id}")
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"[FILES] {self.name}: No embedded files found, returning all {len(requested_files)} requested files"
|
|
|
|
|
)
|
|
|
|
|
@@ -324,7 +322,7 @@ class BaseTool(ABC):
|
|
|
|
|
"""
|
|
|
|
|
if not request_files:
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Note: Even if conversation history is already embedded, we still need to process
|
|
|
|
|
# any NEW files that aren't in the conversation history yet. The filter_new_files
|
|
|
|
|
# method will correctly identify which files need to be embedded.
|
|
|
|
|
@@ -345,48 +343,60 @@ class BaseTool(ABC):
|
|
|
|
|
# First check if model_context was passed from server.py
|
|
|
|
|
model_context = None
|
|
|
|
|
if arguments:
|
|
|
|
|
model_context = arguments.get("_model_context") or getattr(self, "_current_arguments", {}).get("_model_context")
|
|
|
|
|
|
|
|
|
|
model_context = arguments.get("_model_context") or getattr(self, "_current_arguments", {}).get(
|
|
|
|
|
"_model_context"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if model_context:
|
|
|
|
|
# Use the passed model context
|
|
|
|
|
try:
|
|
|
|
|
token_allocation = model_context.calculate_token_allocation()
|
|
|
|
|
effective_max_tokens = token_allocation.file_tokens - reserve_tokens
|
|
|
|
|
logger.debug(f"[FILES] {self.name}: Using passed model context for {model_context.model_name}: "
|
|
|
|
|
f"{token_allocation.file_tokens:,} file tokens from {token_allocation.total_tokens:,} total")
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"[FILES] {self.name}: Using passed model context for {model_context.model_name}: "
|
|
|
|
|
f"{token_allocation.file_tokens:,} file tokens from {token_allocation.total_tokens:,} total"
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"[FILES] {self.name}: Error using passed model context: {e}")
|
|
|
|
|
# Fall through to manual calculation
|
|
|
|
|
model_context = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not model_context:
|
|
|
|
|
# Manual calculation as fallback
|
|
|
|
|
model_name = getattr(self, "_current_model_name", None) or DEFAULT_MODEL
|
|
|
|
|
try:
|
|
|
|
|
provider = self.get_model_provider(model_name)
|
|
|
|
|
capabilities = provider.get_capabilities(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Calculate content allocation based on model capacity
|
|
|
|
|
if capabilities.max_tokens < 300_000:
|
|
|
|
|
# Smaller context models: 60% content, 40% response
|
|
|
|
|
model_content_tokens = int(capabilities.max_tokens * 0.6)
|
|
|
|
|
else:
|
|
|
|
|
# Larger context models: 80% content, 20% response
|
|
|
|
|
# Larger context models: 80% content, 20% response
|
|
|
|
|
model_content_tokens = int(capabilities.max_tokens * 0.8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
effective_max_tokens = model_content_tokens - reserve_tokens
|
|
|
|
|
logger.debug(f"[FILES] {self.name}: Using model-specific limit for {model_name}: "
|
|
|
|
|
f"{model_content_tokens:,} content tokens from {capabilities.max_tokens:,} total")
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"[FILES] {self.name}: Using model-specific limit for {model_name}: "
|
|
|
|
|
f"{model_content_tokens:,} content tokens from {capabilities.max_tokens:,} total"
|
|
|
|
|
)
|
|
|
|
|
except (ValueError, AttributeError) as e:
|
|
|
|
|
# Handle specific errors: provider not found, model not supported, missing attributes
|
|
|
|
|
logger.warning(f"[FILES] {self.name}: Could not get model capabilities for {model_name}: {type(e).__name__}: {e}")
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"[FILES] {self.name}: Could not get model capabilities for {model_name}: {type(e).__name__}: {e}"
|
|
|
|
|
)
|
|
|
|
|
# Fall back to conservative default for safety
|
|
|
|
|
from config import MAX_CONTENT_TOKENS
|
|
|
|
|
|
|
|
|
|
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# Catch any other unexpected errors
|
|
|
|
|
logger.error(f"[FILES] {self.name}: Unexpected error getting model capabilities: {type(e).__name__}: {e}")
|
|
|
|
|
logger.error(
|
|
|
|
|
f"[FILES] {self.name}: Unexpected error getting model capabilities: {type(e).__name__}: {e}"
|
|
|
|
|
)
|
|
|
|
|
from config import MAX_CONTENT_TOKENS
|
|
|
|
|
|
|
|
|
|
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
|
|
|
|
|
|
|
|
|
|
# Ensure we have a reasonable minimum budget
|
|
|
|
|
@@ -394,12 +404,16 @@ class BaseTool(ABC):
|
|
|
|
|
|
|
|
|
|
files_to_embed = self.filter_new_files(request_files, continuation_id)
|
|
|
|
|
logger.debug(f"[FILES] {self.name}: Will embed {len(files_to_embed)} files after filtering")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Log the specific files for debugging/testing
|
|
|
|
|
if files_to_embed:
|
|
|
|
|
logger.info(f"[FILE_PROCESSING] {self.name} tool will embed new files: {', '.join([os.path.basename(f) for f in files_to_embed])}")
|
|
|
|
|
logger.info(
|
|
|
|
|
f"[FILE_PROCESSING] {self.name} tool will embed new files: {', '.join([os.path.basename(f) for f in files_to_embed])}"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
logger.info(f"[FILE_PROCESSING] {self.name} tool: No new files to embed (all files already in conversation history)")
|
|
|
|
|
logger.info(
|
|
|
|
|
f"[FILE_PROCESSING] {self.name} tool: No new files to embed (all files already in conversation history)"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
content_parts = []
|
|
|
|
|
|
|
|
|
|
@@ -688,20 +702,20 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
|
|
|
|
|
# Check if we have continuation_id - if so, conversation history is already embedded
|
|
|
|
|
continuation_id = getattr(request, "continuation_id", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if continuation_id:
|
|
|
|
|
# When continuation_id is present, server.py has already injected the
|
|
|
|
|
# conversation history into the appropriate field. We need to check if
|
|
|
|
|
# the prompt already contains conversation history marker.
|
|
|
|
|
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Store the original arguments to detect enhanced prompts
|
|
|
|
|
self._has_embedded_history = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Check if conversation history is already embedded in the prompt field
|
|
|
|
|
field_value = getattr(request, "prompt", "")
|
|
|
|
|
field_name = "prompt"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "=== CONVERSATION HISTORY ===" in field_value:
|
|
|
|
|
# Conversation history is already embedded, use it directly
|
|
|
|
|
prompt = field_value
|
|
|
|
|
@@ -714,9 +728,10 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
else:
|
|
|
|
|
# New conversation, prepare prompt normally
|
|
|
|
|
prompt = await self.prepare_prompt(request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Add follow-up instructions for new conversations
|
|
|
|
|
from server import get_follow_up_instructions
|
|
|
|
|
|
|
|
|
|
follow_up_instructions = get_follow_up_instructions(0) # New conversation, turn 0
|
|
|
|
|
prompt = f"{prompt}\n\n{follow_up_instructions}"
|
|
|
|
|
logger.debug(f"Added follow-up instructions for new {self.name} conversation")
|
|
|
|
|
@@ -725,9 +740,10 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
model_name = getattr(request, "model", None)
|
|
|
|
|
if not model_name:
|
|
|
|
|
model_name = DEFAULT_MODEL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# In auto mode, model parameter is required
|
|
|
|
|
from config import IS_AUTO_MODE
|
|
|
|
|
|
|
|
|
|
if IS_AUTO_MODE and model_name.lower() == "auto":
|
|
|
|
|
error_output = ToolOutput(
|
|
|
|
|
status="error",
|
|
|
|
|
@@ -735,10 +751,10 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
content_type="text",
|
|
|
|
|
)
|
|
|
|
|
return [TextContent(type="text", text=error_output.model_dump_json())]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Store model name for use by helper methods like _prepare_file_content_for_prompt
|
|
|
|
|
self._current_model_name = model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temperature = getattr(request, "temperature", None)
|
|
|
|
|
if temperature is None:
|
|
|
|
|
temperature = self.get_default_temperature()
|
|
|
|
|
@@ -748,14 +764,14 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
|
|
|
|
|
# Get the appropriate model provider
|
|
|
|
|
provider = self.get_model_provider(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Validate and correct temperature for this model
|
|
|
|
|
temperature, temp_warnings = self._validate_and_correct_temperature(model_name, temperature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Log any temperature corrections
|
|
|
|
|
for warning in temp_warnings:
|
|
|
|
|
logger.warning(warning)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Get system prompt for this tool
|
|
|
|
|
system_prompt = self.get_system_prompt()
|
|
|
|
|
|
|
|
|
|
@@ -763,16 +779,16 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.name}")
|
|
|
|
|
logger.info(f"Using model: {model_name} via {provider.get_provider_type().value} provider")
|
|
|
|
|
logger.debug(f"Prompt length: {len(prompt)} characters")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Generate content with provider abstraction
|
|
|
|
|
model_response = provider.generate_content(
|
|
|
|
|
prompt=prompt,
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
system_prompt=system_prompt,
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
|
|
|
|
|
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Received response from {provider.get_provider_type().value} API for {self.name}")
|
|
|
|
|
|
|
|
|
|
# Process the model's response
|
|
|
|
|
@@ -781,11 +797,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
|
|
|
|
|
# Parse response to check for clarification requests or format output
|
|
|
|
|
# Pass model info for conversation tracking
|
|
|
|
|
model_info = {
|
|
|
|
|
"provider": provider,
|
|
|
|
|
"model_name": model_name,
|
|
|
|
|
"model_response": model_response
|
|
|
|
|
}
|
|
|
|
|
model_info = {"provider": provider, "model_name": model_name, "model_response": model_response}
|
|
|
|
|
tool_output = self._parse_response(raw_text, request, model_info)
|
|
|
|
|
logger.info(f"Successfully completed {self.name} tool execution")
|
|
|
|
|
|
|
|
|
|
@@ -819,15 +831,15 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
system_prompt=system_prompt,
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
|
|
|
|
|
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if retry_response.content:
|
|
|
|
|
# If successful, process normally
|
|
|
|
|
retry_model_info = {
|
|
|
|
|
"provider": provider,
|
|
|
|
|
"model_name": model_name,
|
|
|
|
|
"model_response": retry_response
|
|
|
|
|
"model_response": retry_response,
|
|
|
|
|
}
|
|
|
|
|
tool_output = self._parse_response(retry_response.content, request, retry_model_info)
|
|
|
|
|
return [TextContent(type="text", text=tool_output.model_dump_json())]
|
|
|
|
|
@@ -916,7 +928,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
model_provider = None
|
|
|
|
|
model_name = None
|
|
|
|
|
model_metadata = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_info:
|
|
|
|
|
provider = model_info.get("provider")
|
|
|
|
|
if provider:
|
|
|
|
|
@@ -924,11 +936,8 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
model_name = model_info.get("model_name")
|
|
|
|
|
model_response = model_info.get("model_response")
|
|
|
|
|
if model_response:
|
|
|
|
|
model_metadata = {
|
|
|
|
|
"usage": model_response.usage,
|
|
|
|
|
"metadata": model_response.metadata
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
|
|
|
|
|
|
|
|
|
|
success = add_turn(
|
|
|
|
|
continuation_id,
|
|
|
|
|
"assistant",
|
|
|
|
|
@@ -986,7 +995,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _create_follow_up_response(self, content: str, follow_up_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
|
|
|
|
|
def _create_follow_up_response(
|
|
|
|
|
self, content: str, follow_up_data: dict, request, model_info: Optional[dict] = None
|
|
|
|
|
) -> ToolOutput:
|
|
|
|
|
"""
|
|
|
|
|
Create a response with follow-up question for conversation threading.
|
|
|
|
|
|
|
|
|
|
@@ -1001,13 +1012,13 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
# Always create a new thread (with parent linkage if continuation)
|
|
|
|
|
continuation_id = getattr(request, "continuation_id", None)
|
|
|
|
|
request_files = getattr(request, "files", []) or []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Create new thread with parent linkage if continuing
|
|
|
|
|
thread_id = create_thread(
|
|
|
|
|
tool_name=self.name,
|
|
|
|
|
tool_name=self.name,
|
|
|
|
|
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
|
|
|
|
|
parent_thread_id=continuation_id # Link to parent thread if continuing
|
|
|
|
|
parent_thread_id=continuation_id, # Link to parent thread if continuing
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add the assistant's response with follow-up
|
|
|
|
|
@@ -1015,7 +1026,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
model_provider = None
|
|
|
|
|
model_name = None
|
|
|
|
|
model_metadata = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_info:
|
|
|
|
|
provider = model_info.get("provider")
|
|
|
|
|
if provider:
|
|
|
|
|
@@ -1023,11 +1034,8 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
model_name = model_info.get("model_name")
|
|
|
|
|
model_response = model_info.get("model_response")
|
|
|
|
|
if model_response:
|
|
|
|
|
model_metadata = {
|
|
|
|
|
"usage": model_response.usage,
|
|
|
|
|
"metadata": model_response.metadata
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
|
|
|
|
|
|
|
|
|
|
add_turn(
|
|
|
|
|
thread_id, # Add to the new thread
|
|
|
|
|
"assistant",
|
|
|
|
|
@@ -1088,6 +1096,12 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
Returns:
|
|
|
|
|
Dict with continuation data if opportunity should be offered, None otherwise
|
|
|
|
|
"""
|
|
|
|
|
# Skip continuation offers in test mode
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
if os.getenv("PYTEST_CURRENT_TEST"):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
continuation_id = getattr(request, "continuation_id", None)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
@@ -1117,7 +1131,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
# If anything fails, don't offer continuation
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _create_continuation_offer_response(self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
|
|
|
|
|
def _create_continuation_offer_response(
|
|
|
|
|
self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None
|
|
|
|
|
) -> ToolOutput:
|
|
|
|
|
"""
|
|
|
|
|
Create a response offering Claude the opportunity to continue conversation.
|
|
|
|
|
|
|
|
|
|
@@ -1133,9 +1149,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
# Create new thread for potential continuation (with parent link if continuing)
|
|
|
|
|
continuation_id = getattr(request, "continuation_id", None)
|
|
|
|
|
thread_id = create_thread(
|
|
|
|
|
tool_name=self.name,
|
|
|
|
|
tool_name=self.name,
|
|
|
|
|
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
|
|
|
|
|
parent_thread_id=continuation_id # Link to parent if this is a continuation
|
|
|
|
|
parent_thread_id=continuation_id, # Link to parent if this is a continuation
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add this response as the first turn (assistant turn)
|
|
|
|
|
@@ -1144,7 +1160,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
model_provider = None
|
|
|
|
|
model_name = None
|
|
|
|
|
model_metadata = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_info:
|
|
|
|
|
provider = model_info.get("provider")
|
|
|
|
|
if provider:
|
|
|
|
|
@@ -1152,16 +1168,13 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
model_name = model_info.get("model_name")
|
|
|
|
|
model_response = model_info.get("model_response")
|
|
|
|
|
if model_response:
|
|
|
|
|
model_metadata = {
|
|
|
|
|
"usage": model_response.usage,
|
|
|
|
|
"metadata": model_response.metadata
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
|
|
|
|
|
|
|
|
|
|
add_turn(
|
|
|
|
|
thread_id,
|
|
|
|
|
"assistant",
|
|
|
|
|
content,
|
|
|
|
|
files=request_files,
|
|
|
|
|
thread_id,
|
|
|
|
|
"assistant",
|
|
|
|
|
content,
|
|
|
|
|
files=request_files,
|
|
|
|
|
tool_name=self.name,
|
|
|
|
|
model_provider=model_provider,
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
@@ -1260,11 +1273,11 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
def _validate_and_correct_temperature(self, model_name: str, temperature: float) -> tuple[float, list[str]]:
|
|
|
|
|
"""
|
|
|
|
|
Validate and correct temperature for the specified model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model_name: Name of the model to validate temperature for
|
|
|
|
|
temperature: Temperature value to validate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple of (corrected_temperature, warning_messages)
|
|
|
|
|
"""
|
|
|
|
|
@@ -1272,9 +1285,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
provider = self.get_model_provider(model_name)
|
|
|
|
|
capabilities = provider.get_capabilities(model_name)
|
|
|
|
|
constraint = capabilities.temperature_constraint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
warnings = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not constraint.validate(temperature):
|
|
|
|
|
corrected = constraint.get_corrected_value(temperature)
|
|
|
|
|
warning = (
|
|
|
|
|
@@ -1283,9 +1296,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
)
|
|
|
|
|
warnings.append(warning)
|
|
|
|
|
return corrected, warnings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return temperature, warnings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# If validation fails for any reason, use the original temperature
|
|
|
|
|
# and log a warning (but don't fail the request)
|
|
|
|
|
@@ -1308,26 +1321,28 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|
|
|
|
"""
|
|
|
|
|
# Get provider from registry
|
|
|
|
|
provider = ModelProviderRegistry.get_provider_for_model(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not provider:
|
|
|
|
|
# Try to determine provider from model name patterns
|
|
|
|
|
if "gemini" in model_name.lower() or model_name.lower() in ["flash", "pro"]:
|
|
|
|
|
# Register Gemini provider if not already registered
|
|
|
|
|
from providers.gemini import GeminiModelProvider
|
|
|
|
|
from providers.base import ProviderType
|
|
|
|
|
from providers.gemini import GeminiModelProvider
|
|
|
|
|
|
|
|
|
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
|
|
|
|
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
|
|
|
|
elif "gpt" in model_name.lower() or "o3" in model_name.lower():
|
|
|
|
|
# Register OpenAI provider if not already registered
|
|
|
|
|
from providers.openai import OpenAIModelProvider
|
|
|
|
|
from providers.base import ProviderType
|
|
|
|
|
from providers.openai import OpenAIModelProvider
|
|
|
|
|
|
|
|
|
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
|
|
|
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENAI)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not provider:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"No provider found for model '{model_name}'. "
|
|
|
|
|
f"Ensure the appropriate API key is set and the model name is correct."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return provider
|
|
|
|
|
|