Merge remote-tracking branch 'origin/main' into feature/docker-image-docs

This commit is contained in:
Patryk Ciechanski
2025-06-12 12:09:24 +02:00
68 changed files with 5382 additions and 2163 deletions

View File

@@ -1,14 +1,18 @@
# Gemini MCP Server Environment Configuration
# Zen MCP Server Environment Configuration
# Copy this file to .env and fill in your values
# Required: Google Gemini API Key
# Get your API key from: https://makersuite.google.com/app/apikey
# API Keys - At least one is required
# Get your Gemini API key from: https://makersuite.google.com/app/apikey
GEMINI_API_KEY=your_gemini_api_key_here
# Get your OpenAI API key from: https://platform.openai.com/api-keys
OPENAI_API_KEY=your_openai_api_key_here
# Optional: Default model to use
# Full names: 'gemini-2.5-pro-preview-06-05' or 'gemini-2.0-flash-exp'
# Defaults to gemini-2.5-pro-preview-06-05 if not specified
DEFAULT_MODEL=gemini-2.5-pro-preview-06-05
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini'
# When set to 'auto', Claude will select the best model for each task
# Defaults to 'auto' if not specified
DEFAULT_MODEL=auto
# Optional: Default thinking mode for ThinkDeep tool
# NOTE: Only applies to models that support extended thinking (e.g., Gemini 2.5 Pro)

View File

@@ -28,12 +28,13 @@ jobs:
- name: Run unit tests
run: |
# Run all tests except live integration tests
# Run all unit tests
# These tests use mocks and don't require API keys
python -m pytest tests/ --ignore=tests/test_live_integration.py -v
python -m pytest tests/ -v
env:
# Ensure no API key is accidentally used in CI
GEMINI_API_KEY: ""
OPENAI_API_KEY: ""
lint:
runs-on: ubuntu-latest
@@ -56,9 +57,9 @@ jobs:
- name: Run ruff linter
run: ruff check .
live-tests:
simulation-tests:
runs-on: ubuntu-latest
# Only run live tests on main branch pushes (requires manual API key setup)
# Only run simulation tests on main branch pushes (requires manual API key setup)
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
steps:
- uses: actions/checkout@v4
@@ -76,24 +77,41 @@ jobs:
- name: Check API key availability
id: check-key
run: |
if [ -z "${{ secrets.GEMINI_API_KEY }}" ]; then
echo "api_key_available=false" >> $GITHUB_OUTPUT
echo "⚠️ GEMINI_API_KEY secret not configured - skipping live tests"
has_key=false
if [ -n "${{ secrets.GEMINI_API_KEY }}" ] || [ -n "${{ secrets.OPENAI_API_KEY }}" ]; then
has_key=true
echo "✅ API key(s) found - running simulation tests"
else
echo "api_key_available=true" >> $GITHUB_OUTPUT
echo "✅ GEMINI_API_KEY found - running live tests"
echo "⚠️ No API keys configured - skipping simulation tests"
fi
echo "api_key_available=$has_key" >> $GITHUB_OUTPUT
- name: Run live integration tests
- name: Set up Docker
if: steps.check-key.outputs.api_key_available == 'true'
uses: docker/setup-buildx-action@v3
- name: Build Docker image
if: steps.check-key.outputs.api_key_available == 'true'
run: |
# Run live tests that make actual API calls
python tests/test_live_integration.py
docker compose build
- name: Run simulation tests
if: steps.check-key.outputs.api_key_available == 'true'
run: |
# Start services
docker compose up -d
# Wait for services to be ready
sleep 10
# Run communication simulator tests
python communication_simulator_test.py --skip-docker
env:
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Skip live tests
- name: Skip simulation tests
if: steps.check-key.outputs.api_key_available == 'false'
run: |
echo "🔒 Live integration tests skipped (no API key configured)"
echo "To enable live tests, add GEMINI_API_KEY as a repository secret"
echo "🔒 Simulation tests skipped (no API keys configured)"
echo "To enable simulation tests, add GEMINI_API_KEY and/or OPENAI_API_KEY as repository secrets"

6
.gitignore vendored
View File

