WIP major refactor and features
This commit is contained in:
13
.env.example
13
.env.example
@@ -1,14 +1,19 @@
|
|||||||
# Gemini MCP Server Environment Configuration
|
# Gemini MCP Server Environment Configuration
|
||||||
# Copy this file to .env and fill in your values
|
# Copy this file to .env and fill in your values
|
||||||
|
|
||||||
# Required: Google Gemini API Key
|
# API Keys - At least one is required
|
||||||
# Get your API key from: https://makersuite.google.com/app/apikey
|
# Get your Gemini API key from: https://makersuite.google.com/app/apikey
|
||||||
GEMINI_API_KEY=your_gemini_api_key_here
|
GEMINI_API_KEY=your_gemini_api_key_here
|
||||||
|
|
||||||
|
# Get your OpenAI API key from: https://platform.openai.com/api-keys
|
||||||
|
OPENAI_API_KEY=your_openai_api_key_here
|
||||||
|
|
||||||
# Optional: Default model to use
|
# Optional: Default model to use
|
||||||
|
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'gpt-4o'
|
||||||
# Full names: 'gemini-2.5-pro-preview-06-05' or 'gemini-2.0-flash-exp'
|
# Full names: 'gemini-2.5-pro-preview-06-05' or 'gemini-2.0-flash-exp'
|
||||||
# Defaults to gemini-2.5-pro-preview-06-05 if not specified
|
# When set to 'auto', Claude will select the best model for each task
|
||||||
DEFAULT_MODEL=gemini-2.5-pro-preview-06-05
|
# Defaults to 'auto' if not specified
|
||||||
|
DEFAULT_MODEL=auto
|
||||||
|
|
||||||
# Optional: Default thinking mode for ThinkDeep tool
|
# Optional: Default thinking mode for ThinkDeep tool
|
||||||
# NOTE: Only applies to models that support extended thinking (e.g., Gemini 2.5 Pro)
|
# NOTE: Only applies to models that support extended thinking (e.g., Gemini 2.5 Pro)
|
||||||
|
|||||||
40
FIX_SUMMARY.md
Normal file
40
FIX_SUMMARY.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# Fix for Conversation History Bug in Continuation Flow
|
||||||
|
|
||||||
|
## Problem
|
||||||
|
When using `continuation_id` to continue a conversation, the conversation history (with embedded files) was being lost for tools that don't have a `prompt` field. Only new file content was being passed to the tool, resulting in minimal content (e.g., 322 chars for just a NOTE about files already in history).
|
||||||
|
|
||||||
|
## Root Cause
|
||||||
|
1. `reconstruct_thread_context()` builds conversation history and stores it in `arguments["prompt"]`
|
||||||
|
2. Different tools use different field names for user input:
|
||||||
|
- `chat` → `prompt`
|
||||||
|
- `analyze` → `question`
|
||||||
|
- `debug` → `error_description`
|
||||||
|
- `codereview` → `context`
|
||||||
|
- `thinkdeep` → `current_analysis`
|
||||||
|
- `precommit` → `original_request`
|
||||||
|
3. The enhanced prompt with conversation history was being placed in the wrong field
|
||||||
|
4. Tools would only see their new input, not the conversation history
|
||||||
|
|
||||||
|
## Solution
|
||||||
|
Modified `reconstruct_thread_context()` in `server.py` to:
|
||||||
|
1. Create a mapping of tool names to their primary input fields
|
||||||
|
2. Extract the user's new input from the correct field based on the tool
|
||||||
|
3. Store the enhanced prompt (with conversation history) back into the correct field
|
||||||
|
|
||||||
|
## Changes Made
|
||||||
|
1. **server.py**:
|
||||||
|
- Added `prompt_field_mapping` to map tools to their input fields
|
||||||
|
- Modified to extract user input from the correct field
|
||||||
|
- Modified to store enhanced prompt in the correct field
|
||||||
|
|
||||||
|
2. **tests/test_conversation_field_mapping.py**:
|
||||||
|
- Added comprehensive tests to verify the fix works for all tools
|
||||||
|
- Tests ensure conversation history is properly mapped to each tool's field
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
All existing tests pass, including:
|
||||||
|
- `test_conversation_memory.py` (18 tests)
|
||||||
|
- `test_cross_tool_continuation.py` (4 tests)
|
||||||
|
- New `test_conversation_field_mapping.py` (2 tests)
|
||||||
|
|
||||||
|
The fix ensures that when continuing conversations, tools receive the full conversation history with embedded files, not just new content.
|
||||||
245
README.md
245
README.md
@@ -1,29 +1,49 @@
|
|||||||
# Claude Code + Gemini: Working Together as One
|
# Claude Code + Multi-Model AI: Your Ultimate Development Team
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/a67099df-9387-4720-9b41-c986243ac11b
|
https://github.com/user-attachments/assets/a67099df-9387-4720-9b41-c986243ac11b
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<b>🤖 Claude + Gemini = Your Ultimate AI Development Team</b>
|
<b>🤖 Claude + Gemini / O3 / GPT-4o = Your Ultimate AI Development Team</b>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
The ultimate development partner for Claude - a Model Context Protocol server that gives Claude access to Google's Gemini models (2.5 Pro for extended thinking, 2.0 Flash for speed) for code analysis, problem-solving, and collaborative development. **Automatically reads files and directories, passing their contents to Gemini for analysis within its 1M token context.**
|
The ultimate development partner for Claude - a Model Context Protocol server that gives Claude access to multiple AI models for enhanced code analysis, problem-solving, and collaborative development.
|
||||||
|
|
||||||
**Features true AI orchestration with conversations that continue across tasks** - Give Claude a complex task and ask it to collaborate with Gemini.
|
**🎯 Auto Mode (NEW):** Set `DEFAULT_MODEL=auto` and Claude will intelligently select the best model for each task:
|
||||||
Claude stays in control, performs the actual work, but gets a second perspective from Gemini. Claude will talk to Gemini, work on implementation, then automatically resume the
|
- **Complex architecture review?** → Claude picks Gemini Pro with extended thinking
|
||||||
conversation with Gemini while maintaining the full thread.
|
- **Quick code formatting?** → Claude picks Gemini Flash for speed
|
||||||
Claude can switch between different Gemini tools ([`thinkdeep`](#2-thinkdeep---extended-reasoning-partner) → [`chat`](#1-chat---general-development-chat--collaborative-thinking) → [`precommit`](#4-precommit---pre-commit-validation) → [`codereview`](#3-codereview---professional-code-review)) and the conversation context carries forward seamlessly.
|
- **Logical debugging?** → Claude picks O3 for reasoning
|
||||||
For example, in the video above, Claude was asked to debate SwiftUI vs UIKit with Gemini, resulting in a back-and-forth discussion rather than a simple one-shot query and response.
|
- **Or specify your preference:** "Use flash to quickly analyze this" or "Use o3 for debugging"
|
||||||
|
|
||||||
|
**📚 Supported Models:**
|
||||||
|
- **Google Gemini**: 2.5 Pro (extended thinking, 1M tokens) & 2.0 Flash (ultra-fast, 1M tokens)
|
||||||
|
- **OpenAI**: O3 (strong reasoning, 200K tokens), O3-mini (faster variant), GPT-4o (128K tokens)
|
||||||
|
- **More providers coming soon!**
|
||||||
|
|
||||||
|
**Features true AI orchestration with conversations that continue across tasks** - Give Claude a complex task and let it orchestrate between models automatically. Claude stays in control, performs the actual work, but gets perspectives from the best AI for each subtask. Claude can switch between different tools AND models mid-conversation, with context carrying forward seamlessly.
|
||||||
|
|
||||||
|
**Example Workflow:**
|
||||||
|
1. Claude uses Gemini Pro to deeply analyze your architecture
|
||||||
|
2. Switches to O3 for logical debugging of a specific issue
|
||||||
|
3. Uses Flash for quick code formatting
|
||||||
|
4. Returns to Pro for security review
|
||||||
|
|
||||||
|
All within a single conversation thread!
|
||||||
|
|
||||||
**Think of it as Claude Code _for_ Claude Code.**
|
**Think of it as Claude Code _for_ Claude Code.**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
> ⚠️ **Active Development Notice**
|
> 🚀 **Multi-Provider Support with Auto Mode!**
|
||||||
> This project is under rapid development with frequent commits and changes over the past few days.
|
> Claude automatically selects the best model for each task when using `DEFAULT_MODEL=auto`:
|
||||||
> The goal is to expand support beyond Gemini to include additional AI models and providers.
|
> - **Gemini Pro**: Extended thinking (up to 32K tokens), best for complex problems
|
||||||
> **Watch this space** for new capabilities and potentially breaking changes in between updates!
|
> - **Gemini Flash**: Ultra-fast responses, best for quick tasks
|
||||||
|
> - **O3**: Strong reasoning, best for logical problems and debugging
|
||||||
|
> - **O3-mini**: Balanced performance, good for moderate complexity
|
||||||
|
> - **GPT-4o**: General-purpose, good for explanations and chat
|
||||||
|
>
|
||||||
|
> Or manually specify: "Use pro for deep analysis" or "Use o3 to debug this"
|
||||||
|
|
||||||
## Quick Navigation
|
## Quick Navigation
|
||||||
|
|
||||||
@@ -58,18 +78,20 @@ For example, in the video above, Claude was asked to debate SwiftUI vs UIKit wit
|
|||||||
## Why This Server?
|
## Why This Server?
|
||||||
|
|
||||||
Claude is brilliant, but sometimes you need:
|
Claude is brilliant, but sometimes you need:
|
||||||
|
- **Multiple AI perspectives** - Let Claude orchestrate between different models to get the best analysis
|
||||||
|
- **Automatic model selection** - Claude picks the right model for each task (or you can specify)
|
||||||
- **A senior developer partner** to validate and extend ideas ([`chat`](#1-chat---general-development-chat--collaborative-thinking))
|
- **A senior developer partner** to validate and extend ideas ([`chat`](#1-chat---general-development-chat--collaborative-thinking))
|
||||||
- **A second opinion** on complex architectural decisions - augment Claude's extended thinking with Gemini's perspective ([`thinkdeep`](#2-thinkdeep---extended-reasoning-partner))
|
- **A second opinion** on complex architectural decisions - augment Claude's thinking with perspectives from Gemini Pro, O3, or others ([`thinkdeep`](#2-thinkdeep---extended-reasoning-partner))
|
||||||
- **Professional code reviews** with actionable feedback across entire repositories ([`codereview`](#3-codereview---professional-code-review))
|
- **Professional code reviews** with actionable feedback across entire repositories ([`codereview`](#3-codereview---professional-code-review))
|
||||||
- **Pre-commit validation** with deep analysis that finds edge cases, validates your implementation against original requirements, and catches subtle bugs Claude might miss ([`precommit`](#4-precommit---pre-commit-validation))
|
- **Pre-commit validation** with deep analysis using the best model for the job ([`precommit`](#4-precommit---pre-commit-validation))
|
||||||
- **Expert debugging** for tricky issues with full system context ([`debug`](#5-debug---expert-debugging-assistant))
|
- **Expert debugging** - O3 for logical issues, Gemini for architectural problems ([`debug`](#5-debug---expert-debugging-assistant))
|
||||||
- **Massive context window** (1M tokens) - Gemini 2.5 Pro can analyze entire codebases, read hundreds of files at once, and provide comprehensive insights ([`analyze`](#6-analyze---smart-file-analysis))
|
- **Massive context windows** - Gemini (1M tokens), O3 (200K tokens), GPT-4o (128K tokens)
|
||||||
- **Deep code analysis** across massive codebases that exceed Claude's context limits ([`analyze`](#6-analyze---smart-file-analysis))
|
- **Model-specific strengths** - Extended thinking with Gemini Pro, fast iteration with Flash, strong reasoning with O3
|
||||||
- **Dynamic collaboration** - Gemini can request additional context from Claude mid-analysis for more thorough insights
|
- **Dynamic collaboration** - Models can request additional context from Claude mid-analysis
|
||||||
- **Smart file handling** - Automatically expands directories, filters irrelevant files, and manages token limits when analyzing `"main.py, src/, tests/"`
|
- **Smart file handling** - Automatically expands directories, manages token limits based on model capacity
|
||||||
- **[Bypass MCP's token limits](#working-with-large-prompts)** - Work around MCP's 25K combined token limit by automatically handling large prompts as files, preserving the full capacity for responses
|
- **[Bypass MCP's token limits](#working-with-large-prompts)** - Work around MCP's 25K limit automatically
|
||||||
|
|
||||||
This server makes Gemini your development sidekick, handling what Claude can't or extending what Claude starts.
|
This server orchestrates multiple AI models as your development team, with Claude automatically selecting the best model for each task or allowing you to choose specific models for different strengths.
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<img src="https://github.com/user-attachments/assets/0f3c8e2d-a236-4068-a80e-46f37b0c9d35" width="600">
|
<img src="https://github.com/user-attachments/assets/0f3c8e2d-a236-4068-a80e-46f37b0c9d35" width="600">
|
||||||
@@ -93,8 +115,9 @@ The final implementation resulted in a 26% improvement in JSON parsing performan
|
|||||||
- Git
|
- Git
|
||||||
- **Windows users**: WSL2 is required for Claude Code CLI
|
- **Windows users**: WSL2 is required for Claude Code CLI
|
||||||
|
|
||||||
### 1. Get a Gemini API Key
|
### 1. Get API Keys (at least one required)
|
||||||
Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models.
|
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models.
|
||||||
|
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) to get an API key for O3 model access.
|
||||||
|
|
||||||
### 2. Clone and Set Up
|
### 2. Clone and Set Up
|
||||||
|
|
||||||
@@ -109,22 +132,25 @@ cd gemini-mcp-server
|
|||||||
|
|
||||||
**What this does:**
|
**What this does:**
|
||||||
- **Builds Docker images** with all dependencies (including Redis for conversation threading)
|
- **Builds Docker images** with all dependencies (including Redis for conversation threading)
|
||||||
- **Creates .env file** (automatically uses `$GEMINI_API_KEY` if set in environment)
|
- **Creates .env file** (automatically uses `$GEMINI_API_KEY` and `$OPENAI_API_KEY` if set in environment)
|
||||||
- **Starts Redis service** for AI-to-AI conversation memory
|
- **Starts Redis service** for AI-to-AI conversation memory
|
||||||
- **Starts MCP server** ready to connect
|
- **Starts MCP server** with providers based on available API keys
|
||||||
- **Shows exact Claude Desktop configuration** to copy
|
- **Shows exact Claude Desktop configuration** to copy
|
||||||
- **Multi-turn AI conversations** - Gemini can ask follow-up questions that persist across requests
|
- **Multi-turn AI conversations** - Models can ask follow-up questions that persist across requests
|
||||||
|
|
||||||
### 3. Add Your API Key
|
### 3. Add Your API Keys
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Edit .env to add your Gemini API key (if not already set in environment)
|
# Edit .env to add your API keys (if not already set in environment)
|
||||||
nano .env
|
nano .env
|
||||||
|
|
||||||
# The file will contain:
|
# The file will contain:
|
||||||
# GEMINI_API_KEY=your-gemini-api-key-here
|
# GEMINI_API_KEY=your-gemini-api-key-here # For Gemini models
|
||||||
|
# OPENAI_API_KEY=your-openai-api-key-here # For O3 model
|
||||||
# REDIS_URL=redis://redis:6379/0 (automatically configured)
|
# REDIS_URL=redis://redis:6379/0 (automatically configured)
|
||||||
# WORKSPACE_ROOT=/workspace (automatically configured)
|
# WORKSPACE_ROOT=/workspace (automatically configured)
|
||||||
|
|
||||||
|
# Note: At least one API key is required (Gemini or OpenAI)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. Configure Claude Desktop
|
### 4. Configure Claude Desktop
|
||||||
@@ -184,17 +210,17 @@ Completely quit and restart Claude Desktop for the changes to take effect.
|
|||||||
### 6. Start Using It!
|
### 6. Start Using It!
|
||||||
|
|
||||||
Just ask Claude naturally:
|
Just ask Claude naturally:
|
||||||
- "Use gemini to think deeper about this architecture design" → `thinkdeep`
|
- "Think deeper about this architecture design" → Claude picks best model + `thinkdeep`
|
||||||
- "Get gemini to review this code for security issues" → `codereview`
|
- "Review this code for security issues" → Claude might pick Gemini Pro + `codereview`
|
||||||
- "Get gemini to debug why this test is failing" → `debug`
|
- "Debug why this test is failing" → Claude might pick O3 + `debug`
|
||||||
- "Use gemini to analyze these files to understand the data flow" → `analyze`
|
- "Analyze these files to understand the data flow" → Claude picks appropriate model + `analyze`
|
||||||
- "Brainstorm with gemini about scaling strategies" → `chat`
|
- "Use flash to quickly format this code" → Uses Gemini Flash specifically
|
||||||
- "Share my implementation plan with gemini for feedback" → `chat`
|
- "Get o3 to debug this logic error" → Uses O3 specifically
|
||||||
- "Get gemini's opinion on my authentication design" → `chat`
|
- "Brainstorm scaling strategies with pro" → Uses Gemini Pro specifically
|
||||||
|
|
||||||
## AI-to-AI Conversation Threading
|
## AI-to-AI Conversation Threading
|
||||||
|
|
||||||
This server enables **true AI collaboration** between Claude and Gemini, where they can coordinate and question each other's approaches:
|
This server enables **true AI collaboration** between Claude and multiple AI models (Gemini, O3, GPT-4o), where they can coordinate and question each other's approaches:
|
||||||
|
|
||||||
**How it works:**
|
**How it works:**
|
||||||
- **Gemini can ask Claude follow-up questions** to clarify requirements or gather more context
|
- **Gemini can ask Claude follow-up questions** to clarify requirements or gather more context
|
||||||
@@ -206,12 +232,14 @@ This server enables **true AI collaboration** between Claude and Gemini, where t
|
|||||||
- Each conversation maintains full context while only sending incremental updates
|
- Each conversation maintains full context while only sending incremental updates
|
||||||
- Conversations are automatically managed with Redis for persistence
|
- Conversations are automatically managed with Redis for persistence
|
||||||
|
|
||||||
**Example of AI-to-AI coordination:**
|
**Example of Multi-Model AI Coordination:**
|
||||||
1. You: "Chat with gemini to debate SwiftUI vs UIKit - which is better for iOS development?"
|
1. You: "Debate SwiftUI vs UIKit - which is better for iOS development?"
|
||||||
2. Claude: "I'll present SwiftUI's case: Modern declarative syntax, automatic state management, cross-platform compatibility, and Apple's clear future direction make SwiftUI the superior choice for new projects."
|
2. Claude (auto mode): "I'll orchestrate a debate between different models for diverse perspectives."
|
||||||
3. Gemini: "I disagree. UIKit remains superior for production apps: mature ecosystem, granular control, extensive third-party support, proven performance in complex UIs, and better debugging tools. SwiftUI still has too many limitations."
|
3. Gemini Pro: "From an architectural standpoint, SwiftUI's declarative paradigm and state management make it superior for maintainable, modern apps."
|
||||||
4. Claude: "While UIKit has maturity advantages, SwiftUI's rapid evolution, simplified development workflow, and Apple's investment make it the strategic choice. The learning curve pays off with faster development and maintainable code."
|
4. O3: "Logically analyzing the trade-offs: UIKit offers 15+ years of stability, complete control, and proven scalability. SwiftUI has <5 years maturity with ongoing breaking changes."
|
||||||
5. **Final recommendation**: After this AI debate, Claude concludes: "Based on our discussion, SwiftUI is recommended for new projects despite Gemini's valid UIKit points."
|
5. Claude: "Let me get Flash's quick take on developer experience..."
|
||||||
|
6. Gemini Flash: "SwiftUI = faster development, less code, better previews. UIKit = more control, better debugging, stable APIs."
|
||||||
|
7. **Claude's synthesis**: "Based on the multi-model analysis: Use SwiftUI for new projects prioritizing development speed, UIKit for apps requiring fine control or supporting older iOS versions."
|
||||||
|
|
||||||
**Asynchronous workflow example:**
|
**Asynchronous workflow example:**
|
||||||
- Claude can work independently between exchanges (analyzing code, implementing fixes, gathering data)
|
- Claude can work independently between exchanges (analyzing code, implementing fixes, gathering data)
|
||||||
@@ -231,32 +259,45 @@ This server enables **true AI collaboration** between Claude and Gemini, where t
|
|||||||
- Up to 5 exchanges per conversation with 1-hour expiry
|
- Up to 5 exchanges per conversation with 1-hour expiry
|
||||||
- Thread-safe with Redis persistence across all tools
|
- Thread-safe with Redis persistence across all tools
|
||||||
|
|
||||||
**Cross-tool continuation example:**
|
**Cross-tool & Cross-Model Continuation Example:**
|
||||||
```
|
```
|
||||||
1. Claude: "Use gemini to analyze /src/auth.py for security issues"
|
1. Claude: "Analyze /src/auth.py for security issues"
|
||||||
→ Gemini analyzes and finds vulnerabilities, provides continuation_id
|
→ Auto mode: Claude picks Gemini Pro for deep security analysis
|
||||||
|
→ Pro analyzes and finds vulnerabilities, provides continuation_id
|
||||||
|
|
||||||
2. Claude: "Use gemini to review the authentication logic thoroughly"
|
2. Claude: "Review the authentication logic thoroughly"
|
||||||
→ Uses same continuation_id, Gemini sees previous analysis and files
|
→ Uses same continuation_id, but Claude picks O3 for logical analysis
|
||||||
→ Provides detailed code review building on previous findings
|
→ O3 sees previous Pro analysis and provides logic-focused review
|
||||||
|
|
||||||
3. Claude: "Use gemini to help debug the auth test failures"
|
3. Claude: "Debug the auth test failures"
|
||||||
→ Same continuation_id, full context from analysis + review
|
→ Same continuation_id, Claude keeps O3 for debugging
|
||||||
→ Gemini provides targeted debugging with complete understanding
|
→ O3 provides targeted debugging with full context from both previous analyses
|
||||||
|
|
||||||
|
4. Claude: "Quick style check before committing"
|
||||||
|
→ Same thread, but Claude switches to Flash for speed
|
||||||
|
→ Flash quickly validates formatting with awareness of all previous fixes
|
||||||
```
|
```
|
||||||
|
|
||||||
## Available Tools
|
## Available Tools
|
||||||
|
|
||||||
**Quick Tool Selection Guide:**
|
**Quick Tool Selection Guide:**
|
||||||
- **Need a thinking partner?** → `chat` (brainstorm ideas, get second opinions, validate approaches)
|
- **Need a thinking partner?** → `chat` (brainstorm ideas, get second opinions, validate approaches)
|
||||||
- **Need deeper thinking?** → `thinkdeep` (extends Claude's analysis, finds edge cases)
|
- **Need deeper thinking?** → `thinkdeep` (extends analysis, finds edge cases)
|
||||||
- **Code needs review?** → `codereview` (bugs, security, performance issues)
|
- **Code needs review?** → `codereview` (bugs, security, performance issues)
|
||||||
- **Pre-commit validation?** → `precommit` (validate git changes before committing)
|
- **Pre-commit validation?** → `precommit` (validate git changes before committing)
|
||||||
- **Something's broken?** → `debug` (root cause analysis, error tracing)
|
- **Something's broken?** → `debug` (root cause analysis, error tracing)
|
||||||
- **Want to understand code?** → `analyze` (architecture, patterns, dependencies)
|
- **Want to understand code?** → `analyze` (architecture, patterns, dependencies)
|
||||||
- **Server info?** → `get_version` (version and configuration details)
|
- **Server info?** → `get_version` (version and configuration details)
|
||||||
|
|
||||||
**Pro Tip:** You can control the depth of Gemini's analysis with thinking modes to manage token costs. For quick tasks use "minimal" or "low" to save tokens, for complex problems use "high" or "max" when quality matters more than cost. [Learn more about thinking modes](#thinking-modes---managing-token-costs--quality)
|
**Auto Mode:** When `DEFAULT_MODEL=auto`, Claude automatically picks the best model for each task. You can override with: "Use flash for quick analysis" or "Use o3 to debug this".
|
||||||
|
|
||||||
|
**Model Selection Examples:**
|
||||||
|
- Complex architecture review → Claude picks Gemini Pro
|
||||||
|
- Quick formatting check → Claude picks Flash
|
||||||
|
- Logical debugging → Claude picks O3
|
||||||
|
- General explanations → Claude picks GPT-4o
|
||||||
|
|
||||||
|
**Pro Tip:** Thinking modes (for Gemini models) control depth vs token cost. Use "minimal" or "low" for quick tasks, "high" or "max" for complex problems. [Learn more](#thinking-modes---managing-token-costs--quality)
|
||||||
|
|
||||||
**Tools Overview:**
|
**Tools Overview:**
|
||||||
1. [`chat`](#1-chat---general-development-chat--collaborative-thinking) - Collaborative thinking and development conversations
|
1. [`chat`](#1-chat---general-development-chat--collaborative-thinking) - Collaborative thinking and development conversations
|
||||||
@@ -591,58 +632,65 @@ All tools that work with files support **both individual files and entire direct
|
|||||||
|
|
||||||
**`analyze`** - Analyze files or directories
|
**`analyze`** - Analyze files or directories
|
||||||
- `files`: List of file paths or directories (required)
|
- `files`: List of file paths or directories (required)
|
||||||
- `question`: What to analyze (required)
|
- `question`: What to analyze (required)
|
||||||
- `model`: pro|flash (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|gpt-4o (default: server default)
|
||||||
- `analysis_type`: architecture|performance|security|quality|general
|
- `analysis_type`: architecture|performance|security|quality|general
|
||||||
- `output_format`: summary|detailed|actionable
|
- `output_format`: summary|detailed|actionable
|
||||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium)
|
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||||
- `use_websearch`: Enable web search for documentation and best practices (default: false)
|
- `use_websearch`: Enable web search for documentation and best practices (default: false)
|
||||||
|
|
||||||
```
|
```
|
||||||
"Use gemini to analyze the src/ directory for architectural patterns"
|
"Analyze the src/ directory for architectural patterns" (auto mode picks best model)
|
||||||
"Use flash to quickly analyze main.py and tests/ to understand test coverage"
|
"Use flash to quickly analyze main.py and tests/ to understand test coverage"
|
||||||
|
"Use o3 for logical analysis of the algorithm in backend/core.py"
|
||||||
"Use pro for deep analysis of the entire backend/ directory structure"
|
"Use pro for deep analysis of the entire backend/ directory structure"
|
||||||
```
|
```
|
||||||
|
|
||||||
**`codereview`** - Review code files or directories
|
**`codereview`** - Review code files or directories
|
||||||
- `files`: List of file paths or directories (required)
|
- `files`: List of file paths or directories (required)
|
||||||
- `model`: pro|flash (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|gpt-4o (default: server default)
|
||||||
- `review_type`: full|security|performance|quick
|
- `review_type`: full|security|performance|quick
|
||||||
- `focus_on`: Specific aspects to focus on
|
- `focus_on`: Specific aspects to focus on
|
||||||
- `standards`: Coding standards to enforce
|
- `standards`: Coding standards to enforce
|
||||||
- `severity_filter`: critical|high|medium|all
|
- `severity_filter`: critical|high|medium|all
|
||||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium)
|
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||||
|
|
||||||
```
|
```
|
||||||
"Use pro to review the entire api/ directory for security issues"
|
"Review the entire api/ directory for security issues" (auto mode picks best model)
|
||||||
|
"Use pro to review auth/ for deep security analysis"
|
||||||
|
"Use o3 to review logic in algorithms/ for correctness"
|
||||||
"Use flash to quickly review src/ with focus on performance, only show critical issues"
|
"Use flash to quickly review src/ with focus on performance, only show critical issues"
|
||||||
```
|
```
|
||||||
|
|
||||||
**`debug`** - Debug with file context
|
**`debug`** - Debug with file context
|
||||||
- `error_description`: Description of the issue (required)
|
- `error_description`: Description of the issue (required)
|
||||||
- `model`: pro|flash (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|gpt-4o (default: server default)
|
||||||
- `error_context`: Stack trace or logs
|
- `error_context`: Stack trace or logs
|
||||||
- `files`: Files or directories related to the issue
|
- `files`: Files or directories related to the issue
|
||||||
- `runtime_info`: Environment details
|
- `runtime_info`: Environment details
|
||||||
- `previous_attempts`: What you've tried
|
- `previous_attempts`: What you've tried
|
||||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium)
|
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||||
- `use_websearch`: Enable web search for error messages and solutions (default: false)
|
- `use_websearch`: Enable web search for error messages and solutions (default: false)
|
||||||
|
|
||||||
```
|
```
|
||||||
"Use gemini to debug this error with context from the entire backend/ directory"
|
"Debug this logic error with context from backend/" (auto mode picks best model)
|
||||||
|
"Use o3 to debug this algorithm correctness issue"
|
||||||
|
"Use pro to debug this complex architecture problem"
|
||||||
```
|
```
|
||||||
|
|
||||||
**`thinkdeep`** - Extended analysis with file context
|
**`thinkdeep`** - Extended analysis with file context
|
||||||
- `current_analysis`: Your current thinking (required)
|
- `current_analysis`: Your current thinking (required)
|
||||||
- `model`: pro|flash (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|gpt-4o (default: server default)
|
||||||
- `problem_context`: Additional context
|
- `problem_context`: Additional context
|
||||||
- `focus_areas`: Specific aspects to focus on
|
- `focus_areas`: Specific aspects to focus on
|
||||||
- `files`: Files or directories for context
|
- `files`: Files or directories for context
|
||||||
- `thinking_mode`: minimal|low|medium|high|max (default: max)
|
- `thinking_mode`: minimal|low|medium|high|max (default: max, Gemini only)
|
||||||
- `use_websearch`: Enable web search for documentation and insights (default: false)
|
- `use_websearch`: Enable web search for documentation and insights (default: false)
|
||||||
|
|
||||||
```
|
```
|
||||||
"Use gemini to think deeper about my design with reference to the src/models/ directory"
|
"Think deeper about my design with reference to src/models/" (auto mode picks best model)
|
||||||
|
"Use pro to think deeper about this architecture with extended thinking"
|
||||||
|
"Use o3 to think deeper about the logical flow in this algorithm"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Collaborative Workflows
|
## Collaborative Workflows
|
||||||
@@ -877,31 +925,54 @@ The server includes several configurable properties that control its behavior:
|
|||||||
|
|
||||||
### Model Configuration
|
### Model Configuration
|
||||||
|
|
||||||
**Default Model (Environment Variable):**
|
**🎯 Auto Mode (Recommended):**
|
||||||
- **`DEFAULT_MODEL`**: Set your preferred default model globally
|
Set `DEFAULT_MODEL=auto` in your .env file and Claude will intelligently select the best model for each task:
|
||||||
- Default: `"gemini-2.5-pro-preview-06-05"` (extended thinking capabilities)
|
|
||||||
- Alternative: `"gemini-2.0-flash-exp"` (faster responses)
|
|
||||||
|
|
||||||
**Per-Tool Model Selection:**
|
|
||||||
All tools support a `model` parameter for flexible model switching:
|
|
||||||
- **`"pro"`** → Gemini 2.5 Pro (extended thinking, slower, higher quality)
|
|
||||||
- **`"flash"`** → Gemini 2.0 Flash (faster responses, lower cost)
|
|
||||||
- **Full model names** → Direct model specification
|
|
||||||
|
|
||||||
**Examples:**
|
|
||||||
```env
|
```env
|
||||||
# Set default globally in .env file
|
# .env file
|
||||||
DEFAULT_MODEL=flash
|
DEFAULT_MODEL=auto # Claude picks the best model automatically
|
||||||
|
|
||||||
|
# API Keys (at least one required)
|
||||||
|
GEMINI_API_KEY=your-gemini-key # Enables Gemini Pro & Flash
|
||||||
|
OPENAI_API_KEY=your-openai-key # Enables O3, O3-mini, GPT-4o
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
**How Auto Mode Works:**
|
||||||
# Per-tool usage in Claude
|
- Claude analyzes each request and selects the optimal model
|
||||||
"Use flash to quickly analyze this function"
|
- Model selection is based on task complexity, requirements, and model strengths
|
||||||
"Use pro for deep architectural analysis"
|
- You can always override: "Use flash for quick check" or "Use o3 to debug"
|
||||||
|
|
||||||
|
**Supported Models & When Claude Uses Them:**
|
||||||
|
|
||||||
|
| Model | Provider | Context | Strengths | Auto Mode Usage |
|
||||||
|
|-------|----------|---------|-----------|------------------|
|
||||||
|
| **`pro`** (Gemini 2.5 Pro) | Google | 1M tokens | Extended thinking (up to 32K tokens), deep analysis | Complex architecture, security reviews, deep debugging |
|
||||||
|
| **`flash`** (Gemini 2.0 Flash) | Google | 1M tokens | Ultra-fast responses | Quick checks, formatting, simple analysis |
|
||||||
|
| **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis |
|
||||||
|
| **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks |
|
||||||
|
| **`gpt-4o`** | OpenAI | 128K tokens | General purpose | Explanations, documentation, chat |
|
||||||
|
|
||||||
|
**Manual Model Selection:**
|
||||||
|
You can specify a default model instead of auto mode:
|
||||||
|
|
||||||
|
```env
|
||||||
|
# Use a specific model by default
|
||||||
|
DEFAULT_MODEL=gemini-2.5-pro-preview-06-05 # Always use Gemini Pro
|
||||||
|
DEFAULT_MODEL=flash # Always use Flash
|
||||||
|
DEFAULT_MODEL=o3 # Always use O3
|
||||||
```
|
```
|
||||||
|
|
||||||
**Token Limits:**
|
**Per-Request Model Override:**
|
||||||
- **`MAX_CONTEXT_TOKENS`**: `1,000,000` - Maximum input context (1M tokens for Gemini 2.5 Pro)
|
Regardless of your default setting, you can specify models per request:
|
||||||
|
- "Use **pro** for deep security analysis of auth.py"
|
||||||
|
- "Use **flash** to quickly format this code"
|
||||||
|
- "Use **o3** to debug this logic error"
|
||||||
|
- "Review with **o3-mini** for balanced analysis"
|
||||||
|
|
||||||
|
**Model Capabilities:**
|
||||||
|
- **Gemini Models**: Support thinking modes (minimal to max), web search, 1M context
|
||||||
|
- **O3 Models**: Excellent reasoning, systematic analysis, 200K context
|
||||||
|
- **GPT-4o**: Balanced general-purpose model, 128K context
|
||||||
|
|
||||||
### Temperature Defaults
|
### Temperature Defaults
|
||||||
Different tools use optimized temperature settings:
|
Different tools use optimized temperature settings:
|
||||||
|
|||||||
27
config.py
27
config.py
@@ -21,7 +21,32 @@ __author__ = "Fahad Gilani" # Primary maintainer
|
|||||||
# DEFAULT_MODEL: The default model used for all AI operations
|
# DEFAULT_MODEL: The default model used for all AI operations
|
||||||
# This should be a stable, high-performance model suitable for code analysis
|
# This should be a stable, high-performance model suitable for code analysis
|
||||||
# Can be overridden by setting DEFAULT_MODEL environment variable
|
# Can be overridden by setting DEFAULT_MODEL environment variable
|
||||||
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "gemini-2.5-pro-preview-06-05")
|
# Special value "auto" means Claude should pick the best model for each task
|
||||||
|
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "auto")
|
||||||
|
|
||||||
|
# Validate DEFAULT_MODEL and set to "auto" if invalid
|
||||||
|
# Only include actually supported models from providers
|
||||||
|
VALID_MODELS = ["auto", "flash", "pro", "o3", "o3-mini", "gemini-2.0-flash-exp", "gemini-2.5-pro-preview-06-05"]
|
||||||
|
if DEFAULT_MODEL not in VALID_MODELS:
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.warning(f"Invalid DEFAULT_MODEL '{DEFAULT_MODEL}'. Setting to 'auto'. Valid options: {', '.join(VALID_MODELS)}")
|
||||||
|
DEFAULT_MODEL = "auto"
|
||||||
|
|
||||||
|
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
|
||||||
|
IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto"
|
||||||
|
|
||||||
|
# Model capabilities descriptions for auto mode
|
||||||
|
# These help Claude choose the best model for each task
|
||||||
|
MODEL_CAPABILITIES_DESC = {
|
||||||
|
"flash": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||||
|
"pro": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||||
|
"o3": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
|
||||||
|
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||||
|
# Full model names also supported
|
||||||
|
"gemini-2.0-flash-exp": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||||
|
"gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis"
|
||||||
|
}
|
||||||
|
|
||||||
# Token allocation for Gemini Pro (1M total capacity)
|
# Token allocation for Gemini Pro (1M total capacity)
|
||||||
# MAX_CONTEXT_TOKENS: Total model capacity
|
# MAX_CONTEXT_TOKENS: Total model capacity
|
||||||
|
|||||||
@@ -29,8 +29,9 @@ services:
|
|||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
environment:
|
environment:
|
||||||
- GEMINI_API_KEY=${GEMINI_API_KEY:?GEMINI_API_KEY is required. Please set it in your .env file or environment.}
|
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||||
- DEFAULT_MODEL=${DEFAULT_MODEL:-gemini-2.5-pro-preview-06-05}
|
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||||
|
- DEFAULT_MODEL=${DEFAULT_MODEL:-auto}
|
||||||
- DEFAULT_THINKING_MODE_THINKDEEP=${DEFAULT_THINKING_MODE_THINKDEEP:-high}
|
- DEFAULT_THINKING_MODE_THINKDEEP=${DEFAULT_THINKING_MODE_THINKDEEP:-high}
|
||||||
- REDIS_URL=redis://redis:6379/0
|
- REDIS_URL=redis://redis:6379/0
|
||||||
# Use HOME not PWD: Claude needs access to any absolute file path, not just current project,
|
# Use HOME not PWD: Claude needs access to any absolute file path, not just current project,
|
||||||
@@ -42,7 +43,6 @@ services:
|
|||||||
- ${HOME:-/tmp}:/workspace:ro
|
- ${HOME:-/tmp}:/workspace:ro
|
||||||
- mcp_logs:/tmp # Shared volume for logs
|
- mcp_logs:/tmp # Shared volume for logs
|
||||||
- /etc/localtime:/etc/localtime:ro
|
- /etc/localtime:/etc/localtime:ro
|
||||||
- /etc/timezone:/etc/timezone:ro
|
|
||||||
stdin_open: true
|
stdin_open: true
|
||||||
tty: true
|
tty: true
|
||||||
entrypoint: ["python"]
|
entrypoint: ["python"]
|
||||||
@@ -60,7 +60,6 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- mcp_logs:/tmp # Shared volume for logs
|
- mcp_logs:/tmp # Shared volume for logs
|
||||||
- /etc/localtime:/etc/localtime:ro
|
- /etc/localtime:/etc/localtime:ro
|
||||||
- /etc/timezone:/etc/timezone:ro
|
|
||||||
entrypoint: ["python"]
|
entrypoint: ["python"]
|
||||||
command: ["log_monitor.py"]
|
command: ["log_monitor.py"]
|
||||||
|
|
||||||
|
|||||||
15
providers/__init__.py
Normal file
15
providers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Model provider abstractions for supporting multiple AI providers."""
|
||||||
|
|
||||||
|
from .base import ModelProvider, ModelResponse, ModelCapabilities
|
||||||
|
from .registry import ModelProviderRegistry
|
||||||
|
from .gemini import GeminiModelProvider
|
||||||
|
from .openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelProvider",
|
||||||
|
"ModelResponse",
|
||||||
|
"ModelCapabilities",
|
||||||
|
"ModelProviderRegistry",
|
||||||
|
"GeminiModelProvider",
|
||||||
|
"OpenAIModelProvider",
|
||||||
|
]
|
||||||
122
providers/base.py
Normal file
122
providers/base.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Base model provider interface and data classes."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Any, Tuple
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderType(Enum):
|
||||||
|
"""Supported model provider types."""
|
||||||
|
GOOGLE = "google"
|
||||||
|
OPENAI = "openai"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelCapabilities:
|
||||||
|
"""Capabilities and constraints for a specific model."""
|
||||||
|
provider: ProviderType
|
||||||
|
model_name: str
|
||||||
|
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
||||||
|
max_tokens: int
|
||||||
|
supports_extended_thinking: bool = False
|
||||||
|
supports_system_prompts: bool = True
|
||||||
|
supports_streaming: bool = True
|
||||||
|
supports_function_calling: bool = False
|
||||||
|
temperature_range: Tuple[float, float] = (0.0, 2.0)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelResponse:
|
||||||
|
"""Response from a model provider."""
|
||||||
|
content: str
|
||||||
|
usage: Dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
|
||||||
|
model_name: str = ""
|
||||||
|
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
|
||||||
|
provider: ProviderType = ProviderType.GOOGLE
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
"""Get total tokens used."""
|
||||||
|
return self.usage.get("total_tokens", 0)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProvider(ABC):
|
||||||
|
"""Abstract base class for model providers."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize the provider with API key and optional configuration."""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.config = kwargs
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
|
"""Get capabilities for a specific model."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send to the model
|
||||||
|
model_name: Name of the model to use
|
||||||
|
system_prompt: Optional system prompt for model behavior
|
||||||
|
temperature: Sampling temperature (0-2)
|
||||||
|
max_output_tokens: Maximum tokens to generate
|
||||||
|
**kwargs: Provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelResponse with generated content and metadata
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
|
"""Count tokens for the given text using the specified model's tokenizer."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
|
"""Validate if the model name is supported by this provider."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validate_parameters(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
temperature: float,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
"""Validate model parameters against capabilities.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If parameters are invalid
|
||||||
|
"""
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
|
# Validate temperature
|
||||||
|
min_temp, max_temp = capabilities.temperature_range
|
||||||
|
if not min_temp <= temperature <= max_temp:
|
||||||
|
raise ValueError(
|
||||||
|
f"Temperature {temperature} out of range [{min_temp}, {max_temp}] "
|
||||||
|
f"for model {model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
|
"""Check if the model supports extended thinking mode."""
|
||||||
|
pass
|
||||||
185
providers/gemini.py
Normal file
185
providers/gemini.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""Gemini model provider implementation."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, Optional, List
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
from .base import ModelProvider, ModelResponse, ModelCapabilities, ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiModelProvider(ModelProvider):
|
||||||
|
"""Google Gemini model provider implementation."""
|
||||||
|
|
||||||
|
# Model configurations
|
||||||
|
SUPPORTED_MODELS = {
|
||||||
|
"gemini-2.0-flash-exp": {
|
||||||
|
"max_tokens": 1_048_576, # 1M tokens
|
||||||
|
"supports_extended_thinking": False,
|
||||||
|
},
|
||||||
|
"gemini-2.5-pro-preview-06-05": {
|
||||||
|
"max_tokens": 1_048_576, # 1M tokens
|
||||||
|
"supports_extended_thinking": True,
|
||||||
|
},
|
||||||
|
# Shorthands
|
||||||
|
"flash": "gemini-2.0-flash-exp",
|
||||||
|
"pro": "gemini-2.5-pro-preview-06-05",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Thinking mode configurations for models that support it
|
||||||
|
THINKING_BUDGETS = {
|
||||||
|
"minimal": 128, # Minimum for 2.5 Pro - fast responses
|
||||||
|
"low": 2048, # Light reasoning tasks
|
||||||
|
"medium": 8192, # Balanced reasoning (default)
|
||||||
|
"high": 16384, # Complex analysis
|
||||||
|
"max": 32768, # Maximum reasoning depth
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize Gemini provider with API key."""
|
||||||
|
super().__init__(api_key, **kwargs)
|
||||||
|
self._client = None
|
||||||
|
self._token_counters = {} # Cache for token counting
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self):
|
||||||
|
"""Lazy initialization of Gemini client."""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = genai.Client(api_key=self.api_key)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
|
"""Get capabilities for a specific Gemini model."""
|
||||||
|
# Resolve shorthand
|
||||||
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
|
raise ValueError(f"Unsupported Gemini model: {model_name}")
|
||||||
|
|
||||||
|
config = self.SUPPORTED_MODELS[resolved_name]
|
||||||
|
|
||||||
|
return ModelCapabilities(
|
||||||
|
provider=ProviderType.GOOGLE,
|
||||||
|
model_name=resolved_name,
|
||||||
|
friendly_name="Gemini",
|
||||||
|
max_tokens=config["max_tokens"],
|
||||||
|
supports_extended_thinking=config["supports_extended_thinking"],
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
temperature_range=(0.0, 2.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
thinking_mode: str = "medium",
|
||||||
|
**kwargs
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using Gemini model."""
|
||||||
|
# Validate parameters
|
||||||
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
self.validate_parameters(resolved_name, temperature)
|
||||||
|
|
||||||
|
# Combine system prompt with user prompt if provided
|
||||||
|
if system_prompt:
|
||||||
|
full_prompt = f"{system_prompt}\n\n{prompt}"
|
||||||
|
else:
|
||||||
|
full_prompt = prompt
|
||||||
|
|
||||||
|
# Prepare generation config
|
||||||
|
generation_config = types.GenerateContentConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
candidate_count=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add max output tokens if specified
|
||||||
|
if max_output_tokens:
|
||||||
|
generation_config.max_output_tokens = max_output_tokens
|
||||||
|
|
||||||
|
# Add thinking configuration for models that support it
|
||||||
|
capabilities = self.get_capabilities(resolved_name)
|
||||||
|
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||||
|
generation_config.thinking_config = types.ThinkingConfig(
|
||||||
|
thinking_budget=self.THINKING_BUDGETS[thinking_mode]
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate content
|
||||||
|
response = self.client.models.generate_content(
|
||||||
|
model=resolved_name,
|
||||||
|
contents=full_prompt,
|
||||||
|
config=generation_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract usage information if available
|
||||||
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
|
return ModelResponse(
|
||||||
|
content=response.text,
|
||||||
|
usage=usage,
|
||||||
|
model_name=resolved_name,
|
||||||
|
friendly_name="Gemini",
|
||||||
|
provider=ProviderType.GOOGLE,
|
||||||
|
metadata={
|
||||||
|
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
|
||||||
|
"finish_reason": getattr(response.candidates[0], "finish_reason", "STOP") if response.candidates else "STOP",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log error and re-raise with more context
|
||||||
|
error_msg = f"Gemini API error for model {resolved_name}: {str(e)}"
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
|
"""Count tokens for the given text using Gemini's tokenizer."""
|
||||||
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# For now, use a simple estimation
|
||||||
|
# TODO: Use actual Gemini tokenizer when available in SDK
|
||||||
|
# Rough estimation: ~4 characters per token for English text
|
||||||
|
return len(text) // 4
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type."""
|
||||||
|
return ProviderType.GOOGLE
|
||||||
|
|
||||||
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
|
"""Validate if the model name is supported."""
|
||||||
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
|
||||||
|
|
||||||
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
|
"""Check if the model supports extended thinking mode."""
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
return capabilities.supports_extended_thinking
|
||||||
|
|
||||||
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
"""Resolve model shorthand to full name."""
|
||||||
|
# Check if it's a shorthand
|
||||||
|
shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower())
|
||||||
|
if isinstance(shorthand_value, str):
|
||||||
|
return shorthand_value
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def _extract_usage(self, response) -> Dict[str, int]:
|
||||||
|
"""Extract token usage from Gemini response."""
|
||||||
|
usage = {}
|
||||||
|
|
||||||
|
# Try to extract usage metadata from response
|
||||||
|
# Note: The actual structure depends on the SDK version and response format
|
||||||
|
if hasattr(response, "usage_metadata"):
|
||||||
|
metadata = response.usage_metadata
|
||||||
|
if hasattr(metadata, "prompt_token_count"):
|
||||||
|
usage["input_tokens"] = metadata.prompt_token_count
|
||||||
|
if hasattr(metadata, "candidates_token_count"):
|
||||||
|
usage["output_tokens"] = metadata.candidates_token_count
|
||||||
|
if "input_tokens" in usage and "output_tokens" in usage:
|
||||||
|
usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
|
||||||
|
|
||||||
|
return usage
|
||||||
163
providers/openai.py
Normal file
163
providers/openai.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""OpenAI model provider implementation."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, Optional, List, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from .base import ModelProvider, ModelResponse, ModelCapabilities, ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIModelProvider(ModelProvider):
|
||||||
|
"""OpenAI model provider implementation."""
|
||||||
|
|
||||||
|
# Model configurations
|
||||||
|
SUPPORTED_MODELS = {
|
||||||
|
"o3": {
|
||||||
|
"max_tokens": 200_000, # 200K tokens
|
||||||
|
"supports_extended_thinking": False,
|
||||||
|
},
|
||||||
|
"o3-mini": {
|
||||||
|
"max_tokens": 200_000, # 200K tokens
|
||||||
|
"supports_extended_thinking": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize OpenAI provider with API key."""
|
||||||
|
super().__init__(api_key, **kwargs)
|
||||||
|
self._client = None
|
||||||
|
self.base_url = kwargs.get("base_url") # Support custom endpoints
|
||||||
|
self.organization = kwargs.get("organization")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self):
|
||||||
|
"""Lazy initialization of OpenAI client."""
|
||||||
|
if self._client is None:
|
||||||
|
client_kwargs = {"api_key": self.api_key}
|
||||||
|
if self.base_url:
|
||||||
|
client_kwargs["base_url"] = self.base_url
|
||||||
|
if self.organization:
|
||||||
|
client_kwargs["organization"] = self.organization
|
||||||
|
|
||||||
|
self._client = OpenAI(**client_kwargs)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
|
"""Get capabilities for a specific OpenAI model."""
|
||||||
|
if model_name not in self.SUPPORTED_MODELS:
|
||||||
|
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||||
|
|
||||||
|
config = self.SUPPORTED_MODELS[model_name]
|
||||||
|
|
||||||
|
return ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name=model_name,
|
||||||
|
friendly_name="OpenAI",
|
||||||
|
max_tokens=config["max_tokens"],
|
||||||
|
supports_extended_thinking=config["supports_extended_thinking"],
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
temperature_range=(0.0, 2.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using OpenAI model."""
|
||||||
|
# Validate parameters
|
||||||
|
self.validate_parameters(model_name, temperature)
|
||||||
|
|
||||||
|
# Prepare messages
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
# Prepare completion parameters
|
||||||
|
completion_params = {
|
||||||
|
"model": model_name,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add max tokens if specified
|
||||||
|
if max_output_tokens:
|
||||||
|
completion_params["max_tokens"] = max_output_tokens
|
||||||
|
|
||||||
|
# Add any additional OpenAI-specific parameters
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop"]:
|
||||||
|
completion_params[key] = value
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate completion
|
||||||
|
response = self.client.chat.completions.create(**completion_params)
|
||||||
|
|
||||||
|
# Extract content and usage
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
|
return ModelResponse(
|
||||||
|
content=content,
|
||||||
|
usage=usage,
|
||||||
|
model_name=model_name,
|
||||||
|
friendly_name="OpenAI",
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
metadata={
|
||||||
|
"finish_reason": response.choices[0].finish_reason,
|
||||||
|
"model": response.model, # Actual model used (in case of fallbacks)
|
||||||
|
"id": response.id,
|
||||||
|
"created": response.created,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log error and re-raise with more context
|
||||||
|
error_msg = f"OpenAI API error for model {model_name}: {str(e)}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
|
"""Count tokens for the given text.
|
||||||
|
|
||||||
|
Note: For accurate token counting, we should use tiktoken library.
|
||||||
|
This is a simplified estimation.
|
||||||
|
"""
|
||||||
|
# TODO: Implement proper token counting with tiktoken
|
||||||
|
# For now, use rough estimation
|
||||||
|
# O3 models ~4 chars per token
|
||||||
|
return len(text) // 4
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type."""
|
||||||
|
return ProviderType.OPENAI
|
||||||
|
|
||||||
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
|
"""Validate if the model name is supported."""
|
||||||
|
return model_name in self.SUPPORTED_MODELS
|
||||||
|
|
||||||
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
|
"""Check if the model supports extended thinking mode."""
|
||||||
|
# Currently no OpenAI models support extended thinking
|
||||||
|
# This may change with future O3 models
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _extract_usage(self, response) -> Dict[str, int]:
|
||||||
|
"""Extract token usage from OpenAI response."""
|
||||||
|
usage = {}
|
||||||
|
|
||||||
|
if hasattr(response, "usage") and response.usage:
|
||||||
|
usage["input_tokens"] = response.usage.prompt_tokens
|
||||||
|
usage["output_tokens"] = response.usage.completion_tokens
|
||||||
|
usage["total_tokens"] = response.usage.total_tokens
|
||||||
|
|
||||||
|
return usage
|
||||||
136
providers/registry.py
Normal file
136
providers/registry.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""Model provider registry for managing available providers."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, Optional, Type, List
|
||||||
|
from .base import ModelProvider, ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProviderRegistry:
|
||||||
|
"""Registry for managing model providers."""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_providers: Dict[ProviderType, Type[ModelProvider]] = {}
|
||||||
|
_initialized_providers: Dict[ProviderType, ModelProvider] = {}
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
"""Singleton pattern for registry."""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_provider(cls, provider_type: ProviderType, provider_class: Type[ModelProvider]) -> None:
|
||||||
|
"""Register a new provider class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
|
||||||
|
provider_class: Class that implements ModelProvider interface
|
||||||
|
"""
|
||||||
|
cls._providers[provider_type] = provider_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
|
||||||
|
"""Get an initialized provider instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Type of provider to get
|
||||||
|
force_new: Force creation of new instance instead of using cached
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Initialized ModelProvider instance or None if not available
|
||||||
|
"""
|
||||||
|
# Return cached instance if available and not forcing new
|
||||||
|
if not force_new and provider_type in cls._initialized_providers:
|
||||||
|
return cls._initialized_providers[provider_type]
|
||||||
|
|
||||||
|
# Check if provider class is registered
|
||||||
|
if provider_type not in cls._providers:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get API key from environment
|
||||||
|
api_key = cls._get_api_key_for_provider(provider_type)
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Initialize provider
|
||||||
|
provider_class = cls._providers[provider_type]
|
||||||
|
provider = provider_class(api_key=api_key)
|
||||||
|
|
||||||
|
# Cache the instance
|
||||||
|
cls._initialized_providers[provider_type] = provider
|
||||||
|
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
|
||||||
|
"""Get provider instance for a specific model name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model (e.g., "gemini-2.0-flash-exp", "o3-mini")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelProvider instance that supports this model
|
||||||
|
"""
|
||||||
|
# Check each registered provider
|
||||||
|
for provider_type, provider_class in cls._providers.items():
|
||||||
|
# Get or create provider instance
|
||||||
|
provider = cls.get_provider(provider_type)
|
||||||
|
if provider and provider.validate_model_name(model_name):
|
||||||
|
return provider
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_available_providers(cls) -> List[ProviderType]:
|
||||||
|
"""Get list of registered provider types."""
|
||||||
|
return list(cls._providers.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_available_models(cls) -> Dict[str, ProviderType]:
|
||||||
|
"""Get mapping of all available models to their providers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping model names to provider types
|
||||||
|
"""
|
||||||
|
models = {}
|
||||||
|
|
||||||
|
for provider_type in cls._providers:
|
||||||
|
provider = cls.get_provider(provider_type)
|
||||||
|
if provider:
|
||||||
|
# This assumes providers have a method to list supported models
|
||||||
|
# We'll need to add this to the interface
|
||||||
|
pass
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
|
||||||
|
"""Get API key for a provider from environment variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Provider type to get API key for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API key string or None if not found
|
||||||
|
"""
|
||||||
|
key_mapping = {
|
||||||
|
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||||
|
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||||
|
}
|
||||||
|
|
||||||
|
env_var = key_mapping.get(provider_type)
|
||||||
|
if not env_var:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return os.getenv(env_var)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_cache(cls) -> None:
|
||||||
|
"""Clear cached provider instances."""
|
||||||
|
cls._initialized_providers.clear()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unregister_provider(cls, provider_type: ProviderType) -> None:
|
||||||
|
"""Unregister a provider (mainly for testing)."""
|
||||||
|
cls._providers.pop(provider_type, None)
|
||||||
|
cls._initialized_providers.pop(provider_type, None)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
mcp>=1.0.0
|
mcp>=1.0.0
|
||||||
google-genai>=1.19.0
|
google-genai>=1.19.0
|
||||||
|
openai>=1.0.0
|
||||||
pydantic>=2.0.0
|
pydantic>=2.0.0
|
||||||
redis>=5.0.0
|
redis>=5.0.0
|
||||||
|
|
||||||
|
|||||||
94
server.py
94
server.py
@@ -117,23 +117,46 @@ TOOLS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def configure_gemini():
|
def configure_providers():
|
||||||
"""
|
"""
|
||||||
Configure Gemini API with the provided API key.
|
Configure and validate AI providers based on available API keys.
|
||||||
|
|
||||||
This function validates that the GEMINI_API_KEY environment variable is set.
|
This function checks for API keys and registers the appropriate providers.
|
||||||
The actual API key is used when creating Gemini clients within individual tools
|
At least one valid API key (Gemini or OpenAI) is required.
|
||||||
to ensure proper isolation and error handling.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If GEMINI_API_KEY environment variable is not set
|
ValueError: If no valid API keys are found
|
||||||
"""
|
"""
|
||||||
api_key = os.getenv("GEMINI_API_KEY")
|
from providers import ModelProviderRegistry
|
||||||
if not api_key:
|
from providers.base import ProviderType
|
||||||
raise ValueError("GEMINI_API_KEY environment variable is required. Please set it with your Gemini API key.")
|
from providers.gemini import GeminiModelProvider
|
||||||
# Note: We don't store the API key globally for security reasons
|
from providers.openai import OpenAIModelProvider
|
||||||
# Each tool creates its own Gemini client with the API key when needed
|
|
||||||
logger.info("Gemini API key found")
|
valid_providers = []
|
||||||
|
|
||||||
|
# Check for Gemini API key
|
||||||
|
gemini_key = os.getenv("GEMINI_API_KEY")
|
||||||
|
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
valid_providers.append("Gemini")
|
||||||
|
logger.info("Gemini API key found - Gemini models available")
|
||||||
|
|
||||||
|
# Check for OpenAI API key
|
||||||
|
openai_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if openai_key and openai_key != "your_openai_api_key_here":
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
valid_providers.append("OpenAI (o3)")
|
||||||
|
logger.info("OpenAI API key found - o3 model available")
|
||||||
|
|
||||||
|
# Require at least one valid provider
|
||||||
|
if not valid_providers:
|
||||||
|
raise ValueError(
|
||||||
|
"At least one API key is required. Please set either:\n"
|
||||||
|
"- GEMINI_API_KEY for Gemini models\n"
|
||||||
|
"- OPENAI_API_KEY for OpenAI o3 model"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Available providers: {', '.join(valid_providers)}")
|
||||||
|
|
||||||
|
|
||||||
@server.list_tools()
|
@server.list_tools()
|
||||||
@@ -363,10 +386,15 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
else:
|
else:
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Successfully added user turn to thread {continuation_id}")
|
logger.debug(f"[CONVERSATION_DEBUG] Successfully added user turn to thread {continuation_id}")
|
||||||
|
|
||||||
# Build conversation history and track token usage
|
# Create model context early to use for history building
|
||||||
|
from utils.model_context import ModelContext
|
||||||
|
model_context = ModelContext.from_arguments(arguments)
|
||||||
|
|
||||||
|
# Build conversation history with model-specific limits
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Building conversation history for thread {continuation_id}")
|
logger.debug(f"[CONVERSATION_DEBUG] Building conversation history for thread {continuation_id}")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Thread has {len(context.turns)} turns, tool: {context.tool_name}")
|
logger.debug(f"[CONVERSATION_DEBUG] Thread has {len(context.turns)} turns, tool: {context.tool_name}")
|
||||||
conversation_history, conversation_tokens = build_conversation_history(context)
|
logger.debug(f"[CONVERSATION_DEBUG] Using model: {model_context.model_name}")
|
||||||
|
conversation_history, conversation_tokens = build_conversation_history(context, model_context)
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Conversation history built: {conversation_tokens:,} tokens")
|
logger.debug(f"[CONVERSATION_DEBUG] Conversation history built: {conversation_tokens:,} tokens")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Conversation history length: {len(conversation_history)} chars")
|
logger.debug(f"[CONVERSATION_DEBUG] Conversation history length: {len(conversation_history)} chars")
|
||||||
|
|
||||||
@@ -374,8 +402,12 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
follow_up_instructions = get_follow_up_instructions(len(context.turns))
|
follow_up_instructions = get_follow_up_instructions(len(context.turns))
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Follow-up instructions added for turn {len(context.turns)}")
|
logger.debug(f"[CONVERSATION_DEBUG] Follow-up instructions added for turn {len(context.turns)}")
|
||||||
|
|
||||||
# Merge original context with new prompt and follow-up instructions
|
# All tools now use standardized 'prompt' field
|
||||||
original_prompt = arguments.get("prompt", "")
|
original_prompt = arguments.get("prompt", "")
|
||||||
|
logger.debug(f"[CONVERSATION_DEBUG] Extracting user input from 'prompt' field")
|
||||||
|
logger.debug(f"[CONVERSATION_DEBUG] User input length: {len(original_prompt)} chars")
|
||||||
|
|
||||||
|
# Merge original context with new prompt and follow-up instructions
|
||||||
if conversation_history:
|
if conversation_history:
|
||||||
enhanced_prompt = (
|
enhanced_prompt = (
|
||||||
f"{conversation_history}\n\n=== NEW USER INPUT ===\n{original_prompt}\n\n{follow_up_instructions}"
|
f"{conversation_history}\n\n=== NEW USER INPUT ===\n{original_prompt}\n\n{follow_up_instructions}"
|
||||||
@@ -385,15 +417,25 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
|
|
||||||
# Update arguments with enhanced context and remaining token budget
|
# Update arguments with enhanced context and remaining token budget
|
||||||
enhanced_arguments = arguments.copy()
|
enhanced_arguments = arguments.copy()
|
||||||
|
|
||||||
|
# Store the enhanced prompt in the prompt field
|
||||||
enhanced_arguments["prompt"] = enhanced_prompt
|
enhanced_arguments["prompt"] = enhanced_prompt
|
||||||
|
logger.debug(f"[CONVERSATION_DEBUG] Storing enhanced prompt in 'prompt' field")
|
||||||
|
|
||||||
# Calculate remaining token budget for current request files/content
|
# Calculate remaining token budget based on current model
|
||||||
from config import MAX_CONTENT_TOKENS
|
# (model_context was already created above for history building)
|
||||||
|
token_allocation = model_context.calculate_token_allocation()
|
||||||
remaining_tokens = MAX_CONTENT_TOKENS - conversation_tokens
|
|
||||||
|
# Calculate remaining tokens for files/new content
|
||||||
|
# History has already consumed some of the content budget
|
||||||
|
remaining_tokens = token_allocation.content_tokens - conversation_tokens
|
||||||
enhanced_arguments["_remaining_tokens"] = max(0, remaining_tokens) # Ensure non-negative
|
enhanced_arguments["_remaining_tokens"] = max(0, remaining_tokens) # Ensure non-negative
|
||||||
|
enhanced_arguments["_model_context"] = model_context # Pass context for use in tools
|
||||||
|
|
||||||
logger.debug("[CONVERSATION_DEBUG] Token budget calculation:")
|
logger.debug("[CONVERSATION_DEBUG] Token budget calculation:")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] MAX_CONTENT_TOKENS: {MAX_CONTENT_TOKENS:,}")
|
logger.debug(f"[CONVERSATION_DEBUG] Model: {model_context.model_name}")
|
||||||
|
logger.debug(f"[CONVERSATION_DEBUG] Total capacity: {token_allocation.total_tokens:,}")
|
||||||
|
logger.debug(f"[CONVERSATION_DEBUG] Content allocation: {token_allocation.content_tokens:,}")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Conversation tokens: {conversation_tokens:,}")
|
logger.debug(f"[CONVERSATION_DEBUG] Conversation tokens: {conversation_tokens:,}")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Remaining tokens: {remaining_tokens:,}")
|
logger.debug(f"[CONVERSATION_DEBUG] Remaining tokens: {remaining_tokens:,}")
|
||||||
|
|
||||||
@@ -485,13 +527,19 @@ async def main():
|
|||||||
The server communicates via standard input/output streams using the
|
The server communicates via standard input/output streams using the
|
||||||
MCP protocol's JSON-RPC message format.
|
MCP protocol's JSON-RPC message format.
|
||||||
"""
|
"""
|
||||||
# Validate that Gemini API key is available before starting
|
# Validate and configure providers based on available API keys
|
||||||
configure_gemini()
|
configure_providers()
|
||||||
|
|
||||||
# Log startup message for Docker log monitoring
|
# Log startup message for Docker log monitoring
|
||||||
logger.info("Gemini MCP Server starting up...")
|
logger.info("Gemini MCP Server starting up...")
|
||||||
logger.info(f"Log level: {log_level}")
|
logger.info(f"Log level: {log_level}")
|
||||||
logger.info(f"Using default model: {DEFAULT_MODEL}")
|
|
||||||
|
# Log current model mode
|
||||||
|
from config import IS_AUTO_MODE
|
||||||
|
if IS_AUTO_MODE:
|
||||||
|
logger.info("Model mode: AUTO (Claude will select the best model for each task)")
|
||||||
|
else:
|
||||||
|
logger.info(f"Model mode: Fixed model '{DEFAULT_MODEL}'")
|
||||||
|
|
||||||
# Import here to avoid circular imports
|
# Import here to avoid circular imports
|
||||||
from config import DEFAULT_THINKING_MODE_THINKDEEP
|
from config import DEFAULT_THINKING_MODE_THINKDEEP
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ else
|
|||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
echo "✅ Created .env from .env.example"
|
echo "✅ Created .env from .env.example"
|
||||||
|
|
||||||
# Customize the API key if it's set in environment
|
# Customize the API keys if they're set in environment
|
||||||
if [ -n "$GEMINI_API_KEY" ]; then
|
if [ -n "${GEMINI_API_KEY:-}" ]; then
|
||||||
# Replace the placeholder API key with the actual value
|
# Replace the placeholder API key with the actual value
|
||||||
if command -v sed >/dev/null 2>&1; then
|
if command -v sed >/dev/null 2>&1; then
|
||||||
sed -i.bak "s/your_gemini_api_key_here/$GEMINI_API_KEY/" .env && rm .env.bak
|
sed -i.bak "s/your_gemini_api_key_here/$GEMINI_API_KEY/" .env && rm .env.bak
|
||||||
@@ -40,6 +40,18 @@ else
|
|||||||
echo "⚠️ GEMINI_API_KEY not found in environment. Please edit .env and add your API key."
|
echo "⚠️ GEMINI_API_KEY not found in environment. Please edit .env and add your API key."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ -n "${OPENAI_API_KEY:-}" ]; then
|
||||||
|
# Replace the placeholder API key with the actual value
|
||||||
|
if command -v sed >/dev/null 2>&1; then
|
||||||
|
sed -i.bak "s/your_openai_api_key_here/$OPENAI_API_KEY/" .env && rm .env.bak
|
||||||
|
echo "✅ Updated .env with existing OPENAI_API_KEY from environment"
|
||||||
|
else
|
||||||
|
echo "⚠️ Found OPENAI_API_KEY in environment, but sed not available. Please update .env manually."
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "⚠️ OPENAI_API_KEY not found in environment. Please edit .env and add your API key."
|
||||||
|
fi
|
||||||
|
|
||||||
# Update WORKSPACE_ROOT to use current user's home directory
|
# Update WORKSPACE_ROOT to use current user's home directory
|
||||||
if command -v sed >/dev/null 2>&1; then
|
if command -v sed >/dev/null 2>&1; then
|
||||||
sed -i.bak "s|WORKSPACE_ROOT=/Users/your-username|WORKSPACE_ROOT=$HOME|" .env && rm .env.bak
|
sed -i.bak "s|WORKSPACE_ROOT=/Users/your-username|WORKSPACE_ROOT=$HOME|" .env && rm .env.bak
|
||||||
@@ -74,6 +86,41 @@ if ! docker compose version &> /dev/null; then
|
|||||||
COMPOSE_CMD="docker-compose"
|
COMPOSE_CMD="docker-compose"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Check if at least one API key is properly configured
|
||||||
|
echo "🔑 Checking API key configuration..."
|
||||||
|
source .env 2>/dev/null || true
|
||||||
|
|
||||||
|
VALID_GEMINI_KEY=false
|
||||||
|
VALID_OPENAI_KEY=false
|
||||||
|
|
||||||
|
# Check if GEMINI_API_KEY is set and not the placeholder
|
||||||
|
if [ -n "${GEMINI_API_KEY:-}" ] && [ "$GEMINI_API_KEY" != "your_gemini_api_key_here" ]; then
|
||||||
|
VALID_GEMINI_KEY=true
|
||||||
|
echo "✅ Valid GEMINI_API_KEY found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if OPENAI_API_KEY is set and not the placeholder
|
||||||
|
if [ -n "${OPENAI_API_KEY:-}" ] && [ "$OPENAI_API_KEY" != "your_openai_api_key_here" ]; then
|
||||||
|
VALID_OPENAI_KEY=true
|
||||||
|
echo "✅ Valid OPENAI_API_KEY found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Require at least one valid API key
|
||||||
|
if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ]; then
|
||||||
|
echo ""
|
||||||
|
echo "❌ ERROR: At least one valid API key is required!"
|
||||||
|
echo ""
|
||||||
|
echo "Please edit the .env file and set at least one of:"
|
||||||
|
echo " - GEMINI_API_KEY (get from https://makersuite.google.com/app/apikey)"
|
||||||
|
echo " - OPENAI_API_KEY (get from https://platform.openai.com/api-keys)"
|
||||||
|
echo ""
|
||||||
|
echo "Example:"
|
||||||
|
echo " GEMINI_API_KEY=your-actual-api-key-here"
|
||||||
|
echo " OPENAI_API_KEY=sk-your-actual-openai-key-here"
|
||||||
|
echo ""
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
echo "🛠️ Building and starting services..."
|
echo "🛠️ Building and starting services..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
@@ -143,8 +190,15 @@ $COMPOSE_CMD ps --format table
|
|||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "🔄 Next steps:"
|
echo "🔄 Next steps:"
|
||||||
if grep -q "your-gemini-api-key-here" .env 2>/dev/null || false; then
|
NEEDS_KEY_UPDATE=false
|
||||||
echo "1. Edit .env and replace 'your-gemini-api-key-here' with your actual Gemini API key"
|
if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null; then
|
||||||
|
NEEDS_KEY_UPDATE=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$NEEDS_KEY_UPDATE" = true ]; then
|
||||||
|
echo "1. Edit .env and replace placeholder API keys with actual ones"
|
||||||
|
echo " - GEMINI_API_KEY: your-gemini-api-key-here"
|
||||||
|
echo " - OPENAI_API_KEY: your-openai-api-key-here"
|
||||||
echo "2. Restart services: $COMPOSE_CMD restart"
|
echo "2. Restart services: $COMPOSE_CMD restart"
|
||||||
echo "3. Copy the configuration below to your Claude Desktop config:"
|
echo "3. Copy the configuration below to your Claude Desktop config:"
|
||||||
else
|
else
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ class BasicConversationTest(BaseSimulatorTest):
|
|||||||
{
|
{
|
||||||
"prompt": "Please use low thinking mode. Analyze this Python code and explain what it does",
|
"prompt": "Please use low thinking mode. Analyze this Python code and explain what it does",
|
||||||
"files": [self.test_files["python"]],
|
"files": [self.test_files["python"]],
|
||||||
|
"model": "flash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,6 +55,7 @@ class BasicConversationTest(BaseSimulatorTest):
|
|||||||
"prompt": "Please use low thinking mode. Now focus on the Calculator class specifically. Are there any improvements you'd suggest?",
|
"prompt": "Please use low thinking mode. Now focus on the Calculator class specifically. Are there any improvements you'd suggest?",
|
||||||
"files": [self.test_files["python"]], # Same file - should be deduplicated
|
"files": [self.test_files["python"]], # Same file - should be deduplicated
|
||||||
"continuation_id": continuation_id,
|
"continuation_id": continuation_id,
|
||||||
|
"model": "flash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -69,6 +71,7 @@ class BasicConversationTest(BaseSimulatorTest):
|
|||||||
"prompt": "Please use low thinking mode. Now also analyze this configuration file and see how it might relate to the Python code",
|
"prompt": "Please use low thinking mode. Now also analyze this configuration file and see how it might relate to the Python code",
|
||||||
"files": [self.test_files["python"], self.test_files["config"]],
|
"files": [self.test_files["python"], self.test_files["config"]],
|
||||||
"continuation_id": continuation_id,
|
"continuation_id": continuation_id,
|
||||||
|
"model": "flash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ DATABASE_CONFIG = {
|
|||||||
{
|
{
|
||||||
"path": os.getcwd(),
|
"path": os.getcwd(),
|
||||||
"files": [validation_file],
|
"files": [validation_file],
|
||||||
"original_request": "Test for content duplication in precommit tool",
|
"prompt": "Test for content duplication in precommit tool",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -116,16 +116,18 @@ DATABASE_CONFIG = {
|
|||||||
{
|
{
|
||||||
"prompt": "Please use low thinking mode. Analyze this config file",
|
"prompt": "Please use low thinking mode. Analyze this config file",
|
||||||
"files": [validation_file],
|
"files": [validation_file],
|
||||||
|
"model": "flash",
|
||||||
}, # Using absolute path
|
}, # Using absolute path
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"codereview",
|
"codereview",
|
||||||
{
|
{
|
||||||
"files": [validation_file],
|
"files": [validation_file],
|
||||||
"context": "Please use low thinking mode. Review this configuration",
|
"prompt": "Please use low thinking mode. Review this configuration",
|
||||||
|
"model": "flash",
|
||||||
}, # Using absolute path
|
}, # Using absolute path
|
||||||
),
|
),
|
||||||
("analyze", {"files": [validation_file], "analysis_type": "code_quality"}), # Using absolute path
|
("analyze", {"files": [validation_file], "analysis_type": "code_quality", "model": "flash"}), # Using absolute path
|
||||||
]
|
]
|
||||||
|
|
||||||
for tool_name, params in tools_to_test:
|
for tool_name, params in tools_to_test:
|
||||||
@@ -163,6 +165,7 @@ DATABASE_CONFIG = {
|
|||||||
"prompt": "Please use low thinking mode. Continue analyzing this configuration file",
|
"prompt": "Please use low thinking mode. Continue analyzing this configuration file",
|
||||||
"files": [validation_file], # Same file should be deduplicated
|
"files": [validation_file], # Same file should be deduplicated
|
||||||
"continuation_id": thread_id,
|
"continuation_id": thread_id,
|
||||||
|
"model": "flash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ def hash_pwd(pwd):
|
|||||||
"prompt": "Please give me a quick one line reply. I have an authentication module that needs review. Can you help me understand potential issues?",
|
"prompt": "Please give me a quick one line reply. I have an authentication module that needs review. Can you help me understand potential issues?",
|
||||||
"files": [auth_file],
|
"files": [auth_file],
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
|
"model": "flash",
|
||||||
}
|
}
|
||||||
|
|
||||||
response1, continuation_id1 = self.call_mcp_tool("chat", chat_params)
|
response1, continuation_id1 = self.call_mcp_tool("chat", chat_params)
|
||||||
@@ -106,8 +107,9 @@ def hash_pwd(pwd):
|
|||||||
self.logger.info(" Step 2: analyze tool - Deep code analysis (fresh)")
|
self.logger.info(" Step 2: analyze tool - Deep code analysis (fresh)")
|
||||||
analyze_params = {
|
analyze_params = {
|
||||||
"files": [auth_file],
|
"files": [auth_file],
|
||||||
"question": "Please give me a quick one line reply. What are the security vulnerabilities and architectural issues in this authentication code?",
|
"prompt": "Please give me a quick one line reply. What are the security vulnerabilities and architectural issues in this authentication code?",
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
|
"model": "flash",
|
||||||
}
|
}
|
||||||
|
|
||||||
response2, continuation_id2 = self.call_mcp_tool("analyze", analyze_params)
|
response2, continuation_id2 = self.call_mcp_tool("analyze", analyze_params)
|
||||||
@@ -127,6 +129,7 @@ def hash_pwd(pwd):
|
|||||||
"prompt": "Please give me a quick one line reply. I also have this configuration file. Can you analyze it alongside the authentication code?",
|
"prompt": "Please give me a quick one line reply. I also have this configuration file. Can you analyze it alongside the authentication code?",
|
||||||
"files": [auth_file, config_file_path], # Old + new file
|
"files": [auth_file, config_file_path], # Old + new file
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
|
"model": "flash",
|
||||||
}
|
}
|
||||||
|
|
||||||
response3, _ = self.call_mcp_tool("chat", chat_continue_params)
|
response3, _ = self.call_mcp_tool("chat", chat_continue_params)
|
||||||
@@ -141,8 +144,9 @@ def hash_pwd(pwd):
|
|||||||
self.logger.info(" Step 4: debug tool - Identify specific problems")
|
self.logger.info(" Step 4: debug tool - Identify specific problems")
|
||||||
debug_params = {
|
debug_params = {
|
||||||
"files": [auth_file, config_file_path],
|
"files": [auth_file, config_file_path],
|
||||||
"error_description": "Please give me a quick one line reply. The authentication system has security vulnerabilities. Help me identify and fix the main issues.",
|
"prompt": "Please give me a quick one line reply. The authentication system has security vulnerabilities. Help me identify and fix the main issues.",
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
|
"model": "flash",
|
||||||
}
|
}
|
||||||
|
|
||||||
response4, continuation_id4 = self.call_mcp_tool("debug", debug_params)
|
response4, continuation_id4 = self.call_mcp_tool("debug", debug_params)
|
||||||
@@ -161,8 +165,9 @@ def hash_pwd(pwd):
|
|||||||
debug_continue_params = {
|
debug_continue_params = {
|
||||||
"continuation_id": continuation_id4,
|
"continuation_id": continuation_id4,
|
||||||
"files": [auth_file, config_file_path],
|
"files": [auth_file, config_file_path],
|
||||||
"error_description": "Please give me a quick one line reply. What specific code changes would you recommend to fix the password hashing vulnerability?",
|
"prompt": "Please give me a quick one line reply. What specific code changes would you recommend to fix the password hashing vulnerability?",
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
|
"model": "flash",
|
||||||
}
|
}
|
||||||
|
|
||||||
response5, _ = self.call_mcp_tool("debug", debug_continue_params)
|
response5, _ = self.call_mcp_tool("debug", debug_continue_params)
|
||||||
@@ -174,8 +179,9 @@ def hash_pwd(pwd):
|
|||||||
self.logger.info(" Step 6: codereview tool - Comprehensive code review")
|
self.logger.info(" Step 6: codereview tool - Comprehensive code review")
|
||||||
codereview_params = {
|
codereview_params = {
|
||||||
"files": [auth_file, config_file_path],
|
"files": [auth_file, config_file_path],
|
||||||
"context": "Please give me a quick one line reply. Comprehensive security-focused code review for production readiness",
|
"prompt": "Please give me a quick one line reply. Comprehensive security-focused code review for production readiness",
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
|
"model": "flash",
|
||||||
}
|
}
|
||||||
|
|
||||||
response6, continuation_id6 = self.call_mcp_tool("codereview", codereview_params)
|
response6, continuation_id6 = self.call_mcp_tool("codereview", codereview_params)
|
||||||
@@ -207,7 +213,7 @@ def secure_login(user, pwd):
|
|||||||
precommit_params = {
|
precommit_params = {
|
||||||
"path": self.test_dir,
|
"path": self.test_dir,
|
||||||
"files": [auth_file, config_file_path, improved_file],
|
"files": [auth_file, config_file_path, improved_file],
|
||||||
"original_request": "Please give me a quick one line reply. Ready to commit security improvements to authentication module",
|
"prompt": "Please give me a quick one line reply. Ready to commit security improvements to authentication module",
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ class CrossToolContinuationTest(BaseSimulatorTest):
|
|||||||
{
|
{
|
||||||
"prompt": "Please use low thinking mode. Look at this Python code and tell me what you think about it",
|
"prompt": "Please use low thinking mode. Look at this Python code and tell me what you think about it",
|
||||||
"files": [self.test_files["python"]],
|
"files": [self.test_files["python"]],
|
||||||
|
"model": "flash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -81,6 +82,7 @@ class CrossToolContinuationTest(BaseSimulatorTest):
|
|||||||
"prompt": "Please use low thinking mode. Think deeply about potential performance issues in this code",
|
"prompt": "Please use low thinking mode. Think deeply about potential performance issues in this code",
|
||||||
"files": [self.test_files["python"]], # Same file should be deduplicated
|
"files": [self.test_files["python"]], # Same file should be deduplicated
|
||||||
"continuation_id": chat_id,
|
"continuation_id": chat_id,
|
||||||
|
"model": "flash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -93,8 +95,9 @@ class CrossToolContinuationTest(BaseSimulatorTest):
|
|||||||
"codereview",
|
"codereview",
|
||||||
{
|
{
|
||||||
"files": [self.test_files["python"]], # Same file should be deduplicated
|
"files": [self.test_files["python"]], # Same file should be deduplicated
|
||||||
"context": "Building on our previous analysis, provide a comprehensive code review",
|
"prompt": "Building on our previous analysis, provide a comprehensive code review",
|
||||||
"continuation_id": chat_id,
|
"continuation_id": chat_id,
|
||||||
|
"model": "flash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -116,7 +119,7 @@ class CrossToolContinuationTest(BaseSimulatorTest):
|
|||||||
|
|
||||||
# Start with analyze
|
# Start with analyze
|
||||||
analyze_response, analyze_id = self.call_mcp_tool(
|
analyze_response, analyze_id = self.call_mcp_tool(
|
||||||
"analyze", {"files": [self.test_files["python"]], "analysis_type": "code_quality"}
|
"analyze", {"files": [self.test_files["python"]], "analysis_type": "code_quality", "model": "flash"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if not analyze_response or not analyze_id:
|
if not analyze_response or not analyze_id:
|
||||||
@@ -128,8 +131,9 @@ class CrossToolContinuationTest(BaseSimulatorTest):
|
|||||||
"debug",
|
"debug",
|
||||||
{
|
{
|
||||||
"files": [self.test_files["python"]], # Same file should be deduplicated
|
"files": [self.test_files["python"]], # Same file should be deduplicated
|
||||||
"issue_description": "Based on our analysis, help debug the performance issue in fibonacci",
|
"prompt": "Based on our analysis, help debug the performance issue in fibonacci",
|
||||||
"continuation_id": analyze_id,
|
"continuation_id": analyze_id,
|
||||||
|
"model": "flash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -144,6 +148,7 @@ class CrossToolContinuationTest(BaseSimulatorTest):
|
|||||||
"prompt": "Please use low thinking mode. Think deeply about the architectural implications of the issues we've found",
|
"prompt": "Please use low thinking mode. Think deeply about the architectural implications of the issues we've found",
|
||||||
"files": [self.test_files["python"]], # Same file should be deduplicated
|
"files": [self.test_files["python"]], # Same file should be deduplicated
|
||||||
"continuation_id": analyze_id,
|
"continuation_id": analyze_id,
|
||||||
|
"model": "flash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -169,6 +174,7 @@ class CrossToolContinuationTest(BaseSimulatorTest):
|
|||||||
{
|
{
|
||||||
"prompt": "Please use low thinking mode. Analyze both the Python code and configuration file",
|
"prompt": "Please use low thinking mode. Analyze both the Python code and configuration file",
|
||||||
"files": [self.test_files["python"], self.test_files["config"]],
|
"files": [self.test_files["python"], self.test_files["config"]],
|
||||||
|
"model": "flash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -181,8 +187,9 @@ class CrossToolContinuationTest(BaseSimulatorTest):
|
|||||||
"codereview",
|
"codereview",
|
||||||
{
|
{
|
||||||
"files": [self.test_files["python"], self.test_files["config"]], # Same files
|
"files": [self.test_files["python"], self.test_files["config"]], # Same files
|
||||||
"context": "Review both files in the context of our previous discussion",
|
"prompt": "Review both files in the context of our previous discussion",
|
||||||
"continuation_id": multi_id,
|
"continuation_id": multi_id,
|
||||||
|
"model": "flash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -100,8 +100,9 @@ def divide(x, y):
|
|||||||
precommit_params = {
|
precommit_params = {
|
||||||
"path": self.test_dir, # Required path parameter
|
"path": self.test_dir, # Required path parameter
|
||||||
"files": [dummy_file_path],
|
"files": [dummy_file_path],
|
||||||
"original_request": "Please give me a quick one line reply. Review this code for commit readiness",
|
"prompt": "Please give me a quick one line reply. Review this code for commit readiness",
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
|
"model": "flash",
|
||||||
}
|
}
|
||||||
|
|
||||||
response1, continuation_id = self.call_mcp_tool("precommit", precommit_params)
|
response1, continuation_id = self.call_mcp_tool("precommit", precommit_params)
|
||||||
@@ -124,8 +125,9 @@ def divide(x, y):
|
|||||||
self.logger.info(" Step 2: codereview tool with same file (fresh conversation)")
|
self.logger.info(" Step 2: codereview tool with same file (fresh conversation)")
|
||||||
codereview_params = {
|
codereview_params = {
|
||||||
"files": [dummy_file_path],
|
"files": [dummy_file_path],
|
||||||
"context": "Please give me a quick one line reply. General code review for quality and best practices",
|
"prompt": "Please give me a quick one line reply. General code review for quality and best practices",
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
|
"model": "flash",
|
||||||
}
|
}
|
||||||
|
|
||||||
response2, _ = self.call_mcp_tool("codereview", codereview_params)
|
response2, _ = self.call_mcp_tool("codereview", codereview_params)
|
||||||
@@ -150,8 +152,9 @@ def subtract(a, b):
|
|||||||
"continuation_id": continuation_id,
|
"continuation_id": continuation_id,
|
||||||
"path": self.test_dir, # Required path parameter
|
"path": self.test_dir, # Required path parameter
|
||||||
"files": [dummy_file_path, new_file_path], # Old + new file
|
"files": [dummy_file_path, new_file_path], # Old + new file
|
||||||
"original_request": "Please give me a quick one line reply. Now also review the new feature file along with the previous one",
|
"prompt": "Please give me a quick one line reply. Now also review the new feature file along with the previous one",
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
|
"model": "flash",
|
||||||
}
|
}
|
||||||
|
|
||||||
response3, _ = self.call_mcp_tool("precommit", continue_params)
|
response3, _ = self.call_mcp_tool("precommit", continue_params)
|
||||||
|
|||||||
@@ -15,9 +15,20 @@ parent_dir = Path(__file__).resolve().parent.parent
|
|||||||
if str(parent_dir) not in sys.path:
|
if str(parent_dir) not in sys.path:
|
||||||
sys.path.insert(0, str(parent_dir))
|
sys.path.insert(0, str(parent_dir))
|
||||||
|
|
||||||
# Set dummy API key for tests if not already set
|
# Set dummy API keys for tests if not already set
|
||||||
if "GEMINI_API_KEY" not in os.environ:
|
if "GEMINI_API_KEY" not in os.environ:
|
||||||
os.environ["GEMINI_API_KEY"] = "dummy-key-for-tests"
|
os.environ["GEMINI_API_KEY"] = "dummy-key-for-tests"
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
os.environ["OPENAI_API_KEY"] = "dummy-key-for-tests"
|
||||||
|
|
||||||
|
# Set default model to a specific value for tests to avoid auto mode
|
||||||
|
# This prevents all tests from failing due to missing model parameter
|
||||||
|
os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash-exp"
|
||||||
|
|
||||||
|
# Force reload of config module to pick up the env var
|
||||||
|
import importlib
|
||||||
|
import config
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
# Set MCP_PROJECT_ROOT to a temporary directory for tests
|
# Set MCP_PROJECT_ROOT to a temporary directory for tests
|
||||||
# This provides a safe sandbox for file operations during testing
|
# This provides a safe sandbox for file operations during testing
|
||||||
@@ -29,6 +40,16 @@ os.environ["MCP_PROJECT_ROOT"] = test_root
|
|||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
# Register providers for all tests
|
||||||
|
from providers import ModelProviderRegistry
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
from providers.base import ProviderType
|
||||||
|
|
||||||
|
# Register providers at test startup
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def project_path(tmp_path):
|
def project_path(tmp_path):
|
||||||
|
|||||||
39
tests/mock_helpers.py
Normal file
39
tests/mock_helpers.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""Helper functions for test mocking."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
from providers.base import ModelCapabilities, ProviderType
|
||||||
|
|
||||||
|
def create_mock_provider(model_name="gemini-2.0-flash-exp", max_tokens=1_048_576):
|
||||||
|
"""Create a properly configured mock provider."""
|
||||||
|
mock_provider = Mock()
|
||||||
|
|
||||||
|
# Set up capabilities
|
||||||
|
mock_capabilities = ModelCapabilities(
|
||||||
|
provider=ProviderType.GOOGLE,
|
||||||
|
model_name=model_name,
|
||||||
|
friendly_name="Gemini",
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
temperature_range=(0.0, 2.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_provider.get_capabilities.return_value = mock_capabilities
|
||||||
|
mock_provider.get_provider_type.return_value = ProviderType.GOOGLE
|
||||||
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.validate_model_name.return_value = True
|
||||||
|
|
||||||
|
# Set up generate_content response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.content = "Test response"
|
||||||
|
mock_response.usage = {"input_tokens": 10, "output_tokens": 20}
|
||||||
|
mock_response.model_name = model_name
|
||||||
|
mock_response.friendly_name = "Gemini"
|
||||||
|
mock_response.provider = ProviderType.GOOGLE
|
||||||
|
mock_response.metadata = {"finish_reason": "STOP"}
|
||||||
|
|
||||||
|
mock_provider.generate_content.return_value = mock_response
|
||||||
|
|
||||||
|
return mock_provider
|
||||||
180
tests/test_auto_mode.py
Normal file
180
tests/test_auto_mode.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
"""Tests for auto mode functionality"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from mcp.types import TextContent
|
||||||
|
from tools.analyze import AnalyzeTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestAutoMode:
|
||||||
|
"""Test auto mode configuration and behavior"""
|
||||||
|
|
||||||
|
def test_auto_mode_detection(self):
|
||||||
|
"""Test that auto mode is detected correctly"""
|
||||||
|
# Save original
|
||||||
|
original = os.environ.get("DEFAULT_MODEL", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test auto mode
|
||||||
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
import config
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
assert config.DEFAULT_MODEL == "auto"
|
||||||
|
assert config.IS_AUTO_MODE is True
|
||||||
|
|
||||||
|
# Test non-auto mode
|
||||||
|
os.environ["DEFAULT_MODEL"] = "pro"
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
assert config.DEFAULT_MODEL == "pro"
|
||||||
|
assert config.IS_AUTO_MODE is False
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore
|
||||||
|
if original:
|
||||||
|
os.environ["DEFAULT_MODEL"] = original
|
||||||
|
else:
|
||||||
|
os.environ.pop("DEFAULT_MODEL", None)
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
def test_model_capabilities_descriptions(self):
|
||||||
|
"""Test that model capabilities are properly defined"""
|
||||||
|
from config import MODEL_CAPABILITIES_DESC
|
||||||
|
|
||||||
|
# Check all expected models are present
|
||||||
|
expected_models = ["flash", "pro", "o3", "o3-mini", "gpt-4o"]
|
||||||
|
for model in expected_models:
|
||||||
|
assert model in MODEL_CAPABILITIES_DESC
|
||||||
|
assert isinstance(MODEL_CAPABILITIES_DESC[model], str)
|
||||||
|
assert len(MODEL_CAPABILITIES_DESC[model]) > 50 # Meaningful description
|
||||||
|
|
||||||
|
def test_tool_schema_in_auto_mode(self):
|
||||||
|
"""Test that tool schemas require model in auto mode"""
|
||||||
|
# Save original
|
||||||
|
original = os.environ.get("DEFAULT_MODEL", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Enable auto mode
|
||||||
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
import config
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
tool = AnalyzeTool()
|
||||||
|
schema = tool.get_input_schema()
|
||||||
|
|
||||||
|
# Model should be required
|
||||||
|
assert "model" in schema["required"]
|
||||||
|
|
||||||
|
# Model field should have detailed descriptions
|
||||||
|
model_schema = schema["properties"]["model"]
|
||||||
|
assert "enum" in model_schema
|
||||||
|
assert "flash" in model_schema["enum"]
|
||||||
|
assert "Choose the best model" in model_schema["description"]
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore
|
||||||
|
if original:
|
||||||
|
os.environ["DEFAULT_MODEL"] = original
|
||||||
|
else:
|
||||||
|
os.environ.pop("DEFAULT_MODEL", None)
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
def test_tool_schema_in_normal_mode(self):
|
||||||
|
"""Test that tool schemas don't require model in normal mode"""
|
||||||
|
# This test uses the default from conftest.py which sets non-auto mode
|
||||||
|
tool = AnalyzeTool()
|
||||||
|
schema = tool.get_input_schema()
|
||||||
|
|
||||||
|
# Model should not be required
|
||||||
|
assert "model" not in schema["required"]
|
||||||
|
|
||||||
|
# Model field should have simpler description
|
||||||
|
model_schema = schema["properties"]["model"]
|
||||||
|
assert "enum" not in model_schema
|
||||||
|
assert "Available:" in model_schema["description"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_mode_requires_model_parameter(self):
|
||||||
|
"""Test that auto mode enforces model parameter"""
|
||||||
|
# Save original
|
||||||
|
original = os.environ.get("DEFAULT_MODEL", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Enable auto mode
|
||||||
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
import config
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
tool = AnalyzeTool()
|
||||||
|
|
||||||
|
# Mock the provider to avoid real API calls
|
||||||
|
with patch.object(tool, 'get_model_provider') as mock_provider:
|
||||||
|
# Execute without model parameter
|
||||||
|
result = await tool.execute({
|
||||||
|
"files": ["/tmp/test.py"],
|
||||||
|
"prompt": "Analyze this"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Should get error
|
||||||
|
assert len(result) == 1
|
||||||
|
response = result[0].text
|
||||||
|
assert "error" in response
|
||||||
|
assert "Model parameter is required" in response
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore
|
||||||
|
if original:
|
||||||
|
os.environ["DEFAULT_MODEL"] = original
|
||||||
|
else:
|
||||||
|
os.environ.pop("DEFAULT_MODEL", None)
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
def test_model_field_schema_generation(self):
|
||||||
|
"""Test the get_model_field_schema method"""
|
||||||
|
from tools.base import BaseTool
|
||||||
|
|
||||||
|
# Create a minimal concrete tool for testing
|
||||||
|
class TestTool(BaseTool):
|
||||||
|
def get_name(self): return "test"
|
||||||
|
def get_description(self): return "test"
|
||||||
|
def get_input_schema(self): return {}
|
||||||
|
def get_system_prompt(self): return ""
|
||||||
|
def get_request_model(self): return None
|
||||||
|
async def prepare_prompt(self, request): return ""
|
||||||
|
|
||||||
|
tool = TestTool()
|
||||||
|
|
||||||
|
# Save original
|
||||||
|
original = os.environ.get("DEFAULT_MODEL", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test auto mode
|
||||||
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
import config
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
schema = tool.get_model_field_schema()
|
||||||
|
assert "enum" in schema
|
||||||
|
assert all(model in schema["enum"] for model in ["flash", "pro", "o3"])
|
||||||
|
assert "Choose the best model" in schema["description"]
|
||||||
|
|
||||||
|
# Test normal mode
|
||||||
|
os.environ["DEFAULT_MODEL"] = "pro"
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
schema = tool.get_model_field_schema()
|
||||||
|
assert "enum" not in schema
|
||||||
|
assert "Available:" in schema["description"]
|
||||||
|
assert "'pro'" in schema["description"]
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore
|
||||||
|
if original:
|
||||||
|
os.environ["DEFAULT_MODEL"] = original
|
||||||
|
else:
|
||||||
|
os.environ.pop("DEFAULT_MODEL", None)
|
||||||
|
importlib.reload(config)
|
||||||
@@ -7,6 +7,7 @@ when Gemini doesn't explicitly ask a follow-up question.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@@ -116,20 +117,20 @@ class TestClaudeContinuationOffers:
|
|||||||
mock_redis.return_value = mock_client
|
mock_redis.return_value = mock_client
|
||||||
|
|
||||||
# Mock the model to return a response without follow-up question
|
# Mock the model to return a response without follow-up question
|
||||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response = Mock()
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=Mock(parts=[Mock(text="Analysis complete. The code looks good.")]),
|
content="Analysis complete. The code looks good.",
|
||||||
finish_reason="STOP",
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
)
|
model_name="gemini-2.0-flash-exp",
|
||||||
]
|
metadata={"finish_reason": "STOP"}
|
||||||
mock_model.generate_content.return_value = mock_response
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Execute tool with new conversation
|
# Execute tool with new conversation
|
||||||
arguments = {"prompt": "Analyze this code"}
|
arguments = {"prompt": "Analyze this code", "model": "flash"}
|
||||||
response = await self.tool.execute(arguments)
|
response = await self.tool.execute(arguments)
|
||||||
|
|
||||||
# Parse response
|
# Parse response
|
||||||
@@ -157,15 +158,12 @@ class TestClaudeContinuationOffers:
|
|||||||
mock_redis.return_value = mock_client
|
mock_redis.return_value = mock_client
|
||||||
|
|
||||||
# Mock the model to return a response WITH follow-up question
|
# Mock the model to return a response WITH follow-up question
|
||||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response = Mock()
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
Mock(
|
# Include follow-up JSON in the content
|
||||||
content=Mock(
|
content_with_followup = """Analysis complete. The code looks good.
|
||||||
parts=[
|
|
||||||
Mock(
|
|
||||||
text="""Analysis complete. The code looks good.
|
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -174,14 +172,13 @@ class TestClaudeContinuationOffers:
|
|||||||
"ui_hint": "Examining error handling would help ensure robustness"
|
"ui_hint": "Examining error handling would help ensure robustness"
|
||||||
}
|
}
|
||||||
```"""
|
```"""
|
||||||
)
|
mock_provider.generate_content.return_value = Mock(
|
||||||
]
|
content=content_with_followup,
|
||||||
),
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
finish_reason="STOP",
|
model_name="gemini-2.0-flash-exp",
|
||||||
)
|
metadata={"finish_reason": "STOP"}
|
||||||
]
|
)
|
||||||
mock_model.generate_content.return_value = mock_response
|
mock_get_provider.return_value = mock_provider
|
||||||
mock_create_model.return_value = mock_model
|
|
||||||
|
|
||||||
# Execute tool
|
# Execute tool
|
||||||
arguments = {"prompt": "Analyze this code"}
|
arguments = {"prompt": "Analyze this code"}
|
||||||
@@ -215,17 +212,17 @@ class TestClaudeContinuationOffers:
|
|||||||
mock_client.get.return_value = thread_context.model_dump_json()
|
mock_client.get.return_value = thread_context.model_dump_json()
|
||||||
|
|
||||||
# Mock the model
|
# Mock the model
|
||||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response = Mock()
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=Mock(parts=[Mock(text="Continued analysis complete.")]),
|
content="Continued analysis complete.",
|
||||||
finish_reason="STOP",
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
)
|
model_name="gemini-2.0-flash-exp",
|
||||||
]
|
metadata={"finish_reason": "STOP"}
|
||||||
mock_model.generate_content.return_value = mock_response
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Execute tool with continuation_id
|
# Execute tool with continuation_id
|
||||||
arguments = {"prompt": "Continue the analysis", "continuation_id": "12345678-1234-1234-1234-123456789012"}
|
arguments = {"prompt": "Continue the analysis", "continuation_id": "12345678-1234-1234-1234-123456789012"}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ Tests for dynamic context request and collaboration features
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -24,8 +25,8 @@ class TestDynamicContextRequests:
|
|||||||
return DebugIssueTool()
|
return DebugIssueTool()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_clarification_request_parsing(self, mock_create_model, analyze_tool):
|
async def test_clarification_request_parsing(self, mock_get_provider, analyze_tool):
|
||||||
"""Test that tools correctly parse clarification requests"""
|
"""Test that tools correctly parse clarification requests"""
|
||||||
# Mock model to return a clarification request
|
# Mock model to return a clarification request
|
||||||
clarification_json = json.dumps(
|
clarification_json = json.dumps(
|
||||||
@@ -36,16 +37,21 @@ class TestDynamicContextRequests:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content=clarification_json,
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await analyze_tool.execute(
|
result = await analyze_tool.execute(
|
||||||
{
|
{
|
||||||
"files": ["/absolute/path/src/index.js"],
|
"files": ["/absolute/path/src/index.js"],
|
||||||
"question": "Analyze the dependencies used in this project",
|
"prompt": "Analyze the dependencies used in this project",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -62,8 +68,8 @@ class TestDynamicContextRequests:
|
|||||||
assert clarification["files_needed"] == ["package.json", "package-lock.json"]
|
assert clarification["files_needed"] == ["package.json", "package-lock.json"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_normal_response_not_parsed_as_clarification(self, mock_create_model, debug_tool):
|
async def test_normal_response_not_parsed_as_clarification(self, mock_get_provider, debug_tool):
|
||||||
"""Test that normal responses are not mistaken for clarification requests"""
|
"""Test that normal responses are not mistaken for clarification requests"""
|
||||||
normal_response = """
|
normal_response = """
|
||||||
## Summary
|
## Summary
|
||||||
@@ -75,13 +81,18 @@ class TestDynamicContextRequests:
|
|||||||
**Root Cause:** The module 'utils' is not imported
|
**Root Cause:** The module 'utils' is not imported
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text=normal_response)]))]
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content=normal_response,
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await debug_tool.execute({"error_description": "NameError: name 'utils' is not defined"})
|
result = await debug_tool.execute({"prompt": "NameError: name 'utils' is not defined"})
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
|
|
||||||
@@ -92,18 +103,23 @@ class TestDynamicContextRequests:
|
|||||||
assert "Summary" in response_data["content"]
|
assert "Summary" in response_data["content"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_malformed_clarification_request_treated_as_normal(self, mock_create_model, analyze_tool):
|
async def test_malformed_clarification_request_treated_as_normal(self, mock_get_provider, analyze_tool):
|
||||||
"""Test that malformed JSON clarification requests are treated as normal responses"""
|
"""Test that malformed JSON clarification requests are treated as normal responses"""
|
||||||
malformed_json = '{"status": "requires_clarification", "question": "Missing closing brace"'
|
malformed_json = '{"status": "requires_clarification", "prompt": "Missing closing brace"'
|
||||||
|
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text=malformed_json)]))]
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content=malformed_json,
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "question": "What does this do?"})
|
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "prompt": "What does this do?"})
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
|
|
||||||
@@ -113,8 +129,8 @@ class TestDynamicContextRequests:
|
|||||||
assert malformed_json in response_data["content"]
|
assert malformed_json in response_data["content"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_clarification_with_suggested_action(self, mock_create_model, debug_tool):
|
async def test_clarification_with_suggested_action(self, mock_get_provider, debug_tool):
|
||||||
"""Test clarification request with suggested next action"""
|
"""Test clarification request with suggested next action"""
|
||||||
clarification_json = json.dumps(
|
clarification_json = json.dumps(
|
||||||
{
|
{
|
||||||
@@ -124,7 +140,7 @@ class TestDynamicContextRequests:
|
|||||||
"suggested_next_action": {
|
"suggested_next_action": {
|
||||||
"tool": "debug",
|
"tool": "debug",
|
||||||
"args": {
|
"args": {
|
||||||
"error_description": "Connection timeout to database",
|
"prompt": "Connection timeout to database",
|
||||||
"files": [
|
"files": [
|
||||||
"/config/database.yml",
|
"/config/database.yml",
|
||||||
"/src/db.py",
|
"/src/db.py",
|
||||||
@@ -135,15 +151,20 @@ class TestDynamicContextRequests:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content=clarification_json,
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await debug_tool.execute(
|
result = await debug_tool.execute(
|
||||||
{
|
{
|
||||||
"error_description": "Connection timeout to database",
|
"prompt": "Connection timeout to database",
|
||||||
"files": ["/absolute/logs/error.log"],
|
"files": ["/absolute/logs/error.log"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -187,12 +208,12 @@ class TestDynamicContextRequests:
|
|||||||
assert request.suggested_next_action["tool"] == "analyze"
|
assert request.suggested_next_action["tool"] == "analyze"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_error_response_format(self, mock_create_model, analyze_tool):
|
async def test_error_response_format(self, mock_get_provider, analyze_tool):
|
||||||
"""Test error response format"""
|
"""Test error response format"""
|
||||||
mock_create_model.side_effect = Exception("API connection failed")
|
mock_get_provider.side_effect = Exception("API connection failed")
|
||||||
|
|
||||||
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "question": "Analyze this"})
|
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "prompt": "Analyze this"})
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
|
|
||||||
@@ -206,8 +227,8 @@ class TestCollaborationWorkflow:
|
|||||||
"""Test complete collaboration workflows"""
|
"""Test complete collaboration workflows"""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_dependency_analysis_triggers_clarification(self, mock_create_model):
|
async def test_dependency_analysis_triggers_clarification(self, mock_get_provider):
|
||||||
"""Test that asking about dependencies without package files triggers clarification"""
|
"""Test that asking about dependencies without package files triggers clarification"""
|
||||||
tool = AnalyzeTool()
|
tool = AnalyzeTool()
|
||||||
|
|
||||||
@@ -220,17 +241,22 @@ class TestCollaborationWorkflow:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content=clarification_json,
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Ask about dependencies with only source files
|
# Ask about dependencies with only source files
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"files": ["/absolute/path/src/index.js"],
|
"files": ["/absolute/path/src/index.js"],
|
||||||
"question": "What npm packages and versions does this project use?",
|
"prompt": "What npm packages and versions does this project use?",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -243,8 +269,8 @@ class TestCollaborationWorkflow:
|
|||||||
assert "package.json" in str(clarification["files_needed"]), "Should specifically request package.json"
|
assert "package.json" in str(clarification["files_needed"]), "Should specifically request package.json"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_multi_step_collaboration(self, mock_create_model):
|
async def test_multi_step_collaboration(self, mock_get_provider):
|
||||||
"""Test a multi-step collaboration workflow"""
|
"""Test a multi-step collaboration workflow"""
|
||||||
tool = DebugIssueTool()
|
tool = DebugIssueTool()
|
||||||
|
|
||||||
@@ -257,15 +283,20 @@ class TestCollaborationWorkflow:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content=clarification_json,
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result1 = await tool.execute(
|
result1 = await tool.execute(
|
||||||
{
|
{
|
||||||
"error_description": "Database connection timeout",
|
"prompt": "Database connection timeout",
|
||||||
"error_context": "Timeout after 30s",
|
"error_context": "Timeout after 30s",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -285,13 +316,16 @@ class TestCollaborationWorkflow:
|
|||||||
**Root Cause:** The config.py file shows the database host is set to 'localhost' but the database is running on a different server.
|
**Root Cause:** The config.py file shows the database host is set to 'localhost' but the database is running on a different server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text=final_response)]))]
|
content=final_response,
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
|
|
||||||
result2 = await tool.execute(
|
result2 = await tool.execute(
|
||||||
{
|
{
|
||||||
"error_description": "Database connection timeout",
|
"prompt": "Database connection timeout",
|
||||||
"error_context": "Timeout after 30s",
|
"error_context": "Timeout after 30s",
|
||||||
"files": ["/absolute/path/config.py"], # Additional context provided
|
"files": ["/absolute/path/config.py"], # Additional context provided
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ class TestConfig:
|
|||||||
|
|
||||||
def test_model_config(self):
|
def test_model_config(self):
|
||||||
"""Test model configuration"""
|
"""Test model configuration"""
|
||||||
assert DEFAULT_MODEL == "gemini-2.5-pro-preview-06-05"
|
# DEFAULT_MODEL is set in conftest.py for tests
|
||||||
|
assert DEFAULT_MODEL == "gemini-2.0-flash-exp"
|
||||||
assert MAX_CONTEXT_TOKENS == 1_000_000
|
assert MAX_CONTEXT_TOKENS == 1_000_000
|
||||||
|
|
||||||
def test_temperature_defaults(self):
|
def test_temperature_defaults(self):
|
||||||
|
|||||||
171
tests/test_conversation_field_mapping.py
Normal file
171
tests/test_conversation_field_mapping.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""
|
||||||
|
Test that conversation history is correctly mapped to tool-specific fields
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from server import reconstruct_thread_context
|
||||||
|
from utils.conversation_memory import ConversationTurn, ThreadContext
|
||||||
|
from providers.base import ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_conversation_history_field_mapping():
|
||||||
|
"""Test that enhanced prompts are mapped to prompt field for all tools"""
|
||||||
|
|
||||||
|
# Test data for different tools - all use 'prompt' now
|
||||||
|
test_cases = [
|
||||||
|
{
|
||||||
|
"tool_name": "analyze",
|
||||||
|
"original_value": "What does this code do?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tool_name": "chat",
|
||||||
|
"original_value": "Explain this concept",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tool_name": "debug",
|
||||||
|
"original_value": "Getting undefined error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tool_name": "codereview",
|
||||||
|
"original_value": "Review this implementation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tool_name": "thinkdeep",
|
||||||
|
"original_value": "My analysis so far",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for test_case in test_cases:
|
||||||
|
# Create mock conversation context
|
||||||
|
mock_context = ThreadContext(
|
||||||
|
thread_id="test-thread-123",
|
||||||
|
tool_name=test_case["tool_name"],
|
||||||
|
created_at=datetime.now().isoformat(),
|
||||||
|
last_updated_at=datetime.now().isoformat(),
|
||||||
|
turns=[
|
||||||
|
ConversationTurn(
|
||||||
|
role="user",
|
||||||
|
content="Previous user message",
|
||||||
|
timestamp=datetime.now().isoformat(),
|
||||||
|
files=["/test/file1.py"],
|
||||||
|
),
|
||||||
|
ConversationTurn(
|
||||||
|
role="assistant",
|
||||||
|
content="Previous assistant response",
|
||||||
|
timestamp=datetime.now().isoformat(),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
initial_context={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock get_thread to return our test context
|
||||||
|
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
|
||||||
|
with patch("utils.conversation_memory.add_turn", return_value=True):
|
||||||
|
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
|
||||||
|
# Mock provider registry to avoid model lookup errors
|
||||||
|
with patch("providers.registry.ModelProviderRegistry.get_provider_for_model") as mock_get_provider:
|
||||||
|
from providers.base import ModelCapabilities
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.get_capabilities.return_value = ModelCapabilities(
|
||||||
|
provider=ProviderType.GOOGLE,
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
friendly_name="Gemini",
|
||||||
|
max_tokens=200000,
|
||||||
|
supports_extended_thinking=True
|
||||||
|
)
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
# Mock conversation history building
|
||||||
|
mock_build.return_value = (
|
||||||
|
"=== CONVERSATION HISTORY ===\nPrevious conversation content\n=== END HISTORY ===",
|
||||||
|
1000 # mock token count
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create arguments with continuation_id
|
||||||
|
arguments = {
|
||||||
|
"continuation_id": "test-thread-123",
|
||||||
|
"prompt": test_case["original_value"],
|
||||||
|
"files": ["/test/file2.py"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call reconstruct_thread_context
|
||||||
|
enhanced_args = await reconstruct_thread_context(arguments)
|
||||||
|
|
||||||
|
# Verify the enhanced prompt is in the prompt field
|
||||||
|
assert "prompt" in enhanced_args
|
||||||
|
enhanced_value = enhanced_args["prompt"]
|
||||||
|
|
||||||
|
# Should contain conversation history
|
||||||
|
assert "=== CONVERSATION HISTORY ===" in enhanced_value
|
||||||
|
assert "Previous conversation content" in enhanced_value
|
||||||
|
|
||||||
|
# Should contain the new user input
|
||||||
|
assert "=== NEW USER INPUT ===" in enhanced_value
|
||||||
|
assert test_case["original_value"] in enhanced_value
|
||||||
|
|
||||||
|
# Should have token budget
|
||||||
|
assert "_remaining_tokens" in enhanced_args
|
||||||
|
assert enhanced_args["_remaining_tokens"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_tool_defaults_to_prompt():
|
||||||
|
"""Test that unknown tools default to using 'prompt' field"""
|
||||||
|
|
||||||
|
mock_context = ThreadContext(
|
||||||
|
thread_id="test-thread-456",
|
||||||
|
tool_name="unknown_tool",
|
||||||
|
created_at=datetime.now().isoformat(),
|
||||||
|
last_updated_at=datetime.now().isoformat(),
|
||||||
|
turns=[],
|
||||||
|
initial_context={},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
|
||||||
|
with patch("utils.conversation_memory.add_turn", return_value=True):
|
||||||
|
with patch("utils.conversation_memory.build_conversation_history", return_value=("History", 500)):
|
||||||
|
arguments = {
|
||||||
|
"continuation_id": "test-thread-456",
|
||||||
|
"prompt": "User input",
|
||||||
|
}
|
||||||
|
|
||||||
|
enhanced_args = await reconstruct_thread_context(arguments)
|
||||||
|
|
||||||
|
# Should default to 'prompt' field
|
||||||
|
assert "prompt" in enhanced_args
|
||||||
|
assert "History" in enhanced_args["prompt"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_parameter_standardization():
|
||||||
|
"""Test that all tools use standardized 'prompt' parameter"""
|
||||||
|
from tools.analyze import AnalyzeRequest
|
||||||
|
from tools.debug import DebugIssueRequest
|
||||||
|
from tools.codereview import CodeReviewRequest
|
||||||
|
from tools.thinkdeep import ThinkDeepRequest
|
||||||
|
from tools.precommit import PrecommitRequest
|
||||||
|
|
||||||
|
# Test analyze tool uses prompt
|
||||||
|
analyze = AnalyzeRequest(files=["/test.py"], prompt="What does this do?")
|
||||||
|
assert analyze.prompt == "What does this do?"
|
||||||
|
|
||||||
|
# Test debug tool uses prompt
|
||||||
|
debug = DebugIssueRequest(prompt="Error occurred")
|
||||||
|
assert debug.prompt == "Error occurred"
|
||||||
|
|
||||||
|
# Test codereview tool uses prompt
|
||||||
|
review = CodeReviewRequest(files=["/test.py"], prompt="Review this")
|
||||||
|
assert review.prompt == "Review this"
|
||||||
|
|
||||||
|
# Test thinkdeep tool uses prompt
|
||||||
|
think = ThinkDeepRequest(prompt="My analysis")
|
||||||
|
assert think.prompt == "My analysis"
|
||||||
|
|
||||||
|
# Test precommit tool uses prompt (optional)
|
||||||
|
precommit = PrecommitRequest(path="/repo", prompt="Fix bug")
|
||||||
|
assert precommit.prompt == "Fix bug"
|
||||||
@@ -12,6 +12,7 @@ Claude had shared in earlier turns.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@@ -94,7 +95,7 @@ class TestConversationHistoryBugFix:
|
|||||||
files=["/src/auth.py", "/tests/test_auth.py"], # Files from codereview tool
|
files=["/src/auth.py", "/tests/test_auth.py"], # Files from codereview tool
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
initial_context={"question": "Analyze authentication security"},
|
initial_context={"prompt": "Analyze authentication security"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock add_turn to return success
|
# Mock add_turn to return success
|
||||||
@@ -103,23 +104,23 @@ class TestConversationHistoryBugFix:
|
|||||||
# Mock the model to capture what prompt it receives
|
# Mock the model to capture what prompt it receives
|
||||||
captured_prompt = None
|
captured_prompt = None
|
||||||
|
|
||||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response = Mock()
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
Mock(
|
|
||||||
content=Mock(parts=[Mock(text="Response with conversation context")]),
|
|
||||||
finish_reason="STOP",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def capture_prompt(prompt):
|
def capture_prompt(prompt, **kwargs):
|
||||||
nonlocal captured_prompt
|
nonlocal captured_prompt
|
||||||
captured_prompt = prompt
|
captured_prompt = prompt
|
||||||
return mock_response
|
return Mock(
|
||||||
|
content="Response with conversation context",
|
||||||
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={"finish_reason": "STOP"}
|
||||||
|
)
|
||||||
|
|
||||||
mock_model.generate_content.side_effect = capture_prompt
|
mock_provider.generate_content.side_effect = capture_prompt
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Execute tool with continuation_id
|
# Execute tool with continuation_id
|
||||||
# In the corrected flow, server.py:reconstruct_thread_context
|
# In the corrected flow, server.py:reconstruct_thread_context
|
||||||
@@ -163,23 +164,23 @@ class TestConversationHistoryBugFix:
|
|||||||
|
|
||||||
captured_prompt = None
|
captured_prompt = None
|
||||||
|
|
||||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response = Mock()
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
Mock(
|
|
||||||
content=Mock(parts=[Mock(text="Response without history")]),
|
|
||||||
finish_reason="STOP",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def capture_prompt(prompt):
|
def capture_prompt(prompt, **kwargs):
|
||||||
nonlocal captured_prompt
|
nonlocal captured_prompt
|
||||||
captured_prompt = prompt
|
captured_prompt = prompt
|
||||||
return mock_response
|
return Mock(
|
||||||
|
content="Response without history",
|
||||||
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={"finish_reason": "STOP"}
|
||||||
|
)
|
||||||
|
|
||||||
mock_model.generate_content.side_effect = capture_prompt
|
mock_provider.generate_content.side_effect = capture_prompt
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Execute tool with continuation_id for non-existent thread
|
# Execute tool with continuation_id for non-existent thread
|
||||||
# In the real flow, server.py would have already handled the missing thread
|
# In the real flow, server.py would have already handled the missing thread
|
||||||
@@ -201,23 +202,23 @@ class TestConversationHistoryBugFix:
|
|||||||
|
|
||||||
captured_prompt = None
|
captured_prompt = None
|
||||||
|
|
||||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response = Mock()
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
Mock(
|
|
||||||
content=Mock(parts=[Mock(text="New conversation response")]),
|
|
||||||
finish_reason="STOP",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def capture_prompt(prompt):
|
def capture_prompt(prompt, **kwargs):
|
||||||
nonlocal captured_prompt
|
nonlocal captured_prompt
|
||||||
captured_prompt = prompt
|
captured_prompt = prompt
|
||||||
return mock_response
|
return Mock(
|
||||||
|
content="New conversation response",
|
||||||
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={"finish_reason": "STOP"}
|
||||||
|
)
|
||||||
|
|
||||||
mock_model.generate_content.side_effect = capture_prompt
|
mock_provider.generate_content.side_effect = capture_prompt
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Execute tool without continuation_id (new conversation)
|
# Execute tool without continuation_id (new conversation)
|
||||||
arguments = {"prompt": "Start new conversation", "files": ["/src/new_file.py"]}
|
arguments = {"prompt": "Start new conversation", "files": ["/src/new_file.py"]}
|
||||||
@@ -275,7 +276,7 @@ class TestConversationHistoryBugFix:
|
|||||||
files=["/src/auth.py", "/tests/test_auth.py"], # auth.py referenced again + new file
|
files=["/src/auth.py", "/tests/test_auth.py"], # auth.py referenced again + new file
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
initial_context={"question": "Analyze authentication security"},
|
initial_context={"prompt": "Analyze authentication security"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock get_thread to return our test context
|
# Mock get_thread to return our test context
|
||||||
@@ -285,23 +286,23 @@ class TestConversationHistoryBugFix:
|
|||||||
# Mock the model to capture what prompt it receives
|
# Mock the model to capture what prompt it receives
|
||||||
captured_prompt = None
|
captured_prompt = None
|
||||||
|
|
||||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response = Mock()
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
Mock(
|
|
||||||
content=Mock(parts=[Mock(text="Analysis of new files complete")]),
|
|
||||||
finish_reason="STOP",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def capture_prompt(prompt):
|
def capture_prompt(prompt, **kwargs):
|
||||||
nonlocal captured_prompt
|
nonlocal captured_prompt
|
||||||
captured_prompt = prompt
|
captured_prompt = prompt
|
||||||
return mock_response
|
return Mock(
|
||||||
|
content="Analysis of new files complete",
|
||||||
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={"finish_reason": "STOP"}
|
||||||
|
)
|
||||||
|
|
||||||
mock_model.generate_content.side_effect = capture_prompt
|
mock_provider.generate_content.side_effect = capture_prompt
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Mock read_files to simulate file existence and capture its calls
|
# Mock read_files to simulate file existence and capture its calls
|
||||||
with patch("tools.base.read_files") as mock_read_files:
|
with patch("tools.base.read_files") as mock_read_files:
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ class TestConversationMemory:
|
|||||||
initial_context={},
|
initial_context={},
|
||||||
)
|
)
|
||||||
|
|
||||||
history, tokens = build_conversation_history(context)
|
history, tokens = build_conversation_history(context, model_context=None)
|
||||||
|
|
||||||
# Test basic structure
|
# Test basic structure
|
||||||
assert "CONVERSATION HISTORY" in history
|
assert "CONVERSATION HISTORY" in history
|
||||||
@@ -207,7 +207,7 @@ class TestConversationMemory:
|
|||||||
initial_context={},
|
initial_context={},
|
||||||
)
|
)
|
||||||
|
|
||||||
history, tokens = build_conversation_history(context)
|
history, tokens = build_conversation_history(context, model_context=None)
|
||||||
assert history == ""
|
assert history == ""
|
||||||
assert tokens == 0
|
assert tokens == 0
|
||||||
|
|
||||||
@@ -374,7 +374,7 @@ class TestConversationFlow:
|
|||||||
initial_context={},
|
initial_context={},
|
||||||
)
|
)
|
||||||
|
|
||||||
history, tokens = build_conversation_history(context)
|
history, tokens = build_conversation_history(context, model_context=None)
|
||||||
expected_turn_text = f"Turn {test_max}/{MAX_CONVERSATION_TURNS}"
|
expected_turn_text = f"Turn {test_max}/{MAX_CONVERSATION_TURNS}"
|
||||||
assert expected_turn_text in history
|
assert expected_turn_text in history
|
||||||
|
|
||||||
@@ -763,7 +763,7 @@ class TestConversationFlow:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build conversation history (should handle token limits gracefully)
|
# Build conversation history (should handle token limits gracefully)
|
||||||
history, tokens = build_conversation_history(context)
|
history, tokens = build_conversation_history(context, model_context=None)
|
||||||
|
|
||||||
# Verify the history was built successfully
|
# Verify the history was built successfully
|
||||||
assert "=== CONVERSATION HISTORY ===" in history
|
assert "=== CONVERSATION HISTORY ===" in history
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ allowing multi-turn conversations to span multiple tool types.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@@ -98,15 +99,12 @@ class TestCrossToolContinuation:
|
|||||||
mock_redis.return_value = mock_client
|
mock_redis.return_value = mock_client
|
||||||
|
|
||||||
# Step 1: Analysis tool creates a conversation with follow-up
|
# Step 1: Analysis tool creates a conversation with follow-up
|
||||||
with patch.object(self.analysis_tool, "create_model") as mock_create_model:
|
with patch.object(self.analysis_tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response = Mock()
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
Mock(
|
# Include follow-up JSON in the content
|
||||||
content=Mock(
|
content_with_followup = """Found potential security issues in authentication logic.
|
||||||
parts=[
|
|
||||||
Mock(
|
|
||||||
text="""Found potential security issues in authentication logic.
|
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -115,14 +113,13 @@ class TestCrossToolContinuation:
|
|||||||
"ui_hint": "Security review recommended"
|
"ui_hint": "Security review recommended"
|
||||||
}
|
}
|
||||||
```"""
|
```"""
|
||||||
)
|
mock_provider.generate_content.return_value = Mock(
|
||||||
]
|
content=content_with_followup,
|
||||||
),
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
finish_reason="STOP",
|
model_name="gemini-2.0-flash-exp",
|
||||||
)
|
metadata={"finish_reason": "STOP"}
|
||||||
]
|
)
|
||||||
mock_model.generate_content.return_value = mock_response
|
mock_get_provider.return_value = mock_provider
|
||||||
mock_create_model.return_value = mock_model
|
|
||||||
|
|
||||||
# Execute analysis tool
|
# Execute analysis tool
|
||||||
arguments = {"code": "function authenticate(user) { return true; }"}
|
arguments = {"code": "function authenticate(user) { return true; }"}
|
||||||
@@ -160,23 +157,17 @@ class TestCrossToolContinuation:
|
|||||||
mock_client.get.side_effect = mock_get_side_effect
|
mock_client.get.side_effect = mock_get_side_effect
|
||||||
|
|
||||||
# Step 3: Review tool uses the same continuation_id
|
# Step 3: Review tool uses the same continuation_id
|
||||||
with patch.object(self.review_tool, "create_model") as mock_create_model:
|
with patch.object(self.review_tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response = Mock()
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=Mock(
|
content="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks.",
|
||||||
parts=[
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
Mock(
|
model_name="gemini-2.0-flash-exp",
|
||||||
text="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks."
|
metadata={"finish_reason": "STOP"}
|
||||||
)
|
)
|
||||||
]
|
mock_get_provider.return_value = mock_provider
|
||||||
),
|
|
||||||
finish_reason="STOP",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
mock_model.generate_content.return_value = mock_response
|
|
||||||
mock_create_model.return_value = mock_model
|
|
||||||
|
|
||||||
# Execute review tool with the continuation_id from analysis tool
|
# Execute review tool with the continuation_id from analysis tool
|
||||||
arguments = {
|
arguments = {
|
||||||
@@ -247,7 +238,7 @@ class TestCrossToolContinuation:
|
|||||||
# Build conversation history
|
# Build conversation history
|
||||||
from utils.conversation_memory import build_conversation_history
|
from utils.conversation_memory import build_conversation_history
|
||||||
|
|
||||||
history, tokens = build_conversation_history(thread_context)
|
history, tokens = build_conversation_history(thread_context, model_context=None)
|
||||||
|
|
||||||
# Verify tool names are included in the history
|
# Verify tool names are included in the history
|
||||||
assert "Turn 1 (Gemini using test_analysis)" in history
|
assert "Turn 1 (Gemini using test_analysis)" in history
|
||||||
@@ -286,17 +277,17 @@ class TestCrossToolContinuation:
|
|||||||
mock_get_thread.return_value = existing_context
|
mock_get_thread.return_value = existing_context
|
||||||
|
|
||||||
# Mock review tool response
|
# Mock review tool response
|
||||||
with patch.object(self.review_tool, "create_model") as mock_create_model:
|
with patch.object(self.review_tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response = Mock()
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=Mock(parts=[Mock(text="Security review of auth.py shows vulnerabilities")]),
|
content="Security review of auth.py shows vulnerabilities",
|
||||||
finish_reason="STOP",
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
)
|
model_name="gemini-2.0-flash-exp",
|
||||||
]
|
metadata={"finish_reason": "STOP"}
|
||||||
mock_model.generate_content.return_value = mock_response
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Execute review tool with additional files
|
# Execute review tool with additional files
|
||||||
arguments = {
|
arguments = {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from mcp.types import TextContent
|
from mcp.types import TextContent
|
||||||
@@ -68,17 +69,17 @@ class TestLargePromptHandling:
|
|||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
|
|
||||||
# Mock the model to avoid actual API calls
|
# Mock the model to avoid actual API calls
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_response = MagicMock()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
MagicMock(
|
mock_provider.generate_content.return_value = MagicMock(
|
||||||
content=MagicMock(parts=[MagicMock(text="This is a test response")]),
|
content="This is a test response",
|
||||||
finish_reason="STOP",
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
)
|
model_name="gemini-2.0-flash-exp",
|
||||||
]
|
metadata={"finish_reason": "STOP"}
|
||||||
mock_model.generate_content.return_value = mock_response
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute({"prompt": normal_prompt})
|
result = await tool.execute({"prompt": normal_prompt})
|
||||||
|
|
||||||
@@ -93,17 +94,17 @@ class TestLargePromptHandling:
|
|||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
|
|
||||||
# Mock the model
|
# Mock the model
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_response = MagicMock()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
MagicMock(
|
mock_provider.generate_content.return_value = MagicMock(
|
||||||
content=MagicMock(parts=[MagicMock(text="Processed large prompt")]),
|
content="Processed large prompt",
|
||||||
finish_reason="STOP",
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
)
|
model_name="gemini-2.0-flash-exp",
|
||||||
]
|
metadata={"finish_reason": "STOP"}
|
||||||
mock_model.generate_content.return_value = mock_response
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Mock read_file_content to avoid security checks
|
# Mock read_file_content to avoid security checks
|
||||||
with patch("tools.base.read_file_content") as mock_read_file:
|
with patch("tools.base.read_file_content") as mock_read_file:
|
||||||
@@ -123,8 +124,11 @@ class TestLargePromptHandling:
|
|||||||
mock_read_file.assert_called_once_with(temp_prompt_file)
|
mock_read_file.assert_called_once_with(temp_prompt_file)
|
||||||
|
|
||||||
# Verify the large content was used
|
# Verify the large content was used
|
||||||
call_args = mock_model.generate_content.call_args[0][0]
|
# generate_content is called with keyword arguments
|
||||||
assert large_prompt in call_args
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
|
prompt_arg = call_kwargs.get("prompt")
|
||||||
|
assert prompt_arg is not None
|
||||||
|
assert large_prompt in prompt_arg
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
temp_dir = os.path.dirname(temp_prompt_file)
|
temp_dir = os.path.dirname(temp_prompt_file)
|
||||||
@@ -134,7 +138,7 @@ class TestLargePromptHandling:
|
|||||||
async def test_thinkdeep_large_analysis(self, large_prompt):
|
async def test_thinkdeep_large_analysis(self, large_prompt):
|
||||||
"""Test that thinkdeep tool detects large current_analysis."""
|
"""Test that thinkdeep tool detects large current_analysis."""
|
||||||
tool = ThinkDeepTool()
|
tool = ThinkDeepTool()
|
||||||
result = await tool.execute({"current_analysis": large_prompt})
|
result = await tool.execute({"prompt": large_prompt})
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
output = json.loads(result[0].text)
|
output = json.loads(result[0].text)
|
||||||
@@ -148,7 +152,7 @@ class TestLargePromptHandling:
|
|||||||
{
|
{
|
||||||
"files": ["/some/file.py"],
|
"files": ["/some/file.py"],
|
||||||
"focus_on": large_prompt,
|
"focus_on": large_prompt,
|
||||||
"context": "Test code review for validation purposes",
|
"prompt": "Test code review for validation purposes",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -160,7 +164,7 @@ class TestLargePromptHandling:
|
|||||||
async def test_review_changes_large_original_request(self, large_prompt):
|
async def test_review_changes_large_original_request(self, large_prompt):
|
||||||
"""Test that review_changes tool detects large original_request."""
|
"""Test that review_changes tool detects large original_request."""
|
||||||
tool = Precommit()
|
tool = Precommit()
|
||||||
result = await tool.execute({"path": "/some/path", "original_request": large_prompt})
|
result = await tool.execute({"path": "/some/path", "prompt": large_prompt})
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
output = json.loads(result[0].text)
|
output = json.loads(result[0].text)
|
||||||
@@ -170,7 +174,7 @@ class TestLargePromptHandling:
|
|||||||
async def test_debug_large_error_description(self, large_prompt):
|
async def test_debug_large_error_description(self, large_prompt):
|
||||||
"""Test that debug tool detects large error_description."""
|
"""Test that debug tool detects large error_description."""
|
||||||
tool = DebugIssueTool()
|
tool = DebugIssueTool()
|
||||||
result = await tool.execute({"error_description": large_prompt})
|
result = await tool.execute({"prompt": large_prompt})
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
output = json.loads(result[0].text)
|
output = json.loads(result[0].text)
|
||||||
@@ -180,7 +184,7 @@ class TestLargePromptHandling:
|
|||||||
async def test_debug_large_error_context(self, large_prompt, normal_prompt):
|
async def test_debug_large_error_context(self, large_prompt, normal_prompt):
|
||||||
"""Test that debug tool detects large error_context."""
|
"""Test that debug tool detects large error_context."""
|
||||||
tool = DebugIssueTool()
|
tool = DebugIssueTool()
|
||||||
result = await tool.execute({"error_description": normal_prompt, "error_context": large_prompt})
|
result = await tool.execute({"prompt": normal_prompt, "error_context": large_prompt})
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
output = json.loads(result[0].text)
|
output = json.loads(result[0].text)
|
||||||
@@ -190,7 +194,7 @@ class TestLargePromptHandling:
|
|||||||
async def test_analyze_large_question(self, large_prompt):
|
async def test_analyze_large_question(self, large_prompt):
|
||||||
"""Test that analyze tool detects large question."""
|
"""Test that analyze tool detects large question."""
|
||||||
tool = AnalyzeTool()
|
tool = AnalyzeTool()
|
||||||
result = await tool.execute({"files": ["/some/file.py"], "question": large_prompt})
|
result = await tool.execute({"files": ["/some/file.py"], "prompt": large_prompt})
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
output = json.loads(result[0].text)
|
output = json.loads(result[0].text)
|
||||||
@@ -202,17 +206,17 @@ class TestLargePromptHandling:
|
|||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
other_file = "/some/other/file.py"
|
other_file = "/some/other/file.py"
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_response = MagicMock()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
MagicMock(
|
mock_provider.generate_content.return_value = MagicMock(
|
||||||
content=MagicMock(parts=[MagicMock(text="Success")]),
|
content="Success",
|
||||||
finish_reason="STOP",
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
)
|
model_name="gemini-2.0-flash-exp",
|
||||||
]
|
metadata={"finish_reason": "STOP"}
|
||||||
mock_model.generate_content.return_value = mock_response
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Mock the centralized file preparation method to avoid file system access
|
# Mock the centralized file preparation method to avoid file system access
|
||||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare_files:
|
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare_files:
|
||||||
@@ -235,17 +239,17 @@ class TestLargePromptHandling:
|
|||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
exact_prompt = "x" * MCP_PROMPT_SIZE_LIMIT
|
exact_prompt = "x" * MCP_PROMPT_SIZE_LIMIT
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_response = MagicMock()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
MagicMock(
|
mock_provider.generate_content.return_value = MagicMock(
|
||||||
content=MagicMock(parts=[MagicMock(text="Success")]),
|
content="Success",
|
||||||
finish_reason="STOP",
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
)
|
model_name="gemini-2.0-flash-exp",
|
||||||
]
|
metadata={"finish_reason": "STOP"}
|
||||||
mock_model.generate_content.return_value = mock_response
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute({"prompt": exact_prompt})
|
result = await tool.execute({"prompt": exact_prompt})
|
||||||
output = json.loads(result[0].text)
|
output = json.loads(result[0].text)
|
||||||
@@ -266,17 +270,17 @@ class TestLargePromptHandling:
|
|||||||
"""Test empty prompt without prompt.txt file."""
|
"""Test empty prompt without prompt.txt file."""
|
||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_response = MagicMock()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
MagicMock(
|
mock_provider.generate_content.return_value = MagicMock(
|
||||||
content=MagicMock(parts=[MagicMock(text="Success")]),
|
content="Success",
|
||||||
finish_reason="STOP",
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
)
|
model_name="gemini-2.0-flash-exp",
|
||||||
]
|
metadata={"finish_reason": "STOP"}
|
||||||
mock_model.generate_content.return_value = mock_response
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute({"prompt": ""})
|
result = await tool.execute({"prompt": ""})
|
||||||
output = json.loads(result[0].text)
|
output = json.loads(result[0].text)
|
||||||
@@ -288,17 +292,17 @@ class TestLargePromptHandling:
|
|||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
bad_file = "/nonexistent/prompt.txt"
|
bad_file = "/nonexistent/prompt.txt"
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_response = MagicMock()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_response.candidates = [
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
MagicMock(
|
mock_provider.generate_content.return_value = MagicMock(
|
||||||
content=MagicMock(parts=[MagicMock(text="Success")]),
|
content="Success",
|
||||||
finish_reason="STOP",
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
)
|
model_name="gemini-2.0-flash-exp",
|
||||||
]
|
metadata={"finish_reason": "STOP"}
|
||||||
mock_model.generate_content.return_value = mock_response
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Should continue with empty prompt when file can't be read
|
# Should continue with empty prompt when file can't be read
|
||||||
result = await tool.execute({"prompt": "", "files": [bad_file]})
|
result = await tool.execute({"prompt": "", "files": [bad_file]})
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ async def run_manual_live_tests():
|
|||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"files": [temp_path],
|
"files": [temp_path],
|
||||||
"question": "What does this code do?",
|
"prompt": "What does this code do?",
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -64,7 +64,7 @@ async def run_manual_live_tests():
|
|||||||
think_tool = ThinkDeepTool()
|
think_tool = ThinkDeepTool()
|
||||||
result = await think_tool.execute(
|
result = await think_tool.execute(
|
||||||
{
|
{
|
||||||
"current_analysis": "Testing live integration",
|
"prompt": "Testing live integration",
|
||||||
"thinking_mode": "minimal", # Fast test
|
"thinking_mode": "minimal", # Fast test
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -86,7 +86,7 @@ async def run_manual_live_tests():
|
|||||||
result = await analyze_tool.execute(
|
result = await analyze_tool.execute(
|
||||||
{
|
{
|
||||||
"files": [temp_path], # Only Python file, no package.json
|
"files": [temp_path], # Only Python file, no package.json
|
||||||
"question": "What npm packages and their versions does this project depend on? List all dependencies.",
|
"prompt": "What npm packages and their versions does this project depend on? List all dependencies.",
|
||||||
"thinking_mode": "minimal", # Fast test
|
"thinking_mode": "minimal", # Fast test
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class TestPrecommitTool:
|
|||||||
schema = tool.get_input_schema()
|
schema = tool.get_input_schema()
|
||||||
assert schema["type"] == "object"
|
assert schema["type"] == "object"
|
||||||
assert "path" in schema["properties"]
|
assert "path" in schema["properties"]
|
||||||
assert "original_request" in schema["properties"]
|
assert "prompt" in schema["properties"]
|
||||||
assert "compare_to" in schema["properties"]
|
assert "compare_to" in schema["properties"]
|
||||||
assert "review_type" in schema["properties"]
|
assert "review_type" in schema["properties"]
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ class TestPrecommitTool:
|
|||||||
"""Test request model default values"""
|
"""Test request model default values"""
|
||||||
request = PrecommitRequest(path="/some/absolute/path")
|
request = PrecommitRequest(path="/some/absolute/path")
|
||||||
assert request.path == "/some/absolute/path"
|
assert request.path == "/some/absolute/path"
|
||||||
assert request.original_request is None
|
assert request.prompt is None
|
||||||
assert request.compare_to is None
|
assert request.compare_to is None
|
||||||
assert request.include_staged is True
|
assert request.include_staged is True
|
||||||
assert request.include_unstaged is True
|
assert request.include_unstaged is True
|
||||||
@@ -48,7 +48,7 @@ class TestPrecommitTool:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_relative_path_rejected(self, tool):
|
async def test_relative_path_rejected(self, tool):
|
||||||
"""Test that relative paths are rejected"""
|
"""Test that relative paths are rejected"""
|
||||||
result = await tool.execute({"path": "./relative/path", "original_request": "Test"})
|
result = await tool.execute({"path": "./relative/path", "prompt": "Test"})
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
response = json.loads(result[0].text)
|
response = json.loads(result[0].text)
|
||||||
assert response["status"] == "error"
|
assert response["status"] == "error"
|
||||||
@@ -128,7 +128,7 @@ class TestPrecommitTool:
|
|||||||
|
|
||||||
request = PrecommitRequest(
|
request = PrecommitRequest(
|
||||||
path="/absolute/repo/path",
|
path="/absolute/repo/path",
|
||||||
original_request="Add hello message",
|
prompt="Add hello message",
|
||||||
review_type="security",
|
review_type="security",
|
||||||
)
|
)
|
||||||
result = await tool.prepare_prompt(request)
|
result = await tool.prepare_prompt(request)
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
|||||||
temp_dir, config_path = temp_repo
|
temp_dir, config_path = temp_repo
|
||||||
|
|
||||||
# Create request with files parameter
|
# Create request with files parameter
|
||||||
request = PrecommitRequest(path=temp_dir, files=[config_path], original_request="Test configuration changes")
|
request = PrecommitRequest(path=temp_dir, files=[config_path], prompt="Test configuration changes")
|
||||||
|
|
||||||
# Generate the prompt
|
# Generate the prompt
|
||||||
prompt = await tool.prepare_prompt(request)
|
prompt = await tool.prepare_prompt(request)
|
||||||
@@ -152,7 +152,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
|||||||
# Mock conversation memory functions to use our mock redis
|
# Mock conversation memory functions to use our mock redis
|
||||||
with patch("utils.conversation_memory.get_redis_client", return_value=mock_redis):
|
with patch("utils.conversation_memory.get_redis_client", return_value=mock_redis):
|
||||||
# First request - should embed file content
|
# First request - should embed file content
|
||||||
PrecommitRequest(path=temp_dir, files=[config_path], original_request="First review")
|
PrecommitRequest(path=temp_dir, files=[config_path], prompt="First review")
|
||||||
|
|
||||||
# Simulate conversation thread creation
|
# Simulate conversation thread creation
|
||||||
from utils.conversation_memory import add_turn, create_thread
|
from utils.conversation_memory import add_turn, create_thread
|
||||||
@@ -168,7 +168,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
|||||||
|
|
||||||
# Second request with continuation - should skip already embedded files
|
# Second request with continuation - should skip already embedded files
|
||||||
PrecommitRequest(
|
PrecommitRequest(
|
||||||
path=temp_dir, files=[config_path], continuation_id=thread_id, original_request="Follow-up review"
|
path=temp_dir, files=[config_path], continuation_id=thread_id, prompt="Follow-up review"
|
||||||
)
|
)
|
||||||
|
|
||||||
files_to_embed_2 = tool.filter_new_files([config_path], thread_id)
|
files_to_embed_2 = tool.filter_new_files([config_path], thread_id)
|
||||||
@@ -182,7 +182,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
|||||||
request = PrecommitRequest(
|
request = PrecommitRequest(
|
||||||
path=temp_dir,
|
path=temp_dir,
|
||||||
files=[config_path],
|
files=[config_path],
|
||||||
original_request="Validate prompt structure",
|
prompt="Validate prompt structure",
|
||||||
review_type="full",
|
review_type="full",
|
||||||
severity_filter="high",
|
severity_filter="high",
|
||||||
)
|
)
|
||||||
@@ -191,7 +191,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
|||||||
|
|
||||||
# Split prompt into sections
|
# Split prompt into sections
|
||||||
sections = {
|
sections = {
|
||||||
"original_request": "## Original Request",
|
"prompt": "## Original Request",
|
||||||
"review_parameters": "## Review Parameters",
|
"review_parameters": "## Review Parameters",
|
||||||
"repo_summary": "## Repository Changes Summary",
|
"repo_summary": "## Repository Changes Summary",
|
||||||
"context_files_summary": "## Context Files Summary",
|
"context_files_summary": "## Context Files Summary",
|
||||||
@@ -207,7 +207,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
|||||||
section_indices[name] = index
|
section_indices[name] = index
|
||||||
|
|
||||||
# Verify sections appear in logical order
|
# Verify sections appear in logical order
|
||||||
assert section_indices["original_request"] < section_indices["review_parameters"]
|
assert section_indices["prompt"] < section_indices["review_parameters"]
|
||||||
assert section_indices["review_parameters"] < section_indices["repo_summary"]
|
assert section_indices["review_parameters"] < section_indices["repo_summary"]
|
||||||
assert section_indices["git_diffs"] < section_indices["additional_context"]
|
assert section_indices["git_diffs"] < section_indices["additional_context"]
|
||||||
assert section_indices["additional_context"] < section_indices["review_instructions"]
|
assert section_indices["additional_context"] < section_indices["review_instructions"]
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ normal-sized prompts after implementing the large prompt handling feature.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -24,16 +25,16 @@ class TestPromptRegression:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_model_response(self):
|
def mock_model_response(self):
|
||||||
"""Create a mock model response."""
|
"""Create a mock model response."""
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
def _create_response(text="Test response"):
|
def _create_response(text="Test response"):
|
||||||
mock_response = MagicMock()
|
# Return a Mock that acts like ModelResponse
|
||||||
mock_response.candidates = [
|
return Mock(
|
||||||
MagicMock(
|
content=text,
|
||||||
content=MagicMock(parts=[MagicMock(text=text)]),
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
finish_reason="STOP",
|
model_name="gemini-2.0-flash-exp",
|
||||||
)
|
metadata={"finish_reason": "STOP"}
|
||||||
]
|
)
|
||||||
return mock_response
|
|
||||||
|
|
||||||
return _create_response
|
return _create_response
|
||||||
|
|
||||||
@@ -42,10 +43,12 @@ class TestPromptRegression:
|
|||||||
"""Test chat tool with normal prompt."""
|
"""Test chat tool with normal prompt."""
|
||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response("This is a helpful response about Python.")
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_create_model.return_value = mock_model
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response("This is a helpful response about Python.")
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute({"prompt": "Explain Python decorators"})
|
result = await tool.execute({"prompt": "Explain Python decorators"})
|
||||||
|
|
||||||
@@ -54,18 +57,20 @@ class TestPromptRegression:
|
|||||||
assert output["status"] == "success"
|
assert output["status"] == "success"
|
||||||
assert "helpful response about Python" in output["content"]
|
assert "helpful response about Python" in output["content"]
|
||||||
|
|
||||||
# Verify model was called
|
# Verify provider was called
|
||||||
mock_model.generate_content.assert_called_once()
|
mock_provider.generate_content.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_with_files(self, mock_model_response):
|
async def test_chat_with_files(self, mock_model_response):
|
||||||
"""Test chat tool with files parameter."""
|
"""Test chat tool with files parameter."""
|
||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_create_model.return_value = mock_model
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response()
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Mock file reading through the centralized method
|
# Mock file reading through the centralized method
|
||||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare_files:
|
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare_files:
|
||||||
@@ -83,16 +88,18 @@ class TestPromptRegression:
|
|||||||
"""Test thinkdeep tool with normal analysis."""
|
"""Test thinkdeep tool with normal analysis."""
|
||||||
tool = ThinkDeepTool()
|
tool = ThinkDeepTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response(
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response(
|
||||||
"Here's a deeper analysis with edge cases..."
|
"Here's a deeper analysis with edge cases..."
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"current_analysis": "I think we should use a cache for performance",
|
"prompt": "I think we should use a cache for performance",
|
||||||
"problem_context": "Building a high-traffic API",
|
"problem_context": "Building a high-traffic API",
|
||||||
"focus_areas": ["scalability", "reliability"],
|
"focus_areas": ["scalability", "reliability"],
|
||||||
}
|
}
|
||||||
@@ -109,12 +116,14 @@ class TestPromptRegression:
|
|||||||
"""Test codereview tool with normal inputs."""
|
"""Test codereview tool with normal inputs."""
|
||||||
tool = CodeReviewTool()
|
tool = CodeReviewTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response(
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response(
|
||||||
"Found 3 issues: 1) Missing error handling..."
|
"Found 3 issues: 1) Missing error handling..."
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Mock file reading
|
# Mock file reading
|
||||||
with patch("tools.base.read_files") as mock_read_files:
|
with patch("tools.base.read_files") as mock_read_files:
|
||||||
@@ -125,7 +134,7 @@ class TestPromptRegression:
|
|||||||
"files": ["/path/to/code.py"],
|
"files": ["/path/to/code.py"],
|
||||||
"review_type": "security",
|
"review_type": "security",
|
||||||
"focus_on": "Look for SQL injection vulnerabilities",
|
"focus_on": "Look for SQL injection vulnerabilities",
|
||||||
"context": "Test code review for validation purposes",
|
"prompt": "Test code review for validation purposes",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -139,12 +148,14 @@ class TestPromptRegression:
|
|||||||
"""Test review_changes tool with normal original_request."""
|
"""Test review_changes tool with normal original_request."""
|
||||||
tool = Precommit()
|
tool = Precommit()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response(
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response(
|
||||||
"Changes look good, implementing feature as requested..."
|
"Changes look good, implementing feature as requested..."
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Mock git operations
|
# Mock git operations
|
||||||
with patch("tools.precommit.find_git_repositories") as mock_find_repos:
|
with patch("tools.precommit.find_git_repositories") as mock_find_repos:
|
||||||
@@ -158,7 +169,7 @@ class TestPromptRegression:
|
|||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"path": "/path/to/repo",
|
"path": "/path/to/repo",
|
||||||
"original_request": "Add user authentication feature with JWT tokens",
|
"prompt": "Add user authentication feature with JWT tokens",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -171,16 +182,18 @@ class TestPromptRegression:
|
|||||||
"""Test debug tool with normal error description."""
|
"""Test debug tool with normal error description."""
|
||||||
tool = DebugIssueTool()
|
tool = DebugIssueTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response(
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response(
|
||||||
"Root cause: The variable is undefined. Fix: Initialize it..."
|
"Root cause: The variable is undefined. Fix: Initialize it..."
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"error_description": "TypeError: Cannot read property 'name' of undefined",
|
"prompt": "TypeError: Cannot read property 'name' of undefined",
|
||||||
"error_context": "at line 42 in user.js\n console.log(user.name)",
|
"error_context": "at line 42 in user.js\n console.log(user.name)",
|
||||||
"runtime_info": "Node.js v16.14.0",
|
"runtime_info": "Node.js v16.14.0",
|
||||||
}
|
}
|
||||||
@@ -197,12 +210,14 @@ class TestPromptRegression:
|
|||||||
"""Test analyze tool with normal question."""
|
"""Test analyze tool with normal question."""
|
||||||
tool = AnalyzeTool()
|
tool = AnalyzeTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response(
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response(
|
||||||
"The code follows MVC pattern with clear separation..."
|
"The code follows MVC pattern with clear separation..."
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Mock file reading
|
# Mock file reading
|
||||||
with patch("tools.base.read_files") as mock_read_files:
|
with patch("tools.base.read_files") as mock_read_files:
|
||||||
@@ -211,7 +226,7 @@ class TestPromptRegression:
|
|||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"files": ["/path/to/project"],
|
"files": ["/path/to/project"],
|
||||||
"question": "What design patterns are used in this codebase?",
|
"prompt": "What design patterns are used in this codebase?",
|
||||||
"analysis_type": "architecture",
|
"analysis_type": "architecture",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -226,10 +241,12 @@ class TestPromptRegression:
|
|||||||
"""Test tools work with empty optional fields."""
|
"""Test tools work with empty optional fields."""
|
||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_create_model.return_value = mock_model
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response()
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# Test with no files parameter
|
# Test with no files parameter
|
||||||
result = await tool.execute({"prompt": "Hello"})
|
result = await tool.execute({"prompt": "Hello"})
|
||||||
@@ -243,10 +260,12 @@ class TestPromptRegression:
|
|||||||
"""Test that thinking modes are properly passed through."""
|
"""Test that thinking modes are properly passed through."""
|
||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_create_model.return_value = mock_model
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response()
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute({"prompt": "Test", "thinking_mode": "high", "temperature": 0.8})
|
result = await tool.execute({"prompt": "Test", "thinking_mode": "high", "temperature": 0.8})
|
||||||
|
|
||||||
@@ -254,21 +273,24 @@ class TestPromptRegression:
|
|||||||
output = json.loads(result[0].text)
|
output = json.loads(result[0].text)
|
||||||
assert output["status"] == "success"
|
assert output["status"] == "success"
|
||||||
|
|
||||||
# Verify create_model was called with correct parameters
|
# Verify generate_content was called with correct parameters
|
||||||
mock_create_model.assert_called_once()
|
mock_provider.generate_content.assert_called_once()
|
||||||
call_args = mock_create_model.call_args
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
assert call_args[0][2] == "high" # thinking_mode
|
assert call_kwargs.get("temperature") == 0.8
|
||||||
assert call_args[0][1] == 0.8 # temperature
|
# thinking_mode would be passed if the provider supports it
|
||||||
|
# In this test, we set supports_thinking_mode to False, so it won't be passed
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_special_characters_in_prompts(self, mock_model_response):
|
async def test_special_characters_in_prompts(self, mock_model_response):
|
||||||
"""Test prompts with special characters work correctly."""
|
"""Test prompts with special characters work correctly."""
|
||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_create_model.return_value = mock_model
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response()
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
special_prompt = 'Test with "quotes" and\nnewlines\tand tabs'
|
special_prompt = 'Test with "quotes" and\nnewlines\tand tabs'
|
||||||
result = await tool.execute({"prompt": special_prompt})
|
result = await tool.execute({"prompt": special_prompt})
|
||||||
@@ -282,10 +304,12 @@ class TestPromptRegression:
|
|||||||
"""Test handling of various file path formats."""
|
"""Test handling of various file path formats."""
|
||||||
tool = AnalyzeTool()
|
tool = AnalyzeTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_create_model.return_value = mock_model
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response()
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
with patch("tools.base.read_files") as mock_read_files:
|
with patch("tools.base.read_files") as mock_read_files:
|
||||||
mock_read_files.return_value = "Content"
|
mock_read_files.return_value = "Content"
|
||||||
@@ -297,7 +321,7 @@ class TestPromptRegression:
|
|||||||
"/Users/name/project/src/",
|
"/Users/name/project/src/",
|
||||||
"/home/user/code.js",
|
"/home/user/code.js",
|
||||||
],
|
],
|
||||||
"question": "Analyze these files",
|
"prompt": "Analyze these files",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -311,10 +335,12 @@ class TestPromptRegression:
|
|||||||
"""Test handling of unicode content in prompts."""
|
"""Test handling of unicode content in prompts."""
|
||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
|
|
||||||
with patch.object(tool, "create_model") as mock_create_model:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_model = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_model.generate_content.return_value = mock_model_response()
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_create_model.return_value = mock_model
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = mock_model_response()
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
unicode_prompt = "Explain this: 你好世界 مرحبا بالعالم"
|
unicode_prompt = "Explain this: 你好世界 مرحبا بالعالم"
|
||||||
result = await tool.execute({"prompt": unicode_prompt})
|
result = await tool.execute({"prompt": unicode_prompt})
|
||||||
|
|||||||
187
tests/test_providers.py
Normal file
187
tests/test_providers.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
"""Tests for the model provider abstraction system"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
import os
|
||||||
|
|
||||||
|
from providers import ModelProviderRegistry, ModelProvider, ModelResponse, ModelCapabilities
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelProviderRegistry:
|
||||||
|
"""Test the model provider registry"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Clear registry before each test"""
|
||||||
|
ModelProviderRegistry._providers.clear()
|
||||||
|
ModelProviderRegistry._initialized_providers.clear()
|
||||||
|
|
||||||
|
def test_register_provider(self):
|
||||||
|
"""Test registering a provider"""
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
|
assert ProviderType.GOOGLE in ModelProviderRegistry._providers
|
||||||
|
assert ModelProviderRegistry._providers[ProviderType.GOOGLE] == GeminiModelProvider
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
|
||||||
|
def test_get_provider(self):
|
||||||
|
"""Test getting a provider instance"""
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
|
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
||||||
|
|
||||||
|
assert provider is not None
|
||||||
|
assert isinstance(provider, GeminiModelProvider)
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {}, clear=True)
|
||||||
|
def test_get_provider_no_api_key(self):
|
||||||
|
"""Test getting provider without API key returns None"""
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
|
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
||||||
|
|
||||||
|
assert provider is None
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
|
||||||
|
def test_get_provider_for_model(self):
|
||||||
|
"""Test getting provider for a specific model"""
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
|
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash-exp")
|
||||||
|
|
||||||
|
assert provider is not None
|
||||||
|
assert isinstance(provider, GeminiModelProvider)
|
||||||
|
|
||||||
|
def test_get_available_providers(self):
|
||||||
|
"""Test getting list of available providers"""
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
|
providers = ModelProviderRegistry.get_available_providers()
|
||||||
|
|
||||||
|
assert len(providers) == 2
|
||||||
|
assert ProviderType.GOOGLE in providers
|
||||||
|
assert ProviderType.OPENAI in providers
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeminiProvider:
|
||||||
|
"""Test Gemini model provider"""
|
||||||
|
|
||||||
|
def test_provider_initialization(self):
|
||||||
|
"""Test provider initialization"""
|
||||||
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.get_provider_type() == ProviderType.GOOGLE
|
||||||
|
|
||||||
|
def test_get_capabilities(self):
|
||||||
|
"""Test getting model capabilities"""
|
||||||
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
capabilities = provider.get_capabilities("gemini-2.0-flash-exp")
|
||||||
|
|
||||||
|
assert capabilities.provider == ProviderType.GOOGLE
|
||||||
|
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
||||||
|
assert capabilities.max_tokens == 1_048_576
|
||||||
|
assert not capabilities.supports_extended_thinking
|
||||||
|
|
||||||
|
def test_get_capabilities_pro_model(self):
|
||||||
|
"""Test getting capabilities for Pro model with thinking support"""
|
||||||
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
capabilities = provider.get_capabilities("gemini-2.5-pro-preview-06-05")
|
||||||
|
|
||||||
|
assert capabilities.supports_extended_thinking
|
||||||
|
|
||||||
|
def test_model_shorthand_resolution(self):
|
||||||
|
"""Test model shorthand resolution"""
|
||||||
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
assert provider.validate_model_name("flash")
|
||||||
|
assert provider.validate_model_name("pro")
|
||||||
|
|
||||||
|
capabilities = provider.get_capabilities("flash")
|
||||||
|
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
||||||
|
|
||||||
|
def test_supports_thinking_mode(self):
|
||||||
|
"""Test thinking mode support detection"""
|
||||||
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
assert not provider.supports_thinking_mode("gemini-2.0-flash-exp")
|
||||||
|
assert provider.supports_thinking_mode("gemini-2.5-pro-preview-06-05")
|
||||||
|
|
||||||
|
@patch("google.genai.Client")
|
||||||
|
def test_generate_content(self, mock_client_class):
|
||||||
|
"""Test content generation"""
|
||||||
|
# Mock the client
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.text = "Generated content"
|
||||||
|
# Mock candidates for finish_reason
|
||||||
|
mock_candidate = Mock()
|
||||||
|
mock_candidate.finish_reason = "STOP"
|
||||||
|
mock_response.candidates = [mock_candidate]
|
||||||
|
# Mock usage metadata
|
||||||
|
mock_usage = Mock()
|
||||||
|
mock_usage.prompt_token_count = 10
|
||||||
|
mock_usage.candidates_token_count = 20
|
||||||
|
mock_response.usage_metadata = mock_usage
|
||||||
|
mock_client.models.generate_content.return_value = mock_response
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
response = provider.generate_content(
|
||||||
|
prompt="Test prompt",
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ModelResponse)
|
||||||
|
assert response.content == "Generated content"
|
||||||
|
assert response.model_name == "gemini-2.0-flash-exp"
|
||||||
|
assert response.provider == ProviderType.GOOGLE
|
||||||
|
assert response.usage["input_tokens"] == 10
|
||||||
|
assert response.usage["output_tokens"] == 20
|
||||||
|
assert response.usage["total_tokens"] == 30
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIProvider:
|
||||||
|
"""Test OpenAI model provider"""
|
||||||
|
|
||||||
|
def test_provider_initialization(self):
|
||||||
|
"""Test provider initialization"""
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key", organization="test-org")
|
||||||
|
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.organization == "test-org"
|
||||||
|
assert provider.get_provider_type() == ProviderType.OPENAI
|
||||||
|
|
||||||
|
def test_get_capabilities_o3(self):
|
||||||
|
"""Test getting O3 model capabilities"""
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
capabilities = provider.get_capabilities("o3-mini")
|
||||||
|
|
||||||
|
assert capabilities.provider == ProviderType.OPENAI
|
||||||
|
assert capabilities.model_name == "o3-mini"
|
||||||
|
assert capabilities.max_tokens == 200_000
|
||||||
|
assert not capabilities.supports_extended_thinking
|
||||||
|
|
||||||
|
def test_validate_model_names(self):
|
||||||
|
"""Test model name validation"""
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
assert provider.validate_model_name("o3-mini")
|
||||||
|
assert provider.validate_model_name("gpt-4o")
|
||||||
|
assert not provider.validate_model_name("invalid-model")
|
||||||
|
|
||||||
|
def test_no_thinking_mode_support(self):
|
||||||
|
"""Test that no OpenAI models support thinking mode"""
|
||||||
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
|
assert not provider.supports_thinking_mode("o3-mini")
|
||||||
|
assert not provider.supports_thinking_mode("gpt-4o")
|
||||||
@@ -3,6 +3,7 @@ Tests for the main server functionality
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -42,31 +43,36 @@ class TestServerTools:
|
|||||||
assert "Unknown tool: unknown_tool" in result[0].text
|
assert "Unknown tool: unknown_tool" in result[0].text
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_chat(self):
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
|
async def test_handle_chat(self, mock_get_provider):
|
||||||
"""Test chat functionality"""
|
"""Test chat functionality"""
|
||||||
# Set test environment
|
# Set test environment
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["PYTEST_CURRENT_TEST"] = "test"
|
os.environ["PYTEST_CURRENT_TEST"] = "test"
|
||||||
|
|
||||||
# Create a mock for the model
|
# Create a mock for the provider
|
||||||
with patch("tools.base.BaseTool.create_model") as mock_create:
|
mock_provider = create_mock_provider()
|
||||||
mock_model = Mock()
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text="Chat response")]))]
|
mock_provider.generate_content.return_value = Mock(
|
||||||
)
|
content="Chat response",
|
||||||
mock_create.return_value = mock_model
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
|
)
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await handle_call_tool("chat", {"prompt": "Hello Gemini"})
|
result = await handle_call_tool("chat", {"prompt": "Hello Gemini"})
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
# Parse JSON response
|
# Parse JSON response
|
||||||
import json
|
import json
|
||||||
|
|
||||||
response_data = json.loads(result[0].text)
|
response_data = json.loads(result[0].text)
|
||||||
assert response_data["status"] == "success"
|
assert response_data["status"] == "success"
|
||||||
assert "Chat response" in response_data["content"]
|
assert "Chat response" in response_data["content"]
|
||||||
assert "Claude's Turn" in response_data["content"]
|
assert "Claude's Turn" in response_data["content"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_get_version(self):
|
async def test_handle_get_version(self):
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ Tests for thinking_mode functionality across all tools
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -37,28 +38,35 @@ class TestThinkingModes:
|
|||||||
), f"{tool.__class__.__name__} should default to {expected_default}"
|
), f"{tool.__class__.__name__} should default to {expected_default}"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_thinking_mode_minimal(self, mock_create_model):
|
async def test_thinking_mode_minimal(self, mock_get_provider):
|
||||||
"""Test minimal thinking mode"""
|
"""Test minimal thinking mode"""
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text="Minimal thinking response")]))]
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content="Minimal thinking response",
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
tool = AnalyzeTool()
|
tool = AnalyzeTool()
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"files": ["/absolute/path/test.py"],
|
"files": ["/absolute/path/test.py"],
|
||||||
"question": "What is this?",
|
"prompt": "What is this?",
|
||||||
"thinking_mode": "minimal",
|
"thinking_mode": "minimal",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify create_model was called with correct thinking_mode
|
# Verify create_model was called with correct thinking_mode
|
||||||
mock_create_model.assert_called_once()
|
mock_get_provider.assert_called_once()
|
||||||
args = mock_create_model.call_args[0]
|
# Verify generate_content was called with thinking_mode
|
||||||
assert args[2] == "minimal" # thinking_mode parameter
|
mock_provider.generate_content.assert_called_once()
|
||||||
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
|
assert call_kwargs.get("thinking_mode") == "minimal" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None) # thinking_mode parameter
|
||||||
|
|
||||||
# Parse JSON response
|
# Parse JSON response
|
||||||
import json
|
import json
|
||||||
@@ -68,102 +76,130 @@ class TestThinkingModes:
|
|||||||
assert response_data["content"].startswith("Analysis:")
|
assert response_data["content"].startswith("Analysis:")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_thinking_mode_low(self, mock_create_model):
|
async def test_thinking_mode_low(self, mock_get_provider):
|
||||||
"""Test low thinking mode"""
|
"""Test low thinking mode"""
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text="Low thinking response")]))]
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content="Low thinking response",
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
tool = CodeReviewTool()
|
tool = CodeReviewTool()
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"files": ["/absolute/path/test.py"],
|
"files": ["/absolute/path/test.py"],
|
||||||
"thinking_mode": "low",
|
"thinking_mode": "low",
|
||||||
"context": "Test code review for validation purposes",
|
"prompt": "Test code review for validation purposes",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify create_model was called with correct thinking_mode
|
# Verify create_model was called with correct thinking_mode
|
||||||
mock_create_model.assert_called_once()
|
mock_get_provider.assert_called_once()
|
||||||
args = mock_create_model.call_args[0]
|
# Verify generate_content was called with thinking_mode
|
||||||
assert args[2] == "low"
|
mock_provider.generate_content.assert_called_once()
|
||||||
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
|
assert call_kwargs.get("thinking_mode") == "low" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
||||||
|
|
||||||
assert "Code Review" in result[0].text
|
assert "Code Review" in result[0].text
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_thinking_mode_medium(self, mock_create_model):
|
async def test_thinking_mode_medium(self, mock_get_provider):
|
||||||
"""Test medium thinking mode (default for most tools)"""
|
"""Test medium thinking mode (default for most tools)"""
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text="Medium thinking response")]))]
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content="Medium thinking response",
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
tool = DebugIssueTool()
|
tool = DebugIssueTool()
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"error_description": "Test error",
|
"prompt": "Test error",
|
||||||
# Not specifying thinking_mode, should use default (medium)
|
# Not specifying thinking_mode, should use default (medium)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify create_model was called with default thinking_mode
|
# Verify create_model was called with default thinking_mode
|
||||||
mock_create_model.assert_called_once()
|
mock_get_provider.assert_called_once()
|
||||||
args = mock_create_model.call_args[0]
|
# Verify generate_content was called with thinking_mode
|
||||||
assert args[2] == "medium"
|
mock_provider.generate_content.assert_called_once()
|
||||||
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
|
assert call_kwargs.get("thinking_mode") == "medium" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
||||||
|
|
||||||
assert "Debug Analysis" in result[0].text
|
assert "Debug Analysis" in result[0].text
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_thinking_mode_high(self, mock_create_model):
|
async def test_thinking_mode_high(self, mock_get_provider):
|
||||||
"""Test high thinking mode"""
|
"""Test high thinking mode"""
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text="High thinking response")]))]
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content="High thinking response",
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
tool = AnalyzeTool()
|
tool = AnalyzeTool()
|
||||||
await tool.execute(
|
await tool.execute(
|
||||||
{
|
{
|
||||||
"files": ["/absolute/path/complex.py"],
|
"files": ["/absolute/path/complex.py"],
|
||||||
"question": "Analyze architecture",
|
"prompt": "Analyze architecture",
|
||||||
"thinking_mode": "high",
|
"thinking_mode": "high",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify create_model was called with correct thinking_mode
|
# Verify create_model was called with correct thinking_mode
|
||||||
mock_create_model.assert_called_once()
|
mock_get_provider.assert_called_once()
|
||||||
args = mock_create_model.call_args[0]
|
# Verify generate_content was called with thinking_mode
|
||||||
assert args[2] == "high"
|
mock_provider.generate_content.assert_called_once()
|
||||||
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
|
assert call_kwargs.get("thinking_mode") == "high" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_thinking_mode_max(self, mock_create_model):
|
async def test_thinking_mode_max(self, mock_get_provider):
|
||||||
"""Test max thinking mode (default for thinkdeep)"""
|
"""Test max thinking mode (default for thinkdeep)"""
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text="Max thinking response")]))]
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content="Max thinking response",
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
tool = ThinkDeepTool()
|
tool = ThinkDeepTool()
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"current_analysis": "Initial analysis",
|
"prompt": "Initial analysis",
|
||||||
# Not specifying thinking_mode, should use default (high)
|
# Not specifying thinking_mode, should use default (high)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify create_model was called with default thinking_mode
|
# Verify create_model was called with default thinking_mode
|
||||||
mock_create_model.assert_called_once()
|
mock_get_provider.assert_called_once()
|
||||||
args = mock_create_model.call_args[0]
|
# Verify generate_content was called with thinking_mode
|
||||||
assert args[2] == "high"
|
mock_provider.generate_content.assert_called_once()
|
||||||
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
|
assert call_kwargs.get("thinking_mode") == "high" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
||||||
|
|
||||||
assert "Extended Analysis by Gemini" in result[0].text
|
assert "Extended Analysis by Gemini" in result[0].text
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ Tests for individual tool implementations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -24,23 +25,28 @@ class TestThinkDeepTool:
|
|||||||
assert tool.get_default_temperature() == 0.7
|
assert tool.get_default_temperature() == 0.7
|
||||||
|
|
||||||
schema = tool.get_input_schema()
|
schema = tool.get_input_schema()
|
||||||
assert "current_analysis" in schema["properties"]
|
assert "prompt" in schema["properties"]
|
||||||
assert schema["required"] == ["current_analysis"]
|
assert schema["required"] == ["prompt"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_execute_success(self, mock_create_model, tool):
|
async def test_execute_success(self, mock_get_provider, tool):
|
||||||
"""Test successful execution"""
|
"""Test successful execution"""
|
||||||
# Mock model
|
# Mock provider
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text="Extended analysis")]))]
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content="Extended analysis",
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"current_analysis": "Initial analysis",
|
"prompt": "Initial analysis",
|
||||||
"problem_context": "Building a cache",
|
"problem_context": "Building a cache",
|
||||||
"focus_areas": ["performance", "scalability"],
|
"focus_areas": ["performance", "scalability"],
|
||||||
}
|
}
|
||||||
@@ -69,30 +75,35 @@ class TestCodeReviewTool:
|
|||||||
|
|
||||||
schema = tool.get_input_schema()
|
schema = tool.get_input_schema()
|
||||||
assert "files" in schema["properties"]
|
assert "files" in schema["properties"]
|
||||||
assert "context" in schema["properties"]
|
assert "prompt" in schema["properties"]
|
||||||
assert schema["required"] == ["files", "context"]
|
assert schema["required"] == ["files", "prompt"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_execute_with_review_type(self, mock_create_model, tool, tmp_path):
|
async def test_execute_with_review_type(self, mock_get_provider, tool, tmp_path):
|
||||||
"""Test execution with specific review type"""
|
"""Test execution with specific review type"""
|
||||||
# Create test file
|
# Create test file
|
||||||
test_file = tmp_path / "test.py"
|
test_file = tmp_path / "test.py"
|
||||||
test_file.write_text("def insecure(): pass", encoding="utf-8")
|
test_file.write_text("def insecure(): pass", encoding="utf-8")
|
||||||
|
|
||||||
# Mock model
|
# Mock provider
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text="Security issues found")]))]
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content="Security issues found",
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"files": [str(test_file)],
|
"files": [str(test_file)],
|
||||||
"review_type": "security",
|
"review_type": "security",
|
||||||
"focus_on": "authentication",
|
"focus_on": "authentication",
|
||||||
"context": "Test code review for validation purposes",
|
"prompt": "Test code review for validation purposes",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -116,23 +127,28 @@ class TestDebugIssueTool:
|
|||||||
assert tool.get_default_temperature() == 0.2
|
assert tool.get_default_temperature() == 0.2
|
||||||
|
|
||||||
schema = tool.get_input_schema()
|
schema = tool.get_input_schema()
|
||||||
assert "error_description" in schema["properties"]
|
assert "prompt" in schema["properties"]
|
||||||
assert schema["required"] == ["error_description"]
|
assert schema["required"] == ["prompt"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_execute_with_context(self, mock_create_model, tool):
|
async def test_execute_with_context(self, mock_get_provider, tool):
|
||||||
"""Test execution with error context"""
|
"""Test execution with error context"""
|
||||||
# Mock model
|
# Mock provider
|
||||||
mock_model = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_model.generate_content.return_value = Mock(
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text="Root cause: race condition")]))]
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content="Root cause: race condition",
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
)
|
)
|
||||||
mock_create_model.return_value = mock_model
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"error_description": "Test fails intermittently",
|
"prompt": "Test fails intermittently",
|
||||||
"error_context": "AssertionError in test_async",
|
"error_context": "AssertionError in test_async",
|
||||||
"previous_attempts": "Added sleep, still fails",
|
"previous_attempts": "Added sleep, still fails",
|
||||||
}
|
}
|
||||||
@@ -158,30 +174,33 @@ class TestAnalyzeTool:
|
|||||||
|
|
||||||
schema = tool.get_input_schema()
|
schema = tool.get_input_schema()
|
||||||
assert "files" in schema["properties"]
|
assert "files" in schema["properties"]
|
||||||
assert "question" in schema["properties"]
|
assert "prompt" in schema["properties"]
|
||||||
assert set(schema["required"]) == {"files", "question"}
|
assert set(schema["required"]) == {"files", "prompt"}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.create_model")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
async def test_execute_with_analysis_type(self, mock_model, tool, tmp_path):
|
async def test_execute_with_analysis_type(self, mock_get_provider, tool, tmp_path):
|
||||||
"""Test execution with specific analysis type"""
|
"""Test execution with specific analysis type"""
|
||||||
# Create test file
|
# Create test file
|
||||||
test_file = tmp_path / "module.py"
|
test_file = tmp_path / "module.py"
|
||||||
test_file.write_text("class Service: pass", encoding="utf-8")
|
test_file.write_text("class Service: pass", encoding="utf-8")
|
||||||
|
|
||||||
# Mock response
|
# Mock provider
|
||||||
mock_response = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response.candidates = [Mock()]
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates[0].content.parts = [Mock(text="Architecture analysis")]
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
mock_instance = Mock()
|
content="Architecture analysis",
|
||||||
mock_instance.generate_content.return_value = mock_response
|
usage={},
|
||||||
mock_model.return_value = mock_instance
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
|
)
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"files": [str(test_file)],
|
"files": [str(test_file)],
|
||||||
"question": "What's the structure?",
|
"prompt": "What's the structure?",
|
||||||
"analysis_type": "architecture",
|
"analysis_type": "architecture",
|
||||||
"output_format": "summary",
|
"output_format": "summary",
|
||||||
}
|
}
|
||||||
@@ -203,7 +222,7 @@ class TestAbsolutePathValidation:
|
|||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"files": ["./relative/path.py", "/absolute/path.py"],
|
"files": ["./relative/path.py", "/absolute/path.py"],
|
||||||
"question": "What does this do?",
|
"prompt": "What does this do?",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -221,7 +240,7 @@ class TestAbsolutePathValidation:
|
|||||||
{
|
{
|
||||||
"files": ["../parent/file.py"],
|
"files": ["../parent/file.py"],
|
||||||
"review_type": "full",
|
"review_type": "full",
|
||||||
"context": "Test code review for validation purposes",
|
"prompt": "Test code review for validation purposes",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -237,7 +256,7 @@ class TestAbsolutePathValidation:
|
|||||||
tool = DebugIssueTool()
|
tool = DebugIssueTool()
|
||||||
result = await tool.execute(
|
result = await tool.execute(
|
||||||
{
|
{
|
||||||
"error_description": "Something broke",
|
"prompt": "Something broke",
|
||||||
"files": ["src/main.py"], # relative path
|
"files": ["src/main.py"], # relative path
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -252,7 +271,7 @@ class TestAbsolutePathValidation:
|
|||||||
async def test_thinkdeep_tool_relative_path_rejected(self):
|
async def test_thinkdeep_tool_relative_path_rejected(self):
|
||||||
"""Test that thinkdeep tool rejects relative paths"""
|
"""Test that thinkdeep tool rejects relative paths"""
|
||||||
tool = ThinkDeepTool()
|
tool = ThinkDeepTool()
|
||||||
result = await tool.execute({"current_analysis": "My analysis", "files": ["./local/file.py"]})
|
result = await tool.execute({"prompt": "My analysis", "files": ["./local/file.py"]})
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
response = json.loads(result[0].text)
|
response = json.loads(result[0].text)
|
||||||
@@ -278,21 +297,24 @@ class TestAbsolutePathValidation:
|
|||||||
assert "code.py" in response["content"]
|
assert "code.py" in response["content"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.AnalyzeTool.create_model")
|
@patch("tools.AnalyzeTool.get_model_provider")
|
||||||
async def test_analyze_tool_accepts_absolute_paths(self, mock_model):
|
async def test_analyze_tool_accepts_absolute_paths(self, mock_get_provider):
|
||||||
"""Test that analyze tool accepts absolute paths"""
|
"""Test that analyze tool accepts absolute paths"""
|
||||||
tool = AnalyzeTool()
|
tool = AnalyzeTool()
|
||||||
|
|
||||||
# Mock the model response
|
# Mock provider
|
||||||
mock_response = Mock()
|
mock_provider = create_mock_provider()
|
||||||
mock_response.candidates = [Mock()]
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_response.candidates[0].content.parts = [Mock(text="Analysis complete")]
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = Mock(
|
||||||
|
content="Analysis complete",
|
||||||
|
usage={},
|
||||||
|
model_name="gemini-2.0-flash-exp",
|
||||||
|
metadata={}
|
||||||
|
)
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
mock_instance = Mock()
|
result = await tool.execute({"files": ["/absolute/path/file.py"], "prompt": "What does this do?"})
|
||||||
mock_instance.generate_content.return_value = mock_response
|
|
||||||
mock_model.return_value = mock_instance
|
|
||||||
|
|
||||||
result = await tool.execute({"files": ["/absolute/path/file.py"], "question": "What does this do?"})
|
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
response = json.loads(result[0].text)
|
response = json.loads(result[0].text)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class AnalyzeRequest(ToolRequest):
|
|||||||
"""Request model for analyze tool"""
|
"""Request model for analyze tool"""
|
||||||
|
|
||||||
files: list[str] = Field(..., description="Files or directories to analyze (must be absolute paths)")
|
files: list[str] = Field(..., description="Files or directories to analyze (must be absolute paths)")
|
||||||
question: str = Field(..., description="What to analyze or look for")
|
prompt: str = Field(..., description="What to analyze or look for")
|
||||||
analysis_type: Optional[str] = Field(
|
analysis_type: Optional[str] = Field(
|
||||||
None,
|
None,
|
||||||
description="Type of analysis: architecture|performance|security|quality|general",
|
description="Type of analysis: architecture|performance|security|quality|general",
|
||||||
@@ -42,9 +42,9 @@ class AnalyzeTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_input_schema(self) -> dict[str, Any]:
|
def get_input_schema(self) -> dict[str, Any]:
|
||||||
from config import DEFAULT_MODEL
|
from config import IS_AUTO_MODE
|
||||||
|
|
||||||
return {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"files": {
|
"files": {
|
||||||
@@ -52,11 +52,8 @@ class AnalyzeTool(BaseTool):
|
|||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"description": "Files or directories to analyze (must be absolute paths)",
|
"description": "Files or directories to analyze (must be absolute paths)",
|
||||||
},
|
},
|
||||||
"model": {
|
"model": self.get_model_field_schema(),
|
||||||
"type": "string",
|
"prompt": {
|
||||||
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
|
|
||||||
},
|
|
||||||
"question": {
|
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "What to analyze or look for",
|
"description": "What to analyze or look for",
|
||||||
},
|
},
|
||||||
@@ -98,8 +95,10 @@ class AnalyzeTool(BaseTool):
|
|||||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["files", "question"],
|
"required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_system_prompt(self) -> str:
|
||||||
return ANALYZE_PROMPT
|
return ANALYZE_PROMPT
|
||||||
@@ -116,8 +115,8 @@ class AnalyzeTool(BaseTool):
|
|||||||
request_model = self.get_request_model()
|
request_model = self.get_request_model()
|
||||||
request = request_model(**arguments)
|
request = request_model(**arguments)
|
||||||
|
|
||||||
# Check question size
|
# Check prompt size
|
||||||
size_check = self.check_prompt_size(request.question)
|
size_check = self.check_prompt_size(request.prompt)
|
||||||
if size_check:
|
if size_check:
|
||||||
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
|
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
|
||||||
|
|
||||||
@@ -129,9 +128,9 @@ class AnalyzeTool(BaseTool):
|
|||||||
# Check for prompt.txt in files
|
# Check for prompt.txt in files
|
||||||
prompt_content, updated_files = self.handle_prompt_file(request.files)
|
prompt_content, updated_files = self.handle_prompt_file(request.files)
|
||||||
|
|
||||||
# If prompt.txt was found, use it as the question
|
# If prompt.txt was found, use it as the prompt
|
||||||
if prompt_content:
|
if prompt_content:
|
||||||
request.question = prompt_content
|
request.prompt = prompt_content
|
||||||
|
|
||||||
# Update request files list
|
# Update request files list
|
||||||
if updated_files is not None:
|
if updated_files is not None:
|
||||||
@@ -177,7 +176,7 @@ class AnalyzeTool(BaseTool):
|
|||||||
{focus_instruction}{websearch_instruction}
|
{focus_instruction}{websearch_instruction}
|
||||||
|
|
||||||
=== USER QUESTION ===
|
=== USER QUESTION ===
|
||||||
{request.question}
|
{request.prompt}
|
||||||
=== END QUESTION ===
|
=== END QUESTION ===
|
||||||
|
|
||||||
=== FILES TO ANALYZE ===
|
=== FILES TO ANALYZE ===
|
||||||
@@ -188,12 +187,6 @@ Please analyze these files to answer the user's question."""
|
|||||||
|
|
||||||
return full_prompt
|
return full_prompt
|
||||||
|
|
||||||
def format_response(self, response: str, request: AnalyzeRequest) -> str:
|
def format_response(self, response: str, request: AnalyzeRequest, model_info: Optional[dict] = None) -> str:
|
||||||
"""Format the analysis response"""
|
"""Format the analysis response"""
|
||||||
header = f"Analysis: {request.question[:50]}..."
|
return f"{response}\n\n---\n\n**Next Steps:** Use this analysis to actively continue your task. Investigate deeper into any findings, implement solutions based on these insights, and carry out the necessary work. Only pause to ask the user if you need their explicit approval for major changes or if critical decisions require their input."
|
||||||
if request.analysis_type:
|
|
||||||
header = f"{request.analysis_type.upper()} Analysis"
|
|
||||||
|
|
||||||
summary_text = f"Analyzed {len(request.files)} file(s)"
|
|
||||||
|
|
||||||
return f"{header}\n{summary_text}\n{'=' * 50}\n\n{response}\n\n---\n\n**Next Steps:** Consider if this analysis reveals areas needing deeper investigation, additional context, or specific implementation details."
|
|
||||||
|
|||||||
570
tools/base.py
570
tools/base.py
@@ -20,13 +20,12 @@ import re
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from google import genai
|
|
||||||
from google.genai import types
|
|
||||||
from mcp.types import TextContent
|
from mcp.types import TextContent
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from config import DEFAULT_MODEL, MAX_CONTEXT_TOKENS, MCP_PROMPT_SIZE_LIMIT
|
from config import DEFAULT_MODEL, MAX_CONTEXT_TOKENS, MCP_PROMPT_SIZE_LIMIT
|
||||||
from utils import check_token_limit
|
from utils import check_token_limit
|
||||||
|
from providers import ModelProviderRegistry, ModelProvider, ModelResponse
|
||||||
from utils.conversation_memory import (
|
from utils.conversation_memory import (
|
||||||
MAX_CONVERSATION_TURNS,
|
MAX_CONVERSATION_TURNS,
|
||||||
add_turn,
|
add_turn,
|
||||||
@@ -52,7 +51,7 @@ class ToolRequest(BaseModel):
|
|||||||
|
|
||||||
model: Optional[str] = Field(
|
model: Optional[str] = Field(
|
||||||
None,
|
None,
|
||||||
description=f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
|
description="Model to use. See tool's input schema for available models and their capabilities.",
|
||||||
)
|
)
|
||||||
temperature: Optional[float] = Field(None, description="Temperature for response (tool-specific defaults)")
|
temperature: Optional[float] = Field(None, description="Temperature for response (tool-specific defaults)")
|
||||||
# Thinking mode controls how much computational budget the model uses for reasoning
|
# Thinking mode controls how much computational budget the model uses for reasoning
|
||||||
@@ -144,6 +143,38 @@ class BaseTool(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_model_field_schema(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate the model field schema based on auto mode configuration.
|
||||||
|
|
||||||
|
When auto mode is enabled, the model parameter becomes required
|
||||||
|
and includes detailed descriptions of each model's capabilities.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing the model field JSON schema
|
||||||
|
"""
|
||||||
|
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
||||||
|
|
||||||
|
if IS_AUTO_MODE:
|
||||||
|
# In auto mode, model is required and we provide detailed descriptions
|
||||||
|
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
|
||||||
|
for model, desc in MODEL_CAPABILITIES_DESC.items():
|
||||||
|
model_desc_parts.append(f"- '{model}': {desc}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "string",
|
||||||
|
"description": "\n".join(model_desc_parts),
|
||||||
|
"enum": list(MODEL_CAPABILITIES_DESC.keys()),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Normal mode - model is optional with default
|
||||||
|
available_models = list(MODEL_CAPABILITIES_DESC.keys())
|
||||||
|
models_str = ', '.join(f"'{m}'" for m in available_models)
|
||||||
|
return {
|
||||||
|
"type": "string",
|
||||||
|
"description": f"Model to use. Available: {models_str}. Defaults to '{DEFAULT_MODEL}' if not specified.",
|
||||||
|
}
|
||||||
|
|
||||||
def get_default_temperature(self) -> float:
|
def get_default_temperature(self) -> float:
|
||||||
"""
|
"""
|
||||||
Return the default temperature setting for this tool.
|
Return the default temperature setting for this tool.
|
||||||
@@ -293,6 +324,11 @@ class BaseTool(ABC):
|
|||||||
"""
|
"""
|
||||||
if not request_files:
|
if not request_files:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# If conversation history is already embedded, skip file processing
|
||||||
|
if hasattr(self, '_has_embedded_history') and self._has_embedded_history:
|
||||||
|
logger.debug(f"[FILES] {self.name}: Skipping file processing - conversation history already embedded")
|
||||||
|
return ""
|
||||||
|
|
||||||
# Extract remaining budget from arguments if available
|
# Extract remaining budget from arguments if available
|
||||||
if remaining_budget is None:
|
if remaining_budget is None:
|
||||||
@@ -300,15 +336,59 @@ class BaseTool(ABC):
|
|||||||
args_to_use = arguments or getattr(self, "_current_arguments", {})
|
args_to_use = arguments or getattr(self, "_current_arguments", {})
|
||||||
remaining_budget = args_to_use.get("_remaining_tokens")
|
remaining_budget = args_to_use.get("_remaining_tokens")
|
||||||
|
|
||||||
# Use remaining budget if provided, otherwise fall back to max_tokens or default
|
# Use remaining budget if provided, otherwise fall back to max_tokens or model-specific default
|
||||||
if remaining_budget is not None:
|
if remaining_budget is not None:
|
||||||
effective_max_tokens = remaining_budget - reserve_tokens
|
effective_max_tokens = remaining_budget - reserve_tokens
|
||||||
elif max_tokens is not None:
|
elif max_tokens is not None:
|
||||||
effective_max_tokens = max_tokens - reserve_tokens
|
effective_max_tokens = max_tokens - reserve_tokens
|
||||||
else:
|
else:
|
||||||
from config import MAX_CONTENT_TOKENS
|
# Get model-specific limits
|
||||||
|
# First check if model_context was passed from server.py
|
||||||
effective_max_tokens = MAX_CONTENT_TOKENS - reserve_tokens
|
model_context = None
|
||||||
|
if arguments:
|
||||||
|
model_context = arguments.get("_model_context") or getattr(self, "_current_arguments", {}).get("_model_context")
|
||||||
|
|
||||||
|
if model_context:
|
||||||
|
# Use the passed model context
|
||||||
|
try:
|
||||||
|
token_allocation = model_context.calculate_token_allocation()
|
||||||
|
effective_max_tokens = token_allocation.file_tokens - reserve_tokens
|
||||||
|
logger.debug(f"[FILES] {self.name}: Using passed model context for {model_context.model_name}: "
|
||||||
|
f"{token_allocation.file_tokens:,} file tokens from {token_allocation.total_tokens:,} total")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[FILES] {self.name}: Error using passed model context: {e}")
|
||||||
|
# Fall through to manual calculation
|
||||||
|
model_context = None
|
||||||
|
|
||||||
|
if not model_context:
|
||||||
|
# Manual calculation as fallback
|
||||||
|
model_name = getattr(self, "_current_model_name", None) or DEFAULT_MODEL
|
||||||
|
try:
|
||||||
|
provider = self.get_model_provider(model_name)
|
||||||
|
capabilities = provider.get_capabilities(model_name)
|
||||||
|
|
||||||
|
# Calculate content allocation based on model capacity
|
||||||
|
if capabilities.max_tokens < 300_000:
|
||||||
|
# Smaller context models: 60% content, 40% response
|
||||||
|
model_content_tokens = int(capabilities.max_tokens * 0.6)
|
||||||
|
else:
|
||||||
|
# Larger context models: 80% content, 20% response
|
||||||
|
model_content_tokens = int(capabilities.max_tokens * 0.8)
|
||||||
|
|
||||||
|
effective_max_tokens = model_content_tokens - reserve_tokens
|
||||||
|
logger.debug(f"[FILES] {self.name}: Using model-specific limit for {model_name}: "
|
||||||
|
f"{model_content_tokens:,} content tokens from {capabilities.max_tokens:,} total")
|
||||||
|
except (ValueError, AttributeError) as e:
|
||||||
|
# Handle specific errors: provider not found, model not supported, missing attributes
|
||||||
|
logger.warning(f"[FILES] {self.name}: Could not get model capabilities for {model_name}: {type(e).__name__}: {e}")
|
||||||
|
# Fall back to conservative default for safety
|
||||||
|
from config import MAX_CONTENT_TOKENS
|
||||||
|
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
|
||||||
|
except Exception as e:
|
||||||
|
# Catch any other unexpected errors
|
||||||
|
logger.error(f"[FILES] {self.name}: Unexpected error getting model capabilities: {type(e).__name__}: {e}")
|
||||||
|
from config import MAX_CONTENT_TOKENS
|
||||||
|
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
|
||||||
|
|
||||||
# Ensure we have a reasonable minimum budget
|
# Ensure we have a reasonable minimum budget
|
||||||
effective_max_tokens = max(1000, effective_max_tokens)
|
effective_max_tokens = max(1000, effective_max_tokens)
|
||||||
@@ -601,34 +681,59 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
)
|
)
|
||||||
return [TextContent(type="text", text=error_output.model_dump_json())]
|
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||||
|
|
||||||
# Prepare the full prompt by combining system prompt with user request
|
# Check if we have continuation_id - if so, conversation history is already embedded
|
||||||
# This is delegated to the tool implementation for customization
|
|
||||||
prompt = await self.prepare_prompt(request)
|
|
||||||
|
|
||||||
# Add follow-up instructions for new conversations (not threaded)
|
|
||||||
continuation_id = getattr(request, "continuation_id", None)
|
continuation_id = getattr(request, "continuation_id", None)
|
||||||
if not continuation_id:
|
|
||||||
# Import here to avoid circular imports
|
if continuation_id:
|
||||||
|
# When continuation_id is present, server.py has already injected the
|
||||||
|
# conversation history into the appropriate field. We need to check if
|
||||||
|
# the prompt already contains conversation history marker.
|
||||||
|
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
|
||||||
|
|
||||||
|
# Store the original arguments to detect enhanced prompts
|
||||||
|
self._has_embedded_history = False
|
||||||
|
|
||||||
|
# Check if conversation history is already embedded in the prompt field
|
||||||
|
field_value = getattr(request, "prompt", "")
|
||||||
|
field_name = "prompt"
|
||||||
|
|
||||||
|
if "=== CONVERSATION HISTORY ===" in field_value:
|
||||||
|
# Conversation history is already embedded, use it directly
|
||||||
|
prompt = field_value
|
||||||
|
self._has_embedded_history = True
|
||||||
|
logger.debug(f"{self.name}: Using pre-embedded conversation history from {field_name}")
|
||||||
|
else:
|
||||||
|
# No embedded history, prepare prompt normally
|
||||||
|
prompt = await self.prepare_prompt(request)
|
||||||
|
logger.debug(f"{self.name}: No embedded history found, prepared prompt normally")
|
||||||
|
else:
|
||||||
|
# New conversation, prepare prompt normally
|
||||||
|
prompt = await self.prepare_prompt(request)
|
||||||
|
|
||||||
|
# Add follow-up instructions for new conversations
|
||||||
from server import get_follow_up_instructions
|
from server import get_follow_up_instructions
|
||||||
|
|
||||||
follow_up_instructions = get_follow_up_instructions(0) # New conversation, turn 0
|
follow_up_instructions = get_follow_up_instructions(0) # New conversation, turn 0
|
||||||
prompt = f"{prompt}\n\n{follow_up_instructions}"
|
prompt = f"{prompt}\n\n{follow_up_instructions}"
|
||||||
|
|
||||||
logger.debug(f"Added follow-up instructions for new {self.name} conversation")
|
logger.debug(f"Added follow-up instructions for new {self.name} conversation")
|
||||||
|
|
||||||
# Also log to file for debugging MCP issues
|
|
||||||
try:
|
|
||||||
with open("/tmp/gemini_debug.log", "a") as f:
|
|
||||||
f.write(f"[{self.name}] Added follow-up instructions for new conversation\n")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
|
|
||||||
# History reconstruction is handled by server.py:reconstruct_thread_context
|
|
||||||
# No need to rebuild it here - prompt already contains conversation history
|
|
||||||
|
|
||||||
# Extract model configuration from request or use defaults
|
# Extract model configuration from request or use defaults
|
||||||
model_name = getattr(request, "model", None) or DEFAULT_MODEL
|
model_name = getattr(request, "model", None)
|
||||||
|
if not model_name:
|
||||||
|
model_name = DEFAULT_MODEL
|
||||||
|
|
||||||
|
# In auto mode, model parameter is required
|
||||||
|
from config import IS_AUTO_MODE
|
||||||
|
if IS_AUTO_MODE and model_name.lower() == "auto":
|
||||||
|
error_output = ToolOutput(
|
||||||
|
status="error",
|
||||||
|
content="Model parameter is required. Please specify which model to use for this task.",
|
||||||
|
content_type="text",
|
||||||
|
)
|
||||||
|
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||||
|
|
||||||
|
# Store model name for use by helper methods like _prepare_file_content_for_prompt
|
||||||
|
self._current_model_name = model_name
|
||||||
|
|
||||||
temperature = getattr(request, "temperature", None)
|
temperature = getattr(request, "temperature", None)
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = self.get_default_temperature()
|
temperature = self.get_default_temperature()
|
||||||
@@ -636,28 +741,45 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
if thinking_mode is None:
|
if thinking_mode is None:
|
||||||
thinking_mode = self.get_default_thinking_mode()
|
thinking_mode = self.get_default_thinking_mode()
|
||||||
|
|
||||||
# Create model instance with appropriate configuration
|
# Get the appropriate model provider
|
||||||
# This handles both regular models and thinking-enabled models
|
provider = self.get_model_provider(model_name)
|
||||||
model = self.create_model(model_name, temperature, thinking_mode)
|
|
||||||
|
# Get system prompt for this tool
|
||||||
|
system_prompt = self.get_system_prompt()
|
||||||
|
|
||||||
# Generate AI response using the configured model
|
# Generate AI response using the provider
|
||||||
logger.info(f"Sending request to Gemini API for {self.name}")
|
logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.name}")
|
||||||
logger.debug(f"Prompt length: {len(prompt)} characters")
|
logger.debug(f"Prompt length: {len(prompt)} characters")
|
||||||
response = model.generate_content(prompt)
|
|
||||||
logger.info(f"Received response from Gemini API for {self.name}")
|
# Generate content with provider abstraction
|
||||||
|
model_response = provider.generate_content(
|
||||||
|
prompt=prompt,
|
||||||
|
model_name=model_name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Received response from {provider.get_provider_type().value} API for {self.name}")
|
||||||
|
|
||||||
# Process the model's response
|
# Process the model's response
|
||||||
if response.candidates and response.candidates[0].content.parts:
|
if model_response.content:
|
||||||
raw_text = response.candidates[0].content.parts[0].text
|
raw_text = model_response.content
|
||||||
|
|
||||||
# Parse response to check for clarification requests or format output
|
# Parse response to check for clarification requests or format output
|
||||||
tool_output = self._parse_response(raw_text, request)
|
# Pass model info for conversation tracking
|
||||||
|
model_info = {
|
||||||
|
"provider": provider,
|
||||||
|
"model_name": model_name,
|
||||||
|
"model_response": model_response
|
||||||
|
}
|
||||||
|
tool_output = self._parse_response(raw_text, request, model_info)
|
||||||
logger.info(f"Successfully completed {self.name} tool execution")
|
logger.info(f"Successfully completed {self.name} tool execution")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Handle cases where the model couldn't generate a response
|
# Handle cases where the model couldn't generate a response
|
||||||
# This might happen due to safety filters or other constraints
|
# This might happen due to safety filters or other constraints
|
||||||
finish_reason = response.candidates[0].finish_reason if response.candidates else "Unknown"
|
finish_reason = model_response.metadata.get("finish_reason", "Unknown")
|
||||||
logger.warning(f"Response blocked or incomplete for {self.name}. Finish reason: {finish_reason}")
|
logger.warning(f"Response blocked or incomplete for {self.name}. Finish reason: {finish_reason}")
|
||||||
tool_output = ToolOutput(
|
tool_output = ToolOutput(
|
||||||
status="error",
|
status="error",
|
||||||
@@ -678,13 +800,24 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
if "500 INTERNAL" in error_msg and "Please retry" in error_msg:
|
if "500 INTERNAL" in error_msg and "Please retry" in error_msg:
|
||||||
logger.warning(f"500 INTERNAL error in {self.name} - attempting retry")
|
logger.warning(f"500 INTERNAL error in {self.name} - attempting retry")
|
||||||
try:
|
try:
|
||||||
# Single retry attempt
|
# Single retry attempt using provider
|
||||||
model = self._get_model_wrapper(request)
|
retry_response = provider.generate_content(
|
||||||
raw_response = await model.generate_content(prompt)
|
prompt=prompt,
|
||||||
response = raw_response.text
|
model_name=model_name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
# If successful, process normally
|
temperature=temperature,
|
||||||
return [TextContent(type="text", text=self._process_response(response, request).model_dump_json())]
|
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if retry_response.content:
|
||||||
|
# If successful, process normally
|
||||||
|
retry_model_info = {
|
||||||
|
"provider": provider,
|
||||||
|
"model_name": model_name,
|
||||||
|
"model_response": retry_response
|
||||||
|
}
|
||||||
|
tool_output = self._parse_response(retry_response.content, request, retry_model_info)
|
||||||
|
return [TextContent(type="text", text=tool_output.model_dump_json())]
|
||||||
|
|
||||||
except Exception as retry_e:
|
except Exception as retry_e:
|
||||||
logger.error(f"Retry failed for {self.name} tool: {str(retry_e)}")
|
logger.error(f"Retry failed for {self.name} tool: {str(retry_e)}")
|
||||||
@@ -699,7 +832,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
)
|
)
|
||||||
return [TextContent(type="text", text=error_output.model_dump_json())]
|
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||||
|
|
||||||
def _parse_response(self, raw_text: str, request) -> ToolOutput:
|
def _parse_response(self, raw_text: str, request, model_info: Optional[dict] = None) -> ToolOutput:
|
||||||
"""
|
"""
|
||||||
Parse the raw response and determine if it's a clarification request or follow-up.
|
Parse the raw response and determine if it's a clarification request or follow-up.
|
||||||
|
|
||||||
@@ -745,11 +878,11 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Normal text response - format using tool-specific formatting
|
# Normal text response - format using tool-specific formatting
|
||||||
formatted_content = self.format_response(raw_text, request)
|
formatted_content = self.format_response(raw_text, request, model_info)
|
||||||
|
|
||||||
# If we found a follow-up question, prepare the threading response
|
# If we found a follow-up question, prepare the threading response
|
||||||
if follow_up_question:
|
if follow_up_question:
|
||||||
return self._create_follow_up_response(formatted_content, follow_up_question, request)
|
return self._create_follow_up_response(formatted_content, follow_up_question, request, model_info)
|
||||||
|
|
||||||
# Check if we should offer Claude a continuation opportunity
|
# Check if we should offer Claude a continuation opportunity
|
||||||
continuation_offer = self._check_continuation_opportunity(request)
|
continuation_offer = self._check_continuation_opportunity(request)
|
||||||
@@ -758,7 +891,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Creating continuation offer for {self.name} with {continuation_offer['remaining_turns']} turns remaining"
|
f"Creating continuation offer for {self.name} with {continuation_offer['remaining_turns']} turns remaining"
|
||||||
)
|
)
|
||||||
return self._create_continuation_offer_response(formatted_content, continuation_offer, request)
|
return self._create_continuation_offer_response(formatted_content, continuation_offer, request, model_info)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"No continuation offer created for {self.name}")
|
logger.debug(f"No continuation offer created for {self.name}")
|
||||||
|
|
||||||
@@ -766,12 +899,32 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
continuation_id = getattr(request, "continuation_id", None)
|
continuation_id = getattr(request, "continuation_id", None)
|
||||||
if continuation_id:
|
if continuation_id:
|
||||||
request_files = getattr(request, "files", []) or []
|
request_files = getattr(request, "files", []) or []
|
||||||
|
# Extract model metadata for conversation tracking
|
||||||
|
model_provider = None
|
||||||
|
model_name = None
|
||||||
|
model_metadata = None
|
||||||
|
|
||||||
|
if model_info:
|
||||||
|
provider = model_info.get("provider")
|
||||||
|
if provider:
|
||||||
|
model_provider = provider.get_provider_type().value
|
||||||
|
model_name = model_info.get("model_name")
|
||||||
|
model_response = model_info.get("model_response")
|
||||||
|
if model_response:
|
||||||
|
model_metadata = {
|
||||||
|
"usage": model_response.usage,
|
||||||
|
"metadata": model_response.metadata
|
||||||
|
}
|
||||||
|
|
||||||
success = add_turn(
|
success = add_turn(
|
||||||
continuation_id,
|
continuation_id,
|
||||||
"assistant",
|
"assistant",
|
||||||
formatted_content,
|
formatted_content,
|
||||||
files=request_files,
|
files=request_files,
|
||||||
tool_name=self.name,
|
tool_name=self.name,
|
||||||
|
model_provider=model_provider,
|
||||||
|
model_name=model_name,
|
||||||
|
model_metadata=model_metadata,
|
||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
logging.warning(f"Failed to add turn to thread {continuation_id} for {self.name}")
|
logging.warning(f"Failed to add turn to thread {continuation_id} for {self.name}")
|
||||||
@@ -820,7 +973,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _create_follow_up_response(self, content: str, follow_up_data: dict, request) -> ToolOutput:
|
def _create_follow_up_response(self, content: str, follow_up_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
|
||||||
"""
|
"""
|
||||||
Create a response with follow-up question for conversation threading.
|
Create a response with follow-up question for conversation threading.
|
||||||
|
|
||||||
@@ -832,56 +985,57 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
Returns:
|
Returns:
|
||||||
ToolOutput configured for conversation continuation
|
ToolOutput configured for conversation continuation
|
||||||
"""
|
"""
|
||||||
# Create or get thread ID
|
# Always create a new thread (with parent linkage if continuation)
|
||||||
continuation_id = getattr(request, "continuation_id", None)
|
continuation_id = getattr(request, "continuation_id", None)
|
||||||
|
request_files = getattr(request, "files", []) or []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create new thread with parent linkage if continuing
|
||||||
|
thread_id = create_thread(
|
||||||
|
tool_name=self.name,
|
||||||
|
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
|
||||||
|
parent_thread_id=continuation_id # Link to parent thread if continuing
|
||||||
|
)
|
||||||
|
|
||||||
if continuation_id:
|
# Add the assistant's response with follow-up
|
||||||
# This is a continuation - add this turn to existing thread
|
# Extract model metadata
|
||||||
request_files = getattr(request, "files", []) or []
|
model_provider = None
|
||||||
success = add_turn(
|
model_name = None
|
||||||
continuation_id,
|
model_metadata = None
|
||||||
|
|
||||||
|
if model_info:
|
||||||
|
provider = model_info.get("provider")
|
||||||
|
if provider:
|
||||||
|
model_provider = provider.get_provider_type().value
|
||||||
|
model_name = model_info.get("model_name")
|
||||||
|
model_response = model_info.get("model_response")
|
||||||
|
if model_response:
|
||||||
|
model_metadata = {
|
||||||
|
"usage": model_response.usage,
|
||||||
|
"metadata": model_response.metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
add_turn(
|
||||||
|
thread_id, # Add to the new thread
|
||||||
"assistant",
|
"assistant",
|
||||||
content,
|
content,
|
||||||
follow_up_question=follow_up_data.get("follow_up_question"),
|
follow_up_question=follow_up_data.get("follow_up_question"),
|
||||||
files=request_files,
|
files=request_files,
|
||||||
tool_name=self.name,
|
tool_name=self.name,
|
||||||
|
model_provider=model_provider,
|
||||||
|
model_name=model_name,
|
||||||
|
model_metadata=model_metadata,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Threading failed, return normal response
|
||||||
|
logger = logging.getLogger(f"tools.{self.name}")
|
||||||
|
logger.warning(f"Follow-up threading failed in {self.name}: {str(e)}")
|
||||||
|
return ToolOutput(
|
||||||
|
status="success",
|
||||||
|
content=content,
|
||||||
|
content_type="markdown",
|
||||||
|
metadata={"tool_name": self.name, "follow_up_error": str(e)},
|
||||||
)
|
)
|
||||||
if not success:
|
|
||||||
# Thread not found or at limit, return normal response
|
|
||||||
return ToolOutput(
|
|
||||||
status="success",
|
|
||||||
content=content,
|
|
||||||
content_type="markdown",
|
|
||||||
metadata={"tool_name": self.name},
|
|
||||||
)
|
|
||||||
thread_id = continuation_id
|
|
||||||
else:
|
|
||||||
# Create new thread
|
|
||||||
try:
|
|
||||||
thread_id = create_thread(
|
|
||||||
tool_name=self.name, initial_request=request.model_dump() if hasattr(request, "model_dump") else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the assistant's response with follow-up
|
|
||||||
request_files = getattr(request, "files", []) or []
|
|
||||||
add_turn(
|
|
||||||
thread_id,
|
|
||||||
"assistant",
|
|
||||||
content,
|
|
||||||
follow_up_question=follow_up_data.get("follow_up_question"),
|
|
||||||
files=request_files,
|
|
||||||
tool_name=self.name,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# Threading failed, return normal response
|
|
||||||
logger = logging.getLogger(f"tools.{self.name}")
|
|
||||||
logger.warning(f"Follow-up threading failed in {self.name}: {str(e)}")
|
|
||||||
return ToolOutput(
|
|
||||||
status="success",
|
|
||||||
content=content,
|
|
||||||
content_type="markdown",
|
|
||||||
metadata={"tool_name": self.name, "follow_up_error": str(e)},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create follow-up request
|
# Create follow-up request
|
||||||
follow_up_request = FollowUpRequest(
|
follow_up_request = FollowUpRequest(
|
||||||
@@ -925,13 +1079,14 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if continuation_id:
|
if continuation_id:
|
||||||
# Check remaining turns in existing thread
|
# Check remaining turns in thread chain
|
||||||
from utils.conversation_memory import get_thread
|
from utils.conversation_memory import get_thread_chain
|
||||||
|
|
||||||
context = get_thread(continuation_id)
|
chain = get_thread_chain(continuation_id)
|
||||||
if context:
|
if chain:
|
||||||
current_turns = len(context.turns)
|
# Count total turns across all threads in chain
|
||||||
remaining_turns = MAX_CONVERSATION_TURNS - current_turns - 1 # -1 for this response
|
total_turns = sum(len(thread.turns) for thread in chain)
|
||||||
|
remaining_turns = MAX_CONVERSATION_TURNS - total_turns - 1 # -1 for this response
|
||||||
else:
|
else:
|
||||||
# Thread not found, don't offer continuation
|
# Thread not found, don't offer continuation
|
||||||
return None
|
return None
|
||||||
@@ -949,7 +1104,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
# If anything fails, don't offer continuation
|
# If anything fails, don't offer continuation
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _create_continuation_offer_response(self, content: str, continuation_data: dict, request) -> ToolOutput:
|
def _create_continuation_offer_response(self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
|
||||||
"""
|
"""
|
||||||
Create a response offering Claude the opportunity to continue conversation.
|
Create a response offering Claude the opportunity to continue conversation.
|
||||||
|
|
||||||
@@ -962,14 +1117,43 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
ToolOutput configured with continuation offer
|
ToolOutput configured with continuation offer
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Create new thread for potential continuation
|
# Create new thread for potential continuation (with parent link if continuing)
|
||||||
|
continuation_id = getattr(request, "continuation_id", None)
|
||||||
thread_id = create_thread(
|
thread_id = create_thread(
|
||||||
tool_name=self.name, initial_request=request.model_dump() if hasattr(request, "model_dump") else {}
|
tool_name=self.name,
|
||||||
|
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
|
||||||
|
parent_thread_id=continuation_id # Link to parent if this is a continuation
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add this response as the first turn (assistant turn)
|
# Add this response as the first turn (assistant turn)
|
||||||
request_files = getattr(request, "files", []) or []
|
request_files = getattr(request, "files", []) or []
|
||||||
add_turn(thread_id, "assistant", content, files=request_files, tool_name=self.name)
|
# Extract model metadata
|
||||||
|
model_provider = None
|
||||||
|
model_name = None
|
||||||
|
model_metadata = None
|
||||||
|
|
||||||
|
if model_info:
|
||||||
|
provider = model_info.get("provider")
|
||||||
|
if provider:
|
||||||
|
model_provider = provider.get_provider_type().value
|
||||||
|
model_name = model_info.get("model_name")
|
||||||
|
model_response = model_info.get("model_response")
|
||||||
|
if model_response:
|
||||||
|
model_metadata = {
|
||||||
|
"usage": model_response.usage,
|
||||||
|
"metadata": model_response.metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
add_turn(
|
||||||
|
thread_id,
|
||||||
|
"assistant",
|
||||||
|
content,
|
||||||
|
files=request_files,
|
||||||
|
tool_name=self.name,
|
||||||
|
model_provider=model_provider,
|
||||||
|
model_name=model_name,
|
||||||
|
model_metadata=model_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
# Create continuation offer
|
# Create continuation offer
|
||||||
remaining_turns = continuation_data["remaining_turns"]
|
remaining_turns = continuation_data["remaining_turns"]
|
||||||
@@ -1022,7 +1206,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def format_response(self, response: str, request) -> str:
|
def format_response(self, response: str, request, model_info: Optional[dict] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Format the model's response for display.
|
Format the model's response for display.
|
||||||
|
|
||||||
@@ -1033,6 +1217,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
Args:
|
Args:
|
||||||
response: The raw response from the model
|
response: The raw response from the model
|
||||||
request: The original request for context
|
request: The original request for context
|
||||||
|
model_info: Optional dict with model metadata (provider, model_name, model_response)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Formatted response
|
str: Formatted response
|
||||||
@@ -1059,154 +1244,41 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
f"{context_type} too large (~{estimated_tokens:,} tokens). Maximum is {MAX_CONTEXT_TOKENS:,} tokens."
|
f"{context_type} too large (~{estimated_tokens:,} tokens). Maximum is {MAX_CONTEXT_TOKENS:,} tokens."
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_model(self, model_name: str, temperature: float, thinking_mode: str = "medium"):
|
def get_model_provider(self, model_name: str) -> ModelProvider:
|
||||||
"""
|
"""
|
||||||
Create a configured Gemini model instance.
|
Get a model provider for the specified model.
|
||||||
|
|
||||||
This method handles model creation with appropriate settings including
|
|
||||||
temperature and thinking budget configuration for models that support it.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the Gemini model to use (or shorthand like 'flash', 'pro')
|
model_name: Name of the model to use (can be provider-specific or generic)
|
||||||
temperature: Temperature setting for response generation
|
|
||||||
thinking_mode: Thinking depth mode (affects computational budget)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model instance configured and ready for generation
|
ModelProvider instance configured for the model
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no provider supports the requested model
|
||||||
"""
|
"""
|
||||||
# Define model shorthands for user convenience
|
# Get provider from registry
|
||||||
model_shorthands = {
|
provider = ModelProviderRegistry.get_provider_for_model(model_name)
|
||||||
"pro": "gemini-2.5-pro-preview-06-05",
|
|
||||||
"flash": "gemini-2.0-flash-exp",
|
if not provider:
|
||||||
}
|
# Try to determine provider from model name patterns
|
||||||
|
if "gemini" in model_name.lower() or model_name.lower() in ["flash", "pro"]:
|
||||||
# Resolve shorthand to full model name
|
# Register Gemini provider if not already registered
|
||||||
resolved_model_name = model_shorthands.get(model_name.lower(), model_name)
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.base import ProviderType
|
||||||
# Map thinking modes to computational budget values
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
# Higher budgets allow for more complex reasoning but increase latency
|
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
||||||
thinking_budgets = {
|
elif "gpt" in model_name.lower() or "o3" in model_name.lower():
|
||||||
"minimal": 128, # Minimum for 2.5 Pro - fast responses
|
# Register OpenAI provider if not already registered
|
||||||
"low": 2048, # Light reasoning tasks
|
from providers.openai import OpenAIModelProvider
|
||||||
"medium": 8192, # Balanced reasoning (default)
|
from providers.base import ProviderType
|
||||||
"high": 16384, # Complex analysis
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
"max": 32768, # Maximum reasoning depth
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENAI)
|
||||||
}
|
|
||||||
|
if not provider:
|
||||||
thinking_budget = thinking_budgets.get(thinking_mode, 8192)
|
raise ValueError(
|
||||||
|
f"No provider found for model '{model_name}'. "
|
||||||
# Gemini 2.5 models support thinking configuration for enhanced reasoning
|
f"Ensure the appropriate API key is set and the model name is correct."
|
||||||
# Skip special handling in test environment to allow mocking
|
)
|
||||||
if "2.5" in resolved_model_name and not os.environ.get("PYTEST_CURRENT_TEST"):
|
|
||||||
try:
|
return provider
|
||||||
# Retrieve API key for Gemini client creation
|
|
||||||
api_key = os.environ.get("GEMINI_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise ValueError("GEMINI_API_KEY environment variable is required")
|
|
||||||
|
|
||||||
client = genai.Client(api_key=api_key)
|
|
||||||
|
|
||||||
# Create a wrapper class to provide a consistent interface
|
|
||||||
# This abstracts the differences between API versions
|
|
||||||
class ModelWrapper:
|
|
||||||
def __init__(self, client, model_name, temperature, thinking_budget):
|
|
||||||
self.client = client
|
|
||||||
self.model_name = model_name
|
|
||||||
self.temperature = temperature
|
|
||||||
self.thinking_budget = thinking_budget
|
|
||||||
|
|
||||||
def generate_content(self, prompt):
|
|
||||||
response = self.client.models.generate_content(
|
|
||||||
model=self.model_name,
|
|
||||||
contents=prompt,
|
|
||||||
config=types.GenerateContentConfig(
|
|
||||||
temperature=self.temperature,
|
|
||||||
candidate_count=1,
|
|
||||||
thinking_config=types.ThinkingConfig(thinking_budget=self.thinking_budget),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wrap the response to match the expected format
|
|
||||||
# This ensures compatibility across different API versions
|
|
||||||
class ResponseWrapper:
|
|
||||||
def __init__(self, text):
|
|
||||||
self.text = text
|
|
||||||
self.candidates = [
|
|
||||||
type(
|
|
||||||
"obj",
|
|
||||||
(object,),
|
|
||||||
{
|
|
||||||
"content": type(
|
|
||||||
"obj",
|
|
||||||
(object,),
|
|
||||||
{
|
|
||||||
"parts": [
|
|
||||||
type(
|
|
||||||
"obj",
|
|
||||||
(object,),
|
|
||||||
{"text": text},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)(),
|
|
||||||
"finish_reason": "STOP",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return ResponseWrapper(response.text)
|
|
||||||
|
|
||||||
return ModelWrapper(client, resolved_model_name, temperature, thinking_budget)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# Fall back to regular API if thinking configuration fails
|
|
||||||
# This ensures the tool remains functional even with API changes
|
|
||||||
pass
|
|
||||||
|
|
||||||
# For models that don't support thinking configuration, use standard API
|
|
||||||
api_key = os.environ.get("GEMINI_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise ValueError("GEMINI_API_KEY environment variable is required")
|
|
||||||
|
|
||||||
client = genai.Client(api_key=api_key)
|
|
||||||
|
|
||||||
# Create a simple wrapper for models without thinking configuration
|
|
||||||
# This provides the same interface as the thinking-enabled wrapper
|
|
||||||
class SimpleModelWrapper:
|
|
||||||
def __init__(self, client, model_name, temperature):
|
|
||||||
self.client = client
|
|
||||||
self.model_name = model_name
|
|
||||||
self.temperature = temperature
|
|
||||||
|
|
||||||
def generate_content(self, prompt):
|
|
||||||
response = self.client.models.generate_content(
|
|
||||||
model=self.model_name,
|
|
||||||
contents=prompt,
|
|
||||||
config=types.GenerateContentConfig(
|
|
||||||
temperature=self.temperature,
|
|
||||||
candidate_count=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to match expected format
|
|
||||||
class ResponseWrapper:
|
|
||||||
def __init__(self, text):
|
|
||||||
self.text = text
|
|
||||||
self.candidates = [
|
|
||||||
type(
|
|
||||||
"obj",
|
|
||||||
(object,),
|
|
||||||
{
|
|
||||||
"content": type(
|
|
||||||
"obj",
|
|
||||||
(object,),
|
|
||||||
{"parts": [type("obj", (object,), {"text": text})]},
|
|
||||||
)(),
|
|
||||||
"finish_reason": "STOP",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return ResponseWrapper(response.text)
|
|
||||||
|
|
||||||
return SimpleModelWrapper(client, resolved_model_name, temperature)
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class ChatRequest(ToolRequest):
|
|||||||
|
|
||||||
prompt: str = Field(
|
prompt: str = Field(
|
||||||
...,
|
...,
|
||||||
description="Your question, topic, or current thinking to discuss with Gemini",
|
description="Your question, topic, or current thinking to discuss",
|
||||||
)
|
)
|
||||||
files: Optional[list[str]] = Field(
|
files: Optional[list[str]] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
@@ -35,33 +35,30 @@ class ChatTool(BaseTool):
|
|||||||
|
|
||||||
def get_description(self) -> str:
|
def get_description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"GENERAL CHAT & COLLABORATIVE THINKING - Use Gemini as your thinking partner! "
|
"GENERAL CHAT & COLLABORATIVE THINKING - Use the AI model as your thinking partner! "
|
||||||
"Perfect for: bouncing ideas during your own analysis, getting second opinions on your plans, "
|
"Perfect for: bouncing ideas during your own analysis, getting second opinions on your plans, "
|
||||||
"collaborative brainstorming, validating your checklists and approaches, exploring alternatives. "
|
"collaborative brainstorming, validating your checklists and approaches, exploring alternatives. "
|
||||||
"Also great for: explanations, comparisons, general development questions. "
|
"Also great for: explanations, comparisons, general development questions. "
|
||||||
"Use this when you want to ask Gemini questions, brainstorm ideas, get opinions, discuss topics, "
|
"Use this when you want to ask questions, brainstorm ideas, get opinions, discuss topics, "
|
||||||
"share your thinking, or need explanations about concepts and approaches."
|
"share your thinking, or need explanations about concepts and approaches."
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_input_schema(self) -> dict[str, Any]:
|
def get_input_schema(self) -> dict[str, Any]:
|
||||||
from config import DEFAULT_MODEL
|
from config import IS_AUTO_MODE
|
||||||
|
|
||||||
return {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"prompt": {
|
"prompt": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Your question, topic, or current thinking to discuss with Gemini",
|
"description": "Your question, topic, or current thinking to discuss",
|
||||||
},
|
},
|
||||||
"files": {
|
"files": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"description": "Optional files for context (must be absolute paths)",
|
"description": "Optional files for context (must be absolute paths)",
|
||||||
},
|
},
|
||||||
"model": {
|
"model": self.get_model_field_schema(),
|
||||||
"type": "string",
|
|
||||||
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
|
|
||||||
},
|
|
||||||
"temperature": {
|
"temperature": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"description": "Response creativity (0-1, default 0.5)",
|
"description": "Response creativity (0-1, default 0.5)",
|
||||||
@@ -83,8 +80,10 @@ class ChatTool(BaseTool):
|
|||||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["prompt"],
|
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_system_prompt(self) -> str:
|
||||||
return CHAT_PROMPT
|
return CHAT_PROMPT
|
||||||
@@ -153,6 +152,6 @@ Please provide a thoughtful, comprehensive response:"""
|
|||||||
|
|
||||||
return full_prompt
|
return full_prompt
|
||||||
|
|
||||||
def format_response(self, response: str, request: ChatRequest) -> str:
|
def format_response(self, response: str, request: ChatRequest, model_info: Optional[dict] = None) -> str:
|
||||||
"""Format the chat response with actionable guidance"""
|
"""Format the chat response"""
|
||||||
return f"{response}\n\n---\n\n**Claude's Turn:** Evaluate this perspective alongside your analysis to form a comprehensive solution and continue with the user's request and task at hand."
|
return f"{response}\n\n---\n\n**Claude's Turn:** Evaluate this perspective alongside your analysis to form a comprehensive solution and continue with the user's request and task at hand."
|
||||||
|
|||||||
@@ -39,12 +39,12 @@ class CodeReviewRequest(ToolRequest):
|
|||||||
...,
|
...,
|
||||||
description="Code files or directories to review (must be absolute paths)",
|
description="Code files or directories to review (must be absolute paths)",
|
||||||
)
|
)
|
||||||
context: str = Field(
|
prompt: str = Field(
|
||||||
...,
|
...,
|
||||||
description="User's summary of what the code does, expected behavior, constraints, and review objectives",
|
description="User's summary of what the code does, expected behavior, constraints, and review objectives",
|
||||||
)
|
)
|
||||||
review_type: str = Field("full", description="Type of review: full|security|performance|quick")
|
review_type: str = Field("full", description="Type of review: full|security|performance|quick")
|
||||||
focus_on: Optional[str] = Field(None, description="Specific aspects to focus on during review")
|
focus_on: Optional[str] = Field(None, description="Specific aspects to focus on, or additional context that would help understand areas of concern")
|
||||||
standards: Optional[str] = Field(None, description="Coding standards or guidelines to enforce")
|
standards: Optional[str] = Field(None, description="Coding standards or guidelines to enforce")
|
||||||
severity_filter: str = Field(
|
severity_filter: str = Field(
|
||||||
"all",
|
"all",
|
||||||
@@ -79,9 +79,9 @@ class CodeReviewTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_input_schema(self) -> dict[str, Any]:
|
def get_input_schema(self) -> dict[str, Any]:
|
||||||
from config import DEFAULT_MODEL
|
from config import IS_AUTO_MODE
|
||||||
|
|
||||||
return {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"files": {
|
"files": {
|
||||||
@@ -89,11 +89,8 @@ class CodeReviewTool(BaseTool):
|
|||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"description": "Code files or directories to review (must be absolute paths)",
|
"description": "Code files or directories to review (must be absolute paths)",
|
||||||
},
|
},
|
||||||
"model": {
|
"model": self.get_model_field_schema(),
|
||||||
"type": "string",
|
"prompt": {
|
||||||
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
|
|
||||||
},
|
|
||||||
"context": {
|
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "User's summary of what the code does, expected behavior, constraints, and review objectives",
|
"description": "User's summary of what the code does, expected behavior, constraints, and review objectives",
|
||||||
},
|
},
|
||||||
@@ -105,7 +102,7 @@ class CodeReviewTool(BaseTool):
|
|||||||
},
|
},
|
||||||
"focus_on": {
|
"focus_on": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Specific aspects to focus on",
|
"description": "Specific aspects to focus on, or additional context that would help understand areas of concern",
|
||||||
},
|
},
|
||||||
"standards": {
|
"standards": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -138,8 +135,10 @@ class CodeReviewTool(BaseTool):
|
|||||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["files", "context"],
|
"required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_system_prompt(self) -> str:
|
||||||
return CODEREVIEW_PROMPT
|
return CODEREVIEW_PROMPT
|
||||||
@@ -184,9 +183,9 @@ class CodeReviewTool(BaseTool):
|
|||||||
# Check for prompt.txt in files
|
# Check for prompt.txt in files
|
||||||
prompt_content, updated_files = self.handle_prompt_file(request.files)
|
prompt_content, updated_files = self.handle_prompt_file(request.files)
|
||||||
|
|
||||||
# If prompt.txt was found, use it as focus_on
|
# If prompt.txt was found, incorporate it into the prompt
|
||||||
if prompt_content:
|
if prompt_content:
|
||||||
request.focus_on = prompt_content
|
request.prompt = prompt_content + "\n\n" + request.prompt
|
||||||
|
|
||||||
# Update request files list
|
# Update request files list
|
||||||
if updated_files is not None:
|
if updated_files is not None:
|
||||||
@@ -234,7 +233,7 @@ class CodeReviewTool(BaseTool):
|
|||||||
full_prompt = f"""{self.get_system_prompt()}{websearch_instruction}
|
full_prompt = f"""{self.get_system_prompt()}{websearch_instruction}
|
||||||
|
|
||||||
=== USER CONTEXT ===
|
=== USER CONTEXT ===
|
||||||
{request.context}
|
{request.prompt}
|
||||||
=== END CONTEXT ===
|
=== END CONTEXT ===
|
||||||
|
|
||||||
{focus_instruction}
|
{focus_instruction}
|
||||||
@@ -247,27 +246,19 @@ Please provide a code review aligned with the user's context and expectations, f
|
|||||||
|
|
||||||
return full_prompt
|
return full_prompt
|
||||||
|
|
||||||
def format_response(self, response: str, request: CodeReviewRequest) -> str:
|
def format_response(self, response: str, request: CodeReviewRequest, model_info: Optional[dict] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Format the review response with appropriate headers.
|
Format the review response.
|
||||||
|
|
||||||
Adds context about the review type and focus area to help
|
|
||||||
users understand the scope of the review.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: The raw review from the model
|
response: The raw review from the model
|
||||||
request: The original request for context
|
request: The original request for context
|
||||||
|
model_info: Optional dict with model metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Formatted response with headers
|
str: Formatted response with next steps
|
||||||
"""
|
"""
|
||||||
header = f"Code Review ({request.review_type.upper()})"
|
return f"""{response}
|
||||||
if request.focus_on:
|
|
||||||
header += f" - Focus: {request.focus_on}"
|
|
||||||
return f"""{header}
|
|
||||||
{"=" * 50}
|
|
||||||
|
|
||||||
{response}
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from .models import ToolOutput
|
|||||||
class DebugIssueRequest(ToolRequest):
|
class DebugIssueRequest(ToolRequest):
|
||||||
"""Request model for debug tool"""
|
"""Request model for debug tool"""
|
||||||
|
|
||||||
error_description: str = Field(..., description="Error message, symptoms, or issue description")
|
prompt: str = Field(..., description="Error message, symptoms, or issue description")
|
||||||
error_context: Optional[str] = Field(None, description="Stack trace, logs, or additional error context")
|
error_context: Optional[str] = Field(None, description="Stack trace, logs, or additional error context")
|
||||||
files: Optional[list[str]] = Field(
|
files: Optional[list[str]] = Field(
|
||||||
None,
|
None,
|
||||||
@@ -38,7 +38,7 @@ class DebugIssueTool(BaseTool):
|
|||||||
"DEBUG & ROOT CAUSE ANALYSIS - Expert debugging for complex issues with 1M token capacity. "
|
"DEBUG & ROOT CAUSE ANALYSIS - Expert debugging for complex issues with 1M token capacity. "
|
||||||
"Use this when you need to debug code, find out why something is failing, identify root causes, "
|
"Use this when you need to debug code, find out why something is failing, identify root causes, "
|
||||||
"trace errors, or diagnose issues. "
|
"trace errors, or diagnose issues. "
|
||||||
"IMPORTANT: Share diagnostic files liberally! Gemini can handle up to 1M tokens, so include: "
|
"IMPORTANT: Share diagnostic files liberally! The model can handle up to 1M tokens, so include: "
|
||||||
"large log files, full stack traces, memory dumps, diagnostic outputs, multiple related files, "
|
"large log files, full stack traces, memory dumps, diagnostic outputs, multiple related files, "
|
||||||
"entire modules, test results, configuration files - anything that might help debug the issue. "
|
"entire modules, test results, configuration files - anything that might help debug the issue. "
|
||||||
"Claude should proactively use this tool whenever debugging is needed and share comprehensive "
|
"Claude should proactively use this tool whenever debugging is needed and share comprehensive "
|
||||||
@@ -50,19 +50,16 @@ class DebugIssueTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_input_schema(self) -> dict[str, Any]:
|
def get_input_schema(self) -> dict[str, Any]:
|
||||||
from config import DEFAULT_MODEL
|
from config import IS_AUTO_MODE
|
||||||
|
|
||||||
return {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"error_description": {
|
"prompt": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Error message, symptoms, or issue description",
|
"description": "Error message, symptoms, or issue description",
|
||||||
},
|
},
|
||||||
"model": {
|
"model": self.get_model_field_schema(),
|
||||||
"type": "string",
|
|
||||||
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
|
|
||||||
},
|
|
||||||
"error_context": {
|
"error_context": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Stack trace, logs, or additional error context",
|
"description": "Stack trace, logs, or additional error context",
|
||||||
@@ -101,8 +98,10 @@ class DebugIssueTool(BaseTool):
|
|||||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["error_description"],
|
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_system_prompt(self) -> str:
|
||||||
return DEBUG_ISSUE_PROMPT
|
return DEBUG_ISSUE_PROMPT
|
||||||
@@ -119,8 +118,8 @@ class DebugIssueTool(BaseTool):
|
|||||||
request_model = self.get_request_model()
|
request_model = self.get_request_model()
|
||||||
request = request_model(**arguments)
|
request = request_model(**arguments)
|
||||||
|
|
||||||
# Check error_description size
|
# Check prompt size
|
||||||
size_check = self.check_prompt_size(request.error_description)
|
size_check = self.check_prompt_size(request.prompt)
|
||||||
if size_check:
|
if size_check:
|
||||||
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
|
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
|
||||||
|
|
||||||
@@ -138,11 +137,10 @@ class DebugIssueTool(BaseTool):
|
|||||||
# Check for prompt.txt in files
|
# Check for prompt.txt in files
|
||||||
prompt_content, updated_files = self.handle_prompt_file(request.files)
|
prompt_content, updated_files = self.handle_prompt_file(request.files)
|
||||||
|
|
||||||
# If prompt.txt was found, use it as error_description or error_context
|
# If prompt.txt was found, use it as prompt or error_context
|
||||||
# Priority: if error_description is empty, use it there, otherwise use as error_context
|
|
||||||
if prompt_content:
|
if prompt_content:
|
||||||
if not request.error_description or request.error_description == "":
|
if not request.prompt or request.prompt == "":
|
||||||
request.error_description = prompt_content
|
request.prompt = prompt_content
|
||||||
else:
|
else:
|
||||||
request.error_context = prompt_content
|
request.error_context = prompt_content
|
||||||
|
|
||||||
@@ -151,7 +149,7 @@ class DebugIssueTool(BaseTool):
|
|||||||
request.files = updated_files
|
request.files = updated_files
|
||||||
|
|
||||||
# Build context sections
|
# Build context sections
|
||||||
context_parts = [f"=== ISSUE DESCRIPTION ===\n{request.error_description}\n=== END DESCRIPTION ==="]
|
context_parts = [f"=== ISSUE DESCRIPTION ===\n{request.prompt}\n=== END DESCRIPTION ==="]
|
||||||
|
|
||||||
if request.error_context:
|
if request.error_context:
|
||||||
context_parts.append(f"\n=== ERROR CONTEXT/STACK TRACE ===\n{request.error_context}\n=== END CONTEXT ===")
|
context_parts.append(f"\n=== ERROR CONTEXT/STACK TRACE ===\n{request.error_context}\n=== END CONTEXT ===")
|
||||||
@@ -197,11 +195,15 @@ Focus on finding the root cause and providing actionable solutions."""
|
|||||||
|
|
||||||
return full_prompt
|
return full_prompt
|
||||||
|
|
||||||
def format_response(self, response: str, request: DebugIssueRequest) -> str:
|
def format_response(self, response: str, request: DebugIssueRequest, model_info: Optional[dict] = None) -> str:
|
||||||
"""Format the debugging response"""
|
"""Format the debugging response"""
|
||||||
return (
|
# Get the friendly model name
|
||||||
f"Debug Analysis\n{'=' * 50}\n\n{response}\n\n---\n\n"
|
model_name = "the model"
|
||||||
"**Next Steps:** Evaluate Gemini's recommendations, synthesize the best fix considering potential "
|
if model_info and model_info.get("model_response"):
|
||||||
"regressions, and if the root cause has been clearly identified, proceed with implementing the "
|
model_name = model_info["model_response"].friendly_name or "the model"
|
||||||
"potential fixes."
|
|
||||||
)
|
return f"""{response}
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Next Steps:** Evaluate {model_name}'s recommendations, synthesize the best fix considering potential regressions, and if the root cause has been clearly identified, proceed with implementing the potential fixes."""
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class PrecommitRequest(ToolRequest):
|
|||||||
...,
|
...,
|
||||||
description="Starting directory to search for git repositories (must be absolute path).",
|
description="Starting directory to search for git repositories (must be absolute path).",
|
||||||
)
|
)
|
||||||
original_request: Optional[str] = Field(
|
prompt: Optional[str] = Field(
|
||||||
None,
|
None,
|
||||||
description="The original user request description for the changes. Provides critical context for the review.",
|
description="The original user request description for the changes. Provides critical context for the review.",
|
||||||
)
|
)
|
||||||
@@ -98,15 +98,17 @@ class Precommit(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_input_schema(self) -> dict[str, Any]:
|
def get_input_schema(self) -> dict[str, Any]:
|
||||||
from config import DEFAULT_MODEL
|
from config import IS_AUTO_MODE
|
||||||
|
|
||||||
schema = self.get_request_model().model_json_schema()
|
schema = self.get_request_model().model_json_schema()
|
||||||
# Ensure model parameter has enhanced description
|
# Ensure model parameter has enhanced description
|
||||||
if "properties" in schema and "model" in schema["properties"]:
|
if "properties" in schema and "model" in schema["properties"]:
|
||||||
schema["properties"]["model"] = {
|
schema["properties"]["model"] = self.get_model_field_schema()
|
||||||
"type": "string",
|
|
||||||
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
|
# In auto mode, model is required
|
||||||
}
|
if IS_AUTO_MODE and "required" in schema:
|
||||||
|
if "model" not in schema["required"]:
|
||||||
|
schema["required"].append("model")
|
||||||
# Ensure use_websearch is in the schema with proper description
|
# Ensure use_websearch is in the schema with proper description
|
||||||
if "properties" in schema and "use_websearch" not in schema["properties"]:
|
if "properties" in schema and "use_websearch" not in schema["properties"]:
|
||||||
schema["properties"]["use_websearch"] = {
|
schema["properties"]["use_websearch"] = {
|
||||||
@@ -140,9 +142,9 @@ class Precommit(BaseTool):
|
|||||||
request_model = self.get_request_model()
|
request_model = self.get_request_model()
|
||||||
request = request_model(**arguments)
|
request = request_model(**arguments)
|
||||||
|
|
||||||
# Check original_request size if provided
|
# Check prompt size if provided
|
||||||
if request.original_request:
|
if request.prompt:
|
||||||
size_check = self.check_prompt_size(request.original_request)
|
size_check = self.check_prompt_size(request.prompt)
|
||||||
if size_check:
|
if size_check:
|
||||||
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
|
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
|
||||||
|
|
||||||
@@ -154,9 +156,9 @@ class Precommit(BaseTool):
|
|||||||
# Check for prompt.txt in files
|
# Check for prompt.txt in files
|
||||||
prompt_content, updated_files = self.handle_prompt_file(request.files)
|
prompt_content, updated_files = self.handle_prompt_file(request.files)
|
||||||
|
|
||||||
# If prompt.txt was found, use it as original_request
|
# If prompt.txt was found, use it as prompt
|
||||||
if prompt_content:
|
if prompt_content:
|
||||||
request.original_request = prompt_content
|
request.prompt = prompt_content
|
||||||
|
|
||||||
# Update request files list
|
# Update request files list
|
||||||
if updated_files is not None:
|
if updated_files is not None:
|
||||||
@@ -338,8 +340,8 @@ class Precommit(BaseTool):
|
|||||||
prompt_parts = []
|
prompt_parts = []
|
||||||
|
|
||||||
# Add original request context if provided
|
# Add original request context if provided
|
||||||
if request.original_request:
|
if request.prompt:
|
||||||
prompt_parts.append(f"## Original Request\n\n{request.original_request}\n")
|
prompt_parts.append(f"## Original Request\n\n{request.prompt}\n")
|
||||||
|
|
||||||
# Add review parameters
|
# Add review parameters
|
||||||
prompt_parts.append("## Review Parameters\n")
|
prompt_parts.append("## Review Parameters\n")
|
||||||
@@ -443,6 +445,6 @@ class Precommit(BaseTool):
|
|||||||
|
|
||||||
return full_prompt
|
return full_prompt
|
||||||
|
|
||||||
def format_response(self, response: str, request: PrecommitRequest) -> str:
|
def format_response(self, response: str, request: PrecommitRequest, model_info: Optional[dict] = None) -> str:
|
||||||
"""Format the response with commit guidance"""
|
"""Format the response with commit guidance"""
|
||||||
return f"{response}\n\n---\n\n**Commit Status:** If no critical issues found, changes are ready for commit. Otherwise, address issues first and re-run review. Check with user before proceeding with any commit."
|
return f"{response}\n\n---\n\n**Commit Status:** If no critical issues found, changes are ready for commit. Otherwise, address issues first and re-run review. Check with user before proceeding with any commit."
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from .models import ToolOutput
|
|||||||
class ThinkDeepRequest(ToolRequest):
|
class ThinkDeepRequest(ToolRequest):
|
||||||
"""Request model for thinkdeep tool"""
|
"""Request model for thinkdeep tool"""
|
||||||
|
|
||||||
current_analysis: str = Field(..., description="Claude's current thinking/analysis to extend")
|
prompt: str = Field(..., description="Your current thinking/analysis to extend and validate")
|
||||||
problem_context: Optional[str] = Field(None, description="Additional context about the problem or goal")
|
problem_context: Optional[str] = Field(None, description="Additional context about the problem or goal")
|
||||||
focus_areas: Optional[list[str]] = Field(
|
focus_areas: Optional[list[str]] = Field(
|
||||||
None,
|
None,
|
||||||
@@ -48,19 +48,16 @@ class ThinkDeepTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_input_schema(self) -> dict[str, Any]:
|
def get_input_schema(self) -> dict[str, Any]:
|
||||||
from config import DEFAULT_MODEL
|
from config import IS_AUTO_MODE
|
||||||
|
|
||||||
return {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"current_analysis": {
|
"prompt": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Your current thinking/analysis to extend and validate",
|
"description": "Your current thinking/analysis to extend and validate",
|
||||||
},
|
},
|
||||||
"model": {
|
"model": self.get_model_field_schema(),
|
||||||
"type": "string",
|
|
||||||
"description": f"Model to use: 'pro' (Gemini 2.5 Pro with extended thinking) or 'flash' (Gemini 2.0 Flash - faster). Defaults to '{DEFAULT_MODEL}' if not specified.",
|
|
||||||
},
|
|
||||||
"problem_context": {
|
"problem_context": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Additional context about the problem or goal",
|
"description": "Additional context about the problem or goal",
|
||||||
@@ -96,8 +93,10 @@ class ThinkDeepTool(BaseTool):
|
|||||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["current_analysis"],
|
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_system_prompt(self) -> str:
|
||||||
return THINKDEEP_PROMPT
|
return THINKDEEP_PROMPT
|
||||||
@@ -120,8 +119,8 @@ class ThinkDeepTool(BaseTool):
|
|||||||
request_model = self.get_request_model()
|
request_model = self.get_request_model()
|
||||||
request = request_model(**arguments)
|
request = request_model(**arguments)
|
||||||
|
|
||||||
# Check current_analysis size
|
# Check prompt size
|
||||||
size_check = self.check_prompt_size(request.current_analysis)
|
size_check = self.check_prompt_size(request.prompt)
|
||||||
if size_check:
|
if size_check:
|
||||||
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
|
return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())]
|
||||||
|
|
||||||
@@ -133,8 +132,8 @@ class ThinkDeepTool(BaseTool):
|
|||||||
# Check for prompt.txt in files
|
# Check for prompt.txt in files
|
||||||
prompt_content, updated_files = self.handle_prompt_file(request.files)
|
prompt_content, updated_files = self.handle_prompt_file(request.files)
|
||||||
|
|
||||||
# Use prompt.txt content if available, otherwise use the current_analysis field
|
# Use prompt.txt content if available, otherwise use the prompt field
|
||||||
current_analysis = prompt_content if prompt_content else request.current_analysis
|
current_analysis = prompt_content if prompt_content else request.prompt
|
||||||
|
|
||||||
# Update request files list
|
# Update request files list
|
||||||
if updated_files is not None:
|
if updated_files is not None:
|
||||||
@@ -190,21 +189,24 @@ Please provide deep analysis that extends Claude's thinking with:
|
|||||||
|
|
||||||
return full_prompt
|
return full_prompt
|
||||||
|
|
||||||
def format_response(self, response: str, request: ThinkDeepRequest) -> str:
|
def format_response(self, response: str, request: ThinkDeepRequest, model_info: Optional[dict] = None) -> str:
|
||||||
"""Format the response with clear attribution and critical thinking prompt"""
|
"""Format the response with clear attribution and critical thinking prompt"""
|
||||||
return f"""## Extended Analysis by Gemini
|
# Get the friendly model name
|
||||||
|
model_name = "your fellow developer"
|
||||||
{response}
|
if model_info and model_info.get("model_response"):
|
||||||
|
model_name = model_info["model_response"].friendly_name or "your fellow developer"
|
||||||
|
|
||||||
|
return f"""{response}
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Critical Evaluation Required
|
## Critical Evaluation Required
|
||||||
|
|
||||||
Claude, please critically evaluate Gemini's analysis by considering:
|
Claude, please critically evaluate {model_name}'s analysis by thinking hard about the following:
|
||||||
|
|
||||||
1. **Technical merit** - Which suggestions are valuable vs. have limitations?
|
1. **Technical merit** - Which suggestions are valuable vs. have limitations?
|
||||||
2. **Constraints** - Fit with codebase patterns, performance, security, architecture
|
2. **Constraints** - Fit with codebase patterns, performance, security, architecture
|
||||||
3. **Risks** - Hidden complexities, edge cases, potential failure modes
|
3. **Risks** - Hidden complexities, edge cases, potential failure modes
|
||||||
4. **Final recommendation** - Synthesize both perspectives, then think deeply further to explore additional considerations and arrive at the best technical solution
|
4. **Final recommendation** - Synthesize both perspectives, then think deeply further to explore additional considerations and arrive at the best technical solution
|
||||||
|
|
||||||
Remember: Use Gemini's insights to enhance, not replace, your analysis."""
|
Remember: Use {model_name}'s insights to enhance, not replace, your analysis."""
|
||||||
|
|||||||
@@ -68,12 +68,15 @@ class ConversationTurn(BaseModel):
|
|||||||
the content and metadata needed for cross-tool continuation.
|
the content and metadata needed for cross-tool continuation.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
role: "user" (Claude) or "assistant" (Gemini)
|
role: "user" (Claude) or "assistant" (Gemini/O3/etc)
|
||||||
content: The actual message content/response
|
content: The actual message content/response
|
||||||
timestamp: ISO timestamp when this turn was created
|
timestamp: ISO timestamp when this turn was created
|
||||||
follow_up_question: Optional follow-up question from Gemini to Claude
|
follow_up_question: Optional follow-up question from assistant to Claude
|
||||||
files: List of file paths referenced in this specific turn
|
files: List of file paths referenced in this specific turn
|
||||||
tool_name: Which tool generated this turn (for cross-tool tracking)
|
tool_name: Which tool generated this turn (for cross-tool tracking)
|
||||||
|
model_provider: Provider used (e.g., "google", "openai")
|
||||||
|
model_name: Specific model used (e.g., "gemini-2.0-flash-exp", "o3-mini")
|
||||||
|
model_metadata: Additional model-specific metadata (e.g., thinking mode, token usage)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: str # "user" or "assistant"
|
role: str # "user" or "assistant"
|
||||||
@@ -82,6 +85,9 @@ class ConversationTurn(BaseModel):
|
|||||||
follow_up_question: Optional[str] = None
|
follow_up_question: Optional[str] = None
|
||||||
files: Optional[list[str]] = None # Files referenced in this turn
|
files: Optional[list[str]] = None # Files referenced in this turn
|
||||||
tool_name: Optional[str] = None # Tool used for this turn
|
tool_name: Optional[str] = None # Tool used for this turn
|
||||||
|
model_provider: Optional[str] = None # Model provider (google, openai, etc)
|
||||||
|
model_name: Optional[str] = None # Specific model used
|
||||||
|
model_metadata: Optional[dict[str, Any]] = None # Additional model info
|
||||||
|
|
||||||
|
|
||||||
class ThreadContext(BaseModel):
|
class ThreadContext(BaseModel):
|
||||||
@@ -94,6 +100,7 @@ class ThreadContext(BaseModel):
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
thread_id: UUID identifying this conversation thread
|
thread_id: UUID identifying this conversation thread
|
||||||
|
parent_thread_id: UUID of parent thread (for conversation chains)
|
||||||
created_at: ISO timestamp when thread was created
|
created_at: ISO timestamp when thread was created
|
||||||
last_updated_at: ISO timestamp of last modification
|
last_updated_at: ISO timestamp of last modification
|
||||||
tool_name: Name of the tool that initiated this thread
|
tool_name: Name of the tool that initiated this thread
|
||||||
@@ -102,6 +109,7 @@ class ThreadContext(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
thread_id: str
|
thread_id: str
|
||||||
|
parent_thread_id: Optional[str] = None # Parent thread for conversation chains
|
||||||
created_at: str
|
created_at: str
|
||||||
last_updated_at: str
|
last_updated_at: str
|
||||||
tool_name: str # Tool that created this thread (preserved for attribution)
|
tool_name: str # Tool that created this thread (preserved for attribution)
|
||||||
@@ -131,7 +139,7 @@ def get_redis_client():
|
|||||||
raise ValueError("redis package required. Install with: pip install redis")
|
raise ValueError("redis package required. Install with: pip install redis")
|
||||||
|
|
||||||
|
|
||||||
def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
|
def create_thread(tool_name: str, initial_request: dict[str, Any], parent_thread_id: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Create new conversation thread and return thread ID
|
Create new conversation thread and return thread ID
|
||||||
|
|
||||||
@@ -142,6 +150,7 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
|
|||||||
Args:
|
Args:
|
||||||
tool_name: Name of the tool creating this thread (e.g., "analyze", "chat")
|
tool_name: Name of the tool creating this thread (e.g., "analyze", "chat")
|
||||||
initial_request: Original request parameters (will be filtered for serialization)
|
initial_request: Original request parameters (will be filtered for serialization)
|
||||||
|
parent_thread_id: Optional parent thread ID for conversation chains
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: UUID thread identifier that can be used for continuation
|
str: UUID thread identifier that can be used for continuation
|
||||||
@@ -150,6 +159,7 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
|
|||||||
- Thread expires after 1 hour (3600 seconds)
|
- Thread expires after 1 hour (3600 seconds)
|
||||||
- Non-serializable parameters are filtered out automatically
|
- Non-serializable parameters are filtered out automatically
|
||||||
- Thread can be continued by any tool using the returned UUID
|
- Thread can be continued by any tool using the returned UUID
|
||||||
|
- Parent thread creates a chain for conversation history traversal
|
||||||
"""
|
"""
|
||||||
thread_id = str(uuid.uuid4())
|
thread_id = str(uuid.uuid4())
|
||||||
now = datetime.now(timezone.utc).isoformat()
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
@@ -163,6 +173,7 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
|
|||||||
|
|
||||||
context = ThreadContext(
|
context = ThreadContext(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
parent_thread_id=parent_thread_id, # Link to parent for conversation chains
|
||||||
created_at=now,
|
created_at=now,
|
||||||
last_updated_at=now,
|
last_updated_at=now,
|
||||||
tool_name=tool_name, # Track which tool initiated this conversation
|
tool_name=tool_name, # Track which tool initiated this conversation
|
||||||
@@ -175,6 +186,8 @@ def create_thread(tool_name: str, initial_request: dict[str, Any]) -> str:
|
|||||||
key = f"thread:{thread_id}"
|
key = f"thread:{thread_id}"
|
||||||
client.setex(key, 3600, context.model_dump_json())
|
client.setex(key, 3600, context.model_dump_json())
|
||||||
|
|
||||||
|
logger.debug(f"[THREAD] Created new thread {thread_id} with parent {parent_thread_id}")
|
||||||
|
|
||||||
return thread_id
|
return thread_id
|
||||||
|
|
||||||
|
|
||||||
@@ -221,34 +234,41 @@ def add_turn(
|
|||||||
follow_up_question: Optional[str] = None,
|
follow_up_question: Optional[str] = None,
|
||||||
files: Optional[list[str]] = None,
|
files: Optional[list[str]] = None,
|
||||||
tool_name: Optional[str] = None,
|
tool_name: Optional[str] = None,
|
||||||
|
model_provider: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
model_metadata: Optional[dict[str, Any]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Add turn to existing thread
|
Add turn to existing thread
|
||||||
|
|
||||||
Appends a new conversation turn to an existing thread. This is the core
|
Appends a new conversation turn to an existing thread. This is the core
|
||||||
function for building conversation history and enabling cross-tool
|
function for building conversation history and enabling cross-tool
|
||||||
continuation. Each turn preserves the tool that generated it.
|
continuation. Each turn preserves the tool and model that generated it.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: UUID of the conversation thread
|
thread_id: UUID of the conversation thread
|
||||||
role: "user" (Claude) or "assistant" (Gemini)
|
role: "user" (Claude) or "assistant" (Gemini/O3/etc)
|
||||||
content: The actual message/response content
|
content: The actual message/response content
|
||||||
follow_up_question: Optional follow-up question from Gemini
|
follow_up_question: Optional follow-up question from assistant
|
||||||
files: Optional list of files referenced in this turn
|
files: Optional list of files referenced in this turn
|
||||||
tool_name: Name of the tool adding this turn (for attribution)
|
tool_name: Name of the tool adding this turn (for attribution)
|
||||||
|
model_provider: Provider used (e.g., "google", "openai")
|
||||||
|
model_name: Specific model used (e.g., "gemini-2.0-flash-exp", "o3-mini")
|
||||||
|
model_metadata: Additional model info (e.g., thinking mode, token usage)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if turn was successfully added, False otherwise
|
bool: True if turn was successfully added, False otherwise
|
||||||
|
|
||||||
Failure cases:
|
Failure cases:
|
||||||
- Thread doesn't exist or expired
|
- Thread doesn't exist or expired
|
||||||
- Maximum turn limit reached (5 turns)
|
- Maximum turn limit reached
|
||||||
- Redis connection failure
|
- Redis connection failure
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
- Refreshes thread TTL to 1 hour on successful update
|
- Refreshes thread TTL to 1 hour on successful update
|
||||||
- Turn limits prevent runaway conversations
|
- Turn limits prevent runaway conversations
|
||||||
- File references are preserved for cross-tool access
|
- File references are preserved for cross-tool access
|
||||||
|
- Model information enables cross-provider conversations
|
||||||
"""
|
"""
|
||||||
logger.debug(f"[FLOW] Adding {role} turn to {thread_id} ({tool_name})")
|
logger.debug(f"[FLOW] Adding {role} turn to {thread_id} ({tool_name})")
|
||||||
|
|
||||||
@@ -270,6 +290,9 @@ def add_turn(
|
|||||||
follow_up_question=follow_up_question,
|
follow_up_question=follow_up_question,
|
||||||
files=files, # Preserved for cross-tool file context
|
files=files, # Preserved for cross-tool file context
|
||||||
tool_name=tool_name, # Track which tool generated this turn
|
tool_name=tool_name, # Track which tool generated this turn
|
||||||
|
model_provider=model_provider, # Track model provider
|
||||||
|
model_name=model_name, # Track specific model
|
||||||
|
model_metadata=model_metadata, # Additional model info
|
||||||
)
|
)
|
||||||
|
|
||||||
context.turns.append(turn)
|
context.turns.append(turn)
|
||||||
@@ -286,6 +309,48 @@ def add_turn(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_chain(thread_id: str, max_depth: int = 20) -> list[ThreadContext]:
|
||||||
|
"""
|
||||||
|
Traverse the parent chain to get all threads in conversation sequence.
|
||||||
|
|
||||||
|
Retrieves the complete conversation chain by following parent_thread_id
|
||||||
|
links. Returns threads in chronological order (oldest first).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Starting thread ID
|
||||||
|
max_depth: Maximum chain depth to prevent infinite loops
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[ThreadContext]: All threads in chain, oldest first
|
||||||
|
"""
|
||||||
|
chain = []
|
||||||
|
current_id = thread_id
|
||||||
|
seen_ids = set()
|
||||||
|
|
||||||
|
# Build chain from current to oldest
|
||||||
|
while current_id and len(chain) < max_depth:
|
||||||
|
# Prevent circular references
|
||||||
|
if current_id in seen_ids:
|
||||||
|
logger.warning(f"[THREAD] Circular reference detected in thread chain at {current_id}")
|
||||||
|
break
|
||||||
|
|
||||||
|
seen_ids.add(current_id)
|
||||||
|
|
||||||
|
context = get_thread(current_id)
|
||||||
|
if not context:
|
||||||
|
logger.debug(f"[THREAD] Thread {current_id} not found in chain traversal")
|
||||||
|
break
|
||||||
|
|
||||||
|
chain.append(context)
|
||||||
|
current_id = context.parent_thread_id
|
||||||
|
|
||||||
|
# Reverse to get chronological order (oldest first)
|
||||||
|
chain.reverse()
|
||||||
|
|
||||||
|
logger.debug(f"[THREAD] Retrieved chain of {len(chain)} threads for {thread_id}")
|
||||||
|
return chain
|
||||||
|
|
||||||
|
|
||||||
def get_conversation_file_list(context: ThreadContext) -> list[str]:
|
def get_conversation_file_list(context: ThreadContext) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Get all unique files referenced across all turns in a conversation.
|
Get all unique files referenced across all turns in a conversation.
|
||||||
@@ -327,7 +392,7 @@ def get_conversation_file_list(context: ThreadContext) -> list[str]:
|
|||||||
return unique_files
|
return unique_files
|
||||||
|
|
||||||
|
|
||||||
def build_conversation_history(context: ThreadContext, read_files_func=None) -> tuple[str, int]:
|
def build_conversation_history(context: ThreadContext, model_context=None, read_files_func=None) -> tuple[str, int]:
|
||||||
"""
|
"""
|
||||||
Build formatted conversation history for tool prompts with embedded file contents.
|
Build formatted conversation history for tool prompts with embedded file contents.
|
||||||
|
|
||||||
@@ -335,9 +400,14 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
|
|||||||
full file contents from all referenced files. Files are embedded only ONCE at the
|
full file contents from all referenced files. Files are embedded only ONCE at the
|
||||||
start, even if referenced in multiple turns, to prevent duplication and optimize
|
start, even if referenced in multiple turns, to prevent duplication and optimize
|
||||||
token usage.
|
token usage.
|
||||||
|
|
||||||
|
If the thread has a parent chain, this function traverses the entire chain to
|
||||||
|
include the complete conversation history.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: ThreadContext containing the complete conversation
|
context: ThreadContext containing the complete conversation
|
||||||
|
model_context: ModelContext for token allocation (optional, uses DEFAULT_MODEL if not provided)
|
||||||
|
read_files_func: Optional function to read files (for testing)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[str, int]: (formatted_conversation_history, total_tokens_used)
|
tuple[str, int]: (formatted_conversation_history, total_tokens_used)
|
||||||
@@ -355,18 +425,57 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
|
|||||||
file contents from previous tools, enabling true cross-tool collaboration
|
file contents from previous tools, enabling true cross-tool collaboration
|
||||||
while preventing duplicate file embeddings.
|
while preventing duplicate file embeddings.
|
||||||
"""
|
"""
|
||||||
if not context.turns:
|
# Get the complete thread chain
|
||||||
|
if context.parent_thread_id:
|
||||||
|
# This thread has a parent, get the full chain
|
||||||
|
chain = get_thread_chain(context.thread_id)
|
||||||
|
|
||||||
|
# Collect all turns from all threads in chain
|
||||||
|
all_turns = []
|
||||||
|
all_files_set = set()
|
||||||
|
total_turns = 0
|
||||||
|
|
||||||
|
for thread in chain:
|
||||||
|
all_turns.extend(thread.turns)
|
||||||
|
total_turns += len(thread.turns)
|
||||||
|
|
||||||
|
# Collect files from this thread
|
||||||
|
for turn in thread.turns:
|
||||||
|
if turn.files:
|
||||||
|
all_files_set.update(turn.files)
|
||||||
|
|
||||||
|
all_files = list(all_files_set)
|
||||||
|
logger.debug(f"[THREAD] Built history from {len(chain)} threads with {total_turns} total turns")
|
||||||
|
else:
|
||||||
|
# Single thread, no parent chain
|
||||||
|
all_turns = context.turns
|
||||||
|
total_turns = len(context.turns)
|
||||||
|
all_files = get_conversation_file_list(context)
|
||||||
|
|
||||||
|
if not all_turns:
|
||||||
return "", 0
|
return "", 0
|
||||||
|
|
||||||
# Get all unique files referenced in this conversation
|
|
||||||
all_files = get_conversation_file_list(context)
|
|
||||||
logger.debug(f"[FILES] Found {len(all_files)} unique files in conversation history")
|
logger.debug(f"[FILES] Found {len(all_files)} unique files in conversation history")
|
||||||
|
|
||||||
|
# Get model-specific token allocation early (needed for both files and turns)
|
||||||
|
if model_context is None:
|
||||||
|
from utils.model_context import ModelContext
|
||||||
|
from config import DEFAULT_MODEL
|
||||||
|
model_context = ModelContext(DEFAULT_MODEL)
|
||||||
|
|
||||||
|
token_allocation = model_context.calculate_token_allocation()
|
||||||
|
max_file_tokens = token_allocation.file_tokens
|
||||||
|
max_history_tokens = token_allocation.history_tokens
|
||||||
|
|
||||||
|
logger.debug(f"[HISTORY] Using model-specific limits for {model_context.model_name}:")
|
||||||
|
logger.debug(f"[HISTORY] Max file tokens: {max_file_tokens:,}")
|
||||||
|
logger.debug(f"[HISTORY] Max history tokens: {max_history_tokens:,}")
|
||||||
|
|
||||||
history_parts = [
|
history_parts = [
|
||||||
"=== CONVERSATION HISTORY ===",
|
"=== CONVERSATION HISTORY ===",
|
||||||
f"Thread: {context.thread_id}",
|
f"Thread: {context.thread_id}",
|
||||||
f"Tool: {context.tool_name}", # Original tool that started the conversation
|
f"Tool: {context.tool_name}", # Original tool that started the conversation
|
||||||
f"Turn {len(context.turns)}/{MAX_CONVERSATION_TURNS}",
|
f"Turn {total_turns}/{MAX_CONVERSATION_TURNS}",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -382,9 +491,6 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import required functions
|
|
||||||
from config import MAX_CONTENT_TOKENS
|
|
||||||
|
|
||||||
if read_files_func is None:
|
if read_files_func is None:
|
||||||
from utils.file_utils import read_file_content
|
from utils.file_utils import read_file_content
|
||||||
|
|
||||||
@@ -402,7 +508,7 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
|
|||||||
if formatted_content:
|
if formatted_content:
|
||||||
# read_file_content already returns formatted content, use it directly
|
# read_file_content already returns formatted content, use it directly
|
||||||
# Check if adding this file would exceed the limit
|
# Check if adding this file would exceed the limit
|
||||||
if total_tokens + content_tokens <= MAX_CONTENT_TOKENS:
|
if total_tokens + content_tokens <= max_file_tokens:
|
||||||
file_contents.append(formatted_content)
|
file_contents.append(formatted_content)
|
||||||
total_tokens += content_tokens
|
total_tokens += content_tokens
|
||||||
files_included += 1
|
files_included += 1
|
||||||
@@ -415,7 +521,7 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
|
|||||||
else:
|
else:
|
||||||
files_truncated += 1
|
files_truncated += 1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"📄 File truncated due to token limit: {file_path} ({content_tokens:,} tokens, would exceed {MAX_CONTENT_TOKENS:,} limit)"
|
f"📄 File truncated due to token limit: {file_path} ({content_tokens:,} tokens, would exceed {max_file_tokens:,} limit)"
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[FILES] File {file_path} would exceed token limit - skipping (would be {total_tokens + content_tokens:,} tokens)"
|
f"[FILES] File {file_path} would exceed token limit - skipping (would be {total_tokens + content_tokens:,} tokens)"
|
||||||
@@ -464,7 +570,7 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
|
|||||||
history_parts.append(files_content)
|
history_parts.append(files_content)
|
||||||
else:
|
else:
|
||||||
# Handle token limit exceeded for conversation files
|
# Handle token limit exceeded for conversation files
|
||||||
error_message = f"ERROR: The total size of files referenced in this conversation has exceeded the context limit and cannot be displayed.\nEstimated tokens: {estimated_tokens}, but limit is {MAX_CONTENT_TOKENS}."
|
error_message = f"ERROR: The total size of files referenced in this conversation has exceeded the context limit and cannot be displayed.\nEstimated tokens: {estimated_tokens}, but limit is {max_file_tokens}."
|
||||||
history_parts.append(error_message)
|
history_parts.append(error_message)
|
||||||
else:
|
else:
|
||||||
history_parts.append("(No accessible files found)")
|
history_parts.append("(No accessible files found)")
|
||||||
@@ -478,29 +584,79 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
|
|||||||
)
|
)
|
||||||
|
|
||||||
history_parts.append("Previous conversation turns:")
|
history_parts.append("Previous conversation turns:")
|
||||||
|
|
||||||
for i, turn in enumerate(context.turns, 1):
|
# Build conversation turns bottom-up (most recent first) but present chronologically
|
||||||
|
# This ensures we include as many recent turns as possible within the token budget
|
||||||
|
turn_entries = [] # Will store (index, formatted_turn_content) for chronological ordering
|
||||||
|
total_turn_tokens = 0
|
||||||
|
file_embedding_tokens = sum(model_context.estimate_tokens(part) for part in history_parts)
|
||||||
|
|
||||||
|
# Process turns in reverse order (most recent first) to prioritize recent context
|
||||||
|
for idx in range(len(all_turns) - 1, -1, -1):
|
||||||
|
turn = all_turns[idx]
|
||||||
|
turn_num = idx + 1
|
||||||
role_label = "Claude" if turn.role == "user" else "Gemini"
|
role_label = "Claude" if turn.role == "user" else "Gemini"
|
||||||
|
|
||||||
|
# Build the complete turn content
|
||||||
|
turn_parts = []
|
||||||
|
|
||||||
# Add turn header with tool attribution for cross-tool tracking
|
# Add turn header with tool attribution for cross-tool tracking
|
||||||
turn_header = f"\n--- Turn {i} ({role_label}"
|
turn_header = f"\n--- Turn {turn_num} ({role_label}"
|
||||||
if turn.tool_name:
|
if turn.tool_name:
|
||||||
turn_header += f" using {turn.tool_name}"
|
turn_header += f" using {turn.tool_name}"
|
||||||
|
|
||||||
|
# Add model info if available
|
||||||
|
if turn.model_provider and turn.model_name:
|
||||||
|
turn_header += f" via {turn.model_provider}/{turn.model_name}"
|
||||||
|
|
||||||
turn_header += ") ---"
|
turn_header += ") ---"
|
||||||
history_parts.append(turn_header)
|
turn_parts.append(turn_header)
|
||||||
|
|
||||||
# Add files context if present - but just reference which files were used
|
# Add files context if present - but just reference which files were used
|
||||||
# (the actual contents are already embedded above)
|
# (the actual contents are already embedded above)
|
||||||
if turn.files:
|
if turn.files:
|
||||||
history_parts.append(f"📁 Files used in this turn: {', '.join(turn.files)}")
|
turn_parts.append(f"📁 Files used in this turn: {', '.join(turn.files)}")
|
||||||
history_parts.append("") # Empty line for readability
|
turn_parts.append("") # Empty line for readability
|
||||||
|
|
||||||
# Add the actual content
|
# Add the actual content
|
||||||
history_parts.append(turn.content)
|
turn_parts.append(turn.content)
|
||||||
|
|
||||||
# Add follow-up question if present
|
# Add follow-up question if present
|
||||||
if turn.follow_up_question:
|
if turn.follow_up_question:
|
||||||
history_parts.append(f"\n[Gemini's Follow-up: {turn.follow_up_question}]")
|
turn_parts.append(f"\n[Gemini's Follow-up: {turn.follow_up_question}]")
|
||||||
|
|
||||||
|
# Calculate tokens for this turn
|
||||||
|
turn_content = "\n".join(turn_parts)
|
||||||
|
turn_tokens = model_context.estimate_tokens(turn_content)
|
||||||
|
|
||||||
|
# Check if adding this turn would exceed history budget
|
||||||
|
if file_embedding_tokens + total_turn_tokens + turn_tokens > max_history_tokens:
|
||||||
|
# Stop adding turns - we've reached the limit
|
||||||
|
logger.debug(f"[HISTORY] Stopping at turn {turn_num} - would exceed history budget")
|
||||||
|
logger.debug(f"[HISTORY] File tokens: {file_embedding_tokens:,}")
|
||||||
|
logger.debug(f"[HISTORY] Turn tokens so far: {total_turn_tokens:,}")
|
||||||
|
logger.debug(f"[HISTORY] This turn: {turn_tokens:,}")
|
||||||
|
logger.debug(f"[HISTORY] Would total: {file_embedding_tokens + total_turn_tokens + turn_tokens:,}")
|
||||||
|
logger.debug(f"[HISTORY] Budget: {max_history_tokens:,}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add this turn to our list (we'll reverse it later for chronological order)
|
||||||
|
turn_entries.append((idx, turn_content))
|
||||||
|
total_turn_tokens += turn_tokens
|
||||||
|
|
||||||
|
# Reverse to get chronological order (oldest first)
|
||||||
|
turn_entries.reverse()
|
||||||
|
|
||||||
|
# Add the turns in chronological order
|
||||||
|
for _, turn_content in turn_entries:
|
||||||
|
history_parts.append(turn_content)
|
||||||
|
|
||||||
|
# Log what we included
|
||||||
|
included_turns = len(turn_entries)
|
||||||
|
total_turns = len(all_turns)
|
||||||
|
if included_turns < total_turns:
|
||||||
|
logger.info(f"[HISTORY] Included {included_turns}/{total_turns} turns due to token limit")
|
||||||
|
history_parts.append(f"\n[Note: Showing {included_turns} most recent turns out of {total_turns} total]")
|
||||||
|
|
||||||
history_parts.extend(
|
history_parts.extend(
|
||||||
["", "=== END CONVERSATION HISTORY ===", "", "Continue this conversation by building on the previous context."]
|
["", "=== END CONVERSATION HISTORY ===", "", "Continue this conversation by building on the previous context."]
|
||||||
@@ -513,8 +669,8 @@ def build_conversation_history(context: ThreadContext, read_files_func=None) ->
|
|||||||
total_conversation_tokens = estimate_tokens(complete_history)
|
total_conversation_tokens = estimate_tokens(complete_history)
|
||||||
|
|
||||||
# Summary log of what was built
|
# Summary log of what was built
|
||||||
user_turns = len([t for t in context.turns if t.role == "user"])
|
user_turns = len([t for t in all_turns if t.role == "user"])
|
||||||
assistant_turns = len([t for t in context.turns if t.role == "assistant"])
|
assistant_turns = len([t for t in all_turns if t.role == "assistant"])
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[FLOW] Built conversation history: {user_turns} user + {assistant_turns} assistant turns, {len(all_files)} files, {total_conversation_tokens:,} tokens"
|
f"[FLOW] Built conversation history: {user_turns} user + {assistant_turns} assistant turns, {len(all_files)} files, {total_conversation_tokens:,} tokens"
|
||||||
)
|
)
|
||||||
|
|||||||
130
utils/model_context.py
Normal file
130
utils/model_context.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""
|
||||||
|
Model context management for dynamic token allocation.
|
||||||
|
|
||||||
|
This module provides a clean abstraction for model-specific token management,
|
||||||
|
ensuring that token limits are properly calculated based on the current model
|
||||||
|
being used, not global constants.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from providers import ModelProviderRegistry, ModelCapabilities
|
||||||
|
from config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenAllocation:
|
||||||
|
"""Token allocation strategy for a model."""
|
||||||
|
total_tokens: int
|
||||||
|
content_tokens: int
|
||||||
|
response_tokens: int
|
||||||
|
file_tokens: int
|
||||||
|
history_tokens: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_for_prompt(self) -> int:
|
||||||
|
"""Tokens available for the actual prompt after allocations."""
|
||||||
|
return self.content_tokens - self.file_tokens - self.history_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class ModelContext:
|
||||||
|
"""
|
||||||
|
Encapsulates model-specific information and token calculations.
|
||||||
|
|
||||||
|
This class provides a single source of truth for all model-related
|
||||||
|
token calculations, ensuring consistency across the system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str):
|
||||||
|
self.model_name = model_name
|
||||||
|
self._provider = None
|
||||||
|
self._capabilities = None
|
||||||
|
self._token_allocation = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self):
|
||||||
|
"""Get the model provider lazily."""
|
||||||
|
if self._provider is None:
|
||||||
|
self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name)
|
||||||
|
if not self._provider:
|
||||||
|
raise ValueError(f"No provider found for model: {self.model_name}")
|
||||||
|
return self._provider
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capabilities(self) -> ModelCapabilities:
|
||||||
|
"""Get model capabilities lazily."""
|
||||||
|
if self._capabilities is None:
|
||||||
|
self._capabilities = self.provider.get_capabilities(self.model_name)
|
||||||
|
return self._capabilities
|
||||||
|
|
||||||
|
def calculate_token_allocation(self, reserved_for_response: Optional[int] = None) -> TokenAllocation:
|
||||||
|
"""
|
||||||
|
Calculate token allocation based on model capacity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reserved_for_response: Override response token reservation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenAllocation with calculated budgets
|
||||||
|
"""
|
||||||
|
total_tokens = self.capabilities.max_tokens
|
||||||
|
|
||||||
|
# Dynamic allocation based on model capacity
|
||||||
|
if total_tokens < 300_000:
|
||||||
|
# Smaller context models (O3, GPT-4O): Conservative allocation
|
||||||
|
content_ratio = 0.6 # 60% for content
|
||||||
|
response_ratio = 0.4 # 40% for response
|
||||||
|
file_ratio = 0.3 # 30% of content for files
|
||||||
|
history_ratio = 0.5 # 50% of content for history
|
||||||
|
else:
|
||||||
|
# Larger context models (Gemini): More generous allocation
|
||||||
|
content_ratio = 0.8 # 80% for content
|
||||||
|
response_ratio = 0.2 # 20% for response
|
||||||
|
file_ratio = 0.4 # 40% of content for files
|
||||||
|
history_ratio = 0.4 # 40% of content for history
|
||||||
|
|
||||||
|
# Calculate allocations
|
||||||
|
content_tokens = int(total_tokens * content_ratio)
|
||||||
|
response_tokens = reserved_for_response or int(total_tokens * response_ratio)
|
||||||
|
|
||||||
|
# Sub-allocations within content budget
|
||||||
|
file_tokens = int(content_tokens * file_ratio)
|
||||||
|
history_tokens = int(content_tokens * history_ratio)
|
||||||
|
|
||||||
|
allocation = TokenAllocation(
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
content_tokens=content_tokens,
|
||||||
|
response_tokens=response_tokens,
|
||||||
|
file_tokens=file_tokens,
|
||||||
|
history_tokens=history_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Token allocation for {self.model_name}:")
|
||||||
|
logger.debug(f" Total: {allocation.total_tokens:,}")
|
||||||
|
logger.debug(f" Content: {allocation.content_tokens:,} ({content_ratio:.0%})")
|
||||||
|
logger.debug(f" Response: {allocation.response_tokens:,} ({response_ratio:.0%})")
|
||||||
|
logger.debug(f" Files: {allocation.file_tokens:,} ({file_ratio:.0%} of content)")
|
||||||
|
logger.debug(f" History: {allocation.history_tokens:,} ({history_ratio:.0%} of content)")
|
||||||
|
|
||||||
|
return allocation
|
||||||
|
|
||||||
|
def estimate_tokens(self, text: str) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count for text using model-specific tokenizer.
|
||||||
|
|
||||||
|
For now, uses simple estimation. Can be enhanced with model-specific
|
||||||
|
tokenizers (tiktoken for OpenAI, etc.) in the future.
|
||||||
|
"""
|
||||||
|
# TODO: Integrate model-specific tokenizers
|
||||||
|
# For now, use conservative estimation
|
||||||
|
return len(text) // 3 # Conservative estimate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_arguments(cls, arguments: Dict[str, Any]) -> "ModelContext":
|
||||||
|
"""Create ModelContext from tool arguments."""
|
||||||
|
model_name = arguments.get("model") or DEFAULT_MODEL
|
||||||
|
return cls(model_name)
|
||||||
Reference in New Issue
Block a user