Rebranding, refactoring, renaming, cleanup, updated docs

This commit is contained in:
Fahad
2025-06-12 10:40:43 +04:00
parent 9a55ca8898
commit fb66825bf6
55 changed files with 1048 additions and 1474 deletions

View File

@@ -1,5 +1,5 @@
"""
Base class for all Gemini MCP tools
Base class for all Zen MCP tools
This module provides the abstract base class that all tools must inherit from.
It defines the contract that tools must implement and provides common functionality
@@ -24,8 +24,8 @@ from mcp.types import TextContent
from pydantic import BaseModel, Field
from config import DEFAULT_MODEL, MAX_CONTEXT_TOKENS, MCP_PROMPT_SIZE_LIMIT
from providers import ModelProvider, ModelProviderRegistry
from utils import check_token_limit
from providers import ModelProviderRegistry, ModelProvider, ModelResponse
from utils.conversation_memory import (
MAX_CONVERSATION_TURNS,
add_turn,
@@ -146,21 +146,21 @@ class BaseTool(ABC):
def get_model_field_schema(self) -> dict[str, Any]:
"""
Generate the model field schema based on auto mode configuration.
When auto mode is enabled, the model parameter becomes required
and includes detailed descriptions of each model's capabilities.
Returns:
Dict containing the model field JSON schema
"""
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
if IS_AUTO_MODE:
# In auto mode, model is required and we provide detailed descriptions
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
for model, desc in MODEL_CAPABILITIES_DESC.items():
model_desc_parts.append(f"- '{model}': {desc}")
return {
"type": "string",
"description": "\n".join(model_desc_parts),
@@ -169,12 +169,12 @@ class BaseTool(ABC):
else:
# Normal mode - model is optional with default
available_models = list(MODEL_CAPABILITIES_DESC.keys())
models_str = ', '.join(f"'{m}'" for m in available_models)
models_str = ", ".join(f"'{m}'" for m in available_models)
return {
"type": "string",
"type": "string",
"description": f"Model to use. Available: {models_str}. Defaults to '{DEFAULT_MODEL}' if not specified.",
}
def get_default_temperature(self) -> float:
"""
Return the default temperature setting for this tool.
@@ -257,9 +257,7 @@ class BaseTool(ABC):
# Safety check: If no files are marked as embedded but we have a continuation_id,
# this might indicate an issue with conversation history. Be conservative.
if not embedded_files:
logger.debug(
f"{self.name} tool: No files found in conversation history for thread {continuation_id}"
)
logger.debug(f"{self.name} tool: No files found in conversation history for thread {continuation_id}")
logger.debug(
f"[FILES] {self.name}: No embedded files found, returning all {len(requested_files)} requested files"
)
@@ -324,7 +322,7 @@ class BaseTool(ABC):
"""
if not request_files:
return ""
# Note: Even if conversation history is already embedded, we still need to process
# any NEW files that aren't in the conversation history yet. The filter_new_files
# method will correctly identify which files need to be embedded.
@@ -345,48 +343,60 @@ class BaseTool(ABC):
# First check if model_context was passed from server.py
model_context = None
if arguments:
model_context = arguments.get("_model_context") or getattr(self, "_current_arguments", {}).get("_model_context")
model_context = arguments.get("_model_context") or getattr(self, "_current_arguments", {}).get(
"_model_context"
)
if model_context:
# Use the passed model context
try:
token_allocation = model_context.calculate_token_allocation()
effective_max_tokens = token_allocation.file_tokens - reserve_tokens
logger.debug(f"[FILES] {self.name}: Using passed model context for {model_context.model_name}: "
f"{token_allocation.file_tokens:,} file tokens from {token_allocation.total_tokens:,} total")
logger.debug(
f"[FILES] {self.name}: Using passed model context for {model_context.model_name}: "
f"{token_allocation.file_tokens:,} file tokens from {token_allocation.total_tokens:,} total"
)
except Exception as e:
logger.warning(f"[FILES] {self.name}: Error using passed model context: {e}")
# Fall through to manual calculation
model_context = None
if not model_context:
# Manual calculation as fallback
model_name = getattr(self, "_current_model_name", None) or DEFAULT_MODEL
try:
provider = self.get_model_provider(model_name)
capabilities = provider.get_capabilities(model_name)
# Calculate content allocation based on model capacity
if capabilities.max_tokens < 300_000:
# Smaller context models: 60% content, 40% response
model_content_tokens = int(capabilities.max_tokens * 0.6)
else:
# Larger context models: 80% content, 20% response
# Larger context models: 80% content, 20% response
model_content_tokens = int(capabilities.max_tokens * 0.8)
effective_max_tokens = model_content_tokens - reserve_tokens
logger.debug(f"[FILES] {self.name}: Using model-specific limit for {model_name}: "
f"{model_content_tokens:,} content tokens from {capabilities.max_tokens:,} total")
logger.debug(
f"[FILES] {self.name}: Using model-specific limit for {model_name}: "
f"{model_content_tokens:,} content tokens from {capabilities.max_tokens:,} total"
)
except (ValueError, AttributeError) as e:
# Handle specific errors: provider not found, model not supported, missing attributes
logger.warning(f"[FILES] {self.name}: Could not get model capabilities for {model_name}: {type(e).__name__}: {e}")
logger.warning(
f"[FILES] {self.name}: Could not get model capabilities for {model_name}: {type(e).__name__}: {e}"
)
# Fall back to conservative default for safety
from config import MAX_CONTENT_TOKENS
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
except Exception as e:
# Catch any other unexpected errors
logger.error(f"[FILES] {self.name}: Unexpected error getting model capabilities: {type(e).__name__}: {e}")
logger.error(
f"[FILES] {self.name}: Unexpected error getting model capabilities: {type(e).__name__}: {e}"
)
from config import MAX_CONTENT_TOKENS
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
# Ensure we have a reasonable minimum budget
@@ -394,12 +404,16 @@ class BaseTool(ABC):
files_to_embed = self.filter_new_files(request_files, continuation_id)
logger.debug(f"[FILES] {self.name}: Will embed {len(files_to_embed)} files after filtering")
# Log the specific files for debugging/testing
if files_to_embed:
logger.info(f"[FILE_PROCESSING] {self.name} tool will embed new files: {', '.join([os.path.basename(f) for f in files_to_embed])}")
logger.info(
f"[FILE_PROCESSING] {self.name} tool will embed new files: {', '.join([os.path.basename(f) for f in files_to_embed])}"
)
else:
logger.info(f"[FILE_PROCESSING] {self.name} tool: No new files to embed (all files already in conversation history)")
logger.info(
f"[FILE_PROCESSING] {self.name} tool: No new files to embed (all files already in conversation history)"
)
content_parts = []
@@ -688,20 +702,20 @@ If any of these would strengthen your analysis, specify what Claude should searc
# Check if we have continuation_id - if so, conversation history is already embedded
continuation_id = getattr(request, "continuation_id", None)
if continuation_id:
# When continuation_id is present, server.py has already injected the
# conversation history into the appropriate field. We need to check if
# the prompt already contains conversation history marker.
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
# Store the original arguments to detect enhanced prompts
self._has_embedded_history = False
# Check if conversation history is already embedded in the prompt field
field_value = getattr(request, "prompt", "")
field_name = "prompt"
if "=== CONVERSATION HISTORY ===" in field_value:
# Conversation history is already embedded, use it directly
prompt = field_value
@@ -714,9 +728,10 @@ If any of these would strengthen your analysis, specify what Claude should searc
else:
# New conversation, prepare prompt normally
prompt = await self.prepare_prompt(request)
# Add follow-up instructions for new conversations
from server import get_follow_up_instructions
follow_up_instructions = get_follow_up_instructions(0) # New conversation, turn 0
prompt = f"{prompt}\n\n{follow_up_instructions}"
logger.debug(f"Added follow-up instructions for new {self.name} conversation")
@@ -725,9 +740,10 @@ If any of these would strengthen your analysis, specify what Claude should searc
model_name = getattr(request, "model", None)
if not model_name:
model_name = DEFAULT_MODEL
# In auto mode, model parameter is required
from config import IS_AUTO_MODE
if IS_AUTO_MODE and model_name.lower() == "auto":
error_output = ToolOutput(
status="error",
@@ -735,10 +751,10 @@ If any of these would strengthen your analysis, specify what Claude should searc
content_type="text",
)
return [TextContent(type="text", text=error_output.model_dump_json())]
# Store model name for use by helper methods like _prepare_file_content_for_prompt
self._current_model_name = model_name
temperature = getattr(request, "temperature", None)
if temperature is None:
temperature = self.get_default_temperature()
@@ -748,14 +764,14 @@ If any of these would strengthen your analysis, specify what Claude should searc
# Get the appropriate model provider
provider = self.get_model_provider(model_name)
# Validate and correct temperature for this model
temperature, temp_warnings = self._validate_and_correct_temperature(model_name, temperature)
# Log any temperature corrections
for warning in temp_warnings:
logger.warning(warning)
# Get system prompt for this tool
system_prompt = self.get_system_prompt()
@@ -763,16 +779,16 @@ If any of these would strengthen your analysis, specify what Claude should searc
logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.name}")
logger.info(f"Using model: {model_name} via {provider.get_provider_type().value} provider")
logger.debug(f"Prompt length: {len(prompt)} characters")
# Generate content with provider abstraction
model_response = provider.generate_content(
prompt=prompt,
model_name=model_name,
system_prompt=system_prompt,
temperature=temperature,
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None,
)
logger.info(f"Received response from {provider.get_provider_type().value} API for {self.name}")
# Process the model's response
@@ -781,11 +797,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
# Parse response to check for clarification requests or format output
# Pass model info for conversation tracking
model_info = {
"provider": provider,
"model_name": model_name,
"model_response": model_response
}
model_info = {"provider": provider, "model_name": model_name, "model_response": model_response}
tool_output = self._parse_response(raw_text, request, model_info)
logger.info(f"Successfully completed {self.name} tool execution")
@@ -819,15 +831,15 @@ If any of these would strengthen your analysis, specify what Claude should searc
model_name=model_name,
system_prompt=system_prompt,
temperature=temperature,
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None,
)
if retry_response.content:
# If successful, process normally
retry_model_info = {
"provider": provider,
"model_name": model_name,
"model_response": retry_response
"model_response": retry_response,
}
tool_output = self._parse_response(retry_response.content, request, retry_model_info)
return [TextContent(type="text", text=tool_output.model_dump_json())]
@@ -916,7 +928,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
model_provider = None
model_name = None
model_metadata = None
if model_info:
provider = model_info.get("provider")
if provider:
@@ -924,11 +936,8 @@ If any of these would strengthen your analysis, specify what Claude should searc
model_name = model_info.get("model_name")
model_response = model_info.get("model_response")
if model_response:
model_metadata = {
"usage": model_response.usage,
"metadata": model_response.metadata
}
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
success = add_turn(
continuation_id,
"assistant",
@@ -986,7 +995,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
return None
def _create_follow_up_response(self, content: str, follow_up_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
def _create_follow_up_response(
self, content: str, follow_up_data: dict, request, model_info: Optional[dict] = None
) -> ToolOutput:
"""
Create a response with follow-up question for conversation threading.
@@ -1001,13 +1012,13 @@ If any of these would strengthen your analysis, specify what Claude should searc
# Always create a new thread (with parent linkage if continuation)
continuation_id = getattr(request, "continuation_id", None)
request_files = getattr(request, "files", []) or []
try:
# Create new thread with parent linkage if continuing
thread_id = create_thread(
tool_name=self.name,
tool_name=self.name,
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
parent_thread_id=continuation_id # Link to parent thread if continuing
parent_thread_id=continuation_id, # Link to parent thread if continuing
)
# Add the assistant's response with follow-up
@@ -1015,7 +1026,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
model_provider = None
model_name = None
model_metadata = None
if model_info:
provider = model_info.get("provider")
if provider:
@@ -1023,11 +1034,8 @@ If any of these would strengthen your analysis, specify what Claude should searc
model_name = model_info.get("model_name")
model_response = model_info.get("model_response")
if model_response:
model_metadata = {
"usage": model_response.usage,
"metadata": model_response.metadata
}
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
add_turn(
thread_id, # Add to the new thread
"assistant",
@@ -1088,6 +1096,12 @@ If any of these would strengthen your analysis, specify what Claude should searc
Returns:
Dict with continuation data if opportunity should be offered, None otherwise
"""
# Skip continuation offers in test mode
import os
if os.getenv("PYTEST_CURRENT_TEST"):
return None
continuation_id = getattr(request, "continuation_id", None)
try:
@@ -1117,7 +1131,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
# If anything fails, don't offer continuation
return None
def _create_continuation_offer_response(self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
def _create_continuation_offer_response(
self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None
) -> ToolOutput:
"""
Create a response offering Claude the opportunity to continue conversation.
@@ -1133,9 +1149,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
# Create new thread for potential continuation (with parent link if continuing)
continuation_id = getattr(request, "continuation_id", None)
thread_id = create_thread(
tool_name=self.name,
tool_name=self.name,
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
parent_thread_id=continuation_id # Link to parent if this is a continuation
parent_thread_id=continuation_id, # Link to parent if this is a continuation
)
# Add this response as the first turn (assistant turn)
@@ -1144,7 +1160,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
model_provider = None
model_name = None
model_metadata = None
if model_info:
provider = model_info.get("provider")
if provider:
@@ -1152,16 +1168,13 @@ If any of these would strengthen your analysis, specify what Claude should searc
model_name = model_info.get("model_name")
model_response = model_info.get("model_response")
if model_response:
model_metadata = {
"usage": model_response.usage,
"metadata": model_response.metadata
}
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
add_turn(
thread_id,
"assistant",
content,
files=request_files,
thread_id,
"assistant",
content,
files=request_files,
tool_name=self.name,
model_provider=model_provider,
model_name=model_name,
@@ -1260,11 +1273,11 @@ If any of these would strengthen your analysis, specify what Claude should searc
def _validate_and_correct_temperature(self, model_name: str, temperature: float) -> tuple[float, list[str]]:
"""
Validate and correct temperature for the specified model.
Args:
model_name: Name of the model to validate temperature for
temperature: Temperature value to validate
Returns:
Tuple of (corrected_temperature, warning_messages)
"""
@@ -1272,9 +1285,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
provider = self.get_model_provider(model_name)
capabilities = provider.get_capabilities(model_name)
constraint = capabilities.temperature_constraint
warnings = []
if not constraint.validate(temperature):
corrected = constraint.get_corrected_value(temperature)
warning = (
@@ -1283,9 +1296,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
)
warnings.append(warning)
return corrected, warnings
return temperature, warnings
except Exception as e:
# If validation fails for any reason, use the original temperature
# and log a warning (but don't fail the request)
@@ -1308,26 +1321,28 @@ If any of these would strengthen your analysis, specify what Claude should searc
"""
# Get provider from registry
provider = ModelProviderRegistry.get_provider_for_model(model_name)
if not provider:
# Try to determine provider from model name patterns
if "gemini" in model_name.lower() or model_name.lower() in ["flash", "pro"]:
# Register Gemini provider if not already registered
from providers.gemini import GeminiModelProvider
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
elif "gpt" in model_name.lower() or "o3" in model_name.lower():
# Register OpenAI provider if not already registered
from providers.openai import OpenAIModelProvider
from providers.base import ProviderType
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
provider = ModelProviderRegistry.get_provider(ProviderType.OPENAI)
if not provider:
raise ValueError(
f"No provider found for model '{model_name}'. "
f"Ensure the appropriate API key is set and the model name is correct."
)
return provider