@@ -173,3 +173,9 @@ memory-bank/
@.claude/
@memory-bank/
CLAUDE.md
# Test simulation artifacts (dynamically created during testing)
test_simulation_files/.claude/
# Temporary test directories
test-setup/
/test_simulation_files/**

40
FIX_SUMMARY.md Normal file
View File

@@ -0,0 +1,40 @@
# Fix for Conversation History Bug in Continuation Flow
## Problem
When using `continuation_id` to continue a conversation, the conversation history (with embedded files) was being lost for tools that don't have a `prompt` field. Only new file content was being passed to the tool, resulting in minimal content (e.g., 322 chars for just a NOTE about files already in history).
## Root Cause
1. `reconstruct_thread_context()` builds conversation history and stores it in `arguments["prompt"]`
2. Different tools use different field names for user input:
- `chat``prompt`
- `analyze``question`
- `debug``error_description`
- `codereview``context`
- `thinkdeep``current_analysis`
- `precommit``original_request`
3. The enhanced prompt with conversation history was being placed in the wrong field
4. Tools would only see their new input, not the conversation history
## Solution
Modified `reconstruct_thread_context()` in `server.py` to:
1. Create a mapping of tool names to their primary input fields
2. Extract the user's new input from the correct field based on the tool
3. Store the enhanced prompt (with conversation history) back into the correct field
## Changes Made
1. **server.py**:
- Added `prompt_field_mapping` to map tools to their input fields
- Modified to extract user input from the correct field
- Modified to store enhanced prompt in the correct field
2. **tests/test_conversation_field_mapping.py**:
- Added comprehensive tests to verify the fix works for all tools
- Tests ensure conversation history is properly mapped to each tool's field
## Verification
All existing tests pass, including:
- `test_conversation_memory.py` (18 tests)
- `test_cross_tool_continuation.py` (4 tests)
- New `test_conversation_field_mapping.py` (2 tests)
The fix ensures that when continuing conversations, tools receive the full conversation history with embedded files, not just new content.

704
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,17 @@
{
"comment": "Example Claude Desktop configuration for Gemini MCP Server",
"comment": "Example Claude Desktop configuration for Zen MCP Server",
"comment2": "For Docker setup, use examples/claude_config_docker_home.json",
"comment3": "For platform-specific examples, see the examples/ directory",
"mcpServers": {
"gemini": {
"command": "/path/to/gemini-mcp-server/run_gemini.sh",
"env": {
"GEMINI_API_KEY": "your-gemini-api-key-here"
}
"zen": {
"command": "docker",
"args": [
"exec",
"-i",
"zen-mcp-server",
"python",
"server.py"
]
}
}
}

View File

@@ -1,8 +1,8 @@
#!/usr/bin/env python3
"""
Communication Simulator Test for Gemini MCP Server
Communication Simulator Test for Zen MCP Server
This script provides comprehensive end-to-end testing of the Gemini MCP server
This script provides comprehensive end-to-end testing of the Zen MCP server
by simulating real Claude CLI communications and validating conversation
continuity, file handling, deduplication features, and clarification scenarios.
@@ -63,8 +63,8 @@ class CommunicationSimulator:
self.keep_logs = keep_logs
self.selected_tests = selected_tests or []
self.temp_dir = None
self.container_name = "gemini-mcp-server"
self.redis_container = "gemini-mcp-redis"
self.container_name = "zen-mcp-server"
self.redis_container = "zen-mcp-redis"
# Import test registry
from simulator_tests import TEST_REGISTRY
@@ -100,7 +100,7 @@ class CommunicationSimulator:
def setup_test_environment(self) -> bool:
"""Setup fresh Docker environment"""
try:
self.logger.info("🚀 Setting up test environment...")
self.logger.info("Setting up test environment...")
# Create temporary directory for test files
self.temp_dir = tempfile.mkdtemp(prefix="mcp_test_")
@@ -116,7 +116,7 @@ class CommunicationSimulator:
def _setup_docker(self) -> bool:
"""Setup fresh Docker environment"""
try:
self.logger.info("🐳 Setting up Docker environment...")
self.logger.info("Setting up Docker environment...")
# Stop and remove existing containers
self._run_command(["docker", "compose", "down", "--remove-orphans"], check=False, capture_output=True)
@@ -128,27 +128,27 @@ class CommunicationSimulator:
self._run_command(["docker", "rm", container], check=False, capture_output=True)
# Build and start services
self.logger.info("📦 Building Docker images...")
self.logger.info("Building Docker images...")
result = self._run_command(["docker", "compose", "build", "--no-cache"], capture_output=True)
if result.returncode != 0:
self.logger.error(f"Docker build failed: {result.stderr}")
return False
self.logger.info("🚀 Starting Docker services...")
self.logger.info("Starting Docker services...")
result = self._run_command(["docker", "compose", "up", "-d"], capture_output=True)
if result.returncode != 0:
self.logger.error(f"Docker startup failed: {result.stderr}")
return False
# Wait for services to be ready
self.logger.info("Waiting for services to be ready...")
self.logger.info("Waiting for services to be ready...")
time.sleep(10) # Give services time to initialize
# Verify containers are running
if not self._verify_containers():
return False
self.logger.info("Docker environment ready")
self.logger.info("Docker environment ready")
return True
except Exception as e:
@@ -177,7 +177,7 @@ class CommunicationSimulator:
def simulate_claude_cli_session(self) -> bool:
"""Simulate a complete Claude CLI session with conversation continuity"""
try:
self.logger.info("🤖 Starting Claude CLI simulation...")
self.logger.info("Starting Claude CLI simulation...")
# If specific tests are selected, run only those
if self.selected_tests:
@@ -190,7 +190,7 @@ class CommunicationSimulator:
if not self._run_single_test(test_name):
return False
self.logger.info("All tests passed")
self.logger.info("All tests passed")
return True
except Exception as e:
@@ -200,13 +200,13 @@ class CommunicationSimulator:
def _run_selected_tests(self) -> bool:
"""Run only the selected tests"""
try:
self.logger.info(f"🎯 Running selected tests: {', '.join(self.selected_tests)}")
self.logger.info(f"Running selected tests: {', '.join(self.selected_tests)}")
for test_name in self.selected_tests:
if not self._run_single_test(test_name):
return False
self.logger.info("All selected tests passed")
self.logger.info("All selected tests passed")
return True
except Exception as e:
@@ -221,14 +221,14 @@ class CommunicationSimulator:
self.logger.info(f"Available tests: {', '.join(self.available_tests.keys())}")
return False
self.logger.info(f"🧪 Running test: {test_name}")
self.logger.info(f"Running test: {test_name}")
test_function = self.available_tests[test_name]
result = test_function()
if result:
self.logger.info(f"Test {test_name} passed")
self.logger.info(f"Test {test_name} passed")
else:
self.logger.error(f"Test {test_name} failed")
self.logger.error(f"Test {test_name} failed")
return result
@@ -244,12 +244,12 @@ class CommunicationSimulator:
self.logger.info(f"Available tests: {', '.join(self.available_tests.keys())}")
return False
self.logger.info(f"🧪 Running individual test: {test_name}")
self.logger.info(f"Running individual test: {test_name}")
# Setup environment unless skipped
if not skip_docker_setup:
if not self.setup_test_environment():
self.logger.error("Environment setup failed")
self.logger.error("Environment setup failed")
return False
# Run the single test
@@ -257,9 +257,9 @@ class CommunicationSimulator:
result = test_function()
if result:
self.logger.info(f"Individual test {test_name} passed")
self.logger.info(f"Individual test {test_name} passed")
else:
self.logger.error(f"Individual test {test_name} failed")
self.logger.error(f"Individual test {test_name} failed")
return result
@@ -282,40 +282,40 @@ class CommunicationSimulator:
def print_test_summary(self):
"""Print comprehensive test results summary"""
print("\\n" + "=" * 70)
print("🧪 GEMINI MCP COMMUNICATION SIMULATOR - TEST RESULTS SUMMARY")
print("ZEN MCP COMMUNICATION SIMULATOR - TEST RESULTS SUMMARY")
print("=" * 70)
passed_count = sum(1 for result in self.test_results.values() if result)
total_count = len(self.test_results)
for test_name, result in self.test_results.items():
status = "PASS" if result else "FAIL"
status = "PASS" if result else "FAIL"
# Get test description
temp_instance = self.test_registry[test_name](verbose=False)
description = temp_instance.test_description
print(f"📝 {description}: {status}")
print(f"{description}: {status}")
print(f"\\n🎯 OVERALL RESULT: {'🎉 SUCCESS' if passed_count == total_count else 'FAILURE'}")
print(f"{passed_count}/{total_count} tests passed")
print(f"\\nOVERALL RESULT: {'SUCCESS' if passed_count == total_count else 'FAILURE'}")
print(f"{passed_count}/{total_count} tests passed")
print("=" * 70)
return passed_count == total_count
def run_full_test_suite(self, skip_docker_setup: bool = False) -> bool:
"""Run the complete test suite"""
try:
self.logger.info("🚀 Starting Gemini MCP Communication Simulator Test Suite")
self.logger.info("Starting Zen MCP Communication Simulator Test Suite")
# Setup
if not skip_docker_setup:
if not self.setup_test_environment():
self.logger.error("Environment setup failed")
self.logger.error("Environment setup failed")
return False
else:
self.logger.info("Skipping Docker setup (containers assumed running)")
self.logger.info("Skipping Docker setup (containers assumed running)")
# Main simulation
if not self.simulate_claude_cli_session():
self.logger.error("Claude CLI simulation failed")
self.logger.error("Claude CLI simulation failed")
return False
# Print comprehensive summary
@@ -333,13 +333,13 @@ class CommunicationSimulator:
def cleanup(self):
"""Cleanup test environment"""
try:
self.logger.info("🧹 Cleaning up test environment...")
self.logger.info("Cleaning up test environment...")
if not self.keep_logs:
# Stop Docker services
self._run_command(["docker", "compose", "down", "--remove-orphans"], check=False, capture_output=True)
else:
self.logger.info("📋 Keeping Docker services running for log inspection")
self.logger.info("Keeping Docker services running for log inspection")
# Remove temp directory
if self.temp_dir and os.path.exists(self.temp_dir):
@@ -359,7 +359,7 @@ class CommunicationSimulator:
def parse_arguments():
"""Parse and validate command line arguments"""
parser = argparse.ArgumentParser(description="Gemini MCP Communication Simulator Test")
parser = argparse.ArgumentParser(description="Zen MCP Communication Simulator Test")
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging")
parser.add_argument("--keep-logs", action="store_true", help="Keep Docker services running for log inspection")
parser.add_argument("--tests", "-t", nargs="+", help="Specific tests to run (space-separated)")
@@ -392,19 +392,19 @@ def run_individual_test(simulator, test_name, skip_docker):
success = simulator.run_individual_test(test_name, skip_docker_setup=skip_docker)
if success:
print(f"\\n🎉 INDIVIDUAL TEST {test_name.upper()}: PASSED")
print(f"\\nINDIVIDUAL TEST {test_name.upper()}: PASSED")
return 0
else:
print(f"\\nINDIVIDUAL TEST {test_name.upper()}: FAILED")
print(f"\\nINDIVIDUAL TEST {test_name.upper()}: FAILED")
return 1
except KeyboardInterrupt:
print(f"\\n🛑 Individual test {test_name} interrupted by user")
print(f"\\nIndividual test {test_name} interrupted by user")
if not skip_docker:
simulator.cleanup()
return 130
except Exception as e:
print(f"\\n💥 Individual test {test_name} failed with error: {e}")
print(f"\\nIndividual test {test_name} failed with error: {e}")
if not skip_docker:
simulator.cleanup()
return 1
@@ -416,20 +416,20 @@ def run_test_suite(simulator, skip_docker=False):
success = simulator.run_full_test_suite(skip_docker_setup=skip_docker)
if success:
print("\\n🎉 COMPREHENSIVE MCP COMMUNICATION TEST: PASSED")
print("\\nCOMPREHENSIVE MCP COMMUNICATION TEST: PASSED")
return 0
else:
print("\\nCOMPREHENSIVE MCP COMMUNICATION TEST: FAILED")
print("⚠️ Check detailed results above")
print("\\nCOMPREHENSIVE MCP COMMUNICATION TEST: FAILED")
print("Check detailed results above")
return 1
except KeyboardInterrupt:
print("\\n🛑 Test interrupted by user")
print("\\nTest interrupted by user")
if not skip_docker:
simulator.cleanup()
return 130
except Exception as e:
print(f"\\n💥 Unexpected error: {e}")
print(f"\\nUnexpected error: {e}")
if not skip_docker:
simulator.cleanup()
return 1

View File

@@ -1,7 +1,7 @@
"""
Configuration and constants for Gemini MCP Server
Configuration and constants for Zen MCP Server
This module centralizes all configuration settings for the Gemini MCP Server.
This module centralizes all configuration settings for the Zen MCP Server.
It defines model configurations, token limits, temperature defaults, and other
constants used throughout the application.
@@ -13,15 +13,43 @@ import os
# Version and metadata
# These values are used in server responses and for tracking releases
# IMPORTANT: This is the single source of truth for version and author info
__version__ = "3.3.0" # Semantic versioning: MAJOR.MINOR.PATCH
__updated__ = "2025-06-11" # Last update date in ISO format
__version__ = "4.0.0" # Semantic versioning: MAJOR.MINOR.PATCH
__updated__ = "2025-06-12" # Last update date in ISO format
__author__ = "Fahad Gilani" # Primary maintainer
# Model configuration
# DEFAULT_MODEL: The default model used for all AI operations
# This should be a stable, high-performance model suitable for code analysis
# Can be overridden by setting DEFAULT_MODEL environment variable
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "gemini-2.5-pro-preview-06-05")
# Special value "auto" means Claude should pick the best model for each task
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "auto")
# Validate DEFAULT_MODEL and set to "auto" if invalid
# Only include actually supported models from providers
VALID_MODELS = ["auto", "flash", "pro", "o3", "o3-mini", "gemini-2.0-flash", "gemini-2.5-pro-preview-06-05"]
if DEFAULT_MODEL not in VALID_MODELS:
import logging
logger = logging.getLogger(__name__)
logger.warning(
f"Invalid DEFAULT_MODEL '{DEFAULT_MODEL}'. Setting to 'auto'. Valid options: {', '.join(VALID_MODELS)}"
)
DEFAULT_MODEL = "auto"
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto"
# Model capabilities descriptions for auto mode
# These help Claude choose the best model for each task
MODEL_CAPABILITIES_DESC = {
"flash": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
"pro": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
"o3": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
# Full model names also supported
"gemini-2.0-flash": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
"gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
}
# Token allocation for Gemini Pro (1M total capacity)
# MAX_CONTEXT_TOKENS: Total model capacity

View File

@@ -1,7 +1,7 @@
services:
redis:
image: redis:7-alpine
container_name: gemini-mcp-redis
container_name: zen-mcp-redis
restart: unless-stopped
ports:
- "6379:6379"
@@ -20,17 +20,18 @@ services:
reservations:
memory: 256M
gemini-mcp:
zen-mcp:
build: .
image: gemini-mcp-server:latest
container_name: gemini-mcp-server
image: zen-mcp-server:latest
container_name: zen-mcp-server
restart: unless-stopped
depends_on:
redis:
condition: service_healthy
environment:
- GEMINI_API_KEY=${GEMINI_API_KEY:?GEMINI_API_KEY is required. Please set it in your .env file or environment.}
- DEFAULT_MODEL=${DEFAULT_MODEL:-gemini-2.5-pro-preview-06-05}
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
- DEFAULT_MODEL=${DEFAULT_MODEL:-auto}
- DEFAULT_THINKING_MODE_THINKDEEP=${DEFAULT_THINKING_MODE_THINKDEEP:-high}
- REDIS_URL=redis://redis:6379/0
# Use HOME not PWD: Claude needs access to any absolute file path, not just current project,
@@ -42,7 +43,6 @@ services:
- ${HOME:-/tmp}:/workspace:ro
- mcp_logs:/tmp # Shared volume for logs
- /etc/localtime:/etc/localtime:ro
- /etc/timezone:/etc/timezone:ro
stdin_open: true
tty: true
entrypoint: ["python"]
@@ -50,17 +50,16 @@ services:
log-monitor:
build: .
image: gemini-mcp-server:latest
container_name: gemini-mcp-log-monitor
image: zen-mcp-server:latest
container_name: zen-mcp-log-monitor
restart: unless-stopped
depends_on:
- gemini-mcp
- zen-mcp
environment:
- PYTHONUNBUFFERED=1
volumes:
- mcp_logs:/tmp # Shared volume for logs
- /etc/localtime:/etc/localtime:ro
- /etc/timezone:/etc/timezone:ro
entrypoint: ["python"]
command: ["log_monitor.py"]

View File

@@ -1,18 +1,18 @@
{
"comment": "Docker configuration that mounts your home directory",
"comment2": "Update paths: /path/to/gemini-mcp-server/.env and /Users/your-username",
"comment2": "Update paths: /path/to/zen-mcp-server/.env and /Users/your-username",
"comment3": "The container auto-detects /workspace as sandbox from WORKSPACE_ROOT",
"mcpServers": {
"gemini": {
"zen": {
"command": "docker",
"args": [
"run",
"--rm",
"-i",
"--env-file", "/path/to/gemini-mcp-server/.env",
"--env-file", "/path/to/zen-mcp-server/.env",
"-e", "WORKSPACE_ROOT=/Users/your-username",
"-v", "/Users/your-username:/workspace:ro",
"gemini-mcp-server:latest"
"zen-mcp-server:latest"
]
}
}

View File

@@ -1,13 +1,17 @@
{
"comment": "Traditional macOS/Linux configuration (non-Docker)",
"comment2": "Replace YOUR_USERNAME with your actual username",
"comment3": "This gives access to all files under your home directory",
"comment": "macOS configuration using Docker",
"comment2": "Ensure Docker is running and containers are started",
"comment3": "Run './setup-docker.sh' first to set up the environment",
"mcpServers": {
"gemini": {
"command": "/Users/YOUR_USERNAME/gemini-mcp-server/run_gemini.sh",
"env": {
"GEMINI_API_KEY": "your-gemini-api-key-here"
}
"zen": {
"command": "docker",
"args": [
"exec",
"-i",
"zen-mcp-server",
"python",
"server.py"
]
}
}
}

View File

@@ -1,14 +1,18 @@
{
"comment": "Windows configuration using WSL (Windows Subsystem for Linux)",
"comment2": "Replace YOUR_WSL_USERNAME with your WSL username",
"comment3": "Make sure the server is installed in your WSL environment",
"comment": "Windows configuration using WSL with Docker",
"comment2": "Ensure Docker Desktop is running and WSL integration is enabled",
"comment3": "Run './setup-docker.sh' in WSL first to set up the environment",
"mcpServers": {
"gemini": {
"zen": {
"command": "wsl.exe",
"args": ["/home/YOUR_WSL_USERNAME/gemini-mcp-server/run_gemini.sh"],
"env": {
"GEMINI_API_KEY": "your-gemini-api-key-here"
}
"args": [
"docker",
"exec",
"-i",
"zen-mcp-server",
"python",
"server.py"
]
}
}
}

15
providers/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
"""Model provider abstractions for supporting multiple AI providers."""
from .base import ModelCapabilities, ModelProvider, ModelResponse
from .gemini import GeminiModelProvider
from .openai import OpenAIModelProvider
from .registry import ModelProviderRegistry
__all__ = [
"ModelProvider",
"ModelResponse",
"ModelCapabilities",
"ModelProviderRegistry",
"GeminiModelProvider",
"OpenAIModelProvider",
]

220
providers/base.py Normal file
View File

@@ -0,0 +1,220 @@
"""Base model provider interface and data classes."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional
class ProviderType(Enum):
"""Supported model provider types."""
GOOGLE = "google"
OPENAI = "openai"
class TemperatureConstraint(ABC):
"""Abstract base class for temperature constraints."""
@abstractmethod
def validate(self, temperature: float) -> bool:
"""Check if temperature is valid."""
pass
@abstractmethod
def get_corrected_value(self, temperature: float) -> float:
"""Get nearest valid temperature."""
pass
@abstractmethod
def get_description(self) -> str:
"""Get human-readable description of constraint."""
pass
@abstractmethod
def get_default(self) -> float:
"""Get model's default temperature."""
pass
class FixedTemperatureConstraint(TemperatureConstraint):
"""For models that only support one temperature value (e.g., O3)."""
def __init__(self, value: float):
self.value = value
def validate(self, temperature: float) -> bool:
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
def get_corrected_value(self, temperature: float) -> float:
return self.value
def get_description(self) -> str:
return f"Only supports temperature={self.value}"
def get_default(self) -> float:
return self.value
class RangeTemperatureConstraint(TemperatureConstraint):
"""For models supporting continuous temperature ranges."""
def __init__(self, min_temp: float, max_temp: float, default: float = None):
self.min_temp = min_temp
self.max_temp = max_temp
self.default_temp = default or (min_temp + max_temp) / 2
def validate(self, temperature: float) -> bool:
return self.min_temp <= temperature <= self.max_temp
def get_corrected_value(self, temperature: float) -> float:
return max(self.min_temp, min(self.max_temp, temperature))
def get_description(self) -> str:
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
def get_default(self) -> float:
return self.default_temp
class DiscreteTemperatureConstraint(TemperatureConstraint):
"""For models supporting only specific temperature values."""
def __init__(self, allowed_values: list[float], default: float = None):
self.allowed_values = sorted(allowed_values)
self.default_temp = default or allowed_values[len(allowed_values) // 2]
def validate(self, temperature: float) -> bool:
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
def get_corrected_value(self, temperature: float) -> float:
return min(self.allowed_values, key=lambda x: abs(x - temperature))
def get_description(self) -> str:
return f"Supports temperatures: {self.allowed_values}"
def get_default(self) -> float:
return self.default_temp
@dataclass
class ModelCapabilities:
"""Capabilities and constraints for a specific model."""
provider: ProviderType
model_name: str
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
max_tokens: int
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
supports_streaming: bool = True
supports_function_calling: bool = False
# Temperature constraint object - preferred way to define temperature limits
temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
)
# Backward compatibility property for existing code
@property
def temperature_range(self) -> tuple[float, float]:
"""Backward compatibility for existing code that uses temperature_range."""
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
elif isinstance(self.temperature_constraint, FixedTemperatureConstraint):
return (self.temperature_constraint.value, self.temperature_constraint.value)
elif isinstance(self.temperature_constraint, DiscreteTemperatureConstraint):
values = self.temperature_constraint.allowed_values
return (min(values), max(values))
return (0.0, 2.0) # Fallback
@dataclass
class ModelResponse:
"""Response from a model provider."""
content: str
usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
model_name: str = ""
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
provider: ProviderType = ProviderType.GOOGLE
metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
@property
def total_tokens(self) -> int:
"""Get total tokens used."""
return self.usage.get("total_tokens", 0)
class ModelProvider(ABC):
"""Abstract base class for model providers."""
def __init__(self, api_key: str, **kwargs):
"""Initialize the provider with API key and optional configuration."""
self.api_key = api_key
self.config = kwargs
@abstractmethod
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model."""
pass
@abstractmethod
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the model.
Args:
prompt: User prompt to send to the model
model_name: Name of the model to use
system_prompt: Optional system prompt for model behavior
temperature: Sampling temperature (0-2)
max_output_tokens: Maximum tokens to generate
**kwargs: Provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
pass
@abstractmethod
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text using the specified model's tokenizer."""
pass
@abstractmethod
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
pass
@abstractmethod
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported by this provider."""
pass
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters against capabilities.
Raises:
ValueError: If parameters are invalid
"""
capabilities = self.get_capabilities(model_name)
# Validate temperature
min_temp, max_temp = capabilities.temperature_range
if not min_temp <= temperature <= max_temp:
raise ValueError(
f"Temperature {temperature} out of range [{min_temp}, {max_temp}] " f"for model {model_name}"
)
@abstractmethod
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
pass

190
providers/gemini.py Normal file
View File

@@ -0,0 +1,190 @@
"""Gemini model provider implementation."""
from typing import Optional
from google import genai
from google.genai import types
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
class GeminiModelProvider(ModelProvider):
"""Google Gemini model provider implementation."""
# Model configurations
SUPPORTED_MODELS = {
"gemini-2.0-flash": {
"max_tokens": 1_048_576, # 1M tokens
"supports_extended_thinking": False,
},
"gemini-2.5-pro-preview-06-05": {
"max_tokens": 1_048_576, # 1M tokens
"supports_extended_thinking": True,
},
# Shorthands
"flash": "gemini-2.0-flash",
"pro": "gemini-2.5-pro-preview-06-05",
}
# Thinking mode configurations for models that support it
THINKING_BUDGETS = {
"minimal": 128, # Minimum for 2.5 Pro - fast responses
"low": 2048, # Light reasoning tasks
"medium": 8192, # Balanced reasoning (default)
"high": 16384, # Complex analysis
"max": 32768, # Maximum reasoning depth
}
def __init__(self, api_key: str, **kwargs):
"""Initialize Gemini provider with API key."""
super().__init__(api_key, **kwargs)
self._client = None
self._token_counters = {} # Cache for token counting
@property
def client(self):
"""Lazy initialization of Gemini client."""
if self._client is None:
self._client = genai.Client(api_key=self.api_key)
return self._client
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific Gemini model."""
# Resolve shorthand
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported Gemini model: {model_name}")
config = self.SUPPORTED_MODELS[resolved_name]
# Gemini models support 0.0-2.0 temperature range
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
return ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name=resolved_name,
friendly_name="Gemini",
max_tokens=config["max_tokens"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
temperature_constraint=temp_constraint,
)
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
thinking_mode: str = "medium",
**kwargs,
) -> ModelResponse:
"""Generate content using Gemini model."""
# Validate parameters
resolved_name = self._resolve_model_name(model_name)
self.validate_parameters(resolved_name, temperature)
# Combine system prompt with user prompt if provided
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
else:
full_prompt = prompt
# Prepare generation config
generation_config = types.GenerateContentConfig(
temperature=temperature,
candidate_count=1,
)
# Add max output tokens if specified
if max_output_tokens:
generation_config.max_output_tokens = max_output_tokens
# Add thinking configuration for models that support it
capabilities = self.get_capabilities(resolved_name)
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
generation_config.thinking_config = types.ThinkingConfig(
thinking_budget=self.THINKING_BUDGETS[thinking_mode]
)
try:
# Generate content
response = self.client.models.generate_content(
model=resolved_name,
contents=full_prompt,
config=generation_config,
)
# Extract usage information if available
usage = self._extract_usage(response)
return ModelResponse(
content=response.text,
usage=usage,
model_name=resolved_name,
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": (
getattr(response.candidates[0], "finish_reason", "STOP") if response.candidates else "STOP"
),
},
)
except Exception as e:
# Log error and re-raise with more context
error_msg = f"Gemini API error for model {resolved_name}: {str(e)}"
raise RuntimeError(error_msg) from e
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text using Gemini's tokenizer."""
self._resolve_model_name(model_name)
# For now, use a simple estimation
# TODO: Use actual Gemini tokenizer when available in SDK
# Rough estimation: ~4 characters per token for English text
return len(text) // 4
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.GOOGLE
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported."""
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
capabilities = self.get_capabilities(model_name)
return capabilities.supports_extended_thinking
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
# Check if it's a shorthand
shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower())
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from Gemini response."""
usage = {}
# Try to extract usage metadata from response
# Note: The actual structure depends on the SDK version and response format
if hasattr(response, "usage_metadata"):
metadata = response.usage_metadata
if hasattr(metadata, "prompt_token_count"):
usage["input_tokens"] = metadata.prompt_token_count
if hasattr(metadata, "candidates_token_count"):
usage["output_tokens"] = metadata.candidates_token_count
if "input_tokens" in usage and "output_tokens" in usage:
usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
return usage

177
providers/openai.py Normal file
View File

@@ -0,0 +1,177 @@
"""OpenAI model provider implementation."""
import logging
from typing import Optional
from openai import OpenAI
from .base import (
FixedTemperatureConstraint,
ModelCapabilities,
ModelProvider,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
class OpenAIModelProvider(ModelProvider):
"""OpenAI model provider implementation."""
# Model configurations
SUPPORTED_MODELS = {
"o3": {
"max_tokens": 200_000, # 200K tokens
"supports_extended_thinking": False,
},
"o3-mini": {
"max_tokens": 200_000, # 200K tokens
"supports_extended_thinking": False,
},
}
def __init__(self, api_key: str, **kwargs):
"""Initialize OpenAI provider with API key."""
super().__init__(api_key, **kwargs)
self._client = None
self.base_url = kwargs.get("base_url") # Support custom endpoints
self.organization = kwargs.get("organization")
@property
def client(self):
"""Lazy initialization of OpenAI client."""
if self._client is None:
client_kwargs = {"api_key": self.api_key}
if self.base_url:
client_kwargs["base_url"] = self.base_url
if self.organization:
client_kwargs["organization"] = self.organization
self._client = OpenAI(**client_kwargs)
return self._client
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific OpenAI model."""
if model_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported OpenAI model: {model_name}")
config = self.SUPPORTED_MODELS[model_name]
# Define temperature constraints per model
if model_name in ["o3", "o3-mini"]:
# O3 models only support temperature=1.0
temp_constraint = FixedTemperatureConstraint(1.0)
else:
# Other OpenAI models support 0.0-2.0 range
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
return ModelCapabilities(
provider=ProviderType.OPENAI,
model_name=model_name,
friendly_name="OpenAI",
max_tokens=config["max_tokens"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
temperature_constraint=temp_constraint,
)
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using OpenAI model."""
# Validate parameters
self.validate_parameters(model_name, temperature)
# Prepare messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# Prepare completion parameters
completion_params = {
"model": model_name,
"messages": messages,
"temperature": temperature,
}
# Add max tokens if specified
if max_output_tokens:
completion_params["max_tokens"] = max_output_tokens
# Add any additional OpenAI-specific parameters
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop"]:
completion_params[key] = value
try:
# Generate completion
response = self.client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name="OpenAI",
provider=ProviderType.OPENAI,
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model, # Actual model used (in case of fallbacks)
"id": response.id,
"created": response.created,
},
)
except Exception as e:
# Log error and re-raise with more context
error_msg = f"OpenAI API error for model {model_name}: {str(e)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from e
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens for the given text.
Note: For accurate token counting, we should use tiktoken library.
This is a simplified estimation.
"""
# TODO: Implement proper token counting with tiktoken
# For now, use rough estimation
# O3 models ~4 chars per token
return len(text) // 4
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.OPENAI
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported."""
return model_name in self.SUPPORTED_MODELS
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
# Currently no OpenAI models support extended thinking
# This may change with future O3 models
return False
def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from OpenAI response."""
usage = {}
if hasattr(response, "usage") and response.usage:
usage["input_tokens"] = response.usage.prompt_tokens
usage["output_tokens"] = response.usage.completion_tokens
usage["total_tokens"] = response.usage.total_tokens
return usage

181
providers/registry.py Normal file
View File

@@ -0,0 +1,181 @@
"""Model provider registry for managing available providers."""
import os
from typing import Optional
from .base import ModelProvider, ProviderType
class ModelProviderRegistry:
"""Registry for managing model providers."""
_instance = None
_providers: dict[ProviderType, type[ModelProvider]] = {}
_initialized_providers: dict[ProviderType, ModelProvider] = {}
def __new__(cls):
"""Singleton pattern for registry."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def register_provider(cls, provider_type: ProviderType, provider_class: type[ModelProvider]) -> None:
"""Register a new provider class.
Args:
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
provider_class: Class that implements ModelProvider interface
"""
cls._providers[provider_type] = provider_class
@classmethod
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
"""Get an initialized provider instance.
Args:
provider_type: Type of provider to get
force_new: Force creation of new instance instead of using cached
Returns:
Initialized ModelProvider instance or None if not available
"""
# Return cached instance if available and not forcing new
if not force_new and provider_type in cls._initialized_providers:
return cls._initialized_providers[provider_type]
# Check if provider class is registered
if provider_type not in cls._providers:
return None
# Get API key from environment
api_key = cls._get_api_key_for_provider(provider_type)
if not api_key:
return None
# Initialize provider
provider_class = cls._providers[provider_type]
provider = provider_class(api_key=api_key)
# Cache the instance
cls._initialized_providers[provider_type] = provider
return provider
@classmethod
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
"""Get provider instance for a specific model name.
Args:
model_name: Name of the model (e.g., "gemini-2.0-flash", "o3-mini")
Returns:
ModelProvider instance that supports this model
"""
# Check each registered provider
for provider_type, _provider_class in cls._providers.items():
# Get or create provider instance
provider = cls.get_provider(provider_type)
if provider and provider.validate_model_name(model_name):
return provider
return None
@classmethod
def get_available_providers(cls) -> list[ProviderType]:
"""Get list of registered provider types."""
return list(cls._providers.keys())
@classmethod
def get_available_models(cls) -> dict[str, ProviderType]:
"""Get mapping of all available models to their providers.
Returns:
Dict mapping model names to provider types
"""
models = {}
for provider_type in cls._providers:
provider = cls.get_provider(provider_type)
if provider:
# This assumes providers have a method to list supported models
# We'll need to add this to the interface
pass
return models
@classmethod
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
"""Get API key for a provider from environment variables.
Args:
provider_type: Provider type to get API key for
Returns:
API key string or None if not found
"""
key_mapping = {
ProviderType.GOOGLE: "GEMINI_API_KEY",
ProviderType.OPENAI: "OPENAI_API_KEY",
}
env_var = key_mapping.get(provider_type)
if not env_var:
return None
return os.getenv(env_var)
@classmethod
def get_preferred_fallback_model(cls) -> str:
"""Get the preferred fallback model based on available API keys.
This method checks which providers have valid API keys and returns
a sensible default model for auto mode fallback situations.
Priority order:
1. OpenAI o3-mini (balanced performance/cost) if OpenAI API key available
2. Gemini 2.0 Flash (fast and efficient) if Gemini API key available
3. OpenAI o3 (high performance) if OpenAI API key available
4. Gemini 2.5 Pro (deep reasoning) if Gemini API key available
5. Fallback to gemini-2.0-flash (most common case)
Returns:
Model name string for fallback use
"""
# Check provider availability by trying to get instances
openai_available = cls.get_provider(ProviderType.OPENAI) is not None
gemini_available = cls.get_provider(ProviderType.GOOGLE) is not None
# Priority order: prefer balanced models first, then high-performance
if openai_available:
return "o3-mini" # Balanced performance/cost
elif gemini_available:
return "gemini-2.0-flash" # Fast and efficient
else:
# No API keys available - return a reasonable default
# This maintains backward compatibility for tests
return "gemini-2.0-flash"
@classmethod
def get_available_providers_with_keys(cls) -> list[ProviderType]:
"""Get list of provider types that have valid API keys.
Returns:
List of ProviderType values for providers with valid API keys
"""
available = []
for provider_type in cls._providers:
if cls.get_provider(provider_type) is not None:
available.append(provider_type)
return available
@classmethod
def clear_cache(cls) -> None:
"""Clear cached provider instances."""
cls._initialized_providers.clear()
@classmethod
def unregister_provider(cls, provider_type: ProviderType) -> None:
"""Unregister a provider (mainly for testing)."""
cls._providers.pop(provider_type, None)
cls._initialized_providers.pop(provider_type, None)

View File

@@ -53,6 +53,7 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"tests/*" = ["B011"]
"tests/conftest.py" = ["E402"] # Module level imports not at top of file - needed for test setup
[build-system]
requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"]

View File

@@ -1,5 +1,6 @@
mcp>=1.0.0
google-genai>=1.19.0
openai>=1.0.0
pydantic>=2.0.0
redis>=5.0.0

142
server.py
View File

@@ -1,8 +1,8 @@
"""
Gemini MCP Server - Main server implementation
Zen MCP Server - Main server implementation
This module implements the core MCP (Model Context Protocol) server that provides
AI-powered tools for code analysis, review, and assistance using Google's Gemini models.
AI-powered tools for code analysis, review, and assistance using multiple AI models.
The server follows the MCP specification to expose various AI tools as callable functions
that can be used by MCP clients (like Claude). Each tool provides specialized functionality
@@ -102,7 +102,7 @@ logger = logging.getLogger(__name__)
# Create the MCP server instance with a unique name identifier
# This name is used by MCP clients to identify and connect to this specific server
server: Server = Server("gemini-server")
server: Server = Server("zen-server")
# Initialize the tool registry with all available AI-powered tools
# Each tool provides specialized functionality for different development tasks
@@ -117,23 +117,46 @@ TOOLS = {
}
def configure_gemini():
def configure_providers():
"""
Configure Gemini API with the provided API key.
Configure and validate AI providers based on available API keys.
This function validates that the GEMINI_API_KEY environment variable is set.
The actual API key is used when creating Gemini clients within individual tools
to ensure proper isolation and error handling.
This function checks for API keys and registers the appropriate providers.
At least one valid API key (Gemini or OpenAI) is required.
Raises:
ValueError: If GEMINI_API_KEY environment variable is not set
ValueError: If no valid API keys are found
"""
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY environment variable is required. Please set it with your Gemini API key.")
# Note: We don't store the API key globally for security reasons
# Each tool creates its own Gemini client with the API key when needed
logger.info("Gemini API key found")
from providers import ModelProviderRegistry
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
valid_providers = []
# Check for Gemini API key
gemini_key = os.getenv("GEMINI_API_KEY")
if gemini_key and gemini_key != "your_gemini_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
valid_providers.append("Gemini")
logger.info("Gemini API key found - Gemini models available")
# Check for OpenAI API key
openai_key = os.getenv("OPENAI_API_KEY")
if openai_key and openai_key != "your_openai_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
valid_providers.append("OpenAI (o3)")
logger.info("OpenAI API key found - o3 model available")
# Require at least one valid provider
if not valid_providers:
raise ValueError(
"At least one API key is required. Please set either:\n"
"- GEMINI_API_KEY for Gemini models\n"
"- OPENAI_API_KEY for OpenAI o3 model"
)
logger.info(f"Available providers: {', '.join(valid_providers)}")
@server.list_tools()
@@ -287,26 +310,26 @@ final analysis and recommendations."""
remaining_turns = max_turns - current_turn_count - 1
return f"""
🤝 CONVERSATION THREADING: You can continue this discussion with Claude! ({remaining_turns} exchanges remaining)
CONVERSATION CONTINUATION: You can continue this discussion with Claude! ({remaining_turns} exchanges remaining)
If you'd like to ask a follow-up question, explore a specific aspect deeper, or need clarification,
add this JSON block at the very end of your response:
Feel free to ask clarifying questions or suggest areas for deeper exploration naturally within your response.
If something needs clarification or you'd benefit from additional context, simply mention it conversationally.
```json
{{
"follow_up_question": "Would you like me to [specific action you could take]?",
"suggested_params": {{"files": ["relevant/files"], "focus_on": "specific area"}},
"ui_hint": "What this follow-up would accomplish"
}}
```
IMPORTANT: When you suggest follow-ups or ask questions, you MUST explicitly instruct Claude to use the continuation_id
to respond. Use clear, direct language based on urgency:
💡 Good follow-up opportunities:
- "Would you like me to examine the error handling in more detail?"
- "Should I analyze the performance implications of this approach?"
- "Would it be helpful to review the security aspects of this implementation?"
- "Should I dive deeper into the architecture patterns used here?"
For optional follow-ups: "Please continue this conversation using the continuation_id from this response if you'd like to explore this further."
Only ask follow-ups when they would genuinely add value to the discussion."""
For needed responses: "Please respond using the continuation_id from this response - your input is needed to proceed."
For essential/critical responses: "RESPONSE REQUIRED: Please immediately continue using the continuation_id from this response. Cannot proceed without your clarification/input."
This ensures Claude knows both HOW to maintain the conversation thread AND whether a response is optional, needed, or essential.
The tool will automatically provide a continuation_id in the structured response that Claude can use in subsequent
tool calls to maintain full conversation context across multiple exchanges.
Remember: Only suggest follow-ups when they would genuinely add value to the discussion, and always instruct Claude to use the continuation_id when you do."""
async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any]:
@@ -363,10 +386,16 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
else:
logger.debug(f"[CONVERSATION_DEBUG] Successfully added user turn to thread {continuation_id}")
# Build conversation history and track token usage
# Create model context early to use for history building
from utils.model_context import ModelContext
model_context = ModelContext.from_arguments(arguments)
# Build conversation history with model-specific limits
logger.debug(f"[CONVERSATION_DEBUG] Building conversation history for thread {continuation_id}")
logger.debug(f"[CONVERSATION_DEBUG] Thread has {len(context.turns)} turns, tool: {context.tool_name}")
conversation_history, conversation_tokens = build_conversation_history(context)
logger.debug(f"[CONVERSATION_DEBUG] Using model: {model_context.model_name}")
conversation_history, conversation_tokens = build_conversation_history(context, model_context)
logger.debug(f"[CONVERSATION_DEBUG] Conversation history built: {conversation_tokens:,} tokens")
logger.debug(f"[CONVERSATION_DEBUG] Conversation history length: {len(conversation_history)} chars")
@@ -374,8 +403,12 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
follow_up_instructions = get_follow_up_instructions(len(context.turns))
logger.debug(f"[CONVERSATION_DEBUG] Follow-up instructions added for turn {len(context.turns)}")
# Merge original context with new prompt and follow-up instructions
# All tools now use standardized 'prompt' field
original_prompt = arguments.get("prompt", "")
logger.debug("[CONVERSATION_DEBUG] Extracting user input from 'prompt' field")
logger.debug(f"[CONVERSATION_DEBUG] User input length: {len(original_prompt)} chars")
# Merge original context with new prompt and follow-up instructions
if conversation_history:
enhanced_prompt = (
f"{conversation_history}\n\n=== NEW USER INPUT ===\n{original_prompt}\n\n{follow_up_instructions}"
@@ -385,15 +418,25 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
# Update arguments with enhanced context and remaining token budget
enhanced_arguments = arguments.copy()
# Store the enhanced prompt in the prompt field
enhanced_arguments["prompt"] = enhanced_prompt
logger.debug("[CONVERSATION_DEBUG] Storing enhanced prompt in 'prompt' field")
# Calculate remaining token budget for current request files/content
from config import MAX_CONTENT_TOKENS
# Calculate remaining token budget based on current model
# (model_context was already created above for history building)
token_allocation = model_context.calculate_token_allocation()
remaining_tokens = MAX_CONTENT_TOKENS - conversation_tokens
# Calculate remaining tokens for files/new content
# History has already consumed some of the content budget
remaining_tokens = token_allocation.content_tokens - conversation_tokens
enhanced_arguments["_remaining_tokens"] = max(0, remaining_tokens) # Ensure non-negative
enhanced_arguments["_model_context"] = model_context # Pass context for use in tools
logger.debug("[CONVERSATION_DEBUG] Token budget calculation:")
logger.debug(f"[CONVERSATION_DEBUG] MAX_CONTENT_TOKENS: {MAX_CONTENT_TOKENS:,}")
logger.debug(f"[CONVERSATION_DEBUG] Model: {model_context.model_name}")
logger.debug(f"[CONVERSATION_DEBUG] Total capacity: {token_allocation.total_tokens:,}")
logger.debug(f"[CONVERSATION_DEBUG] Content allocation: {token_allocation.content_tokens:,}")
logger.debug(f"[CONVERSATION_DEBUG] Conversation tokens: {conversation_tokens:,}")
logger.debug(f"[CONVERSATION_DEBUG] Remaining tokens: {remaining_tokens:,}")
@@ -416,7 +459,7 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
try:
mcp_activity_logger = logging.getLogger("mcp_activity")
mcp_activity_logger.info(
f"CONVERSATION_CONTEXT: Thread {continuation_id} turn {len(context.turns)} - {len(context.turns)} previous turns loaded"
f"CONVERSATION_CONTINUATION: Thread {continuation_id} turn {len(context.turns)} - {len(context.turns)} previous turns loaded"
)
except Exception:
pass
@@ -452,7 +495,7 @@ async def handle_get_version() -> list[TextContent]:
}
# Format the information in a human-readable way
text = f"""Gemini MCP Server v{__version__}
text = f"""Zen MCP Server v{__version__}
Updated: {__updated__}
Author: {__author__}
@@ -466,7 +509,7 @@ Configuration:
Available Tools:
{chr(10).join(f" - {tool}" for tool in version_info["available_tools"])}
For updates, visit: https://github.com/BeehiveInnovations/gemini-mcp-server"""
For updates, visit: https://github.com/BeehiveInnovations/zen-mcp-server"""
# Create standardized tool output
tool_output = ToolOutput(status="success", content=text, content_type="text", metadata={"tool_name": "get_version"})
@@ -485,13 +528,20 @@ async def main():
The server communicates via standard input/output streams using the
MCP protocol's JSON-RPC message format.
"""
# Validate that Gemini API key is available before starting
configure_gemini()
# Validate and configure providers based on available API keys
configure_providers()
# Log startup message for Docker log monitoring
logger.info("Gemini MCP Server starting up...")
logger.info("Zen MCP Server starting up...")
logger.info(f"Log level: {log_level}")
logger.info(f"Using default model: {DEFAULT_MODEL}")
# Log current model mode
from config import IS_AUTO_MODE
if IS_AUTO_MODE:
logger.info("Model mode: AUTO (Claude will select the best model for each task)")
else:
logger.info(f"Model mode: Fixed model '{DEFAULT_MODEL}'")
# Import here to avoid circular imports
from config import DEFAULT_THINKING_MODE_THINKDEEP
@@ -508,7 +558,7 @@ async def main():
read_stream,
write_stream,
InitializationOptions(
server_name="gemini",
server_name="zen",
server_version=__version__,
capabilities=ServerCapabilities(tools=ToolsCapability()), # Advertise tool support capability
),

View File

@@ -3,10 +3,10 @@
# Exit on any error, undefined variables, and pipe failures
set -euo pipefail
# Modern Docker setup script for Gemini MCP Server with Redis
# Modern Docker setup script for Zen MCP Server with Redis
# This script sets up the complete Docker environment including Redis for conversation threading
echo "🚀 Setting up Gemini MCP Server with Docker Compose..."
echo "🚀 Setting up Zen MCP Server with Docker Compose..."
echo ""
# Get the current working directory (absolute path)
@@ -27,8 +27,8 @@ else
cp .env.example .env
echo "✅ Created .env from .env.example"
# Customize the API key if it's set in environment
if [ -n "$GEMINI_API_KEY" ]; then
# Customize the API keys if they're set in environment
if [ -n "${GEMINI_API_KEY:-}" ]; then
# Replace the placeholder API key with the actual value
if command -v sed >/dev/null 2>&1; then
sed -i.bak "s/your_gemini_api_key_here/$GEMINI_API_KEY/" .env && rm .env.bak
@@ -40,6 +40,18 @@ else
echo "⚠️ GEMINI_API_KEY not found in environment. Please edit .env and add your API key."
fi
if [ -n "${OPENAI_API_KEY:-}" ]; then
# Replace the placeholder API key with the actual value
if command -v sed >/dev/null 2>&1; then
sed -i.bak "s/your_openai_api_key_here/$OPENAI_API_KEY/" .env && rm .env.bak
echo "✅ Updated .env with existing OPENAI_API_KEY from environment"
else
echo "⚠️ Found OPENAI_API_KEY in environment, but sed not available. Please update .env manually."
fi
else
echo "⚠️ OPENAI_API_KEY not found in environment. Please edit .env and add your API key."
fi
# Update WORKSPACE_ROOT to use current user's home directory
if command -v sed >/dev/null 2>&1; then
sed -i.bak "s|WORKSPACE_ROOT=/Users/your-username|WORKSPACE_ROOT=$HOME|" .env && rm .env.bak
@@ -74,6 +86,41 @@ if ! docker compose version &> /dev/null; then
COMPOSE_CMD="docker-compose"
fi
# Check if at least one API key is properly configured
echo "🔑 Checking API key configuration..."
source .env 2>/dev/null || true
VALID_GEMINI_KEY=false
VALID_OPENAI_KEY=false
# Check if GEMINI_API_KEY is set and not the placeholder
if [ -n "${GEMINI_API_KEY:-}" ] && [ "$GEMINI_API_KEY" != "your_gemini_api_key_here" ]; then
VALID_GEMINI_KEY=true
echo "✅ Valid GEMINI_API_KEY found"
fi
# Check if OPENAI_API_KEY is set and not the placeholder
if [ -n "${OPENAI_API_KEY:-}" ] && [ "$OPENAI_API_KEY" != "your_openai_api_key_here" ]; then
VALID_OPENAI_KEY=true
echo "✅ Valid OPENAI_API_KEY found"
fi
# Require at least one valid API key
if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ]; then
echo ""
echo "❌ ERROR: At least one valid API key is required!"
echo ""
echo "Please edit the .env file and set at least one of:"
echo " - GEMINI_API_KEY (get from https://makersuite.google.com/app/apikey)"
echo " - OPENAI_API_KEY (get from https://platform.openai.com/api-keys)"
echo ""
echo "Example:"
echo " GEMINI_API_KEY=your-actual-api-key-here"
echo " OPENAI_API_KEY=sk-your-actual-openai-key-here"
echo ""
exit 1
fi
echo "🛠️ Building and starting services..."
echo ""
@@ -84,7 +131,7 @@ $COMPOSE_CMD down --remove-orphans >/dev/null 2>&1 || true
# Clean up any old containers with different naming patterns
OLD_CONTAINERS_FOUND=false
# Check for old Gemini MCP container
# Check for old Gemini MCP containers (for migration)
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server-gemini-mcp-1$" 2>/dev/null || false; then
OLD_CONTAINERS_FOUND=true
echo " - Cleaning up old container: gemini-mcp-server-gemini-mcp-1"
@@ -92,6 +139,21 @@ if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server-gemini-mcp-1
docker rm gemini-mcp-server-gemini-mcp-1 >/dev/null 2>&1 || true
fi
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server$" 2>/dev/null || false; then
OLD_CONTAINERS_FOUND=true
echo " - Cleaning up old container: gemini-mcp-server"
docker stop gemini-mcp-server >/dev/null 2>&1 || true
docker rm gemini-mcp-server >/dev/null 2>&1 || true
fi
# Check for current old containers (from recent versions)
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-log-monitor$" 2>/dev/null || false; then
OLD_CONTAINERS_FOUND=true
echo " - Cleaning up old container: gemini-mcp-log-monitor"
docker stop gemini-mcp-log-monitor >/dev/null 2>&1 || true
docker rm gemini-mcp-log-monitor >/dev/null 2>&1 || true
fi
# Check for old Redis container
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server-redis-1$" 2>/dev/null || false; then
OLD_CONTAINERS_FOUND=true
@@ -100,17 +162,37 @@ if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server-redis-1$" 2>
docker rm gemini-mcp-server-redis-1 >/dev/null 2>&1 || true
fi
# Check for old image
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-redis$" 2>/dev/null || false; then
OLD_CONTAINERS_FOUND=true
echo " - Cleaning up old container: gemini-mcp-redis"
docker stop gemini-mcp-redis >/dev/null 2>&1 || true
docker rm gemini-mcp-redis >/dev/null 2>&1 || true
fi
# Check for old images
if docker images --format "{{.Repository}}:{{.Tag}}" | grep -q "^gemini-mcp-server-gemini-mcp:latest$" 2>/dev/null || false; then
OLD_CONTAINERS_FOUND=true
echo " - Cleaning up old image: gemini-mcp-server-gemini-mcp:latest"
docker rmi gemini-mcp-server-gemini-mcp:latest >/dev/null 2>&1 || true
fi
if docker images --format "{{.Repository}}:{{.Tag}}" | grep -q "^gemini-mcp-server:latest$" 2>/dev/null || false; then
OLD_CONTAINERS_FOUND=true
echo " - Cleaning up old image: gemini-mcp-server:latest"
docker rmi gemini-mcp-server:latest >/dev/null 2>&1 || true
fi
# Check for current old network (if it exists)
if docker network ls --format "{{.Name}}" | grep -q "^gemini-mcp-server_default$" 2>/dev/null || false; then
OLD_CONTAINERS_FOUND=true
echo " - Cleaning up old network: gemini-mcp-server_default"
docker network rm gemini-mcp-server_default >/dev/null 2>&1 || true
fi
# Only show cleanup messages if something was actually cleaned up
# Build and start services
echo " - Building Gemini MCP Server image..."
echo " - Building Zen MCP Server image..."
if $COMPOSE_CMD build --no-cache >/dev/null 2>&1; then
echo "✅ Docker image built successfully!"
else
@@ -143,8 +225,15 @@ $COMPOSE_CMD ps --format table
echo ""
echo "🔄 Next steps:"
if grep -q "your-gemini-api-key-here" .env 2>/dev/null || false; then
echo "1. Edit .env and replace 'your-gemini-api-key-here' with your actual Gemini API key"
NEEDS_KEY_UPDATE=false
if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null; then
NEEDS_KEY_UPDATE=true
fi
if [ "$NEEDS_KEY_UPDATE" = true ]; then
echo "1. Edit .env and replace placeholder API keys with actual ones"
echo " - GEMINI_API_KEY: your-gemini-api-key-here"
echo " - OPENAI_API_KEY: your-openai-api-key-here"
echo "2. Restart services: $COMPOSE_CMD restart"
echo "3. Copy the configuration below to your Claude Desktop config:"
else
@@ -155,12 +244,12 @@ echo ""
echo "===== CLAUDE DESKTOP CONFIGURATION ====="
echo "{"
echo " \"mcpServers\": {"
echo " \"gemini\": {"
echo " \"zen\": {"
echo " \"command\": \"docker\","
echo " \"args\": ["
echo " \"exec\","
echo " \"-i\","
echo " \"gemini-mcp-server\","
echo " \"zen-mcp-server\","
echo " \"python\","
echo " \"server.py\""
echo " ]"
@@ -171,13 +260,13 @@ echo "==========================================="
echo ""
echo "===== CLAUDE CODE CLI CONFIGURATION ====="
echo "# Add the MCP server via Claude Code CLI:"
echo "claude mcp add gemini -s user -- docker exec -i gemini-mcp-server python server.py"
echo "claude mcp add zen -s user -- docker exec -i zen-mcp-server python server.py"
echo ""
echo "# List your MCP servers to verify:"
echo "claude mcp list"
echo ""
echo "# Remove if needed:"
echo "claude mcp remove gemini -s user"
echo "claude mcp remove zen -s user"
echo "==========================================="
echo ""

View File

@@ -1,19 +1,22 @@
"""
Communication Simulator Tests Package
This package contains individual test modules for the Gemini MCP Communication Simulator.
This package contains individual test modules for the Zen MCP Communication Simulator.
Each test is in its own file for better organization and maintainability.
"""
from .base_test import BaseSimulatorTest
from .test_basic_conversation import BasicConversationTest
from .test_content_validation import ContentValidationTest
from .test_conversation_chain_validation import ConversationChainValidationTest
from .test_cross_tool_comprehensive import CrossToolComprehensiveTest
from .test_cross_tool_continuation import CrossToolContinuationTest
from .test_logs_validation import LogsValidationTest
from .test_model_thinking_config import TestModelThinkingConfig
from .test_o3_model_selection import O3ModelSelectionTest
from .test_per_tool_deduplication import PerToolDeduplicationTest
from .test_redis_validation import RedisValidationTest
from .test_token_allocation_validation import TokenAllocationValidationTest
# Test registry for dynamic loading
TEST_REGISTRY = {
@@ -25,6 +28,9 @@ TEST_REGISTRY = {
"logs_validation": LogsValidationTest,
"redis_validation": RedisValidationTest,
"model_thinking_config": TestModelThinkingConfig,
"o3_model_selection": O3ModelSelectionTest,
"token_allocation_validation": TokenAllocationValidationTest,
"conversation_chain_validation": ConversationChainValidationTest,
}
__all__ = [
@@ -37,5 +43,8 @@ __all__ = [
"LogsValidationTest",
"RedisValidationTest",
"TestModelThinkingConfig",
"O3ModelSelectionTest",
"TokenAllocationValidationTest",
"ConversationChainValidationTest",
"TEST_REGISTRY",
]

View File

@@ -19,8 +19,8 @@ class BaseSimulatorTest:
self.verbose = verbose
self.test_files = {}
self.test_dir = None
self.container_name = "gemini-mcp-server"
self.redis_container = "gemini-mcp-redis"
self.container_name = "zen-mcp-server"
self.redis_container = "zen-mcp-redis"
# Configure logging
log_level = logging.DEBUG if verbose else logging.INFO

View File

@@ -25,7 +25,7 @@ class BasicConversationTest(BaseSimulatorTest):
def run_test(self) -> bool:
"""Test basic conversation flow with chat tool"""
try:
self.logger.info("📝 Test: Basic conversation flow")
self.logger.info("Test: Basic conversation flow")
# Setup test files
self.setup_test_files()
@@ -37,6 +37,7 @@ class BasicConversationTest(BaseSimulatorTest):
{
"prompt": "Please use low thinking mode. Analyze this Python code and explain what it does",
"files": [self.test_files["python"]],
"model": "flash",
},
)
@@ -54,6 +55,7 @@ class BasicConversationTest(BaseSimulatorTest):
"prompt": "Please use low thinking mode. Now focus on the Calculator class specifically. Are there any improvements you'd suggest?",
"files": [self.test_files["python"]], # Same file - should be deduplicated
"continuation_id": continuation_id,
"model": "flash",
},
)
@@ -69,6 +71,7 @@ class BasicConversationTest(BaseSimulatorTest):
"prompt": "Please use low thinking mode. Now also analyze this configuration file and see how it might relate to the Python code",
"files": [self.test_files["python"], self.test_files["config"]],
"continuation_id": continuation_id,
"model": "flash",
},
)

View File

@@ -6,7 +6,6 @@ Tests that tools don't duplicate file content in their responses.
This test is specifically designed to catch content duplication bugs.
"""
import json
import os
from .base_test import BaseSimulatorTest
@@ -23,23 +22,58 @@ class ContentValidationTest(BaseSimulatorTest):
def test_description(self) -> str:
return "Content validation and duplicate detection"
def run_test(self) -> bool:
"""Test that tools don't duplicate file content in their responses"""
def get_docker_logs_since(self, since_time: str) -> str:
"""Get docker logs since a specific timestamp"""
try:
self.logger.info("📄 Test: Content validation and duplicate detection")
# Check both main server and log monitor for comprehensive logs
cmd_server = ["docker", "logs", "--since", since_time, self.container_name]
cmd_monitor = ["docker", "logs", "--since", since_time, "zen-mcp-log-monitor"]
import subprocess
result_server = subprocess.run(cmd_server, capture_output=True, text=True)
result_monitor = subprocess.run(cmd_monitor, capture_output=True, text=True)
# Get the internal log files which have more detailed logging
server_log_result = subprocess.run(
["docker", "exec", self.container_name, "cat", "/tmp/mcp_server.log"], capture_output=True, text=True
)
activity_log_result = subprocess.run(
["docker", "exec", self.container_name, "cat", "/tmp/mcp_activity.log"], capture_output=True, text=True
)
# Combine all logs
combined_logs = (
result_server.stdout
+ "\n"
+ result_monitor.stdout
+ "\n"
+ server_log_result.stdout
+ "\n"
+ activity_log_result.stdout
)
return combined_logs
except Exception as e:
self.logger.error(f"Failed to get docker logs: {e}")
return ""
def run_test(self) -> bool:
"""Test that file processing system properly handles file deduplication"""
try:
self.logger.info("📄 Test: Content validation and file processing deduplication")
# Setup test files first
self.setup_test_files()
# Create a test file with distinctive content for validation
# Create a test file for validation
validation_content = '''"""
Configuration file for content validation testing
This content should appear only ONCE in any tool response
"""
# Configuration constants
MAX_CONTENT_TOKENS = 800_000 # This line should appear exactly once
TEMPERATURE_ANALYTICAL = 0.2 # This should also appear exactly once
MAX_CONTENT_TOKENS = 800_000
TEMPERATURE_ANALYTICAL = 0.2
UNIQUE_VALIDATION_MARKER = "CONTENT_VALIDATION_TEST_12345"
# Database settings
@@ -57,138 +91,127 @@ DATABASE_CONFIG = {
# Ensure absolute path for MCP server compatibility
validation_file = os.path.abspath(validation_file)
# Test 1: Precommit tool with files parameter (where the bug occurred)
self.logger.info(" 1: Testing precommit tool content duplication")
# Get timestamp for log filtering
import datetime
# Call precommit tool with the validation file
start_time = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
# Test 1: Initial tool call with validation file
self.logger.info(" 1: Testing initial tool call with file")
# Call chat tool with the validation file
response1, thread_id = self.call_mcp_tool(
"precommit",
"chat",
{
"path": os.getcwd(),
"prompt": "Analyze this configuration file briefly",
"files": [validation_file],
"original_request": "Test for content duplication in precommit tool",
"model": "flash",
},
)
if response1:
# Parse response and check for content duplication
try:
response_data = json.loads(response1)
content = response_data.get("content", "")
if not response1:
self.logger.error(" ❌ Initial tool call failed")
return False
# Count occurrences of distinctive markers
max_content_count = content.count("MAX_CONTENT_TOKENS = 800_000")
temp_analytical_count = content.count("TEMPERATURE_ANALYTICAL = 0.2")
unique_marker_count = content.count("UNIQUE_VALIDATION_MARKER")
self.logger.info(" ✅ Initial tool call completed")
# Validate no duplication
duplication_detected = False
issues = []
if max_content_count > 1:
issues.append(f"MAX_CONTENT_TOKENS appears {max_content_count} times")
duplication_detected = True
if temp_analytical_count > 1:
issues.append(f"TEMPERATURE_ANALYTICAL appears {temp_analytical_count} times")
duplication_detected = True
if unique_marker_count > 1:
issues.append(f"UNIQUE_VALIDATION_MARKER appears {unique_marker_count} times")
duplication_detected = True
if duplication_detected:
self.logger.error(f" ❌ Content duplication detected in precommit tool: {'; '.join(issues)}")
return False
else:
self.logger.info(" ✅ No content duplication in precommit tool")
except json.JSONDecodeError:
self.logger.warning(" ⚠️ Could not parse precommit response as JSON")
else:
self.logger.warning(" ⚠️ Precommit tool failed to respond")
# Test 2: Other tools that use files parameter
tools_to_test = [
(
"chat",
{
"prompt": "Please use low thinking mode. Analyze this config file",
"files": [validation_file],
}, # Using absolute path
),
(
"codereview",
{
"files": [validation_file],
"context": "Please use low thinking mode. Review this configuration",
}, # Using absolute path
),
("analyze", {"files": [validation_file], "analysis_type": "code_quality"}), # Using absolute path
]
for tool_name, params in tools_to_test:
self.logger.info(f" 2.{tool_name}: Testing {tool_name} tool content duplication")
response, _ = self.call_mcp_tool(tool_name, params)
if response:
try:
response_data = json.loads(response)
content = response_data.get("content", "")
# Check for duplication
marker_count = content.count("UNIQUE_VALIDATION_MARKER")
if marker_count > 1:
self.logger.error(
f" ❌ Content duplication in {tool_name}: marker appears {marker_count} times"
)
return False
else:
self.logger.info(f" ✅ No content duplication in {tool_name}")
except json.JSONDecodeError:
self.logger.warning(f" ⚠️ Could not parse {tool_name} response")
else:
self.logger.warning(f" ⚠️ {tool_name} tool failed to respond")
# Test 3: Cross-tool content validation with file deduplication
self.logger.info(" 3: Testing cross-tool content consistency")
# Test 2: Continuation with same file (should be deduplicated)
self.logger.info(" 2: Testing continuation with same file")
if thread_id:
# Continue conversation with same file - content should be deduplicated in conversation history
response2, _ = self.call_mcp_tool(
"chat",
{
"prompt": "Please use low thinking mode. Continue analyzing this configuration file",
"prompt": "Continue analyzing this configuration file",
"files": [validation_file], # Same file should be deduplicated
"continuation_id": thread_id,
"model": "flash",
},
)
if response2:
try:
response_data = json.loads(response2)
content = response_data.get("content", "")
self.logger.info(" ✅ Continuation with same file completed")
else:
self.logger.warning(" ⚠️ Continuation failed")
# In continuation, the file content shouldn't be duplicated either
marker_count = content.count("UNIQUE_VALIDATION_MARKER")
if marker_count > 1:
self.logger.error(
f" ❌ Content duplication in cross-tool continuation: marker appears {marker_count} times"
)
return False
else:
self.logger.info(" ✅ No content duplication in cross-tool continuation")
# Test 3: Different tool with same file (new conversation)
self.logger.info(" 3: Testing different tool with same file")
except json.JSONDecodeError:
self.logger.warning(" ⚠️ Could not parse continuation response")
response3, _ = self.call_mcp_tool(
"codereview",
{
"files": [validation_file],
"prompt": "Review this configuration file",
"model": "flash",
},
)
if response3:
self.logger.info(" ✅ Different tool with same file completed")
else:
self.logger.warning(" ⚠️ Different tool failed")
# Validate file processing behavior from Docker logs
self.logger.info(" 4: Validating file processing logs")
logs = self.get_docker_logs_since(start_time)
# Check for proper file embedding logs
embedding_logs = [
line
for line in logs.split("\n")
if "[FILE_PROCESSING]" in line or "embedding" in line.lower() or "[FILES]" in line
]
# Check for deduplication evidence
deduplication_logs = [
line
for line in logs.split("\n")
if ("skipping" in line.lower() and "already in conversation" in line.lower())
or "No new files to embed" in line
]
# Check for file processing patterns
new_file_logs = [
line
for line in logs.split("\n")
if "will embed new files" in line or "New conversation" in line or "[FILE_PROCESSING]" in line
]
# Validation criteria
validation_file_mentioned = any("validation_config.py" in line for line in logs.split("\n"))
embedding_found = len(embedding_logs) > 0
(len(deduplication_logs) > 0 or len(new_file_logs) >= 2) # Should see new conversation patterns
self.logger.info(f" Embedding logs found: {len(embedding_logs)}")
self.logger.info(f" Deduplication evidence: {len(deduplication_logs)}")
self.logger.info(f" New conversation patterns: {len(new_file_logs)}")
self.logger.info(f" Validation file mentioned: {validation_file_mentioned}")
# Log sample evidence for debugging
if self.verbose and embedding_logs:
self.logger.debug(" 📋 Sample embedding logs:")
for log in embedding_logs[:5]:
self.logger.debug(f" {log}")
# Success criteria
success_criteria = [
("Embedding logs found", embedding_found),
("File processing evidence", validation_file_mentioned),
("Multiple tool calls", len(new_file_logs) >= 2),
]
passed_criteria = sum(1 for _, passed in success_criteria if passed)
self.logger.info(f" Success criteria met: {passed_criteria}/{len(success_criteria)}")
# Cleanup
os.remove(validation_file)
self.logger.info(" ✅ All content validation tests passed")
return True
if passed_criteria >= 2: # At least 2 out of 3 criteria
self.logger.info(" ✅ File processing validation passed")
return True
else:
self.logger.error(" ❌ File processing validation failed")
return False
except Exception as e:
self.logger.error(f"Content validation test failed: {e}")

View File

@@ -0,0 +1,412 @@
#!/usr/bin/env python3
"""
Conversation Chain and Threading Validation Test
This test validates that:
1. Multiple tool invocations create proper parent->parent->parent chains
2. New conversations can be started independently
3. Original conversation chains can be resumed from any point
4. History traversal works correctly for all scenarios
5. Thread relationships are properly maintained in Redis
Test Flow:
Chain A: chat -> analyze -> debug (3 linked threads)
Chain B: chat -> analyze (2 linked threads, independent)
Chain A Branch: debug (continue from original chat, creating branch)
This validates the conversation threading system's ability to:
- Build linear chains
- Create independent conversation threads
- Branch from earlier points in existing chains
- Properly traverse parent relationships for history reconstruction
"""
import re
import subprocess
from .base_test import BaseSimulatorTest
class ConversationChainValidationTest(BaseSimulatorTest):
"""Test conversation chain and threading functionality"""
@property
def test_name(self) -> str:
return "conversation_chain_validation"
@property
def test_description(self) -> str:
return "Conversation chain and threading validation"
def get_recent_server_logs(self) -> str:
"""Get recent server logs from the log file directly"""
try:
cmd = ["docker", "exec", self.container_name, "tail", "-n", "500", "/tmp/mcp_server.log"]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
return result.stdout
else:
self.logger.warning(f"Failed to read server logs: {result.stderr}")
return ""
except Exception as e:
self.logger.error(f"Failed to get server logs: {e}")
return ""
def extract_thread_creation_logs(self, logs: str) -> list[dict[str, str]]:
"""Extract thread creation logs with parent relationships"""
thread_logs = []
lines = logs.split("\n")
for line in lines:
if "[THREAD] Created new thread" in line:
# Parse: [THREAD] Created new thread 9dc779eb-645f-4850-9659-34c0e6978d73 with parent a0ce754d-c995-4b3e-9103-88af429455aa
match = re.search(r"\[THREAD\] Created new thread ([a-f0-9-]+) with parent ([a-f0-9-]+|None)", line)
if match:
thread_id = match.group(1)
parent_id = match.group(2) if match.group(2) != "None" else None
thread_logs.append({"thread_id": thread_id, "parent_id": parent_id, "log_line": line})
return thread_logs
def extract_history_traversal_logs(self, logs: str) -> list[dict[str, str]]:
"""Extract conversation history traversal logs"""
traversal_logs = []
lines = logs.split("\n")
for line in lines:
if "[THREAD] Retrieved chain of" in line:
# Parse: [THREAD] Retrieved chain of 3 threads for 9dc779eb-645f-4850-9659-34c0e6978d73
match = re.search(r"\[THREAD\] Retrieved chain of (\d+) threads for ([a-f0-9-]+)", line)
if match:
chain_length = int(match.group(1))
thread_id = match.group(2)
traversal_logs.append({"thread_id": thread_id, "chain_length": chain_length, "log_line": line})
return traversal_logs
def run_test(self) -> bool:
"""Test conversation chain and threading functionality"""
try:
self.logger.info("Test: Conversation chain and threading validation")
# Setup test files
self.setup_test_files()
# Create test file for consistent context
test_file_content = """def example_function():
'''Simple test function for conversation continuity testing'''
return "Hello from conversation chain test"
class TestClass:
def method(self):
return "Method in test class"
"""
test_file_path = self.create_additional_test_file("chain_test.py", test_file_content)
# Track all continuation IDs and their relationships
conversation_chains = {}
# === CHAIN A: Build linear conversation chain ===
self.logger.info(" Chain A: Building linear conversation chain")
# Step A1: Start with chat tool (creates thread_id_1)
self.logger.info(" Step A1: Chat tool - start new conversation")
response_a1, continuation_id_a1 = self.call_mcp_tool(
"chat",
{
"prompt": "Analyze this test file and explain what it does.",
"files": [test_file_path],
"model": "flash",
"temperature": 0.7,
},
)
if not response_a1 or not continuation_id_a1:
self.logger.error(" ❌ Step A1 failed - no response or continuation ID")
return False
self.logger.info(f" ✅ Step A1 completed - thread_id: {continuation_id_a1[:8]}...")
conversation_chains["A1"] = continuation_id_a1
# Step A2: Continue with analyze tool (creates thread_id_2 with parent=thread_id_1)
self.logger.info(" Step A2: Analyze tool - continue Chain A")
response_a2, continuation_id_a2 = self.call_mcp_tool(
"analyze",
{
"prompt": "Now analyze the code quality and suggest improvements.",
"files": [test_file_path],
"continuation_id": continuation_id_a1,
"model": "flash",
"temperature": 0.7,
},
)
if not response_a2 or not continuation_id_a2:
self.logger.error(" ❌ Step A2 failed - no response or continuation ID")
return False
self.logger.info(f" ✅ Step A2 completed - thread_id: {continuation_id_a2[:8]}...")
conversation_chains["A2"] = continuation_id_a2
# Step A3: Continue with debug tool (creates thread_id_3 with parent=thread_id_2)
self.logger.info(" Step A3: Debug tool - continue Chain A")
response_a3, continuation_id_a3 = self.call_mcp_tool(
"debug",
{
"prompt": "Debug any potential issues in this code.",
"files": [test_file_path],
"continuation_id": continuation_id_a2,
"model": "flash",
"temperature": 0.7,
},
)
if not response_a3 or not continuation_id_a3:
self.logger.error(" ❌ Step A3 failed - no response or continuation ID")
return False
self.logger.info(f" ✅ Step A3 completed - thread_id: {continuation_id_a3[:8]}...")
conversation_chains["A3"] = continuation_id_a3
# === CHAIN B: Start independent conversation ===
self.logger.info(" Chain B: Starting independent conversation")
# Step B1: Start new chat conversation (creates thread_id_4, no parent)
self.logger.info(" Step B1: Chat tool - start NEW independent conversation")
response_b1, continuation_id_b1 = self.call_mcp_tool(
"chat",
{
"prompt": "This is a completely new conversation. Please greet me.",
"model": "flash",
"temperature": 0.7,
},
)
if not response_b1 or not continuation_id_b1:
self.logger.error(" ❌ Step B1 failed - no response or continuation ID")
return False
self.logger.info(f" ✅ Step B1 completed - thread_id: {continuation_id_b1[:8]}...")
conversation_chains["B1"] = continuation_id_b1
# Step B2: Continue the new conversation (creates thread_id_5 with parent=thread_id_4)
self.logger.info(" Step B2: Analyze tool - continue Chain B")
response_b2, continuation_id_b2 = self.call_mcp_tool(
"analyze",
{
"prompt": "Analyze the previous greeting and suggest improvements.",
"continuation_id": continuation_id_b1,
"model": "flash",
"temperature": 0.7,
},
)
if not response_b2 or not continuation_id_b2:
self.logger.error(" ❌ Step B2 failed - no response or continuation ID")
return False
self.logger.info(f" ✅ Step B2 completed - thread_id: {continuation_id_b2[:8]}...")
conversation_chains["B2"] = continuation_id_b2
# === CHAIN A BRANCH: Go back to original conversation ===
self.logger.info(" Chain A Branch: Resume original conversation from A1")
# Step A1-Branch: Use original continuation_id_a1 to branch (creates thread_id_6 with parent=thread_id_1)
self.logger.info(" Step A1-Branch: Debug tool - branch from original Chain A")
response_a1_branch, continuation_id_a1_branch = self.call_mcp_tool(
"debug",
{
"prompt": "Let's debug this from a different angle now.",
"files": [test_file_path],
"continuation_id": continuation_id_a1, # Go back to original!
"model": "flash",
"temperature": 0.7,
},
)
if not response_a1_branch or not continuation_id_a1_branch:
self.logger.error(" ❌ Step A1-Branch failed - no response or continuation ID")
return False
self.logger.info(f" ✅ Step A1-Branch completed - thread_id: {continuation_id_a1_branch[:8]}...")
conversation_chains["A1_Branch"] = continuation_id_a1_branch
# === ANALYSIS: Validate thread relationships and history traversal ===
self.logger.info(" Analyzing conversation chain structure...")
# Get logs and extract thread relationships
logs = self.get_recent_server_logs()
thread_creation_logs = self.extract_thread_creation_logs(logs)
history_traversal_logs = self.extract_history_traversal_logs(logs)
self.logger.info(f" Found {len(thread_creation_logs)} thread creation logs")
self.logger.info(f" Found {len(history_traversal_logs)} history traversal logs")
# Debug: Show what we found
if self.verbose:
self.logger.debug(" Thread creation logs found:")
for log in thread_creation_logs:
self.logger.debug(
f" {log['thread_id'][:8]}... parent: {log['parent_id'][:8] if log['parent_id'] else 'None'}..."
)
self.logger.debug(" History traversal logs found:")
for log in history_traversal_logs:
self.logger.debug(f" {log['thread_id'][:8]}... chain length: {log['chain_length']}")
# Build expected thread relationships
expected_relationships = []
# Note: A1 and B1 won't appear in thread creation logs because they're new conversations (no parent)
# Only continuation threads (A2, A3, B2, A1-Branch) will appear in creation logs
# Find logs for each continuation thread
a2_log = next((log for log in thread_creation_logs if log["thread_id"] == continuation_id_a2), None)
a3_log = next((log for log in thread_creation_logs if log["thread_id"] == continuation_id_a3), None)
b2_log = next((log for log in thread_creation_logs if log["thread_id"] == continuation_id_b2), None)
a1_branch_log = next(
(log for log in thread_creation_logs if log["thread_id"] == continuation_id_a1_branch), None
)
# A2 should have A1 as parent
if a2_log:
expected_relationships.append(("A2 has A1 as parent", a2_log["parent_id"] == continuation_id_a1))
# A3 should have A2 as parent
if a3_log:
expected_relationships.append(("A3 has A2 as parent", a3_log["parent_id"] == continuation_id_a2))
# B2 should have B1 as parent (independent chain)
if b2_log:
expected_relationships.append(("B2 has B1 as parent", b2_log["parent_id"] == continuation_id_b1))
# A1-Branch should have A1 as parent (branching)
if a1_branch_log:
expected_relationships.append(
("A1-Branch has A1 as parent", a1_branch_log["parent_id"] == continuation_id_a1)
)
# Validate history traversal
traversal_validations = []
# History traversal logs are only generated when conversation history is built from scratch
# (not when history is already embedded in the prompt by server.py)
# So we should expect at least 1 traversal log, but not necessarily for every continuation
if len(history_traversal_logs) > 0:
# Validate that any traversal logs we find have reasonable chain lengths
for log in history_traversal_logs:
thread_id = log["thread_id"]
chain_length = log["chain_length"]
# Chain length should be at least 2 for any continuation thread
# (original thread + continuation thread)
is_valid_length = chain_length >= 2
# Try to identify which thread this is for better validation
thread_description = "Unknown thread"
if thread_id == continuation_id_a2:
thread_description = "A2 (should be 2-thread chain)"
is_valid_length = chain_length == 2
elif thread_id == continuation_id_a3:
thread_description = "A3 (should be 3-thread chain)"
is_valid_length = chain_length == 3
elif thread_id == continuation_id_b2:
thread_description = "B2 (should be 2-thread chain)"
is_valid_length = chain_length == 2
elif thread_id == continuation_id_a1_branch:
thread_description = "A1-Branch (should be 2-thread chain)"
is_valid_length = chain_length == 2
traversal_validations.append(
(f"{thread_description[:8]}... has valid chain length", is_valid_length)
)
# Also validate we found at least one traversal (shows the system is working)
traversal_validations.append(
("At least one history traversal occurred", len(history_traversal_logs) >= 1)
)
# === VALIDATION RESULTS ===
self.logger.info(" Thread Relationship Validation:")
relationship_passed = 0
for desc, passed in expected_relationships:
status = "" if passed else ""
self.logger.info(f" {status} {desc}")
if passed:
relationship_passed += 1
self.logger.info(" History Traversal Validation:")
traversal_passed = 0
for desc, passed in traversal_validations:
status = "" if passed else ""
self.logger.info(f" {status} {desc}")
if passed:
traversal_passed += 1
# === SUCCESS CRITERIA ===
total_relationship_checks = len(expected_relationships)
total_traversal_checks = len(traversal_validations)
self.logger.info(" Validation Summary:")
self.logger.info(f" Thread relationships: {relationship_passed}/{total_relationship_checks}")
self.logger.info(f" History traversal: {traversal_passed}/{total_traversal_checks}")
# Success requires at least 80% of validations to pass
relationship_success = relationship_passed >= (total_relationship_checks * 0.8)
# If no traversal checks were possible, it means no traversal logs were found
# This could indicate an issue since we expect at least some history building
if total_traversal_checks == 0:
self.logger.warning(
" No history traversal logs found - this may indicate conversation history is always pre-embedded"
)
# Still consider it successful since the thread relationships are what matter most
traversal_success = True
else:
# For traversal success, we need at least 50% to pass since chain lengths can vary
# The important thing is that traversal is happening and relationships are correct
traversal_success = traversal_passed >= (total_traversal_checks * 0.5)
overall_success = relationship_success and traversal_success
self.logger.info(" Conversation Chain Structure:")
self.logger.info(
f" Chain A: {continuation_id_a1[:8]}{continuation_id_a2[:8]}{continuation_id_a3[:8]}"
)
self.logger.info(f" Chain B: {continuation_id_b1[:8]}{continuation_id_b2[:8]}")
self.logger.info(f" Branch: {continuation_id_a1[:8]}{continuation_id_a1_branch[:8]}")
if overall_success:
self.logger.info(" ✅ Conversation chain validation test PASSED")
return True
else:
self.logger.error(" ❌ Conversation chain validation test FAILED")
return False
except Exception as e:
self.logger.error(f"Conversation chain validation test failed: {e}")
return False
finally:
self.cleanup_test_files()
def main():
"""Run the conversation chain validation test"""
import sys
verbose = "--verbose" in sys.argv or "-v" in sys.argv
test = ConversationChainValidationTest(verbose=verbose)
success = test.run_test()
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@@ -33,13 +33,30 @@ class CrossToolComprehensiveTest(BaseSimulatorTest):
try:
# Check both main server and log monitor for comprehensive logs
cmd_server = ["docker", "logs", "--since", since_time, self.container_name]
cmd_monitor = ["docker", "logs", "--since", since_time, "gemini-mcp-log-monitor"]
cmd_monitor = ["docker", "logs", "--since", since_time, "zen-mcp-log-monitor"]
result_server = subprocess.run(cmd_server, capture_output=True, text=True)
result_monitor = subprocess.run(cmd_monitor, capture_output=True, text=True)
# Combine logs from both containers
combined_logs = result_server.stdout + "\n" + result_monitor.stdout
# Get the internal log files which have more detailed logging
server_log_result = subprocess.run(
["docker", "exec", self.container_name, "cat", "/tmp/mcp_server.log"], capture_output=True, text=True
)
activity_log_result = subprocess.run(
["docker", "exec", self.container_name, "cat", "/tmp/mcp_activity.log"], capture_output=True, text=True
)
# Combine all logs
combined_logs = (
result_server.stdout
+ "\n"
+ result_monitor.stdout
+ "\n"
+ server_log_result.stdout
+ "\n"
+ activity_log_result.stdout
)
return combined_logs
except Exception as e:
self.logger.error(f"Failed to get docker logs: {e}")
@@ -91,6 +108,7 @@ def hash_pwd(pwd):
"prompt": "Please give me a quick one line reply. I have an authentication module that needs review. Can you help me understand potential issues?",
"files": [auth_file],
"thinking_mode": "low",
"model": "flash",
}
response1, continuation_id1 = self.call_mcp_tool("chat", chat_params)
@@ -106,8 +124,9 @@ def hash_pwd(pwd):
self.logger.info(" Step 2: analyze tool - Deep code analysis (fresh)")
analyze_params = {
"files": [auth_file],
"question": "Please give me a quick one line reply. What are the security vulnerabilities and architectural issues in this authentication code?",
"prompt": "Please give me a quick one line reply. What are the security vulnerabilities and architectural issues in this authentication code?",
"thinking_mode": "low",
"model": "flash",
}
response2, continuation_id2 = self.call_mcp_tool("analyze", analyze_params)
@@ -127,6 +146,7 @@ def hash_pwd(pwd):
"prompt": "Please give me a quick one line reply. I also have this configuration file. Can you analyze it alongside the authentication code?",
"files": [auth_file, config_file_path], # Old + new file
"thinking_mode": "low",
"model": "flash",
}
response3, _ = self.call_mcp_tool("chat", chat_continue_params)
@@ -141,8 +161,9 @@ def hash_pwd(pwd):
self.logger.info(" Step 4: debug tool - Identify specific problems")
debug_params = {
"files": [auth_file, config_file_path],
"error_description": "Please give me a quick one line reply. The authentication system has security vulnerabilities. Help me identify and fix the main issues.",
"prompt": "Please give me a quick one line reply. The authentication system has security vulnerabilities. Help me identify and fix the main issues.",
"thinking_mode": "low",
"model": "flash",
}
response4, continuation_id4 = self.call_mcp_tool("debug", debug_params)
@@ -161,8 +182,9 @@ def hash_pwd(pwd):
debug_continue_params = {
"continuation_id": continuation_id4,
"files": [auth_file, config_file_path],
"error_description": "Please give me a quick one line reply. What specific code changes would you recommend to fix the password hashing vulnerability?",
"prompt": "Please give me a quick one line reply. What specific code changes would you recommend to fix the password hashing vulnerability?",
"thinking_mode": "low",
"model": "flash",
}
response5, _ = self.call_mcp_tool("debug", debug_continue_params)
@@ -174,8 +196,9 @@ def hash_pwd(pwd):
self.logger.info(" Step 6: codereview tool - Comprehensive code review")
codereview_params = {
"files": [auth_file, config_file_path],
"context": "Please give me a quick one line reply. Comprehensive security-focused code review for production readiness",
"prompt": "Please give me a quick one line reply. Comprehensive security-focused code review for production readiness",
"thinking_mode": "low",
"model": "flash",
}
response6, continuation_id6 = self.call_mcp_tool("codereview", codereview_params)
@@ -207,8 +230,9 @@ def secure_login(user, pwd):
precommit_params = {
"path": self.test_dir,
"files": [auth_file, config_file_path, improved_file],
"original_request": "Please give me a quick one line reply. Ready to commit security improvements to authentication module",
"prompt": "Please give me a quick one line reply. Ready to commit security improvements to authentication module",
"thinking_mode": "low",
"model": "flash",
}
response7, continuation_id7 = self.call_mcp_tool("precommit", precommit_params)
@@ -253,15 +277,15 @@ def secure_login(user, pwd):
improved_file_mentioned = any("auth_improved.py" in line for line in logs.split("\n"))
# Print comprehensive diagnostics
self.logger.info(f" 📊 Tools used: {len(tools_used)} ({', '.join(tools_used)})")
self.logger.info(f" 📊 Continuation IDs created: {len(continuation_ids_created)}")
self.logger.info(f" 📊 Conversation logs found: {len(conversation_logs)}")
self.logger.info(f" 📊 File embedding logs found: {len(embedding_logs)}")
self.logger.info(f" 📊 Continuation logs found: {len(continuation_logs)}")
self.logger.info(f" 📊 Cross-tool activity logs: {len(cross_tool_logs)}")
self.logger.info(f" 📊 Auth file mentioned: {auth_file_mentioned}")
self.logger.info(f" 📊 Config file mentioned: {config_file_mentioned}")
self.logger.info(f" 📊 Improved file mentioned: {improved_file_mentioned}")
self.logger.info(f" Tools used: {len(tools_used)} ({', '.join(tools_used)})")
self.logger.info(f" Continuation IDs created: {len(continuation_ids_created)}")
self.logger.info(f" Conversation logs found: {len(conversation_logs)}")
self.logger.info(f" File embedding logs found: {len(embedding_logs)}")
self.logger.info(f" Continuation logs found: {len(continuation_logs)}")
self.logger.info(f" Cross-tool activity logs: {len(cross_tool_logs)}")
self.logger.info(f" Auth file mentioned: {auth_file_mentioned}")
self.logger.info(f" Config file mentioned: {config_file_mentioned}")
self.logger.info(f" Improved file mentioned: {improved_file_mentioned}")
if self.verbose:
self.logger.debug(" 📋 Sample tool activity logs:")
@@ -289,9 +313,9 @@ def secure_login(user, pwd):
passed_criteria = sum(success_criteria)
total_criteria = len(success_criteria)
self.logger.info(f" 📊 Success criteria met: {passed_criteria}/{total_criteria}")
self.logger.info(f" Success criteria met: {passed_criteria}/{total_criteria}")
if passed_criteria >= 6: # At least 6 out of 8 criteria
if passed_criteria == total_criteria: # All criteria must pass
self.logger.info(" ✅ Comprehensive cross-tool test: PASSED")
return True
else:

View File

@@ -67,6 +67,7 @@ class CrossToolContinuationTest(BaseSimulatorTest):
{
"prompt": "Please use low thinking mode. Look at this Python code and tell me what you think about it",
"files": [self.test_files["python"]],
"model": "flash",
},
)
@@ -81,6 +82,7 @@ class CrossToolContinuationTest(BaseSimulatorTest):
"prompt": "Please use low thinking mode. Think deeply about potential performance issues in this code",
"files": [self.test_files["python"]], # Same file should be deduplicated
"continuation_id": chat_id,
"model": "flash",
},
)
@@ -93,8 +95,9 @@ class CrossToolContinuationTest(BaseSimulatorTest):
"codereview",
{
"files": [self.test_files["python"]], # Same file should be deduplicated
"context": "Building on our previous analysis, provide a comprehensive code review",
"prompt": "Building on our previous analysis, provide a comprehensive code review",
"continuation_id": chat_id,
"model": "flash",
},
)
@@ -116,7 +119,7 @@ class CrossToolContinuationTest(BaseSimulatorTest):
# Start with analyze
analyze_response, analyze_id = self.call_mcp_tool(
"analyze", {"files": [self.test_files["python"]], "analysis_type": "code_quality"}
"analyze", {"files": [self.test_files["python"]], "analysis_type": "code_quality", "model": "flash"}
)
if not analyze_response or not analyze_id:
@@ -128,8 +131,9 @@ class CrossToolContinuationTest(BaseSimulatorTest):
"debug",
{
"files": [self.test_files["python"]], # Same file should be deduplicated
"issue_description": "Based on our analysis, help debug the performance issue in fibonacci",
"prompt": "Based on our analysis, help debug the performance issue in fibonacci",
"continuation_id": analyze_id,
"model": "flash",
},
)
@@ -144,6 +148,7 @@ class CrossToolContinuationTest(BaseSimulatorTest):
"prompt": "Please use low thinking mode. Think deeply about the architectural implications of the issues we've found",
"files": [self.test_files["python"]], # Same file should be deduplicated
"continuation_id": analyze_id,
"model": "flash",
},
)
@@ -169,6 +174,7 @@ class CrossToolContinuationTest(BaseSimulatorTest):
{
"prompt": "Please use low thinking mode. Analyze both the Python code and configuration file",
"files": [self.test_files["python"], self.test_files["config"]],
"model": "flash",
},
)
@@ -181,8 +187,9 @@ class CrossToolContinuationTest(BaseSimulatorTest):
"codereview",
{
"files": [self.test_files["python"], self.test_files["config"]], # Same files
"context": "Review both files in the context of our previous discussion",
"prompt": "Review both files in the context of our previous discussion",
"continuation_id": multi_id,
"model": "flash",
},
)

View File

@@ -35,7 +35,7 @@ class LogsValidationTest(BaseSimulatorTest):
main_logs = result.stdout.decode() + result.stderr.decode()
# Get logs from log monitor container (where detailed activity is logged)
monitor_result = self.run_command(["docker", "logs", "gemini-mcp-log-monitor"], capture_output=True)
monitor_result = self.run_command(["docker", "logs", "zen-mcp-log-monitor"], capture_output=True)
monitor_logs = ""
if monitor_result.returncode == 0:
monitor_logs = monitor_result.stdout.decode() + monitor_result.stderr.decode()

View File

@@ -55,7 +55,7 @@ class TestModelThinkingConfig(BaseSimulatorTest):
"chat",
{
"prompt": "What is 3 + 3? Give a quick answer.",
"model": "flash", # Should resolve to gemini-2.0-flash-exp
"model": "flash", # Should resolve to gemini-2.0-flash
"thinking_mode": "high", # Should be ignored for Flash model
},
)
@@ -80,7 +80,7 @@ class TestModelThinkingConfig(BaseSimulatorTest):
("pro", "should work with Pro model"),
("flash", "should work with Flash model"),
("gemini-2.5-pro-preview-06-05", "should work with full Pro model name"),
("gemini-2.0-flash-exp", "should work with full Flash model name"),
("gemini-2.0-flash", "should work with full Flash model name"),
]
success_count = 0
@@ -135,7 +135,7 @@ class TestModelThinkingConfig(BaseSimulatorTest):
def run_test(self) -> bool:
"""Run all model thinking configuration tests"""
self.logger.info(f"📝 Test: {self.test_description}")
self.logger.info(f" Test: {self.test_description}")
try:
# Test Pro model with thinking config

View File

@@ -0,0 +1,208 @@
#!/usr/bin/env python3
"""
O3 Model Selection Test
Tests that O3 models are properly selected and used when explicitly specified,
regardless of the default model configuration (even when set to auto).
Validates model selection via Docker logs.
"""
import datetime
import subprocess
from .base_test import BaseSimulatorTest
class O3ModelSelectionTest(BaseSimulatorTest):
"""Test O3 model selection and usage"""
@property
def test_name(self) -> str:
return "o3_model_selection"
@property
def test_description(self) -> str:
return "O3 model selection and usage validation"
def get_recent_server_logs(self) -> str:
"""Get recent server logs from the log file directly"""
try:
# Read logs directly from the log file - more reliable than docker logs --since
cmd = ["docker", "exec", self.container_name, "tail", "-n", "200", "/tmp/mcp_server.log"]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
return result.stdout
else:
self.logger.warning(f"Failed to read server logs: {result.stderr}")
return ""
except Exception as e:
self.logger.error(f"Failed to get server logs: {e}")
return ""
def run_test(self) -> bool:
"""Test O3 model selection and usage"""
try:
self.logger.info(" Test: O3 model selection and usage validation")
# Setup test files for later use
self.setup_test_files()
# Get timestamp for log filtering
datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
# Test 1: Explicit O3 model selection
self.logger.info(" 1: Testing explicit O3 model selection")
response1, _ = self.call_mcp_tool(
"chat",
{
"prompt": "Simple test: What is 2 + 2? Just give a brief answer.",
"model": "o3",
"temperature": 1.0, # O3 only supports default temperature of 1.0
},
)
if not response1:
self.logger.error(" ❌ O3 model test failed")
return False
self.logger.info(" ✅ O3 model call completed")
# Test 2: Explicit O3-mini model selection
self.logger.info(" 2: Testing explicit O3-mini model selection")
response2, _ = self.call_mcp_tool(
"chat",
{
"prompt": "Simple test: What is 3 + 3? Just give a brief answer.",
"model": "o3-mini",
"temperature": 1.0, # O3-mini only supports default temperature of 1.0
},
)
if not response2:
self.logger.error(" ❌ O3-mini model test failed")
return False
self.logger.info(" ✅ O3-mini model call completed")
# Test 3: Another tool with O3 to ensure it works across tools
self.logger.info(" 3: Testing O3 with different tool (codereview)")
# Create a simple test file
test_code = """def add(a, b):
return a + b
def multiply(x, y):
return x * y
"""
test_file = self.create_additional_test_file("simple_math.py", test_code)
response3, _ = self.call_mcp_tool(
"codereview",
{
"files": [test_file],
"prompt": "Quick review of this simple code",
"model": "o3",
"temperature": 1.0, # O3 only supports default temperature of 1.0
},
)
if not response3:
self.logger.error(" ❌ O3 with codereview tool failed")
return False
self.logger.info(" ✅ O3 with codereview tool completed")
# Validate model usage from server logs
self.logger.info(" 4: Validating model usage in logs")
logs = self.get_recent_server_logs()
# Check for OpenAI API calls (this proves O3 models are being used)
openai_api_logs = [line for line in logs.split("\n") if "Sending request to openai API for" in line]
# Check for OpenAI model usage logs
openai_model_logs = [
line for line in logs.split("\n") if "Using model:" in line and "openai provider" in line
]
# Check for successful OpenAI responses
openai_response_logs = [
line for line in logs.split("\n") if "openai provider" in line and "Using model:" in line
]
# Check that we have both chat and codereview tool calls to OpenAI
chat_openai_logs = [line for line in logs.split("\n") if "Sending request to openai API for chat" in line]
codereview_openai_logs = [
line for line in logs.split("\n") if "Sending request to openai API for codereview" in line
]
# Validation criteria - we expect 3 OpenAI calls (2 chat + 1 codereview)
openai_api_called = len(openai_api_logs) >= 3 # Should see 3 OpenAI API calls
openai_model_usage = len(openai_model_logs) >= 3 # Should see 3 model usage logs
openai_responses_received = len(openai_response_logs) >= 3 # Should see 3 responses
chat_calls_to_openai = len(chat_openai_logs) >= 2 # Should see 2 chat calls (o3 + o3-mini)
codereview_calls_to_openai = len(codereview_openai_logs) >= 1 # Should see 1 codereview call
self.logger.info(f" OpenAI API call logs: {len(openai_api_logs)}")
self.logger.info(f" OpenAI model usage logs: {len(openai_model_logs)}")
self.logger.info(f" OpenAI response logs: {len(openai_response_logs)}")
self.logger.info(f" Chat calls to OpenAI: {len(chat_openai_logs)}")
self.logger.info(f" Codereview calls to OpenAI: {len(codereview_openai_logs)}")
# Log sample evidence for debugging
if self.verbose and openai_api_logs:
self.logger.debug(" 📋 Sample OpenAI API logs:")
for log in openai_api_logs[:5]:
self.logger.debug(f" {log}")
if self.verbose and chat_openai_logs:
self.logger.debug(" 📋 Sample chat OpenAI logs:")
for log in chat_openai_logs[:3]:
self.logger.debug(f" {log}")
# Success criteria
success_criteria = [
("OpenAI API calls made", openai_api_called),
("OpenAI model usage logged", openai_model_usage),
("OpenAI responses received", openai_responses_received),
("Chat tool used OpenAI", chat_calls_to_openai),
("Codereview tool used OpenAI", codereview_calls_to_openai),
]
passed_criteria = sum(1 for _, passed in success_criteria if passed)
self.logger.info(f" Success criteria met: {passed_criteria}/{len(success_criteria)}")
for criterion, passed in success_criteria:
status = "" if passed else ""
self.logger.info(f" {status} {criterion}")
if passed_criteria >= 3: # At least 3 out of 4 criteria
self.logger.info(" ✅ O3 model selection validation passed")
return True
else:
self.logger.error(" ❌ O3 model selection validation failed")
return False
except Exception as e:
self.logger.error(f"O3 model selection test failed: {e}")
return False
finally:
self.cleanup_test_files()
def main():
"""Run the O3 model selection tests"""
import sys
verbose = "--verbose" in sys.argv or "-v" in sys.argv
test = O3ModelSelectionTest(verbose=verbose)
success = test.run_test()
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@@ -32,13 +32,30 @@ class PerToolDeduplicationTest(BaseSimulatorTest):
try:
# Check both main server and log monitor for comprehensive logs
cmd_server = ["docker", "logs", "--since", since_time, self.container_name]
cmd_monitor = ["docker", "logs", "--since", since_time, "gemini-mcp-log-monitor"]
cmd_monitor = ["docker", "logs", "--since", since_time, "zen-mcp-log-monitor"]
result_server = subprocess.run(cmd_server, capture_output=True, text=True)
result_monitor = subprocess.run(cmd_monitor, capture_output=True, text=True)
# Combine logs from both containers
combined_logs = result_server.stdout + "\n" + result_monitor.stdout
# Get the internal log files which have more detailed logging
server_log_result = subprocess.run(
["docker", "exec", self.container_name, "cat", "/tmp/mcp_server.log"], capture_output=True, text=True
)
activity_log_result = subprocess.run(
["docker", "exec", self.container_name, "cat", "/tmp/mcp_activity.log"], capture_output=True, text=True
)
# Combine all logs
combined_logs = (
result_server.stdout
+ "\n"
+ result_monitor.stdout
+ "\n"
+ server_log_result.stdout
+ "\n"
+ activity_log_result.stdout
)
return combined_logs
except Exception as e:
self.logger.error(f"Failed to get docker logs: {e}")
@@ -100,8 +117,9 @@ def divide(x, y):
precommit_params = {
"path": self.test_dir, # Required path parameter
"files": [dummy_file_path],
"original_request": "Please give me a quick one line reply. Review this code for commit readiness",
"prompt": "Please give me a quick one line reply. Review this code for commit readiness",
"thinking_mode": "low",
"model": "flash",
}
response1, continuation_id = self.call_mcp_tool("precommit", precommit_params)
@@ -124,8 +142,9 @@ def divide(x, y):
self.logger.info(" Step 2: codereview tool with same file (fresh conversation)")
codereview_params = {
"files": [dummy_file_path],
"context": "Please give me a quick one line reply. General code review for quality and best practices",
"prompt": "Please give me a quick one line reply. General code review for quality and best practices",
"thinking_mode": "low",
"model": "flash",
}
response2, _ = self.call_mcp_tool("codereview", codereview_params)
@@ -150,8 +169,9 @@ def subtract(a, b):
"continuation_id": continuation_id,
"path": self.test_dir, # Required path parameter
"files": [dummy_file_path, new_file_path], # Old + new file
"original_request": "Please give me a quick one line reply. Now also review the new feature file along with the previous one",
"prompt": "Please give me a quick one line reply. Now also review the new feature file along with the previous one",
"thinking_mode": "low",
"model": "flash",
}
response3, _ = self.call_mcp_tool("precommit", continue_params)
@@ -174,7 +194,7 @@ def subtract(a, b):
embedding_logs = [
line
for line in logs.split("\n")
if "📁" in line or "embedding" in line.lower() or "file" in line.lower()
if "[FILE_PROCESSING]" in line or "embedding" in line.lower() or "[FILES]" in line
]
# Check for continuation evidence
@@ -187,11 +207,11 @@ def subtract(a, b):
new_file_mentioned = any("new_feature.py" in line for line in logs.split("\n"))
# Print diagnostic information
self.logger.info(f" 📊 Conversation logs found: {len(conversation_logs)}")
self.logger.info(f" 📊 File embedding logs found: {len(embedding_logs)}")
self.logger.info(f" 📊 Continuation logs found: {len(continuation_logs)}")
self.logger.info(f" 📊 Dummy file mentioned: {dummy_file_mentioned}")
self.logger.info(f" 📊 New file mentioned: {new_file_mentioned}")
self.logger.info(f" Conversation logs found: {len(conversation_logs)}")
self.logger.info(f" File embedding logs found: {len(embedding_logs)}")
self.logger.info(f" Continuation logs found: {len(continuation_logs)}")
self.logger.info(f" Dummy file mentioned: {dummy_file_mentioned}")
self.logger.info(f" New file mentioned: {new_file_mentioned}")
if self.verbose:
self.logger.debug(" 📋 Sample embedding logs:")
@@ -215,9 +235,9 @@ def subtract(a, b):
passed_criteria = sum(success_criteria)
total_criteria = len(success_criteria)
self.logger.info(f" 📊 Success criteria met: {passed_criteria}/{total_criteria}")
self.logger.info(f" Success criteria met: {passed_criteria}/{total_criteria}")
if passed_criteria >= 3: # At least 3 out of 4 criteria
if passed_criteria == total_criteria: # All criteria must pass
self.logger.info(" ✅ File deduplication workflow test: PASSED")
return True
else:

View File

@@ -76,7 +76,7 @@ class RedisValidationTest(BaseSimulatorTest):
return True
else:
# If no existing threads, create a test thread to validate Redis functionality
self.logger.info("📝 No existing threads found, creating test thread to validate Redis...")
self.logger.info(" No existing threads found, creating test thread to validate Redis...")
test_thread_id = "test_thread_validation"
test_data = {

View File

@@ -0,0 +1,549 @@
#!/usr/bin/env python3
"""
Token Allocation and Conversation History Validation Test
This test validates that:
1. Token allocation logging works correctly for file processing
2. Conversation history builds up properly and consumes tokens
3. File deduplication works correctly across tool calls
4. Token usage increases appropriately as conversation history grows
"""
import datetime
import re
import subprocess
from .base_test import BaseSimulatorTest
class TokenAllocationValidationTest(BaseSimulatorTest):
"""Test token allocation and conversation history functionality"""
@property
def test_name(self) -> str:
return "token_allocation_validation"
@property
def test_description(self) -> str:
return "Token allocation and conversation history validation"
def get_recent_server_logs(self) -> str:
"""Get recent server logs from the log file directly"""
try:
cmd = ["docker", "exec", self.container_name, "tail", "-n", "300", "/tmp/mcp_server.log"]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
return result.stdout
else:
self.logger.warning(f"Failed to read server logs: {result.stderr}")
return ""
except Exception as e:
self.logger.error(f"Failed to get server logs: {e}")
return ""
def extract_conversation_usage_logs(self, logs: str) -> list[dict[str, int]]:
"""Extract actual conversation token usage from server logs"""
usage_logs = []
# Look for conversation debug logs that show actual usage
lines = logs.split("\n")
for i, line in enumerate(lines):
if "[CONVERSATION_DEBUG] Token budget calculation:" in line:
# Found start of token budget log, extract the following lines
usage = {}
for j in range(1, 8): # Next 7 lines contain the usage details
if i + j < len(lines):
detail_line = lines[i + j]
# Parse Total capacity: 1,048,576
if "Total capacity:" in detail_line:
match = re.search(r"Total capacity:\s*([\d,]+)", detail_line)
if match:
usage["total_capacity"] = int(match.group(1).replace(",", ""))
# Parse Content allocation: 838,860
elif "Content allocation:" in detail_line:
match = re.search(r"Content allocation:\s*([\d,]+)", detail_line)
if match:
usage["content_allocation"] = int(match.group(1).replace(",", ""))
# Parse Conversation tokens: 12,345
elif "Conversation tokens:" in detail_line:
match = re.search(r"Conversation tokens:\s*([\d,]+)", detail_line)
if match:
usage["conversation_tokens"] = int(match.group(1).replace(",", ""))
# Parse Remaining tokens: 825,515
elif "Remaining tokens:" in detail_line:
match = re.search(r"Remaining tokens:\s*([\d,]+)", detail_line)
if match:
usage["remaining_tokens"] = int(match.group(1).replace(",", ""))
if usage: # Only add if we found some usage data
usage_logs.append(usage)
return usage_logs
def extract_conversation_token_usage(self, logs: str) -> list[int]:
"""Extract conversation token usage from logs"""
usage_values = []
# Look for conversation token usage logs
pattern = r"Conversation history token usage:\s*([\d,]+)"
matches = re.findall(pattern, logs)
for match in matches:
usage_values.append(int(match.replace(",", "")))
return usage_values
def run_test(self) -> bool:
"""Test token allocation and conversation history functionality"""
try:
self.logger.info(" Test: Token allocation and conversation history validation")
# Setup test files
self.setup_test_files()
# Create additional test files for this test - make them substantial enough to see token differences
file1_content = """def fibonacci(n):
'''Calculate fibonacci number recursively
This is a classic recursive algorithm that demonstrates
the exponential time complexity of naive recursion.
For large values of n, this becomes very slow.
Time complexity: O(2^n)
Space complexity: O(n) due to call stack
'''
if n <= 1:
return n
return fibonacci(n-1) + fibonacci(n-2)
def factorial(n):
'''Calculate factorial using recursion
More efficient than fibonacci as each value
is calculated only once.
Time complexity: O(n)
Space complexity: O(n) due to call stack
'''
if n <= 1:
return 1
return n * factorial(n-1)
def gcd(a, b):
'''Calculate greatest common divisor using Euclidean algorithm'''
while b:
a, b = b, a % b
return a
def lcm(a, b):
'''Calculate least common multiple'''
return abs(a * b) // gcd(a, b)
# Test functions with detailed output
if __name__ == "__main__":
print("=== Mathematical Functions Demo ===")
print(f"Fibonacci(10) = {fibonacci(10)}")
print(f"Factorial(5) = {factorial(5)}")
print(f"GCD(48, 18) = {gcd(48, 18)}")
print(f"LCM(48, 18) = {lcm(48, 18)}")
print("Fibonacci sequence (first 10 numbers):")
for i in range(10):
print(f" F({i}) = {fibonacci(i)}")
"""
file2_content = """class Calculator:
'''Advanced calculator class with error handling and logging'''
def __init__(self):
self.history = []
self.last_result = 0
def add(self, a, b):
'''Addition with history tracking'''
result = a + b
operation = f"{a} + {b} = {result}"
self.history.append(operation)
self.last_result = result
return result
def multiply(self, a, b):
'''Multiplication with history tracking'''
result = a * b
operation = f"{a} * {b} = {result}"
self.history.append(operation)
self.last_result = result
return result
def divide(self, a, b):
'''Division with error handling and history tracking'''
if b == 0:
error_msg = f"Division by zero error: {a} / {b}"
self.history.append(error_msg)
raise ValueError("Cannot divide by zero")
result = a / b
operation = f"{a} / {b} = {result}"
self.history.append(operation)
self.last_result = result
return result
def power(self, base, exponent):
'''Exponentiation with history tracking'''
result = base ** exponent
operation = f"{base} ^ {exponent} = {result}"
self.history.append(operation)
self.last_result = result
return result
def get_history(self):
'''Return calculation history'''
return self.history.copy()
def clear_history(self):
'''Clear calculation history'''
self.history.clear()
self.last_result = 0
# Demo usage
if __name__ == "__main__":
calc = Calculator()
print("=== Calculator Demo ===")
# Perform various calculations
print(f"Addition: {calc.add(10, 20)}")
print(f"Multiplication: {calc.multiply(5, 8)}")
print(f"Division: {calc.divide(100, 4)}")
print(f"Power: {calc.power(2, 8)}")
print("\\nCalculation History:")
for operation in calc.get_history():
print(f" {operation}")
print(f"\\nLast result: {calc.last_result}")
"""
# Create test files
file1_path = self.create_additional_test_file("math_functions.py", file1_content)
file2_path = self.create_additional_test_file("calculator.py", file2_content)
# Track continuation IDs to validate each step generates new ones
continuation_ids = []
# Step 1: Initial chat with first file
self.logger.info(" Step 1: Initial chat with file1 - checking token allocation")
datetime.datetime.now()
response1, continuation_id1 = self.call_mcp_tool(
"chat",
{
"prompt": "Please analyze this math functions file and explain what it does.",
"files": [file1_path],
"model": "flash",
"temperature": 0.7,
},
)
if not response1 or not continuation_id1:
self.logger.error(" ❌ Step 1 failed - no response or continuation ID")
return False
self.logger.info(f" ✅ Step 1 completed with continuation_id: {continuation_id1[:8]}...")
continuation_ids.append(continuation_id1)
# Get logs and analyze file processing (Step 1 is new conversation, no conversation debug logs expected)
logs_step1 = self.get_recent_server_logs()
# For Step 1, check for file embedding logs instead of conversation usage
file_embedding_logs_step1 = [
line
for line in logs_step1.split("\n")
if "successfully embedded" in line and "files" in line and "tokens" in line
]
if not file_embedding_logs_step1:
self.logger.error(" ❌ Step 1: No file embedding logs found")
return False
# Extract file token count from embedding logs
step1_file_tokens = 0
for log in file_embedding_logs_step1:
# Look for pattern like "successfully embedded 1 files (146 tokens)"
match = re.search(r"\((\d+) tokens\)", log)
if match:
step1_file_tokens = int(match.group(1))
break
self.logger.info(f" Step 1 File Processing - Embedded files: {step1_file_tokens:,} tokens")
# Validate that file1 is actually mentioned in the embedding logs (check for actual filename)
file1_mentioned = any("math_functions.py" in log for log in file_embedding_logs_step1)
if not file1_mentioned:
# Debug: show what files were actually found in the logs
self.logger.debug(" 📋 Files found in embedding logs:")
for log in file_embedding_logs_step1:
self.logger.debug(f" {log}")
# Also check if any files were embedded at all
any_file_embedded = len(file_embedding_logs_step1) > 0
if not any_file_embedded:
self.logger.error(" ❌ Step 1: No file embedding logs found at all")
return False
else:
self.logger.warning(" ⚠️ Step 1: math_functions.py not specifically found, but files were embedded")
# Continue test - the important thing is that files were processed
# Step 2: Different tool continuing same conversation - should build conversation history
self.logger.info(
" Step 2: Analyze tool continuing chat conversation - checking conversation history buildup"
)
response2, continuation_id2 = self.call_mcp_tool(
"analyze",
{
"prompt": "Analyze the performance implications of these recursive functions.",
"files": [file1_path],
"continuation_id": continuation_id1, # Continue the chat conversation
"model": "flash",
"temperature": 0.7,
},
)
if not response2 or not continuation_id2:
self.logger.error(" ❌ Step 2 failed - no response or continuation ID")
return False
self.logger.info(f" ✅ Step 2 completed with continuation_id: {continuation_id2[:8]}...")
continuation_ids.append(continuation_id2)
# Validate that we got a different continuation ID
if continuation_id2 == continuation_id1:
self.logger.error(" ❌ Step 2: Got same continuation ID as Step 1 - continuation not working")
return False
# Get logs and analyze token usage
logs_step2 = self.get_recent_server_logs()
usage_step2 = self.extract_conversation_usage_logs(logs_step2)
if len(usage_step2) < 2:
self.logger.warning(
f" ⚠️ Step 2: Only found {len(usage_step2)} conversation usage logs, expected at least 2"
)
# Debug: Look for any CONVERSATION_DEBUG logs
conversation_debug_lines = [line for line in logs_step2.split("\n") if "CONVERSATION_DEBUG" in line]
self.logger.debug(f" 📋 Found {len(conversation_debug_lines)} CONVERSATION_DEBUG lines in step 2")
if conversation_debug_lines:
self.logger.debug(" 📋 Recent CONVERSATION_DEBUG lines:")
for line in conversation_debug_lines[-10:]: # Show last 10
self.logger.debug(f" {line}")
# If we have at least 1 usage log, continue with adjusted expectations
if len(usage_step2) >= 1:
self.logger.info(" 📋 Continuing with single usage log for analysis")
else:
self.logger.error(" ❌ No conversation usage logs found at all")
return False
latest_usage_step2 = usage_step2[-1] # Get most recent usage
self.logger.info(
f" Step 2 Token Usage - Total Capacity: {latest_usage_step2.get('total_capacity', 0):,}, "
f"Conversation: {latest_usage_step2.get('conversation_tokens', 0):,}, "
f"Remaining: {latest_usage_step2.get('remaining_tokens', 0):,}"
)
# Step 3: Continue conversation with additional file - should show increased token usage
self.logger.info(" Step 3: Continue conversation with file1 + file2 - checking token growth")
response3, continuation_id3 = self.call_mcp_tool(
"chat",
{
"prompt": "Now compare the math functions with this calculator class. How do they differ in approach?",
"files": [file1_path, file2_path],
"continuation_id": continuation_id2, # Continue the conversation from step 2
"model": "flash",
"temperature": 0.7,
},
)
if not response3 or not continuation_id3:
self.logger.error(" ❌ Step 3 failed - no response or continuation ID")
return False
self.logger.info(f" ✅ Step 3 completed with continuation_id: {continuation_id3[:8]}...")
continuation_ids.append(continuation_id3)
# Get logs and analyze final token usage
logs_step3 = self.get_recent_server_logs()
usage_step3 = self.extract_conversation_usage_logs(logs_step3)
self.logger.info(f" 📋 Found {len(usage_step3)} total conversation usage logs")
if len(usage_step3) < 3:
self.logger.warning(
f" ⚠️ Step 3: Only found {len(usage_step3)} conversation usage logs, expected at least 3"
)
# Let's check if we have at least some logs to work with
if len(usage_step3) == 0:
self.logger.error(" ❌ No conversation usage logs found at all")
# Debug: show some recent logs
recent_lines = logs_step3.split("\n")[-50:]
self.logger.debug(" 📋 Recent log lines:")
for line in recent_lines:
if line.strip() and "CONVERSATION_DEBUG" in line:
self.logger.debug(f" {line}")
return False
latest_usage_step3 = usage_step3[-1] # Get most recent usage
self.logger.info(
f" Step 3 Token Usage - Total Capacity: {latest_usage_step3.get('total_capacity', 0):,}, "
f"Conversation: {latest_usage_step3.get('conversation_tokens', 0):,}, "
f"Remaining: {latest_usage_step3.get('remaining_tokens', 0):,}"
)
# Validation: Check token processing and conversation history
self.logger.info(" 📋 Validating token processing and conversation history...")
# Get conversation usage for steps with continuation_id
step2_conversation = 0
step2_remaining = 0
step3_conversation = 0
step3_remaining = 0
if len(usage_step2) > 0:
step2_conversation = latest_usage_step2.get("conversation_tokens", 0)
step2_remaining = latest_usage_step2.get("remaining_tokens", 0)
if len(usage_step3) >= len(usage_step2) + 1: # Should have one more log than step2
step3_conversation = latest_usage_step3.get("conversation_tokens", 0)
step3_remaining = latest_usage_step3.get("remaining_tokens", 0)
else:
# Use step2 values as fallback
step3_conversation = step2_conversation
step3_remaining = step2_remaining
self.logger.warning(" ⚠️ Using Step 2 usage for Step 3 comparison due to missing logs")
# Validation criteria
criteria = []
# 1. Step 1 should have processed files successfully
step1_processed_files = step1_file_tokens > 0
criteria.append(("Step 1 processed files successfully", step1_processed_files))
# 2. Step 2 should have conversation history (if continuation worked)
step2_has_conversation = (
step2_conversation > 0 if len(usage_step2) > 0 else True
) # Pass if no logs (might be different issue)
step2_has_remaining = step2_remaining > 0 if len(usage_step2) > 0 else True
criteria.append(("Step 2 has conversation history", step2_has_conversation))
criteria.append(("Step 2 has remaining tokens", step2_has_remaining))
# 3. Step 3 should show conversation growth
step3_has_conversation = (
step3_conversation >= step2_conversation if len(usage_step3) > len(usage_step2) else True
)
criteria.append(("Step 3 maintains conversation history", step3_has_conversation))
# 4. Check that we got some conversation usage logs for continuation calls
has_conversation_logs = len(usage_step3) > 0
criteria.append(("Found conversation usage logs", has_conversation_logs))
# 5. Validate unique continuation IDs per response
unique_continuation_ids = len(set(continuation_ids)) == len(continuation_ids)
criteria.append(("Each response generated unique continuation ID", unique_continuation_ids))
# 6. Validate continuation IDs were different from each step
step_ids_different = (
len(continuation_ids) == 3
and continuation_ids[0] != continuation_ids[1]
and continuation_ids[1] != continuation_ids[2]
)
criteria.append(("All continuation IDs are different", step_ids_different))
# Log detailed analysis
self.logger.info(" Token Processing Analysis:")
self.logger.info(f" Step 1 - File tokens: {step1_file_tokens:,} (new conversation)")
self.logger.info(f" Step 2 - Conversation: {step2_conversation:,}, Remaining: {step2_remaining:,}")
self.logger.info(f" Step 3 - Conversation: {step3_conversation:,}, Remaining: {step3_remaining:,}")
# Log continuation ID analysis
self.logger.info(" Continuation ID Analysis:")
self.logger.info(f" Step 1 ID: {continuation_ids[0][:8]}... (generated)")
self.logger.info(f" Step 2 ID: {continuation_ids[1][:8]}... (generated from Step 1)")
self.logger.info(f" Step 3 ID: {continuation_ids[2][:8]}... (generated from Step 2)")
# Check for file mentions in step 3 (should include both files)
# Look for file processing in conversation memory logs and tool embedding logs
file2_mentioned_step3 = any(
"calculator.py" in log
for log in logs_step3.split("\n")
if ("embedded" in log.lower() and ("conversation" in log.lower() or "tool" in log.lower()))
)
file1_still_mentioned_step3 = any(
"math_functions.py" in log
for log in logs_step3.split("\n")
if ("embedded" in log.lower() and ("conversation" in log.lower() or "tool" in log.lower()))
)
self.logger.info(" File Processing in Step 3:")
self.logger.info(f" File1 (math_functions.py) mentioned: {file1_still_mentioned_step3}")
self.logger.info(f" File2 (calculator.py) mentioned: {file2_mentioned_step3}")
# Add file increase validation
step3_file_increase = file2_mentioned_step3 # New file should be visible
criteria.append(("Step 3 shows new file being processed", step3_file_increase))
# Check validation criteria
passed_criteria = sum(1 for _, passed in criteria if passed)
total_criteria = len(criteria)
self.logger.info(f" Validation criteria: {passed_criteria}/{total_criteria}")
for criterion, passed in criteria:
status = "" if passed else ""
self.logger.info(f" {status} {criterion}")
# Check for file embedding logs
file_embedding_logs = [
line for line in logs_step3.split("\n") if "tool embedding" in line and "files" in line
]
conversation_logs = [line for line in logs_step3.split("\n") if "conversation history" in line.lower()]
self.logger.info(f" File embedding logs: {len(file_embedding_logs)}")
self.logger.info(f" Conversation history logs: {len(conversation_logs)}")
# Success criteria: All validation criteria must pass
success = passed_criteria == total_criteria
if success:
self.logger.info(" ✅ Token allocation validation test PASSED")
return True
else:
self.logger.error(" ❌ Token allocation validation test FAILED")
return False
except Exception as e:
self.logger.error(f"Token allocation validation test failed: {e}")
return False
finally:
self.cleanup_test_files()
def main():
"""Run the token allocation validation test"""
import sys
verbose = "--verbose" in sys.argv or "-v" in sys.argv
test = TokenAllocationValidationTest(verbose=verbose)
success = test.run_test()
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@@ -1 +1 @@
# Tests for Gemini MCP Server
# Tests for Zen MCP Server

View File

@@ -1,8 +1,9 @@
"""
Pytest configuration for Gemini MCP Server tests
Pytest configuration for Zen MCP Server tests
"""
import asyncio
import importlib
import os
import sys
import tempfile
@@ -15,20 +16,41 @@ parent_dir = Path(__file__).resolve().parent.parent
if str(parent_dir) not in sys.path:
sys.path.insert(0, str(parent_dir))
# Set dummy API key for tests if not already set
# Set dummy API keys for tests if not already set
if "GEMINI_API_KEY" not in os.environ:
os.environ["GEMINI_API_KEY"] = "dummy-key-for-tests"
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = "dummy-key-for-tests"
# Set default model to a specific value for tests to avoid auto mode
# This prevents all tests from failing due to missing model parameter
os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash"
# Force reload of config module to pick up the env var
import config # noqa: E402
importlib.reload(config)
# Set MCP_PROJECT_ROOT to a temporary directory for tests
# This provides a safe sandbox for file operations during testing
# Create a temporary directory that will be used as the project root for all tests
test_root = tempfile.mkdtemp(prefix="gemini_mcp_test_")
test_root = tempfile.mkdtemp(prefix="zen_mcp_test_")
os.environ["MCP_PROJECT_ROOT"] = test_root
# Configure asyncio for Windows compatibility
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# Register providers for all tests
from providers import ModelProviderRegistry # noqa: E402
from providers.base import ProviderType # noqa: E402
from providers.gemini import GeminiModelProvider # noqa: E402
from providers.openai import OpenAIModelProvider # noqa: E402
# Register providers at test startup
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
@pytest.fixture
def project_path(tmp_path):

41
tests/mock_helpers.py Normal file
View File

@@ -0,0 +1,41 @@
"""Helper functions for test mocking."""
from unittest.mock import Mock
from providers.base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
def create_mock_provider(model_name="gemini-2.0-flash", max_tokens=1_048_576):
"""Create a properly configured mock provider."""
mock_provider = Mock()
# Set up capabilities
mock_capabilities = ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name=model_name,
friendly_name="Gemini",
max_tokens=max_tokens,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
)
mock_provider.get_capabilities.return_value = mock_capabilities
mock_provider.get_provider_type.return_value = ProviderType.GOOGLE
mock_provider.supports_thinking_mode.return_value = False
mock_provider.validate_model_name.return_value = True
# Set up generate_content response
mock_response = Mock()
mock_response.content = "Test response"
mock_response.usage = {"input_tokens": 10, "output_tokens": 20}
mock_response.model_name = model_name
mock_response.friendly_name = "Gemini"
mock_response.provider = ProviderType.GOOGLE
mock_response.metadata = {"finish_reason": "STOP"}
mock_provider.generate_content.return_value = mock_response
return mock_provider

192
tests/test_auto_mode.py Normal file
View File

@@ -0,0 +1,192 @@
"""Tests for auto mode functionality"""
import importlib
import os
from unittest.mock import patch
import pytest
from tools.analyze import AnalyzeTool
class TestAutoMode:
"""Test auto mode configuration and behavior"""
def test_auto_mode_detection(self):
"""Test that auto mode is detected correctly"""
# Save original
original = os.environ.get("DEFAULT_MODEL", "")
try:
# Test auto mode
os.environ["DEFAULT_MODEL"] = "auto"
import config
importlib.reload(config)
assert config.DEFAULT_MODEL == "auto"
assert config.IS_AUTO_MODE is True
# Test non-auto mode
os.environ["DEFAULT_MODEL"] = "pro"
importlib.reload(config)
assert config.DEFAULT_MODEL == "pro"
assert config.IS_AUTO_MODE is False
finally:
# Restore
if original:
os.environ["DEFAULT_MODEL"] = original
else:
os.environ.pop("DEFAULT_MODEL", None)
importlib.reload(config)
def test_model_capabilities_descriptions(self):
"""Test that model capabilities are properly defined"""
from config import MODEL_CAPABILITIES_DESC
# Check all expected models are present
expected_models = ["flash", "pro", "o3", "o3-mini"]
for model in expected_models:
assert model in MODEL_CAPABILITIES_DESC
assert isinstance(MODEL_CAPABILITIES_DESC[model], str)
assert len(MODEL_CAPABILITIES_DESC[model]) > 50 # Meaningful description
def test_tool_schema_in_auto_mode(self):
"""Test that tool schemas require model in auto mode"""
# Save original
original = os.environ.get("DEFAULT_MODEL", "")
try:
# Enable auto mode
os.environ["DEFAULT_MODEL"] = "auto"
import config
importlib.reload(config)
tool = AnalyzeTool()
schema = tool.get_input_schema()
# Model should be required
assert "model" in schema["required"]
# Model field should have detailed descriptions
model_schema = schema["properties"]["model"]
assert "enum" in model_schema
assert "flash" in model_schema["enum"]
assert "Choose the best model" in model_schema["description"]
finally:
# Restore
if original:
os.environ["DEFAULT_MODEL"] = original
else:
os.environ.pop("DEFAULT_MODEL", None)
importlib.reload(config)
def test_tool_schema_in_normal_mode(self):
"""Test that tool schemas don't require model in normal mode"""
# This test uses the default from conftest.py which sets non-auto mode
tool = AnalyzeTool()
schema = tool.get_input_schema()
# Model should not be required
assert "model" not in schema["required"]
# Model field should have simpler description
model_schema = schema["properties"]["model"]
assert "enum" not in model_schema
assert "Available:" in model_schema["description"]
@pytest.mark.asyncio
async def test_auto_mode_requires_model_parameter(self):
"""Test that auto mode enforces model parameter"""
# Save original
original = os.environ.get("DEFAULT_MODEL", "")
try:
# Enable auto mode
os.environ["DEFAULT_MODEL"] = "auto"
import config
importlib.reload(config)
tool = AnalyzeTool()
# Mock the provider to avoid real API calls
with patch.object(tool, "get_model_provider"):
# Execute without model parameter
result = await tool.execute({"files": ["/tmp/test.py"], "prompt": "Analyze this"})
# Should get error
assert len(result) == 1
response = result[0].text
assert "error" in response
assert "Model parameter is required" in response
finally:
# Restore
if original:
os.environ["DEFAULT_MODEL"] = original
else:
os.environ.pop("DEFAULT_MODEL", None)
importlib.reload(config)
def test_model_field_schema_generation(self):
"""Test the get_model_field_schema method"""
from tools.base import BaseTool
# Create a minimal concrete tool for testing
class TestTool(BaseTool):
def get_name(self):
return "test"
def get_description(self):
return "test"
def get_input_schema(self):
return {}
def get_system_prompt(self):
return ""
def get_request_model(self):
return None
async def prepare_prompt(self, request):
return ""
tool = TestTool()
# Save original
original = os.environ.get("DEFAULT_MODEL", "")
try:
# Test auto mode
os.environ["DEFAULT_MODEL"] = "auto"
import config
importlib.reload(config)
schema = tool.get_model_field_schema()
assert "enum" in schema
assert all(model in schema["enum"] for model in ["flash", "pro", "o3"])
assert "Choose the best model" in schema["description"]
# Test normal mode
os.environ["DEFAULT_MODEL"] = "pro"
importlib.reload(config)
schema = tool.get_model_field_schema()
assert "enum" not in schema
assert "Available:" in schema["description"]
assert "'pro'" in schema["description"]
finally:
# Restore
if original:
os.environ["DEFAULT_MODEL"] = original
else:
os.environ.pop("DEFAULT_MODEL", None)
importlib.reload(config)

View File

@@ -11,8 +11,8 @@ from unittest.mock import Mock, patch
import pytest
from pydantic import Field
from tests.mock_helpers import create_mock_provider
from tools.base import BaseTool, ToolRequest
from tools.models import ContinuationOffer, ToolOutput
from utils.conversation_memory import MAX_CONVERSATION_TURNS
@@ -58,78 +58,117 @@ class TestClaudeContinuationOffers:
self.tool = ClaudeContinuationTool()
@patch("utils.conversation_memory.get_redis_client")
def test_new_conversation_offers_continuation(self, mock_redis):
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_new_conversation_offers_continuation(self, mock_redis):
"""Test that new conversations offer Claude continuation opportunity"""
mock_client = Mock()
mock_redis.return_value = mock_client
# Test request without continuation_id (new conversation)
request = ContinuationRequest(prompt="Analyze this code")
# Mock the model
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Analysis complete.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Check continuation opportunity
continuation_data = self.tool._check_continuation_opportunity(request)
# Execute tool without continuation_id (new conversation)
arguments = {"prompt": "Analyze this code"}
response = await self.tool.execute(arguments)
assert continuation_data is not None
assert continuation_data["remaining_turns"] == MAX_CONVERSATION_TURNS - 1
assert continuation_data["tool_name"] == "test_continuation"
# Parse response
response_data = json.loads(response[0].text)
def test_existing_conversation_no_continuation_offer(self):
"""Test that existing threaded conversations don't offer continuation"""
# Test request with continuation_id (existing conversation)
request = ContinuationRequest(
prompt="Continue analysis", continuation_id="12345678-1234-1234-1234-123456789012"
)
# Check continuation opportunity
continuation_data = self.tool._check_continuation_opportunity(request)
assert continuation_data is None
# Should offer continuation for new conversation
assert response_data["status"] == "continuation_available"
assert "continuation_offer" in response_data
assert response_data["continuation_offer"]["remaining_turns"] == MAX_CONVERSATION_TURNS - 1
@patch("utils.conversation_memory.get_redis_client")
def test_create_continuation_offer_response(self, mock_redis):
"""Test creating continuation offer response"""
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_existing_conversation_still_offers_continuation(self, mock_redis):
"""Test that existing threaded conversations still offer continuation if turns remain"""
mock_client = Mock()
mock_redis.return_value = mock_client
request = ContinuationRequest(prompt="Test prompt")
content = "This is the analysis result."
continuation_data = {"remaining_turns": 4, "tool_name": "test_continuation"}
# Mock existing thread context with 2 turns
from utils.conversation_memory import ConversationTurn, ThreadContext
# Create continuation offer response
response = self.tool._create_continuation_offer_response(content, continuation_data, request)
thread_context = ThreadContext(
thread_id="12345678-1234-1234-1234-123456789012",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:01:00Z",
tool_name="test_continuation",
turns=[
ConversationTurn(
role="assistant",
content="Previous response",
timestamp="2023-01-01T00:00:30Z",
tool_name="test_continuation",
),
ConversationTurn(
role="user",
content="Follow up question",
timestamp="2023-01-01T00:01:00Z",
),
],
initial_context={"prompt": "Initial analysis"},
)
mock_client.get.return_value = thread_context.model_dump_json()
assert isinstance(response, ToolOutput)
assert response.status == "continuation_available"
assert response.content == content
assert response.continuation_offer is not None
# Mock the model
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Continued analysis.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
offer = response.continuation_offer
assert isinstance(offer, ContinuationOffer)
assert offer.remaining_turns == 4
assert "continuation_id" in offer.suggested_tool_params
assert "You have 4 more exchange(s) available" in offer.message_to_user
# Execute tool with continuation_id
arguments = {"prompt": "Continue analysis", "continuation_id": "12345678-1234-1234-1234-123456789012"}
response = await self.tool.execute(arguments)
# Parse response
response_data = json.loads(response[0].text)
# Should still offer continuation since turns remain
assert response_data["status"] == "continuation_available"
assert "continuation_offer" in response_data
# 10 max - 2 existing - 1 new = 7 remaining
assert response_data["continuation_offer"]["remaining_turns"] == 7
@patch("utils.conversation_memory.get_redis_client")
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_full_response_flow_with_continuation_offer(self, mock_redis):
"""Test complete response flow that creates continuation offer"""
mock_client = Mock()
mock_redis.return_value = mock_client
# Mock the model to return a response without follow-up question
with patch.object(self.tool, "create_model") as mock_create_model:
mock_model = Mock()
mock_response = Mock()
mock_response.candidates = [
Mock(
content=Mock(parts=[Mock(text="Analysis complete. The code looks good.")]),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Analysis complete. The code looks good.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Execute tool with new conversation
arguments = {"prompt": "Analyze this code"}
arguments = {"prompt": "Analyze this code", "model": "flash"}
response = await self.tool.execute(arguments)
# Parse response
@@ -151,37 +190,28 @@ class TestClaudeContinuationOffers:
assert "more exchange(s) available" in offer["message_to_user"]
@patch("utils.conversation_memory.get_redis_client")
async def test_gemini_follow_up_takes_precedence(self, mock_redis):
"""Test that Gemini follow-up questions take precedence over continuation offers"""
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_continuation_always_offered_with_natural_language(self, mock_redis):
"""Test that continuation is always offered with natural language prompts"""
mock_client = Mock()
mock_redis.return_value = mock_client
# Mock the model to return a response WITH follow-up question
with patch.object(self.tool, "create_model") as mock_create_model:
mock_model = Mock()
mock_response = Mock()
mock_response.candidates = [
Mock(
content=Mock(
parts=[
Mock(
text="""Analysis complete. The code looks good.
# Mock the model to return a response with natural language follow-up
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
# Include natural language follow-up in the content
content_with_followup = """Analysis complete. The code looks good.
```json
{
"follow_up_question": "Would you like me to examine the error handling patterns?",
"suggested_params": {"files": ["/src/error_handler.py"]},
"ui_hint": "Examining error handling would help ensure robustness"
}
```"""
)
]
),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
I'd be happy to examine the error handling patterns in more detail if that would be helpful."""
mock_provider.generate_content.return_value = Mock(
content=content_with_followup,
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Execute tool
arguments = {"prompt": "Analyze this code"}
@@ -190,12 +220,13 @@ class TestClaudeContinuationOffers:
# Parse response
response_data = json.loads(response[0].text)
# Should be follow-up, not continuation offer
assert response_data["status"] == "requires_continuation"
assert "follow_up_request" in response_data
assert response_data.get("continuation_offer") is None
# Should always offer continuation
assert response_data["status"] == "continuation_available"
assert "continuation_offer" in response_data
assert response_data["continuation_offer"]["remaining_turns"] == MAX_CONVERSATION_TURNS - 1
@patch("utils.conversation_memory.get_redis_client")
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_threaded_conversation_with_continuation_offer(self, mock_redis):
"""Test that threaded conversations still get continuation offers when turns remain"""
mock_client = Mock()
@@ -215,17 +246,17 @@ class TestClaudeContinuationOffers:
mock_client.get.return_value = thread_context.model_dump_json()
# Mock the model
with patch.object(self.tool, "create_model") as mock_create_model:
mock_model = Mock()
mock_response = Mock()
mock_response.candidates = [
Mock(
content=Mock(parts=[Mock(text="Continued analysis complete.")]),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Continued analysis complete.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Execute tool with continuation_id
arguments = {"prompt": "Continue the analysis", "continuation_id": "12345678-1234-1234-1234-123456789012"}
@@ -239,81 +270,60 @@ class TestClaudeContinuationOffers:
assert response_data.get("continuation_offer") is not None
assert response_data["continuation_offer"]["remaining_turns"] == 9
def test_max_turns_reached_no_continuation_offer(self):
@patch("utils.conversation_memory.get_redis_client")
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_max_turns_reached_no_continuation_offer(self, mock_redis):
"""Test that no continuation is offered when max turns would be exceeded"""
# Mock MAX_CONVERSATION_TURNS to be 1 for this test
with patch("tools.base.MAX_CONVERSATION_TURNS", 1):
request = ContinuationRequest(prompt="Test prompt")
# Check continuation opportunity
continuation_data = self.tool._check_continuation_opportunity(request)
# Should be None because remaining_turns would be 0
assert continuation_data is None
@patch("utils.conversation_memory.get_redis_client")
def test_continuation_offer_thread_creation_failure_fallback(self, mock_redis):
"""Test fallback to normal response when thread creation fails"""
# Mock Redis to fail
mock_client = Mock()
mock_client.setex.side_effect = Exception("Redis failure")
mock_redis.return_value = mock_client
request = ContinuationRequest(prompt="Test prompt")
content = "Analysis result"
continuation_data = {"remaining_turns": 4, "tool_name": "test_continuation"}
# Should fallback to normal response
response = self.tool._create_continuation_offer_response(content, continuation_data, request)
assert response.status == "success"
assert response.content == content
assert response.continuation_offer is None
@patch("utils.conversation_memory.get_redis_client")
def test_continuation_offer_message_format(self, mock_redis):
"""Test that continuation offer message is properly formatted for Claude"""
mock_client = Mock()
mock_redis.return_value = mock_client
request = ContinuationRequest(prompt="Analyze architecture")
content = "Architecture analysis complete."
continuation_data = {"remaining_turns": 3, "tool_name": "test_continuation"}
# Mock existing thread context at max turns
from utils.conversation_memory import ConversationTurn, ThreadContext
response = self.tool._create_continuation_offer_response(content, continuation_data, request)
# Create turns at the limit (MAX_CONVERSATION_TURNS - 1 since we're about to add one)
turns = [
ConversationTurn(
role="assistant" if i % 2 else "user",
content=f"Turn {i+1}",
timestamp="2023-01-01T00:00:00Z",
tool_name="test_continuation",
)
for i in range(MAX_CONVERSATION_TURNS - 1)
]
offer = response.continuation_offer
message = offer.message_to_user
thread_context = ThreadContext(
thread_id="12345678-1234-1234-1234-123456789012",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:01:00Z",
tool_name="test_continuation",
turns=turns,
initial_context={"prompt": "Initial"},
)
mock_client.get.return_value = thread_context.model_dump_json()
# Check message contains key information for Claude
assert "continue this analysis" in message
assert "continuation_id" in message
assert "test_continuation tool call" in message
assert "3 more exchange(s)" in message
# Mock the model
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Final response.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Check suggested params are properly formatted
suggested_params = offer.suggested_tool_params
assert "continuation_id" in suggested_params
assert "prompt" in suggested_params
assert isinstance(suggested_params["continuation_id"], str)
# Execute tool with continuation_id at max turns
arguments = {"prompt": "Final question", "continuation_id": "12345678-1234-1234-1234-123456789012"}
response = await self.tool.execute(arguments)
@patch("utils.conversation_memory.get_redis_client")
def test_continuation_offer_metadata(self, mock_redis):
"""Test that continuation offer includes proper metadata"""
mock_client = Mock()
mock_redis.return_value = mock_client
# Parse response
response_data = json.loads(response[0].text)
request = ContinuationRequest(prompt="Test")
content = "Test content"
continuation_data = {"remaining_turns": 2, "tool_name": "test_continuation"}
response = self.tool._create_continuation_offer_response(content, continuation_data, request)
metadata = response.metadata
assert metadata["tool_name"] == "test_continuation"
assert metadata["remaining_turns"] == 2
assert "thread_id" in metadata
assert len(metadata["thread_id"]) == 36 # UUID length
# Should NOT offer continuation since we're at max turns
assert response_data["status"] == "success"
assert response_data.get("continuation_offer") is None
class TestContinuationIntegration:
@@ -323,7 +333,8 @@ class TestContinuationIntegration:
self.tool = ClaudeContinuationTool()
@patch("utils.conversation_memory.get_redis_client")
def test_continuation_offer_creates_proper_thread(self, mock_redis):
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_continuation_offer_creates_proper_thread(self, mock_redis):
"""Test that continuation offers create properly formatted threads"""
mock_client = Mock()
mock_redis.return_value = mock_client
@@ -339,77 +350,119 @@ class TestContinuationIntegration:
mock_client.get.side_effect = side_effect_get
request = ContinuationRequest(prompt="Initial analysis", files=["/test/file.py"])
content = "Analysis result"
continuation_data = {"remaining_turns": 4, "tool_name": "test_continuation"}
# Mock the model
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Analysis result",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
self.tool._create_continuation_offer_response(content, continuation_data, request)
# Execute tool for initial analysis
arguments = {"prompt": "Initial analysis", "files": ["/test/file.py"]}
response = await self.tool.execute(arguments)
# Verify thread creation was called (should be called twice: create_thread + add_turn)
assert mock_client.setex.call_count == 2
# Parse response
response_data = json.loads(response[0].text)
# Check the first call (create_thread)
first_call = mock_client.setex.call_args_list[0]
thread_key = first_call[0][0]
assert thread_key.startswith("thread:")
assert len(thread_key.split(":")[-1]) == 36 # UUID length
# Should offer continuation
assert response_data["status"] == "continuation_available"
assert "continuation_offer" in response_data
# Check the second call (add_turn) which should have the assistant response
second_call = mock_client.setex.call_args_list[1]
thread_data = second_call[0][2]
thread_context = json.loads(thread_data)
# Verify thread creation was called (should be called twice: create_thread + add_turn)
assert mock_client.setex.call_count == 2
assert thread_context["tool_name"] == "test_continuation"
assert len(thread_context["turns"]) == 1 # Assistant's response added
assert thread_context["turns"][0]["role"] == "assistant"
assert thread_context["turns"][0]["content"] == content
assert thread_context["turns"][0]["files"] == ["/test/file.py"] # Files from request
assert thread_context["initial_context"]["prompt"] == "Initial analysis"
assert thread_context["initial_context"]["files"] == ["/test/file.py"]
# Check the first call (create_thread)
first_call = mock_client.setex.call_args_list[0]
thread_key = first_call[0][0]
assert thread_key.startswith("thread:")
assert len(thread_key.split(":")[-1]) == 36 # UUID length
# Check the second call (add_turn) which should have the assistant response
second_call = mock_client.setex.call_args_list[1]
thread_data = second_call[0][2]
thread_context = json.loads(thread_data)
assert thread_context["tool_name"] == "test_continuation"
assert len(thread_context["turns"]) == 1 # Assistant's response added
assert thread_context["turns"][0]["role"] == "assistant"
assert thread_context["turns"][0]["content"] == "Analysis result"
assert thread_context["turns"][0]["files"] == ["/test/file.py"] # Files from request
assert thread_context["initial_context"]["prompt"] == "Initial analysis"
assert thread_context["initial_context"]["files"] == ["/test/file.py"]
@patch("utils.conversation_memory.get_redis_client")
def test_claude_can_use_continuation_id(self, mock_redis):
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_claude_can_use_continuation_id(self, mock_redis):
"""Test that Claude can use the provided continuation_id in subsequent calls"""
mock_client = Mock()
mock_redis.return_value = mock_client
# Step 1: Initial request creates continuation offer
request1 = ToolRequest(prompt="Analyze code structure")
continuation_data = {"remaining_turns": 4, "tool_name": "test_continuation"}
response1 = self.tool._create_continuation_offer_response(
"Structure analysis done.", continuation_data, request1
)
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Structure analysis done.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
thread_id = response1.continuation_offer.continuation_id
# Execute initial request
arguments = {"prompt": "Analyze code structure"}
response = await self.tool.execute(arguments)
# Step 2: Mock the thread context for Claude's follow-up
from utils.conversation_memory import ConversationTurn, ThreadContext
# Parse response
response_data = json.loads(response[0].text)
thread_id = response_data["continuation_offer"]["continuation_id"]
existing_context = ThreadContext(
thread_id=thread_id,
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:01:00Z",
tool_name="test_continuation",
turns=[
ConversationTurn(
role="assistant",
content="Structure analysis done.",
timestamp="2023-01-01T00:00:30Z",
tool_name="test_continuation",
)
],
initial_context={"prompt": "Analyze code structure"},
)
mock_client.get.return_value = existing_context.model_dump_json()
# Step 2: Mock the thread context for Claude's follow-up
from utils.conversation_memory import ConversationTurn, ThreadContext
# Step 3: Claude uses continuation_id
request2 = ToolRequest(prompt="Now analyze the performance aspects", continuation_id=thread_id)
existing_context = ThreadContext(
thread_id=thread_id,
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:01:00Z",
tool_name="test_continuation",
turns=[
ConversationTurn(
role="assistant",
content="Structure analysis done.",
timestamp="2023-01-01T00:00:30Z",
tool_name="test_continuation",
)
],
initial_context={"prompt": "Analyze code structure"},
)
mock_client.get.return_value = existing_context.model_dump_json()
# Should still offer continuation if there are remaining turns
continuation_data2 = self.tool._check_continuation_opportunity(request2)
assert continuation_data2 is not None
assert continuation_data2["remaining_turns"] == 8 # MAX_CONVERSATION_TURNS(10) - current_turns(1) - 1
assert continuation_data2["tool_name"] == "test_continuation"
# Step 3: Claude uses continuation_id
mock_provider.generate_content.return_value = Mock(
content="Performance analysis done.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
arguments2 = {"prompt": "Now analyze the performance aspects", "continuation_id": thread_id}
response2 = await self.tool.execute(arguments2)
# Parse response
response_data2 = json.loads(response2[0].text)
# Should still offer continuation if there are remaining turns
assert response_data2["status"] == "continuation_available"
assert "continuation_offer" in response_data2
# 10 max - 1 existing - 1 new = 8 remaining
assert response_data2["continuation_offer"]["remaining_turns"] == 8
if __name__ == "__main__":

View File

@@ -7,6 +7,7 @@ from unittest.mock import Mock, patch
import pytest
from tests.mock_helpers import create_mock_provider
from tools.analyze import AnalyzeTool
from tools.debug import DebugIssueTool
from tools.models import ClarificationRequest, ToolOutput
@@ -24,8 +25,8 @@ class TestDynamicContextRequests:
return DebugIssueTool()
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_clarification_request_parsing(self, mock_create_model, analyze_tool):
@patch("tools.base.BaseTool.get_model_provider")
async def test_clarification_request_parsing(self, mock_get_provider, analyze_tool):
"""Test that tools correctly parse clarification requests"""
# Mock model to return a clarification request
clarification_json = json.dumps(
@@ -36,16 +37,18 @@ class TestDynamicContextRequests:
}
)
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
result = await analyze_tool.execute(
{
"files": ["/absolute/path/src/index.js"],
"question": "Analyze the dependencies used in this project",
"prompt": "Analyze the dependencies used in this project",
}
)
@@ -62,8 +65,8 @@ class TestDynamicContextRequests:
assert clarification["files_needed"] == ["package.json", "package-lock.json"]
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_normal_response_not_parsed_as_clarification(self, mock_create_model, debug_tool):
@patch("tools.base.BaseTool.get_model_provider")
async def test_normal_response_not_parsed_as_clarification(self, mock_get_provider, debug_tool):
"""Test that normal responses are not mistaken for clarification requests"""
normal_response = """
## Summary
@@ -75,13 +78,15 @@ class TestDynamicContextRequests:
**Root Cause:** The module 'utils' is not imported
"""
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text=normal_response)]))]
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=normal_response, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
result = await debug_tool.execute({"error_description": "NameError: name 'utils' is not defined"})
result = await debug_tool.execute({"prompt": "NameError: name 'utils' is not defined"})
assert len(result) == 1
@@ -92,18 +97,20 @@ class TestDynamicContextRequests:
assert "Summary" in response_data["content"]
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_malformed_clarification_request_treated_as_normal(self, mock_create_model, analyze_tool):
@patch("tools.base.BaseTool.get_model_provider")
async def test_malformed_clarification_request_treated_as_normal(self, mock_get_provider, analyze_tool):
"""Test that malformed JSON clarification requests are treated as normal responses"""
malformed_json = '{"status": "requires_clarification", "question": "Missing closing brace"'
malformed_json = '{"status": "requires_clarification", "prompt": "Missing closing brace"'
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text=malformed_json)]))]
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=malformed_json, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "question": "What does this do?"})
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "prompt": "What does this do?"})
assert len(result) == 1
@@ -113,8 +120,8 @@ class TestDynamicContextRequests:
assert malformed_json in response_data["content"]
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_clarification_with_suggested_action(self, mock_create_model, debug_tool):
@patch("tools.base.BaseTool.get_model_provider")
async def test_clarification_with_suggested_action(self, mock_get_provider, debug_tool):
"""Test clarification request with suggested next action"""
clarification_json = json.dumps(
{
@@ -124,7 +131,7 @@ class TestDynamicContextRequests:
"suggested_next_action": {
"tool": "debug",
"args": {
"error_description": "Connection timeout to database",
"prompt": "Connection timeout to database",
"files": [
"/config/database.yml",
"/src/db.py",
@@ -135,15 +142,17 @@ class TestDynamicContextRequests:
}
)
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
result = await debug_tool.execute(
{
"error_description": "Connection timeout to database",
"prompt": "Connection timeout to database",
"files": ["/absolute/logs/error.log"],
}
)
@@ -187,12 +196,12 @@ class TestDynamicContextRequests:
assert request.suggested_next_action["tool"] == "analyze"
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_error_response_format(self, mock_create_model, analyze_tool):
@patch("tools.base.BaseTool.get_model_provider")
async def test_error_response_format(self, mock_get_provider, analyze_tool):
"""Test error response format"""
mock_create_model.side_effect = Exception("API connection failed")
mock_get_provider.side_effect = Exception("API connection failed")
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "question": "Analyze this"})
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "prompt": "Analyze this"})
assert len(result) == 1
@@ -206,8 +215,8 @@ class TestCollaborationWorkflow:
"""Test complete collaboration workflows"""
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_dependency_analysis_triggers_clarification(self, mock_create_model):
@patch("tools.base.BaseTool.get_model_provider")
async def test_dependency_analysis_triggers_clarification(self, mock_get_provider):
"""Test that asking about dependencies without package files triggers clarification"""
tool = AnalyzeTool()
@@ -220,17 +229,19 @@ class TestCollaborationWorkflow:
}
)
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
# Ask about dependencies with only source files
result = await tool.execute(
{
"files": ["/absolute/path/src/index.js"],
"question": "What npm packages and versions does this project use?",
"prompt": "What npm packages and versions does this project use?",
}
)
@@ -243,8 +254,8 @@ class TestCollaborationWorkflow:
assert "package.json" in str(clarification["files_needed"]), "Should specifically request package.json"
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_multi_step_collaboration(self, mock_create_model):
@patch("tools.base.BaseTool.get_model_provider")
async def test_multi_step_collaboration(self, mock_get_provider):
"""Test a multi-step collaboration workflow"""
tool = DebugIssueTool()
@@ -257,15 +268,17 @@ class TestCollaborationWorkflow:
}
)
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
result1 = await tool.execute(
{
"error_description": "Database connection timeout",
"prompt": "Database connection timeout",
"error_context": "Timeout after 30s",
}
)
@@ -285,13 +298,13 @@ class TestCollaborationWorkflow:
**Root Cause:** The config.py file shows the database host is set to 'localhost' but the database is running on a different server.
"""
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text=final_response)]))]
mock_provider.generate_content.return_value = Mock(
content=final_response, usage={}, model_name="gemini-2.0-flash", metadata={}
)
result2 = await tool.execute(
{
"error_description": "Database connection timeout",
"prompt": "Database connection timeout",
"error_context": "Timeout after 30s",
"files": ["/absolute/path/config.py"], # Additional context provided
}

View File

@@ -31,7 +31,8 @@ class TestConfig:
def test_model_config(self):
"""Test model configuration"""
assert DEFAULT_MODEL == "gemini-2.5-pro-preview-06-05"
# DEFAULT_MODEL is set in conftest.py for tests
assert DEFAULT_MODEL == "gemini-2.0-flash"
assert MAX_CONTEXT_TOKENS == 1_000_000
def test_temperature_defaults(self):

View File

@@ -0,0 +1,177 @@
"""
Test that conversation history is correctly mapped to tool-specific fields
"""
import os
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from providers.base import ProviderType
from server import reconstruct_thread_context
from utils.conversation_memory import ConversationTurn, ThreadContext
@pytest.mark.asyncio
async def test_conversation_history_field_mapping():
"""Test that enhanced prompts are mapped to prompt field for all tools"""
# Test data for different tools - all use 'prompt' now
test_cases = [
{
"tool_name": "analyze",
"original_value": "What does this code do?",
},
{
"tool_name": "chat",
"original_value": "Explain this concept",
},
{
"tool_name": "debug",
"original_value": "Getting undefined error",
},
{
"tool_name": "codereview",
"original_value": "Review this implementation",
},
{
"tool_name": "thinkdeep",
"original_value": "My analysis so far",
},
]
for test_case in test_cases:
# Create mock conversation context
mock_context = ThreadContext(
thread_id="test-thread-123",
tool_name=test_case["tool_name"],
created_at=datetime.now().isoformat(),
last_updated_at=datetime.now().isoformat(),
turns=[
ConversationTurn(
role="user",
content="Previous user message",
timestamp=datetime.now().isoformat(),
files=["/test/file1.py"],
),
ConversationTurn(
role="assistant",
content="Previous assistant response",
timestamp=datetime.now().isoformat(),
),
],
initial_context={},
)
# Mock get_thread to return our test context
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
with patch("utils.conversation_memory.add_turn", return_value=True):
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
# Mock provider registry to avoid model lookup errors
with patch("providers.registry.ModelProviderRegistry.get_provider_for_model") as mock_get_provider:
from providers.base import ModelCapabilities
mock_provider = MagicMock()
mock_provider.get_capabilities.return_value = ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.0-flash",
friendly_name="Gemini",
max_tokens=200000,
supports_extended_thinking=True,
)
mock_get_provider.return_value = mock_provider
# Mock conversation history building
mock_build.return_value = (
"=== CONVERSATION HISTORY ===\nPrevious conversation content\n=== END HISTORY ===",
1000, # mock token count
)
# Create arguments with continuation_id
arguments = {
"continuation_id": "test-thread-123",
"prompt": test_case["original_value"],
"files": ["/test/file2.py"],
}
# Call reconstruct_thread_context
enhanced_args = await reconstruct_thread_context(arguments)
# Verify the enhanced prompt is in the prompt field
assert "prompt" in enhanced_args
enhanced_value = enhanced_args["prompt"]
# Should contain conversation history
assert "=== CONVERSATION HISTORY ===" in enhanced_value
assert "Previous conversation content" in enhanced_value
# Should contain the new user input
assert "=== NEW USER INPUT ===" in enhanced_value
assert test_case["original_value"] in enhanced_value
# Should have token budget
assert "_remaining_tokens" in enhanced_args
assert enhanced_args["_remaining_tokens"] > 0
@pytest.mark.asyncio
async def test_unknown_tool_defaults_to_prompt():
"""Test that unknown tools default to using 'prompt' field"""
mock_context = ThreadContext(
thread_id="test-thread-456",
tool_name="unknown_tool",
created_at=datetime.now().isoformat(),
last_updated_at=datetime.now().isoformat(),
turns=[],
initial_context={},
)
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
with patch("utils.conversation_memory.add_turn", return_value=True):
with patch("utils.conversation_memory.build_conversation_history", return_value=("History", 500)):
with patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False):
from providers.registry import ModelProviderRegistry
ModelProviderRegistry.clear_cache()
arguments = {
"continuation_id": "test-thread-456",
"prompt": "User input",
}
enhanced_args = await reconstruct_thread_context(arguments)
# Should default to 'prompt' field
assert "prompt" in enhanced_args
assert "History" in enhanced_args["prompt"]
@pytest.mark.asyncio
async def test_tool_parameter_standardization():
"""Test that all tools use standardized 'prompt' parameter"""
from tools.analyze import AnalyzeRequest
from tools.codereview import CodeReviewRequest
from tools.debug import DebugIssueRequest
from tools.precommit import PrecommitRequest
from tools.thinkdeep import ThinkDeepRequest
# Test analyze tool uses prompt
analyze = AnalyzeRequest(files=["/test.py"], prompt="What does this do?")
assert analyze.prompt == "What does this do?"
# Test debug tool uses prompt
debug = DebugIssueRequest(prompt="Error occurred")
assert debug.prompt == "Error occurred"
# Test codereview tool uses prompt
review = CodeReviewRequest(files=["/test.py"], prompt="Review this")
assert review.prompt == "Review this"
# Test thinkdeep tool uses prompt
think = ThinkDeepRequest(prompt="My analysis")
assert think.prompt == "My analysis"
# Test precommit tool uses prompt (optional)
precommit = PrecommitRequest(path="/repo", prompt="Fix bug")
assert precommit.prompt == "Fix bug"

View File

@@ -16,6 +16,7 @@ from unittest.mock import Mock, patch
import pytest
from pydantic import Field
from tests.mock_helpers import create_mock_provider
from tools.base import BaseTool, ToolRequest
from utils.conversation_memory import ConversationTurn, ThreadContext
@@ -72,30 +73,10 @@ class TestConversationHistoryBugFix:
async def test_conversation_history_included_with_continuation_id(self, mock_add_turn):
"""Test that conversation history (including file context) is included when using continuation_id"""
# Create a thread context with previous turns including files
_thread_context = ThreadContext(
thread_id="test-history-id",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:02:00Z",
tool_name="analyze", # Started with analyze tool
turns=[
ConversationTurn(
role="assistant",
content="I've analyzed the authentication module and found several security issues.",
timestamp="2023-01-01T00:01:00Z",
tool_name="analyze",
files=["/src/auth.py", "/src/security.py"], # Files from analyze tool
),
ConversationTurn(
role="assistant",
content="The code review shows these files have critical vulnerabilities.",
timestamp="2023-01-01T00:02:00Z",
tool_name="codereview",
files=["/src/auth.py", "/tests/test_auth.py"], # Files from codereview tool
),
],
initial_context={"question": "Analyze authentication security"},
)
# Test setup note: This test simulates a conversation thread with previous turns
# containing files from different tools (analyze -> codereview)
# The continuation_id "test-history-id" references this implicit thread context
# In the real flow, server.py would reconstruct this context and add it to the prompt
# Mock add_turn to return success
mock_add_turn.return_value = True
@@ -103,23 +84,23 @@ class TestConversationHistoryBugFix:
# Mock the model to capture what prompt it receives
captured_prompt = None
with patch.object(self.tool, "create_model") as mock_create_model:
mock_model = Mock()
mock_response = Mock()
mock_response.candidates = [
Mock(
content=Mock(parts=[Mock(text="Response with conversation context")]),
finish_reason="STOP",
)
]
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
def capture_prompt(prompt):
def capture_prompt(prompt, **kwargs):
nonlocal captured_prompt
captured_prompt = prompt
return mock_response
return Mock(
content="Response with conversation context",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_model.generate_content.side_effect = capture_prompt
mock_create_model.return_value = mock_model
mock_provider.generate_content.side_effect = capture_prompt
mock_get_provider.return_value = mock_provider
# Execute tool with continuation_id
# In the corrected flow, server.py:reconstruct_thread_context
@@ -163,23 +144,23 @@ class TestConversationHistoryBugFix:
captured_prompt = None
with patch.object(self.tool, "create_model") as mock_create_model:
mock_model = Mock()
mock_response = Mock()
mock_response.candidates = [
Mock(
content=Mock(parts=[Mock(text="Response without history")]),
finish_reason="STOP",
)
]
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
def capture_prompt(prompt):
def capture_prompt(prompt, **kwargs):
nonlocal captured_prompt
captured_prompt = prompt
return mock_response
return Mock(
content="Response without history",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_model.generate_content.side_effect = capture_prompt
mock_create_model.return_value = mock_model
mock_provider.generate_content.side_effect = capture_prompt
mock_get_provider.return_value = mock_provider
# Execute tool with continuation_id for non-existent thread
# In the real flow, server.py would have already handled the missing thread
@@ -201,23 +182,23 @@ class TestConversationHistoryBugFix:
captured_prompt = None
with patch.object(self.tool, "create_model") as mock_create_model:
mock_model = Mock()
mock_response = Mock()
mock_response.candidates = [
Mock(
content=Mock(parts=[Mock(text="New conversation response")]),
finish_reason="STOP",
)
]
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
def capture_prompt(prompt):
def capture_prompt(prompt, **kwargs):
nonlocal captured_prompt
captured_prompt = prompt
return mock_response
return Mock(
content="New conversation response",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_model.generate_content.side_effect = capture_prompt
mock_create_model.return_value = mock_model
mock_provider.generate_content.side_effect = capture_prompt
mock_get_provider.return_value = mock_provider
# Execute tool without continuation_id (new conversation)
arguments = {"prompt": "Start new conversation", "files": ["/src/new_file.py"]}
@@ -235,7 +216,7 @@ class TestConversationHistoryBugFix:
# Should include follow-up instructions for new conversation
# (This is the existing behavior for new conversations)
assert "If you'd like to ask a follow-up question" in captured_prompt
assert "CONVERSATION CONTINUATION" in captured_prompt
@patch("tools.base.get_thread")
@patch("tools.base.add_turn")
@@ -275,7 +256,7 @@ class TestConversationHistoryBugFix:
files=["/src/auth.py", "/tests/test_auth.py"], # auth.py referenced again + new file
),
],
initial_context={"question": "Analyze authentication security"},
initial_context={"prompt": "Analyze authentication security"},
)
# Mock get_thread to return our test context
@@ -285,23 +266,23 @@ class TestConversationHistoryBugFix:
# Mock the model to capture what prompt it receives
captured_prompt = None
with patch.object(self.tool, "create_model") as mock_create_model:
mock_model = Mock()
mock_response = Mock()
mock_response.candidates = [
Mock(
content=Mock(parts=[Mock(text="Analysis of new files complete")]),
finish_reason="STOP",
)
]
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
def capture_prompt(prompt):
def capture_prompt(prompt, **kwargs):
nonlocal captured_prompt
captured_prompt = prompt
return mock_response
return Mock(
content="Analysis of new files complete",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_model.generate_content.side_effect = capture_prompt
mock_create_model.return_value = mock_model
mock_provider.generate_content.side_effect = capture_prompt
mock_get_provider.return_value = mock_provider
# Mock read_files to simulate file existence and capture its calls
with patch("tools.base.read_files") as mock_read_files:

View File

@@ -5,6 +5,7 @@ Tests the Redis-based conversation persistence needed for AI-to-AI multi-turn
discussions in stateless MCP environments.
"""
import os
from unittest.mock import Mock, patch
import pytest
@@ -136,8 +137,13 @@ class TestConversationMemory:
assert success is False
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False)
def test_build_conversation_history(self):
"""Test building conversation history format with files and speaker identification"""
from providers.registry import ModelProviderRegistry
ModelProviderRegistry.clear_cache()
test_uuid = "12345678-1234-1234-1234-123456789012"
turns = [
@@ -151,7 +157,6 @@ class TestConversationMemory:
role="assistant",
content="Python is a programming language",
timestamp="2023-01-01T00:01:00Z",
follow_up_question="Would you like examples?",
files=["/home/user/examples/"],
tool_name="chat",
),
@@ -166,7 +171,7 @@ class TestConversationMemory:
initial_context={},
)
history, tokens = build_conversation_history(context)
history, tokens = build_conversation_history(context, model_context=None)
# Test basic structure
assert "CONVERSATION HISTORY" in history
@@ -188,11 +193,8 @@ class TestConversationMemory:
assert "The following files have been shared and analyzed during our conversation." in history
# Check that file context from previous turns is included (now shows files used per turn)
assert "📁 Files used in this turn: /home/user/main.py, /home/user/docs/readme.md" in history
assert "📁 Files used in this turn: /home/user/examples/" in history
# Test follow-up attribution
assert "[Gemini's Follow-up: Would you like examples?]" in history
assert "Files used in this turn: /home/user/main.py, /home/user/docs/readme.md" in history
assert "Files used in this turn: /home/user/examples/" in history
def test_build_conversation_history_empty(self):
"""Test building history with no turns"""
@@ -207,7 +209,7 @@ class TestConversationMemory:
initial_context={},
)
history, tokens = build_conversation_history(context)
history, tokens = build_conversation_history(context, model_context=None)
assert history == ""
assert tokens == 0
@@ -235,12 +237,11 @@ class TestConversationFlow:
)
mock_client.get.return_value = initial_context.model_dump_json()
# Add assistant response with follow-up
# Add assistant response
success = add_turn(
thread_id,
"assistant",
"Code analysis complete",
follow_up_question="Would you like me to check error handling?",
)
assert success is True
@@ -256,7 +257,6 @@ class TestConversationFlow:
role="assistant",
content="Code analysis complete",
timestamp="2023-01-01T00:00:30Z",
follow_up_question="Would you like me to check error handling?",
)
],
initial_context={"prompt": "Analyze this code"},
@@ -266,9 +266,7 @@ class TestConversationFlow:
success = add_turn(thread_id, "user", "Yes, check error handling")
assert success is True
success = add_turn(
thread_id, "assistant", "Error handling reviewed", follow_up_question="Should I examine the test coverage?"
)
success = add_turn(thread_id, "assistant", "Error handling reviewed")
assert success is True
# REQUEST 3-5: Continue conversation (simulating independent cycles)
@@ -283,14 +281,12 @@ class TestConversationFlow:
role="assistant",
content="Code analysis complete",
timestamp="2023-01-01T00:00:30Z",
follow_up_question="Would you like me to check error handling?",
),
ConversationTurn(role="user", content="Yes, check error handling", timestamp="2023-01-01T00:01:30Z"),
ConversationTurn(
role="assistant",
content="Error handling reviewed",
timestamp="2023-01-01T00:02:30Z",
follow_up_question="Should I examine the test coverage?",
),
],
initial_context={"prompt": "Analyze this code"},
@@ -349,8 +345,13 @@ class TestConversationFlow:
in error_msg
)
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False)
def test_dynamic_max_turns_configuration(self):
"""Test that all functions respect MAX_CONVERSATION_TURNS configuration"""
from providers.registry import ModelProviderRegistry
ModelProviderRegistry.clear_cache()
# This test ensures if we change MAX_CONVERSATION_TURNS, everything updates
# Test with different max values by patching the constant
@@ -374,7 +375,7 @@ class TestConversationFlow:
initial_context={},
)
history, tokens = build_conversation_history(context)
history, tokens = build_conversation_history(context, model_context=None)
expected_turn_text = f"Turn {test_max}/{MAX_CONVERSATION_TURNS}"
assert expected_turn_text in history
@@ -385,18 +386,20 @@ class TestConversationFlow:
# Test early conversation (should allow follow-ups)
early_instructions = get_follow_up_instructions(0, max_turns)
assert "CONVERSATION THREADING" in early_instructions
assert "CONVERSATION CONTINUATION" in early_instructions
assert f"({max_turns - 1} exchanges remaining)" in early_instructions
assert "Feel free to ask clarifying questions" in early_instructions
# Test mid conversation
mid_instructions = get_follow_up_instructions(2, max_turns)
assert "CONVERSATION THREADING" in mid_instructions
assert "CONVERSATION CONTINUATION" in mid_instructions
assert f"({max_turns - 3} exchanges remaining)" in mid_instructions
assert "Feel free to ask clarifying questions" in mid_instructions
# Test approaching limit (should stop follow-ups)
limit_instructions = get_follow_up_instructions(max_turns - 1, max_turns)
assert "Do NOT include any follow-up questions" in limit_instructions
assert "FOLLOW-UP CONVERSATIONS" not in limit_instructions
assert "final exchange" in limit_instructions
# Test at limit
at_limit_instructions = get_follow_up_instructions(max_turns, max_turns)
@@ -473,8 +476,13 @@ class TestConversationFlow:
assert success is False, f"Turn {MAX_CONVERSATION_TURNS + 1} should fail"
@patch("utils.conversation_memory.get_redis_client")
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False)
def test_conversation_with_files_and_context_preservation(self, mock_redis):
"""Test complete conversation flow with file tracking and context preservation"""
from providers.registry import ModelProviderRegistry
ModelProviderRegistry.clear_cache()
mock_client = Mock()
mock_redis.return_value = mock_client
@@ -492,12 +500,11 @@ class TestConversationFlow:
)
mock_client.get.return_value = initial_context.model_dump_json()
# Add Gemini's response with follow-up
# Add Gemini's response
success = add_turn(
thread_id,
"assistant",
"I've analyzed your codebase structure.",
follow_up_question="Would you like me to examine the test coverage?",
files=["/project/src/main.py", "/project/src/utils.py"],
tool_name="analyze",
)
@@ -514,7 +521,6 @@ class TestConversationFlow:
role="assistant",
content="I've analyzed your codebase structure.",
timestamp="2023-01-01T00:00:30Z",
follow_up_question="Would you like me to examine the test coverage?",
files=["/project/src/main.py", "/project/src/utils.py"],
tool_name="analyze",
)
@@ -540,7 +546,6 @@ class TestConversationFlow:
role="assistant",
content="I've analyzed your codebase structure.",
timestamp="2023-01-01T00:00:30Z",
follow_up_question="Would you like me to examine the test coverage?",
files=["/project/src/main.py", "/project/src/utils.py"],
tool_name="analyze",
),
@@ -575,7 +580,6 @@ class TestConversationFlow:
role="assistant",
content="I've analyzed your codebase structure.",
timestamp="2023-01-01T00:00:30Z",
follow_up_question="Would you like me to examine the test coverage?",
files=["/project/src/main.py", "/project/src/utils.py"],
tool_name="analyze",
),
@@ -604,19 +608,18 @@ class TestConversationFlow:
assert "--- Turn 3 (Gemini using analyze) ---" in history
# Verify all files are preserved in chronological order
turn_1_files = "📁 Files used in this turn: /project/src/main.py, /project/src/utils.py"
turn_2_files = "📁 Files used in this turn: /project/tests/, /project/test_main.py"
turn_3_files = "📁 Files used in this turn: /project/tests/test_utils.py, /project/coverage.html"
turn_1_files = "Files used in this turn: /project/src/main.py, /project/src/utils.py"
turn_2_files = "Files used in this turn: /project/tests/, /project/test_main.py"
turn_3_files = "Files used in this turn: /project/tests/test_utils.py, /project/coverage.html"
assert turn_1_files in history
assert turn_2_files in history
assert turn_3_files in history
# Verify content and follow-ups
# Verify content
assert "I've analyzed your codebase structure." in history
assert "Yes, check the test coverage" in history
assert "Test coverage analysis complete. Coverage is 85%." in history
assert "[Gemini's Follow-up: Would you like me to examine the test coverage?]" in history
# Verify chronological ordering (turn 1 appears before turn 2, etc.)
turn_1_pos = history.find("--- Turn 1 (Gemini using analyze) ---")
@@ -625,56 +628,6 @@ class TestConversationFlow:
assert turn_1_pos < turn_2_pos < turn_3_pos
@patch("utils.conversation_memory.get_redis_client")
def test_follow_up_question_parsing_cycle(self, mock_redis):
"""Test follow-up question persistence across request cycles"""
mock_client = Mock()
mock_redis.return_value = mock_client
thread_id = "12345678-1234-1234-1234-123456789012"
# First cycle: Assistant generates follow-up
context = ThreadContext(
thread_id=thread_id,
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:00:00Z",
tool_name="debug",
turns=[],
initial_context={"prompt": "Debug this error"},
)
mock_client.get.return_value = context.model_dump_json()
success = add_turn(
thread_id,
"assistant",
"Found potential issue in authentication",
follow_up_question="Should I examine the authentication middleware?",
)
assert success is True
# Second cycle: Retrieve conversation history
context_with_followup = ThreadContext(
thread_id=thread_id,
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:01:00Z",
tool_name="debug",
turns=[
ConversationTurn(
role="assistant",
content="Found potential issue in authentication",
timestamp="2023-01-01T00:00:30Z",
follow_up_question="Should I examine the authentication middleware?",
)
],
initial_context={"prompt": "Debug this error"},
)
mock_client.get.return_value = context_with_followup.model_dump_json()
# Build history to verify follow-up is preserved
history, tokens = build_conversation_history(context_with_followup)
assert "Found potential issue in authentication" in history
assert "[Gemini's Follow-up: Should I examine the authentication middleware?]" in history
@patch("utils.conversation_memory.get_redis_client")
def test_stateless_request_isolation(self, mock_redis):
"""Test that each request cycle is independent but shares context via Redis"""
@@ -695,9 +648,7 @@ class TestConversationFlow:
)
mock_client.get.return_value = initial_context.model_dump_json()
success = add_turn(
thread_id, "assistant", "Architecture analysis", follow_up_question="Want to explore scalability?"
)
success = add_turn(thread_id, "assistant", "Architecture analysis")
assert success is True
# Process 2: Different "request cycle" accesses same thread
@@ -711,7 +662,6 @@ class TestConversationFlow:
role="assistant",
content="Architecture analysis",
timestamp="2023-01-01T00:00:30Z",
follow_up_question="Want to explore scalability?",
)
],
initial_context={"prompt": "Think about architecture"},
@@ -722,13 +672,17 @@ class TestConversationFlow:
retrieved_context = get_thread(thread_id)
assert retrieved_context is not None
assert len(retrieved_context.turns) == 1
assert retrieved_context.turns[0].follow_up_question == "Want to explore scalability?"
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False)
def test_token_limit_optimization_in_conversation_history(self):
"""Test that build_conversation_history efficiently handles token limits"""
import os
import tempfile
from providers.registry import ModelProviderRegistry
ModelProviderRegistry.clear_cache()
from utils.conversation_memory import build_conversation_history
# Create test files with known content sizes
@@ -763,10 +717,10 @@ class TestConversationFlow:
)
# Build conversation history (should handle token limits gracefully)
history, tokens = build_conversation_history(context)
history, tokens = build_conversation_history(context, model_context=None)
# Verify the history was built successfully
assert "=== CONVERSATION HISTORY ===" in history
assert "=== CONVERSATION HISTORY" in history
assert "=== FILES REFERENCED IN THIS CONVERSATION ===" in history
# The small file should be included, but large file might be truncated

View File

@@ -6,11 +6,13 @@ allowing multi-turn conversations to span multiple tool types.
"""
import json
import os
from unittest.mock import Mock, patch
import pytest
from pydantic import Field
from tests.mock_helpers import create_mock_provider
from tools.base import BaseTool, ToolRequest
from utils.conversation_memory import ConversationTurn, ThreadContext
@@ -92,45 +94,36 @@ class TestCrossToolContinuation:
self.review_tool = MockReviewTool()
@patch("utils.conversation_memory.get_redis_client")
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_continuation_id_works_across_different_tools(self, mock_redis):
"""Test that a continuation_id from one tool can be used with another tool"""
mock_client = Mock()
mock_redis.return_value = mock_client
# Step 1: Analysis tool creates a conversation with follow-up
with patch.object(self.analysis_tool, "create_model") as mock_create_model:
mock_model = Mock()
mock_response = Mock()
mock_response.candidates = [
Mock(
content=Mock(
parts=[
Mock(
text="""Found potential security issues in authentication logic.
# Step 1: Analysis tool creates a conversation with continuation offer
with patch.object(self.analysis_tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
# Simple content without JSON follow-up
content = """Found potential security issues in authentication logic.
```json
{
"follow_up_question": "Would you like me to review these security findings in detail?",
"suggested_params": {"findings": "Authentication bypass vulnerability detected"},
"ui_hint": "Security review recommended"
}
```"""
)
]
),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
I'd be happy to review these security findings in detail if that would be helpful."""
mock_provider.generate_content.return_value = Mock(
content=content,
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Execute analysis tool
arguments = {"code": "function authenticate(user) { return true; }"}
response = await self.analysis_tool.execute(arguments)
response_data = json.loads(response[0].text)
assert response_data["status"] == "requires_continuation"
continuation_id = response_data["follow_up_request"]["continuation_id"]
assert response_data["status"] == "continuation_available"
continuation_id = response_data["continuation_offer"]["continuation_id"]
# Step 2: Mock the existing thread context for the review tool
# The thread was created by analysis_tool but will be continued by review_tool
@@ -142,10 +135,9 @@ class TestCrossToolContinuation:
turns=[
ConversationTurn(
role="assistant",
content="Found potential security issues in authentication logic.",
content="Found potential security issues in authentication logic.\n\nI'd be happy to review these security findings in detail if that would be helpful.",
timestamp="2023-01-01T00:00:30Z",
tool_name="test_analysis", # Original tool
follow_up_question="Would you like me to review these security findings in detail?",
)
],
initial_context={"code": "function authenticate(user) { return true; }"},
@@ -160,23 +152,17 @@ class TestCrossToolContinuation:
mock_client.get.side_effect = mock_get_side_effect
# Step 3: Review tool uses the same continuation_id
with patch.object(self.review_tool, "create_model") as mock_create_model:
mock_model = Mock()
mock_response = Mock()
mock_response.candidates = [
Mock(
content=Mock(
parts=[
Mock(
text="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks."
)
]
),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
with patch.object(self.review_tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Execute review tool with the continuation_id from analysis tool
arguments = {
@@ -245,9 +231,13 @@ class TestCrossToolContinuation:
)
# Build conversation history
from providers.registry import ModelProviderRegistry
from utils.conversation_memory import build_conversation_history
history, tokens = build_conversation_history(thread_context)
# Set up provider for this test
with patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False):
ModelProviderRegistry.clear_cache()
history, tokens = build_conversation_history(thread_context, model_context=None)
# Verify tool names are included in the history
assert "Turn 1 (Gemini using test_analysis)" in history
@@ -259,6 +249,7 @@ class TestCrossToolContinuation:
@patch("utils.conversation_memory.get_redis_client")
@patch("utils.conversation_memory.get_thread")
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_cross_tool_conversation_with_files_context(self, mock_get_thread, mock_redis):
"""Test that file context is preserved across tool switches"""
mock_client = Mock()
@@ -286,17 +277,17 @@ class TestCrossToolContinuation:
mock_get_thread.return_value = existing_context
# Mock review tool response
with patch.object(self.review_tool, "create_model") as mock_create_model:
mock_model = Mock()
mock_response = Mock()
mock_response.candidates = [
Mock(
content=Mock(parts=[Mock(text="Security review of auth.py shows vulnerabilities")]),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
with patch.object(self.review_tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Security review of auth.py shows vulnerabilities",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Execute review tool with additional files
arguments = {

View File

@@ -0,0 +1,181 @@
"""
Test suite for intelligent auto mode fallback logic
Tests the new dynamic model selection based on available API keys
"""
import os
from unittest.mock import Mock, patch
import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
class TestIntelligentFallback:
"""Test intelligent model fallback logic"""
def setup_method(self):
"""Setup for each test - clear registry cache"""
ModelProviderRegistry.clear_cache()
def teardown_method(self):
"""Cleanup after each test"""
ModelProviderRegistry.clear_cache()
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False)
def test_prefers_openai_o3_mini_when_available(self):
"""Test that o3-mini is preferred when OpenAI API key is available"""
ModelProviderRegistry.clear_cache()
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o3-mini"
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
def test_prefers_gemini_flash_when_openai_unavailable(self):
"""Test that gemini-2.0-flash is used when only Gemini API key is available"""
ModelProviderRegistry.clear_cache()
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "gemini-2.0-flash"
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
def test_prefers_openai_when_both_available(self):
"""Test that OpenAI is preferred when both API keys are available"""
ModelProviderRegistry.clear_cache()
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o3-mini" # OpenAI has priority
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False)
def test_fallback_when_no_keys_available(self):
"""Test fallback behavior when no API keys are available"""
ModelProviderRegistry.clear_cache()
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "gemini-2.0-flash" # Default fallback
def test_available_providers_with_keys(self):
"""Test the get_available_providers_with_keys method"""
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False):
ModelProviderRegistry.clear_cache()
available = ModelProviderRegistry.get_available_providers_with_keys()
assert ProviderType.OPENAI in available
assert ProviderType.GOOGLE not in available
with patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False):
ModelProviderRegistry.clear_cache()
available = ModelProviderRegistry.get_available_providers_with_keys()
assert ProviderType.GOOGLE in available
assert ProviderType.OPENAI not in available
def test_auto_mode_conversation_memory_integration(self):
"""Test that conversation memory uses intelligent fallback in auto mode"""
from utils.conversation_memory import ThreadContext, build_conversation_history
# Mock auto mode - patch the config module where these values are defined
with (
patch("config.IS_AUTO_MODE", True),
patch("config.DEFAULT_MODEL", "auto"),
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False),
):
ModelProviderRegistry.clear_cache()
# Create a context with at least one turn so it doesn't exit early
from utils.conversation_memory import ConversationTurn
context = ThreadContext(
thread_id="test-123",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:00:00Z",
tool_name="chat",
turns=[ConversationTurn(role="user", content="Test message", timestamp="2023-01-01T00:00:30Z")],
initial_context={},
)
# This should use o3-mini for token calculations since OpenAI is available
with patch("utils.model_context.ModelContext") as mock_context_class:
mock_context_instance = Mock()
mock_context_class.return_value = mock_context_instance
mock_context_instance.calculate_token_allocation.return_value = Mock(
file_tokens=10000, history_tokens=5000
)
# Mock estimate_tokens to return integers for proper summing
mock_context_instance.estimate_tokens.return_value = 100
history, tokens = build_conversation_history(context, model_context=None)
# Verify that ModelContext was called with o3-mini (the intelligent fallback)
mock_context_class.assert_called_once_with("o3-mini")
def test_auto_mode_with_gemini_only(self):
"""Test auto mode behavior when only Gemini API key is available"""
from utils.conversation_memory import ThreadContext, build_conversation_history
with (
patch("config.IS_AUTO_MODE", True),
patch("config.DEFAULT_MODEL", "auto"),
patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False),
):
ModelProviderRegistry.clear_cache()
from utils.conversation_memory import ConversationTurn
context = ThreadContext(
thread_id="test-456",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:00:00Z",
tool_name="analyze",
turns=[ConversationTurn(role="assistant", content="Test response", timestamp="2023-01-01T00:00:30Z")],
initial_context={},
)
with patch("utils.model_context.ModelContext") as mock_context_class:
mock_context_instance = Mock()
mock_context_class.return_value = mock_context_instance
mock_context_instance.calculate_token_allocation.return_value = Mock(
file_tokens=10000, history_tokens=5000
)
# Mock estimate_tokens to return integers for proper summing
mock_context_instance.estimate_tokens.return_value = 100
history, tokens = build_conversation_history(context, model_context=None)
# Should use gemini-2.0-flash when only Gemini is available
mock_context_class.assert_called_once_with("gemini-2.0-flash")
def test_non_auto_mode_unchanged(self):
"""Test that non-auto mode behavior is unchanged"""
from utils.conversation_memory import ThreadContext, build_conversation_history
with patch("config.IS_AUTO_MODE", False), patch("config.DEFAULT_MODEL", "gemini-2.5-pro-preview-06-05"):
from utils.conversation_memory import ConversationTurn
context = ThreadContext(
thread_id="test-789",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:00:00Z",
tool_name="thinkdeep",
turns=[
ConversationTurn(role="user", content="Test in non-auto mode", timestamp="2023-01-01T00:00:30Z")
],
initial_context={},
)
with patch("utils.model_context.ModelContext") as mock_context_class:
mock_context_instance = Mock()
mock_context_class.return_value = mock_context_instance
mock_context_instance.calculate_token_allocation.return_value = Mock(
file_tokens=10000, history_tokens=5000
)
# Mock estimate_tokens to return integers for proper summing
mock_context_instance.estimate_tokens.return_value = 100
history, tokens = build_conversation_history(context, model_context=None)
# Should use the configured DEFAULT_MODEL, not the intelligent fallback
mock_context_class.assert_called_once_with("gemini-2.5-pro-preview-06-05")
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -68,17 +68,17 @@ class TestLargePromptHandling:
tool = ChatTool()
# Mock the model to avoid actual API calls
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.candidates = [
MagicMock(
content=MagicMock(parts=[MagicMock(text="This is a test response")]),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = MagicMock(
content="This is a test response",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
result = await tool.execute({"prompt": normal_prompt})
@@ -93,17 +93,17 @@ class TestLargePromptHandling:
tool = ChatTool()
# Mock the model
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.candidates = [
MagicMock(
content=MagicMock(parts=[MagicMock(text="Processed large prompt")]),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = MagicMock(
content="Processed large prompt",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Mock read_file_content to avoid security checks
with patch("tools.base.read_file_content") as mock_read_file:
@@ -123,8 +123,11 @@ class TestLargePromptHandling:
mock_read_file.assert_called_once_with(temp_prompt_file)
# Verify the large content was used
call_args = mock_model.generate_content.call_args[0][0]
assert large_prompt in call_args
# generate_content is called with keyword arguments
call_kwargs = mock_provider.generate_content.call_args[1]
prompt_arg = call_kwargs.get("prompt")
assert prompt_arg is not None
assert large_prompt in prompt_arg
# Cleanup
temp_dir = os.path.dirname(temp_prompt_file)
@@ -134,7 +137,7 @@ class TestLargePromptHandling:
async def test_thinkdeep_large_analysis(self, large_prompt):
"""Test that thinkdeep tool detects large current_analysis."""
tool = ThinkDeepTool()
result = await tool.execute({"current_analysis": large_prompt})
result = await tool.execute({"prompt": large_prompt})
assert len(result) == 1
output = json.loads(result[0].text)
@@ -148,7 +151,7 @@ class TestLargePromptHandling:
{
"files": ["/some/file.py"],
"focus_on": large_prompt,
"context": "Test code review for validation purposes",
"prompt": "Test code review for validation purposes",
}
)
@@ -160,7 +163,7 @@ class TestLargePromptHandling:
async def test_review_changes_large_original_request(self, large_prompt):
"""Test that review_changes tool detects large original_request."""
tool = Precommit()
result = await tool.execute({"path": "/some/path", "original_request": large_prompt})
result = await tool.execute({"path": "/some/path", "prompt": large_prompt})
assert len(result) == 1
output = json.loads(result[0].text)
@@ -170,7 +173,7 @@ class TestLargePromptHandling:
async def test_debug_large_error_description(self, large_prompt):
"""Test that debug tool detects large error_description."""
tool = DebugIssueTool()
result = await tool.execute({"error_description": large_prompt})
result = await tool.execute({"prompt": large_prompt})
assert len(result) == 1
output = json.loads(result[0].text)
@@ -180,7 +183,7 @@ class TestLargePromptHandling:
async def test_debug_large_error_context(self, large_prompt, normal_prompt):
"""Test that debug tool detects large error_context."""
tool = DebugIssueTool()
result = await tool.execute({"error_description": normal_prompt, "error_context": large_prompt})
result = await tool.execute({"prompt": normal_prompt, "error_context": large_prompt})
assert len(result) == 1
output = json.loads(result[0].text)
@@ -190,7 +193,7 @@ class TestLargePromptHandling:
async def test_analyze_large_question(self, large_prompt):
"""Test that analyze tool detects large question."""
tool = AnalyzeTool()
result = await tool.execute({"files": ["/some/file.py"], "question": large_prompt})
result = await tool.execute({"files": ["/some/file.py"], "prompt": large_prompt})
assert len(result) == 1
output = json.loads(result[0].text)
@@ -202,17 +205,17 @@ class TestLargePromptHandling:
tool = ChatTool()
other_file = "/some/other/file.py"
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.candidates = [
MagicMock(
content=MagicMock(parts=[MagicMock(text="Success")]),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = MagicMock(
content="Success",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Mock the centralized file preparation method to avoid file system access
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare_files:
@@ -235,17 +238,17 @@ class TestLargePromptHandling:
tool = ChatTool()
exact_prompt = "x" * MCP_PROMPT_SIZE_LIMIT
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.candidates = [
MagicMock(
content=MagicMock(parts=[MagicMock(text="Success")]),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = MagicMock(
content="Success",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
result = await tool.execute({"prompt": exact_prompt})
output = json.loads(result[0].text)
@@ -266,17 +269,17 @@ class TestLargePromptHandling:
"""Test empty prompt without prompt.txt file."""
tool = ChatTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.candidates = [
MagicMock(
content=MagicMock(parts=[MagicMock(text="Success")]),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = MagicMock(
content="Success",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
result = await tool.execute({"prompt": ""})
output = json.loads(result[0].text)
@@ -288,17 +291,17 @@ class TestLargePromptHandling:
tool = ChatTool()
bad_file = "/nonexistent/prompt.txt"
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.candidates = [
MagicMock(
content=MagicMock(parts=[MagicMock(text="Success")]),
finish_reason="STOP",
)
]
mock_model.generate_content.return_value = mock_response
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = MagicMock(
content="Success",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Should continue with empty prompt when file can't be read
result = await tool.execute({"prompt": "", "files": [bad_file]})

View File

@@ -1,141 +0,0 @@
"""
Live integration tests for google-genai library
These tests require GEMINI_API_KEY to be set and will make real API calls
To run these tests manually:
python tests/test_live_integration.py
Note: These tests are excluded from regular pytest runs to avoid API rate limits.
They confirm that the google-genai library integration works correctly with live data.
"""
import asyncio
import os
import sys
import tempfile
from pathlib import Path
# Add parent directory to path to allow imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
from tools.analyze import AnalyzeTool
from tools.thinkdeep import ThinkDeepTool
async def run_manual_live_tests():
"""Run live tests manually without pytest"""
print("🚀 Running manual live integration tests...")
# Check API key
if not os.environ.get("GEMINI_API_KEY"):
print("❌ GEMINI_API_KEY not found. Set it to run live tests.")
return False
try:
# Test google-genai import
print("✅ google-genai library import successful")
# Test tool integration
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write("def hello(): return 'world'")
temp_path = f.name
try:
# Test AnalyzeTool
tool = AnalyzeTool()
result = await tool.execute(
{
"files": [temp_path],
"question": "What does this code do?",
"thinking_mode": "low",
}
)
if result and result[0].text:
print("✅ AnalyzeTool live test successful")
else:
print("❌ AnalyzeTool live test failed")
return False
# Test ThinkDeepTool
think_tool = ThinkDeepTool()
result = await think_tool.execute(
{
"current_analysis": "Testing live integration",
"thinking_mode": "minimal", # Fast test
}
)
if result and result[0].text and "Extended Analysis" in result[0].text:
print("✅ ThinkDeepTool live test successful")
else:
print("❌ ThinkDeepTool live test failed")
return False
# Test collaboration/clarification request
print("\n🔄 Testing dynamic context request (collaboration)...")
# Create a specific test case designed to trigger clarification
# We'll use analyze tool with a question that requires seeing files
analyze_tool = AnalyzeTool()
# Ask about dependencies without providing package files
result = await analyze_tool.execute(
{
"files": [temp_path], # Only Python file, no package.json
"question": "What npm packages and their versions does this project depend on? List all dependencies.",
"thinking_mode": "minimal", # Fast test
}
)
if result and result[0].text:
response_data = json.loads(result[0].text)
print(f" Response status: {response_data['status']}")
if response_data["status"] == "requires_clarification":
print("✅ Dynamic context request successfully triggered!")
clarification = json.loads(response_data["content"])
print(f" Gemini asks: {clarification.get('question', 'N/A')}")
if "files_needed" in clarification:
print(f" Files requested: {clarification['files_needed']}")
# Verify it's asking for package-related files
expected_files = [
"package.json",
"package-lock.json",
"yarn.lock",
]
if any(f in str(clarification["files_needed"]) for f in expected_files):
print(" ✅ Correctly identified missing package files!")
else:
print(" ⚠️ Unexpected files requested")
else:
# This is a failure - we specifically designed this to need clarification
print("❌ Expected clarification request but got direct response")
print(" This suggests the dynamic context feature may not be working")
print(" Response:", response_data.get("content", "")[:200])
return False
else:
print("❌ Collaboration test failed - no response")
return False
finally:
Path(temp_path).unlink(missing_ok=True)
print("\n🎉 All manual live tests passed!")
print("✅ google-genai library working correctly")
print("✅ All tools can make live API calls")
print("✅ Thinking modes functioning properly")
return True
except Exception as e:
print(f"❌ Live test failed: {e}")
return False
if __name__ == "__main__":
# Run live tests when script is executed directly
success = asyncio.run(run_manual_live_tests())
exit(0 if success else 1)

View File

@@ -28,7 +28,7 @@ class TestPrecommitTool:
schema = tool.get_input_schema()
assert schema["type"] == "object"
assert "path" in schema["properties"]
assert "original_request" in schema["properties"]
assert "prompt" in schema["properties"]
assert "compare_to" in schema["properties"]
assert "review_type" in schema["properties"]
@@ -36,7 +36,7 @@ class TestPrecommitTool:
"""Test request model default values"""
request = PrecommitRequest(path="/some/absolute/path")
assert request.path == "/some/absolute/path"
assert request.original_request is None
assert request.prompt is None
assert request.compare_to is None
assert request.include_staged is True
assert request.include_unstaged is True
@@ -48,7 +48,7 @@ class TestPrecommitTool:
@pytest.mark.asyncio
async def test_relative_path_rejected(self, tool):
"""Test that relative paths are rejected"""
result = await tool.execute({"path": "./relative/path", "original_request": "Test"})
result = await tool.execute({"path": "./relative/path", "prompt": "Test"})
assert len(result) == 1
response = json.loads(result[0].text)
assert response["status"] == "error"
@@ -128,7 +128,7 @@ class TestPrecommitTool:
request = PrecommitRequest(
path="/absolute/repo/path",
original_request="Add hello message",
prompt="Add hello message",
review_type="security",
)
result = await tool.prepare_prompt(request)

View File

@@ -124,7 +124,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
temp_dir, config_path = temp_repo
# Create request with files parameter
request = PrecommitRequest(path=temp_dir, files=[config_path], original_request="Test configuration changes")
request = PrecommitRequest(path=temp_dir, files=[config_path], prompt="Test configuration changes")
# Generate the prompt
prompt = await tool.prepare_prompt(request)
@@ -152,7 +152,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
# Mock conversation memory functions to use our mock redis
with patch("utils.conversation_memory.get_redis_client", return_value=mock_redis):
# First request - should embed file content
PrecommitRequest(path=temp_dir, files=[config_path], original_request="First review")
PrecommitRequest(path=temp_dir, files=[config_path], prompt="First review")
# Simulate conversation thread creation
from utils.conversation_memory import add_turn, create_thread
@@ -167,9 +167,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
add_turn(thread_id, "assistant", "First response", files=[config_path], tool_name="precommit")
# Second request with continuation - should skip already embedded files
PrecommitRequest(
path=temp_dir, files=[config_path], continuation_id=thread_id, original_request="Follow-up review"
)
PrecommitRequest(path=temp_dir, files=[config_path], continuation_id=thread_id, prompt="Follow-up review")
files_to_embed_2 = tool.filter_new_files([config_path], thread_id)
assert len(files_to_embed_2) == 0, "Continuation should skip already embedded files"
@@ -182,7 +180,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
request = PrecommitRequest(
path=temp_dir,
files=[config_path],
original_request="Validate prompt structure",
prompt="Validate prompt structure",
review_type="full",
severity_filter="high",
)
@@ -191,7 +189,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
# Split prompt into sections
sections = {
"original_request": "## Original Request",
"prompt": "## Original Request",
"review_parameters": "## Review Parameters",
"repo_summary": "## Repository Changes Summary",
"context_files_summary": "## Context Files Summary",
@@ -207,7 +205,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
section_indices[name] = index
# Verify sections appear in logical order
assert section_indices["original_request"] < section_indices["review_parameters"]
assert section_indices["prompt"] < section_indices["review_parameters"]
assert section_indices["review_parameters"] < section_indices["repo_summary"]
assert section_indices["git_diffs"] < section_indices["additional_context"]
assert section_indices["additional_context"] < section_indices["review_instructions"]

View File

@@ -24,16 +24,16 @@ class TestPromptRegression:
@pytest.fixture
def mock_model_response(self):
"""Create a mock model response."""
from unittest.mock import Mock
def _create_response(text="Test response"):
mock_response = MagicMock()
mock_response.candidates = [
MagicMock(
content=MagicMock(parts=[MagicMock(text=text)]),
finish_reason="STOP",
)
]
return mock_response
# Return a Mock that acts like ModelResponse
return Mock(
content=text,
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
return _create_response
@@ -42,10 +42,14 @@ class TestPromptRegression:
"""Test chat tool with normal prompt."""
tool = ChatTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response("This is a helpful response about Python.")
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response(
"This is a helpful response about Python."
)
mock_get_provider.return_value = mock_provider
result = await tool.execute({"prompt": "Explain Python decorators"})
@@ -54,18 +58,20 @@ class TestPromptRegression:
assert output["status"] == "success"
assert "helpful response about Python" in output["content"]
# Verify model was called
mock_model.generate_content.assert_called_once()
# Verify provider was called
mock_provider.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_chat_with_files(self, mock_model_response):
"""Test chat tool with files parameter."""
tool = ChatTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response()
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response()
mock_get_provider.return_value = mock_provider
# Mock file reading through the centralized method
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare_files:
@@ -83,16 +89,18 @@ class TestPromptRegression:
"""Test thinkdeep tool with normal analysis."""
tool = ThinkDeepTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response(
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response(
"Here's a deeper analysis with edge cases..."
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
result = await tool.execute(
{
"current_analysis": "I think we should use a cache for performance",
"prompt": "I think we should use a cache for performance",
"problem_context": "Building a high-traffic API",
"focus_areas": ["scalability", "reliability"],
}
@@ -101,7 +109,7 @@ class TestPromptRegression:
assert len(result) == 1
output = json.loads(result[0].text)
assert output["status"] == "success"
assert "Extended Analysis by Gemini" in output["content"]
assert "Critical Evaluation Required" in output["content"]
assert "deeper analysis" in output["content"]
@pytest.mark.asyncio
@@ -109,12 +117,14 @@ class TestPromptRegression:
"""Test codereview tool with normal inputs."""
tool = CodeReviewTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response(
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response(
"Found 3 issues: 1) Missing error handling..."
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
# Mock file reading
with patch("tools.base.read_files") as mock_read_files:
@@ -125,7 +135,7 @@ class TestPromptRegression:
"files": ["/path/to/code.py"],
"review_type": "security",
"focus_on": "Look for SQL injection vulnerabilities",
"context": "Test code review for validation purposes",
"prompt": "Test code review for validation purposes",
}
)
@@ -139,12 +149,14 @@ class TestPromptRegression:
"""Test review_changes tool with normal original_request."""
tool = Precommit()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response(
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response(
"Changes look good, implementing feature as requested..."
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
# Mock git operations
with patch("tools.precommit.find_git_repositories") as mock_find_repos:
@@ -158,7 +170,7 @@ class TestPromptRegression:
result = await tool.execute(
{
"path": "/path/to/repo",
"original_request": "Add user authentication feature with JWT tokens",
"prompt": "Add user authentication feature with JWT tokens",
}
)
@@ -171,16 +183,18 @@ class TestPromptRegression:
"""Test debug tool with normal error description."""
tool = DebugIssueTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response(
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response(
"Root cause: The variable is undefined. Fix: Initialize it..."
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
result = await tool.execute(
{
"error_description": "TypeError: Cannot read property 'name' of undefined",
"prompt": "TypeError: Cannot read property 'name' of undefined",
"error_context": "at line 42 in user.js\n console.log(user.name)",
"runtime_info": "Node.js v16.14.0",
}
@@ -189,7 +203,7 @@ class TestPromptRegression:
assert len(result) == 1
output = json.loads(result[0].text)
assert output["status"] == "success"
assert "Debug Analysis" in output["content"]
assert "Next Steps:" in output["content"]
assert "Root cause" in output["content"]
@pytest.mark.asyncio
@@ -197,12 +211,14 @@ class TestPromptRegression:
"""Test analyze tool with normal question."""
tool = AnalyzeTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response(
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response(
"The code follows MVC pattern with clear separation..."
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
# Mock file reading
with patch("tools.base.read_files") as mock_read_files:
@@ -211,7 +227,7 @@ class TestPromptRegression:
result = await tool.execute(
{
"files": ["/path/to/project"],
"question": "What design patterns are used in this codebase?",
"prompt": "What design patterns are used in this codebase?",
"analysis_type": "architecture",
}
)
@@ -226,10 +242,12 @@ class TestPromptRegression:
"""Test tools work with empty optional fields."""
tool = ChatTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response()
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response()
mock_get_provider.return_value = mock_provider
# Test with no files parameter
result = await tool.execute({"prompt": "Hello"})
@@ -243,10 +261,12 @@ class TestPromptRegression:
"""Test that thinking modes are properly passed through."""
tool = ChatTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response()
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response()
mock_get_provider.return_value = mock_provider
result = await tool.execute({"prompt": "Test", "thinking_mode": "high", "temperature": 0.8})
@@ -254,21 +274,24 @@ class TestPromptRegression:
output = json.loads(result[0].text)
assert output["status"] == "success"
# Verify create_model was called with correct parameters
mock_create_model.assert_called_once()
call_args = mock_create_model.call_args
assert call_args[0][2] == "high" # thinking_mode
assert call_args[0][1] == 0.8 # temperature
# Verify generate_content was called with correct parameters
mock_provider.generate_content.assert_called_once()
call_kwargs = mock_provider.generate_content.call_args[1]
assert call_kwargs.get("temperature") == 0.8
# thinking_mode would be passed if the provider supports it
# In this test, we set supports_thinking_mode to False, so it won't be passed
@pytest.mark.asyncio
async def test_special_characters_in_prompts(self, mock_model_response):
"""Test prompts with special characters work correctly."""
tool = ChatTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response()
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response()
mock_get_provider.return_value = mock_provider
special_prompt = 'Test with "quotes" and\nnewlines\tand tabs'
result = await tool.execute({"prompt": special_prompt})
@@ -282,10 +305,12 @@ class TestPromptRegression:
"""Test handling of various file path formats."""
tool = AnalyzeTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response()
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response()
mock_get_provider.return_value = mock_provider
with patch("tools.base.read_files") as mock_read_files:
mock_read_files.return_value = "Content"
@@ -297,7 +322,7 @@ class TestPromptRegression:
"/Users/name/project/src/",
"/home/user/code.js",
],
"question": "Analyze these files",
"prompt": "Analyze these files",
}
)
@@ -311,10 +336,12 @@ class TestPromptRegression:
"""Test handling of unicode content in prompts."""
tool = ChatTool()
with patch.object(tool, "create_model") as mock_create_model:
mock_model = MagicMock()
mock_model.generate_content.return_value = mock_model_response()
mock_create_model.return_value = mock_model
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = mock_model_response()
mock_get_provider.return_value = mock_provider
unicode_prompt = "Explain this: 你好世界 مرحبا بالعالم"
result = await tool.execute({"prompt": unicode_prompt})

183
tests/test_providers.py Normal file
View File

@@ -0,0 +1,183 @@
"""Tests for the model provider abstraction system"""
import os
from unittest.mock import Mock, patch
from providers import ModelProviderRegistry, ModelResponse
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
class TestModelProviderRegistry:
"""Test the model provider registry"""
def setup_method(self):
"""Clear registry before each test"""
ModelProviderRegistry._providers.clear()
ModelProviderRegistry._initialized_providers.clear()
def test_register_provider(self):
"""Test registering a provider"""
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
assert ProviderType.GOOGLE in ModelProviderRegistry._providers
assert ModelProviderRegistry._providers[ProviderType.GOOGLE] == GeminiModelProvider
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
def test_get_provider(self):
"""Test getting a provider instance"""
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
assert provider is not None
assert isinstance(provider, GeminiModelProvider)
assert provider.api_key == "test-key"
@patch.dict(os.environ, {}, clear=True)
def test_get_provider_no_api_key(self):
"""Test getting provider without API key returns None"""
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
assert provider is None
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
def test_get_provider_for_model(self):
"""Test getting provider for a specific model"""
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash")
assert provider is not None
assert isinstance(provider, GeminiModelProvider)
def test_get_available_providers(self):
"""Test getting list of available providers"""
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
providers = ModelProviderRegistry.get_available_providers()
assert len(providers) == 2
assert ProviderType.GOOGLE in providers
assert ProviderType.OPENAI in providers
class TestGeminiProvider:
"""Test Gemini model provider"""
def test_provider_initialization(self):
"""Test provider initialization"""
provider = GeminiModelProvider(api_key="test-key")
assert provider.api_key == "test-key"
assert provider.get_provider_type() == ProviderType.GOOGLE
def test_get_capabilities(self):
"""Test getting model capabilities"""
provider = GeminiModelProvider(api_key="test-key")
capabilities = provider.get_capabilities("gemini-2.0-flash")
assert capabilities.provider == ProviderType.GOOGLE
assert capabilities.model_name == "gemini-2.0-flash"
assert capabilities.max_tokens == 1_048_576
assert not capabilities.supports_extended_thinking
def test_get_capabilities_pro_model(self):
"""Test getting capabilities for Pro model with thinking support"""
provider = GeminiModelProvider(api_key="test-key")
capabilities = provider.get_capabilities("gemini-2.5-pro-preview-06-05")
assert capabilities.supports_extended_thinking
def test_model_shorthand_resolution(self):
"""Test model shorthand resolution"""
provider = GeminiModelProvider(api_key="test-key")
assert provider.validate_model_name("flash")
assert provider.validate_model_name("pro")
capabilities = provider.get_capabilities("flash")
assert capabilities.model_name == "gemini-2.0-flash"
def test_supports_thinking_mode(self):
"""Test thinking mode support detection"""
provider = GeminiModelProvider(api_key="test-key")
assert not provider.supports_thinking_mode("gemini-2.0-flash")
assert provider.supports_thinking_mode("gemini-2.5-pro-preview-06-05")
@patch("google.genai.Client")
def test_generate_content(self, mock_client_class):
"""Test content generation"""
# Mock the client
mock_client = Mock()
mock_response = Mock()
mock_response.text = "Generated content"
# Mock candidates for finish_reason
mock_candidate = Mock()
mock_candidate.finish_reason = "STOP"
mock_response.candidates = [mock_candidate]
# Mock usage metadata
mock_usage = Mock()
mock_usage.prompt_token_count = 10
mock_usage.candidates_token_count = 20
mock_response.usage_metadata = mock_usage
mock_client.models.generate_content.return_value = mock_response
mock_client_class.return_value = mock_client
provider = GeminiModelProvider(api_key="test-key")
response = provider.generate_content(prompt="Test prompt", model_name="gemini-2.0-flash", temperature=0.7)
assert isinstance(response, ModelResponse)
assert response.content == "Generated content"
assert response.model_name == "gemini-2.0-flash"
assert response.provider == ProviderType.GOOGLE
assert response.usage["input_tokens"] == 10
assert response.usage["output_tokens"] == 20
assert response.usage["total_tokens"] == 30
class TestOpenAIProvider:
"""Test OpenAI model provider"""
def test_provider_initialization(self):
"""Test provider initialization"""
provider = OpenAIModelProvider(api_key="test-key", organization="test-org")
assert provider.api_key == "test-key"
assert provider.organization == "test-org"
assert provider.get_provider_type() == ProviderType.OPENAI
def test_get_capabilities_o3(self):
"""Test getting O3 model capabilities"""
provider = OpenAIModelProvider(api_key="test-key")
capabilities = provider.get_capabilities("o3-mini")
assert capabilities.provider == ProviderType.OPENAI
assert capabilities.model_name == "o3-mini"
assert capabilities.max_tokens == 200_000
assert not capabilities.supports_extended_thinking
def test_validate_model_names(self):
"""Test model name validation"""
provider = OpenAIModelProvider(api_key="test-key")
assert provider.validate_model_name("o3")
assert provider.validate_model_name("o3-mini")
assert not provider.validate_model_name("gpt-4o")
assert not provider.validate_model_name("invalid-model")
def test_no_thinking_mode_support(self):
"""Test that no OpenAI models support thinking mode"""
provider = OpenAIModelProvider(api_key="test-key")
assert not provider.supports_thinking_mode("o3")
assert not provider.supports_thinking_mode("o3-mini")

View File

@@ -7,6 +7,7 @@ from unittest.mock import Mock, patch
import pytest
from server import handle_call_tool, handle_list_tools
from tests.mock_helpers import create_mock_provider
class TestServerTools:
@@ -42,31 +43,33 @@ class TestServerTools:
assert "Unknown tool: unknown_tool" in result[0].text
@pytest.mark.asyncio
async def test_handle_chat(self):
@patch("tools.base.BaseTool.get_model_provider")
async def test_handle_chat(self, mock_get_provider):
"""Test chat functionality"""
# Set test environment
import os
os.environ["PYTEST_CURRENT_TEST"] = "test"
# Create a mock for the model
with patch("tools.base.BaseTool.create_model") as mock_create:
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text="Chat response")]))]
)
mock_create.return_value = mock_model
# Create a mock for the provider
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Chat response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
result = await handle_call_tool("chat", {"prompt": "Hello Gemini"})
result = await handle_call_tool("chat", {"prompt": "Hello Gemini"})
assert len(result) == 1
# Parse JSON response
import json
assert len(result) == 1
# Parse JSON response
import json
response_data = json.loads(result[0].text)
assert response_data["status"] == "success"
assert "Chat response" in response_data["content"]
assert "Claude's Turn" in response_data["content"]
response_data = json.loads(result[0].text)
assert response_data["status"] == "success"
assert "Chat response" in response_data["content"]
assert "Claude's Turn" in response_data["content"]
@pytest.mark.asyncio
async def test_handle_get_version(self):
@@ -75,6 +78,6 @@ class TestServerTools:
assert len(result) == 1
response = result[0].text
assert "Gemini MCP Server v" in response # Version agnostic check
assert "Zen MCP Server v" in response # Version agnostic check
assert "Available Tools:" in response
assert "thinkdeep" in response

View File

@@ -6,6 +6,7 @@ from unittest.mock import Mock, patch
import pytest
from tests.mock_helpers import create_mock_provider
from tools.analyze import AnalyzeTool
from tools.codereview import CodeReviewTool
from tools.debug import DebugIssueTool
@@ -37,135 +38,165 @@ class TestThinkingModes:
), f"{tool.__class__.__name__} should default to {expected_default}"
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_thinking_mode_minimal(self, mock_create_model):
@patch("tools.base.BaseTool.get_model_provider")
async def test_thinking_mode_minimal(self, mock_get_provider):
"""Test minimal thinking mode"""
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text="Minimal thinking response")]))]
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="Minimal thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
tool = AnalyzeTool()
result = await tool.execute(
{
"files": ["/absolute/path/test.py"],
"question": "What is this?",
"prompt": "What is this?",
"thinking_mode": "minimal",
}
)
# Verify create_model was called with correct thinking_mode
mock_create_model.assert_called_once()
args = mock_create_model.call_args[0]
assert args[2] == "minimal" # thinking_mode parameter
assert mock_get_provider.called
# Verify generate_content was called with thinking_mode
mock_provider.generate_content.assert_called_once()
call_kwargs = mock_provider.generate_content.call_args[1]
assert call_kwargs.get("thinking_mode") == "minimal" or (
not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None
) # thinking_mode parameter
# Parse JSON response
import json
response_data = json.loads(result[0].text)
assert response_data["status"] == "success"
assert response_data["content"].startswith("Analysis:")
assert "Minimal thinking response" in response_data["content"] or "Analysis:" in response_data["content"]
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_thinking_mode_low(self, mock_create_model):
@patch("tools.base.BaseTool.get_model_provider")
async def test_thinking_mode_low(self, mock_get_provider):
"""Test low thinking mode"""
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text="Low thinking response")]))]
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="Low thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
tool = CodeReviewTool()
result = await tool.execute(
{
"files": ["/absolute/path/test.py"],
"thinking_mode": "low",
"context": "Test code review for validation purposes",
"prompt": "Test code review for validation purposes",
}
)
# Verify create_model was called with correct thinking_mode
mock_create_model.assert_called_once()
args = mock_create_model.call_args[0]
assert args[2] == "low"
assert mock_get_provider.called
# Verify generate_content was called with thinking_mode
mock_provider.generate_content.assert_called_once()
call_kwargs = mock_provider.generate_content.call_args[1]
assert call_kwargs.get("thinking_mode") == "low" or (
not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None
)
assert "Code Review" in result[0].text
assert "Low thinking response" in result[0].text or "Code Review" in result[0].text
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_thinking_mode_medium(self, mock_create_model):
@patch("tools.base.BaseTool.get_model_provider")
async def test_thinking_mode_medium(self, mock_get_provider):
"""Test medium thinking mode (default for most tools)"""
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text="Medium thinking response")]))]
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="Medium thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
tool = DebugIssueTool()
result = await tool.execute(
{
"error_description": "Test error",
"prompt": "Test error",
# Not specifying thinking_mode, should use default (medium)
}
)
# Verify create_model was called with default thinking_mode
mock_create_model.assert_called_once()
args = mock_create_model.call_args[0]
assert args[2] == "medium"
assert mock_get_provider.called
# Verify generate_content was called with thinking_mode
mock_provider.generate_content.assert_called_once()
call_kwargs = mock_provider.generate_content.call_args[1]
assert call_kwargs.get("thinking_mode") == "medium" or (
not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None
)
assert "Debug Analysis" in result[0].text
assert "Medium thinking response" in result[0].text or "Debug Analysis" in result[0].text
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_thinking_mode_high(self, mock_create_model):
@patch("tools.base.BaseTool.get_model_provider")
async def test_thinking_mode_high(self, mock_get_provider):
"""Test high thinking mode"""
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text="High thinking response")]))]
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="High thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
tool = AnalyzeTool()
await tool.execute(
{
"files": ["/absolute/path/complex.py"],
"question": "Analyze architecture",
"prompt": "Analyze architecture",
"thinking_mode": "high",
}
)
# Verify create_model was called with correct thinking_mode
mock_create_model.assert_called_once()
args = mock_create_model.call_args[0]
assert args[2] == "high"
assert mock_get_provider.called
# Verify generate_content was called with thinking_mode
mock_provider.generate_content.assert_called_once()
call_kwargs = mock_provider.generate_content.call_args[1]
assert call_kwargs.get("thinking_mode") == "high" or (
not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None
)
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_thinking_mode_max(self, mock_create_model):
@patch("tools.base.BaseTool.get_model_provider")
async def test_thinking_mode_max(self, mock_get_provider):
"""Test max thinking mode (default for thinkdeep)"""
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text="Max thinking response")]))]
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="Max thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
tool = ThinkDeepTool()
result = await tool.execute(
{
"current_analysis": "Initial analysis",
"prompt": "Initial analysis",
# Not specifying thinking_mode, should use default (high)
}
)
# Verify create_model was called with default thinking_mode
mock_create_model.assert_called_once()
args = mock_create_model.call_args[0]
assert args[2] == "high"
assert mock_get_provider.called
# Verify generate_content was called with thinking_mode
mock_provider.generate_content.assert_called_once()
call_kwargs = mock_provider.generate_content.call_args[1]
assert call_kwargs.get("thinking_mode") == "high" or (
not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None
)
assert "Extended Analysis by Gemini" in result[0].text
assert "Max thinking response" in result[0].text or "Extended Analysis by Gemini" in result[0].text
def test_thinking_budget_mapping(self):
"""Test that thinking modes map to correct budget values"""

View File

@@ -7,6 +7,7 @@ from unittest.mock import Mock, patch
import pytest
from tests.mock_helpers import create_mock_provider
from tools import AnalyzeTool, ChatTool, CodeReviewTool, DebugIssueTool, ThinkDeepTool
@@ -24,23 +25,25 @@ class TestThinkDeepTool:
assert tool.get_default_temperature() == 0.7
schema = tool.get_input_schema()
assert "current_analysis" in schema["properties"]
assert schema["required"] == ["current_analysis"]
assert "prompt" in schema["properties"]
assert schema["required"] == ["prompt"]
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_execute_success(self, mock_create_model, tool):
@patch("tools.base.BaseTool.get_model_provider")
async def test_execute_success(self, mock_get_provider, tool):
"""Test successful execution"""
# Mock model
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text="Extended analysis")]))]
# Mock provider
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="Extended analysis", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
result = await tool.execute(
{
"current_analysis": "Initial analysis",
"prompt": "Initial analysis",
"problem_context": "Building a cache",
"focus_areas": ["performance", "scalability"],
}
@@ -50,7 +53,7 @@ class TestThinkDeepTool:
# Parse the JSON response
output = json.loads(result[0].text)
assert output["status"] == "success"
assert "Extended Analysis by Gemini" in output["content"]
assert "Critical Evaluation Required" in output["content"]
assert "Extended analysis" in output["content"]
@@ -69,36 +72,38 @@ class TestCodeReviewTool:
schema = tool.get_input_schema()
assert "files" in schema["properties"]
assert "context" in schema["properties"]
assert schema["required"] == ["files", "context"]
assert "prompt" in schema["properties"]
assert schema["required"] == ["files", "prompt"]
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_execute_with_review_type(self, mock_create_model, tool, tmp_path):
@patch("tools.base.BaseTool.get_model_provider")
async def test_execute_with_review_type(self, mock_get_provider, tool, tmp_path):
"""Test execution with specific review type"""
# Create test file
test_file = tmp_path / "test.py"
test_file.write_text("def insecure(): pass", encoding="utf-8")
# Mock model
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text="Security issues found")]))]
# Mock provider
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Security issues found", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
result = await tool.execute(
{
"files": [str(test_file)],
"review_type": "security",
"focus_on": "authentication",
"context": "Test code review for validation purposes",
"prompt": "Test code review for validation purposes",
}
)
assert len(result) == 1
assert "Code Review (SECURITY)" in result[0].text
assert "Focus: authentication" in result[0].text
assert "Security issues found" in result[0].text
assert "Claude's Next Steps:" in result[0].text
assert "Security issues found" in result[0].text
@@ -116,30 +121,32 @@ class TestDebugIssueTool:
assert tool.get_default_temperature() == 0.2
schema = tool.get_input_schema()
assert "error_description" in schema["properties"]
assert schema["required"] == ["error_description"]
assert "prompt" in schema["properties"]
assert schema["required"] == ["prompt"]
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_execute_with_context(self, mock_create_model, tool):
@patch("tools.base.BaseTool.get_model_provider")
async def test_execute_with_context(self, mock_get_provider, tool):
"""Test execution with error context"""
# Mock model
mock_model = Mock()
mock_model.generate_content.return_value = Mock(
candidates=[Mock(content=Mock(parts=[Mock(text="Root cause: race condition")]))]
# Mock provider
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Root cause: race condition", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_create_model.return_value = mock_model
mock_get_provider.return_value = mock_provider
result = await tool.execute(
{
"error_description": "Test fails intermittently",
"prompt": "Test fails intermittently",
"error_context": "AssertionError in test_async",
"previous_attempts": "Added sleep, still fails",
}
)
assert len(result) == 1
assert "Debug Analysis" in result[0].text
assert "Next Steps:" in result[0].text
assert "Root cause: race condition" in result[0].text
@@ -158,38 +165,38 @@ class TestAnalyzeTool:
schema = tool.get_input_schema()
assert "files" in schema["properties"]
assert "question" in schema["properties"]
assert set(schema["required"]) == {"files", "question"}
assert "prompt" in schema["properties"]
assert set(schema["required"]) == {"files", "prompt"}
@pytest.mark.asyncio
@patch("tools.base.BaseTool.create_model")
async def test_execute_with_analysis_type(self, mock_model, tool, tmp_path):
@patch("tools.base.BaseTool.get_model_provider")
async def test_execute_with_analysis_type(self, mock_get_provider, tool, tmp_path):
"""Test execution with specific analysis type"""
# Create test file
test_file = tmp_path / "module.py"
test_file.write_text("class Service: pass", encoding="utf-8")
# Mock response
mock_response = Mock()
mock_response.candidates = [Mock()]
mock_response.candidates[0].content.parts = [Mock(text="Architecture analysis")]
mock_instance = Mock()
mock_instance.generate_content.return_value = mock_response
mock_model.return_value = mock_instance
# Mock provider
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Architecture analysis", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
result = await tool.execute(
{
"files": [str(test_file)],
"question": "What's the structure?",
"prompt": "What's the structure?",
"analysis_type": "architecture",
"output_format": "summary",
}
)
assert len(result) == 1
assert "ARCHITECTURE Analysis" in result[0].text
assert "Analyzed 1 file(s)" in result[0].text
assert "Architecture analysis" in result[0].text
assert "Next Steps:" in result[0].text
assert "Architecture analysis" in result[0].text
@@ -203,7 +210,7 @@ class TestAbsolutePathValidation:
result = await tool.execute(
{
"files": ["./relative/path.py", "/absolute/path.py"],
"question": "What does this do?",
"prompt": "What does this do?",
}
)
@@ -221,7 +228,7 @@ class TestAbsolutePathValidation:
{
"files": ["../parent/file.py"],
"review_type": "full",
"context": "Test code review for validation purposes",
"prompt": "Test code review for validation purposes",
}
)
@@ -237,7 +244,7 @@ class TestAbsolutePathValidation:
tool = DebugIssueTool()
result = await tool.execute(
{
"error_description": "Something broke",
"prompt": "Something broke",
"files": ["src/main.py"], # relative path
}
)
@@ -252,7 +259,7 @@ class TestAbsolutePathValidation:
async def test_thinkdeep_tool_relative_path_rejected(self):
"""Test that thinkdeep tool rejects relative paths"""
tool = ThinkDeepTool()
result = await tool.execute({"current_analysis": "My analysis", "files": ["./local/file.py"]})
result = await tool.execute({"prompt": "My analysis", "files": ["./local/file.py"]})
assert len(result) == 1
response = json.loads(result[0].text)
@@ -278,21 +285,21 @@ class TestAbsolutePathValidation:
assert "code.py" in response["content"]
@pytest.mark.asyncio
@patch("tools.AnalyzeTool.create_model")
async def test_analyze_tool_accepts_absolute_paths(self, mock_model):
@patch("tools.AnalyzeTool.get_model_provider")
async def test_analyze_tool_accepts_absolute_paths(self, mock_get_provider):
"""Test that analyze tool accepts absolute paths"""
tool = AnalyzeTool()
# Mock the model response
mock_response = Mock()
mock_response.candidates = [Mock()]
mock_response.candidates[0].content.parts = [Mock(text="Analysis complete")]
# Mock provider
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Analysis complete", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
mock_instance = Mock()
mock_instance.generate_content.return_value = mock_response
mock_model.return_value = mock_instance
result = await tool.execute({"files": ["/absolute/path/file.py"], "question": "What does this do?"})
result = await tool.execute({"files": ["/absolute/path/file.py"], "prompt": "What does this do?"})
assert len(result) == 1
response = json.loads(result[0].text)

View File

@@ -1,5 +1,5 @@
"""
Tool implementations for Gemini MCP Server
Tool implementations for Zen MCP Server
"""
from .analyze import AnalyzeTool

View File

@@ -18,7 +18,7 @@ class AnalyzeRequest(ToolRequest):
"""Request model for analyze tool"""
files: list[str] = Field(..., description="Files or directories to analyze (must be absolute paths)")
question: str = Field(..., description="What to analyze or look for")
prompt: str = Field(..., description="What to analyze or look for")
analysis_type: Optional[str] = Field(
None,
description="Type of analysis: architecture|performance|security|quality|general",
@@ -42,9 +42,9 @@ class AnalyzeTool(BaseTool):
)
def get_input_schema(self) -> dict[str, Any]:
from config import DEFAULT_MODEL
from config import IS_AUTO_MODE
return {
schema = {
"type": "object",
"properties": {
"files": {
@@ -52,11 +52,8 @@ class AnalyzeTool(BaseTool):
"items": {"type": "string"},
"description": "Files or directories to analyze (must be absolute paths)",
},
"model": {
"type": "string",
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
},
"question": {
"model": self.get_model_field_schema(),
"prompt": {
"type": "string",
"description": "What to analyze or look for",
},
@@ -98,9 +95,11 @@ class AnalyzeTool(BaseTool):
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
},
},
"required": ["files", "question"],
"required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []),
}
return schema
def get_system_prompt(self) -> str:
return ANALYZE_PROMPT
@@ -116,8 +115,8 @@ class AnalyzeTool(BaseTool):
request_model = self.get_request_model()
request = request_model(**arguments)
# Check question size
size_check = self.check_prompt_size(request.question)
# Check prompt size
size_check = self.check_prompt_size(request.prompt)
if size_check:
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
@@ -129,9 +128,9 @@ class AnalyzeTool(BaseTool):
# Check for prompt.txt in files
prompt_content, updated_files = self.handle_prompt_file(request.files)
# If prompt.txt was found, use it as the question
# If prompt.txt was found, use it as the prompt
if prompt_content:
request.question = prompt_content
request.prompt = prompt_content
# Update request files list
if updated_files is not None:
@@ -177,7 +176,7 @@ class AnalyzeTool(BaseTool):
{focus_instruction}{websearch_instruction}
=== USER QUESTION ===
{request.question}
{request.prompt}
=== END QUESTION ===
=== FILES TO ANALYZE ===
@@ -188,12 +187,6 @@ Please analyze these files to answer the user's question."""
return full_prompt
def format_response(self, response: str, request: AnalyzeRequest) -> str:
def format_response(self, response: str, request: AnalyzeRequest, model_info: Optional[dict] = None) -> str:
"""Format the analysis response"""
header = f"Analysis: {request.question[:50]}..."
if request.analysis_type:
header = f"{request.analysis_type.upper()} Analysis"
summary_text = f"Analyzed {len(request.files)} file(s)"
return f"{header}\n{summary_text}\n{'=' * 50}\n\n{response}\n\n---\n\n**Next Steps:** Consider if this analysis reveals areas needing deeper investigation, additional context, or specific implementation details."
return f"{response}\n\n---\n\n**Next Steps:** Use this analysis to actively continue your task. Investigate deeper into any findings, implement solutions based on these insights, and carry out the necessary work. Only pause to ask the user if you need their explicit approval for major changes or if critical decisions require their input."

View File

@@ -1,5 +1,5 @@
"""
Base class for all Gemini MCP tools
Base class for all Zen MCP tools
This module provides the abstract base class that all tools must inherit from.
It defines the contract that tools must implement and provides common functionality
@@ -16,16 +16,14 @@ Key responsibilities:
import json
import logging
import os
import re
from abc import ABC, abstractmethod
from typing import Any, Literal, Optional
from google import genai
from google.genai import types
from mcp.types import TextContent
from pydantic import BaseModel, Field
from config import DEFAULT_MODEL, MAX_CONTEXT_TOKENS, MCP_PROMPT_SIZE_LIMIT
from config import MAX_CONTEXT_TOKENS, MCP_PROMPT_SIZE_LIMIT
from providers import ModelProvider, ModelProviderRegistry
from utils import check_token_limit
from utils.conversation_memory import (
MAX_CONVERSATION_TURNS,
@@ -36,7 +34,7 @@ from utils.conversation_memory import (
)
from utils.file_utils import read_file_content, read_files, translate_path_for_environment
from .models import ClarificationRequest, ContinuationOffer, FollowUpRequest, ToolOutput
from .models import ClarificationRequest, ContinuationOffer, ToolOutput
logger = logging.getLogger(__name__)
@@ -52,7 +50,7 @@ class ToolRequest(BaseModel):
model: Optional[str] = Field(
None,
description=f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
description="Model to use. See tool's input schema for available models and their capabilities.",
)
temperature: Optional[float] = Field(None, description="Temperature for response (tool-specific defaults)")
# Thinking mode controls how much computational budget the model uses for reasoning
@@ -144,6 +142,38 @@ class BaseTool(ABC):
"""
pass
def get_model_field_schema(self) -> dict[str, Any]:
"""
Generate the model field schema based on auto mode configuration.
When auto mode is enabled, the model parameter becomes required
and includes detailed descriptions of each model's capabilities.
Returns:
Dict containing the model field JSON schema
"""
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
if IS_AUTO_MODE:
# In auto mode, model is required and we provide detailed descriptions
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
for model, desc in MODEL_CAPABILITIES_DESC.items():
model_desc_parts.append(f"- '{model}': {desc}")
return {
"type": "string",
"description": "\n".join(model_desc_parts),
"enum": list(MODEL_CAPABILITIES_DESC.keys()),
}
else:
# Normal mode - model is optional with default
available_models = list(MODEL_CAPABILITIES_DESC.keys())
models_str = ", ".join(f"'{m}'" for m in available_models)
return {
"type": "string",
"description": f"Model to use. Available: {models_str}. Defaults to '{DEFAULT_MODEL}' if not specified.",
}
def get_default_temperature(self) -> float:
"""
Return the default temperature setting for this tool.
@@ -226,9 +256,7 @@ class BaseTool(ABC):
# Safety check: If no files are marked as embedded but we have a continuation_id,
# this might indicate an issue with conversation history. Be conservative.
if not embedded_files:
logger.debug(
f"📁 {self.name} tool: No files found in conversation history for thread {continuation_id}"
)
logger.debug(f"{self.name} tool: No files found in conversation history for thread {continuation_id}")
logger.debug(
f"[FILES] {self.name}: No embedded files found, returning all {len(requested_files)} requested files"
)
@@ -245,7 +273,7 @@ class BaseTool(ABC):
if len(new_files) < len(requested_files):
skipped = [f for f in requested_files if f in embedded_files]
logger.debug(
f"📁 {self.name} tool: Filtering {len(skipped)} files already in conversation history: {', '.join(skipped)}"
f"{self.name} tool: Filtering {len(skipped)} files already in conversation history: {', '.join(skipped)}"
)
logger.debug(f"[FILES] {self.name}: Skipped (already embedded): {skipped}")
@@ -254,8 +282,8 @@ class BaseTool(ABC):
except Exception as e:
# If there's any issue with conversation history lookup, be conservative
# and include all files rather than risk losing access to needed files
logger.warning(f"📁 {self.name} tool: Error checking conversation history for {continuation_id}: {e}")
logger.warning(f"📁 {self.name} tool: Including all requested files as fallback")
logger.warning(f"{self.name} tool: Error checking conversation history for {continuation_id}: {e}")
logger.warning(f"{self.name} tool: Including all requested files as fallback")
logger.debug(
f"[FILES] {self.name}: Exception in filter_new_files, returning all {len(requested_files)} files as fallback"
)
@@ -294,21 +322,83 @@ class BaseTool(ABC):
if not request_files:
return ""
# Note: Even if conversation history is already embedded, we still need to process
# any NEW files that aren't in the conversation history yet. The filter_new_files
# method will correctly identify which files need to be embedded.
# Extract remaining budget from arguments if available
if remaining_budget is None:
# Use provided arguments or fall back to stored arguments from execute()
args_to_use = arguments or getattr(self, "_current_arguments", {})
remaining_budget = args_to_use.get("_remaining_tokens")
# Use remaining budget if provided, otherwise fall back to max_tokens or default
# Use remaining budget if provided, otherwise fall back to max_tokens or model-specific default
if remaining_budget is not None:
effective_max_tokens = remaining_budget - reserve_tokens
elif max_tokens is not None:
effective_max_tokens = max_tokens - reserve_tokens
else:
from config import MAX_CONTENT_TOKENS
# Get model-specific limits
# First check if model_context was passed from server.py
model_context = None
if arguments:
model_context = arguments.get("_model_context") or getattr(self, "_current_arguments", {}).get(
"_model_context"
)
effective_max_tokens = MAX_CONTENT_TOKENS - reserve_tokens
if model_context:
# Use the passed model context
try:
token_allocation = model_context.calculate_token_allocation()
effective_max_tokens = token_allocation.file_tokens - reserve_tokens
logger.debug(
f"[FILES] {self.name}: Using passed model context for {model_context.model_name}: "
f"{token_allocation.file_tokens:,} file tokens from {token_allocation.total_tokens:,} total"
)
except Exception as e:
logger.warning(f"[FILES] {self.name}: Error using passed model context: {e}")
# Fall through to manual calculation
model_context = None
if not model_context:
# Manual calculation as fallback
from config import DEFAULT_MODEL
model_name = getattr(self, "_current_model_name", None) or DEFAULT_MODEL
try:
provider = self.get_model_provider(model_name)
capabilities = provider.get_capabilities(model_name)
# Calculate content allocation based on model capacity
if capabilities.max_tokens < 300_000:
# Smaller context models: 60% content, 40% response
model_content_tokens = int(capabilities.max_tokens * 0.6)
else:
# Larger context models: 80% content, 20% response
model_content_tokens = int(capabilities.max_tokens * 0.8)
effective_max_tokens = model_content_tokens - reserve_tokens
logger.debug(
f"[FILES] {self.name}: Using model-specific limit for {model_name}: "
f"{model_content_tokens:,} content tokens from {capabilities.max_tokens:,} total"
)
except (ValueError, AttributeError) as e:
# Handle specific errors: provider not found, model not supported, missing attributes
logger.warning(
f"[FILES] {self.name}: Could not get model capabilities for {model_name}: {type(e).__name__}: {e}"
)
# Fall back to conservative default for safety
from config import MAX_CONTENT_TOKENS
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
except Exception as e:
# Catch any other unexpected errors
logger.error(
f"[FILES] {self.name}: Unexpected error getting model capabilities: {type(e).__name__}: {e}"
)
from config import MAX_CONTENT_TOKENS
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
# Ensure we have a reasonable minimum budget
effective_max_tokens = max(1000, effective_max_tokens)
@@ -316,11 +406,21 @@ class BaseTool(ABC):
files_to_embed = self.filter_new_files(request_files, continuation_id)
logger.debug(f"[FILES] {self.name}: Will embed {len(files_to_embed)} files after filtering")
# Log the specific files for debugging/testing
if files_to_embed:
logger.info(
f"[FILE_PROCESSING] {self.name} tool will embed new files: {', '.join([os.path.basename(f) for f in files_to_embed])}"
)
else:
logger.info(
f"[FILE_PROCESSING] {self.name} tool: No new files to embed (all files already in conversation history)"
)
content_parts = []
# Read content of new files only
if files_to_embed:
logger.debug(f"📁 {self.name} tool embedding {len(files_to_embed)} new files: {', '.join(files_to_embed)}")
logger.debug(f"{self.name} tool embedding {len(files_to_embed)} new files: {', '.join(files_to_embed)}")
logger.debug(
f"[FILES] {self.name}: Starting file embedding with token budget {effective_max_tokens + reserve_tokens:,}"
)
@@ -336,11 +436,11 @@ class BaseTool(ABC):
content_tokens = estimate_tokens(file_content)
logger.debug(
f"📁 {self.name} tool successfully embedded {len(files_to_embed)} files ({content_tokens:,} tokens)"
f"{self.name} tool successfully embedded {len(files_to_embed)} files ({content_tokens:,} tokens)"
)
logger.debug(f"[FILES] {self.name}: Successfully embedded files - {content_tokens:,} tokens used")
except Exception as e:
logger.error(f"📁 {self.name} tool failed to embed files {files_to_embed}: {type(e).__name__}: {e}")
logger.error(f"{self.name} tool failed to embed files {files_to_embed}: {type(e).__name__}: {e}")
logger.debug(f"[FILES] {self.name}: File embedding failed - {type(e).__name__}: {e}")
raise
else:
@@ -352,7 +452,7 @@ class BaseTool(ABC):
skipped_files = [f for f in request_files if f in embedded_files]
if skipped_files:
logger.debug(
f"📁 {self.name} tool skipping {len(skipped_files)} files already in conversation history: {', '.join(skipped_files)}"
f"{self.name} tool skipping {len(skipped_files)} files already in conversation history: {', '.join(skipped_files)}"
)
logger.debug(f"[FILES] {self.name}: Adding note about {len(skipped_files)} skipped files")
if content_parts:
@@ -601,34 +701,63 @@ If any of these would strengthen your analysis, specify what Claude should searc
)
return [TextContent(type="text", text=error_output.model_dump_json())]
# Prepare the full prompt by combining system prompt with user request
# This is delegated to the tool implementation for customization
prompt = await self.prepare_prompt(request)
# Add follow-up instructions for new conversations (not threaded)
# Check if we have continuation_id - if so, conversation history is already embedded
continuation_id = getattr(request, "continuation_id", None)
if not continuation_id:
# Import here to avoid circular imports
if continuation_id:
# When continuation_id is present, server.py has already injected the
# conversation history into the appropriate field. We need to check if
# the prompt already contains conversation history marker.
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
# Store the original arguments to detect enhanced prompts
self._has_embedded_history = False
# Check if conversation history is already embedded in the prompt field
field_value = getattr(request, "prompt", "")
field_name = "prompt"
if "=== CONVERSATION HISTORY ===" in field_value:
# Conversation history is already embedded, use it directly
prompt = field_value
self._has_embedded_history = True
logger.debug(f"{self.name}: Using pre-embedded conversation history from {field_name}")
else:
# No embedded history, prepare prompt normally
prompt = await self.prepare_prompt(request)
logger.debug(f"{self.name}: No embedded history found, prepared prompt normally")
else:
# New conversation, prepare prompt normally
prompt = await self.prepare_prompt(request)
# Add follow-up instructions for new conversations
from server import get_follow_up_instructions
follow_up_instructions = get_follow_up_instructions(0) # New conversation, turn 0
prompt = f"{prompt}\n\n{follow_up_instructions}"
logger.debug(f"Added follow-up instructions for new {self.name} conversation")
# Also log to file for debugging MCP issues
try:
with open("/tmp/gemini_debug.log", "a") as f:
f.write(f"[{self.name}] Added follow-up instructions for new conversation\n")
except Exception:
pass
else:
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
# History reconstruction is handled by server.py:reconstruct_thread_context
# No need to rebuild it here - prompt already contains conversation history
# Extract model configuration from request or use defaults
model_name = getattr(request, "model", None) or DEFAULT_MODEL
model_name = getattr(request, "model", None)
if not model_name:
from config import DEFAULT_MODEL
model_name = DEFAULT_MODEL
# In auto mode, model parameter is required
from config import IS_AUTO_MODE
if IS_AUTO_MODE and model_name.lower() == "auto":
error_output = ToolOutput(
status="error",
content="Model parameter is required. Please specify which model to use for this task.",
content_type="text",
)
return [TextContent(type="text", text=error_output.model_dump_json())]
# Store model name for use by helper methods like _prepare_file_content_for_prompt
self._current_model_name = model_name
temperature = getattr(request, "temperature", None)
if temperature is None:
temperature = self.get_default_temperature()
@@ -636,28 +765,49 @@ If any of these would strengthen your analysis, specify what Claude should searc
if thinking_mode is None:
thinking_mode = self.get_default_thinking_mode()
# Create model instance with appropriate configuration
# This handles both regular models and thinking-enabled models
model = self.create_model(model_name, temperature, thinking_mode)
# Get the appropriate model provider
provider = self.get_model_provider(model_name)
# Generate AI response using the configured model
logger.info(f"Sending request to Gemini API for {self.name}")
# Validate and correct temperature for this model
temperature, temp_warnings = self._validate_and_correct_temperature(model_name, temperature)
# Log any temperature corrections
for warning in temp_warnings:
logger.warning(warning)
# Get system prompt for this tool
system_prompt = self.get_system_prompt()
# Generate AI response using the provider
logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.name}")
logger.info(f"Using model: {model_name} via {provider.get_provider_type().value} provider")
logger.debug(f"Prompt length: {len(prompt)} characters")
response = model.generate_content(prompt)
logger.info(f"Received response from Gemini API for {self.name}")
# Generate content with provider abstraction
model_response = provider.generate_content(
prompt=prompt,
model_name=model_name,
system_prompt=system_prompt,
temperature=temperature,
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None,
)
logger.info(f"Received response from {provider.get_provider_type().value} API for {self.name}")
# Process the model's response
if response.candidates and response.candidates[0].content.parts:
raw_text = response.candidates[0].content.parts[0].text
if model_response.content:
raw_text = model_response.content
# Parse response to check for clarification requests or format output
tool_output = self._parse_response(raw_text, request)
# Pass model info for conversation tracking
model_info = {"provider": provider, "model_name": model_name, "model_response": model_response}
tool_output = self._parse_response(raw_text, request, model_info)
logger.info(f"Successfully completed {self.name} tool execution")
else:
# Handle cases where the model couldn't generate a response
# This might happen due to safety filters or other constraints
finish_reason = response.candidates[0].finish_reason if response.candidates else "Unknown"
finish_reason = model_response.metadata.get("finish_reason", "Unknown")
logger.warning(f"Response blocked or incomplete for {self.name}. Finish reason: {finish_reason}")
tool_output = ToolOutput(
status="error",
@@ -678,13 +828,24 @@ If any of these would strengthen your analysis, specify what Claude should searc
if "500 INTERNAL" in error_msg and "Please retry" in error_msg:
logger.warning(f"500 INTERNAL error in {self.name} - attempting retry")
try:
# Single retry attempt
model = self._get_model_wrapper(request)
raw_response = await model.generate_content(prompt)
response = raw_response.text
# Single retry attempt using provider
retry_response = provider.generate_content(
prompt=prompt,
model_name=model_name,
system_prompt=system_prompt,
temperature=temperature,
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None,
)
# If successful, process normally
return [TextContent(type="text", text=self._process_response(response, request).model_dump_json())]
if retry_response.content:
# If successful, process normally
retry_model_info = {
"provider": provider,
"model_name": model_name,
"model_response": retry_response,
}
tool_output = self._parse_response(retry_response.content, request, retry_model_info)
return [TextContent(type="text", text=tool_output.model_dump_json())]
except Exception as retry_e:
logger.error(f"Retry failed for {self.name} tool: {str(retry_e)}")
@@ -699,31 +860,23 @@ If any of these would strengthen your analysis, specify what Claude should searc
)
return [TextContent(type="text", text=error_output.model_dump_json())]
def _parse_response(self, raw_text: str, request) -> ToolOutput:
def _parse_response(self, raw_text: str, request, model_info: Optional[dict] = None) -> ToolOutput:
"""
Parse the raw response and determine if it's a clarification request or follow-up.
Parse the raw response and check for clarification requests.
Some tools may return JSON indicating they need more information or want to
continue the conversation. This method detects such responses and formats them.
This method formats the response and always offers a continuation opportunity
unless max conversation turns have been reached.
Args:
raw_text: The raw text response from the model
request: The original request for context
model_info: Optional dict with model metadata
Returns:
ToolOutput: Standardized output object
"""
# Check for follow-up questions in JSON blocks at the end of the response
follow_up_question = self._extract_follow_up_question(raw_text)
logger = logging.getLogger(f"tools.{self.name}")
if follow_up_question:
logger.debug(
f"Found follow-up question in {self.name} response: {follow_up_question.get('follow_up_question', 'N/A')}"
)
else:
logger.debug(f"No follow-up question found in {self.name} response")
try:
# Try to parse as JSON to check for clarification requests
potential_json = json.loads(raw_text.strip())
@@ -745,33 +898,46 @@ If any of these would strengthen your analysis, specify what Claude should searc
pass
# Normal text response - format using tool-specific formatting
formatted_content = self.format_response(raw_text, request)
formatted_content = self.format_response(raw_text, request, model_info)
# If we found a follow-up question, prepare the threading response
if follow_up_question:
return self._create_follow_up_response(formatted_content, follow_up_question, request)
# Check if we should offer Claude a continuation opportunity
# Always check if we should offer Claude a continuation opportunity
continuation_offer = self._check_continuation_opportunity(request)
if continuation_offer:
logger.debug(
f"Creating continuation offer for {self.name} with {continuation_offer['remaining_turns']} turns remaining"
)
return self._create_continuation_offer_response(formatted_content, continuation_offer, request)
return self._create_continuation_offer_response(formatted_content, continuation_offer, request, model_info)
else:
logger.debug(f"No continuation offer created for {self.name}")
logger.debug(f"No continuation offer created for {self.name} - max turns reached")
# If this is a threaded conversation (has continuation_id), save the response
continuation_id = getattr(request, "continuation_id", None)
if continuation_id:
request_files = getattr(request, "files", []) or []
# Extract model metadata for conversation tracking
model_provider = None
model_name = None
model_metadata = None
if model_info:
provider = model_info.get("provider")
if provider:
model_provider = provider.get_provider_type().value
model_name = model_info.get("model_name")
model_response = model_info.get("model_response")
if model_response:
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
success = add_turn(
continuation_id,
"assistant",
formatted_content,
files=request_files,
tool_name=self.name,
model_provider=model_provider,
model_name=model_name,
model_metadata=model_metadata,
)
if not success:
logging.warning(f"Failed to add turn to thread {continuation_id} for {self.name}")
@@ -788,126 +954,6 @@ If any of these would strengthen your analysis, specify what Claude should searc
metadata={"tool_name": self.name},
)
def _extract_follow_up_question(self, text: str) -> Optional[dict]:
"""
Extract follow-up question from JSON blocks in the response.
Looks for JSON blocks containing follow_up_question at the end of responses.
Args:
text: The response text to parse
Returns:
Dict with follow-up data if found, None otherwise
"""
# Look for JSON blocks that contain follow_up_question
# Pattern handles optional leading whitespace and indentation
json_pattern = r'```json\s*\n\s*(\{.*?"follow_up_question".*?\})\s*\n\s*```'
matches = re.findall(json_pattern, text, re.DOTALL)
if not matches:
return None
# Take the last match (most recent follow-up)
try:
# Clean up the JSON string - remove excess whitespace and normalize
json_str = re.sub(r"\n\s+", "\n", matches[-1]).strip()
follow_up_data = json.loads(json_str)
if "follow_up_question" in follow_up_data:
return follow_up_data
except (json.JSONDecodeError, ValueError):
pass
return None
def _create_follow_up_response(self, content: str, follow_up_data: dict, request) -> ToolOutput:
"""
Create a response with follow-up question for conversation threading.
Args:
content: The main response content
follow_up_data: Dict containing follow_up_question and optional suggested_params
request: Original request for context
Returns:
ToolOutput configured for conversation continuation
"""
# Create or get thread ID
continuation_id = getattr(request, "continuation_id", None)
if continuation_id:
# This is a continuation - add this turn to existing thread
request_files = getattr(request, "files", []) or []
success = add_turn(
continuation_id,
"assistant",
content,
follow_up_question=follow_up_data.get("follow_up_question"),
files=request_files,
tool_name=self.name,
)
if not success:
# Thread not found or at limit, return normal response
return ToolOutput(
status="success",
content=content,
content_type="markdown",
metadata={"tool_name": self.name},
)
thread_id = continuation_id
else:
# Create new thread
try:
thread_id = create_thread(
tool_name=self.name, initial_request=request.model_dump() if hasattr(request, "model_dump") else {}
)
# Add the assistant's response with follow-up
request_files = getattr(request, "files", []) or []
add_turn(
thread_id,
"assistant",
content,
follow_up_question=follow_up_data.get("follow_up_question"),
files=request_files,
tool_name=self.name,
)
except Exception as e:
# Threading failed, return normal response
logger = logging.getLogger(f"tools.{self.name}")
logger.warning(f"Follow-up threading failed in {self.name}: {str(e)}")
return ToolOutput(
status="success",
content=content,
content_type="markdown",
metadata={"tool_name": self.name, "follow_up_error": str(e)},
)
# Create follow-up request
follow_up_request = FollowUpRequest(
continuation_id=thread_id,
question_to_user=follow_up_data["follow_up_question"],
suggested_tool_params=follow_up_data.get("suggested_params"),
ui_hint=follow_up_data.get("ui_hint"),
)
# Strip the JSON block from the content since it's now in the follow_up_request
clean_content = self._remove_follow_up_json(content)
return ToolOutput(
status="requires_continuation",
content=clean_content,
content_type="markdown",
follow_up_request=follow_up_request,
metadata={"tool_name": self.name, "thread_id": thread_id},
)
def _remove_follow_up_json(self, text: str) -> str:
"""Remove follow-up JSON blocks from the response text"""
# Remove JSON blocks containing follow_up_question
pattern = r'```json\s*\n\s*\{.*?"follow_up_question".*?\}\s*\n\s*```'
return re.sub(pattern, "", text, flags=re.DOTALL).strip()
def _check_continuation_opportunity(self, request) -> Optional[dict]:
"""
Check if we should offer Claude a continuation opportunity.
@@ -921,17 +967,24 @@ If any of these would strengthen your analysis, specify what Claude should searc
Returns:
Dict with continuation data if opportunity should be offered, None otherwise
"""
# Skip continuation offers in test mode
import os
if os.getenv("PYTEST_CURRENT_TEST"):
return None
continuation_id = getattr(request, "continuation_id", None)
try:
if continuation_id:
# Check remaining turns in existing thread
from utils.conversation_memory import get_thread
# Check remaining turns in thread chain
from utils.conversation_memory import get_thread_chain
context = get_thread(continuation_id)
if context:
current_turns = len(context.turns)
remaining_turns = MAX_CONVERSATION_TURNS - current_turns - 1 # -1 for this response
chain = get_thread_chain(continuation_id)
if chain:
# Count total turns across all threads in chain
total_turns = sum(len(thread.turns) for thread in chain)
remaining_turns = MAX_CONVERSATION_TURNS - total_turns - 1 # -1 for this response
else:
# Thread not found, don't offer continuation
return None
@@ -949,7 +1002,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
# If anything fails, don't offer continuation
return None
def _create_continuation_offer_response(self, content: str, continuation_data: dict, request) -> ToolOutput:
def _create_continuation_offer_response(
self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None
) -> ToolOutput:
"""
Create a response offering Claude the opportunity to continue conversation.
@@ -962,27 +1017,53 @@ If any of these would strengthen your analysis, specify what Claude should searc
ToolOutput configured with continuation offer
"""
try:
# Create new thread for potential continuation
# Create new thread for potential continuation (with parent link if continuing)
continuation_id = getattr(request, "continuation_id", None)
thread_id = create_thread(
tool_name=self.name, initial_request=request.model_dump() if hasattr(request, "model_dump") else {}
tool_name=self.name,
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
parent_thread_id=continuation_id, # Link to parent if this is a continuation
)
# Add this response as the first turn (assistant turn)
request_files = getattr(request, "files", []) or []
add_turn(thread_id, "assistant", content, files=request_files, tool_name=self.name)
# Extract model metadata
model_provider = None
model_name = None
model_metadata = None
if model_info:
provider = model_info.get("provider")
if provider:
model_provider = provider.get_provider_type().value
model_name = model_info.get("model_name")
model_response = model_info.get("model_response")
if model_response:
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
add_turn(
thread_id,
"assistant",
content,
files=request_files,
tool_name=self.name,
model_provider=model_provider,
model_name=model_name,
model_metadata=model_metadata,
)
# Create continuation offer
remaining_turns = continuation_data["remaining_turns"]
continuation_offer = ContinuationOffer(
continuation_id=thread_id,
message_to_user=(
f"If you'd like to continue this analysis or need further details, "
f"you can use the continuation_id '{thread_id}' in your next {self.name} tool call. "
f"If you'd like to continue this discussion or need to provide me with further details or context, "
f"you can use the continuation_id '{thread_id}' with any tool and any model. "
f"You have {remaining_turns} more exchange(s) available in this conversation thread."
),
suggested_tool_params={
"continuation_id": thread_id,
"prompt": "[Your follow-up question or request for additional analysis]",
"prompt": "[Your follow-up question, additional context, or further details]",
},
remaining_turns=remaining_turns,
)
@@ -1022,7 +1103,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
"""
pass
def format_response(self, response: str, request) -> str:
def format_response(self, response: str, request, model_info: Optional[dict] = None) -> str:
"""
Format the model's response for display.
@@ -1033,6 +1114,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
Args:
response: The raw response from the model
request: The original request for context
model_info: Optional dict with model metadata (provider, model_name, model_response)
Returns:
str: Formatted response
@@ -1059,154 +1141,79 @@ If any of these would strengthen your analysis, specify what Claude should searc
f"{context_type} too large (~{estimated_tokens:,} tokens). Maximum is {MAX_CONTEXT_TOKENS:,} tokens."
)
def create_model(self, model_name: str, temperature: float, thinking_mode: str = "medium"):
def _validate_and_correct_temperature(self, model_name: str, temperature: float) -> tuple[float, list[str]]:
"""
Create a configured Gemini model instance.
This method handles model creation with appropriate settings including
temperature and thinking budget configuration for models that support it.
Validate and correct temperature for the specified model.
Args:
model_name: Name of the Gemini model to use (or shorthand like 'flash', 'pro')
temperature: Temperature setting for response generation
thinking_mode: Thinking depth mode (affects computational budget)
model_name: Name of the model to validate temperature for
temperature: Temperature value to validate
Returns:
Model instance configured and ready for generation
Tuple of (corrected_temperature, warning_messages)
"""
# Define model shorthands for user convenience
model_shorthands = {
"pro": "gemini-2.5-pro-preview-06-05",
"flash": "gemini-2.0-flash-exp",
}
try:
provider = self.get_model_provider(model_name)
capabilities = provider.get_capabilities(model_name)
constraint = capabilities.temperature_constraint
# Resolve shorthand to full model name
resolved_model_name = model_shorthands.get(model_name.lower(), model_name)
warnings = []
# Map thinking modes to computational budget values
# Higher budgets allow for more complex reasoning but increase latency
thinking_budgets = {
"minimal": 128, # Minimum for 2.5 Pro - fast responses
"low": 2048, # Light reasoning tasks
"medium": 8192, # Balanced reasoning (default)
"high": 16384, # Complex analysis
"max": 32768, # Maximum reasoning depth
}
thinking_budget = thinking_budgets.get(thinking_mode, 8192)
# Gemini 2.5 models support thinking configuration for enhanced reasoning
# Skip special handling in test environment to allow mocking
if "2.5" in resolved_model_name and not os.environ.get("PYTEST_CURRENT_TEST"):
try:
# Retrieve API key for Gemini client creation
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY environment variable is required")
client = genai.Client(api_key=api_key)
# Create a wrapper class to provide a consistent interface
# This abstracts the differences between API versions
class ModelWrapper:
def __init__(self, client, model_name, temperature, thinking_budget):
self.client = client
self.model_name = model_name
self.temperature = temperature
self.thinking_budget = thinking_budget
def generate_content(self, prompt):
response = self.client.models.generate_content(
model=self.model_name,
contents=prompt,
config=types.GenerateContentConfig(
temperature=self.temperature,
candidate_count=1,
thinking_config=types.ThinkingConfig(thinking_budget=self.thinking_budget),
),
)
# Wrap the response to match the expected format
# This ensures compatibility across different API versions
class ResponseWrapper:
def __init__(self, text):
self.text = text
self.candidates = [
type(
"obj",
(object,),
{
"content": type(
"obj",
(object,),
{
"parts": [
type(
"obj",
(object,),
{"text": text},
)
]
},
)(),
"finish_reason": "STOP",
},
)
]
return ResponseWrapper(response.text)
return ModelWrapper(client, resolved_model_name, temperature, thinking_budget)
except Exception:
# Fall back to regular API if thinking configuration fails
# This ensures the tool remains functional even with API changes
pass
# For models that don't support thinking configuration, use standard API
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY environment variable is required")
client = genai.Client(api_key=api_key)
# Create a simple wrapper for models without thinking configuration
# This provides the same interface as the thinking-enabled wrapper
class SimpleModelWrapper:
def __init__(self, client, model_name, temperature):
self.client = client
self.model_name = model_name
self.temperature = temperature
def generate_content(self, prompt):
response = self.client.models.generate_content(
model=self.model_name,
contents=prompt,
config=types.GenerateContentConfig(
temperature=self.temperature,
candidate_count=1,
),
if not constraint.validate(temperature):
corrected = constraint.get_corrected_value(temperature)
warning = (
f"Temperature {temperature} invalid for {model_name}. "
f"{constraint.get_description()}. Using {corrected} instead."
)
warnings.append(warning)
return corrected, warnings
# Convert to match expected format
class ResponseWrapper:
def __init__(self, text):
self.text = text
self.candidates = [
type(
"obj",
(object,),
{
"content": type(
"obj",
(object,),
{"parts": [type("obj", (object,), {"text": text})]},
)(),
"finish_reason": "STOP",
},
)
]
return temperature, warnings
return ResponseWrapper(response.text)
except Exception as e:
# If validation fails for any reason, use the original temperature
# and log a warning (but don't fail the request)
logger = logging.getLogger(f"tools.{self.name}")
logger.warning(f"Temperature validation failed for {model_name}: {e}")
return temperature, [f"Temperature validation failed: {e}"]
return SimpleModelWrapper(client, resolved_model_name, temperature)
def get_model_provider(self, model_name: str) -> ModelProvider:
"""
Get a model provider for the specified model.
Args:
model_name: Name of the model to use (can be provider-specific or generic)
Returns:
ModelProvider instance configured for the model
Raises:
ValueError: If no provider supports the requested model
"""
# Get provider from registry
provider = ModelProviderRegistry.get_provider_for_model(model_name)
if not provider:
# Try to determine provider from model name patterns
if "gemini" in model_name.lower() or model_name.lower() in ["flash", "pro"]:
# Register Gemini provider if not already registered
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
elif "gpt" in model_name.lower() or "o3" in model_name.lower():
# Register OpenAI provider if not already registered
from providers.base import ProviderType
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
provider = ModelProviderRegistry.get_provider(ProviderType.OPENAI)
if not provider:
raise ValueError(
f"No provider found for model '{model_name}'. "
f"Ensure the appropriate API key is set and the model name is correct."
)
return provider

View File

@@ -19,7 +19,7 @@ class ChatRequest(ToolRequest):
prompt: str = Field(
...,
description="Your question, topic, or current thinking to discuss with Gemini",
description="Your question, topic, or current thinking to discuss",
)
files: Optional[list[str]] = Field(
default_factory=list,
@@ -35,33 +35,30 @@ class ChatTool(BaseTool):
def get_description(self) -> str:
return (
"GENERAL CHAT & COLLABORATIVE THINKING - Use Gemini as your thinking partner! "
"GENERAL CHAT & COLLABORATIVE THINKING - Use the AI model as your thinking partner! "
"Perfect for: bouncing ideas during your own analysis, getting second opinions on your plans, "
"collaborative brainstorming, validating your checklists and approaches, exploring alternatives. "
"Also great for: explanations, comparisons, general development questions. "
"Use this when you want to ask Gemini questions, brainstorm ideas, get opinions, discuss topics, "
"Use this when you want to ask questions, brainstorm ideas, get opinions, discuss topics, "
"share your thinking, or need explanations about concepts and approaches."
)
def get_input_schema(self) -> dict[str, Any]:
from config import DEFAULT_MODEL
from config import IS_AUTO_MODE
return {
schema = {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Your question, topic, or current thinking to discuss with Gemini",
"description": "Your question, topic, or current thinking to discuss",
},
"files": {
"type": "array",
"items": {"type": "string"},
"description": "Optional files for context (must be absolute paths)",
},
"model": {
"type": "string",
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
},
"model": self.get_model_field_schema(),
"temperature": {
"type": "number",
"description": "Response creativity (0-1, default 0.5)",
@@ -83,9 +80,11 @@ class ChatTool(BaseTool):
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
},
},
"required": ["prompt"],
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
}
return schema
def get_system_prompt(self) -> str:
return CHAT_PROMPT
@@ -153,6 +152,6 @@ Please provide a thoughtful, comprehensive response:"""
return full_prompt
def format_response(self, response: str, request: ChatRequest) -> str:
"""Format the chat response with actionable guidance"""
def format_response(self, response: str, request: ChatRequest, model_info: Optional[dict] = None) -> str:
"""Format the chat response"""
return f"{response}\n\n---\n\n**Claude's Turn:** Evaluate this perspective alongside your analysis to form a comprehensive solution and continue with the user's request and task at hand."

View File

@@ -39,12 +39,15 @@ class CodeReviewRequest(ToolRequest):
...,
description="Code files or directories to review (must be absolute paths)",
)
context: str = Field(
prompt: str = Field(
...,
description="User's summary of what the code does, expected behavior, constraints, and review objectives",
)
review_type: str = Field("full", description="Type of review: full|security|performance|quick")
focus_on: Optional[str] = Field(None, description="Specific aspects to focus on during review")
focus_on: Optional[str] = Field(
None,
description="Specific aspects to focus on, or additional context that would help understand areas of concern",
)
standards: Optional[str] = Field(None, description="Coding standards or guidelines to enforce")
severity_filter: str = Field(
"all",
@@ -79,9 +82,9 @@ class CodeReviewTool(BaseTool):
)
def get_input_schema(self) -> dict[str, Any]:
from config import DEFAULT_MODEL
from config import IS_AUTO_MODE
return {
schema = {
"type": "object",
"properties": {
"files": {
@@ -89,11 +92,8 @@ class CodeReviewTool(BaseTool):
"items": {"type": "string"},
"description": "Code files or directories to review (must be absolute paths)",
},
"model": {
"type": "string",
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
},
"context": {
"model": self.get_model_field_schema(),
"prompt": {
"type": "string",
"description": "User's summary of what the code does, expected behavior, constraints, and review objectives",
},
@@ -105,7 +105,7 @@ class CodeReviewTool(BaseTool):
},
"focus_on": {
"type": "string",
"description": "Specific aspects to focus on",
"description": "Specific aspects to focus on, or additional context that would help understand areas of concern",
},
"standards": {
"type": "string",
@@ -138,9 +138,11 @@ class CodeReviewTool(BaseTool):
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
},
},
"required": ["files", "context"],
"required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []),
}
return schema
def get_system_prompt(self) -> str:
return CODEREVIEW_PROMPT
@@ -184,9 +186,9 @@ class CodeReviewTool(BaseTool):
# Check for prompt.txt in files
prompt_content, updated_files = self.handle_prompt_file(request.files)
# If prompt.txt was found, use it as focus_on
# If prompt.txt was found, incorporate it into the prompt
if prompt_content:
request.focus_on = prompt_content
request.prompt = prompt_content + "\n\n" + request.prompt
# Update request files list
if updated_files is not None:
@@ -234,7 +236,7 @@ class CodeReviewTool(BaseTool):
full_prompt = f"""{self.get_system_prompt()}{websearch_instruction}
=== USER CONTEXT ===
{request.context}
{request.prompt}
=== END CONTEXT ===
{focus_instruction}
@@ -247,27 +249,19 @@ Please provide a code review aligned with the user's context and expectations, f
return full_prompt
def format_response(self, response: str, request: CodeReviewRequest) -> str:
def format_response(self, response: str, request: CodeReviewRequest, model_info: Optional[dict] = None) -> str:
"""
Format the review response with appropriate headers.
Adds context about the review type and focus area to help
users understand the scope of the review.
Format the review response.
Args:
response: The raw review from the model
request: The original request for context
model_info: Optional dict with model metadata
Returns:
str: Formatted response with headers
str: Formatted response with next steps
"""
header = f"Code Review ({request.review_type.upper()})"
if request.focus_on:
header += f" - Focus: {request.focus_on}"
return f"""{header}
{"=" * 50}
{response}
return f"""{response}
---

View File

@@ -17,7 +17,7 @@ from .models import ToolOutput
class DebugIssueRequest(ToolRequest):
"""Request model for debug tool"""
error_description: str = Field(..., description="Error message, symptoms, or issue description")
prompt: str = Field(..., description="Error message, symptoms, or issue description")
error_context: Optional[str] = Field(None, description="Stack trace, logs, or additional error context")
files: Optional[list[str]] = Field(
None,
@@ -38,7 +38,7 @@ class DebugIssueTool(BaseTool):
"DEBUG & ROOT CAUSE ANALYSIS - Expert debugging for complex issues with 1M token capacity. "
"Use this when you need to debug code, find out why something is failing, identify root causes, "
"trace errors, or diagnose issues. "
"IMPORTANT: Share diagnostic files liberally! Gemini can handle up to 1M tokens, so include: "
"IMPORTANT: Share diagnostic files liberally! The model can handle up to 1M tokens, so include: "
"large log files, full stack traces, memory dumps, diagnostic outputs, multiple related files, "
"entire modules, test results, configuration files - anything that might help debug the issue. "
"Claude should proactively use this tool whenever debugging is needed and share comprehensive "
@@ -50,19 +50,16 @@ class DebugIssueTool(BaseTool):
)
def get_input_schema(self) -> dict[str, Any]:
from config import DEFAULT_MODEL
from config import IS_AUTO_MODE
return {
schema = {
"type": "object",
"properties": {
"error_description": {
"prompt": {
"type": "string",
"description": "Error message, symptoms, or issue description",
},
"model": {
"type": "string",
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
},
"model": self.get_model_field_schema(),
"error_context": {
"type": "string",
"description": "Stack trace, logs, or additional error context",
@@ -101,9 +98,11 @@ class DebugIssueTool(BaseTool):
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
},
},
"required": ["error_description"],
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
}
return schema
def get_system_prompt(self) -> str:
return DEBUG_ISSUE_PROMPT
@@ -119,8 +118,8 @@ class DebugIssueTool(BaseTool):
request_model = self.get_request_model()
request = request_model(**arguments)
# Check error_description size
size_check = self.check_prompt_size(request.error_description)
# Check prompt size
size_check = self.check_prompt_size(request.prompt)
if size_check:
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
@@ -138,11 +137,10 @@ class DebugIssueTool(BaseTool):
# Check for prompt.txt in files
prompt_content, updated_files = self.handle_prompt_file(request.files)
# If prompt.txt was found, use it as error_description or error_context
# Priority: if error_description is empty, use it there, otherwise use as error_context
# If prompt.txt was found, use it as prompt or error_context
if prompt_content:
if not request.error_description or request.error_description == "":
request.error_description = prompt_content
if not request.prompt or request.prompt == "":
request.prompt = prompt_content
else:
request.error_context = prompt_content
@@ -151,7 +149,7 @@ class DebugIssueTool(BaseTool):
request.files = updated_files
# Build context sections
context_parts = [f"=== ISSUE DESCRIPTION ===\n{request.error_description}\n=== END DESCRIPTION ==="]
context_parts = [f"=== ISSUE DESCRIPTION ===\n{request.prompt}\n=== END DESCRIPTION ==="]
if request.error_context:
context_parts.append(f"\n=== ERROR CONTEXT/STACK TRACE ===\n{request.error_context}\n=== END CONTEXT ===")
@@ -197,11 +195,15 @@ Focus on finding the root cause and providing actionable solutions."""
return full_prompt
def format_response(self, response: str, request: DebugIssueRequest) -> str:
def format_response(self, response: str, request: DebugIssueRequest, model_info: Optional[dict] = None) -> str:
"""Format the debugging response"""
return (
f"Debug Analysis\n{'=' * 50}\n\n{response}\n\n---\n\n"
"**Next Steps:** Evaluate Gemini's recommendations, synthesize the best fix considering potential "
"regressions, and if the root cause has been clearly identified, proceed with implementing the "
"potential fixes."
)
# Get the friendly model name
model_name = "the model"
if model_info and model_info.get("model_response"):
model_name = model_info["model_response"].friendly_name or "the model"
return f"""{response}
---
**Next Steps:** Evaluate {model_name}'s recommendations, synthesize the best fix considering potential regressions, and if the root cause has been clearly identified, proceed with implementing the potential fixes."""

View File

@@ -7,21 +7,6 @@ from typing import Any, Literal, Optional
from pydantic import BaseModel, Field
class FollowUpRequest(BaseModel):
"""Request for follow-up conversation turn"""
continuation_id: str = Field(
..., description="Thread continuation ID for multi-turn conversations across different tools"
)
question_to_user: str = Field(..., description="Follow-up question to ask Claude")
suggested_tool_params: Optional[dict[str, Any]] = Field(
None, description="Suggested parameters for the next tool call"
)
ui_hint: Optional[str] = Field(
None, description="UI hint for Claude (e.g., 'text_input', 'file_select', 'multi_choice')"
)
class ContinuationOffer(BaseModel):
"""Offer for Claude to continue conversation when Gemini doesn't ask follow-up"""
@@ -43,15 +28,11 @@ class ToolOutput(BaseModel):
"error",
"requires_clarification",
"requires_file_prompt",
"requires_continuation",
"continuation_available",
] = "success"
content: Optional[str] = Field(None, description="The main content/response from the tool")
content_type: Literal["text", "markdown", "json"] = "text"
metadata: Optional[dict[str, Any]] = Field(default_factory=dict)
follow_up_request: Optional[FollowUpRequest] = Field(
None, description="Optional follow-up request for continued conversation"
)
continuation_offer: Optional[ContinuationOffer] = Field(
None, description="Optional offer for Claude to continue conversation"
)

View File

@@ -31,7 +31,7 @@ class PrecommitRequest(ToolRequest):
...,
description="Starting directory to search for git repositories (must be absolute path).",
)
original_request: Optional[str] = Field(
prompt: Optional[str] = Field(
None,
description="The original user request description for the changes. Provides critical context for the review.",
)
@@ -98,15 +98,17 @@ class Precommit(BaseTool):
)
def get_input_schema(self) -> dict[str, Any]:
from config import DEFAULT_MODEL
from config import IS_AUTO_MODE
schema = self.get_request_model().model_json_schema()
# Ensure model parameter has enhanced description
if "properties" in schema and "model" in schema["properties"]:
schema["properties"]["model"] = {
"type": "string",
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
}
schema["properties"]["model"] = self.get_model_field_schema()
# In auto mode, model is required
if IS_AUTO_MODE and "required" in schema:
if "model" not in schema["required"]:
schema["required"].append("model")
# Ensure use_websearch is in the schema with proper description
if "properties" in schema and "use_websearch" not in schema["properties"]:
schema["properties"]["use_websearch"] = {
@@ -140,9 +142,9 @@ class Precommit(BaseTool):
request_model = self.get_request_model()
request = request_model(**arguments)
# Check original_request size if provided
if request.original_request:
size_check = self.check_prompt_size(request.original_request)
# Check prompt size if provided
if request.prompt:
size_check = self.check_prompt_size(request.prompt)
if size_check:
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
@@ -154,9 +156,9 @@ class Precommit(BaseTool):
# Check for prompt.txt in files
prompt_content, updated_files = self.handle_prompt_file(request.files)
# If prompt.txt was found, use it as original_request
# If prompt.txt was found, use it as prompt
if prompt_content:
request.original_request = prompt_content
request.prompt = prompt_content
# Update request files list
if updated_files is not None:
@@ -330,7 +332,7 @@ class Precommit(BaseTool):
context_files_content = [file_content]
context_files_summary.append(f"✅ Included: {len(translated_files)} context files")
else:
context_files_summary.append("⚠️ No context files could be read or files too large")
context_files_summary.append("WARNING: No context files could be read or files too large")
total_tokens += context_tokens
@@ -338,8 +340,8 @@ class Precommit(BaseTool):
prompt_parts = []
# Add original request context if provided
if request.original_request:
prompt_parts.append(f"## Original Request\n\n{request.original_request}\n")
if request.prompt:
prompt_parts.append(f"## Original Request\n\n{request.prompt}\n")
# Add review parameters
prompt_parts.append("## Review Parameters\n")
@@ -366,7 +368,7 @@ class Precommit(BaseTool):
for idx, summary in enumerate(repo_summaries, 1):
prompt_parts.append(f"\n### Repository {idx}: {summary['path']}")
if "error" in summary:
prompt_parts.append(f"⚠️ Error: {summary['error']}")
prompt_parts.append(f"ERROR: {summary['error']}")
else:
prompt_parts.append(f"- Branch: {summary['branch']}")
if summary["ahead"] or summary["behind"]:
@@ -443,6 +445,6 @@ class Precommit(BaseTool):
return full_prompt
def format_response(self, response: str, request: PrecommitRequest) -> str:
def format_response(self, response: str, request: PrecommitRequest, model_info: Optional[dict] = None) -> str:
"""Format the response with commit guidance"""
return f"{response}\n\n---\n\n**Commit Status:** If no critical issues found, changes are ready for commit. Otherwise, address issues first and re-run review. Check with user before proceeding with any commit."

View File

@@ -17,7 +17,7 @@ from .models import ToolOutput
class ThinkDeepRequest(ToolRequest):
"""Request model for thinkdeep tool"""
current_analysis: str = Field(..., description="Claude's current thinking/analysis to extend")
prompt: str = Field(..., description="Your current thinking/analysis to extend and validate")
problem_context: Optional[str] = Field(None, description="Additional context about the problem or goal")
focus_areas: Optional[list[str]] = Field(
None,
@@ -48,19 +48,16 @@ class ThinkDeepTool(BaseTool):
)
def get_input_schema(self) -> dict[str, Any]:
from config import DEFAULT_MODEL
from config import IS_AUTO_MODE
return {
schema = {
"type": "object",
"properties": {
"current_analysis": {
"prompt": {
"type": "string",
"description": "Your current thinking/analysis to extend and validate",
},
"model": {
"type": "string",
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
},
"model": self.get_model_field_schema(),
"problem_context": {
"type": "string",
"description": "Additional context about the problem or goal",
@@ -96,9 +93,11 @@ class ThinkDeepTool(BaseTool):
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
},
},
"required": ["current_analysis"],
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
}
return schema
def get_system_prompt(self) -> str:
return THINKDEEP_PROMPT
@@ -120,8 +119,8 @@ class ThinkDeepTool(BaseTool):
request_model = self.get_request_model()
request = request_model(**arguments)
# Check current_analysis size
size_check = self.check_prompt_size(request.current_analysis)
# Check prompt size
size_check = self.check_prompt_size(request.prompt)
if size_check:
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
@@ -133,8 +132,8 @@ class ThinkDeepTool(BaseTool):
# Check for prompt.txt in files
prompt_content, updated_files = self.handle_prompt_file(request.files)
# Use prompt.txt content if available, otherwise use the current_analysis field
current_analysis = prompt_content if prompt_content else request.current_analysis
# Use prompt.txt content if available, otherwise use the prompt field
current_analysis = prompt_content if prompt_content else request.prompt
# Update request files list
if updated_files is not None:
@@ -190,21 +189,24 @@ Please provide deep analysis that extends Claude's thinking with:
return full_prompt
def format_response(self, response: str, request: ThinkDeepRequest) -> str:
def format_response(self, response: str, request: ThinkDeepRequest, model_info: Optional[dict] = None) -> str:
"""Format the response with clear attribution and critical thinking prompt"""
return f"""## Extended Analysis by Gemini
# Get the friendly model name
model_name = "your fellow developer"
if model_info and model_info.get("model_response"):
model_name = model_info["model_response"].friendly_name or "your fellow developer"
{response}
return f"""{response}
---
## Critical Evaluation Required
Claude, please critically evaluate Gemini's analysis by considering:
Claude, please critically evaluate {model_name}'s analysis by thinking hard about the following:
1. **Technical merit** - Which suggestions are valuable vs. have limitations?
2. **Constraints** - Fit with codebase patterns, performance, security, architecture
3. **Risks** - Hidden complexities, edge cases, potential failure modes
4. **Final recommendation** - Synthesize both perspectives, then think deeply further to explore additional considerations and arrive at the best technical solution
Remember: Use Gemini's insights to enhance, not replace, your analysis."""
Remember: Use {model_name}'s insights to enhance, not replace, your analysis."""

View File

@@ -1,5 +1,5 @@
"""
Utility functions for Gemini MCP Server
Utility functions for Zen MCP Server
"""
from .file_utils import CODE_EXTENSIONS, expand_paths, read_file_content, read_files

View File

@@ -68,20 +68,24 @@ class ConversationTurn(BaseModel):
the content and metadata needed for cross-tool continuation.
Attributes:
role: "user" (Claude) or "assistant" (Gemini)
role: "user" (Claude) or "assistant" (Gemini/O3/etc)
content: The actual message content/response
timestamp: ISO timestamp when this turn was created
follow_up_question: Optional follow-up question from Gemini to Claude
files: List of file paths referenced in this specific turn
tool_name: Which tool generated this turn (for cross-tool tracking)
model_provider: Provider used (e.g., "google", "openai")
model_name: Specific model used (e.g., "gemini-2.0-flash", "o3-mini")
model_metadata: Additional model-specific metadata (e.g., thinking mode, token usage)
"""
role: str # "user" or "assistant"
content: str
timestamp: str
follow_up_question: Optional[str] = None
files: Optional[list[str]] = None # Files referenced in this turn
tool_name: Optional[str] = None # Tool used for this turn
model_provider: Optional[str] = None # Model provider (google, openai, etc)
model_name: Optional[str] = None # Specific model used
model_metadata: Optional[dict[str, Any]] = None # Additional model info
class ThreadContext(BaseModel):
@@ -94,6 +98,7 @@ class ThreadContext(BaseModel):
Attributes:
thread_id: UUID identifying this conversation thread
parent_thread_id: UUID of parent thread (for conversation chains)
created_at: ISO timestamp when thread was created
last_updated_at: ISO timestamp of last modification
tool_name: Name of the tool that initiated this thread
@@ -102,6 +107,7 @@ class ThreadContext(BaseModel):
"""
thread_id: str
parent_thread_id: Optional[str] = None # Parent thread for conversation chains
created_at: str
last_updated_at: str
tool_name: str # Tool that created this thread (preserved for attribution)
@@ -131,7 +137,7 @@ def get_redis_client():
raise ValueError("redis package required. Install with: pip install redis")
def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
def create_thread(tool_name: str, initial_request: dict[str, Any], parent_thread_id: Optional[str] = None) -> str:
"""
Create new conversation thread and return thread ID
@@ -142,6 +148,7 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
Args:
tool_name: Name of the tool creating this thread (e.g., "analyze", "chat")
initial_request: Original request parameters (will be filtered for serialization)
parent_thread_id: Optional parent thread ID for conversation chains
Returns:
str: UUID thread identifier that can be used for continuation
@@ -150,6 +157,7 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
- Thread expires after 1 hour (3600 seconds)
- Non-serializable parameters are filtered out automatically
- Thread can be continued by any tool using the returned UUID
- Parent thread creates a chain for conversation history traversal
"""
thread_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
@@ -163,6 +171,7 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
context = ThreadContext(
thread_id=thread_id,
parent_thread_id=parent_thread_id, # Link to parent for conversation chains
created_at=now,
last_updated_at=now,
tool_name=tool_name, # Track which tool initiated this conversation
@@ -175,6 +184,8 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
key = f"thread:{thread_id}"
client.setex(key, 3600, context.model_dump_json())
logger.debug(f"[THREAD] Created new thread {thread_id} with parent {parent_thread_id}")
return thread_id
@@ -218,37 +229,42 @@ def add_turn(
thread_id: str,
role: str,
content: str,
follow_up_question: Optional[str] = None,
files: Optional[list[str]] = None,
tool_name: Optional[str] = None,
model_provider: Optional[str] = None,
model_name: Optional[str] = None,
model_metadata: Optional[dict[str, Any]] = None,
) -> bool:
"""
Add turn to existing thread
Appends a new conversation turn to an existing thread. This is the core
function for building conversation history and enabling cross-tool
continuation. Each turn preserves the tool that generated it.
continuation. Each turn preserves the tool and model that generated it.
Args:
thread_id: UUID of the conversation thread
role: "user" (Claude) or "assistant" (Gemini)
role: "user" (Claude) or "assistant" (Gemini/O3/etc)
content: The actual message/response content
follow_up_question: Optional follow-up question from Gemini
files: Optional list of files referenced in this turn
tool_name: Name of the tool adding this turn (for attribution)
model_provider: Provider used (e.g., "google", "openai")
model_name: Specific model used (e.g., "gemini-2.0-flash", "o3-mini")
model_metadata: Additional model info (e.g., thinking mode, token usage)
Returns:
bool: True if turn was successfully added, False otherwise
Failure cases:
- Thread doesn't exist or expired
- Maximum turn limit reached (5 turns)
- Maximum turn limit reached
- Redis connection failure
Note:
- Refreshes thread TTL to 1 hour on successful update
- Turn limits prevent runaway conversations
- File references are preserved for cross-tool access
- Model information enables cross-provider conversations
"""
logger.debug(f"[FLOW] Adding {role} turn to {thread_id} ({tool_name})")
@@ -267,9 +283,11 @@ def add_turn(
role=role,
content=content,
timestamp=datetime.now(timezone.utc).isoformat(),
follow_up_question=follow_up_question,
files=files, # Preserved for cross-tool file context
tool_name=tool_name, # Track which tool generated this turn
model_provider=model_provider, # Track model provider
model_name=model_name, # Track specific model
model_metadata=model_metadata, # Additional model info
)
context.turns.append(turn)
@@ -286,6 +304,48 @@ def add_turn(
return False
def get_thread_chain(thread_id: str, max_depth: int = 20) -> list[ThreadContext]:
"""
Traverse the parent chain to get all threads in conversation sequence.
Retrieves the complete conversation chain by following parent_thread_id
links. Returns threads in chronological order (oldest first).
Args:
thread_id: Starting thread ID
max_depth: Maximum chain depth to prevent infinite loops
Returns:
list[ThreadContext]: All threads in chain, oldest first
"""
chain = []
current_id = thread_id
seen_ids = set()
# Build chain from current to oldest
while current_id and len(chain) < max_depth:
# Prevent circular references
if current_id in seen_ids:
logger.warning(f"[THREAD] Circular reference detected in thread chain at {current_id}")
break
seen_ids.add(current_id)
context = get_thread(current_id)
if not context:
logger.debug(f"[THREAD] Thread {current_id} not found in chain traversal")
break
chain.append(context)
current_id = context.parent_thread_id
# Reverse to get chronological order (oldest first)
chain.reverse()
logger.debug(f"[THREAD] Retrieved chain of {len(chain)} threads for {thread_id}")
return chain
def get_conversation_file_list(context: ThreadContext) -> list[str]:
"""
Get all unique files referenced across all turns in a conversation.
@@ -327,7 +387,7 @@ def get_conversation_file_list(context: ThreadContext) -> list[str]:
return unique_files
def build_conversation_history(context: ThreadContext, read_files_func=None) -> tuple[str, int]:
def build_conversation_history(context: ThreadContext, model_context=None, read_files_func=None) -> tuple[str, int]:
"""
Build formatted conversation history for tool prompts with embedded file contents.
@@ -336,8 +396,13 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
start, even if referenced in multiple turns, to prevent duplication and optimize
token usage.
If the thread has a parent chain, this function traverses the entire chain to
include the complete conversation history.
Args:
context: ThreadContext containing the complete conversation
model_context: ModelContext for token allocation (optional, uses DEFAULT_MODEL if not provided)
read_files_func: Optional function to read files (for testing)
Returns:
tuple[str, int]: (formatted_conversation_history, total_tokens_used)
@@ -355,18 +420,68 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
file contents from previous tools, enabling true cross-tool collaboration
while preventing duplicate file embeddings.
"""
if not context.turns:
# Get the complete thread chain
if context.parent_thread_id:
# This thread has a parent, get the full chain
chain = get_thread_chain(context.thread_id)
# Collect all turns from all threads in chain
all_turns = []
all_files_set = set()
total_turns = 0
for thread in chain:
all_turns.extend(thread.turns)
total_turns += len(thread.turns)
# Collect files from this thread
for turn in thread.turns:
if turn.files:
all_files_set.update(turn.files)
all_files = list(all_files_set)
logger.debug(f"[THREAD] Built history from {len(chain)} threads with {total_turns} total turns")
else:
# Single thread, no parent chain
all_turns = context.turns
total_turns = len(context.turns)
all_files = get_conversation_file_list(context)
if not all_turns:
return "", 0
# Get all unique files referenced in this conversation
all_files = get_conversation_file_list(context)
logger.debug(f"[FILES] Found {len(all_files)} unique files in conversation history")
# Get model-specific token allocation early (needed for both files and turns)
if model_context is None:
from config import DEFAULT_MODEL, IS_AUTO_MODE
from utils.model_context import ModelContext
# In auto mode, use an intelligent fallback model for token calculations
# since "auto" is not a real model with a provider
model_name = DEFAULT_MODEL
if IS_AUTO_MODE and model_name.lower() == "auto":
# Use intelligent fallback based on available API keys
from providers.registry import ModelProviderRegistry
model_name = ModelProviderRegistry.get_preferred_fallback_model()
model_context = ModelContext(model_name)
token_allocation = model_context.calculate_token_allocation()
max_file_tokens = token_allocation.file_tokens
max_history_tokens = token_allocation.history_tokens
logger.debug(f"[HISTORY] Using model-specific limits for {model_context.model_name}:")
logger.debug(f"[HISTORY] Max file tokens: {max_file_tokens:,}")
logger.debug(f"[HISTORY] Max history tokens: {max_history_tokens:,}")
history_parts = [
"=== CONVERSATION HISTORY ===",
"=== CONVERSATION HISTORY (CONTINUATION) ===",
f"Thread: {context.thread_id}",
f"Tool: {context.tool_name}", # Original tool that started the conversation
f"Turn {len(context.turns)}/{MAX_CONVERSATION_TURNS}",
f"Turn {total_turns}/{MAX_CONVERSATION_TURNS}",
"You are continuing this conversation thread from where it left off.",
"",
]
@@ -382,9 +497,6 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
]
)
# Import required functions
from config import MAX_CONTENT_TOKENS
if read_files_func is None:
from utils.file_utils import read_file_content
@@ -402,12 +514,12 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
if formatted_content:
# read_file_content already returns formatted content, use it directly
# Check if adding this file would exceed the limit
if total_tokens + content_tokens <= MAX_CONTENT_TOKENS:
if total_tokens + content_tokens <= max_file_tokens:
file_contents.append(formatted_content)
total_tokens += content_tokens
files_included += 1
logger.debug(
f"📄 File embedded in conversation history: {file_path} ({content_tokens:,} tokens)"
f"File embedded in conversation history: {file_path} ({content_tokens:,} tokens)"
)
logger.debug(
f"[FILES] Successfully embedded {file_path} - {content_tokens:,} tokens (total: {total_tokens:,})"
@@ -415,7 +527,7 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
else:
files_truncated += 1
logger.debug(
f"📄 File truncated due to token limit: {file_path} ({content_tokens:,} tokens, would exceed {MAX_CONTENT_TOKENS:,} limit)"
f"File truncated due to token limit: {file_path} ({content_tokens:,} tokens, would exceed {max_file_tokens:,} limit)"
)
logger.debug(
f"[FILES] File {file_path} would exceed token limit - skipping (would be {total_tokens + content_tokens:,} tokens)"
@@ -423,12 +535,12 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
# Stop processing more files
break
else:
logger.debug(f"📄 File skipped (empty content): {file_path}")
logger.debug(f"File skipped (empty content): {file_path}")
logger.debug(f"[FILES] File {file_path} has empty content - skipping")
except Exception as e:
# Skip files that can't be read but log the failure
logger.warning(
f"📄 Failed to embed file in conversation history: {file_path} - {type(e).__name__}: {e}"
f"Failed to embed file in conversation history: {file_path} - {type(e).__name__}: {e}"
)
logger.debug(f"[FILES] Failed to read file {file_path} - {type(e).__name__}: {e}")
continue
@@ -441,7 +553,7 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
)
history_parts.append(files_content)
logger.debug(
f"📄 Conversation history file embedding complete: {files_included} files embedded, {files_truncated} truncated, {total_tokens:,} total tokens"
f"Conversation history file embedding complete: {files_included} files embedded, {files_truncated} truncated, {total_tokens:,} total tokens"
)
logger.debug(
f"[FILES] File embedding summary - {files_included} embedded, {files_truncated} truncated, {total_tokens:,} tokens total"
@@ -449,7 +561,7 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
else:
history_parts.append("(No accessible files found)")
logger.debug(
f"📄 Conversation history file embedding: no accessible files found from {len(all_files)} requested"
f"Conversation history file embedding: no accessible files found from {len(all_files)} requested"
)
logger.debug(f"[FILES] No accessible files found from {len(all_files)} requested files")
else:
@@ -464,7 +576,7 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
history_parts.append(files_content)
else:
# Handle token limit exceeded for conversation files
error_message = f"ERROR: The total size of files referenced in this conversation has exceeded the context limit and cannot be displayed.\nEstimated tokens: {estimated_tokens}, but limit is {MAX_CONTENT_TOKENS}."
error_message = f"ERROR: The total size of files referenced in this conversation has exceeded the context limit and cannot be displayed.\nEstimated tokens: {estimated_tokens}, but limit is {max_file_tokens}."
history_parts.append(error_message)
else:
history_parts.append("(No accessible files found)")
@@ -479,31 +591,84 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
history_parts.append("Previous conversation turns:")
for i, turn in enumerate(context.turns, 1):
# Build conversation turns bottom-up (most recent first) but present chronologically
# This ensures we include as many recent turns as possible within the token budget
turn_entries = [] # Will store (index, formatted_turn_content) for chronological ordering
total_turn_tokens = 0
file_embedding_tokens = sum(model_context.estimate_tokens(part) for part in history_parts)
# Process turns in reverse order (most recent first) to prioritize recent context
for idx in range(len(all_turns) - 1, -1, -1):
turn = all_turns[idx]
turn_num = idx + 1
role_label = "Claude" if turn.role == "user" else "Gemini"
# Build the complete turn content
turn_parts = []
# Add turn header with tool attribution for cross-tool tracking
turn_header = f"\n--- Turn {i} ({role_label}"
turn_header = f"\n--- Turn {turn_num} ({role_label}"
if turn.tool_name:
turn_header += f" using {turn.tool_name}"
# Add model info if available
if turn.model_provider and turn.model_name:
turn_header += f" via {turn.model_provider}/{turn.model_name}"
turn_header += ") ---"
history_parts.append(turn_header)
turn_parts.append(turn_header)
# Add files context if present - but just reference which files were used
# (the actual contents are already embedded above)
if turn.files:
history_parts.append(f"📁 Files used in this turn: {', '.join(turn.files)}")
history_parts.append("") # Empty line for readability
turn_parts.append(f"Files used in this turn: {', '.join(turn.files)}")
turn_parts.append("") # Empty line for readability
# Add the actual content
history_parts.append(turn.content)
turn_parts.append(turn.content)
# Add follow-up question if present
if turn.follow_up_question:
history_parts.append(f"\n[Gemini's Follow-up: {turn.follow_up_question}]")
# Calculate tokens for this turn
turn_content = "\n".join(turn_parts)
turn_tokens = model_context.estimate_tokens(turn_content)
# Check if adding this turn would exceed history budget
if file_embedding_tokens + total_turn_tokens + turn_tokens > max_history_tokens:
# Stop adding turns - we've reached the limit
logger.debug(f"[HISTORY] Stopping at turn {turn_num} - would exceed history budget")
logger.debug(f"[HISTORY] File tokens: {file_embedding_tokens:,}")
logger.debug(f"[HISTORY] Turn tokens so far: {total_turn_tokens:,}")
logger.debug(f"[HISTORY] This turn: {turn_tokens:,}")
logger.debug(f"[HISTORY] Would total: {file_embedding_tokens + total_turn_tokens + turn_tokens:,}")
logger.debug(f"[HISTORY] Budget: {max_history_tokens:,}")
break
# Add this turn to our list (we'll reverse it later for chronological order)
turn_entries.append((idx, turn_content))
total_turn_tokens += turn_tokens
# Reverse to get chronological order (oldest first)
turn_entries.reverse()
# Add the turns in chronological order
for _, turn_content in turn_entries:
history_parts.append(turn_content)
# Log what we included
included_turns = len(turn_entries)
total_turns = len(all_turns)
if included_turns < total_turns:
logger.info(f"[HISTORY] Included {included_turns}/{total_turns} turns due to token limit")
history_parts.append(f"\n[Note: Showing {included_turns} most recent turns out of {total_turns} total]")
history_parts.extend(
["", "=== END CONVERSATION HISTORY ===", "", "Continue this conversation by building on the previous context."]
[
"",
"=== END CONVERSATION HISTORY ===",
"",
"IMPORTANT: You are continuing an existing conversation thread. Build upon the previous exchanges shown above,",
"reference earlier points, and maintain consistency with what has been discussed.",
f"This is turn {len(all_turns) + 1} of the conversation - use the conversation history above to provide a coherent continuation.",
]
)
# Calculate total tokens for the complete conversation history
@@ -513,8 +678,8 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
total_conversation_tokens = estimate_tokens(complete_history)
# Summary log of what was built
user_turns = len([t for t in context.turns if t.role == "user"])
assistant_turns = len([t for t in context.turns if t.role == "assistant"])
user_turns = len([t for t in all_turns if t.role == "user"])
assistant_turns = len([t for t in all_turns if t.role == "assistant"])
logger.debug(
f"[FLOW] Built conversation history: {user_turns} user + {assistant_turns} assistant turns, {len(all_files)} files, {total_conversation_tokens:,} tokens"
)

131
utils/model_context.py Normal file
View File

@@ -0,0 +1,131 @@
"""
Model context management for dynamic token allocation.
This module provides a clean abstraction for model-specific token management,
ensuring that token limits are properly calculated based on the current model
being used, not global constants.
"""
import logging
from dataclasses import dataclass
from typing import Any, Optional
from config import DEFAULT_MODEL
from providers import ModelCapabilities, ModelProviderRegistry
logger = logging.getLogger(__name__)
@dataclass
class TokenAllocation:
"""Token allocation strategy for a model."""
total_tokens: int
content_tokens: int
response_tokens: int
file_tokens: int
history_tokens: int
@property
def available_for_prompt(self) -> int:
"""Tokens available for the actual prompt after allocations."""
return self.content_tokens - self.file_tokens - self.history_tokens
class ModelContext:
"""
Encapsulates model-specific information and token calculations.
This class provides a single source of truth for all model-related
token calculations, ensuring consistency across the system.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self._provider = None
self._capabilities = None
self._token_allocation = None
@property
def provider(self):
"""Get the model provider lazily."""
if self._provider is None:
self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name)
if not self._provider:
raise ValueError(f"No provider found for model: {self.model_name}")
return self._provider
@property
def capabilities(self) -> ModelCapabilities:
"""Get model capabilities lazily."""
if self._capabilities is None:
self._capabilities = self.provider.get_capabilities(self.model_name)
return self._capabilities
def calculate_token_allocation(self, reserved_for_response: Optional[int] = None) -> TokenAllocation:
"""
Calculate token allocation based on model capacity.
Args:
reserved_for_response: Override response token reservation
Returns:
TokenAllocation with calculated budgets
"""
total_tokens = self.capabilities.max_tokens
# Dynamic allocation based on model capacity
if total_tokens < 300_000:
# Smaller context models (O3): Conservative allocation
content_ratio = 0.6 # 60% for content
response_ratio = 0.4 # 40% for response
file_ratio = 0.3 # 30% of content for files
history_ratio = 0.5 # 50% of content for history
else:
# Larger context models (Gemini): More generous allocation
content_ratio = 0.8 # 80% for content
response_ratio = 0.2 # 20% for response
file_ratio = 0.4 # 40% of content for files
history_ratio = 0.4 # 40% of content for history
# Calculate allocations
content_tokens = int(total_tokens * content_ratio)
response_tokens = reserved_for_response or int(total_tokens * response_ratio)
# Sub-allocations within content budget
file_tokens = int(content_tokens * file_ratio)
history_tokens = int(content_tokens * history_ratio)
allocation = TokenAllocation(
total_tokens=total_tokens,
content_tokens=content_tokens,
response_tokens=response_tokens,
file_tokens=file_tokens,
history_tokens=history_tokens,
)
logger.debug(f"Token allocation for {self.model_name}:")
logger.debug(f" Total: {allocation.total_tokens:,}")
logger.debug(f" Content: {allocation.content_tokens:,} ({content_ratio:.0%})")
logger.debug(f" Response: {allocation.response_tokens:,} ({response_ratio:.0%})")
logger.debug(f" Files: {allocation.file_tokens:,} ({file_ratio:.0%} of content)")
logger.debug(f" History: {allocation.history_tokens:,} ({history_ratio:.0%} of content)")
return allocation
def estimate_tokens(self, text: str) -> int:
"""
Estimate token count for text using model-specific tokenizer.
For now, uses simple estimation. Can be enhanced with model-specific
tokenizers (tiktoken for OpenAI, etc.) in the future.
"""
# TODO: Integrate model-specific tokenizers
# For now, use conservative estimation
return len(text) // 3 # Conservative estimate
@classmethod
def from_arguments(cls, arguments: dict[str, Any]) -> "ModelContext":
"""Create ModelContext from tool arguments."""
model_name = arguments.get("model") or DEFAULT_MODEL
return cls(model_name)

View File

@@ -1,5 +1,5 @@
"""
Gemini MCP Server - Entry point for backward compatibility
Zen MCP Server - Entry point for backward compatibility
This file exists to maintain compatibility with existing configurations.
The main implementation is now in server.py
"""