Rebranding, refactoring, renaming, cleanup, updated docs
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
# Gemini MCP Server Environment Configuration
|
# Zen MCP Server Environment Configuration
|
||||||
# Copy this file to .env and fill in your values
|
# Copy this file to .env and fill in your values
|
||||||
|
|
||||||
# API Keys - At least one is required
|
# API Keys - At least one is required
|
||||||
@@ -9,8 +9,7 @@ GEMINI_API_KEY=your_gemini_api_key_here
|
|||||||
OPENAI_API_KEY=your_openai_api_key_here
|
OPENAI_API_KEY=your_openai_api_key_here
|
||||||
|
|
||||||
# Optional: Default model to use
|
# Optional: Default model to use
|
||||||
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'gpt-4o'
|
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini'
|
||||||
# Full names: 'gemini-2.5-pro-preview-06-05' or 'gemini-2.0-flash-exp'
|
|
||||||
# When set to 'auto', Claude will select the best model for each task
|
# When set to 'auto', Claude will select the best model for each task
|
||||||
# Defaults to 'auto' if not specified
|
# Defaults to 'auto' if not specified
|
||||||
DEFAULT_MODEL=auto
|
DEFAULT_MODEL=auto
|
||||||
|
|||||||
48
.github/workflows/test.yml
vendored
48
.github/workflows/test.yml
vendored
@@ -28,12 +28,13 @@ jobs:
|
|||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: |
|
run: |
|
||||||
# Run all tests except live integration tests
|
# Run all unit tests
|
||||||
# These tests use mocks and don't require API keys
|
# These tests use mocks and don't require API keys
|
||||||
python -m pytest tests/ --ignore=tests/test_live_integration.py -v
|
python -m pytest tests/ -v
|
||||||
env:
|
env:
|
||||||
# Ensure no API key is accidentally used in CI
|
# Ensure no API key is accidentally used in CI
|
||||||
GEMINI_API_KEY: ""
|
GEMINI_API_KEY: ""
|
||||||
|
OPENAI_API_KEY: ""
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -56,9 +57,9 @@ jobs:
|
|||||||
- name: Run ruff linter
|
- name: Run ruff linter
|
||||||
run: ruff check .
|
run: ruff check .
|
||||||
|
|
||||||
live-tests:
|
simulation-tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# Only run live tests on main branch pushes (requires manual API key setup)
|
# Only run simulation tests on main branch pushes (requires manual API key setup)
|
||||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
@@ -76,24 +77,41 @@ jobs:
|
|||||||
- name: Check API key availability
|
- name: Check API key availability
|
||||||
id: check-key
|
id: check-key
|
||||||
run: |
|
run: |
|
||||||
if [ -z "${{ secrets.GEMINI_API_KEY }}" ]; then
|
has_key=false
|
||||||
echo "api_key_available=false" >> $GITHUB_OUTPUT
|
if [ -n "${{ secrets.GEMINI_API_KEY }}" ] || [ -n "${{ secrets.OPENAI_API_KEY }}" ]; then
|
||||||
echo "⚠️ GEMINI_API_KEY secret not configured - skipping live tests"
|
has_key=true
|
||||||
|
echo "✅ API key(s) found - running simulation tests"
|
||||||
else
|
else
|
||||||
echo "api_key_available=true" >> $GITHUB_OUTPUT
|
echo "⚠️ No API keys configured - skipping simulation tests"
|
||||||
echo "✅ GEMINI_API_KEY found - running live tests"
|
|
||||||
fi
|
fi
|
||||||
|
echo "api_key_available=$has_key" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Run live integration tests
|
- name: Set up Docker
|
||||||
|
if: steps.check-key.outputs.api_key_available == 'true'
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Build Docker image
|
||||||
if: steps.check-key.outputs.api_key_available == 'true'
|
if: steps.check-key.outputs.api_key_available == 'true'
|
||||||
run: |
|
run: |
|
||||||
# Run live tests that make actual API calls
|
docker compose build
|
||||||
python tests/test_live_integration.py
|
|
||||||
|
- name: Run simulation tests
|
||||||
|
if: steps.check-key.outputs.api_key_available == 'true'
|
||||||
|
run: |
|
||||||
|
# Start services
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# Wait for services to be ready
|
||||||
|
sleep 10
|
||||||
|
|
||||||
|
# Run communication simulator tests
|
||||||
|
python communication_simulator_test.py --skip-docker
|
||||||
env:
|
env:
|
||||||
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
||||||
- name: Skip live tests
|
- name: Skip simulation tests
|
||||||
if: steps.check-key.outputs.api_key_available == 'false'
|
if: steps.check-key.outputs.api_key_available == 'false'
|
||||||
run: |
|
run: |
|
||||||
echo "🔒 Live integration tests skipped (no API key configured)"
|
echo "🔒 Simulation tests skipped (no API keys configured)"
|
||||||
echo "To enable live tests, add GEMINI_API_KEY as a repository secret"
|
echo "To enable simulation tests, add GEMINI_API_KEY and/or OPENAI_API_KEY as repository secrets"
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -165,5 +165,4 @@ test_simulation_files/.claude/
|
|||||||
|
|
||||||
# Temporary test directories
|
# Temporary test directories
|
||||||
test-setup/
|
test-setup/
|
||||||
/test_simulation_files/config.json
|
/test_simulation_files/**
|
||||||
/test_simulation_files/test_module.py
|
|
||||||
|
|||||||
155
CONTRIBUTING.md
155
CONTRIBUTING.md
@@ -1,155 +0,0 @@
|
|||||||
# Contributing to Gemini MCP Server
|
|
||||||
|
|
||||||
Thank you for your interest in contributing! This guide explains how to set up the development environment and contribute to the project.
|
|
||||||
|
|
||||||
## Development Setup
|
|
||||||
|
|
||||||
1. **Clone the repository**
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/BeehiveInnovations/gemini-mcp-server.git
|
|
||||||
cd gemini-mcp-server
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Create virtual environment**
|
|
||||||
```bash
|
|
||||||
python -m venv venv
|
|
||||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Install dependencies**
|
|
||||||
```bash
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing Strategy
|
|
||||||
|
|
||||||
### Two Types of Tests
|
|
||||||
|
|
||||||
#### 1. Unit Tests (Mandatory - No API Key Required)
|
|
||||||
- **Location**: `tests/test_*.py` (except `test_live_integration.py`)
|
|
||||||
- **Purpose**: Test logic, mocking, and functionality without API calls
|
|
||||||
- **Run with**: `python -m pytest tests/ --ignore=tests/test_live_integration.py -v`
|
|
||||||
- **GitHub Actions**: ✅ Always runs
|
|
||||||
- **Coverage**: Measures code coverage
|
|
||||||
|
|
||||||
#### 2. Live Integration Tests (Optional - API Key Required)
|
|
||||||
- **Location**: `tests/test_live_integration.py`
|
|
||||||
- **Purpose**: Verify actual API integration works
|
|
||||||
- **Run with**: `python tests/test_live_integration.py` (requires `GEMINI_API_KEY`)
|
|
||||||
- **GitHub Actions**: 🔒 Only runs if `GEMINI_API_KEY` secret is set
|
|
||||||
|
|
||||||
### Running Tests
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run all unit tests (CI-friendly, no API key needed)
|
|
||||||
python -m pytest tests/ --ignore=tests/test_live_integration.py -v
|
|
||||||
|
|
||||||
# Run with coverage
|
|
||||||
python -m pytest tests/ --ignore=tests/test_live_integration.py --cov=. --cov-report=html
|
|
||||||
|
|
||||||
# Run live integration tests (requires API key)
|
|
||||||
export GEMINI_API_KEY=your-api-key-here
|
|
||||||
python tests/test_live_integration.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Code Quality
|
|
||||||
|
|
||||||
### Formatting and Linting
|
|
||||||
```bash
|
|
||||||
# Install development tools
|
|
||||||
pip install black ruff
|
|
||||||
|
|
||||||
# Format code
|
|
||||||
black .
|
|
||||||
|
|
||||||
# Lint code
|
|
||||||
ruff check .
|
|
||||||
```
|
|
||||||
|
|
||||||
### Pre-commit Checks
|
|
||||||
Before submitting a PR, ensure:
|
|
||||||
- [ ] All unit tests pass: `python -m pytest tests/ --ignore=tests/test_live_integration.py -v`
|
|
||||||
- [ ] Code is formatted: `black --check .`
|
|
||||||
- [ ] Code passes linting: `ruff check .`
|
|
||||||
- [ ] Live tests work (if you have API access): `python tests/test_live_integration.py`
|
|
||||||
|
|
||||||
## Adding New Features
|
|
||||||
|
|
||||||
### Adding a New Tool
|
|
||||||
|
|
||||||
1. **Create tool file**: `tools/your_tool.py`
|
|
||||||
2. **Inherit from BaseTool**: Implement all required methods
|
|
||||||
3. **Add system prompt**: Include prompt in `prompts/tool_prompts.py`
|
|
||||||
4. **Register tool**: Add to `TOOLS` dict in `server.py`
|
|
||||||
5. **Write tests**: Add unit tests that use mocks
|
|
||||||
6. **Test live**: Verify with live API calls
|
|
||||||
|
|
||||||
### Testing New Tools
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Unit test example (tools/test_your_tool.py)
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("tools.base.BaseTool.create_model")
|
|
||||||
async def test_your_tool(self, mock_create_model):
|
|
||||||
mock_model = Mock()
|
|
||||||
mock_model.generate_content.return_value = Mock(
|
|
||||||
candidates=[Mock(content=Mock(parts=[Mock(text="Expected response")]))]
|
|
||||||
)
|
|
||||||
mock_create_model.return_value = mock_model
|
|
||||||
|
|
||||||
tool = YourTool()
|
|
||||||
result = await tool.execute({"param": "value"})
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert "Expected response" in result[0].text
|
|
||||||
```
|
|
||||||
|
|
||||||
## CI/CD Pipeline
|
|
||||||
|
|
||||||
The GitHub Actions workflow:
|
|
||||||
|
|
||||||
1. **Unit Tests**: Run on all Python versions (3.10, 3.11, 3.12)
|
|
||||||
2. **Linting**: Check code formatting and style
|
|
||||||
3. **Live Tests**: Only run if `GEMINI_API_KEY` secret is available
|
|
||||||
|
|
||||||
### Key Features:
|
|
||||||
- **✅ No API key required for PRs** - All contributors can run tests
|
|
||||||
- **🔒 Live verification available** - Maintainers can verify API integration
|
|
||||||
- **📊 Coverage reporting** - Track test coverage
|
|
||||||
- **🐍 Multi-Python support** - Ensure compatibility
|
|
||||||
|
|
||||||
## Contribution Guidelines
|
|
||||||
|
|
||||||
### Pull Request Process
|
|
||||||
|
|
||||||
1. **Fork the repository**
|
|
||||||
2. **Create a feature branch**: `git checkout -b feature/your-feature`
|
|
||||||
3. **Make your changes**
|
|
||||||
4. **Add/update tests**
|
|
||||||
5. **Run tests locally**: Ensure unit tests pass
|
|
||||||
6. **Submit PR**: Include description of changes
|
|
||||||
|
|
||||||
### Code Standards
|
|
||||||
|
|
||||||
- **Follow existing patterns**: Look at existing tools for examples
|
|
||||||
- **Add comprehensive tests**: Both unit tests (required) and live tests (recommended)
|
|
||||||
- **Update documentation**: Update README if adding new features
|
|
||||||
- **Use type hints**: All new code should include proper type annotations
|
|
||||||
- **Keep it simple**: Follow SOLID principles and keep functions focused
|
|
||||||
|
|
||||||
### Security Considerations
|
|
||||||
|
|
||||||
- **Never commit API keys**: Use environment variables
|
|
||||||
- **Validate inputs**: Always validate user inputs in tools
|
|
||||||
- **Handle errors gracefully**: Provide meaningful error messages
|
|
||||||
- **Follow security best practices**: Sanitize file paths, validate file access
|
|
||||||
|
|
||||||
## Getting Help
|
|
||||||
|
|
||||||
- **Issues**: Open an issue for bugs or feature requests
|
|
||||||
- **Discussions**: Use GitHub Discussions for questions
|
|
||||||
- **Documentation**: Check the README for usage examples
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
By contributing, you agree that your contributions will be licensed under the MIT License.
|
|
||||||
361
README.md
361
README.md
@@ -3,48 +3,31 @@
|
|||||||
https://github.com/user-attachments/assets/a67099df-9387-4720-9b41-c986243ac11b
|
https://github.com/user-attachments/assets/a67099df-9387-4720-9b41-c986243ac11b
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<b>🤖 Claude + [Gemini / O3 / Both] = Your Ultimate AI Development Team</b>
|
<b>🤖 Claude + [Gemini / O3 / or Both] = Your Ultimate AI Development Team</b>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
The ultimate development partner for Claude - a Model Context Protocol server that gives Claude access to multiple AI models for enhanced code analysis, problem-solving, and collaborative development.
|
The ultimate development partners for Claude - a Model Context Protocol server that gives Claude access to multiple AI models for enhanced code analysis,
|
||||||
|
problem-solving, and collaborative development.
|
||||||
|
|
||||||
**🎯 Auto Mode (NEW):** Set `DEFAULT_MODEL=auto` and Claude will intelligently select the best model for each task:
|
**Features true AI orchestration with conversations that continue across tasks** - Give Claude a complex
|
||||||
- **Complex architecture review?** → Claude picks Gemini Pro with extended thinking
|
task and let it orchestrate between models automatically. Claude stays in control, performs the actual work,
|
||||||
- **Quick code formatting?** → Claude picks Gemini Flash for speed
|
but gets perspectives from the best AI for each subtask. Claude can switch between different tools _and_ models mid-conversation,
|
||||||
- **Logical debugging?** → Claude picks O3 for reasoning
|
with context carrying forward seamlessly.
|
||||||
- **Or specify your preference:** "Use flash to quickly analyze this" or "Use o3 for debugging"
|
|
||||||
|
|
||||||
**📚 Supported Models:**
|
|
||||||
- **Google Gemini**: 2.5 Pro (extended thinking, 1M tokens) & 2.0 Flash (ultra-fast, 1M tokens)
|
|
||||||
- **OpenAI**: O3 (strong reasoning, 200K tokens), O3-mini (faster variant), GPT-4o (128K tokens)
|
|
||||||
- **More providers coming soon!**
|
|
||||||
|
|
||||||
**Features true AI orchestration with conversations that continue across tasks** - Give Claude a complex task and let it orchestrate between models automatically. Claude stays in control, performs the actual work, but gets perspectives from the best AI for each subtask. Claude can switch between different tools AND models mid-conversation, with context carrying forward seamlessly.
|
|
||||||
|
|
||||||
**Example Workflow:**
|
**Example Workflow:**
|
||||||
1. Claude uses Gemini Pro to deeply analyze your architecture
|
1. Claude uses Gemini Pro to deeply [`analyze`](#6-analyze---smart-file-analysis) the code in question
|
||||||
2. Switches to O3 for logical debugging of a specific issue
|
2. Switches to O3 to continue [`chatting`](#1-chat---general-development-chat--collaborative-thinking) about its findings
|
||||||
3. Uses Flash for quick code formatting
|
3. Uses Flash to validate formatting suggestions from O3
|
||||||
4. Returns to Pro for security review
|
4. Performs the actual work after taking in feedback from all three
|
||||||
|
5. Returns to Pro for a [`precommit`](#4-precommit---pre-commit-validation) review
|
||||||
|
|
||||||
All within a single conversation thread!
|
All within a single conversation thread! Gemini Pro in step 5 _knows_ what was recommended by O3 in step 2! Taking that context
|
||||||
|
and review into consideration to aid with its pre-commit review.
|
||||||
|
|
||||||
**Think of it as Claude Code _for_ Claude Code.**
|
**Think of it as Claude Code _for_ Claude Code.**
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
> 🚀 **Multi-Provider Support with Auto Mode!**
|
|
||||||
> Claude automatically selects the best model for each task when using `DEFAULT_MODEL=auto`:
|
|
||||||
> - **Gemini Pro**: Extended thinking (up to 32K tokens), best for complex problems
|
|
||||||
> - **Gemini Flash**: Ultra-fast responses, best for quick tasks
|
|
||||||
> - **O3**: Strong reasoning, best for logical problems and debugging
|
|
||||||
> - **O3-mini**: Balanced performance, good for moderate complexity
|
|
||||||
> - **GPT-4o**: General-purpose, good for explanations and chat
|
|
||||||
>
|
|
||||||
> Or manually specify: "Use pro for deep analysis" or "Use o3 to debug this"
|
|
||||||
|
|
||||||
## Quick Navigation
|
## Quick Navigation
|
||||||
|
|
||||||
- **Getting Started**
|
- **Getting Started**
|
||||||
@@ -72,7 +55,6 @@ All within a single conversation thread!
|
|||||||
- **Resources**
|
- **Resources**
|
||||||
- [Windows Setup](#windows-setup-guide) - WSL setup instructions for Windows
|
- [Windows Setup](#windows-setup-guide) - WSL setup instructions for Windows
|
||||||
- [Troubleshooting](#troubleshooting) - Common issues and solutions
|
- [Troubleshooting](#troubleshooting) - Common issues and solutions
|
||||||
- [Contributing](#contributing) - How to contribute
|
|
||||||
- [Testing](#testing) - Running tests
|
- [Testing](#testing) - Running tests
|
||||||
|
|
||||||
## Why This Server?
|
## Why This Server?
|
||||||
@@ -85,9 +67,9 @@ Claude is brilliant, but sometimes you need:
|
|||||||
- **Professional code reviews** with actionable feedback across entire repositories ([`codereview`](#3-codereview---professional-code-review))
|
- **Professional code reviews** with actionable feedback across entire repositories ([`codereview`](#3-codereview---professional-code-review))
|
||||||
- **Pre-commit validation** with deep analysis using the best model for the job ([`precommit`](#4-precommit---pre-commit-validation))
|
- **Pre-commit validation** with deep analysis using the best model for the job ([`precommit`](#4-precommit---pre-commit-validation))
|
||||||
- **Expert debugging** - O3 for logical issues, Gemini for architectural problems ([`debug`](#5-debug---expert-debugging-assistant))
|
- **Expert debugging** - O3 for logical issues, Gemini for architectural problems ([`debug`](#5-debug---expert-debugging-assistant))
|
||||||
- **Massive context windows** - Gemini (1M tokens), O3 (200K tokens), GPT-4o (128K tokens)
|
- **Extended context windows beyond Claude's limits** - Delegate analysis to Gemini (1M tokens) or O3 (200K tokens) for entire codebases, large datasets, or comprehensive documentation
|
||||||
- **Model-specific strengths** - Extended thinking with Gemini Pro, fast iteration with Flash, strong reasoning with O3
|
- **Model-specific strengths** - Extended thinking with Gemini Pro, fast iteration with Flash, strong reasoning with O3
|
||||||
- **Dynamic collaboration** - Models can request additional context from Claude mid-analysis
|
- **Dynamic collaboration** - Models can request additional context and follow-up replies from Claude mid-analysis
|
||||||
- **Smart file handling** - Automatically expands directories, manages token limits based on model capacity
|
- **Smart file handling** - Automatically expands directories, manages token limits based on model capacity
|
||||||
- **[Bypass MCP's token limits](#working-with-large-prompts)** - Work around MCP's 25K limit automatically
|
- **[Bypass MCP's token limits](#working-with-large-prompts)** - Work around MCP's 25K limit automatically
|
||||||
|
|
||||||
@@ -123,8 +105,8 @@ The final implementation resulted in a 26% improvement in JSON parsing performan
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Clone to your preferred location
|
# Clone to your preferred location
|
||||||
git clone https://github.com/BeehiveInnovations/gemini-mcp-server.git
|
git clone https://github.com/BeehiveInnovations/zen-mcp-server.git
|
||||||
cd gemini-mcp-server
|
cd zen-mcp-server
|
||||||
|
|
||||||
# One-command setup (includes Redis for AI conversations)
|
# One-command setup (includes Redis for AI conversations)
|
||||||
./setup-docker.sh
|
./setup-docker.sh
|
||||||
@@ -147,7 +129,7 @@ nano .env
|
|||||||
# The file will contain:
|
# The file will contain:
|
||||||
# GEMINI_API_KEY=your-gemini-api-key-here # For Gemini models
|
# GEMINI_API_KEY=your-gemini-api-key-here # For Gemini models
|
||||||
# OPENAI_API_KEY=your-openai-api-key-here # For O3 model
|
# OPENAI_API_KEY=your-openai-api-key-here # For O3 model
|
||||||
# WORKSPACE_ROOT=/workspace (automatically configured)
|
# WORKSPACE_ROOT=/Users/your-username (automatically configured)
|
||||||
|
|
||||||
# Note: At least one API key is required (Gemini or OpenAI)
|
# Note: At least one API key is required (Gemini or OpenAI)
|
||||||
```
|
```
|
||||||
@@ -158,13 +140,13 @@ nano .env
|
|||||||
Run the following commands on the terminal to add the MCP directly to Claude Code
|
Run the following commands on the terminal to add the MCP directly to Claude Code
|
||||||
```bash
|
```bash
|
||||||
# Add the MCP server directly via Claude Code CLI
|
# Add the MCP server directly via Claude Code CLI
|
||||||
claude mcp add gemini -s user -- docker exec -i gemini-mcp-server python server.py
|
claude mcp add zen -s user -- docker exec -i zen-mcp-server python server.py
|
||||||
|
|
||||||
# List your MCP servers to verify
|
# List your MCP servers to verify
|
||||||
claude mcp list
|
claude mcp list
|
||||||
|
|
||||||
# Remove when needed
|
# Remove when needed
|
||||||
claude mcp remove gemini
|
claude mcp remove zen
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Claude Desktop
|
#### Claude Desktop
|
||||||
@@ -184,12 +166,12 @@ The setup script shows you the exact configuration. It looks like this:
|
|||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"mcpServers": {
|
"mcpServers": {
|
||||||
"gemini": {
|
"zen": {
|
||||||
"command": "docker",
|
"command": "docker",
|
||||||
"args": [
|
"args": [
|
||||||
"exec",
|
"exec",
|
||||||
"-i",
|
"-i",
|
||||||
"gemini-mcp-server",
|
"zen-mcp-server",
|
||||||
"python",
|
"python",
|
||||||
"server.py"
|
"server.py"
|
||||||
]
|
]
|
||||||
@@ -289,7 +271,7 @@ This server enables **true AI collaboration** between Claude and multiple AI mod
|
|||||||
- Complex architecture review → Claude picks Gemini Pro
|
- Complex architecture review → Claude picks Gemini Pro
|
||||||
- Quick formatting check → Claude picks Flash
|
- Quick formatting check → Claude picks Flash
|
||||||
- Logical debugging → Claude picks O3
|
- Logical debugging → Claude picks O3
|
||||||
- General explanations → Claude picks GPT-4o
|
- General explanations → Claude picks Flash for speed
|
||||||
|
|
||||||
**Pro Tip:** Thinking modes (for Gemini models) control depth vs token cost. Use "minimal" or "low" for quick tasks, "high" or "max" for complex problems. [Learn more](#thinking-modes---managing-token-costs--quality)
|
**Pro Tip:** Thinking modes (for Gemini models) control depth vs token cost. Use "minimal" or "low" for quick tasks, "high" or "max" for complex problems. [Learn more](#thinking-modes---managing-token-costs--quality)
|
||||||
|
|
||||||
@@ -307,37 +289,12 @@ This server enables **true AI collaboration** between Claude and multiple AI mod
|
|||||||
|
|
||||||
**Thinking Mode:** Default is `medium` (8,192 tokens). Use `low` for quick questions to save tokens, or `high` for complex discussions when thoroughness matters.
|
**Thinking Mode:** Default is `medium` (8,192 tokens). Use `low` for quick questions to save tokens, or `high` for complex discussions when thoroughness matters.
|
||||||
|
|
||||||
#### Example Prompts:
|
#### Example Prompt:
|
||||||
|
|
||||||
**Basic Usage:**
|
|
||||||
```
|
```
|
||||||
"Use gemini to explain how async/await works in Python"
|
Chat with zen and pick the best model for this job. I need to pick between Redis and Memcached for session storage
|
||||||
"Get gemini to compare Redis vs Memcached for session storage"
|
and I need an expert opinion for the project I'm working on. Get a good idea of what the project does, pick one of the two options
|
||||||
"Share my authentication design with gemini and get their opinion"
|
and then debate with the other models to give me a final verdict
|
||||||
"Brainstorm with gemini about scaling strategies for our API"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Managing Token Costs:**
|
|
||||||
```
|
|
||||||
# Save tokens (~6k) for simple questions
|
|
||||||
"Use gemini with minimal thinking to explain what a REST API is"
|
|
||||||
"Chat with gemini using low thinking mode about Python naming conventions"
|
|
||||||
|
|
||||||
# Use default for balanced analysis
|
|
||||||
"Get gemini to review my database schema design" (uses default medium)
|
|
||||||
|
|
||||||
# Invest tokens for complex discussions
|
|
||||||
"Use gemini with high thinking to brainstorm distributed system architecture"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Collaborative Workflow:**
|
|
||||||
```
|
|
||||||
"Research the best message queue for our use case (high throughput, exactly-once delivery).
|
|
||||||
Use gemini to compare RabbitMQ, Kafka, and AWS SQS. Based on gemini's analysis and your research,
|
|
||||||
recommend the best option with implementation plan."
|
|
||||||
|
|
||||||
"Design a caching strategy for our API. Get gemini's input on Redis vs Memcached vs in-memory caching.
|
|
||||||
Combine both perspectives to create a comprehensive caching implementation guide."
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Features:**
|
**Key Features:**
|
||||||
@@ -351,47 +308,18 @@ Combine both perspectives to create a comprehensive caching implementation guide
|
|||||||
- Can reference files for context: `"Use gemini to explain this algorithm with context from algorithm.py"`
|
- Can reference files for context: `"Use gemini to explain this algorithm with context from algorithm.py"`
|
||||||
- **Dynamic collaboration**: Gemini can request additional files or context during the conversation if needed for a more thorough response
|
- **Dynamic collaboration**: Gemini can request additional files or context during the conversation if needed for a more thorough response
|
||||||
- **Web search capability**: Analyzes when web searches would be helpful and recommends specific searches for Claude to perform, ensuring access to current documentation and best practices
|
- **Web search capability**: Analyzes when web searches would be helpful and recommends specific searches for Claude to perform, ensuring access to current documentation and best practices
|
||||||
|
|
||||||
### 2. `thinkdeep` - Extended Reasoning Partner
|
### 2. `thinkdeep` - Extended Reasoning Partner
|
||||||
|
|
||||||
**Get a second opinion to augment Claude's own extended thinking**
|
**Get a second opinion to augment Claude's own extended thinking**
|
||||||
|
|
||||||
**Thinking Mode:** Default is `high` (16,384 tokens) for deep analysis. Claude will automatically choose the best mode based on complexity - use `low` for quick validations, `medium` for standard problems, `high` for complex issues (default), or `max` for extremely complex challenges requiring deepest analysis.
|
**Thinking Mode:** Default is `high` (16,384 tokens) for deep analysis. Claude will automatically choose the best mode based on complexity - use `low` for quick validations, `medium` for standard problems, `high` for complex issues (default), or `max` for extremely complex challenges requiring deepest analysis.
|
||||||
|
|
||||||
#### Example Prompts:
|
#### Example Prompt:
|
||||||
|
|
||||||
**Basic Usage:**
|
|
||||||
```
|
```
|
||||||
"Use gemini to think deeper about my authentication design"
|
Think deeper about my authentication design with zen using max thinking mode and brainstorm to come up
|
||||||
"Use gemini to extend my analysis of this distributed system architecture"
|
with the best architecture for my project
|
||||||
```
|
|
||||||
|
|
||||||
**With Web Search (for exploring new technologies):**
|
|
||||||
```
|
|
||||||
"Use gemini to think deeper about using HTMX vs React for this project - enable web search to explore current best practices"
|
|
||||||
"Get gemini to think deeper about implementing WebAuthn authentication with web search enabled for latest standards"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Managing Token Costs:**
|
|
||||||
```
|
|
||||||
# Claude will intelligently select the right mode, but you can override:
|
|
||||||
"Use gemini to think deeper with medium thinking about this refactoring approach" (saves ~8k tokens vs default)
|
|
||||||
"Get gemini to think deeper using low thinking to validate my basic approach" (saves ~14k tokens vs default)
|
|
||||||
|
|
||||||
# Use default high for most complex problems
|
|
||||||
"Use gemini to think deeper about this security architecture" (uses default high - 16k tokens)
|
|
||||||
|
|
||||||
# For extremely complex challenges requiring maximum depth
|
|
||||||
"Use gemini with max thinking to solve this distributed consensus problem" (adds ~16k tokens vs default)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Collaborative Workflow:**
|
|
||||||
```
|
|
||||||
"Design an authentication system for our SaaS platform. Then use gemini to review your design
|
|
||||||
for security vulnerabilities. After getting gemini's feedback, incorporate the suggestions and
|
|
||||||
show me the final improved design."
|
|
||||||
|
|
||||||
"Create an event-driven architecture for our order processing system. Use gemini to think deeper
|
|
||||||
about event ordering and failure scenarios. Then integrate gemini's insights and present the enhanced architecture."
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Features:**
|
**Key Features:**
|
||||||
@@ -403,6 +331,7 @@ about event ordering and failure scenarios. Then integrate gemini's insights and
|
|||||||
- Can reference specific files for context: `"Use gemini to think deeper about my API design with reference to api/routes.py"`
|
- Can reference specific files for context: `"Use gemini to think deeper about my API design with reference to api/routes.py"`
|
||||||
- **Enhanced Critical Evaluation (v2.10.0)**: After Gemini's analysis, Claude is prompted to critically evaluate the suggestions, consider context and constraints, identify risks, and synthesize a final recommendation - ensuring a balanced, well-considered solution
|
- **Enhanced Critical Evaluation (v2.10.0)**: After Gemini's analysis, Claude is prompted to critically evaluate the suggestions, consider context and constraints, identify risks, and synthesize a final recommendation - ensuring a balanced, well-considered solution
|
||||||
- **Web search capability**: When enabled (default: true), identifies areas where current documentation or community solutions would strengthen the analysis and suggests specific searches for Claude
|
- **Web search capability**: When enabled (default: true), identifies areas where current documentation or community solutions would strengthen the analysis and suggests specific searches for Claude
|
||||||
|
|
||||||
### 3. `codereview` - Professional Code Review
|
### 3. `codereview` - Professional Code Review
|
||||||
**Comprehensive code analysis with prioritized feedback**
|
**Comprehensive code analysis with prioritized feedback**
|
||||||
|
|
||||||
@@ -410,34 +339,9 @@ about event ordering and failure scenarios. Then integrate gemini's insights and
|
|||||||
|
|
||||||
#### Example Prompts:
|
#### Example Prompts:
|
||||||
|
|
||||||
**Basic Usage:**
|
|
||||||
```
|
```
|
||||||
"Use gemini to review auth.py for issues"
|
Perform a codereview with zen using gemini pro and review auth.py for security issues and potential vulnerabilities.
|
||||||
"Use gemini to do a security review of auth/ focusing on authentication"
|
I need an actionable plan but break it down into smaller quick-wins that we can implement and test rapidly
|
||||||
```
|
|
||||||
|
|
||||||
**Managing Token Costs:**
|
|
||||||
```
|
|
||||||
# Save tokens for style/formatting reviews
|
|
||||||
"Use gemini with minimal thinking to check code style in utils.py" (saves ~8k tokens)
|
|
||||||
"Review this file with gemini using low thinking for basic issues" (saves ~6k tokens)
|
|
||||||
|
|
||||||
# Default for standard reviews
|
|
||||||
"Use gemini to review the API endpoints" (uses default medium)
|
|
||||||
|
|
||||||
# Invest tokens for critical code
|
|
||||||
"Get gemini to review auth.py with high thinking mode for security issues" (adds ~8k tokens)
|
|
||||||
"Use gemini with max thinking to audit our encryption module" (adds ~24k tokens - justified for security)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Collaborative Workflow:**
|
|
||||||
```
|
|
||||||
"Refactor the authentication module to use dependency injection. Then use gemini to
|
|
||||||
review your refactoring for any security vulnerabilities. Based on gemini's feedback,
|
|
||||||
make any necessary adjustments and show me the final secure implementation."
|
|
||||||
|
|
||||||
"Optimize the slow database queries in user_service.py. Get gemini to review your optimizations
|
|
||||||
for potential regressions or edge cases. Incorporate gemini's suggestions and present the final optimized queries."
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Features:**
|
**Key Features:**
|
||||||
@@ -445,6 +349,7 @@ make any necessary adjustments and show me the final secure implementation."
|
|||||||
- Supports specialized reviews: security, performance, quick
|
- Supports specialized reviews: security, performance, quick
|
||||||
- Can enforce coding standards: `"Use gemini to review src/ against PEP8 standards"`
|
- Can enforce coding standards: `"Use gemini to review src/ against PEP8 standards"`
|
||||||
- Filters by severity: `"Get gemini to review auth/ - only report critical vulnerabilities"`
|
- Filters by severity: `"Get gemini to review auth/ - only report critical vulnerabilities"`
|
||||||
|
|
||||||
### 4. `precommit` - Pre-Commit Validation
|
### 4. `precommit` - Pre-Commit Validation
|
||||||
**Comprehensive review of staged/unstaged git changes across multiple repositories**
|
**Comprehensive review of staged/unstaged git changes across multiple repositories**
|
||||||
|
|
||||||
@@ -454,7 +359,7 @@ make any necessary adjustments and show me the final secure implementation."
|
|||||||
<img src="https://github.com/user-attachments/assets/584adfa6-d252-49b4-b5b0-0cd6e97fb2c6" width="950">
|
<img src="https://github.com/user-attachments/assets/584adfa6-d252-49b4-b5b0-0cd6e97fb2c6" width="950">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
**Prompt:**
|
**Prompt Used:**
|
||||||
```
|
```
|
||||||
Now use gemini and perform a review and precommit and ensure original requirements are met, no duplication of code or
|
Now use gemini and perform a review and precommit and ensure original requirements are met, no duplication of code or
|
||||||
logic, everything should work as expected
|
logic, everything should work as expected
|
||||||
@@ -464,35 +369,8 @@ How beautiful is that? Claude used `precommit` twice and `codereview` once and a
|
|||||||
|
|
||||||
#### Example Prompts:
|
#### Example Prompts:
|
||||||
|
|
||||||
**Basic Usage:**
|
|
||||||
```
|
```
|
||||||
"Use gemini to review my pending changes before I commit"
|
Use zen and perform a thorough precommit ensuring there aren't any new regressions or bugs introduced
|
||||||
"Get gemini to validate all my git changes match the original requirements"
|
|
||||||
"Review pending changes in the frontend/ directory"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Managing Token Costs:**
|
|
||||||
```
|
|
||||||
# Save tokens for small changes
|
|
||||||
"Use gemini with low thinking to review my README updates" (saves ~6k tokens)
|
|
||||||
"Review my config changes with gemini using minimal thinking" (saves ~8k tokens)
|
|
||||||
|
|
||||||
# Default for regular commits
|
|
||||||
"Use gemini to review my feature changes" (uses default medium)
|
|
||||||
|
|
||||||
# Invest tokens for critical releases
|
|
||||||
"Use gemini with high thinking to review changes before production release" (adds ~8k tokens)
|
|
||||||
"Get gemini to validate all changes with max thinking for this security patch" (adds ~24k tokens - worth it!)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Collaborative Workflow:**
|
|
||||||
```
|
|
||||||
"I've implemented the user authentication feature. Use gemini to review all pending changes
|
|
||||||
across the codebase to ensure they align with the security requirements. Fix any issues
|
|
||||||
gemini identifies before committing."
|
|
||||||
|
|
||||||
"Review all my changes for the API refactoring task. Get gemini to check for incomplete
|
|
||||||
implementations or missing test coverage. Update the code based on gemini's findings."
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Features:**
|
**Key Features:**
|
||||||
@@ -524,37 +402,6 @@ implementations or missing test coverage. Update the code based on gemini's find
|
|||||||
"Get gemini to debug why my API returns 500 errors with the full stack trace: [paste traceback]"
|
"Get gemini to debug why my API returns 500 errors with the full stack trace: [paste traceback]"
|
||||||
```
|
```
|
||||||
|
|
||||||
**With Web Search (for unfamiliar errors):**
|
|
||||||
```
|
|
||||||
"Use gemini to debug this cryptic Kubernetes error with web search enabled to find similar issues"
|
|
||||||
"Debug this React hydration error with gemini - enable web search to check for known solutions"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Managing Token Costs:**
|
|
||||||
```
|
|
||||||
# Save tokens for simple errors
|
|
||||||
"Use gemini with minimal thinking to debug this syntax error" (saves ~8k tokens)
|
|
||||||
"Debug this import error with gemini using low thinking" (saves ~6k tokens)
|
|
||||||
|
|
||||||
# Default for standard debugging
|
|
||||||
"Use gemini to debug why this function returns null" (uses default medium)
|
|
||||||
|
|
||||||
# Invest tokens for complex bugs
|
|
||||||
"Use gemini with high thinking to debug this race condition" (adds ~8k tokens)
|
|
||||||
"Get gemini to debug this memory leak with max thinking mode" (adds ~24k tokens - find that leak!)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Collaborative Workflow:**
|
|
||||||
```
|
|
||||||
"I'm getting 'ConnectionPool limit exceeded' errors under load. Debug the issue and use
|
|
||||||
gemini to analyze it deeper with context from db/pool.py. Based on gemini's root cause analysis,
|
|
||||||
implement a fix and get gemini to validate the solution will scale."
|
|
||||||
|
|
||||||
"Debug why tests fail randomly on CI. Once you identify potential causes, share with gemini along
|
|
||||||
with test logs and CI configuration. Apply gemini's debugging strategy, then use gemini to
|
|
||||||
suggest preventive measures."
|
|
||||||
```
|
|
||||||
|
|
||||||
**Key Features:**
|
**Key Features:**
|
||||||
- Generates multiple ranked hypotheses for systematic debugging
|
- Generates multiple ranked hypotheses for systematic debugging
|
||||||
- Accepts error context, stack traces, and logs
|
- Accepts error context, stack traces, and logs
|
||||||
@@ -576,36 +423,6 @@ suggest preventive measures."
|
|||||||
"Get gemini to do an architecture analysis of the src/ directory"
|
"Get gemini to do an architecture analysis of the src/ directory"
|
||||||
```
|
```
|
||||||
|
|
||||||
**With Web Search (for unfamiliar code):**
|
|
||||||
```
|
|
||||||
"Use gemini to analyze this GraphQL schema with web search enabled to understand best practices"
|
|
||||||
"Analyze this Rust code with gemini - enable web search to look up unfamiliar patterns and idioms"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Managing Token Costs:**
|
|
||||||
```
|
|
||||||
# Save tokens for quick overviews
|
|
||||||
"Use gemini with minimal thinking to analyze what config.py does" (saves ~8k tokens)
|
|
||||||
"Analyze this utility file with gemini using low thinking" (saves ~6k tokens)
|
|
||||||
|
|
||||||
# Default for standard analysis
|
|
||||||
"Use gemini to analyze the API structure" (uses default medium)
|
|
||||||
|
|
||||||
# Invest tokens for deep analysis
|
|
||||||
"Use gemini with high thinking to analyze the entire codebase architecture" (adds ~8k tokens)
|
|
||||||
"Get gemini to analyze system design with max thinking for refactoring plan" (adds ~24k tokens)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Collaborative Workflow:**
|
|
||||||
```
|
|
||||||
"Analyze our project structure in src/ and identify architectural improvements. Share your
|
|
||||||
analysis with gemini for a deeper review of design patterns and anti-patterns. Based on both
|
|
||||||
analyses, create a refactoring roadmap."
|
|
||||||
|
|
||||||
"Perform a security analysis of our authentication system. Use gemini to analyze auth/, middleware/, and api/ for vulnerabilities.
|
|
||||||
Combine your findings with gemini's to create a comprehensive security report."
|
|
||||||
```
|
|
||||||
|
|
||||||
**Key Features:**
|
**Key Features:**
|
||||||
- Analyzes single files or entire directories
|
- Analyzes single files or entire directories
|
||||||
- Supports specialized analysis types: architecture, performance, security, quality
|
- Supports specialized analysis types: architecture, performance, security, quality
|
||||||
@@ -627,7 +444,7 @@ All tools that work with files support **both individual files and entire direct
|
|||||||
**`analyze`** - Analyze files or directories
|
**`analyze`** - Analyze files or directories
|
||||||
- `files`: List of file paths or directories (required)
|
- `files`: List of file paths or directories (required)
|
||||||
- `question`: What to analyze (required)
|
- `question`: What to analyze (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|gpt-4o (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini (default: server default)
|
||||||
- `analysis_type`: architecture|performance|security|quality|general
|
- `analysis_type`: architecture|performance|security|quality|general
|
||||||
- `output_format`: summary|detailed|actionable
|
- `output_format`: summary|detailed|actionable
|
||||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||||
@@ -642,7 +459,7 @@ All tools that work with files support **both individual files and entire direct
|
|||||||
|
|
||||||
**`codereview`** - Review code files or directories
|
**`codereview`** - Review code files or directories
|
||||||
- `files`: List of file paths or directories (required)
|
- `files`: List of file paths or directories (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|gpt-4o (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini (default: server default)
|
||||||
- `review_type`: full|security|performance|quick
|
- `review_type`: full|security|performance|quick
|
||||||
- `focus_on`: Specific aspects to focus on
|
- `focus_on`: Specific aspects to focus on
|
||||||
- `standards`: Coding standards to enforce
|
- `standards`: Coding standards to enforce
|
||||||
@@ -658,7 +475,7 @@ All tools that work with files support **both individual files and entire direct
|
|||||||
|
|
||||||
**`debug`** - Debug with file context
|
**`debug`** - Debug with file context
|
||||||
- `error_description`: Description of the issue (required)
|
- `error_description`: Description of the issue (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|gpt-4o (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini (default: server default)
|
||||||
- `error_context`: Stack trace or logs
|
- `error_context`: Stack trace or logs
|
||||||
- `files`: Files or directories related to the issue
|
- `files`: Files or directories related to the issue
|
||||||
- `runtime_info`: Environment details
|
- `runtime_info`: Environment details
|
||||||
@@ -674,7 +491,7 @@ All tools that work with files support **both individual files and entire direct
|
|||||||
|
|
||||||
**`thinkdeep`** - Extended analysis with file context
|
**`thinkdeep`** - Extended analysis with file context
|
||||||
- `current_analysis`: Your current thinking (required)
|
- `current_analysis`: Your current thinking (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|gpt-4o (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini (default: server default)
|
||||||
- `problem_context`: Additional context
|
- `problem_context`: Additional context
|
||||||
- `focus_areas`: Specific aspects to focus on
|
- `focus_areas`: Specific aspects to focus on
|
||||||
- `files`: Files or directories for context
|
- `files`: Files or directories for context
|
||||||
@@ -800,16 +617,16 @@ To help choose the right tool for your needs:
|
|||||||
**Examples by scenario:**
|
**Examples by scenario:**
|
||||||
```
|
```
|
||||||
# Quick style check
|
# Quick style check
|
||||||
"Use gemini to review formatting in utils.py with minimal thinking"
|
"Use o3 to review formatting in utils.py with minimal thinking"
|
||||||
|
|
||||||
# Security audit
|
# Security audit
|
||||||
"Get gemini to do a security review of auth/ with thinking mode high"
|
"Get o3 to do a security review of auth/ with thinking mode high"
|
||||||
|
|
||||||
# Complex debugging
|
# Complex debugging
|
||||||
"Use gemini to debug this race condition with max thinking mode"
|
"Use zen to debug this race condition with max thinking mode"
|
||||||
|
|
||||||
# Architecture analysis
|
# Architecture analysis
|
||||||
"Analyze the entire src/ directory architecture with high thinking"
|
"Analyze the entire src/ directory architecture with high thinking using zen"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Advanced Features
|
## Advanced Features
|
||||||
@@ -831,7 +648,7 @@ The MCP protocol has a combined request+response limit of approximately 25K toke
|
|||||||
User: "Use gemini to review this code: [50,000+ character detailed analysis]"
|
User: "Use gemini to review this code: [50,000+ character detailed analysis]"
|
||||||
|
|
||||||
# Server detects the large prompt and responds:
|
# Server detects the large prompt and responds:
|
||||||
Gemini MCP: "The prompt is too large for MCP's token limits (>50,000 characters).
|
Zen MCP: "The prompt is too large for MCP's token limits (>50,000 characters).
|
||||||
Please save the prompt text to a temporary file named 'prompt.txt' and resend
|
Please save the prompt text to a temporary file named 'prompt.txt' and resend
|
||||||
the request with an empty prompt string and the absolute file path included
|
the request with an empty prompt string and the absolute file path included
|
||||||
in the files parameter, along with any other files you wish to share as context."
|
in the files parameter, along with any other files you wish to share as context."
|
||||||
@@ -928,7 +745,7 @@ DEFAULT_MODEL=auto # Claude picks the best model automatically
|
|||||||
|
|
||||||
# API Keys (at least one required)
|
# API Keys (at least one required)
|
||||||
GEMINI_API_KEY=your-gemini-key # Enables Gemini Pro & Flash
|
GEMINI_API_KEY=your-gemini-key # Enables Gemini Pro & Flash
|
||||||
OPENAI_API_KEY=your-openai-key # Enables O3, O3-mini, GPT-4o
|
OPENAI_API_KEY=your-openai-key # Enables O3, O3-mini
|
||||||
```
|
```
|
||||||
|
|
||||||
**How Auto Mode Works:**
|
**How Auto Mode Works:**
|
||||||
@@ -944,7 +761,6 @@ OPENAI_API_KEY=your-openai-key # Enables O3, O3-mini, GPT-4o
|
|||||||
| **`flash`** (Gemini 2.0 Flash) | Google | 1M tokens | Ultra-fast responses | Quick checks, formatting, simple analysis |
|
| **`flash`** (Gemini 2.0 Flash) | Google | 1M tokens | Ultra-fast responses | Quick checks, formatting, simple analysis |
|
||||||
| **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis |
|
| **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis |
|
||||||
| **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks |
|
| **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks |
|
||||||
| **`gpt-4o`** | OpenAI | 128K tokens | General purpose | Explanations, documentation, chat |
|
|
||||||
|
|
||||||
**Manual Model Selection:**
|
**Manual Model Selection:**
|
||||||
You can specify a default model instead of auto mode:
|
You can specify a default model instead of auto mode:
|
||||||
@@ -966,7 +782,6 @@ Regardless of your default setting, you can specify models per request:
|
|||||||
**Model Capabilities:**
|
**Model Capabilities:**
|
||||||
- **Gemini Models**: Support thinking modes (minimal to max), web search, 1M context
|
- **Gemini Models**: Support thinking modes (minimal to max), web search, 1M context
|
||||||
- **O3 Models**: Excellent reasoning, systematic analysis, 200K context
|
- **O3 Models**: Excellent reasoning, systematic analysis, 200K context
|
||||||
- **GPT-4o**: Balanced general-purpose model, 128K context
|
|
||||||
|
|
||||||
### Temperature Defaults
|
### Temperature Defaults
|
||||||
Different tools use optimized temperature settings:
|
Different tools use optimized temperature settings:
|
||||||
@@ -1011,15 +826,16 @@ When using any Gemini tool, always provide absolute paths:
|
|||||||
|
|
||||||
By default, the server allows access to files within your home directory. This is necessary for the server to work with any file you might want to analyze from Claude.
|
By default, the server allows access to files within your home directory. This is necessary for the server to work with any file you might want to analyze from Claude.
|
||||||
|
|
||||||
**To restrict access to a specific project directory**, set the `MCP_PROJECT_ROOT` environment variable:
|
**For Docker environments**, the `WORKSPACE_ROOT` environment variable is used to map your local directory to the internal `/workspace` directory, enabling the MCP to translate absolute file references correctly:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
"env": {
|
"env": {
|
||||||
"GEMINI_API_KEY": "your-key",
|
"GEMINI_API_KEY": "your-key",
|
||||||
"MCP_PROJECT_ROOT": "/Users/you/specific-project"
|
"WORKSPACE_ROOT": "/Users/you/project" // Maps to /workspace inside Docker
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
This creates a sandbox limiting file access to only that directory and its subdirectories.
|
This allows Claude to use absolute paths that will be correctly translated between your local filesystem and the Docker container.
|
||||||
|
|
||||||
|
|
||||||
## How System Prompts Work
|
## How System Prompts Work
|
||||||
@@ -1044,18 +860,6 @@ To modify tool behavior, you can:
|
|||||||
2. Override `get_system_prompt()` in a tool class for tool-specific changes
|
2. Override `get_system_prompt()` in a tool class for tool-specific changes
|
||||||
3. Use the `temperature` parameter to adjust response style (0.2 for focused, 0.7 for creative)
|
3. Use the `temperature` parameter to adjust response style (0.2 for focused, 0.7 for creative)
|
||||||
|
|
||||||
## Contributing
|
|
||||||
|
|
||||||
We welcome contributions! The modular architecture makes it easy to add new tools:
|
|
||||||
|
|
||||||
1. Create a new tool in `tools/`
|
|
||||||
2. Inherit from `BaseTool`
|
|
||||||
3. Implement required methods (including `get_system_prompt()`)
|
|
||||||
4. Add your system prompt to `prompts/tool_prompts.py`
|
|
||||||
5. Register your tool in `TOOLS` dict in `server.py`
|
|
||||||
|
|
||||||
See existing tools for examples.
|
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
### Unit Tests (No API Key Required)
|
### Unit Tests (No API Key Required)
|
||||||
@@ -1063,32 +867,48 @@ The project includes comprehensive unit tests that use mocks and don't require a
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run all unit tests
|
# Run all unit tests
|
||||||
python -m pytest tests/ --ignore=tests/test_live_integration.py -v
|
python -m pytest tests/ -v
|
||||||
|
|
||||||
# Run with coverage
|
# Run with coverage
|
||||||
python -m pytest tests/ --ignore=tests/test_live_integration.py --cov=. --cov-report=html
|
python -m pytest tests/ --cov=. --cov-report=html
|
||||||
```
|
```
|
||||||
|
|
||||||
### Live Integration Tests (API Key Required)
|
### Simulation Tests (API Key Required)
|
||||||
To test actual API integration:
|
To test the MCP server with comprehensive end-to-end simulation:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Set your API key
|
# Set your API keys (at least one required)
|
||||||
export GEMINI_API_KEY=your-api-key-here
|
export GEMINI_API_KEY=your-gemini-api-key-here
|
||||||
|
export OPENAI_API_KEY=your-openai-api-key-here
|
||||||
|
|
||||||
# Run live integration tests
|
# Run all simulation tests (default: uses existing Docker containers)
|
||||||
python tests/test_live_integration.py
|
python communication_simulator_test.py
|
||||||
|
|
||||||
|
# Run specific tests only
|
||||||
|
python communication_simulator_test.py --tests basic_conversation content_validation
|
||||||
|
|
||||||
|
# Run with Docker rebuild (if needed)
|
||||||
|
python communication_simulator_test.py --rebuild-docker
|
||||||
|
|
||||||
|
# List available tests
|
||||||
|
python communication_simulator_test.py --list-tests
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The simulation tests validate:
|
||||||
|
- Basic conversation flow with continuation
|
||||||
|
- File handling and deduplication
|
||||||
|
- Cross-tool conversation threading
|
||||||
|
- Redis memory persistence
|
||||||
|
- Docker container integration
|
||||||
|
|
||||||
### GitHub Actions CI/CD
|
### GitHub Actions CI/CD
|
||||||
The project includes GitHub Actions workflows that:
|
The project includes GitHub Actions workflows that:
|
||||||
|
|
||||||
- **✅ Run unit tests automatically** - No API key needed, uses mocks
|
- **✅ Run unit tests automatically** - No API key needed, uses mocks
|
||||||
- **✅ Test on Python 3.10, 3.11, 3.12** - Ensures compatibility
|
- **✅ Test on Python 3.10, 3.11, 3.12** - Ensures compatibility
|
||||||
- **✅ Run linting and formatting checks** - Maintains code quality
|
- **✅ Run linting and formatting checks** - Maintains code quality
|
||||||
- **🔒 Run live tests only if API key is available** - Optional live verification
|
|
||||||
|
|
||||||
The CI pipeline works without any secrets and will pass all tests using mocked responses. Live integration tests only run if a `GEMINI_API_KEY` secret is configured in the repository.
|
The CI pipeline works without any secrets and will pass all tests using mocked responses. Simulation tests require API key secrets (`GEMINI_API_KEY` and/or `OPENAI_API_KEY`) to run the communication simulator.
|
||||||
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
@@ -1097,14 +917,14 @@ The CI pipeline works without any secrets and will pass all tests using mocked r
|
|||||||
**"Connection failed" in Claude Desktop**
|
**"Connection failed" in Claude Desktop**
|
||||||
- Ensure Docker services are running: `docker compose ps`
|
- Ensure Docker services are running: `docker compose ps`
|
||||||
- Check if the container name is correct: `docker ps` to see actual container names
|
- Check if the container name is correct: `docker ps` to see actual container names
|
||||||
- Verify your .env file has the correct GEMINI_API_KEY
|
- Verify your .env file has at least one valid API key (GEMINI_API_KEY or OPENAI_API_KEY)
|
||||||
|
|
||||||
**"GEMINI_API_KEY environment variable is required"**
|
**"API key environment variable is required"**
|
||||||
- Edit your .env file and add your API key
|
- Edit your .env file and add at least one API key (Gemini or OpenAI)
|
||||||
- Restart services: `docker compose restart`
|
- Restart services: `docker compose restart`
|
||||||
|
|
||||||
**Container fails to start**
|
**Container fails to start**
|
||||||
- Check logs: `docker compose logs gemini-mcp`
|
- Check logs: `docker compose logs zen-mcp`
|
||||||
- Ensure Docker has enough resources (memory/disk space)
|
- Ensure Docker has enough resources (memory/disk space)
|
||||||
- Try rebuilding: `docker compose build --no-cache`
|
- Try rebuilding: `docker compose build --no-cache`
|
||||||
|
|
||||||
@@ -1119,25 +939,12 @@ The CI pipeline works without any secrets and will pass all tests using mocked r
|
|||||||
docker compose ps
|
docker compose ps
|
||||||
|
|
||||||
# Test manual connection
|
# Test manual connection
|
||||||
docker exec -i gemini-mcp-server-gemini-mcp-1 echo "Connection test"
|
docker exec -i zen-mcp-server echo "Connection test"
|
||||||
|
|
||||||
# View logs
|
# View logs
|
||||||
docker compose logs -f
|
docker compose logs -f
|
||||||
```
|
```
|
||||||
|
|
||||||
**Conversation threading not working?**
|
|
||||||
If you're not seeing follow-up questions from Gemini:
|
|
||||||
```bash
|
|
||||||
# Check if Redis is running
|
|
||||||
docker compose logs redis
|
|
||||||
|
|
||||||
# Test conversation memory system
|
|
||||||
docker exec -i gemini-mcp-server-gemini-mcp-1 python debug_conversation.py
|
|
||||||
|
|
||||||
# Check for threading errors in logs
|
|
||||||
docker compose logs gemini-mcp | grep "threading failed"
|
|
||||||
```
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
MIT License - see LICENSE file for details.
|
MIT License - see LICENSE file for details.
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
{
|
{
|
||||||
"comment": "Example Claude Desktop configuration for Gemini MCP Server",
|
"comment": "Example Claude Desktop configuration for Zen MCP Server",
|
||||||
"comment2": "For Docker setup, use examples/claude_config_docker_home.json",
|
"comment2": "For Docker setup, use examples/claude_config_docker_home.json",
|
||||||
"comment3": "For platform-specific examples, see the examples/ directory",
|
"comment3": "For platform-specific examples, see the examples/ directory",
|
||||||
"mcpServers": {
|
"mcpServers": {
|
||||||
"gemini": {
|
"zen": {
|
||||||
"command": "/path/to/gemini-mcp-server/run_gemini.sh",
|
"command": "docker",
|
||||||
"env": {
|
"args": [
|
||||||
"GEMINI_API_KEY": "your-gemini-api-key-here"
|
"exec",
|
||||||
}
|
"-i",
|
||||||
|
"zen-mcp-server",
|
||||||
|
"python",
|
||||||
|
"server.py"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Communication Simulator Test for Gemini MCP Server
|
Communication Simulator Test for Zen MCP Server
|
||||||
|
|
||||||
This script provides comprehensive end-to-end testing of the Gemini MCP server
|
This script provides comprehensive end-to-end testing of the Zen MCP server
|
||||||
by simulating real Claude CLI communications and validating conversation
|
by simulating real Claude CLI communications and validating conversation
|
||||||
continuity, file handling, deduplication features, and clarification scenarios.
|
continuity, file handling, deduplication features, and clarification scenarios.
|
||||||
|
|
||||||
@@ -63,8 +63,8 @@ class CommunicationSimulator:
|
|||||||
self.keep_logs = keep_logs
|
self.keep_logs = keep_logs
|
||||||
self.selected_tests = selected_tests or []
|
self.selected_tests = selected_tests or []
|
||||||
self.temp_dir = None
|
self.temp_dir = None
|
||||||
self.container_name = "gemini-mcp-server"
|
self.container_name = "zen-mcp-server"
|
||||||
self.redis_container = "gemini-mcp-redis"
|
self.redis_container = "zen-mcp-redis"
|
||||||
|
|
||||||
# Import test registry
|
# Import test registry
|
||||||
from simulator_tests import TEST_REGISTRY
|
from simulator_tests import TEST_REGISTRY
|
||||||
@@ -282,7 +282,7 @@ class CommunicationSimulator:
|
|||||||
def print_test_summary(self):
|
def print_test_summary(self):
|
||||||
"""Print comprehensive test results summary"""
|
"""Print comprehensive test results summary"""
|
||||||
print("\\n" + "=" * 70)
|
print("\\n" + "=" * 70)
|
||||||
print("🧪 GEMINI MCP COMMUNICATION SIMULATOR - TEST RESULTS SUMMARY")
|
print("🧪 ZEN MCP COMMUNICATION SIMULATOR - TEST RESULTS SUMMARY")
|
||||||
print("=" * 70)
|
print("=" * 70)
|
||||||
|
|
||||||
passed_count = sum(1 for result in self.test_results.values() if result)
|
passed_count = sum(1 for result in self.test_results.values() if result)
|
||||||
@@ -303,7 +303,7 @@ class CommunicationSimulator:
|
|||||||
def run_full_test_suite(self, skip_docker_setup: bool = False) -> bool:
|
def run_full_test_suite(self, skip_docker_setup: bool = False) -> bool:
|
||||||
"""Run the complete test suite"""
|
"""Run the complete test suite"""
|
||||||
try:
|
try:
|
||||||
self.logger.info("🚀 Starting Gemini MCP Communication Simulator Test Suite")
|
self.logger.info("🚀 Starting Zen MCP Communication Simulator Test Suite")
|
||||||
|
|
||||||
# Setup
|
# Setup
|
||||||
if not skip_docker_setup:
|
if not skip_docker_setup:
|
||||||
@@ -359,7 +359,7 @@ class CommunicationSimulator:
|
|||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
"""Parse and validate command line arguments"""
|
"""Parse and validate command line arguments"""
|
||||||
parser = argparse.ArgumentParser(description="Gemini MCP Communication Simulator Test")
|
parser = argparse.ArgumentParser(description="Zen MCP Communication Simulator Test")
|
||||||
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging")
|
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging")
|
||||||
parser.add_argument("--keep-logs", action="store_true", help="Keep Docker services running for log inspection")
|
parser.add_argument("--keep-logs", action="store_true", help="Keep Docker services running for log inspection")
|
||||||
parser.add_argument("--tests", "-t", nargs="+", help="Specific tests to run (space-separated)")
|
parser.add_argument("--tests", "-t", nargs="+", help="Specific tests to run (space-separated)")
|
||||||
|
|||||||
11
config.py
11
config.py
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Configuration and constants for Gemini MCP Server
|
Configuration and constants for Zen MCP Server
|
||||||
|
|
||||||
This module centralizes all configuration settings for the Gemini MCP Server.
|
This module centralizes all configuration settings for the Zen MCP Server.
|
||||||
It defines model configurations, token limits, temperature defaults, and other
|
It defines model configurations, token limits, temperature defaults, and other
|
||||||
constants used throughout the application.
|
constants used throughout the application.
|
||||||
|
|
||||||
@@ -29,8 +29,11 @@ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "auto")
|
|||||||
VALID_MODELS = ["auto", "flash", "pro", "o3", "o3-mini", "gemini-2.0-flash-exp", "gemini-2.5-pro-preview-06-05"]
|
VALID_MODELS = ["auto", "flash", "pro", "o3", "o3-mini", "gemini-2.0-flash-exp", "gemini-2.5-pro-preview-06-05"]
|
||||||
if DEFAULT_MODEL not in VALID_MODELS:
|
if DEFAULT_MODEL not in VALID_MODELS:
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.warning(f"Invalid DEFAULT_MODEL '{DEFAULT_MODEL}'. Setting to 'auto'. Valid options: {', '.join(VALID_MODELS)}")
|
logger.warning(
|
||||||
|
f"Invalid DEFAULT_MODEL '{DEFAULT_MODEL}'. Setting to 'auto'. Valid options: {', '.join(VALID_MODELS)}"
|
||||||
|
)
|
||||||
DEFAULT_MODEL = "auto"
|
DEFAULT_MODEL = "auto"
|
||||||
|
|
||||||
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
|
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
|
||||||
@@ -45,7 +48,7 @@ MODEL_CAPABILITIES_DESC = {
|
|||||||
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||||
# Full model names also supported
|
# Full model names also supported
|
||||||
"gemini-2.0-flash-exp": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
"gemini-2.0-flash-exp": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||||
"gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis"
|
"gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Token allocation for Gemini Pro (1M total capacity)
|
# Token allocation for Gemini Pro (1M total capacity)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
services:
|
services:
|
||||||
redis:
|
redis:
|
||||||
image: redis:7-alpine
|
image: redis:7-alpine
|
||||||
container_name: gemini-mcp-redis
|
container_name: zen-mcp-redis
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
ports:
|
ports:
|
||||||
- "6379:6379"
|
- "6379:6379"
|
||||||
@@ -20,10 +20,10 @@ services:
|
|||||||
reservations:
|
reservations:
|
||||||
memory: 256M
|
memory: 256M
|
||||||
|
|
||||||
gemini-mcp:
|
zen-mcp:
|
||||||
build: .
|
build: .
|
||||||
image: gemini-mcp-server:latest
|
image: zen-mcp-server:latest
|
||||||
container_name: gemini-mcp-server
|
container_name: zen-mcp-server
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
depends_on:
|
depends_on:
|
||||||
redis:
|
redis:
|
||||||
@@ -50,11 +50,11 @@ services:
|
|||||||
|
|
||||||
log-monitor:
|
log-monitor:
|
||||||
build: .
|
build: .
|
||||||
image: gemini-mcp-server:latest
|
image: zen-mcp-server:latest
|
||||||
container_name: gemini-mcp-log-monitor
|
container_name: zen-mcp-log-monitor
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
depends_on:
|
depends_on:
|
||||||
- gemini-mcp
|
- zen-mcp
|
||||||
environment:
|
environment:
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
volumes:
|
volumes:
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
{
|
{
|
||||||
"comment": "Docker configuration that mounts your home directory",
|
"comment": "Docker configuration that mounts your home directory",
|
||||||
"comment2": "Update paths: /path/to/gemini-mcp-server/.env and /Users/your-username",
|
"comment2": "Update paths: /path/to/zen-mcp-server/.env and /Users/your-username",
|
||||||
"comment3": "The container auto-detects /workspace as sandbox from WORKSPACE_ROOT",
|
"comment3": "The container auto-detects /workspace as sandbox from WORKSPACE_ROOT",
|
||||||
"mcpServers": {
|
"mcpServers": {
|
||||||
"gemini": {
|
"zen": {
|
||||||
"command": "docker",
|
"command": "docker",
|
||||||
"args": [
|
"args": [
|
||||||
"run",
|
"run",
|
||||||
"--rm",
|
"--rm",
|
||||||
"-i",
|
"-i",
|
||||||
"--env-file", "/path/to/gemini-mcp-server/.env",
|
"--env-file", "/path/to/zen-mcp-server/.env",
|
||||||
"-e", "WORKSPACE_ROOT=/Users/your-username",
|
"-e", "WORKSPACE_ROOT=/Users/your-username",
|
||||||
"-v", "/Users/your-username:/workspace:ro",
|
"-v", "/Users/your-username:/workspace:ro",
|
||||||
"gemini-mcp-server:latest"
|
"zen-mcp-server:latest"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
{
|
{
|
||||||
"comment": "Traditional macOS/Linux configuration (non-Docker)",
|
"comment": "macOS configuration using Docker",
|
||||||
"comment2": "Replace YOUR_USERNAME with your actual username",
|
"comment2": "Ensure Docker is running and containers are started",
|
||||||
"comment3": "This gives access to all files under your home directory",
|
"comment3": "Run './setup-docker.sh' first to set up the environment",
|
||||||
"mcpServers": {
|
"mcpServers": {
|
||||||
"gemini": {
|
"zen": {
|
||||||
"command": "/Users/YOUR_USERNAME/gemini-mcp-server/run_gemini.sh",
|
"command": "docker",
|
||||||
"env": {
|
"args": [
|
||||||
"GEMINI_API_KEY": "your-gemini-api-key-here"
|
"exec",
|
||||||
}
|
"-i",
|
||||||
|
"zen-mcp-server",
|
||||||
|
"python",
|
||||||
|
"server.py"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
{
|
{
|
||||||
"comment": "Windows configuration using WSL (Windows Subsystem for Linux)",
|
"comment": "Windows configuration using WSL with Docker",
|
||||||
"comment2": "Replace YOUR_WSL_USERNAME with your WSL username",
|
"comment2": "Ensure Docker Desktop is running and WSL integration is enabled",
|
||||||
"comment3": "Make sure the server is installed in your WSL environment",
|
"comment3": "Run './setup-docker.sh' in WSL first to set up the environment",
|
||||||
"mcpServers": {
|
"mcpServers": {
|
||||||
"gemini": {
|
"zen": {
|
||||||
"command": "wsl.exe",
|
"command": "wsl.exe",
|
||||||
"args": ["/home/YOUR_WSL_USERNAME/gemini-mcp-server/run_gemini.sh"],
|
"args": [
|
||||||
"env": {
|
"docker",
|
||||||
"GEMINI_API_KEY": "your-gemini-api-key-here"
|
"exec",
|
||||||
}
|
"-i",
|
||||||
|
"zen-mcp-server",
|
||||||
|
"python",
|
||||||
|
"server.py"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""Model provider abstractions for supporting multiple AI providers."""
|
"""Model provider abstractions for supporting multiple AI providers."""
|
||||||
|
|
||||||
from .base import ModelProvider, ModelResponse, ModelCapabilities
|
from .base import ModelCapabilities, ModelProvider, ModelResponse
|
||||||
from .registry import ModelProviderRegistry
|
|
||||||
from .gemini import GeminiModelProvider
|
from .gemini import GeminiModelProvider
|
||||||
from .openai import OpenAIModelProvider
|
from .openai import OpenAIModelProvider
|
||||||
|
from .registry import ModelProviderRegistry
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelProvider",
|
"ModelProvider",
|
||||||
@@ -12,4 +12,4 @@ __all__ = [
|
|||||||
"ModelProviderRegistry",
|
"ModelProviderRegistry",
|
||||||
"GeminiModelProvider",
|
"GeminiModelProvider",
|
||||||
"OpenAIModelProvider",
|
"OpenAIModelProvider",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2,34 +2,35 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional, Any, Tuple
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class ProviderType(Enum):
|
class ProviderType(Enum):
|
||||||
"""Supported model provider types."""
|
"""Supported model provider types."""
|
||||||
|
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
|
||||||
|
|
||||||
class TemperatureConstraint(ABC):
|
class TemperatureConstraint(ABC):
|
||||||
"""Abstract base class for temperature constraints."""
|
"""Abstract base class for temperature constraints."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def validate(self, temperature: float) -> bool:
|
def validate(self, temperature: float) -> bool:
|
||||||
"""Check if temperature is valid."""
|
"""Check if temperature is valid."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_corrected_value(self, temperature: float) -> float:
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
"""Get nearest valid temperature."""
|
"""Get nearest valid temperature."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_description(self) -> str:
|
def get_description(self) -> str:
|
||||||
"""Get human-readable description of constraint."""
|
"""Get human-readable description of constraint."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_default(self) -> float:
|
def get_default(self) -> float:
|
||||||
"""Get model's default temperature."""
|
"""Get model's default temperature."""
|
||||||
@@ -38,60 +39,60 @@ class TemperatureConstraint(ABC):
|
|||||||
|
|
||||||
class FixedTemperatureConstraint(TemperatureConstraint):
|
class FixedTemperatureConstraint(TemperatureConstraint):
|
||||||
"""For models that only support one temperature value (e.g., O3)."""
|
"""For models that only support one temperature value (e.g., O3)."""
|
||||||
|
|
||||||
def __init__(self, value: float):
|
def __init__(self, value: float):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def validate(self, temperature: float) -> bool:
|
def validate(self, temperature: float) -> bool:
|
||||||
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
||||||
|
|
||||||
def get_corrected_value(self, temperature: float) -> float:
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
def get_description(self) -> str:
|
def get_description(self) -> str:
|
||||||
return f"Only supports temperature={self.value}"
|
return f"Only supports temperature={self.value}"
|
||||||
|
|
||||||
def get_default(self) -> float:
|
def get_default(self) -> float:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
class RangeTemperatureConstraint(TemperatureConstraint):
|
class RangeTemperatureConstraint(TemperatureConstraint):
|
||||||
"""For models supporting continuous temperature ranges."""
|
"""For models supporting continuous temperature ranges."""
|
||||||
|
|
||||||
def __init__(self, min_temp: float, max_temp: float, default: float = None):
|
def __init__(self, min_temp: float, max_temp: float, default: float = None):
|
||||||
self.min_temp = min_temp
|
self.min_temp = min_temp
|
||||||
self.max_temp = max_temp
|
self.max_temp = max_temp
|
||||||
self.default_temp = default or (min_temp + max_temp) / 2
|
self.default_temp = default or (min_temp + max_temp) / 2
|
||||||
|
|
||||||
def validate(self, temperature: float) -> bool:
|
def validate(self, temperature: float) -> bool:
|
||||||
return self.min_temp <= temperature <= self.max_temp
|
return self.min_temp <= temperature <= self.max_temp
|
||||||
|
|
||||||
def get_corrected_value(self, temperature: float) -> float:
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
return max(self.min_temp, min(self.max_temp, temperature))
|
return max(self.min_temp, min(self.max_temp, temperature))
|
||||||
|
|
||||||
def get_description(self) -> str:
|
def get_description(self) -> str:
|
||||||
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
||||||
|
|
||||||
def get_default(self) -> float:
|
def get_default(self) -> float:
|
||||||
return self.default_temp
|
return self.default_temp
|
||||||
|
|
||||||
|
|
||||||
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
||||||
"""For models supporting only specific temperature values."""
|
"""For models supporting only specific temperature values."""
|
||||||
|
|
||||||
def __init__(self, allowed_values: List[float], default: float = None):
|
def __init__(self, allowed_values: list[float], default: float = None):
|
||||||
self.allowed_values = sorted(allowed_values)
|
self.allowed_values = sorted(allowed_values)
|
||||||
self.default_temp = default or allowed_values[len(allowed_values)//2]
|
self.default_temp = default or allowed_values[len(allowed_values) // 2]
|
||||||
|
|
||||||
def validate(self, temperature: float) -> bool:
|
def validate(self, temperature: float) -> bool:
|
||||||
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
||||||
|
|
||||||
def get_corrected_value(self, temperature: float) -> float:
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
||||||
|
|
||||||
def get_description(self) -> str:
|
def get_description(self) -> str:
|
||||||
return f"Supports temperatures: {self.allowed_values}"
|
return f"Supports temperatures: {self.allowed_values}"
|
||||||
|
|
||||||
def get_default(self) -> float:
|
def get_default(self) -> float:
|
||||||
return self.default_temp
|
return self.default_temp
|
||||||
|
|
||||||
@@ -99,6 +100,7 @@ class DiscreteTemperatureConstraint(TemperatureConstraint):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelCapabilities:
|
class ModelCapabilities:
|
||||||
"""Capabilities and constraints for a specific model."""
|
"""Capabilities and constraints for a specific model."""
|
||||||
|
|
||||||
provider: ProviderType
|
provider: ProviderType
|
||||||
model_name: str
|
model_name: str
|
||||||
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
||||||
@@ -107,15 +109,15 @@ class ModelCapabilities:
|
|||||||
supports_system_prompts: bool = True
|
supports_system_prompts: bool = True
|
||||||
supports_streaming: bool = True
|
supports_streaming: bool = True
|
||||||
supports_function_calling: bool = False
|
supports_function_calling: bool = False
|
||||||
|
|
||||||
# Temperature constraint object - preferred way to define temperature limits
|
# Temperature constraint object - preferred way to define temperature limits
|
||||||
temperature_constraint: TemperatureConstraint = field(
|
temperature_constraint: TemperatureConstraint = field(
|
||||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Backward compatibility property for existing code
|
# Backward compatibility property for existing code
|
||||||
@property
|
@property
|
||||||
def temperature_range(self) -> Tuple[float, float]:
|
def temperature_range(self) -> tuple[float, float]:
|
||||||
"""Backward compatibility for existing code that uses temperature_range."""
|
"""Backward compatibility for existing code that uses temperature_range."""
|
||||||
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
|
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
|
||||||
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
|
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
|
||||||
@@ -130,13 +132,14 @@ class ModelCapabilities:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelResponse:
|
class ModelResponse:
|
||||||
"""Response from a model provider."""
|
"""Response from a model provider."""
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
usage: Dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
|
usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
|
||||||
model_name: str = ""
|
model_name: str = ""
|
||||||
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
|
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
|
||||||
provider: ProviderType = ProviderType.GOOGLE
|
provider: ProviderType = ProviderType.GOOGLE
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
|
metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_tokens(self) -> int:
|
def total_tokens(self) -> int:
|
||||||
"""Get total tokens used."""
|
"""Get total tokens used."""
|
||||||
@@ -145,17 +148,17 @@ class ModelResponse:
|
|||||||
|
|
||||||
class ModelProvider(ABC):
|
class ModelProvider(ABC):
|
||||||
"""Abstract base class for model providers."""
|
"""Abstract base class for model providers."""
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize the provider with API key and optional configuration."""
|
"""Initialize the provider with API key and optional configuration."""
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.config = kwargs
|
self.config = kwargs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
"""Get capabilities for a specific model."""
|
"""Get capabilities for a specific model."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
@@ -164,10 +167,10 @@ class ModelProvider(ABC):
|
|||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_output_tokens: Optional[int] = None,
|
max_output_tokens: Optional[int] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Generate content using the model.
|
"""Generate content using the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: User prompt to send to the model
|
prompt: User prompt to send to the model
|
||||||
model_name: Name of the model to use
|
model_name: Name of the model to use
|
||||||
@@ -175,49 +178,43 @@ class ModelProvider(ABC):
|
|||||||
temperature: Sampling temperature (0-2)
|
temperature: Sampling temperature (0-2)
|
||||||
max_output_tokens: Maximum tokens to generate
|
max_output_tokens: Maximum tokens to generate
|
||||||
**kwargs: Provider-specific parameters
|
**kwargs: Provider-specific parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelResponse with generated content and metadata
|
ModelResponse with generated content and metadata
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
"""Count tokens for the given text using the specified model's tokenizer."""
|
"""Count tokens for the given text using the specified model's tokenizer."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
"""Validate if the model name is supported by this provider."""
|
"""Validate if the model name is supported by this provider."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def validate_parameters(
|
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
temperature: float,
|
|
||||||
**kwargs
|
|
||||||
) -> None:
|
|
||||||
"""Validate model parameters against capabilities.
|
"""Validate model parameters against capabilities.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If parameters are invalid
|
ValueError: If parameters are invalid
|
||||||
"""
|
"""
|
||||||
capabilities = self.get_capabilities(model_name)
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
# Validate temperature
|
# Validate temperature
|
||||||
min_temp, max_temp = capabilities.temperature_range
|
min_temp, max_temp = capabilities.temperature_range
|
||||||
if not min_temp <= temperature <= max_temp:
|
if not min_temp <= temperature <= max_temp:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Temperature {temperature} out of range [{min_temp}, {max_temp}] "
|
f"Temperature {temperature} out of range [{min_temp}, {max_temp}] " f"for model {model_name}"
|
||||||
f"for model {model_name}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
"""Check if the model supports extended thinking mode."""
|
"""Check if the model supports extended thinking mode."""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,22 +1,16 @@
|
|||||||
"""Gemini model provider implementation."""
|
"""Gemini model provider implementation."""
|
||||||
|
|
||||||
import os
|
from typing import Optional
|
||||||
from typing import Dict, Optional, List
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
from .base import (
|
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
|
||||||
ModelProvider,
|
|
||||||
ModelResponse,
|
|
||||||
ModelCapabilities,
|
|
||||||
ProviderType,
|
|
||||||
RangeTemperatureConstraint
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiModelProvider(ModelProvider):
|
class GeminiModelProvider(ModelProvider):
|
||||||
"""Google Gemini model provider implementation."""
|
"""Google Gemini model provider implementation."""
|
||||||
|
|
||||||
# Model configurations
|
# Model configurations
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"gemini-2.0-flash-exp": {
|
"gemini-2.0-flash-exp": {
|
||||||
@@ -31,42 +25,42 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
"flash": "gemini-2.0-flash-exp",
|
"flash": "gemini-2.0-flash-exp",
|
||||||
"pro": "gemini-2.5-pro-preview-06-05",
|
"pro": "gemini-2.5-pro-preview-06-05",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Thinking mode configurations for models that support it
|
# Thinking mode configurations for models that support it
|
||||||
THINKING_BUDGETS = {
|
THINKING_BUDGETS = {
|
||||||
"minimal": 128, # Minimum for 2.5 Pro - fast responses
|
"minimal": 128, # Minimum for 2.5 Pro - fast responses
|
||||||
"low": 2048, # Light reasoning tasks
|
"low": 2048, # Light reasoning tasks
|
||||||
"medium": 8192, # Balanced reasoning (default)
|
"medium": 8192, # Balanced reasoning (default)
|
||||||
"high": 16384, # Complex analysis
|
"high": 16384, # Complex analysis
|
||||||
"max": 32768, # Maximum reasoning depth
|
"max": 32768, # Maximum reasoning depth
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize Gemini provider with API key."""
|
"""Initialize Gemini provider with API key."""
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
self._client = None
|
self._client = None
|
||||||
self._token_counters = {} # Cache for token counting
|
self._token_counters = {} # Cache for token counting
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self):
|
def client(self):
|
||||||
"""Lazy initialization of Gemini client."""
|
"""Lazy initialization of Gemini client."""
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
self._client = genai.Client(api_key=self.api_key)
|
self._client = genai.Client(api_key=self.api_key)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
"""Get capabilities for a specific Gemini model."""
|
"""Get capabilities for a specific Gemini model."""
|
||||||
# Resolve shorthand
|
# Resolve shorthand
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS:
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
raise ValueError(f"Unsupported Gemini model: {model_name}")
|
raise ValueError(f"Unsupported Gemini model: {model_name}")
|
||||||
|
|
||||||
config = self.SUPPORTED_MODELS[resolved_name]
|
config = self.SUPPORTED_MODELS[resolved_name]
|
||||||
|
|
||||||
# Gemini models support 0.0-2.0 temperature range
|
# Gemini models support 0.0-2.0 temperature range
|
||||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||||
|
|
||||||
return ModelCapabilities(
|
return ModelCapabilities(
|
||||||
provider=ProviderType.GOOGLE,
|
provider=ProviderType.GOOGLE,
|
||||||
model_name=resolved_name,
|
model_name=resolved_name,
|
||||||
@@ -78,7 +72,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
supports_function_calling=True,
|
supports_function_calling=True,
|
||||||
temperature_constraint=temp_constraint,
|
temperature_constraint=temp_constraint,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -87,36 +81,36 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_output_tokens: Optional[int] = None,
|
max_output_tokens: Optional[int] = None,
|
||||||
thinking_mode: str = "medium",
|
thinking_mode: str = "medium",
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Generate content using Gemini model."""
|
"""Generate content using Gemini model."""
|
||||||
# Validate parameters
|
# Validate parameters
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
self.validate_parameters(resolved_name, temperature)
|
self.validate_parameters(resolved_name, temperature)
|
||||||
|
|
||||||
# Combine system prompt with user prompt if provided
|
# Combine system prompt with user prompt if provided
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
full_prompt = f"{system_prompt}\n\n{prompt}"
|
full_prompt = f"{system_prompt}\n\n{prompt}"
|
||||||
else:
|
else:
|
||||||
full_prompt = prompt
|
full_prompt = prompt
|
||||||
|
|
||||||
# Prepare generation config
|
# Prepare generation config
|
||||||
generation_config = types.GenerateContentConfig(
|
generation_config = types.GenerateContentConfig(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
candidate_count=1,
|
candidate_count=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add max output tokens if specified
|
# Add max output tokens if specified
|
||||||
if max_output_tokens:
|
if max_output_tokens:
|
||||||
generation_config.max_output_tokens = max_output_tokens
|
generation_config.max_output_tokens = max_output_tokens
|
||||||
|
|
||||||
# Add thinking configuration for models that support it
|
# Add thinking configuration for models that support it
|
||||||
capabilities = self.get_capabilities(resolved_name)
|
capabilities = self.get_capabilities(resolved_name)
|
||||||
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||||
generation_config.thinking_config = types.ThinkingConfig(
|
generation_config.thinking_config = types.ThinkingConfig(
|
||||||
thinking_budget=self.THINKING_BUDGETS[thinking_mode]
|
thinking_budget=self.THINKING_BUDGETS[thinking_mode]
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate content
|
# Generate content
|
||||||
response = self.client.models.generate_content(
|
response = self.client.models.generate_content(
|
||||||
@@ -124,10 +118,10 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
contents=full_prompt,
|
contents=full_prompt,
|
||||||
config=generation_config,
|
config=generation_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract usage information if available
|
# Extract usage information if available
|
||||||
usage = self._extract_usage(response)
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
return ModelResponse(
|
return ModelResponse(
|
||||||
content=response.text,
|
content=response.text,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@@ -136,38 +130,40 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
provider=ProviderType.GOOGLE,
|
provider=ProviderType.GOOGLE,
|
||||||
metadata={
|
metadata={
|
||||||
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
|
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
|
||||||
"finish_reason": getattr(response.candidates[0], "finish_reason", "STOP") if response.candidates else "STOP",
|
"finish_reason": (
|
||||||
}
|
getattr(response.candidates[0], "finish_reason", "STOP") if response.candidates else "STOP"
|
||||||
|
),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log error and re-raise with more context
|
# Log error and re-raise with more context
|
||||||
error_msg = f"Gemini API error for model {resolved_name}: {str(e)}"
|
error_msg = f"Gemini API error for model {resolved_name}: {str(e)}"
|
||||||
raise RuntimeError(error_msg) from e
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
"""Count tokens for the given text using Gemini's tokenizer."""
|
"""Count tokens for the given text using Gemini's tokenizer."""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# For now, use a simple estimation
|
# For now, use a simple estimation
|
||||||
# TODO: Use actual Gemini tokenizer when available in SDK
|
# TODO: Use actual Gemini tokenizer when available in SDK
|
||||||
# Rough estimation: ~4 characters per token for English text
|
# Rough estimation: ~4 characters per token for English text
|
||||||
return len(text) // 4
|
return len(text) // 4
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.GOOGLE
|
return ProviderType.GOOGLE
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
"""Validate if the model name is supported."""
|
"""Validate if the model name is supported."""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
|
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
|
||||||
|
|
||||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
"""Check if the model supports extended thinking mode."""
|
"""Check if the model supports extended thinking mode."""
|
||||||
capabilities = self.get_capabilities(model_name)
|
capabilities = self.get_capabilities(model_name)
|
||||||
return capabilities.supports_extended_thinking
|
return capabilities.supports_extended_thinking
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
"""Resolve model shorthand to full name."""
|
"""Resolve model shorthand to full name."""
|
||||||
# Check if it's a shorthand
|
# Check if it's a shorthand
|
||||||
@@ -175,11 +171,11 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
if isinstance(shorthand_value, str):
|
if isinstance(shorthand_value, str):
|
||||||
return shorthand_value
|
return shorthand_value
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def _extract_usage(self, response) -> Dict[str, int]:
|
def _extract_usage(self, response) -> dict[str, int]:
|
||||||
"""Extract token usage from Gemini response."""
|
"""Extract token usage from Gemini response."""
|
||||||
usage = {}
|
usage = {}
|
||||||
|
|
||||||
# Try to extract usage metadata from response
|
# Try to extract usage metadata from response
|
||||||
# Note: The actual structure depends on the SDK version and response format
|
# Note: The actual structure depends on the SDK version and response format
|
||||||
if hasattr(response, "usage_metadata"):
|
if hasattr(response, "usage_metadata"):
|
||||||
@@ -190,5 +186,5 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
usage["output_tokens"] = metadata.candidates_token_count
|
usage["output_tokens"] = metadata.candidates_token_count
|
||||||
if "input_tokens" in usage and "output_tokens" in usage:
|
if "input_tokens" in usage and "output_tokens" in usage:
|
||||||
usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
|
usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|||||||
@@ -1,24 +1,23 @@
|
|||||||
"""OpenAI model provider implementation."""
|
"""OpenAI model provider implementation."""
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Dict, Optional, List, Any
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelProvider,
|
|
||||||
ModelResponse,
|
|
||||||
ModelCapabilities,
|
|
||||||
ProviderType,
|
|
||||||
FixedTemperatureConstraint,
|
FixedTemperatureConstraint,
|
||||||
RangeTemperatureConstraint
|
ModelCapabilities,
|
||||||
|
ModelProvider,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderType,
|
||||||
|
RangeTemperatureConstraint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelProvider(ModelProvider):
|
class OpenAIModelProvider(ModelProvider):
|
||||||
"""OpenAI model provider implementation."""
|
"""OpenAI model provider implementation."""
|
||||||
|
|
||||||
# Model configurations
|
# Model configurations
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"o3": {
|
"o3": {
|
||||||
@@ -30,14 +29,14 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
"supports_extended_thinking": False,
|
"supports_extended_thinking": False,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize OpenAI provider with API key."""
|
"""Initialize OpenAI provider with API key."""
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
self._client = None
|
self._client = None
|
||||||
self.base_url = kwargs.get("base_url") # Support custom endpoints
|
self.base_url = kwargs.get("base_url") # Support custom endpoints
|
||||||
self.organization = kwargs.get("organization")
|
self.organization = kwargs.get("organization")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self):
|
def client(self):
|
||||||
"""Lazy initialization of OpenAI client."""
|
"""Lazy initialization of OpenAI client."""
|
||||||
@@ -47,17 +46,17 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
client_kwargs["base_url"] = self.base_url
|
client_kwargs["base_url"] = self.base_url
|
||||||
if self.organization:
|
if self.organization:
|
||||||
client_kwargs["organization"] = self.organization
|
client_kwargs["organization"] = self.organization
|
||||||
|
|
||||||
self._client = OpenAI(**client_kwargs)
|
self._client = OpenAI(**client_kwargs)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
"""Get capabilities for a specific OpenAI model."""
|
"""Get capabilities for a specific OpenAI model."""
|
||||||
if model_name not in self.SUPPORTED_MODELS:
|
if model_name not in self.SUPPORTED_MODELS:
|
||||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||||
|
|
||||||
config = self.SUPPORTED_MODELS[model_name]
|
config = self.SUPPORTED_MODELS[model_name]
|
||||||
|
|
||||||
# Define temperature constraints per model
|
# Define temperature constraints per model
|
||||||
if model_name in ["o3", "o3-mini"]:
|
if model_name in ["o3", "o3-mini"]:
|
||||||
# O3 models only support temperature=1.0
|
# O3 models only support temperature=1.0
|
||||||
@@ -65,7 +64,7 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
else:
|
else:
|
||||||
# Other OpenAI models support 0.0-2.0 range
|
# Other OpenAI models support 0.0-2.0 range
|
||||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||||
|
|
||||||
return ModelCapabilities(
|
return ModelCapabilities(
|
||||||
provider=ProviderType.OPENAI,
|
provider=ProviderType.OPENAI,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@@ -77,7 +76,7 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
supports_function_calling=True,
|
supports_function_calling=True,
|
||||||
temperature_constraint=temp_constraint,
|
temperature_constraint=temp_constraint,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -85,42 +84,42 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_output_tokens: Optional[int] = None,
|
max_output_tokens: Optional[int] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Generate content using OpenAI model."""
|
"""Generate content using OpenAI model."""
|
||||||
# Validate parameters
|
# Validate parameters
|
||||||
self.validate_parameters(model_name, temperature)
|
self.validate_parameters(model_name, temperature)
|
||||||
|
|
||||||
# Prepare messages
|
# Prepare messages
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
# Prepare completion parameters
|
# Prepare completion parameters
|
||||||
completion_params = {
|
completion_params = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add max tokens if specified
|
# Add max tokens if specified
|
||||||
if max_output_tokens:
|
if max_output_tokens:
|
||||||
completion_params["max_tokens"] = max_output_tokens
|
completion_params["max_tokens"] = max_output_tokens
|
||||||
|
|
||||||
# Add any additional OpenAI-specific parameters
|
# Add any additional OpenAI-specific parameters
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop"]:
|
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop"]:
|
||||||
completion_params[key] = value
|
completion_params[key] = value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate completion
|
# Generate completion
|
||||||
response = self.client.chat.completions.create(**completion_params)
|
response = self.client.chat.completions.create(**completion_params)
|
||||||
|
|
||||||
# Extract content and usage
|
# Extract content and usage
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
usage = self._extract_usage(response)
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
return ModelResponse(
|
return ModelResponse(
|
||||||
content=content,
|
content=content,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@@ -132,18 +131,18 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
"model": response.model, # Actual model used (in case of fallbacks)
|
"model": response.model, # Actual model used (in case of fallbacks)
|
||||||
"id": response.id,
|
"id": response.id,
|
||||||
"created": response.created,
|
"created": response.created,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log error and re-raise with more context
|
# Log error and re-raise with more context
|
||||||
error_msg = f"OpenAI API error for model {model_name}: {str(e)}"
|
error_msg = f"OpenAI API error for model {model_name}: {str(e)}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise RuntimeError(error_msg) from e
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
def count_tokens(self, text: str, model_name: str) -> int:
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
"""Count tokens for the given text.
|
"""Count tokens for the given text.
|
||||||
|
|
||||||
Note: For accurate token counting, we should use tiktoken library.
|
Note: For accurate token counting, we should use tiktoken library.
|
||||||
This is a simplified estimation.
|
This is a simplified estimation.
|
||||||
"""
|
"""
|
||||||
@@ -151,28 +150,28 @@ class OpenAIModelProvider(ModelProvider):
|
|||||||
# For now, use rough estimation
|
# For now, use rough estimation
|
||||||
# O3 models ~4 chars per token
|
# O3 models ~4 chars per token
|
||||||
return len(text) // 4
|
return len(text) // 4
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.OPENAI
|
return ProviderType.OPENAI
|
||||||
|
|
||||||
def validate_model_name(self, model_name: str) -> bool:
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
"""Validate if the model name is supported."""
|
"""Validate if the model name is supported."""
|
||||||
return model_name in self.SUPPORTED_MODELS
|
return model_name in self.SUPPORTED_MODELS
|
||||||
|
|
||||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
"""Check if the model supports extended thinking mode."""
|
"""Check if the model supports extended thinking mode."""
|
||||||
# Currently no OpenAI models support extended thinking
|
# Currently no OpenAI models support extended thinking
|
||||||
# This may change with future O3 models
|
# This may change with future O3 models
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _extract_usage(self, response) -> Dict[str, int]:
|
def _extract_usage(self, response) -> dict[str, int]:
|
||||||
"""Extract token usage from OpenAI response."""
|
"""Extract token usage from OpenAI response."""
|
||||||
usage = {}
|
usage = {}
|
||||||
|
|
||||||
if hasattr(response, "usage") and response.usage:
|
if hasattr(response, "usage") and response.usage:
|
||||||
usage["input_tokens"] = response.usage.prompt_tokens
|
usage["input_tokens"] = response.usage.prompt_tokens
|
||||||
usage["output_tokens"] = response.usage.completion_tokens
|
usage["output_tokens"] = response.usage.completion_tokens
|
||||||
usage["total_tokens"] = response.usage.total_tokens
|
usage["total_tokens"] = response.usage.total_tokens
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|||||||
@@ -1,115 +1,116 @@
|
|||||||
"""Model provider registry for managing available providers."""
|
"""Model provider registry for managing available providers."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Optional, Type, List
|
from typing import Optional
|
||||||
|
|
||||||
from .base import ModelProvider, ProviderType
|
from .base import ModelProvider, ProviderType
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderRegistry:
|
class ModelProviderRegistry:
|
||||||
"""Registry for managing model providers."""
|
"""Registry for managing model providers."""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
_providers: Dict[ProviderType, Type[ModelProvider]] = {}
|
_providers: dict[ProviderType, type[ModelProvider]] = {}
|
||||||
_initialized_providers: Dict[ProviderType, ModelProvider] = {}
|
_initialized_providers: dict[ProviderType, ModelProvider] = {}
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls):
|
||||||
"""Singleton pattern for registry."""
|
"""Singleton pattern for registry."""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_provider(cls, provider_type: ProviderType, provider_class: Type[ModelProvider]) -> None:
|
def register_provider(cls, provider_type: ProviderType, provider_class: type[ModelProvider]) -> None:
|
||||||
"""Register a new provider class.
|
"""Register a new provider class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
|
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
|
||||||
provider_class: Class that implements ModelProvider interface
|
provider_class: Class that implements ModelProvider interface
|
||||||
"""
|
"""
|
||||||
cls._providers[provider_type] = provider_class
|
cls._providers[provider_type] = provider_class
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
|
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
|
||||||
"""Get an initialized provider instance.
|
"""Get an initialized provider instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider_type: Type of provider to get
|
provider_type: Type of provider to get
|
||||||
force_new: Force creation of new instance instead of using cached
|
force_new: Force creation of new instance instead of using cached
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Initialized ModelProvider instance or None if not available
|
Initialized ModelProvider instance or None if not available
|
||||||
"""
|
"""
|
||||||
# Return cached instance if available and not forcing new
|
# Return cached instance if available and not forcing new
|
||||||
if not force_new and provider_type in cls._initialized_providers:
|
if not force_new and provider_type in cls._initialized_providers:
|
||||||
return cls._initialized_providers[provider_type]
|
return cls._initialized_providers[provider_type]
|
||||||
|
|
||||||
# Check if provider class is registered
|
# Check if provider class is registered
|
||||||
if provider_type not in cls._providers:
|
if provider_type not in cls._providers:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get API key from environment
|
# Get API key from environment
|
||||||
api_key = cls._get_api_key_for_provider(provider_type)
|
api_key = cls._get_api_key_for_provider(provider_type)
|
||||||
if not api_key:
|
if not api_key:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Initialize provider
|
# Initialize provider
|
||||||
provider_class = cls._providers[provider_type]
|
provider_class = cls._providers[provider_type]
|
||||||
provider = provider_class(api_key=api_key)
|
provider = provider_class(api_key=api_key)
|
||||||
|
|
||||||
# Cache the instance
|
# Cache the instance
|
||||||
cls._initialized_providers[provider_type] = provider
|
cls._initialized_providers[provider_type] = provider
|
||||||
|
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
|
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
|
||||||
"""Get provider instance for a specific model name.
|
"""Get provider instance for a specific model name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the model (e.g., "gemini-2.0-flash-exp", "o3-mini")
|
model_name: Name of the model (e.g., "gemini-2.0-flash-exp", "o3-mini")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelProvider instance that supports this model
|
ModelProvider instance that supports this model
|
||||||
"""
|
"""
|
||||||
# Check each registered provider
|
# Check each registered provider
|
||||||
for provider_type, provider_class in cls._providers.items():
|
for provider_type, _provider_class in cls._providers.items():
|
||||||
# Get or create provider instance
|
# Get or create provider instance
|
||||||
provider = cls.get_provider(provider_type)
|
provider = cls.get_provider(provider_type)
|
||||||
if provider and provider.validate_model_name(model_name):
|
if provider and provider.validate_model_name(model_name):
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_available_providers(cls) -> List[ProviderType]:
|
def get_available_providers(cls) -> list[ProviderType]:
|
||||||
"""Get list of registered provider types."""
|
"""Get list of registered provider types."""
|
||||||
return list(cls._providers.keys())
|
return list(cls._providers.keys())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_available_models(cls) -> Dict[str, ProviderType]:
|
def get_available_models(cls) -> dict[str, ProviderType]:
|
||||||
"""Get mapping of all available models to their providers.
|
"""Get mapping of all available models to their providers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping model names to provider types
|
Dict mapping model names to provider types
|
||||||
"""
|
"""
|
||||||
models = {}
|
models = {}
|
||||||
|
|
||||||
for provider_type in cls._providers:
|
for provider_type in cls._providers:
|
||||||
provider = cls.get_provider(provider_type)
|
provider = cls.get_provider(provider_type)
|
||||||
if provider:
|
if provider:
|
||||||
# This assumes providers have a method to list supported models
|
# This assumes providers have a method to list supported models
|
||||||
# We'll need to add this to the interface
|
# We'll need to add this to the interface
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
|
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
|
||||||
"""Get API key for a provider from environment variables.
|
"""Get API key for a provider from environment variables.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider_type: Provider type to get API key for
|
provider_type: Provider type to get API key for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
API key string or None if not found
|
API key string or None if not found
|
||||||
"""
|
"""
|
||||||
@@ -117,20 +118,20 @@ class ModelProviderRegistry:
|
|||||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||||
}
|
}
|
||||||
|
|
||||||
env_var = key_mapping.get(provider_type)
|
env_var = key_mapping.get(provider_type)
|
||||||
if not env_var:
|
if not env_var:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return os.getenv(env_var)
|
return os.getenv(env_var)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clear_cache(cls) -> None:
|
def clear_cache(cls) -> None:
|
||||||
"""Clear cached provider instances."""
|
"""Clear cached provider instances."""
|
||||||
cls._initialized_providers.clear()
|
cls._initialized_providers.clear()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def unregister_provider(cls, provider_type: ProviderType) -> None:
|
def unregister_provider(cls, provider_type: ProviderType) -> None:
|
||||||
"""Unregister a provider (mainly for testing)."""
|
"""Unregister a provider (mainly for testing)."""
|
||||||
cls._providers.pop(provider_type, None)
|
cls._providers.pop(provider_type, None)
|
||||||
cls._initialized_providers.pop(provider_type, None)
|
cls._initialized_providers.pop(provider_type, None)
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ ignore = [
|
|||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"__init__.py" = ["F401"]
|
"__init__.py" = ["F401"]
|
||||||
"tests/*" = ["B011"]
|
"tests/*" = ["B011"]
|
||||||
|
"tests/conftest.py" = ["E402"] # Module level imports not at top of file - needed for test setup
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"]
|
requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"]
|
||||||
|
|||||||
44
server.py
44
server.py
@@ -1,8 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Gemini MCP Server - Main server implementation
|
Zen MCP Server - Main server implementation
|
||||||
|
|
||||||
This module implements the core MCP (Model Context Protocol) server that provides
|
This module implements the core MCP (Model Context Protocol) server that provides
|
||||||
AI-powered tools for code analysis, review, and assistance using Google's Gemini models.
|
AI-powered tools for code analysis, review, and assistance using multiple AI models.
|
||||||
|
|
||||||
The server follows the MCP specification to expose various AI tools as callable functions
|
The server follows the MCP specification to expose various AI tools as callable functions
|
||||||
that can be used by MCP clients (like Claude). Each tool provides specialized functionality
|
that can be used by MCP clients (like Claude). Each tool provides specialized functionality
|
||||||
@@ -102,7 +102,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Create the MCP server instance with a unique name identifier
|
# Create the MCP server instance with a unique name identifier
|
||||||
# This name is used by MCP clients to identify and connect to this specific server
|
# This name is used by MCP clients to identify and connect to this specific server
|
||||||
server: Server = Server("gemini-server")
|
server: Server = Server("zen-server")
|
||||||
|
|
||||||
# Initialize the tool registry with all available AI-powered tools
|
# Initialize the tool registry with all available AI-powered tools
|
||||||
# Each tool provides specialized functionality for different development tasks
|
# Each tool provides specialized functionality for different development tasks
|
||||||
@@ -131,23 +131,23 @@ def configure_providers():
|
|||||||
from providers.base import ProviderType
|
from providers.base import ProviderType
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
valid_providers = []
|
valid_providers = []
|
||||||
|
|
||||||
# Check for Gemini API key
|
# Check for Gemini API key
|
||||||
gemini_key = os.getenv("GEMINI_API_KEY")
|
gemini_key = os.getenv("GEMINI_API_KEY")
|
||||||
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
valid_providers.append("Gemini")
|
valid_providers.append("Gemini")
|
||||||
logger.info("Gemini API key found - Gemini models available")
|
logger.info("Gemini API key found - Gemini models available")
|
||||||
|
|
||||||
# Check for OpenAI API key
|
# Check for OpenAI API key
|
||||||
openai_key = os.getenv("OPENAI_API_KEY")
|
openai_key = os.getenv("OPENAI_API_KEY")
|
||||||
if openai_key and openai_key != "your_openai_api_key_here":
|
if openai_key and openai_key != "your_openai_api_key_here":
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
valid_providers.append("OpenAI (o3)")
|
valid_providers.append("OpenAI (o3)")
|
||||||
logger.info("OpenAI API key found - o3 model available")
|
logger.info("OpenAI API key found - o3 model available")
|
||||||
|
|
||||||
# Require at least one valid provider
|
# Require at least one valid provider
|
||||||
if not valid_providers:
|
if not valid_providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -155,7 +155,7 @@ def configure_providers():
|
|||||||
"- GEMINI_API_KEY for Gemini models\n"
|
"- GEMINI_API_KEY for Gemini models\n"
|
||||||
"- OPENAI_API_KEY for OpenAI o3 model"
|
"- OPENAI_API_KEY for OpenAI o3 model"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Available providers: {', '.join(valid_providers)}")
|
logger.info(f"Available providers: {', '.join(valid_providers)}")
|
||||||
|
|
||||||
|
|
||||||
@@ -388,8 +388,9 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
|
|
||||||
# Create model context early to use for history building
|
# Create model context early to use for history building
|
||||||
from utils.model_context import ModelContext
|
from utils.model_context import ModelContext
|
||||||
|
|
||||||
model_context = ModelContext.from_arguments(arguments)
|
model_context = ModelContext.from_arguments(arguments)
|
||||||
|
|
||||||
# Build conversation history with model-specific limits
|
# Build conversation history with model-specific limits
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Building conversation history for thread {continuation_id}")
|
logger.debug(f"[CONVERSATION_DEBUG] Building conversation history for thread {continuation_id}")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Thread has {len(context.turns)} turns, tool: {context.tool_name}")
|
logger.debug(f"[CONVERSATION_DEBUG] Thread has {len(context.turns)} turns, tool: {context.tool_name}")
|
||||||
@@ -404,9 +405,9 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
|
|
||||||
# All tools now use standardized 'prompt' field
|
# All tools now use standardized 'prompt' field
|
||||||
original_prompt = arguments.get("prompt", "")
|
original_prompt = arguments.get("prompt", "")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Extracting user input from 'prompt' field")
|
logger.debug("[CONVERSATION_DEBUG] Extracting user input from 'prompt' field")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] User input length: {len(original_prompt)} chars")
|
logger.debug(f"[CONVERSATION_DEBUG] User input length: {len(original_prompt)} chars")
|
||||||
|
|
||||||
# Merge original context with new prompt and follow-up instructions
|
# Merge original context with new prompt and follow-up instructions
|
||||||
if conversation_history:
|
if conversation_history:
|
||||||
enhanced_prompt = (
|
enhanced_prompt = (
|
||||||
@@ -417,25 +418,25 @@ async def reconstruct_thread_context(arguments: dict[str, Any]) -> dict[str, Any
|
|||||||
|
|
||||||
# Update arguments with enhanced context and remaining token budget
|
# Update arguments with enhanced context and remaining token budget
|
||||||
enhanced_arguments = arguments.copy()
|
enhanced_arguments = arguments.copy()
|
||||||
|
|
||||||
# Store the enhanced prompt in the prompt field
|
# Store the enhanced prompt in the prompt field
|
||||||
enhanced_arguments["prompt"] = enhanced_prompt
|
enhanced_arguments["prompt"] = enhanced_prompt
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Storing enhanced prompt in 'prompt' field")
|
logger.debug("[CONVERSATION_DEBUG] Storing enhanced prompt in 'prompt' field")
|
||||||
|
|
||||||
# Calculate remaining token budget based on current model
|
# Calculate remaining token budget based on current model
|
||||||
# (model_context was already created above for history building)
|
# (model_context was already created above for history building)
|
||||||
token_allocation = model_context.calculate_token_allocation()
|
token_allocation = model_context.calculate_token_allocation()
|
||||||
|
|
||||||
# Calculate remaining tokens for files/new content
|
# Calculate remaining tokens for files/new content
|
||||||
# History has already consumed some of the content budget
|
# History has already consumed some of the content budget
|
||||||
remaining_tokens = token_allocation.content_tokens - conversation_tokens
|
remaining_tokens = token_allocation.content_tokens - conversation_tokens
|
||||||
enhanced_arguments["_remaining_tokens"] = max(0, remaining_tokens) # Ensure non-negative
|
enhanced_arguments["_remaining_tokens"] = max(0, remaining_tokens) # Ensure non-negative
|
||||||
enhanced_arguments["_model_context"] = model_context # Pass context for use in tools
|
enhanced_arguments["_model_context"] = model_context # Pass context for use in tools
|
||||||
|
|
||||||
logger.debug("[CONVERSATION_DEBUG] Token budget calculation:")
|
logger.debug("[CONVERSATION_DEBUG] Token budget calculation:")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Model: {model_context.model_name}")
|
logger.debug(f"[CONVERSATION_DEBUG] Model: {model_context.model_name}")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Total capacity: {token_allocation.total_tokens:,}")
|
logger.debug(f"[CONVERSATION_DEBUG] Total capacity: {token_allocation.total_tokens:,}")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Content allocation: {token_allocation.content_tokens:,}")
|
logger.debug(f"[CONVERSATION_DEBUG] Content allocation: {token_allocation.content_tokens:,}")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Conversation tokens: {conversation_tokens:,}")
|
logger.debug(f"[CONVERSATION_DEBUG] Conversation tokens: {conversation_tokens:,}")
|
||||||
logger.debug(f"[CONVERSATION_DEBUG] Remaining tokens: {remaining_tokens:,}")
|
logger.debug(f"[CONVERSATION_DEBUG] Remaining tokens: {remaining_tokens:,}")
|
||||||
|
|
||||||
@@ -494,7 +495,7 @@ async def handle_get_version() -> list[TextContent]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Format the information in a human-readable way
|
# Format the information in a human-readable way
|
||||||
text = f"""Gemini MCP Server v{__version__}
|
text = f"""Zen MCP Server v{__version__}
|
||||||
Updated: {__updated__}
|
Updated: {__updated__}
|
||||||
Author: {__author__}
|
Author: {__author__}
|
||||||
|
|
||||||
@@ -508,7 +509,7 @@ Configuration:
|
|||||||
Available Tools:
|
Available Tools:
|
||||||
{chr(10).join(f" - {tool}" for tool in version_info["available_tools"])}
|
{chr(10).join(f" - {tool}" for tool in version_info["available_tools"])}
|
||||||
|
|
||||||
For updates, visit: https://github.com/BeehiveInnovations/gemini-mcp-server"""
|
For updates, visit: https://github.com/BeehiveInnovations/zen-mcp-server"""
|
||||||
|
|
||||||
# Create standardized tool output
|
# Create standardized tool output
|
||||||
tool_output = ToolOutput(status="success", content=text, content_type="text", metadata={"tool_name": "get_version"})
|
tool_output = ToolOutput(status="success", content=text, content_type="text", metadata={"tool_name": "get_version"})
|
||||||
@@ -531,11 +532,12 @@ async def main():
|
|||||||
configure_providers()
|
configure_providers()
|
||||||
|
|
||||||
# Log startup message for Docker log monitoring
|
# Log startup message for Docker log monitoring
|
||||||
logger.info("Gemini MCP Server starting up...")
|
logger.info("Zen MCP Server starting up...")
|
||||||
logger.info(f"Log level: {log_level}")
|
logger.info(f"Log level: {log_level}")
|
||||||
|
|
||||||
# Log current model mode
|
# Log current model mode
|
||||||
from config import IS_AUTO_MODE
|
from config import IS_AUTO_MODE
|
||||||
|
|
||||||
if IS_AUTO_MODE:
|
if IS_AUTO_MODE:
|
||||||
logger.info("Model mode: AUTO (Claude will select the best model for each task)")
|
logger.info("Model mode: AUTO (Claude will select the best model for each task)")
|
||||||
else:
|
else:
|
||||||
@@ -556,7 +558,7 @@ async def main():
|
|||||||
read_stream,
|
read_stream,
|
||||||
write_stream,
|
write_stream,
|
||||||
InitializationOptions(
|
InitializationOptions(
|
||||||
server_name="gemini",
|
server_name="zen",
|
||||||
server_version=__version__,
|
server_version=__version__,
|
||||||
capabilities=ServerCapabilities(tools=ToolsCapability()), # Advertise tool support capability
|
capabilities=ServerCapabilities(tools=ToolsCapability()), # Advertise tool support capability
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -3,10 +3,10 @@
|
|||||||
# Exit on any error, undefined variables, and pipe failures
|
# Exit on any error, undefined variables, and pipe failures
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
# Modern Docker setup script for Gemini MCP Server with Redis
|
# Modern Docker setup script for Zen MCP Server with Redis
|
||||||
# This script sets up the complete Docker environment including Redis for conversation threading
|
# This script sets up the complete Docker environment including Redis for conversation threading
|
||||||
|
|
||||||
echo "🚀 Setting up Gemini MCP Server with Docker Compose..."
|
echo "🚀 Setting up Zen MCP Server with Docker Compose..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Get the current working directory (absolute path)
|
# Get the current working directory (absolute path)
|
||||||
@@ -131,7 +131,7 @@ $COMPOSE_CMD down --remove-orphans >/dev/null 2>&1 || true
|
|||||||
# Clean up any old containers with different naming patterns
|
# Clean up any old containers with different naming patterns
|
||||||
OLD_CONTAINERS_FOUND=false
|
OLD_CONTAINERS_FOUND=false
|
||||||
|
|
||||||
# Check for old Gemini MCP container
|
# Check for old Gemini MCP containers (for migration)
|
||||||
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server-gemini-mcp-1$" 2>/dev/null || false; then
|
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server-gemini-mcp-1$" 2>/dev/null || false; then
|
||||||
OLD_CONTAINERS_FOUND=true
|
OLD_CONTAINERS_FOUND=true
|
||||||
echo " - Cleaning up old container: gemini-mcp-server-gemini-mcp-1"
|
echo " - Cleaning up old container: gemini-mcp-server-gemini-mcp-1"
|
||||||
@@ -139,6 +139,21 @@ if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server-gemini-mcp-1
|
|||||||
docker rm gemini-mcp-server-gemini-mcp-1 >/dev/null 2>&1 || true
|
docker rm gemini-mcp-server-gemini-mcp-1 >/dev/null 2>&1 || true
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server$" 2>/dev/null || false; then
|
||||||
|
OLD_CONTAINERS_FOUND=true
|
||||||
|
echo " - Cleaning up old container: gemini-mcp-server"
|
||||||
|
docker stop gemini-mcp-server >/dev/null 2>&1 || true
|
||||||
|
docker rm gemini-mcp-server >/dev/null 2>&1 || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check for current old containers (from recent versions)
|
||||||
|
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-log-monitor$" 2>/dev/null || false; then
|
||||||
|
OLD_CONTAINERS_FOUND=true
|
||||||
|
echo " - Cleaning up old container: gemini-mcp-log-monitor"
|
||||||
|
docker stop gemini-mcp-log-monitor >/dev/null 2>&1 || true
|
||||||
|
docker rm gemini-mcp-log-monitor >/dev/null 2>&1 || true
|
||||||
|
fi
|
||||||
|
|
||||||
# Check for old Redis container
|
# Check for old Redis container
|
||||||
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server-redis-1$" 2>/dev/null || false; then
|
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server-redis-1$" 2>/dev/null || false; then
|
||||||
OLD_CONTAINERS_FOUND=true
|
OLD_CONTAINERS_FOUND=true
|
||||||
@@ -147,17 +162,37 @@ if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-server-redis-1$" 2>
|
|||||||
docker rm gemini-mcp-server-redis-1 >/dev/null 2>&1 || true
|
docker rm gemini-mcp-server-redis-1 >/dev/null 2>&1 || true
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Check for old image
|
if docker ps -a --format "{{.Names}}" | grep -q "^gemini-mcp-redis$" 2>/dev/null || false; then
|
||||||
|
OLD_CONTAINERS_FOUND=true
|
||||||
|
echo " - Cleaning up old container: gemini-mcp-redis"
|
||||||
|
docker stop gemini-mcp-redis >/dev/null 2>&1 || true
|
||||||
|
docker rm gemini-mcp-redis >/dev/null 2>&1 || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check for old images
|
||||||
if docker images --format "{{.Repository}}:{{.Tag}}" | grep -q "^gemini-mcp-server-gemini-mcp:latest$" 2>/dev/null || false; then
|
if docker images --format "{{.Repository}}:{{.Tag}}" | grep -q "^gemini-mcp-server-gemini-mcp:latest$" 2>/dev/null || false; then
|
||||||
OLD_CONTAINERS_FOUND=true
|
OLD_CONTAINERS_FOUND=true
|
||||||
echo " - Cleaning up old image: gemini-mcp-server-gemini-mcp:latest"
|
echo " - Cleaning up old image: gemini-mcp-server-gemini-mcp:latest"
|
||||||
docker rmi gemini-mcp-server-gemini-mcp:latest >/dev/null 2>&1 || true
|
docker rmi gemini-mcp-server-gemini-mcp:latest >/dev/null 2>&1 || true
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if docker images --format "{{.Repository}}:{{.Tag}}" | grep -q "^gemini-mcp-server:latest$" 2>/dev/null || false; then
|
||||||
|
OLD_CONTAINERS_FOUND=true
|
||||||
|
echo " - Cleaning up old image: gemini-mcp-server:latest"
|
||||||
|
docker rmi gemini-mcp-server:latest >/dev/null 2>&1 || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check for current old network (if it exists)
|
||||||
|
if docker network ls --format "{{.Name}}" | grep -q "^gemini-mcp-server_default$" 2>/dev/null || false; then
|
||||||
|
OLD_CONTAINERS_FOUND=true
|
||||||
|
echo " - Cleaning up old network: gemini-mcp-server_default"
|
||||||
|
docker network rm gemini-mcp-server_default >/dev/null 2>&1 || true
|
||||||
|
fi
|
||||||
|
|
||||||
# Only show cleanup messages if something was actually cleaned up
|
# Only show cleanup messages if something was actually cleaned up
|
||||||
|
|
||||||
# Build and start services
|
# Build and start services
|
||||||
echo " - Building Gemini MCP Server image..."
|
echo " - Building Zen MCP Server image..."
|
||||||
if $COMPOSE_CMD build --no-cache >/dev/null 2>&1; then
|
if $COMPOSE_CMD build --no-cache >/dev/null 2>&1; then
|
||||||
echo "✅ Docker image built successfully!"
|
echo "✅ Docker image built successfully!"
|
||||||
else
|
else
|
||||||
@@ -209,12 +244,12 @@ echo ""
|
|||||||
echo "===== CLAUDE DESKTOP CONFIGURATION ====="
|
echo "===== CLAUDE DESKTOP CONFIGURATION ====="
|
||||||
echo "{"
|
echo "{"
|
||||||
echo " \"mcpServers\": {"
|
echo " \"mcpServers\": {"
|
||||||
echo " \"gemini\": {"
|
echo " \"zen\": {"
|
||||||
echo " \"command\": \"docker\","
|
echo " \"command\": \"docker\","
|
||||||
echo " \"args\": ["
|
echo " \"args\": ["
|
||||||
echo " \"exec\","
|
echo " \"exec\","
|
||||||
echo " \"-i\","
|
echo " \"-i\","
|
||||||
echo " \"gemini-mcp-server\","
|
echo " \"zen-mcp-server\","
|
||||||
echo " \"python\","
|
echo " \"python\","
|
||||||
echo " \"server.py\""
|
echo " \"server.py\""
|
||||||
echo " ]"
|
echo " ]"
|
||||||
@@ -225,13 +260,13 @@ echo "==========================================="
|
|||||||
echo ""
|
echo ""
|
||||||
echo "===== CLAUDE CODE CLI CONFIGURATION ====="
|
echo "===== CLAUDE CODE CLI CONFIGURATION ====="
|
||||||
echo "# Add the MCP server via Claude Code CLI:"
|
echo "# Add the MCP server via Claude Code CLI:"
|
||||||
echo "claude mcp add gemini -s user -- docker exec -i gemini-mcp-server python server.py"
|
echo "claude mcp add zen -s user -- docker exec -i zen-mcp-server python server.py"
|
||||||
echo ""
|
echo ""
|
||||||
echo "# List your MCP servers to verify:"
|
echo "# List your MCP servers to verify:"
|
||||||
echo "claude mcp list"
|
echo "claude mcp list"
|
||||||
echo ""
|
echo ""
|
||||||
echo "# Remove if needed:"
|
echo "# Remove if needed:"
|
||||||
echo "claude mcp remove gemini -s user"
|
echo "claude mcp remove zen -s user"
|
||||||
echo "==========================================="
|
echo "==========================================="
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
Communication Simulator Tests Package
|
Communication Simulator Tests Package
|
||||||
|
|
||||||
This package contains individual test modules for the Gemini MCP Communication Simulator.
|
This package contains individual test modules for the Zen MCP Communication Simulator.
|
||||||
Each test is in its own file for better organization and maintainability.
|
Each test is in its own file for better organization and maintainability.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .base_test import BaseSimulatorTest
|
from .base_test import BaseSimulatorTest
|
||||||
from .test_basic_conversation import BasicConversationTest
|
from .test_basic_conversation import BasicConversationTest
|
||||||
from .test_content_validation import ContentValidationTest
|
from .test_content_validation import ContentValidationTest
|
||||||
|
from .test_conversation_chain_validation import ConversationChainValidationTest
|
||||||
from .test_cross_tool_comprehensive import CrossToolComprehensiveTest
|
from .test_cross_tool_comprehensive import CrossToolComprehensiveTest
|
||||||
from .test_cross_tool_continuation import CrossToolContinuationTest
|
from .test_cross_tool_continuation import CrossToolContinuationTest
|
||||||
from .test_logs_validation import LogsValidationTest
|
from .test_logs_validation import LogsValidationTest
|
||||||
@@ -16,7 +17,6 @@ from .test_o3_model_selection import O3ModelSelectionTest
|
|||||||
from .test_per_tool_deduplication import PerToolDeduplicationTest
|
from .test_per_tool_deduplication import PerToolDeduplicationTest
|
||||||
from .test_redis_validation import RedisValidationTest
|
from .test_redis_validation import RedisValidationTest
|
||||||
from .test_token_allocation_validation import TokenAllocationValidationTest
|
from .test_token_allocation_validation import TokenAllocationValidationTest
|
||||||
from .test_conversation_chain_validation import ConversationChainValidationTest
|
|
||||||
|
|
||||||
# Test registry for dynamic loading
|
# Test registry for dynamic loading
|
||||||
TEST_REGISTRY = {
|
TEST_REGISTRY = {
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ class BaseSimulatorTest:
|
|||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.test_files = {}
|
self.test_files = {}
|
||||||
self.test_dir = None
|
self.test_dir = None
|
||||||
self.container_name = "gemini-mcp-server"
|
self.container_name = "zen-mcp-server"
|
||||||
self.redis_container = "gemini-mcp-redis"
|
self.redis_container = "zen-mcp-redis"
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
log_level = logging.DEBUG if verbose else logging.INFO
|
log_level = logging.DEBUG if verbose else logging.INFO
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ Tests that tools don't duplicate file content in their responses.
|
|||||||
This test is specifically designed to catch content duplication bugs.
|
This test is specifically designed to catch content duplication bugs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from .base_test import BaseSimulatorTest
|
from .base_test import BaseSimulatorTest
|
||||||
@@ -31,6 +30,7 @@ class ContentValidationTest(BaseSimulatorTest):
|
|||||||
cmd_monitor = ["docker", "logs", "--since", since_time, "gemini-mcp-log-monitor"]
|
cmd_monitor = ["docker", "logs", "--since", since_time, "gemini-mcp-log-monitor"]
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
result_server = subprocess.run(cmd_server, capture_output=True, text=True)
|
result_server = subprocess.run(cmd_server, capture_output=True, text=True)
|
||||||
result_monitor = subprocess.run(cmd_monitor, capture_output=True, text=True)
|
result_monitor = subprocess.run(cmd_monitor, capture_output=True, text=True)
|
||||||
|
|
||||||
@@ -76,6 +76,7 @@ DATABASE_CONFIG = {
|
|||||||
|
|
||||||
# Get timestamp for log filtering
|
# Get timestamp for log filtering
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
start_time = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
start_time = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
||||||
|
|
||||||
# Test 1: Initial tool call with validation file
|
# Test 1: Initial tool call with validation file
|
||||||
@@ -139,26 +140,25 @@ DATABASE_CONFIG = {
|
|||||||
|
|
||||||
# Check for proper file embedding logs
|
# Check for proper file embedding logs
|
||||||
embedding_logs = [
|
embedding_logs = [
|
||||||
line for line in logs.split("\n")
|
line for line in logs.split("\n") if "📁" in line or "embedding" in line.lower() or "[FILES]" in line
|
||||||
if "📁" in line or "embedding" in line.lower() or "[FILES]" in line
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check for deduplication evidence
|
# Check for deduplication evidence
|
||||||
deduplication_logs = [
|
deduplication_logs = [
|
||||||
line for line in logs.split("\n")
|
line
|
||||||
|
for line in logs.split("\n")
|
||||||
if "skipping" in line.lower() and "already in conversation" in line.lower()
|
if "skipping" in line.lower() and "already in conversation" in line.lower()
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check for file processing patterns
|
# Check for file processing patterns
|
||||||
new_file_logs = [
|
new_file_logs = [
|
||||||
line for line in logs.split("\n")
|
line for line in logs.split("\n") if "all 1 files are new" in line or "New conversation" in line
|
||||||
if "all 1 files are new" in line or "New conversation" in line
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Validation criteria
|
# Validation criteria
|
||||||
validation_file_mentioned = any("validation_config.py" in line for line in logs.split("\n"))
|
validation_file_mentioned = any("validation_config.py" in line for line in logs.split("\n"))
|
||||||
embedding_found = len(embedding_logs) > 0
|
embedding_found = len(embedding_logs) > 0
|
||||||
proper_deduplication = len(deduplication_logs) > 0 or len(new_file_logs) >= 2 # Should see new conversation patterns
|
(len(deduplication_logs) > 0 or len(new_file_logs) >= 2) # Should see new conversation patterns
|
||||||
|
|
||||||
self.logger.info(f" 📊 Embedding logs found: {len(embedding_logs)}")
|
self.logger.info(f" 📊 Embedding logs found: {len(embedding_logs)}")
|
||||||
self.logger.info(f" 📊 Deduplication evidence: {len(deduplication_logs)}")
|
self.logger.info(f" 📊 Deduplication evidence: {len(deduplication_logs)}")
|
||||||
@@ -175,7 +175,7 @@ DATABASE_CONFIG = {
|
|||||||
success_criteria = [
|
success_criteria = [
|
||||||
("Embedding logs found", embedding_found),
|
("Embedding logs found", embedding_found),
|
||||||
("File processing evidence", validation_file_mentioned),
|
("File processing evidence", validation_file_mentioned),
|
||||||
("Multiple tool calls", len(new_file_logs) >= 2)
|
("Multiple tool calls", len(new_file_logs) >= 2),
|
||||||
]
|
]
|
||||||
|
|
||||||
passed_criteria = sum(1 for _, passed in success_criteria if passed)
|
passed_criteria = sum(1 for _, passed in success_criteria if passed)
|
||||||
|
|||||||
@@ -4,14 +4,14 @@ Conversation Chain and Threading Validation Test
|
|||||||
|
|
||||||
This test validates that:
|
This test validates that:
|
||||||
1. Multiple tool invocations create proper parent->parent->parent chains
|
1. Multiple tool invocations create proper parent->parent->parent chains
|
||||||
2. New conversations can be started independently
|
2. New conversations can be started independently
|
||||||
3. Original conversation chains can be resumed from any point
|
3. Original conversation chains can be resumed from any point
|
||||||
4. History traversal works correctly for all scenarios
|
4. History traversal works correctly for all scenarios
|
||||||
5. Thread relationships are properly maintained in Redis
|
5. Thread relationships are properly maintained in Redis
|
||||||
|
|
||||||
Test Flow:
|
Test Flow:
|
||||||
Chain A: chat -> analyze -> debug (3 linked threads)
|
Chain A: chat -> analyze -> debug (3 linked threads)
|
||||||
Chain B: chat -> analyze (2 linked threads, independent)
|
Chain B: chat -> analyze (2 linked threads, independent)
|
||||||
Chain A Branch: debug (continue from original chat, creating branch)
|
Chain A Branch: debug (continue from original chat, creating branch)
|
||||||
|
|
||||||
This validates the conversation threading system's ability to:
|
This validates the conversation threading system's ability to:
|
||||||
@@ -21,10 +21,8 @@ This validates the conversation threading system's ability to:
|
|||||||
- Properly traverse parent relationships for history reconstruction
|
- Properly traverse parent relationships for history reconstruction
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import datetime
|
|
||||||
import subprocess
|
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Tuple, Optional
|
import subprocess
|
||||||
|
|
||||||
from .base_test import BaseSimulatorTest
|
from .base_test import BaseSimulatorTest
|
||||||
|
|
||||||
@@ -45,7 +43,7 @@ class ConversationChainValidationTest(BaseSimulatorTest):
|
|||||||
try:
|
try:
|
||||||
cmd = ["docker", "exec", self.container_name, "tail", "-n", "500", "/tmp/mcp_server.log"]
|
cmd = ["docker", "exec", self.container_name, "tail", "-n", "500", "/tmp/mcp_server.log"]
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
return result.stdout
|
return result.stdout
|
||||||
else:
|
else:
|
||||||
@@ -55,44 +53,36 @@ class ConversationChainValidationTest(BaseSimulatorTest):
|
|||||||
self.logger.error(f"Failed to get server logs: {e}")
|
self.logger.error(f"Failed to get server logs: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def extract_thread_creation_logs(self, logs: str) -> List[Dict[str, str]]:
|
def extract_thread_creation_logs(self, logs: str) -> list[dict[str, str]]:
|
||||||
"""Extract thread creation logs with parent relationships"""
|
"""Extract thread creation logs with parent relationships"""
|
||||||
thread_logs = []
|
thread_logs = []
|
||||||
|
|
||||||
lines = logs.split('\n')
|
lines = logs.split("\n")
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if "[THREAD] Created new thread" in line:
|
if "[THREAD] Created new thread" in line:
|
||||||
# Parse: [THREAD] Created new thread 9dc779eb-645f-4850-9659-34c0e6978d73 with parent a0ce754d-c995-4b3e-9103-88af429455aa
|
# Parse: [THREAD] Created new thread 9dc779eb-645f-4850-9659-34c0e6978d73 with parent a0ce754d-c995-4b3e-9103-88af429455aa
|
||||||
match = re.search(r'\[THREAD\] Created new thread ([a-f0-9-]+) with parent ([a-f0-9-]+|None)', line)
|
match = re.search(r"\[THREAD\] Created new thread ([a-f0-9-]+) with parent ([a-f0-9-]+|None)", line)
|
||||||
if match:
|
if match:
|
||||||
thread_id = match.group(1)
|
thread_id = match.group(1)
|
||||||
parent_id = match.group(2) if match.group(2) != "None" else None
|
parent_id = match.group(2) if match.group(2) != "None" else None
|
||||||
thread_logs.append({
|
thread_logs.append({"thread_id": thread_id, "parent_id": parent_id, "log_line": line})
|
||||||
"thread_id": thread_id,
|
|
||||||
"parent_id": parent_id,
|
|
||||||
"log_line": line
|
|
||||||
})
|
|
||||||
|
|
||||||
return thread_logs
|
return thread_logs
|
||||||
|
|
||||||
def extract_history_traversal_logs(self, logs: str) -> List[Dict[str, str]]:
|
def extract_history_traversal_logs(self, logs: str) -> list[dict[str, str]]:
|
||||||
"""Extract conversation history traversal logs"""
|
"""Extract conversation history traversal logs"""
|
||||||
traversal_logs = []
|
traversal_logs = []
|
||||||
|
|
||||||
lines = logs.split('\n')
|
lines = logs.split("\n")
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if "[THREAD] Retrieved chain of" in line:
|
if "[THREAD] Retrieved chain of" in line:
|
||||||
# Parse: [THREAD] Retrieved chain of 3 threads for 9dc779eb-645f-4850-9659-34c0e6978d73
|
# Parse: [THREAD] Retrieved chain of 3 threads for 9dc779eb-645f-4850-9659-34c0e6978d73
|
||||||
match = re.search(r'\[THREAD\] Retrieved chain of (\d+) threads for ([a-f0-9-]+)', line)
|
match = re.search(r"\[THREAD\] Retrieved chain of (\d+) threads for ([a-f0-9-]+)", line)
|
||||||
if match:
|
if match:
|
||||||
chain_length = int(match.group(1))
|
chain_length = int(match.group(1))
|
||||||
thread_id = match.group(2)
|
thread_id = match.group(2)
|
||||||
traversal_logs.append({
|
traversal_logs.append({"thread_id": thread_id, "chain_length": chain_length, "log_line": line})
|
||||||
"thread_id": thread_id,
|
|
||||||
"chain_length": chain_length,
|
|
||||||
"log_line": line
|
|
||||||
})
|
|
||||||
|
|
||||||
return traversal_logs
|
return traversal_logs
|
||||||
|
|
||||||
def run_test(self) -> bool:
|
def run_test(self) -> bool:
|
||||||
@@ -113,16 +103,16 @@ class TestClass:
|
|||||||
return "Method in test class"
|
return "Method in test class"
|
||||||
"""
|
"""
|
||||||
test_file_path = self.create_additional_test_file("chain_test.py", test_file_content)
|
test_file_path = self.create_additional_test_file("chain_test.py", test_file_content)
|
||||||
|
|
||||||
# Track all continuation IDs and their relationships
|
# Track all continuation IDs and their relationships
|
||||||
conversation_chains = {}
|
conversation_chains = {}
|
||||||
|
|
||||||
# === CHAIN A: Build linear conversation chain ===
|
# === CHAIN A: Build linear conversation chain ===
|
||||||
self.logger.info(" 🔗 Chain A: Building linear conversation chain")
|
self.logger.info(" 🔗 Chain A: Building linear conversation chain")
|
||||||
|
|
||||||
# Step A1: Start with chat tool (creates thread_id_1)
|
# Step A1: Start with chat tool (creates thread_id_1)
|
||||||
self.logger.info(" Step A1: Chat tool - start new conversation")
|
self.logger.info(" Step A1: Chat tool - start new conversation")
|
||||||
|
|
||||||
response_a1, continuation_id_a1 = self.call_mcp_tool(
|
response_a1, continuation_id_a1 = self.call_mcp_tool(
|
||||||
"chat",
|
"chat",
|
||||||
{
|
{
|
||||||
@@ -138,11 +128,11 @@ class TestClass:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
self.logger.info(f" ✅ Step A1 completed - thread_id: {continuation_id_a1[:8]}...")
|
self.logger.info(f" ✅ Step A1 completed - thread_id: {continuation_id_a1[:8]}...")
|
||||||
conversation_chains['A1'] = continuation_id_a1
|
conversation_chains["A1"] = continuation_id_a1
|
||||||
|
|
||||||
# Step A2: Continue with analyze tool (creates thread_id_2 with parent=thread_id_1)
|
# Step A2: Continue with analyze tool (creates thread_id_2 with parent=thread_id_1)
|
||||||
self.logger.info(" Step A2: Analyze tool - continue Chain A")
|
self.logger.info(" Step A2: Analyze tool - continue Chain A")
|
||||||
|
|
||||||
response_a2, continuation_id_a2 = self.call_mcp_tool(
|
response_a2, continuation_id_a2 = self.call_mcp_tool(
|
||||||
"analyze",
|
"analyze",
|
||||||
{
|
{
|
||||||
@@ -159,11 +149,11 @@ class TestClass:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
self.logger.info(f" ✅ Step A2 completed - thread_id: {continuation_id_a2[:8]}...")
|
self.logger.info(f" ✅ Step A2 completed - thread_id: {continuation_id_a2[:8]}...")
|
||||||
conversation_chains['A2'] = continuation_id_a2
|
conversation_chains["A2"] = continuation_id_a2
|
||||||
|
|
||||||
# Step A3: Continue with debug tool (creates thread_id_3 with parent=thread_id_2)
|
# Step A3: Continue with debug tool (creates thread_id_3 with parent=thread_id_2)
|
||||||
self.logger.info(" Step A3: Debug tool - continue Chain A")
|
self.logger.info(" Step A3: Debug tool - continue Chain A")
|
||||||
|
|
||||||
response_a3, continuation_id_a3 = self.call_mcp_tool(
|
response_a3, continuation_id_a3 = self.call_mcp_tool(
|
||||||
"debug",
|
"debug",
|
||||||
{
|
{
|
||||||
@@ -180,14 +170,14 @@ class TestClass:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
self.logger.info(f" ✅ Step A3 completed - thread_id: {continuation_id_a3[:8]}...")
|
self.logger.info(f" ✅ Step A3 completed - thread_id: {continuation_id_a3[:8]}...")
|
||||||
conversation_chains['A3'] = continuation_id_a3
|
conversation_chains["A3"] = continuation_id_a3
|
||||||
|
|
||||||
# === CHAIN B: Start independent conversation ===
|
# === CHAIN B: Start independent conversation ===
|
||||||
self.logger.info(" 🔗 Chain B: Starting independent conversation")
|
self.logger.info(" 🔗 Chain B: Starting independent conversation")
|
||||||
|
|
||||||
# Step B1: Start new chat conversation (creates thread_id_4, no parent)
|
# Step B1: Start new chat conversation (creates thread_id_4, no parent)
|
||||||
self.logger.info(" Step B1: Chat tool - start NEW independent conversation")
|
self.logger.info(" Step B1: Chat tool - start NEW independent conversation")
|
||||||
|
|
||||||
response_b1, continuation_id_b1 = self.call_mcp_tool(
|
response_b1, continuation_id_b1 = self.call_mcp_tool(
|
||||||
"chat",
|
"chat",
|
||||||
{
|
{
|
||||||
@@ -202,11 +192,11 @@ class TestClass:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
self.logger.info(f" ✅ Step B1 completed - thread_id: {continuation_id_b1[:8]}...")
|
self.logger.info(f" ✅ Step B1 completed - thread_id: {continuation_id_b1[:8]}...")
|
||||||
conversation_chains['B1'] = continuation_id_b1
|
conversation_chains["B1"] = continuation_id_b1
|
||||||
|
|
||||||
# Step B2: Continue the new conversation (creates thread_id_5 with parent=thread_id_4)
|
# Step B2: Continue the new conversation (creates thread_id_5 with parent=thread_id_4)
|
||||||
self.logger.info(" Step B2: Analyze tool - continue Chain B")
|
self.logger.info(" Step B2: Analyze tool - continue Chain B")
|
||||||
|
|
||||||
response_b2, continuation_id_b2 = self.call_mcp_tool(
|
response_b2, continuation_id_b2 = self.call_mcp_tool(
|
||||||
"analyze",
|
"analyze",
|
||||||
{
|
{
|
||||||
@@ -222,14 +212,14 @@ class TestClass:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
self.logger.info(f" ✅ Step B2 completed - thread_id: {continuation_id_b2[:8]}...")
|
self.logger.info(f" ✅ Step B2 completed - thread_id: {continuation_id_b2[:8]}...")
|
||||||
conversation_chains['B2'] = continuation_id_b2
|
conversation_chains["B2"] = continuation_id_b2
|
||||||
|
|
||||||
# === CHAIN A BRANCH: Go back to original conversation ===
|
# === CHAIN A BRANCH: Go back to original conversation ===
|
||||||
self.logger.info(" 🔗 Chain A Branch: Resume original conversation from A1")
|
self.logger.info(" 🔗 Chain A Branch: Resume original conversation from A1")
|
||||||
|
|
||||||
# Step A1-Branch: Use original continuation_id_a1 to branch (creates thread_id_6 with parent=thread_id_1)
|
# Step A1-Branch: Use original continuation_id_a1 to branch (creates thread_id_6 with parent=thread_id_1)
|
||||||
self.logger.info(" Step A1-Branch: Debug tool - branch from original Chain A")
|
self.logger.info(" Step A1-Branch: Debug tool - branch from original Chain A")
|
||||||
|
|
||||||
response_a1_branch, continuation_id_a1_branch = self.call_mcp_tool(
|
response_a1_branch, continuation_id_a1_branch = self.call_mcp_tool(
|
||||||
"debug",
|
"debug",
|
||||||
{
|
{
|
||||||
@@ -246,73 +236,79 @@ class TestClass:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
self.logger.info(f" ✅ Step A1-Branch completed - thread_id: {continuation_id_a1_branch[:8]}...")
|
self.logger.info(f" ✅ Step A1-Branch completed - thread_id: {continuation_id_a1_branch[:8]}...")
|
||||||
conversation_chains['A1_Branch'] = continuation_id_a1_branch
|
conversation_chains["A1_Branch"] = continuation_id_a1_branch
|
||||||
|
|
||||||
# === ANALYSIS: Validate thread relationships and history traversal ===
|
# === ANALYSIS: Validate thread relationships and history traversal ===
|
||||||
self.logger.info(" 📊 Analyzing conversation chain structure...")
|
self.logger.info(" 📊 Analyzing conversation chain structure...")
|
||||||
|
|
||||||
# Get logs and extract thread relationships
|
# Get logs and extract thread relationships
|
||||||
logs = self.get_recent_server_logs()
|
logs = self.get_recent_server_logs()
|
||||||
thread_creation_logs = self.extract_thread_creation_logs(logs)
|
thread_creation_logs = self.extract_thread_creation_logs(logs)
|
||||||
history_traversal_logs = self.extract_history_traversal_logs(logs)
|
history_traversal_logs = self.extract_history_traversal_logs(logs)
|
||||||
|
|
||||||
self.logger.info(f" Found {len(thread_creation_logs)} thread creation logs")
|
self.logger.info(f" Found {len(thread_creation_logs)} thread creation logs")
|
||||||
self.logger.info(f" Found {len(history_traversal_logs)} history traversal logs")
|
self.logger.info(f" Found {len(history_traversal_logs)} history traversal logs")
|
||||||
|
|
||||||
# Debug: Show what we found
|
# Debug: Show what we found
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self.logger.debug(" Thread creation logs found:")
|
self.logger.debug(" Thread creation logs found:")
|
||||||
for log in thread_creation_logs:
|
for log in thread_creation_logs:
|
||||||
self.logger.debug(f" {log['thread_id'][:8]}... parent: {log['parent_id'][:8] if log['parent_id'] else 'None'}...")
|
self.logger.debug(
|
||||||
|
f" {log['thread_id'][:8]}... parent: {log['parent_id'][:8] if log['parent_id'] else 'None'}..."
|
||||||
|
)
|
||||||
self.logger.debug(" History traversal logs found:")
|
self.logger.debug(" History traversal logs found:")
|
||||||
for log in history_traversal_logs:
|
for log in history_traversal_logs:
|
||||||
self.logger.debug(f" {log['thread_id'][:8]}... chain length: {log['chain_length']}")
|
self.logger.debug(f" {log['thread_id'][:8]}... chain length: {log['chain_length']}")
|
||||||
|
|
||||||
# Build expected thread relationships
|
# Build expected thread relationships
|
||||||
expected_relationships = []
|
expected_relationships = []
|
||||||
|
|
||||||
# Note: A1 and B1 won't appear in thread creation logs because they're new conversations (no parent)
|
# Note: A1 and B1 won't appear in thread creation logs because they're new conversations (no parent)
|
||||||
# Only continuation threads (A2, A3, B2, A1-Branch) will appear in creation logs
|
# Only continuation threads (A2, A3, B2, A1-Branch) will appear in creation logs
|
||||||
|
|
||||||
# Find logs for each continuation thread
|
# Find logs for each continuation thread
|
||||||
a2_log = next((log for log in thread_creation_logs if log['thread_id'] == continuation_id_a2), None)
|
a2_log = next((log for log in thread_creation_logs if log["thread_id"] == continuation_id_a2), None)
|
||||||
a3_log = next((log for log in thread_creation_logs if log['thread_id'] == continuation_id_a3), None)
|
a3_log = next((log for log in thread_creation_logs if log["thread_id"] == continuation_id_a3), None)
|
||||||
b2_log = next((log for log in thread_creation_logs if log['thread_id'] == continuation_id_b2), None)
|
b2_log = next((log for log in thread_creation_logs if log["thread_id"] == continuation_id_b2), None)
|
||||||
a1_branch_log = next((log for log in thread_creation_logs if log['thread_id'] == continuation_id_a1_branch), None)
|
a1_branch_log = next(
|
||||||
|
(log for log in thread_creation_logs if log["thread_id"] == continuation_id_a1_branch), None
|
||||||
|
)
|
||||||
|
|
||||||
# A2 should have A1 as parent
|
# A2 should have A1 as parent
|
||||||
if a2_log:
|
if a2_log:
|
||||||
expected_relationships.append(("A2 has A1 as parent", a2_log['parent_id'] == continuation_id_a1))
|
expected_relationships.append(("A2 has A1 as parent", a2_log["parent_id"] == continuation_id_a1))
|
||||||
|
|
||||||
# A3 should have A2 as parent
|
# A3 should have A2 as parent
|
||||||
if a3_log:
|
if a3_log:
|
||||||
expected_relationships.append(("A3 has A2 as parent", a3_log['parent_id'] == continuation_id_a2))
|
expected_relationships.append(("A3 has A2 as parent", a3_log["parent_id"] == continuation_id_a2))
|
||||||
|
|
||||||
# B2 should have B1 as parent (independent chain)
|
# B2 should have B1 as parent (independent chain)
|
||||||
if b2_log:
|
if b2_log:
|
||||||
expected_relationships.append(("B2 has B1 as parent", b2_log['parent_id'] == continuation_id_b1))
|
expected_relationships.append(("B2 has B1 as parent", b2_log["parent_id"] == continuation_id_b1))
|
||||||
|
|
||||||
# A1-Branch should have A1 as parent (branching)
|
# A1-Branch should have A1 as parent (branching)
|
||||||
if a1_branch_log:
|
if a1_branch_log:
|
||||||
expected_relationships.append(("A1-Branch has A1 as parent", a1_branch_log['parent_id'] == continuation_id_a1))
|
expected_relationships.append(
|
||||||
|
("A1-Branch has A1 as parent", a1_branch_log["parent_id"] == continuation_id_a1)
|
||||||
|
)
|
||||||
|
|
||||||
# Validate history traversal
|
# Validate history traversal
|
||||||
traversal_validations = []
|
traversal_validations = []
|
||||||
|
|
||||||
# History traversal logs are only generated when conversation history is built from scratch
|
# History traversal logs are only generated when conversation history is built from scratch
|
||||||
# (not when history is already embedded in the prompt by server.py)
|
# (not when history is already embedded in the prompt by server.py)
|
||||||
# So we should expect at least 1 traversal log, but not necessarily for every continuation
|
# So we should expect at least 1 traversal log, but not necessarily for every continuation
|
||||||
|
|
||||||
if len(history_traversal_logs) > 0:
|
if len(history_traversal_logs) > 0:
|
||||||
# Validate that any traversal logs we find have reasonable chain lengths
|
# Validate that any traversal logs we find have reasonable chain lengths
|
||||||
for log in history_traversal_logs:
|
for log in history_traversal_logs:
|
||||||
thread_id = log['thread_id']
|
thread_id = log["thread_id"]
|
||||||
chain_length = log['chain_length']
|
chain_length = log["chain_length"]
|
||||||
|
|
||||||
# Chain length should be at least 2 for any continuation thread
|
# Chain length should be at least 2 for any continuation thread
|
||||||
# (original thread + continuation thread)
|
# (original thread + continuation thread)
|
||||||
is_valid_length = chain_length >= 2
|
is_valid_length = chain_length >= 2
|
||||||
|
|
||||||
# Try to identify which thread this is for better validation
|
# Try to identify which thread this is for better validation
|
||||||
thread_description = "Unknown thread"
|
thread_description = "Unknown thread"
|
||||||
if thread_id == continuation_id_a2:
|
if thread_id == continuation_id_a2:
|
||||||
@@ -327,12 +323,16 @@ class TestClass:
|
|||||||
elif thread_id == continuation_id_a1_branch:
|
elif thread_id == continuation_id_a1_branch:
|
||||||
thread_description = "A1-Branch (should be 2-thread chain)"
|
thread_description = "A1-Branch (should be 2-thread chain)"
|
||||||
is_valid_length = chain_length == 2
|
is_valid_length = chain_length == 2
|
||||||
|
|
||||||
traversal_validations.append((f"{thread_description[:8]}... has valid chain length", is_valid_length))
|
traversal_validations.append(
|
||||||
|
(f"{thread_description[:8]}... has valid chain length", is_valid_length)
|
||||||
|
)
|
||||||
|
|
||||||
# Also validate we found at least one traversal (shows the system is working)
|
# Also validate we found at least one traversal (shows the system is working)
|
||||||
traversal_validations.append(("At least one history traversal occurred", len(history_traversal_logs) >= 1))
|
traversal_validations.append(
|
||||||
|
("At least one history traversal occurred", len(history_traversal_logs) >= 1)
|
||||||
|
)
|
||||||
|
|
||||||
# === VALIDATION RESULTS ===
|
# === VALIDATION RESULTS ===
|
||||||
self.logger.info(" 📊 Thread Relationship Validation:")
|
self.logger.info(" 📊 Thread Relationship Validation:")
|
||||||
relationship_passed = 0
|
relationship_passed = 0
|
||||||
@@ -341,7 +341,7 @@ class TestClass:
|
|||||||
self.logger.info(f" {status} {desc}")
|
self.logger.info(f" {status} {desc}")
|
||||||
if passed:
|
if passed:
|
||||||
relationship_passed += 1
|
relationship_passed += 1
|
||||||
|
|
||||||
self.logger.info(" 📊 History Traversal Validation:")
|
self.logger.info(" 📊 History Traversal Validation:")
|
||||||
traversal_passed = 0
|
traversal_passed = 0
|
||||||
for desc, passed in traversal_validations:
|
for desc, passed in traversal_validations:
|
||||||
@@ -349,31 +349,35 @@ class TestClass:
|
|||||||
self.logger.info(f" {status} {desc}")
|
self.logger.info(f" {status} {desc}")
|
||||||
if passed:
|
if passed:
|
||||||
traversal_passed += 1
|
traversal_passed += 1
|
||||||
|
|
||||||
# === SUCCESS CRITERIA ===
|
# === SUCCESS CRITERIA ===
|
||||||
total_relationship_checks = len(expected_relationships)
|
total_relationship_checks = len(expected_relationships)
|
||||||
total_traversal_checks = len(traversal_validations)
|
total_traversal_checks = len(traversal_validations)
|
||||||
|
|
||||||
self.logger.info(f" 📊 Validation Summary:")
|
self.logger.info(" 📊 Validation Summary:")
|
||||||
self.logger.info(f" Thread relationships: {relationship_passed}/{total_relationship_checks}")
|
self.logger.info(f" Thread relationships: {relationship_passed}/{total_relationship_checks}")
|
||||||
self.logger.info(f" History traversal: {traversal_passed}/{total_traversal_checks}")
|
self.logger.info(f" History traversal: {traversal_passed}/{total_traversal_checks}")
|
||||||
|
|
||||||
# Success requires at least 80% of validations to pass
|
# Success requires at least 80% of validations to pass
|
||||||
relationship_success = relationship_passed >= (total_relationship_checks * 0.8)
|
relationship_success = relationship_passed >= (total_relationship_checks * 0.8)
|
||||||
|
|
||||||
# If no traversal checks were possible, it means no traversal logs were found
|
# If no traversal checks were possible, it means no traversal logs were found
|
||||||
# This could indicate an issue since we expect at least some history building
|
# This could indicate an issue since we expect at least some history building
|
||||||
if total_traversal_checks == 0:
|
if total_traversal_checks == 0:
|
||||||
self.logger.warning(" No history traversal logs found - this may indicate conversation history is always pre-embedded")
|
self.logger.warning(
|
||||||
|
" No history traversal logs found - this may indicate conversation history is always pre-embedded"
|
||||||
|
)
|
||||||
# Still consider it successful since the thread relationships are what matter most
|
# Still consider it successful since the thread relationships are what matter most
|
||||||
traversal_success = True
|
traversal_success = True
|
||||||
else:
|
else:
|
||||||
traversal_success = traversal_passed >= (total_traversal_checks * 0.8)
|
traversal_success = traversal_passed >= (total_traversal_checks * 0.8)
|
||||||
|
|
||||||
overall_success = relationship_success and traversal_success
|
overall_success = relationship_success and traversal_success
|
||||||
|
|
||||||
self.logger.info(f" 📊 Conversation Chain Structure:")
|
self.logger.info(" 📊 Conversation Chain Structure:")
|
||||||
self.logger.info(f" Chain A: {continuation_id_a1[:8]} → {continuation_id_a2[:8]} → {continuation_id_a3[:8]}")
|
self.logger.info(
|
||||||
|
f" Chain A: {continuation_id_a1[:8]} → {continuation_id_a2[:8]} → {continuation_id_a3[:8]}"
|
||||||
|
)
|
||||||
self.logger.info(f" Chain B: {continuation_id_b1[:8]} → {continuation_id_b2[:8]}")
|
self.logger.info(f" Chain B: {continuation_id_b1[:8]} → {continuation_id_b2[:8]}")
|
||||||
self.logger.info(f" Branch: {continuation_id_a1[:8]} → {continuation_id_a1_branch[:8]}")
|
self.logger.info(f" Branch: {continuation_id_a1[:8]} → {continuation_id_a1_branch[:8]}")
|
||||||
|
|
||||||
@@ -394,13 +398,13 @@ class TestClass:
|
|||||||
def main():
|
def main():
|
||||||
"""Run the conversation chain validation test"""
|
"""Run the conversation chain validation test"""
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
verbose = "--verbose" in sys.argv or "-v" in sys.argv
|
verbose = "--verbose" in sys.argv or "-v" in sys.argv
|
||||||
test = ConversationChainValidationTest(verbose=verbose)
|
test = ConversationChainValidationTest(verbose=verbose)
|
||||||
|
|
||||||
success = test.run_test()
|
success = test.run_test()
|
||||||
sys.exit(0 if success else 1)
|
sys.exit(0 if success else 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class O3ModelSelectionTest(BaseSimulatorTest):
|
|||||||
# Read logs directly from the log file - more reliable than docker logs --since
|
# Read logs directly from the log file - more reliable than docker logs --since
|
||||||
cmd = ["docker", "exec", self.container_name, "tail", "-n", "200", "/tmp/mcp_server.log"]
|
cmd = ["docker", "exec", self.container_name, "tail", "-n", "200", "/tmp/mcp_server.log"]
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
return result.stdout
|
return result.stdout
|
||||||
else:
|
else:
|
||||||
@@ -49,7 +49,7 @@ class O3ModelSelectionTest(BaseSimulatorTest):
|
|||||||
self.setup_test_files()
|
self.setup_test_files()
|
||||||
|
|
||||||
# Get timestamp for log filtering
|
# Get timestamp for log filtering
|
||||||
start_time = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
||||||
|
|
||||||
# Test 1: Explicit O3 model selection
|
# Test 1: Explicit O3 model selection
|
||||||
self.logger.info(" 1: Testing explicit O3 model selection")
|
self.logger.info(" 1: Testing explicit O3 model selection")
|
||||||
@@ -115,37 +115,26 @@ def multiply(x, y):
|
|||||||
|
|
||||||
self.logger.info(" ✅ O3 with codereview tool completed")
|
self.logger.info(" ✅ O3 with codereview tool completed")
|
||||||
|
|
||||||
# Validate model usage from server logs
|
# Validate model usage from server logs
|
||||||
self.logger.info(" 4: Validating model usage in logs")
|
self.logger.info(" 4: Validating model usage in logs")
|
||||||
logs = self.get_recent_server_logs()
|
logs = self.get_recent_server_logs()
|
||||||
|
|
||||||
# Check for OpenAI API calls (this proves O3 models are being used)
|
# Check for OpenAI API calls (this proves O3 models are being used)
|
||||||
openai_api_logs = [
|
openai_api_logs = [line for line in logs.split("\n") if "Sending request to openai API" in line]
|
||||||
line for line in logs.split("\n")
|
|
||||||
if "Sending request to openai API" in line
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check for OpenAI HTTP responses (confirms successful O3 calls)
|
# Check for OpenAI HTTP responses (confirms successful O3 calls)
|
||||||
openai_http_logs = [
|
openai_http_logs = [
|
||||||
line for line in logs.split("\n")
|
line for line in logs.split("\n") if "HTTP Request: POST https://api.openai.com" in line
|
||||||
if "HTTP Request: POST https://api.openai.com" in line
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check for received responses from OpenAI
|
# Check for received responses from OpenAI
|
||||||
openai_response_logs = [
|
openai_response_logs = [line for line in logs.split("\n") if "Received response from openai API" in line]
|
||||||
line for line in logs.split("\n")
|
|
||||||
if "Received response from openai API" in line
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check that we have both chat and codereview tool calls to OpenAI
|
# Check that we have both chat and codereview tool calls to OpenAI
|
||||||
chat_openai_logs = [
|
chat_openai_logs = [line for line in logs.split("\n") if "Sending request to openai API for chat" in line]
|
||||||
line for line in logs.split("\n")
|
|
||||||
if "Sending request to openai API for chat" in line
|
|
||||||
]
|
|
||||||
|
|
||||||
codereview_openai_logs = [
|
codereview_openai_logs = [
|
||||||
line for line in logs.split("\n")
|
line for line in logs.split("\n") if "Sending request to openai API for codereview" in line
|
||||||
if "Sending request to openai API for codereview" in line
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Validation criteria - we expect 3 OpenAI calls (2 chat + 1 codereview)
|
# Validation criteria - we expect 3 OpenAI calls (2 chat + 1 codereview)
|
||||||
@@ -178,7 +167,7 @@ def multiply(x, y):
|
|||||||
("OpenAI HTTP requests successful", openai_http_success),
|
("OpenAI HTTP requests successful", openai_http_success),
|
||||||
("OpenAI responses received", openai_responses_received),
|
("OpenAI responses received", openai_responses_received),
|
||||||
("Chat tool used OpenAI", chat_calls_to_openai),
|
("Chat tool used OpenAI", chat_calls_to_openai),
|
||||||
("Codereview tool used OpenAI", codereview_calls_to_openai)
|
("Codereview tool used OpenAI", codereview_calls_to_openai),
|
||||||
]
|
]
|
||||||
|
|
||||||
passed_criteria = sum(1 for _, passed in success_criteria if passed)
|
passed_criteria = sum(1 for _, passed in success_criteria if passed)
|
||||||
@@ -214,4 +203,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -10,9 +10,8 @@ This test validates that:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import subprocess
|
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Tuple
|
import subprocess
|
||||||
|
|
||||||
from .base_test import BaseSimulatorTest
|
from .base_test import BaseSimulatorTest
|
||||||
|
|
||||||
@@ -33,7 +32,7 @@ class TokenAllocationValidationTest(BaseSimulatorTest):
|
|||||||
try:
|
try:
|
||||||
cmd = ["docker", "exec", self.container_name, "tail", "-n", "300", "/tmp/mcp_server.log"]
|
cmd = ["docker", "exec", self.container_name, "tail", "-n", "300", "/tmp/mcp_server.log"]
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
return result.stdout
|
return result.stdout
|
||||||
else:
|
else:
|
||||||
@@ -43,13 +42,13 @@ class TokenAllocationValidationTest(BaseSimulatorTest):
|
|||||||
self.logger.error(f"Failed to get server logs: {e}")
|
self.logger.error(f"Failed to get server logs: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def extract_conversation_usage_logs(self, logs: str) -> List[Dict[str, int]]:
|
def extract_conversation_usage_logs(self, logs: str) -> list[dict[str, int]]:
|
||||||
"""Extract actual conversation token usage from server logs"""
|
"""Extract actual conversation token usage from server logs"""
|
||||||
usage_logs = []
|
usage_logs = []
|
||||||
|
|
||||||
# Look for conversation debug logs that show actual usage
|
# Look for conversation debug logs that show actual usage
|
||||||
lines = logs.split('\n')
|
lines = logs.split("\n")
|
||||||
|
|
||||||
for i, line in enumerate(lines):
|
for i, line in enumerate(lines):
|
||||||
if "[CONVERSATION_DEBUG] Token budget calculation:" in line:
|
if "[CONVERSATION_DEBUG] Token budget calculation:" in line:
|
||||||
# Found start of token budget log, extract the following lines
|
# Found start of token budget log, extract the following lines
|
||||||
@@ -57,47 +56,47 @@ class TokenAllocationValidationTest(BaseSimulatorTest):
|
|||||||
for j in range(1, 8): # Next 7 lines contain the usage details
|
for j in range(1, 8): # Next 7 lines contain the usage details
|
||||||
if i + j < len(lines):
|
if i + j < len(lines):
|
||||||
detail_line = lines[i + j]
|
detail_line = lines[i + j]
|
||||||
|
|
||||||
# Parse Total capacity: 1,048,576
|
# Parse Total capacity: 1,048,576
|
||||||
if "Total capacity:" in detail_line:
|
if "Total capacity:" in detail_line:
|
||||||
match = re.search(r'Total capacity:\s*([\d,]+)', detail_line)
|
match = re.search(r"Total capacity:\s*([\d,]+)", detail_line)
|
||||||
if match:
|
if match:
|
||||||
usage['total_capacity'] = int(match.group(1).replace(',', ''))
|
usage["total_capacity"] = int(match.group(1).replace(",", ""))
|
||||||
|
|
||||||
# Parse Content allocation: 838,860
|
# Parse Content allocation: 838,860
|
||||||
elif "Content allocation:" in detail_line:
|
elif "Content allocation:" in detail_line:
|
||||||
match = re.search(r'Content allocation:\s*([\d,]+)', detail_line)
|
match = re.search(r"Content allocation:\s*([\d,]+)", detail_line)
|
||||||
if match:
|
if match:
|
||||||
usage['content_allocation'] = int(match.group(1).replace(',', ''))
|
usage["content_allocation"] = int(match.group(1).replace(",", ""))
|
||||||
|
|
||||||
# Parse Conversation tokens: 12,345
|
# Parse Conversation tokens: 12,345
|
||||||
elif "Conversation tokens:" in detail_line:
|
elif "Conversation tokens:" in detail_line:
|
||||||
match = re.search(r'Conversation tokens:\s*([\d,]+)', detail_line)
|
match = re.search(r"Conversation tokens:\s*([\d,]+)", detail_line)
|
||||||
if match:
|
if match:
|
||||||
usage['conversation_tokens'] = int(match.group(1).replace(',', ''))
|
usage["conversation_tokens"] = int(match.group(1).replace(",", ""))
|
||||||
|
|
||||||
# Parse Remaining tokens: 825,515
|
# Parse Remaining tokens: 825,515
|
||||||
elif "Remaining tokens:" in detail_line:
|
elif "Remaining tokens:" in detail_line:
|
||||||
match = re.search(r'Remaining tokens:\s*([\d,]+)', detail_line)
|
match = re.search(r"Remaining tokens:\s*([\d,]+)", detail_line)
|
||||||
if match:
|
if match:
|
||||||
usage['remaining_tokens'] = int(match.group(1).replace(',', ''))
|
usage["remaining_tokens"] = int(match.group(1).replace(",", ""))
|
||||||
|
|
||||||
if usage: # Only add if we found some usage data
|
if usage: # Only add if we found some usage data
|
||||||
usage_logs.append(usage)
|
usage_logs.append(usage)
|
||||||
|
|
||||||
return usage_logs
|
return usage_logs
|
||||||
|
|
||||||
def extract_conversation_token_usage(self, logs: str) -> List[int]:
|
def extract_conversation_token_usage(self, logs: str) -> list[int]:
|
||||||
"""Extract conversation token usage from logs"""
|
"""Extract conversation token usage from logs"""
|
||||||
usage_values = []
|
usage_values = []
|
||||||
|
|
||||||
# Look for conversation token usage logs
|
# Look for conversation token usage logs
|
||||||
pattern = r'Conversation history token usage:\s*([\d,]+)'
|
pattern = r"Conversation history token usage:\s*([\d,]+)"
|
||||||
matches = re.findall(pattern, logs)
|
matches = re.findall(pattern, logs)
|
||||||
|
|
||||||
for match in matches:
|
for match in matches:
|
||||||
usage_values.append(int(match.replace(',', '')))
|
usage_values.append(int(match.replace(",", "")))
|
||||||
|
|
||||||
return usage_values
|
return usage_values
|
||||||
|
|
||||||
def run_test(self) -> bool:
|
def run_test(self) -> bool:
|
||||||
@@ -111,11 +110,11 @@ class TokenAllocationValidationTest(BaseSimulatorTest):
|
|||||||
# Create additional test files for this test - make them substantial enough to see token differences
|
# Create additional test files for this test - make them substantial enough to see token differences
|
||||||
file1_content = """def fibonacci(n):
|
file1_content = """def fibonacci(n):
|
||||||
'''Calculate fibonacci number recursively
|
'''Calculate fibonacci number recursively
|
||||||
|
|
||||||
This is a classic recursive algorithm that demonstrates
|
This is a classic recursive algorithm that demonstrates
|
||||||
the exponential time complexity of naive recursion.
|
the exponential time complexity of naive recursion.
|
||||||
For large values of n, this becomes very slow.
|
For large values of n, this becomes very slow.
|
||||||
|
|
||||||
Time complexity: O(2^n)
|
Time complexity: O(2^n)
|
||||||
Space complexity: O(n) due to call stack
|
Space complexity: O(n) due to call stack
|
||||||
'''
|
'''
|
||||||
@@ -125,10 +124,10 @@ class TokenAllocationValidationTest(BaseSimulatorTest):
|
|||||||
|
|
||||||
def factorial(n):
|
def factorial(n):
|
||||||
'''Calculate factorial using recursion
|
'''Calculate factorial using recursion
|
||||||
|
|
||||||
More efficient than fibonacci as each value
|
More efficient than fibonacci as each value
|
||||||
is calculated only once.
|
is calculated only once.
|
||||||
|
|
||||||
Time complexity: O(n)
|
Time complexity: O(n)
|
||||||
Space complexity: O(n) due to call stack
|
Space complexity: O(n) due to call stack
|
||||||
'''
|
'''
|
||||||
@@ -157,14 +156,14 @@ if __name__ == "__main__":
|
|||||||
for i in range(10):
|
for i in range(10):
|
||||||
print(f" F({i}) = {fibonacci(i)}")
|
print(f" F({i}) = {fibonacci(i)}")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
file2_content = """class Calculator:
|
file2_content = """class Calculator:
|
||||||
'''Advanced calculator class with error handling and logging'''
|
'''Advanced calculator class with error handling and logging'''
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.history = []
|
self.history = []
|
||||||
self.last_result = 0
|
self.last_result = 0
|
||||||
|
|
||||||
def add(self, a, b):
|
def add(self, a, b):
|
||||||
'''Addition with history tracking'''
|
'''Addition with history tracking'''
|
||||||
result = a + b
|
result = a + b
|
||||||
@@ -172,7 +171,7 @@ if __name__ == "__main__":
|
|||||||
self.history.append(operation)
|
self.history.append(operation)
|
||||||
self.last_result = result
|
self.last_result = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def multiply(self, a, b):
|
def multiply(self, a, b):
|
||||||
'''Multiplication with history tracking'''
|
'''Multiplication with history tracking'''
|
||||||
result = a * b
|
result = a * b
|
||||||
@@ -180,20 +179,20 @@ if __name__ == "__main__":
|
|||||||
self.history.append(operation)
|
self.history.append(operation)
|
||||||
self.last_result = result
|
self.last_result = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def divide(self, a, b):
|
def divide(self, a, b):
|
||||||
'''Division with error handling and history tracking'''
|
'''Division with error handling and history tracking'''
|
||||||
if b == 0:
|
if b == 0:
|
||||||
error_msg = f"Division by zero error: {a} / {b}"
|
error_msg = f"Division by zero error: {a} / {b}"
|
||||||
self.history.append(error_msg)
|
self.history.append(error_msg)
|
||||||
raise ValueError("Cannot divide by zero")
|
raise ValueError("Cannot divide by zero")
|
||||||
|
|
||||||
result = a / b
|
result = a / b
|
||||||
operation = f"{a} / {b} = {result}"
|
operation = f"{a} / {b} = {result}"
|
||||||
self.history.append(operation)
|
self.history.append(operation)
|
||||||
self.last_result = result
|
self.last_result = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def power(self, base, exponent):
|
def power(self, base, exponent):
|
||||||
'''Exponentiation with history tracking'''
|
'''Exponentiation with history tracking'''
|
||||||
result = base ** exponent
|
result = base ** exponent
|
||||||
@@ -201,11 +200,11 @@ if __name__ == "__main__":
|
|||||||
self.history.append(operation)
|
self.history.append(operation)
|
||||||
self.last_result = result
|
self.last_result = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_history(self):
|
def get_history(self):
|
||||||
'''Return calculation history'''
|
'''Return calculation history'''
|
||||||
return self.history.copy()
|
return self.history.copy()
|
||||||
|
|
||||||
def clear_history(self):
|
def clear_history(self):
|
||||||
'''Clear calculation history'''
|
'''Clear calculation history'''
|
||||||
self.history.clear()
|
self.history.clear()
|
||||||
@@ -215,32 +214,32 @@ if __name__ == "__main__":
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
calc = Calculator()
|
calc = Calculator()
|
||||||
print("=== Calculator Demo ===")
|
print("=== Calculator Demo ===")
|
||||||
|
|
||||||
# Perform various calculations
|
# Perform various calculations
|
||||||
print(f"Addition: {calc.add(10, 20)}")
|
print(f"Addition: {calc.add(10, 20)}")
|
||||||
print(f"Multiplication: {calc.multiply(5, 8)}")
|
print(f"Multiplication: {calc.multiply(5, 8)}")
|
||||||
print(f"Division: {calc.divide(100, 4)}")
|
print(f"Division: {calc.divide(100, 4)}")
|
||||||
print(f"Power: {calc.power(2, 8)}")
|
print(f"Power: {calc.power(2, 8)}")
|
||||||
|
|
||||||
print("\\nCalculation History:")
|
print("\\nCalculation History:")
|
||||||
for operation in calc.get_history():
|
for operation in calc.get_history():
|
||||||
print(f" {operation}")
|
print(f" {operation}")
|
||||||
|
|
||||||
print(f"\\nLast result: {calc.last_result}")
|
print(f"\\nLast result: {calc.last_result}")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create test files
|
# Create test files
|
||||||
file1_path = self.create_additional_test_file("math_functions.py", file1_content)
|
file1_path = self.create_additional_test_file("math_functions.py", file1_content)
|
||||||
file2_path = self.create_additional_test_file("calculator.py", file2_content)
|
file2_path = self.create_additional_test_file("calculator.py", file2_content)
|
||||||
|
|
||||||
# Track continuation IDs to validate each step generates new ones
|
# Track continuation IDs to validate each step generates new ones
|
||||||
continuation_ids = []
|
continuation_ids = []
|
||||||
|
|
||||||
# Step 1: Initial chat with first file
|
# Step 1: Initial chat with first file
|
||||||
self.logger.info(" Step 1: Initial chat with file1 - checking token allocation")
|
self.logger.info(" Step 1: Initial chat with file1 - checking token allocation")
|
||||||
|
|
||||||
step1_start_time = datetime.datetime.now()
|
datetime.datetime.now()
|
||||||
|
|
||||||
response1, continuation_id1 = self.call_mcp_tool(
|
response1, continuation_id1 = self.call_mcp_tool(
|
||||||
"chat",
|
"chat",
|
||||||
{
|
{
|
||||||
@@ -260,31 +259,33 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Get logs and analyze file processing (Step 1 is new conversation, no conversation debug logs expected)
|
# Get logs and analyze file processing (Step 1 is new conversation, no conversation debug logs expected)
|
||||||
logs_step1 = self.get_recent_server_logs()
|
logs_step1 = self.get_recent_server_logs()
|
||||||
|
|
||||||
# For Step 1, check for file embedding logs instead of conversation usage
|
# For Step 1, check for file embedding logs instead of conversation usage
|
||||||
file_embedding_logs_step1 = [
|
file_embedding_logs_step1 = [
|
||||||
line for line in logs_step1.split('\n')
|
line
|
||||||
if 'successfully embedded' in line and 'files' in line and 'tokens' in line
|
for line in logs_step1.split("\n")
|
||||||
|
if "successfully embedded" in line and "files" in line and "tokens" in line
|
||||||
]
|
]
|
||||||
|
|
||||||
if not file_embedding_logs_step1:
|
if not file_embedding_logs_step1:
|
||||||
self.logger.error(" ❌ Step 1: No file embedding logs found")
|
self.logger.error(" ❌ Step 1: No file embedding logs found")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Extract file token count from embedding logs
|
# Extract file token count from embedding logs
|
||||||
step1_file_tokens = 0
|
step1_file_tokens = 0
|
||||||
for log in file_embedding_logs_step1:
|
for log in file_embedding_logs_step1:
|
||||||
# Look for pattern like "successfully embedded 1 files (146 tokens)"
|
# Look for pattern like "successfully embedded 1 files (146 tokens)"
|
||||||
import re
|
import re
|
||||||
match = re.search(r'\((\d+) tokens\)', log)
|
|
||||||
|
match = re.search(r"\((\d+) tokens\)", log)
|
||||||
if match:
|
if match:
|
||||||
step1_file_tokens = int(match.group(1))
|
step1_file_tokens = int(match.group(1))
|
||||||
break
|
break
|
||||||
|
|
||||||
self.logger.info(f" 📊 Step 1 File Processing - Embedded files: {step1_file_tokens:,} tokens")
|
self.logger.info(f" 📊 Step 1 File Processing - Embedded files: {step1_file_tokens:,} tokens")
|
||||||
|
|
||||||
# Validate that file1 is actually mentioned in the embedding logs (check for actual filename)
|
# Validate that file1 is actually mentioned in the embedding logs (check for actual filename)
|
||||||
file1_mentioned = any('math_functions.py' in log for log in file_embedding_logs_step1)
|
file1_mentioned = any("math_functions.py" in log for log in file_embedding_logs_step1)
|
||||||
if not file1_mentioned:
|
if not file1_mentioned:
|
||||||
# Debug: show what files were actually found in the logs
|
# Debug: show what files were actually found in the logs
|
||||||
self.logger.debug(" 📋 Files found in embedding logs:")
|
self.logger.debug(" 📋 Files found in embedding logs:")
|
||||||
@@ -300,8 +301,10 @@ if __name__ == "__main__":
|
|||||||
# Continue test - the important thing is that files were processed
|
# Continue test - the important thing is that files were processed
|
||||||
|
|
||||||
# Step 2: Different tool continuing same conversation - should build conversation history
|
# Step 2: Different tool continuing same conversation - should build conversation history
|
||||||
self.logger.info(" Step 2: Analyze tool continuing chat conversation - checking conversation history buildup")
|
self.logger.info(
|
||||||
|
" Step 2: Analyze tool continuing chat conversation - checking conversation history buildup"
|
||||||
|
)
|
||||||
|
|
||||||
response2, continuation_id2 = self.call_mcp_tool(
|
response2, continuation_id2 = self.call_mcp_tool(
|
||||||
"analyze",
|
"analyze",
|
||||||
{
|
{
|
||||||
@@ -314,12 +317,12 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not response2 or not continuation_id2:
|
if not response2 or not continuation_id2:
|
||||||
self.logger.error(" ❌ Step 2 failed - no response or continuation ID")
|
self.logger.error(" ❌ Step 2 failed - no response or continuation ID")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.logger.info(f" ✅ Step 2 completed with continuation_id: {continuation_id2[:8]}...")
|
self.logger.info(f" ✅ Step 2 completed with continuation_id: {continuation_id2[:8]}...")
|
||||||
continuation_ids.append(continuation_id2)
|
continuation_ids.append(continuation_id2)
|
||||||
|
|
||||||
# Validate that we got a different continuation ID
|
# Validate that we got a different continuation ID
|
||||||
if continuation_id2 == continuation_id1:
|
if continuation_id2 == continuation_id1:
|
||||||
self.logger.error(" ❌ Step 2: Got same continuation ID as Step 1 - continuation not working")
|
self.logger.error(" ❌ Step 2: Got same continuation ID as Step 1 - continuation not working")
|
||||||
@@ -328,33 +331,37 @@ if __name__ == "__main__":
|
|||||||
# Get logs and analyze token usage
|
# Get logs and analyze token usage
|
||||||
logs_step2 = self.get_recent_server_logs()
|
logs_step2 = self.get_recent_server_logs()
|
||||||
usage_step2 = self.extract_conversation_usage_logs(logs_step2)
|
usage_step2 = self.extract_conversation_usage_logs(logs_step2)
|
||||||
|
|
||||||
if len(usage_step2) < 2:
|
if len(usage_step2) < 2:
|
||||||
self.logger.warning(f" ⚠️ Step 2: Only found {len(usage_step2)} conversation usage logs, expected at least 2")
|
self.logger.warning(
|
||||||
# Debug: Look for any CONVERSATION_DEBUG logs
|
f" ⚠️ Step 2: Only found {len(usage_step2)} conversation usage logs, expected at least 2"
|
||||||
conversation_debug_lines = [line for line in logs_step2.split('\n') if 'CONVERSATION_DEBUG' in line]
|
)
|
||||||
|
# Debug: Look for any CONVERSATION_DEBUG logs
|
||||||
|
conversation_debug_lines = [line for line in logs_step2.split("\n") if "CONVERSATION_DEBUG" in line]
|
||||||
self.logger.debug(f" 📋 Found {len(conversation_debug_lines)} CONVERSATION_DEBUG lines in step 2")
|
self.logger.debug(f" 📋 Found {len(conversation_debug_lines)} CONVERSATION_DEBUG lines in step 2")
|
||||||
|
|
||||||
if conversation_debug_lines:
|
if conversation_debug_lines:
|
||||||
self.logger.debug(" 📋 Recent CONVERSATION_DEBUG lines:")
|
self.logger.debug(" 📋 Recent CONVERSATION_DEBUG lines:")
|
||||||
for line in conversation_debug_lines[-10:]: # Show last 10
|
for line in conversation_debug_lines[-10:]: # Show last 10
|
||||||
self.logger.debug(f" {line}")
|
self.logger.debug(f" {line}")
|
||||||
|
|
||||||
# If we have at least 1 usage log, continue with adjusted expectations
|
# If we have at least 1 usage log, continue with adjusted expectations
|
||||||
if len(usage_step2) >= 1:
|
if len(usage_step2) >= 1:
|
||||||
self.logger.info(" 📋 Continuing with single usage log for analysis")
|
self.logger.info(" 📋 Continuing with single usage log for analysis")
|
||||||
else:
|
else:
|
||||||
self.logger.error(" ❌ No conversation usage logs found at all")
|
self.logger.error(" ❌ No conversation usage logs found at all")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
latest_usage_step2 = usage_step2[-1] # Get most recent usage
|
latest_usage_step2 = usage_step2[-1] # Get most recent usage
|
||||||
self.logger.info(f" 📊 Step 2 Token Usage - Total Capacity: {latest_usage_step2.get('total_capacity', 0):,}, "
|
self.logger.info(
|
||||||
f"Conversation: {latest_usage_step2.get('conversation_tokens', 0):,}, "
|
f" 📊 Step 2 Token Usage - Total Capacity: {latest_usage_step2.get('total_capacity', 0):,}, "
|
||||||
f"Remaining: {latest_usage_step2.get('remaining_tokens', 0):,}")
|
f"Conversation: {latest_usage_step2.get('conversation_tokens', 0):,}, "
|
||||||
|
f"Remaining: {latest_usage_step2.get('remaining_tokens', 0):,}"
|
||||||
|
)
|
||||||
|
|
||||||
# Step 3: Continue conversation with additional file - should show increased token usage
|
# Step 3: Continue conversation with additional file - should show increased token usage
|
||||||
self.logger.info(" Step 3: Continue conversation with file1 + file2 - checking token growth")
|
self.logger.info(" Step 3: Continue conversation with file1 + file2 - checking token growth")
|
||||||
|
|
||||||
response3, continuation_id3 = self.call_mcp_tool(
|
response3, continuation_id3 = self.call_mcp_tool(
|
||||||
"chat",
|
"chat",
|
||||||
{
|
{
|
||||||
@@ -376,26 +383,30 @@ if __name__ == "__main__":
|
|||||||
# Get logs and analyze final token usage
|
# Get logs and analyze final token usage
|
||||||
logs_step3 = self.get_recent_server_logs()
|
logs_step3 = self.get_recent_server_logs()
|
||||||
usage_step3 = self.extract_conversation_usage_logs(logs_step3)
|
usage_step3 = self.extract_conversation_usage_logs(logs_step3)
|
||||||
|
|
||||||
self.logger.info(f" 📋 Found {len(usage_step3)} total conversation usage logs")
|
self.logger.info(f" 📋 Found {len(usage_step3)} total conversation usage logs")
|
||||||
|
|
||||||
if len(usage_step3) < 3:
|
if len(usage_step3) < 3:
|
||||||
self.logger.warning(f" ⚠️ Step 3: Only found {len(usage_step3)} conversation usage logs, expected at least 3")
|
self.logger.warning(
|
||||||
|
f" ⚠️ Step 3: Only found {len(usage_step3)} conversation usage logs, expected at least 3"
|
||||||
|
)
|
||||||
# Let's check if we have at least some logs to work with
|
# Let's check if we have at least some logs to work with
|
||||||
if len(usage_step3) == 0:
|
if len(usage_step3) == 0:
|
||||||
self.logger.error(" ❌ No conversation usage logs found at all")
|
self.logger.error(" ❌ No conversation usage logs found at all")
|
||||||
# Debug: show some recent logs
|
# Debug: show some recent logs
|
||||||
recent_lines = logs_step3.split('\n')[-50:]
|
recent_lines = logs_step3.split("\n")[-50:]
|
||||||
self.logger.debug(" 📋 Recent log lines:")
|
self.logger.debug(" 📋 Recent log lines:")
|
||||||
for line in recent_lines:
|
for line in recent_lines:
|
||||||
if line.strip() and "CONVERSATION_DEBUG" in line:
|
if line.strip() and "CONVERSATION_DEBUG" in line:
|
||||||
self.logger.debug(f" {line}")
|
self.logger.debug(f" {line}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
latest_usage_step3 = usage_step3[-1] # Get most recent usage
|
latest_usage_step3 = usage_step3[-1] # Get most recent usage
|
||||||
self.logger.info(f" 📊 Step 3 Token Usage - Total Capacity: {latest_usage_step3.get('total_capacity', 0):,}, "
|
self.logger.info(
|
||||||
f"Conversation: {latest_usage_step3.get('conversation_tokens', 0):,}, "
|
f" 📊 Step 3 Token Usage - Total Capacity: {latest_usage_step3.get('total_capacity', 0):,}, "
|
||||||
f"Remaining: {latest_usage_step3.get('remaining_tokens', 0):,}")
|
f"Conversation: {latest_usage_step3.get('conversation_tokens', 0):,}, "
|
||||||
|
f"Remaining: {latest_usage_step3.get('remaining_tokens', 0):,}"
|
||||||
|
)
|
||||||
|
|
||||||
# Validation: Check token processing and conversation history
|
# Validation: Check token processing and conversation history
|
||||||
self.logger.info(" 📋 Validating token processing and conversation history...")
|
self.logger.info(" 📋 Validating token processing and conversation history...")
|
||||||
@@ -405,14 +416,14 @@ if __name__ == "__main__":
|
|||||||
step2_remaining = 0
|
step2_remaining = 0
|
||||||
step3_conversation = 0
|
step3_conversation = 0
|
||||||
step3_remaining = 0
|
step3_remaining = 0
|
||||||
|
|
||||||
if len(usage_step2) > 0:
|
if len(usage_step2) > 0:
|
||||||
step2_conversation = latest_usage_step2.get('conversation_tokens', 0)
|
step2_conversation = latest_usage_step2.get("conversation_tokens", 0)
|
||||||
step2_remaining = latest_usage_step2.get('remaining_tokens', 0)
|
step2_remaining = latest_usage_step2.get("remaining_tokens", 0)
|
||||||
|
|
||||||
if len(usage_step3) >= len(usage_step2) + 1: # Should have one more log than step2
|
if len(usage_step3) >= len(usage_step2) + 1: # Should have one more log than step2
|
||||||
step3_conversation = latest_usage_step3.get('conversation_tokens', 0)
|
step3_conversation = latest_usage_step3.get("conversation_tokens", 0)
|
||||||
step3_remaining = latest_usage_step3.get('remaining_tokens', 0)
|
step3_remaining = latest_usage_step3.get("remaining_tokens", 0)
|
||||||
else:
|
else:
|
||||||
# Use step2 values as fallback
|
# Use step2 values as fallback
|
||||||
step3_conversation = step2_conversation
|
step3_conversation = step2_conversation
|
||||||
@@ -421,62 +432,78 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Validation criteria
|
# Validation criteria
|
||||||
criteria = []
|
criteria = []
|
||||||
|
|
||||||
# 1. Step 1 should have processed files successfully
|
# 1. Step 1 should have processed files successfully
|
||||||
step1_processed_files = step1_file_tokens > 0
|
step1_processed_files = step1_file_tokens > 0
|
||||||
criteria.append(("Step 1 processed files successfully", step1_processed_files))
|
criteria.append(("Step 1 processed files successfully", step1_processed_files))
|
||||||
|
|
||||||
# 2. Step 2 should have conversation history (if continuation worked)
|
# 2. Step 2 should have conversation history (if continuation worked)
|
||||||
step2_has_conversation = step2_conversation > 0 if len(usage_step2) > 0 else True # Pass if no logs (might be different issue)
|
step2_has_conversation = (
|
||||||
|
step2_conversation > 0 if len(usage_step2) > 0 else True
|
||||||
|
) # Pass if no logs (might be different issue)
|
||||||
step2_has_remaining = step2_remaining > 0 if len(usage_step2) > 0 else True
|
step2_has_remaining = step2_remaining > 0 if len(usage_step2) > 0 else True
|
||||||
criteria.append(("Step 2 has conversation history", step2_has_conversation))
|
criteria.append(("Step 2 has conversation history", step2_has_conversation))
|
||||||
criteria.append(("Step 2 has remaining tokens", step2_has_remaining))
|
criteria.append(("Step 2 has remaining tokens", step2_has_remaining))
|
||||||
|
|
||||||
# 3. Step 3 should show conversation growth
|
# 3. Step 3 should show conversation growth
|
||||||
step3_has_conversation = step3_conversation >= step2_conversation if len(usage_step3) > len(usage_step2) else True
|
step3_has_conversation = (
|
||||||
|
step3_conversation >= step2_conversation if len(usage_step3) > len(usage_step2) else True
|
||||||
|
)
|
||||||
criteria.append(("Step 3 maintains conversation history", step3_has_conversation))
|
criteria.append(("Step 3 maintains conversation history", step3_has_conversation))
|
||||||
|
|
||||||
# 4. Check that we got some conversation usage logs for continuation calls
|
# 4. Check that we got some conversation usage logs for continuation calls
|
||||||
has_conversation_logs = len(usage_step3) > 0
|
has_conversation_logs = len(usage_step3) > 0
|
||||||
criteria.append(("Found conversation usage logs", has_conversation_logs))
|
criteria.append(("Found conversation usage logs", has_conversation_logs))
|
||||||
|
|
||||||
# 5. Validate unique continuation IDs per response
|
# 5. Validate unique continuation IDs per response
|
||||||
unique_continuation_ids = len(set(continuation_ids)) == len(continuation_ids)
|
unique_continuation_ids = len(set(continuation_ids)) == len(continuation_ids)
|
||||||
criteria.append(("Each response generated unique continuation ID", unique_continuation_ids))
|
criteria.append(("Each response generated unique continuation ID", unique_continuation_ids))
|
||||||
|
|
||||||
# 6. Validate continuation IDs were different from each step
|
# 6. Validate continuation IDs were different from each step
|
||||||
step_ids_different = len(continuation_ids) == 3 and continuation_ids[0] != continuation_ids[1] and continuation_ids[1] != continuation_ids[2]
|
step_ids_different = (
|
||||||
|
len(continuation_ids) == 3
|
||||||
|
and continuation_ids[0] != continuation_ids[1]
|
||||||
|
and continuation_ids[1] != continuation_ids[2]
|
||||||
|
)
|
||||||
criteria.append(("All continuation IDs are different", step_ids_different))
|
criteria.append(("All continuation IDs are different", step_ids_different))
|
||||||
|
|
||||||
# Log detailed analysis
|
# Log detailed analysis
|
||||||
self.logger.info(f" 📊 Token Processing Analysis:")
|
self.logger.info(" 📊 Token Processing Analysis:")
|
||||||
self.logger.info(f" Step 1 - File tokens: {step1_file_tokens:,} (new conversation)")
|
self.logger.info(f" Step 1 - File tokens: {step1_file_tokens:,} (new conversation)")
|
||||||
self.logger.info(f" Step 2 - Conversation: {step2_conversation:,}, Remaining: {step2_remaining:,}")
|
self.logger.info(f" Step 2 - Conversation: {step2_conversation:,}, Remaining: {step2_remaining:,}")
|
||||||
self.logger.info(f" Step 3 - Conversation: {step3_conversation:,}, Remaining: {step3_remaining:,}")
|
self.logger.info(f" Step 3 - Conversation: {step3_conversation:,}, Remaining: {step3_remaining:,}")
|
||||||
|
|
||||||
# Log continuation ID analysis
|
# Log continuation ID analysis
|
||||||
self.logger.info(f" 📊 Continuation ID Analysis:")
|
self.logger.info(" 📊 Continuation ID Analysis:")
|
||||||
self.logger.info(f" Step 1 ID: {continuation_ids[0][:8]}... (generated)")
|
self.logger.info(f" Step 1 ID: {continuation_ids[0][:8]}... (generated)")
|
||||||
self.logger.info(f" Step 2 ID: {continuation_ids[1][:8]}... (generated from Step 1)")
|
self.logger.info(f" Step 2 ID: {continuation_ids[1][:8]}... (generated from Step 1)")
|
||||||
self.logger.info(f" Step 3 ID: {continuation_ids[2][:8]}... (generated from Step 2)")
|
self.logger.info(f" Step 3 ID: {continuation_ids[2][:8]}... (generated from Step 2)")
|
||||||
|
|
||||||
# Check for file mentions in step 3 (should include both files)
|
# Check for file mentions in step 3 (should include both files)
|
||||||
# Look for file processing in conversation memory logs and tool embedding logs
|
# Look for file processing in conversation memory logs and tool embedding logs
|
||||||
file2_mentioned_step3 = any('calculator.py' in log for log in logs_step3.split('\n') if ('embedded' in log.lower() and ('conversation' in log.lower() or 'tool' in log.lower())))
|
file2_mentioned_step3 = any(
|
||||||
file1_still_mentioned_step3 = any('math_functions.py' in log for log in logs_step3.split('\n') if ('embedded' in log.lower() and ('conversation' in log.lower() or 'tool' in log.lower())))
|
"calculator.py" in log
|
||||||
|
for log in logs_step3.split("\n")
|
||||||
self.logger.info(f" 📊 File Processing in Step 3:")
|
if ("embedded" in log.lower() and ("conversation" in log.lower() or "tool" in log.lower()))
|
||||||
|
)
|
||||||
|
file1_still_mentioned_step3 = any(
|
||||||
|
"math_functions.py" in log
|
||||||
|
for log in logs_step3.split("\n")
|
||||||
|
if ("embedded" in log.lower() and ("conversation" in log.lower() or "tool" in log.lower()))
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(" 📊 File Processing in Step 3:")
|
||||||
self.logger.info(f" File1 (math_functions.py) mentioned: {file1_still_mentioned_step3}")
|
self.logger.info(f" File1 (math_functions.py) mentioned: {file1_still_mentioned_step3}")
|
||||||
self.logger.info(f" File2 (calculator.py) mentioned: {file2_mentioned_step3}")
|
self.logger.info(f" File2 (calculator.py) mentioned: {file2_mentioned_step3}")
|
||||||
|
|
||||||
# Add file increase validation
|
# Add file increase validation
|
||||||
step3_file_increase = file2_mentioned_step3 # New file should be visible
|
step3_file_increase = file2_mentioned_step3 # New file should be visible
|
||||||
criteria.append(("Step 3 shows new file being processed", step3_file_increase))
|
criteria.append(("Step 3 shows new file being processed", step3_file_increase))
|
||||||
|
|
||||||
# Check validation criteria
|
# Check validation criteria
|
||||||
passed_criteria = sum(1 for _, passed in criteria if passed)
|
passed_criteria = sum(1 for _, passed in criteria if passed)
|
||||||
total_criteria = len(criteria)
|
total_criteria = len(criteria)
|
||||||
|
|
||||||
self.logger.info(f" 📊 Validation criteria: {passed_criteria}/{total_criteria}")
|
self.logger.info(f" 📊 Validation criteria: {passed_criteria}/{total_criteria}")
|
||||||
for criterion, passed in criteria:
|
for criterion, passed in criteria:
|
||||||
status = "✅" if passed else "❌"
|
status = "✅" if passed else "❌"
|
||||||
@@ -484,15 +511,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Check for file embedding logs
|
# Check for file embedding logs
|
||||||
file_embedding_logs = [
|
file_embedding_logs = [
|
||||||
line for line in logs_step3.split('\n')
|
line for line in logs_step3.split("\n") if "tool embedding" in line and "files" in line
|
||||||
if 'tool embedding' in line and 'files' in line
|
|
||||||
]
|
|
||||||
|
|
||||||
conversation_logs = [
|
|
||||||
line for line in logs_step3.split('\n')
|
|
||||||
if 'conversation history' in line.lower()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
conversation_logs = [line for line in logs_step3.split("\n") if "conversation history" in line.lower()]
|
||||||
|
|
||||||
self.logger.info(f" 📊 File embedding logs: {len(file_embedding_logs)}")
|
self.logger.info(f" 📊 File embedding logs: {len(file_embedding_logs)}")
|
||||||
self.logger.info(f" 📊 Conversation history logs: {len(conversation_logs)}")
|
self.logger.info(f" 📊 Conversation history logs: {len(conversation_logs)}")
|
||||||
|
|
||||||
@@ -516,13 +539,13 @@ if __name__ == "__main__":
|
|||||||
def main():
|
def main():
|
||||||
"""Run the token allocation validation test"""
|
"""Run the token allocation validation test"""
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
verbose = "--verbose" in sys.argv or "-v" in sys.argv
|
verbose = "--verbose" in sys.argv or "-v" in sys.argv
|
||||||
test = TokenAllocationValidationTest(verbose=verbose)
|
test = TokenAllocationValidationTest(verbose=verbose)
|
||||||
|
|
||||||
success = test.run_test()
|
success = test.run_test()
|
||||||
sys.exit(0 if success else 1)
|
sys.exit(0 if success else 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
# Tests for Gemini MCP Server
|
# Tests for Zen MCP Server
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Pytest configuration for Gemini MCP Server tests
|
Pytest configuration for Zen MCP Server tests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -27,13 +27,15 @@ os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash-exp"
|
|||||||
|
|
||||||
# Force reload of config module to pick up the env var
|
# Force reload of config module to pick up the env var
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import config
|
import config
|
||||||
|
|
||||||
importlib.reload(config)
|
importlib.reload(config)
|
||||||
|
|
||||||
# Set MCP_PROJECT_ROOT to a temporary directory for tests
|
# Set MCP_PROJECT_ROOT to a temporary directory for tests
|
||||||
# This provides a safe sandbox for file operations during testing
|
# This provides a safe sandbox for file operations during testing
|
||||||
# Create a temporary directory that will be used as the project root for all tests
|
# Create a temporary directory that will be used as the project root for all tests
|
||||||
test_root = tempfile.mkdtemp(prefix="gemini_mcp_test_")
|
test_root = tempfile.mkdtemp(prefix="zen_mcp_test_")
|
||||||
os.environ["MCP_PROJECT_ROOT"] = test_root
|
os.environ["MCP_PROJECT_ROOT"] = test_root
|
||||||
|
|
||||||
# Configure asyncio for Windows compatibility
|
# Configure asyncio for Windows compatibility
|
||||||
@@ -42,9 +44,9 @@ if sys.platform == "win32":
|
|||||||
|
|
||||||
# Register providers for all tests
|
# Register providers for all tests
|
||||||
from providers import ModelProviderRegistry
|
from providers import ModelProviderRegistry
|
||||||
|
from providers.base import ProviderType
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.base import ProviderType
|
|
||||||
|
|
||||||
# Register providers at test startup
|
# Register providers at test startup
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
"""Helper functions for test mocking."""
|
"""Helper functions for test mocking."""
|
||||||
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
from providers.base import ModelCapabilities, ProviderType
|
|
||||||
|
from providers.base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||||
|
|
||||||
|
|
||||||
def create_mock_provider(model_name="gemini-2.0-flash-exp", max_tokens=1_048_576):
|
def create_mock_provider(model_name="gemini-2.0-flash-exp", max_tokens=1_048_576):
|
||||||
"""Create a properly configured mock provider."""
|
"""Create a properly configured mock provider."""
|
||||||
mock_provider = Mock()
|
mock_provider = Mock()
|
||||||
|
|
||||||
# Set up capabilities
|
# Set up capabilities
|
||||||
mock_capabilities = ModelCapabilities(
|
mock_capabilities = ModelCapabilities(
|
||||||
provider=ProviderType.GOOGLE,
|
provider=ProviderType.GOOGLE,
|
||||||
@@ -17,14 +19,14 @@ def create_mock_provider(model_name="gemini-2.0-flash-exp", max_tokens=1_048_576
|
|||||||
supports_system_prompts=True,
|
supports_system_prompts=True,
|
||||||
supports_streaming=True,
|
supports_streaming=True,
|
||||||
supports_function_calling=True,
|
supports_function_calling=True,
|
||||||
temperature_range=(0.0, 2.0),
|
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_provider.get_capabilities.return_value = mock_capabilities
|
mock_provider.get_capabilities.return_value = mock_capabilities
|
||||||
mock_provider.get_provider_type.return_value = ProviderType.GOOGLE
|
mock_provider.get_provider_type.return_value = ProviderType.GOOGLE
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.validate_model_name.return_value = True
|
mock_provider.validate_model_name.return_value = True
|
||||||
|
|
||||||
# Set up generate_content response
|
# Set up generate_content response
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.content = "Test response"
|
mock_response.content = "Test response"
|
||||||
@@ -33,7 +35,7 @@ def create_mock_provider(model_name="gemini-2.0-flash-exp", max_tokens=1_048_576
|
|||||||
mock_response.friendly_name = "Gemini"
|
mock_response.friendly_name = "Gemini"
|
||||||
mock_response.provider = ProviderType.GOOGLE
|
mock_response.provider = ProviderType.GOOGLE
|
||||||
mock_response.metadata = {"finish_reason": "STOP"}
|
mock_response.metadata = {"finish_reason": "STOP"}
|
||||||
|
|
||||||
mock_provider.generate_content.return_value = mock_response
|
mock_provider.generate_content.return_value = mock_response
|
||||||
|
|
||||||
return mock_provider
|
return mock_provider
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Tests for auto mode functionality"""
|
"""Tests for auto mode functionality"""
|
||||||
|
|
||||||
import os
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import patch, Mock
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mcp.types import TextContent
|
|
||||||
from tools.analyze import AnalyzeTool
|
from tools.analyze import AnalyzeTool
|
||||||
|
|
||||||
|
|
||||||
@@ -16,23 +16,24 @@ class TestAutoMode:
|
|||||||
"""Test that auto mode is detected correctly"""
|
"""Test that auto mode is detected correctly"""
|
||||||
# Save original
|
# Save original
|
||||||
original = os.environ.get("DEFAULT_MODEL", "")
|
original = os.environ.get("DEFAULT_MODEL", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Test auto mode
|
# Test auto mode
|
||||||
os.environ["DEFAULT_MODEL"] = "auto"
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
import config
|
import config
|
||||||
|
|
||||||
importlib.reload(config)
|
importlib.reload(config)
|
||||||
|
|
||||||
assert config.DEFAULT_MODEL == "auto"
|
assert config.DEFAULT_MODEL == "auto"
|
||||||
assert config.IS_AUTO_MODE is True
|
assert config.IS_AUTO_MODE is True
|
||||||
|
|
||||||
# Test non-auto mode
|
# Test non-auto mode
|
||||||
os.environ["DEFAULT_MODEL"] = "pro"
|
os.environ["DEFAULT_MODEL"] = "pro"
|
||||||
importlib.reload(config)
|
importlib.reload(config)
|
||||||
|
|
||||||
assert config.DEFAULT_MODEL == "pro"
|
assert config.DEFAULT_MODEL == "pro"
|
||||||
assert config.IS_AUTO_MODE is False
|
assert config.IS_AUTO_MODE is False
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore
|
# Restore
|
||||||
if original:
|
if original:
|
||||||
@@ -44,7 +45,7 @@ class TestAutoMode:
|
|||||||
def test_model_capabilities_descriptions(self):
|
def test_model_capabilities_descriptions(self):
|
||||||
"""Test that model capabilities are properly defined"""
|
"""Test that model capabilities are properly defined"""
|
||||||
from config import MODEL_CAPABILITIES_DESC
|
from config import MODEL_CAPABILITIES_DESC
|
||||||
|
|
||||||
# Check all expected models are present
|
# Check all expected models are present
|
||||||
expected_models = ["flash", "pro", "o3", "o3-mini"]
|
expected_models = ["flash", "pro", "o3", "o3-mini"]
|
||||||
for model in expected_models:
|
for model in expected_models:
|
||||||
@@ -56,25 +57,26 @@ class TestAutoMode:
|
|||||||
"""Test that tool schemas require model in auto mode"""
|
"""Test that tool schemas require model in auto mode"""
|
||||||
# Save original
|
# Save original
|
||||||
original = os.environ.get("DEFAULT_MODEL", "")
|
original = os.environ.get("DEFAULT_MODEL", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Enable auto mode
|
# Enable auto mode
|
||||||
os.environ["DEFAULT_MODEL"] = "auto"
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
import config
|
import config
|
||||||
|
|
||||||
importlib.reload(config)
|
importlib.reload(config)
|
||||||
|
|
||||||
tool = AnalyzeTool()
|
tool = AnalyzeTool()
|
||||||
schema = tool.get_input_schema()
|
schema = tool.get_input_schema()
|
||||||
|
|
||||||
# Model should be required
|
# Model should be required
|
||||||
assert "model" in schema["required"]
|
assert "model" in schema["required"]
|
||||||
|
|
||||||
# Model field should have detailed descriptions
|
# Model field should have detailed descriptions
|
||||||
model_schema = schema["properties"]["model"]
|
model_schema = schema["properties"]["model"]
|
||||||
assert "enum" in model_schema
|
assert "enum" in model_schema
|
||||||
assert "flash" in model_schema["enum"]
|
assert "flash" in model_schema["enum"]
|
||||||
assert "Choose the best model" in model_schema["description"]
|
assert "Choose the best model" in model_schema["description"]
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore
|
# Restore
|
||||||
if original:
|
if original:
|
||||||
@@ -88,10 +90,10 @@ class TestAutoMode:
|
|||||||
# This test uses the default from conftest.py which sets non-auto mode
|
# This test uses the default from conftest.py which sets non-auto mode
|
||||||
tool = AnalyzeTool()
|
tool = AnalyzeTool()
|
||||||
schema = tool.get_input_schema()
|
schema = tool.get_input_schema()
|
||||||
|
|
||||||
# Model should not be required
|
# Model should not be required
|
||||||
assert "model" not in schema["required"]
|
assert "model" not in schema["required"]
|
||||||
|
|
||||||
# Model field should have simpler description
|
# Model field should have simpler description
|
||||||
model_schema = schema["properties"]["model"]
|
model_schema = schema["properties"]["model"]
|
||||||
assert "enum" not in model_schema
|
assert "enum" not in model_schema
|
||||||
@@ -102,29 +104,27 @@ class TestAutoMode:
|
|||||||
"""Test that auto mode enforces model parameter"""
|
"""Test that auto mode enforces model parameter"""
|
||||||
# Save original
|
# Save original
|
||||||
original = os.environ.get("DEFAULT_MODEL", "")
|
original = os.environ.get("DEFAULT_MODEL", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Enable auto mode
|
# Enable auto mode
|
||||||
os.environ["DEFAULT_MODEL"] = "auto"
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
import config
|
import config
|
||||||
|
|
||||||
importlib.reload(config)
|
importlib.reload(config)
|
||||||
|
|
||||||
tool = AnalyzeTool()
|
tool = AnalyzeTool()
|
||||||
|
|
||||||
# Mock the provider to avoid real API calls
|
# Mock the provider to avoid real API calls
|
||||||
with patch.object(tool, 'get_model_provider') as mock_provider:
|
with patch.object(tool, "get_model_provider"):
|
||||||
# Execute without model parameter
|
# Execute without model parameter
|
||||||
result = await tool.execute({
|
result = await tool.execute({"files": ["/tmp/test.py"], "prompt": "Analyze this"})
|
||||||
"files": ["/tmp/test.py"],
|
|
||||||
"prompt": "Analyze this"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Should get error
|
# Should get error
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
response = result[0].text
|
response = result[0].text
|
||||||
assert "error" in response
|
assert "error" in response
|
||||||
assert "Model parameter is required" in response
|
assert "Model parameter is required" in response
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore
|
# Restore
|
||||||
if original:
|
if original:
|
||||||
@@ -136,45 +136,57 @@ class TestAutoMode:
|
|||||||
def test_model_field_schema_generation(self):
|
def test_model_field_schema_generation(self):
|
||||||
"""Test the get_model_field_schema method"""
|
"""Test the get_model_field_schema method"""
|
||||||
from tools.base import BaseTool
|
from tools.base import BaseTool
|
||||||
|
|
||||||
# Create a minimal concrete tool for testing
|
# Create a minimal concrete tool for testing
|
||||||
class TestTool(BaseTool):
|
class TestTool(BaseTool):
|
||||||
def get_name(self): return "test"
|
def get_name(self):
|
||||||
def get_description(self): return "test"
|
return "test"
|
||||||
def get_input_schema(self): return {}
|
|
||||||
def get_system_prompt(self): return ""
|
def get_description(self):
|
||||||
def get_request_model(self): return None
|
return "test"
|
||||||
async def prepare_prompt(self, request): return ""
|
|
||||||
|
def get_input_schema(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_system_prompt(self):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_request_model(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def prepare_prompt(self, request):
|
||||||
|
return ""
|
||||||
|
|
||||||
tool = TestTool()
|
tool = TestTool()
|
||||||
|
|
||||||
# Save original
|
# Save original
|
||||||
original = os.environ.get("DEFAULT_MODEL", "")
|
original = os.environ.get("DEFAULT_MODEL", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Test auto mode
|
# Test auto mode
|
||||||
os.environ["DEFAULT_MODEL"] = "auto"
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
import config
|
import config
|
||||||
|
|
||||||
importlib.reload(config)
|
importlib.reload(config)
|
||||||
|
|
||||||
schema = tool.get_model_field_schema()
|
schema = tool.get_model_field_schema()
|
||||||
assert "enum" in schema
|
assert "enum" in schema
|
||||||
assert all(model in schema["enum"] for model in ["flash", "pro", "o3"])
|
assert all(model in schema["enum"] for model in ["flash", "pro", "o3"])
|
||||||
assert "Choose the best model" in schema["description"]
|
assert "Choose the best model" in schema["description"]
|
||||||
|
|
||||||
# Test normal mode
|
# Test normal mode
|
||||||
os.environ["DEFAULT_MODEL"] = "pro"
|
os.environ["DEFAULT_MODEL"] = "pro"
|
||||||
importlib.reload(config)
|
importlib.reload(config)
|
||||||
|
|
||||||
schema = tool.get_model_field_schema()
|
schema = tool.get_model_field_schema()
|
||||||
assert "enum" not in schema
|
assert "enum" not in schema
|
||||||
assert "Available:" in schema["description"]
|
assert "Available:" in schema["description"]
|
||||||
assert "'pro'" in schema["description"]
|
assert "'pro'" in schema["description"]
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore
|
# Restore
|
||||||
if original:
|
if original:
|
||||||
os.environ["DEFAULT_MODEL"] = original
|
os.environ["DEFAULT_MODEL"] = original
|
||||||
else:
|
else:
|
||||||
os.environ.pop("DEFAULT_MODEL", None)
|
os.environ.pop("DEFAULT_MODEL", None)
|
||||||
importlib.reload(config)
|
importlib.reload(config)
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ when Gemini doesn't explicitly ask a follow-up question.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from tests.mock_helpers import create_mock_provider
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
from tools.base import BaseTool, ToolRequest
|
from tools.base import BaseTool, ToolRequest
|
||||||
from tools.models import ContinuationOffer, ToolOutput
|
from tools.models import ContinuationOffer, ToolOutput
|
||||||
from utils.conversation_memory import MAX_CONVERSATION_TURNS
|
from utils.conversation_memory import MAX_CONVERSATION_TURNS
|
||||||
@@ -125,7 +125,7 @@ class TestClaudeContinuationOffers:
|
|||||||
content="Analysis complete. The code looks good.",
|
content="Analysis complete. The code looks good.",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -176,7 +176,7 @@ class TestClaudeContinuationOffers:
|
|||||||
content=content_with_followup,
|
content=content_with_followup,
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -220,7 +220,7 @@ class TestClaudeContinuationOffers:
|
|||||||
content="Continued analysis complete.",
|
content="Continued analysis complete.",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ Tests for dynamic context request and collaboration features
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from tests.mock_helpers import create_mock_provider
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
from tools.analyze import AnalyzeTool
|
from tools.analyze import AnalyzeTool
|
||||||
from tools.debug import DebugIssueTool
|
from tools.debug import DebugIssueTool
|
||||||
from tools.models import ClarificationRequest, ToolOutput
|
from tools.models import ClarificationRequest, ToolOutput
|
||||||
@@ -41,10 +41,7 @@ class TestDynamicContextRequests:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=clarification_json,
|
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -85,10 +82,7 @@ class TestDynamicContextRequests:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=normal_response,
|
content=normal_response, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -112,10 +106,7 @@ class TestDynamicContextRequests:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=malformed_json,
|
content=malformed_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -155,10 +146,7 @@ class TestDynamicContextRequests:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=clarification_json,
|
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -245,10 +233,7 @@ class TestCollaborationWorkflow:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=clarification_json,
|
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -287,10 +272,7 @@ class TestCollaborationWorkflow:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=clarification_json,
|
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -317,10 +299,7 @@ class TestCollaborationWorkflow:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content=final_response,
|
content=final_response, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result2 = await tool.execute(
|
result2 = await tool.execute(
|
||||||
|
|||||||
@@ -2,21 +2,20 @@
|
|||||||
Test that conversation history is correctly mapped to tool-specific fields
|
Test that conversation history is correctly mapped to tool-specific fields
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
from tests.mock_helpers import create_mock_provider
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
from server import reconstruct_thread_context
|
from server import reconstruct_thread_context
|
||||||
from utils.conversation_memory import ConversationTurn, ThreadContext
|
from utils.conversation_memory import ConversationTurn, ThreadContext
|
||||||
from providers.base import ProviderType
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_conversation_history_field_mapping():
|
async def test_conversation_history_field_mapping():
|
||||||
"""Test that enhanced prompts are mapped to prompt field for all tools"""
|
"""Test that enhanced prompts are mapped to prompt field for all tools"""
|
||||||
|
|
||||||
# Test data for different tools - all use 'prompt' now
|
# Test data for different tools - all use 'prompt' now
|
||||||
test_cases = [
|
test_cases = [
|
||||||
{
|
{
|
||||||
@@ -40,7 +39,7 @@ async def test_conversation_history_field_mapping():
|
|||||||
"original_value": "My analysis so far",
|
"original_value": "My analysis so far",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
for test_case in test_cases:
|
for test_case in test_cases:
|
||||||
# Create mock conversation context
|
# Create mock conversation context
|
||||||
mock_context = ThreadContext(
|
mock_context = ThreadContext(
|
||||||
@@ -63,7 +62,7 @@ async def test_conversation_history_field_mapping():
|
|||||||
],
|
],
|
||||||
initial_context={},
|
initial_context={},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock get_thread to return our test context
|
# Mock get_thread to return our test context
|
||||||
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
|
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
|
||||||
with patch("utils.conversation_memory.add_turn", return_value=True):
|
with patch("utils.conversation_memory.add_turn", return_value=True):
|
||||||
@@ -71,43 +70,44 @@ async def test_conversation_history_field_mapping():
|
|||||||
# Mock provider registry to avoid model lookup errors
|
# Mock provider registry to avoid model lookup errors
|
||||||
with patch("providers.registry.ModelProviderRegistry.get_provider_for_model") as mock_get_provider:
|
with patch("providers.registry.ModelProviderRegistry.get_provider_for_model") as mock_get_provider:
|
||||||
from providers.base import ModelCapabilities
|
from providers.base import ModelCapabilities
|
||||||
|
|
||||||
mock_provider = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_provider.get_capabilities.return_value = ModelCapabilities(
|
mock_provider.get_capabilities.return_value = ModelCapabilities(
|
||||||
provider=ProviderType.GOOGLE,
|
provider=ProviderType.GOOGLE,
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
friendly_name="Gemini",
|
friendly_name="Gemini",
|
||||||
max_tokens=200000,
|
max_tokens=200000,
|
||||||
supports_extended_thinking=True
|
supports_extended_thinking=True,
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
# Mock conversation history building
|
# Mock conversation history building
|
||||||
mock_build.return_value = (
|
mock_build.return_value = (
|
||||||
"=== CONVERSATION HISTORY ===\nPrevious conversation content\n=== END HISTORY ===",
|
"=== CONVERSATION HISTORY ===\nPrevious conversation content\n=== END HISTORY ===",
|
||||||
1000 # mock token count
|
1000, # mock token count
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create arguments with continuation_id
|
# Create arguments with continuation_id
|
||||||
arguments = {
|
arguments = {
|
||||||
"continuation_id": "test-thread-123",
|
"continuation_id": "test-thread-123",
|
||||||
"prompt": test_case["original_value"],
|
"prompt": test_case["original_value"],
|
||||||
"files": ["/test/file2.py"],
|
"files": ["/test/file2.py"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Call reconstruct_thread_context
|
# Call reconstruct_thread_context
|
||||||
enhanced_args = await reconstruct_thread_context(arguments)
|
enhanced_args = await reconstruct_thread_context(arguments)
|
||||||
|
|
||||||
# Verify the enhanced prompt is in the prompt field
|
# Verify the enhanced prompt is in the prompt field
|
||||||
assert "prompt" in enhanced_args
|
assert "prompt" in enhanced_args
|
||||||
enhanced_value = enhanced_args["prompt"]
|
enhanced_value = enhanced_args["prompt"]
|
||||||
|
|
||||||
# Should contain conversation history
|
# Should contain conversation history
|
||||||
assert "=== CONVERSATION HISTORY ===" in enhanced_value
|
assert "=== CONVERSATION HISTORY ===" in enhanced_value
|
||||||
assert "Previous conversation content" in enhanced_value
|
assert "Previous conversation content" in enhanced_value
|
||||||
|
|
||||||
# Should contain the new user input
|
# Should contain the new user input
|
||||||
assert "=== NEW USER INPUT ===" in enhanced_value
|
assert "=== NEW USER INPUT ===" in enhanced_value
|
||||||
assert test_case["original_value"] in enhanced_value
|
assert test_case["original_value"] in enhanced_value
|
||||||
|
|
||||||
# Should have token budget
|
# Should have token budget
|
||||||
assert "_remaining_tokens" in enhanced_args
|
assert "_remaining_tokens" in enhanced_args
|
||||||
assert enhanced_args["_remaining_tokens"] > 0
|
assert enhanced_args["_remaining_tokens"] > 0
|
||||||
@@ -116,7 +116,7 @@ async def test_conversation_history_field_mapping():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unknown_tool_defaults_to_prompt():
|
async def test_unknown_tool_defaults_to_prompt():
|
||||||
"""Test that unknown tools default to using 'prompt' field"""
|
"""Test that unknown tools default to using 'prompt' field"""
|
||||||
|
|
||||||
mock_context = ThreadContext(
|
mock_context = ThreadContext(
|
||||||
thread_id="test-thread-456",
|
thread_id="test-thread-456",
|
||||||
tool_name="unknown_tool",
|
tool_name="unknown_tool",
|
||||||
@@ -125,7 +125,7 @@ async def test_unknown_tool_defaults_to_prompt():
|
|||||||
turns=[],
|
turns=[],
|
||||||
initial_context={},
|
initial_context={},
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
|
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
|
||||||
with patch("utils.conversation_memory.add_turn", return_value=True):
|
with patch("utils.conversation_memory.add_turn", return_value=True):
|
||||||
with patch("utils.conversation_memory.build_conversation_history", return_value=("History", 500)):
|
with patch("utils.conversation_memory.build_conversation_history", return_value=("History", 500)):
|
||||||
@@ -133,9 +133,9 @@ async def test_unknown_tool_defaults_to_prompt():
|
|||||||
"continuation_id": "test-thread-456",
|
"continuation_id": "test-thread-456",
|
||||||
"prompt": "User input",
|
"prompt": "User input",
|
||||||
}
|
}
|
||||||
|
|
||||||
enhanced_args = await reconstruct_thread_context(arguments)
|
enhanced_args = await reconstruct_thread_context(arguments)
|
||||||
|
|
||||||
# Should default to 'prompt' field
|
# Should default to 'prompt' field
|
||||||
assert "prompt" in enhanced_args
|
assert "prompt" in enhanced_args
|
||||||
assert "History" in enhanced_args["prompt"]
|
assert "History" in enhanced_args["prompt"]
|
||||||
@@ -145,27 +145,27 @@ async def test_unknown_tool_defaults_to_prompt():
|
|||||||
async def test_tool_parameter_standardization():
|
async def test_tool_parameter_standardization():
|
||||||
"""Test that all tools use standardized 'prompt' parameter"""
|
"""Test that all tools use standardized 'prompt' parameter"""
|
||||||
from tools.analyze import AnalyzeRequest
|
from tools.analyze import AnalyzeRequest
|
||||||
from tools.debug import DebugIssueRequest
|
|
||||||
from tools.codereview import CodeReviewRequest
|
from tools.codereview import CodeReviewRequest
|
||||||
from tools.thinkdeep import ThinkDeepRequest
|
from tools.debug import DebugIssueRequest
|
||||||
from tools.precommit import PrecommitRequest
|
from tools.precommit import PrecommitRequest
|
||||||
|
from tools.thinkdeep import ThinkDeepRequest
|
||||||
|
|
||||||
# Test analyze tool uses prompt
|
# Test analyze tool uses prompt
|
||||||
analyze = AnalyzeRequest(files=["/test.py"], prompt="What does this do?")
|
analyze = AnalyzeRequest(files=["/test.py"], prompt="What does this do?")
|
||||||
assert analyze.prompt == "What does this do?"
|
assert analyze.prompt == "What does this do?"
|
||||||
|
|
||||||
# Test debug tool uses prompt
|
# Test debug tool uses prompt
|
||||||
debug = DebugIssueRequest(prompt="Error occurred")
|
debug = DebugIssueRequest(prompt="Error occurred")
|
||||||
assert debug.prompt == "Error occurred"
|
assert debug.prompt == "Error occurred"
|
||||||
|
|
||||||
# Test codereview tool uses prompt
|
# Test codereview tool uses prompt
|
||||||
review = CodeReviewRequest(files=["/test.py"], prompt="Review this")
|
review = CodeReviewRequest(files=["/test.py"], prompt="Review this")
|
||||||
assert review.prompt == "Review this"
|
assert review.prompt == "Review this"
|
||||||
|
|
||||||
# Test thinkdeep tool uses prompt
|
# Test thinkdeep tool uses prompt
|
||||||
think = ThinkDeepRequest(prompt="My analysis")
|
think = ThinkDeepRequest(prompt="My analysis")
|
||||||
assert think.prompt == "My analysis"
|
assert think.prompt == "My analysis"
|
||||||
|
|
||||||
# Test precommit tool uses prompt (optional)
|
# Test precommit tool uses prompt (optional)
|
||||||
precommit = PrecommitRequest(path="/repo", prompt="Fix bug")
|
precommit = PrecommitRequest(path="/repo", prompt="Fix bug")
|
||||||
assert precommit.prompt == "Fix bug"
|
assert precommit.prompt == "Fix bug"
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ Claude had shared in earlier turns.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from tests.mock_helpers import create_mock_provider
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
from tools.base import BaseTool, ToolRequest
|
from tools.base import BaseTool, ToolRequest
|
||||||
from utils.conversation_memory import ConversationTurn, ThreadContext
|
from utils.conversation_memory import ConversationTurn, ThreadContext
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ class TestConversationHistoryBugFix:
|
|||||||
content="Response with conversation context",
|
content="Response with conversation context",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_provider.generate_content.side_effect = capture_prompt
|
mock_provider.generate_content.side_effect = capture_prompt
|
||||||
@@ -176,7 +176,7 @@ class TestConversationHistoryBugFix:
|
|||||||
content="Response without history",
|
content="Response without history",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_provider.generate_content.side_effect = capture_prompt
|
mock_provider.generate_content.side_effect = capture_prompt
|
||||||
@@ -214,7 +214,7 @@ class TestConversationHistoryBugFix:
|
|||||||
content="New conversation response",
|
content="New conversation response",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_provider.generate_content.side_effect = capture_prompt
|
mock_provider.generate_content.side_effect = capture_prompt
|
||||||
@@ -298,7 +298,7 @@ class TestConversationHistoryBugFix:
|
|||||||
content="Analysis of new files complete",
|
content="Analysis of new files complete",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_provider.generate_content.side_effect = capture_prompt
|
mock_provider.generate_content.side_effect = capture_prompt
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ allowing multi-turn conversations to span multiple tool types.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from tests.mock_helpers import create_mock_provider
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
from tools.base import BaseTool, ToolRequest
|
from tools.base import BaseTool, ToolRequest
|
||||||
from utils.conversation_memory import ConversationTurn, ThreadContext
|
from utils.conversation_memory import ConversationTurn, ThreadContext
|
||||||
|
|
||||||
@@ -117,7 +117,7 @@ class TestCrossToolContinuation:
|
|||||||
content=content_with_followup,
|
content=content_with_followup,
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -165,7 +165,7 @@ class TestCrossToolContinuation:
|
|||||||
content="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks.",
|
content="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks.",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -285,7 +285,7 @@ class TestCrossToolContinuation:
|
|||||||
content="Security review of auth.py shows vulnerabilities",
|
content="Security review of auth.py shows vulnerabilities",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from tests.mock_helpers import create_mock_provider
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from mcp.types import TextContent
|
from mcp.types import TextContent
|
||||||
@@ -77,7 +76,7 @@ class TestLargePromptHandling:
|
|||||||
content="This is a test response",
|
content="This is a test response",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -102,7 +101,7 @@ class TestLargePromptHandling:
|
|||||||
content="Processed large prompt",
|
content="Processed large prompt",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -214,7 +213,7 @@ class TestLargePromptHandling:
|
|||||||
content="Success",
|
content="Success",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -247,7 +246,7 @@ class TestLargePromptHandling:
|
|||||||
content="Success",
|
content="Success",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -278,7 +277,7 @@ class TestLargePromptHandling:
|
|||||||
content="Success",
|
content="Success",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -300,7 +299,7 @@ class TestLargePromptHandling:
|
|||||||
content="Success",
|
content="Success",
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
|
|||||||
@@ -1,141 +0,0 @@
|
|||||||
"""
|
|
||||||
Live integration tests for google-genai library
|
|
||||||
These tests require GEMINI_API_KEY to be set and will make real API calls
|
|
||||||
|
|
||||||
To run these tests manually:
|
|
||||||
python tests/test_live_integration.py
|
|
||||||
|
|
||||||
Note: These tests are excluded from regular pytest runs to avoid API rate limits.
|
|
||||||
They confirm that the google-genai library integration works correctly with live data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path to allow imports
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from tools.analyze import AnalyzeTool
|
|
||||||
from tools.thinkdeep import ThinkDeepTool
|
|
||||||
|
|
||||||
|
|
||||||
async def run_manual_live_tests():
|
|
||||||
"""Run live tests manually without pytest"""
|
|
||||||
print("🚀 Running manual live integration tests...")
|
|
||||||
|
|
||||||
# Check API key
|
|
||||||
if not os.environ.get("GEMINI_API_KEY"):
|
|
||||||
print("❌ GEMINI_API_KEY not found. Set it to run live tests.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test google-genai import
|
|
||||||
|
|
||||||
print("✅ google-genai library import successful")
|
|
||||||
|
|
||||||
# Test tool integration
|
|
||||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
|
||||||
f.write("def hello(): return 'world'")
|
|
||||||
temp_path = f.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test AnalyzeTool
|
|
||||||
tool = AnalyzeTool()
|
|
||||||
result = await tool.execute(
|
|
||||||
{
|
|
||||||
"files": [temp_path],
|
|
||||||
"prompt": "What does this code do?",
|
|
||||||
"thinking_mode": "low",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if result and result[0].text:
|
|
||||||
print("✅ AnalyzeTool live test successful")
|
|
||||||
else:
|
|
||||||
print("❌ AnalyzeTool live test failed")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Test ThinkDeepTool
|
|
||||||
think_tool = ThinkDeepTool()
|
|
||||||
result = await think_tool.execute(
|
|
||||||
{
|
|
||||||
"prompt": "Testing live integration",
|
|
||||||
"thinking_mode": "minimal", # Fast test
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if result and result[0].text and "Extended Analysis" in result[0].text:
|
|
||||||
print("✅ ThinkDeepTool live test successful")
|
|
||||||
else:
|
|
||||||
print("❌ ThinkDeepTool live test failed")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Test collaboration/clarification request
|
|
||||||
print("\n🔄 Testing dynamic context request (collaboration)...")
|
|
||||||
|
|
||||||
# Create a specific test case designed to trigger clarification
|
|
||||||
# We'll use analyze tool with a question that requires seeing files
|
|
||||||
analyze_tool = AnalyzeTool()
|
|
||||||
|
|
||||||
# Ask about dependencies without providing package files
|
|
||||||
result = await analyze_tool.execute(
|
|
||||||
{
|
|
||||||
"files": [temp_path], # Only Python file, no package.json
|
|
||||||
"prompt": "What npm packages and their versions does this project depend on? List all dependencies.",
|
|
||||||
"thinking_mode": "minimal", # Fast test
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if result and result[0].text:
|
|
||||||
response_data = json.loads(result[0].text)
|
|
||||||
print(f" Response status: {response_data['status']}")
|
|
||||||
|
|
||||||
if response_data["status"] == "requires_clarification":
|
|
||||||
print("✅ Dynamic context request successfully triggered!")
|
|
||||||
clarification = json.loads(response_data["content"])
|
|
||||||
print(f" Gemini asks: {clarification.get('question', 'N/A')}")
|
|
||||||
if "files_needed" in clarification:
|
|
||||||
print(f" Files requested: {clarification['files_needed']}")
|
|
||||||
# Verify it's asking for package-related files
|
|
||||||
expected_files = [
|
|
||||||
"package.json",
|
|
||||||
"package-lock.json",
|
|
||||||
"yarn.lock",
|
|
||||||
]
|
|
||||||
if any(f in str(clarification["files_needed"]) for f in expected_files):
|
|
||||||
print(" ✅ Correctly identified missing package files!")
|
|
||||||
else:
|
|
||||||
print(" ⚠️ Unexpected files requested")
|
|
||||||
else:
|
|
||||||
# This is a failure - we specifically designed this to need clarification
|
|
||||||
print("❌ Expected clarification request but got direct response")
|
|
||||||
print(" This suggests the dynamic context feature may not be working")
|
|
||||||
print(" Response:", response_data.get("content", "")[:200])
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
print("❌ Collaboration test failed - no response")
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
Path(temp_path).unlink(missing_ok=True)
|
|
||||||
|
|
||||||
print("\n🎉 All manual live tests passed!")
|
|
||||||
print("✅ google-genai library working correctly")
|
|
||||||
print("✅ All tools can make live API calls")
|
|
||||||
print("✅ Thinking modes functioning properly")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Live test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Run live tests when script is executed directly
|
|
||||||
success = asyncio.run(run_manual_live_tests())
|
|
||||||
exit(0 if success else 1)
|
|
||||||
@@ -167,9 +167,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
|||||||
add_turn(thread_id, "assistant", "First response", files=[config_path], tool_name="precommit")
|
add_turn(thread_id, "assistant", "First response", files=[config_path], tool_name="precommit")
|
||||||
|
|
||||||
# Second request with continuation - should skip already embedded files
|
# Second request with continuation - should skip already embedded files
|
||||||
PrecommitRequest(
|
PrecommitRequest(path=temp_dir, files=[config_path], continuation_id=thread_id, prompt="Follow-up review")
|
||||||
path=temp_dir, files=[config_path], continuation_id=thread_id, prompt="Follow-up review"
|
|
||||||
)
|
|
||||||
|
|
||||||
files_to_embed_2 = tool.filter_new_files([config_path], thread_id)
|
files_to_embed_2 = tool.filter_new_files([config_path], thread_id)
|
||||||
assert len(files_to_embed_2) == 0, "Continuation should skip already embedded files"
|
assert len(files_to_embed_2) == 0, "Continuation should skip already embedded files"
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ normal-sized prompts after implementing the large prompt handling feature.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from tests.mock_helpers import create_mock_provider
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -33,7 +32,7 @@ class TestPromptRegression:
|
|||||||
content=text,
|
content=text,
|
||||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||||
model_name="gemini-2.0-flash-exp",
|
model_name="gemini-2.0-flash-exp",
|
||||||
metadata={"finish_reason": "STOP"}
|
metadata={"finish_reason": "STOP"},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _create_response
|
return _create_response
|
||||||
@@ -47,7 +46,9 @@ class TestPromptRegression:
|
|||||||
mock_provider = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = mock_model_response("This is a helpful response about Python.")
|
mock_provider.generate_content.return_value = mock_model_response(
|
||||||
|
"This is a helpful response about Python."
|
||||||
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
result = await tool.execute({"prompt": "Explain Python decorators"})
|
result = await tool.execute({"prompt": "Explain Python decorators"})
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
"""Tests for the model provider abstraction system"""
|
"""Tests for the model provider abstraction system"""
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from providers import ModelProviderRegistry, ModelProvider, ModelResponse, ModelCapabilities
|
from providers import ModelProviderRegistry, ModelResponse
|
||||||
from providers.base import ProviderType
|
from providers.base import ProviderType
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
@@ -12,56 +11,56 @@ from providers.openai import OpenAIModelProvider
|
|||||||
|
|
||||||
class TestModelProviderRegistry:
|
class TestModelProviderRegistry:
|
||||||
"""Test the model provider registry"""
|
"""Test the model provider registry"""
|
||||||
|
|
||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
"""Clear registry before each test"""
|
"""Clear registry before each test"""
|
||||||
ModelProviderRegistry._providers.clear()
|
ModelProviderRegistry._providers.clear()
|
||||||
ModelProviderRegistry._initialized_providers.clear()
|
ModelProviderRegistry._initialized_providers.clear()
|
||||||
|
|
||||||
def test_register_provider(self):
|
def test_register_provider(self):
|
||||||
"""Test registering a provider"""
|
"""Test registering a provider"""
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
assert ProviderType.GOOGLE in ModelProviderRegistry._providers
|
assert ProviderType.GOOGLE in ModelProviderRegistry._providers
|
||||||
assert ModelProviderRegistry._providers[ProviderType.GOOGLE] == GeminiModelProvider
|
assert ModelProviderRegistry._providers[ProviderType.GOOGLE] == GeminiModelProvider
|
||||||
|
|
||||||
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
|
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
|
||||||
def test_get_provider(self):
|
def test_get_provider(self):
|
||||||
"""Test getting a provider instance"""
|
"""Test getting a provider instance"""
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
||||||
|
|
||||||
assert provider is not None
|
assert provider is not None
|
||||||
assert isinstance(provider, GeminiModelProvider)
|
assert isinstance(provider, GeminiModelProvider)
|
||||||
assert provider.api_key == "test-key"
|
assert provider.api_key == "test-key"
|
||||||
|
|
||||||
@patch.dict(os.environ, {}, clear=True)
|
@patch.dict(os.environ, {}, clear=True)
|
||||||
def test_get_provider_no_api_key(self):
|
def test_get_provider_no_api_key(self):
|
||||||
"""Test getting provider without API key returns None"""
|
"""Test getting provider without API key returns None"""
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
||||||
|
|
||||||
assert provider is None
|
assert provider is None
|
||||||
|
|
||||||
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
|
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
|
||||||
def test_get_provider_for_model(self):
|
def test_get_provider_for_model(self):
|
||||||
"""Test getting provider for a specific model"""
|
"""Test getting provider for a specific model"""
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash-exp")
|
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash-exp")
|
||||||
|
|
||||||
assert provider is not None
|
assert provider is not None
|
||||||
assert isinstance(provider, GeminiModelProvider)
|
assert isinstance(provider, GeminiModelProvider)
|
||||||
|
|
||||||
def test_get_available_providers(self):
|
def test_get_available_providers(self):
|
||||||
"""Test getting list of available providers"""
|
"""Test getting list of available providers"""
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
providers = ModelProviderRegistry.get_available_providers()
|
providers = ModelProviderRegistry.get_available_providers()
|
||||||
|
|
||||||
assert len(providers) == 2
|
assert len(providers) == 2
|
||||||
assert ProviderType.GOOGLE in providers
|
assert ProviderType.GOOGLE in providers
|
||||||
assert ProviderType.OPENAI in providers
|
assert ProviderType.OPENAI in providers
|
||||||
@@ -69,50 +68,50 @@ class TestModelProviderRegistry:
|
|||||||
|
|
||||||
class TestGeminiProvider:
|
class TestGeminiProvider:
|
||||||
"""Test Gemini model provider"""
|
"""Test Gemini model provider"""
|
||||||
|
|
||||||
def test_provider_initialization(self):
|
def test_provider_initialization(self):
|
||||||
"""Test provider initialization"""
|
"""Test provider initialization"""
|
||||||
provider = GeminiModelProvider(api_key="test-key")
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
assert provider.api_key == "test-key"
|
assert provider.api_key == "test-key"
|
||||||
assert provider.get_provider_type() == ProviderType.GOOGLE
|
assert provider.get_provider_type() == ProviderType.GOOGLE
|
||||||
|
|
||||||
def test_get_capabilities(self):
|
def test_get_capabilities(self):
|
||||||
"""Test getting model capabilities"""
|
"""Test getting model capabilities"""
|
||||||
provider = GeminiModelProvider(api_key="test-key")
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
capabilities = provider.get_capabilities("gemini-2.0-flash-exp")
|
capabilities = provider.get_capabilities("gemini-2.0-flash-exp")
|
||||||
|
|
||||||
assert capabilities.provider == ProviderType.GOOGLE
|
assert capabilities.provider == ProviderType.GOOGLE
|
||||||
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
||||||
assert capabilities.max_tokens == 1_048_576
|
assert capabilities.max_tokens == 1_048_576
|
||||||
assert not capabilities.supports_extended_thinking
|
assert not capabilities.supports_extended_thinking
|
||||||
|
|
||||||
def test_get_capabilities_pro_model(self):
|
def test_get_capabilities_pro_model(self):
|
||||||
"""Test getting capabilities for Pro model with thinking support"""
|
"""Test getting capabilities for Pro model with thinking support"""
|
||||||
provider = GeminiModelProvider(api_key="test-key")
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
capabilities = provider.get_capabilities("gemini-2.5-pro-preview-06-05")
|
capabilities = provider.get_capabilities("gemini-2.5-pro-preview-06-05")
|
||||||
|
|
||||||
assert capabilities.supports_extended_thinking
|
assert capabilities.supports_extended_thinking
|
||||||
|
|
||||||
def test_model_shorthand_resolution(self):
|
def test_model_shorthand_resolution(self):
|
||||||
"""Test model shorthand resolution"""
|
"""Test model shorthand resolution"""
|
||||||
provider = GeminiModelProvider(api_key="test-key")
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
assert provider.validate_model_name("flash")
|
assert provider.validate_model_name("flash")
|
||||||
assert provider.validate_model_name("pro")
|
assert provider.validate_model_name("pro")
|
||||||
|
|
||||||
capabilities = provider.get_capabilities("flash")
|
capabilities = provider.get_capabilities("flash")
|
||||||
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
||||||
|
|
||||||
def test_supports_thinking_mode(self):
|
def test_supports_thinking_mode(self):
|
||||||
"""Test thinking mode support detection"""
|
"""Test thinking mode support detection"""
|
||||||
provider = GeminiModelProvider(api_key="test-key")
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
assert not provider.supports_thinking_mode("gemini-2.0-flash-exp")
|
assert not provider.supports_thinking_mode("gemini-2.0-flash-exp")
|
||||||
assert provider.supports_thinking_mode("gemini-2.5-pro-preview-06-05")
|
assert provider.supports_thinking_mode("gemini-2.5-pro-preview-06-05")
|
||||||
|
|
||||||
@patch("google.genai.Client")
|
@patch("google.genai.Client")
|
||||||
def test_generate_content(self, mock_client_class):
|
def test_generate_content(self, mock_client_class):
|
||||||
"""Test content generation"""
|
"""Test content generation"""
|
||||||
@@ -131,15 +130,11 @@ class TestGeminiProvider:
|
|||||||
mock_response.usage_metadata = mock_usage
|
mock_response.usage_metadata = mock_usage
|
||||||
mock_client.models.generate_content.return_value = mock_response
|
mock_client.models.generate_content.return_value = mock_response
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
provider = GeminiModelProvider(api_key="test-key")
|
provider = GeminiModelProvider(api_key="test-key")
|
||||||
|
|
||||||
response = provider.generate_content(
|
response = provider.generate_content(prompt="Test prompt", model_name="gemini-2.0-flash-exp", temperature=0.7)
|
||||||
prompt="Test prompt",
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
temperature=0.7
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, ModelResponse)
|
assert isinstance(response, ModelResponse)
|
||||||
assert response.content == "Generated content"
|
assert response.content == "Generated content"
|
||||||
assert response.model_name == "gemini-2.0-flash-exp"
|
assert response.model_name == "gemini-2.0-flash-exp"
|
||||||
@@ -151,38 +146,38 @@ class TestGeminiProvider:
|
|||||||
|
|
||||||
class TestOpenAIProvider:
|
class TestOpenAIProvider:
|
||||||
"""Test OpenAI model provider"""
|
"""Test OpenAI model provider"""
|
||||||
|
|
||||||
def test_provider_initialization(self):
|
def test_provider_initialization(self):
|
||||||
"""Test provider initialization"""
|
"""Test provider initialization"""
|
||||||
provider = OpenAIModelProvider(api_key="test-key", organization="test-org")
|
provider = OpenAIModelProvider(api_key="test-key", organization="test-org")
|
||||||
|
|
||||||
assert provider.api_key == "test-key"
|
assert provider.api_key == "test-key"
|
||||||
assert provider.organization == "test-org"
|
assert provider.organization == "test-org"
|
||||||
assert provider.get_provider_type() == ProviderType.OPENAI
|
assert provider.get_provider_type() == ProviderType.OPENAI
|
||||||
|
|
||||||
def test_get_capabilities_o3(self):
|
def test_get_capabilities_o3(self):
|
||||||
"""Test getting O3 model capabilities"""
|
"""Test getting O3 model capabilities"""
|
||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
capabilities = provider.get_capabilities("o3-mini")
|
capabilities = provider.get_capabilities("o3-mini")
|
||||||
|
|
||||||
assert capabilities.provider == ProviderType.OPENAI
|
assert capabilities.provider == ProviderType.OPENAI
|
||||||
assert capabilities.model_name == "o3-mini"
|
assert capabilities.model_name == "o3-mini"
|
||||||
assert capabilities.max_tokens == 200_000
|
assert capabilities.max_tokens == 200_000
|
||||||
assert not capabilities.supports_extended_thinking
|
assert not capabilities.supports_extended_thinking
|
||||||
|
|
||||||
def test_validate_model_names(self):
|
def test_validate_model_names(self):
|
||||||
"""Test model name validation"""
|
"""Test model name validation"""
|
||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
assert provider.validate_model_name("o3")
|
assert provider.validate_model_name("o3")
|
||||||
assert provider.validate_model_name("o3-mini")
|
assert provider.validate_model_name("o3-mini")
|
||||||
assert not provider.validate_model_name("gpt-4o")
|
assert not provider.validate_model_name("gpt-4o")
|
||||||
assert not provider.validate_model_name("invalid-model")
|
assert not provider.validate_model_name("invalid-model")
|
||||||
|
|
||||||
def test_no_thinking_mode_support(self):
|
def test_no_thinking_mode_support(self):
|
||||||
"""Test that no OpenAI models support thinking mode"""
|
"""Test that no OpenAI models support thinking mode"""
|
||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
assert not provider.supports_thinking_mode("o3")
|
assert not provider.supports_thinking_mode("o3")
|
||||||
assert not provider.supports_thinking_mode("o3-mini")
|
assert not provider.supports_thinking_mode("o3-mini")
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ Tests for the main server functionality
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from tests.mock_helpers import create_mock_provider
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from server import handle_call_tool, handle_list_tools
|
from server import handle_call_tool, handle_list_tools
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
|
|
||||||
|
|
||||||
class TestServerTools:
|
class TestServerTools:
|
||||||
@@ -56,10 +56,7 @@ class TestServerTools:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content="Chat response",
|
content="Chat response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -81,6 +78,6 @@ class TestServerTools:
|
|||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
|
|
||||||
response = result[0].text
|
response = result[0].text
|
||||||
assert "Gemini MCP Server v" in response # Version agnostic check
|
assert "Zen MCP Server v" in response # Version agnostic check
|
||||||
assert "Available Tools:" in response
|
assert "Available Tools:" in response
|
||||||
assert "thinkdeep" in response
|
assert "thinkdeep" in response
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ Tests for thinking_mode functionality across all tools
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from tests.mock_helpers import create_mock_provider
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
from tools.analyze import AnalyzeTool
|
from tools.analyze import AnalyzeTool
|
||||||
from tools.codereview import CodeReviewTool
|
from tools.codereview import CodeReviewTool
|
||||||
from tools.debug import DebugIssueTool
|
from tools.debug import DebugIssueTool
|
||||||
@@ -45,10 +45,7 @@ class TestThinkingModes:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = True
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content="Minimal thinking response",
|
content="Minimal thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -66,7 +63,9 @@ class TestThinkingModes:
|
|||||||
# Verify generate_content was called with thinking_mode
|
# Verify generate_content was called with thinking_mode
|
||||||
mock_provider.generate_content.assert_called_once()
|
mock_provider.generate_content.assert_called_once()
|
||||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
assert call_kwargs.get("thinking_mode") == "minimal" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None) # thinking_mode parameter
|
assert call_kwargs.get("thinking_mode") == "minimal" or (
|
||||||
|
not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None
|
||||||
|
) # thinking_mode parameter
|
||||||
|
|
||||||
# Parse JSON response
|
# Parse JSON response
|
||||||
import json
|
import json
|
||||||
@@ -83,10 +82,7 @@ class TestThinkingModes:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = True
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content="Low thinking response",
|
content="Low thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -104,7 +100,9 @@ class TestThinkingModes:
|
|||||||
# Verify generate_content was called with thinking_mode
|
# Verify generate_content was called with thinking_mode
|
||||||
mock_provider.generate_content.assert_called_once()
|
mock_provider.generate_content.assert_called_once()
|
||||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
assert call_kwargs.get("thinking_mode") == "low" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
assert call_kwargs.get("thinking_mode") == "low" or (
|
||||||
|
not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None
|
||||||
|
)
|
||||||
|
|
||||||
assert "Code Review" in result[0].text
|
assert "Code Review" in result[0].text
|
||||||
|
|
||||||
@@ -116,10 +114,7 @@ class TestThinkingModes:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = True
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content="Medium thinking response",
|
content="Medium thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -136,7 +131,9 @@ class TestThinkingModes:
|
|||||||
# Verify generate_content was called with thinking_mode
|
# Verify generate_content was called with thinking_mode
|
||||||
mock_provider.generate_content.assert_called_once()
|
mock_provider.generate_content.assert_called_once()
|
||||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
assert call_kwargs.get("thinking_mode") == "medium" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
assert call_kwargs.get("thinking_mode") == "medium" or (
|
||||||
|
not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None
|
||||||
|
)
|
||||||
|
|
||||||
assert "Debug Analysis" in result[0].text
|
assert "Debug Analysis" in result[0].text
|
||||||
|
|
||||||
@@ -148,10 +145,7 @@ class TestThinkingModes:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = True
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content="High thinking response",
|
content="High thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -169,7 +163,9 @@ class TestThinkingModes:
|
|||||||
# Verify generate_content was called with thinking_mode
|
# Verify generate_content was called with thinking_mode
|
||||||
mock_provider.generate_content.assert_called_once()
|
mock_provider.generate_content.assert_called_once()
|
||||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
assert call_kwargs.get("thinking_mode") == "high" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
assert call_kwargs.get("thinking_mode") == "high" or (
|
||||||
|
not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("tools.base.BaseTool.get_model_provider")
|
@patch("tools.base.BaseTool.get_model_provider")
|
||||||
@@ -179,10 +175,7 @@ class TestThinkingModes:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = True
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content="Max thinking response",
|
content="Max thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -199,7 +192,9 @@ class TestThinkingModes:
|
|||||||
# Verify generate_content was called with thinking_mode
|
# Verify generate_content was called with thinking_mode
|
||||||
mock_provider.generate_content.assert_called_once()
|
mock_provider.generate_content.assert_called_once()
|
||||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||||
assert call_kwargs.get("thinking_mode") == "high" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
assert call_kwargs.get("thinking_mode") == "high" or (
|
||||||
|
not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None
|
||||||
|
)
|
||||||
|
|
||||||
assert "Extended Analysis by Gemini" in result[0].text
|
assert "Extended Analysis by Gemini" in result[0].text
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ Tests for individual tool implementations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from tests.mock_helpers import create_mock_provider
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.mock_helpers import create_mock_provider
|
||||||
from tools import AnalyzeTool, ChatTool, CodeReviewTool, DebugIssueTool, ThinkDeepTool
|
from tools import AnalyzeTool, ChatTool, CodeReviewTool, DebugIssueTool, ThinkDeepTool
|
||||||
|
|
||||||
|
|
||||||
@@ -37,10 +37,7 @@ class TestThinkDeepTool:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = True
|
mock_provider.supports_thinking_mode.return_value = True
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content="Extended analysis",
|
content="Extended analysis", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -91,10 +88,7 @@ class TestCodeReviewTool:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content="Security issues found",
|
content="Security issues found", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -139,10 +133,7 @@ class TestDebugIssueTool:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content="Root cause: race condition",
|
content="Root cause: race condition", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -190,10 +181,7 @@ class TestAnalyzeTool:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content="Architecture analysis",
|
content="Architecture analysis", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
@@ -307,10 +295,7 @@ class TestAbsolutePathValidation:
|
|||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
mock_provider.generate_content.return_value = Mock(
|
mock_provider.generate_content.return_value = Mock(
|
||||||
content="Analysis complete",
|
content="Analysis complete", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||||
usage={},
|
|
||||||
model_name="gemini-2.0-flash-exp",
|
|
||||||
metadata={}
|
|
||||||
)
|
)
|
||||||
mock_get_provider.return_value = mock_provider
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Tool implementations for Gemini MCP Server
|
Tool implementations for Zen MCP Server
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .analyze import AnalyzeTool
|
from .analyze import AnalyzeTool
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ class AnalyzeTool(BaseTool):
|
|||||||
},
|
},
|
||||||
"required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []),
|
"required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||||
}
|
}
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_system_prompt(self) -> str:
|
||||||
|
|||||||
195
tools/base.py
195
tools/base.py
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Base class for all Gemini MCP tools
|
Base class for all Zen MCP tools
|
||||||
|
|
||||||
This module provides the abstract base class that all tools must inherit from.
|
This module provides the abstract base class that all tools must inherit from.
|
||||||
It defines the contract that tools must implement and provides common functionality
|
It defines the contract that tools must implement and provides common functionality
|
||||||
@@ -24,8 +24,8 @@ from mcp.types import TextContent
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from config import DEFAULT_MODEL, MAX_CONTEXT_TOKENS, MCP_PROMPT_SIZE_LIMIT
|
from config import DEFAULT_MODEL, MAX_CONTEXT_TOKENS, MCP_PROMPT_SIZE_LIMIT
|
||||||
|
from providers import ModelProvider, ModelProviderRegistry
|
||||||
from utils import check_token_limit
|
from utils import check_token_limit
|
||||||
from providers import ModelProviderRegistry, ModelProvider, ModelResponse
|
|
||||||
from utils.conversation_memory import (
|
from utils.conversation_memory import (
|
||||||
MAX_CONVERSATION_TURNS,
|
MAX_CONVERSATION_TURNS,
|
||||||
add_turn,
|
add_turn,
|
||||||
@@ -146,21 +146,21 @@ class BaseTool(ABC):
|
|||||||
def get_model_field_schema(self) -> dict[str, Any]:
|
def get_model_field_schema(self) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate the model field schema based on auto mode configuration.
|
Generate the model field schema based on auto mode configuration.
|
||||||
|
|
||||||
When auto mode is enabled, the model parameter becomes required
|
When auto mode is enabled, the model parameter becomes required
|
||||||
and includes detailed descriptions of each model's capabilities.
|
and includes detailed descriptions of each model's capabilities.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict containing the model field JSON schema
|
Dict containing the model field JSON schema
|
||||||
"""
|
"""
|
||||||
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
||||||
|
|
||||||
if IS_AUTO_MODE:
|
if IS_AUTO_MODE:
|
||||||
# In auto mode, model is required and we provide detailed descriptions
|
# In auto mode, model is required and we provide detailed descriptions
|
||||||
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
|
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
|
||||||
for model, desc in MODEL_CAPABILITIES_DESC.items():
|
for model, desc in MODEL_CAPABILITIES_DESC.items():
|
||||||
model_desc_parts.append(f"- '{model}': {desc}")
|
model_desc_parts.append(f"- '{model}': {desc}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "\n".join(model_desc_parts),
|
"description": "\n".join(model_desc_parts),
|
||||||
@@ -169,12 +169,12 @@ class BaseTool(ABC):
|
|||||||
else:
|
else:
|
||||||
# Normal mode - model is optional with default
|
# Normal mode - model is optional with default
|
||||||
available_models = list(MODEL_CAPABILITIES_DESC.keys())
|
available_models = list(MODEL_CAPABILITIES_DESC.keys())
|
||||||
models_str = ', '.join(f"'{m}'" for m in available_models)
|
models_str = ", ".join(f"'{m}'" for m in available_models)
|
||||||
return {
|
return {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": f"Model to use. Available: {models_str}. Defaults to '{DEFAULT_MODEL}' if not specified.",
|
"description": f"Model to use. Available: {models_str}. Defaults to '{DEFAULT_MODEL}' if not specified.",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_default_temperature(self) -> float:
|
def get_default_temperature(self) -> float:
|
||||||
"""
|
"""
|
||||||
Return the default temperature setting for this tool.
|
Return the default temperature setting for this tool.
|
||||||
@@ -257,9 +257,7 @@ class BaseTool(ABC):
|
|||||||
# Safety check: If no files are marked as embedded but we have a continuation_id,
|
# Safety check: If no files are marked as embedded but we have a continuation_id,
|
||||||
# this might indicate an issue with conversation history. Be conservative.
|
# this might indicate an issue with conversation history. Be conservative.
|
||||||
if not embedded_files:
|
if not embedded_files:
|
||||||
logger.debug(
|
logger.debug(f"{self.name} tool: No files found in conversation history for thread {continuation_id}")
|
||||||
f"{self.name} tool: No files found in conversation history for thread {continuation_id}"
|
|
||||||
)
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[FILES] {self.name}: No embedded files found, returning all {len(requested_files)} requested files"
|
f"[FILES] {self.name}: No embedded files found, returning all {len(requested_files)} requested files"
|
||||||
)
|
)
|
||||||
@@ -324,7 +322,7 @@ class BaseTool(ABC):
|
|||||||
"""
|
"""
|
||||||
if not request_files:
|
if not request_files:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Note: Even if conversation history is already embedded, we still need to process
|
# Note: Even if conversation history is already embedded, we still need to process
|
||||||
# any NEW files that aren't in the conversation history yet. The filter_new_files
|
# any NEW files that aren't in the conversation history yet. The filter_new_files
|
||||||
# method will correctly identify which files need to be embedded.
|
# method will correctly identify which files need to be embedded.
|
||||||
@@ -345,48 +343,60 @@ class BaseTool(ABC):
|
|||||||
# First check if model_context was passed from server.py
|
# First check if model_context was passed from server.py
|
||||||
model_context = None
|
model_context = None
|
||||||
if arguments:
|
if arguments:
|
||||||
model_context = arguments.get("_model_context") or getattr(self, "_current_arguments", {}).get("_model_context")
|
model_context = arguments.get("_model_context") or getattr(self, "_current_arguments", {}).get(
|
||||||
|
"_model_context"
|
||||||
|
)
|
||||||
|
|
||||||
if model_context:
|
if model_context:
|
||||||
# Use the passed model context
|
# Use the passed model context
|
||||||
try:
|
try:
|
||||||
token_allocation = model_context.calculate_token_allocation()
|
token_allocation = model_context.calculate_token_allocation()
|
||||||
effective_max_tokens = token_allocation.file_tokens - reserve_tokens
|
effective_max_tokens = token_allocation.file_tokens - reserve_tokens
|
||||||
logger.debug(f"[FILES] {self.name}: Using passed model context for {model_context.model_name}: "
|
logger.debug(
|
||||||
f"{token_allocation.file_tokens:,} file tokens from {token_allocation.total_tokens:,} total")
|
f"[FILES] {self.name}: Using passed model context for {model_context.model_name}: "
|
||||||
|
f"{token_allocation.file_tokens:,} file tokens from {token_allocation.total_tokens:,} total"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[FILES] {self.name}: Error using passed model context: {e}")
|
logger.warning(f"[FILES] {self.name}: Error using passed model context: {e}")
|
||||||
# Fall through to manual calculation
|
# Fall through to manual calculation
|
||||||
model_context = None
|
model_context = None
|
||||||
|
|
||||||
if not model_context:
|
if not model_context:
|
||||||
# Manual calculation as fallback
|
# Manual calculation as fallback
|
||||||
model_name = getattr(self, "_current_model_name", None) or DEFAULT_MODEL
|
model_name = getattr(self, "_current_model_name", None) or DEFAULT_MODEL
|
||||||
try:
|
try:
|
||||||
provider = self.get_model_provider(model_name)
|
provider = self.get_model_provider(model_name)
|
||||||
capabilities = provider.get_capabilities(model_name)
|
capabilities = provider.get_capabilities(model_name)
|
||||||
|
|
||||||
# Calculate content allocation based on model capacity
|
# Calculate content allocation based on model capacity
|
||||||
if capabilities.max_tokens < 300_000:
|
if capabilities.max_tokens < 300_000:
|
||||||
# Smaller context models: 60% content, 40% response
|
# Smaller context models: 60% content, 40% response
|
||||||
model_content_tokens = int(capabilities.max_tokens * 0.6)
|
model_content_tokens = int(capabilities.max_tokens * 0.6)
|
||||||
else:
|
else:
|
||||||
# Larger context models: 80% content, 20% response
|
# Larger context models: 80% content, 20% response
|
||||||
model_content_tokens = int(capabilities.max_tokens * 0.8)
|
model_content_tokens = int(capabilities.max_tokens * 0.8)
|
||||||
|
|
||||||
effective_max_tokens = model_content_tokens - reserve_tokens
|
effective_max_tokens = model_content_tokens - reserve_tokens
|
||||||
logger.debug(f"[FILES] {self.name}: Using model-specific limit for {model_name}: "
|
logger.debug(
|
||||||
f"{model_content_tokens:,} content tokens from {capabilities.max_tokens:,} total")
|
f"[FILES] {self.name}: Using model-specific limit for {model_name}: "
|
||||||
|
f"{model_content_tokens:,} content tokens from {capabilities.max_tokens:,} total"
|
||||||
|
)
|
||||||
except (ValueError, AttributeError) as e:
|
except (ValueError, AttributeError) as e:
|
||||||
# Handle specific errors: provider not found, model not supported, missing attributes
|
# Handle specific errors: provider not found, model not supported, missing attributes
|
||||||
logger.warning(f"[FILES] {self.name}: Could not get model capabilities for {model_name}: {type(e).__name__}: {e}")
|
logger.warning(
|
||||||
|
f"[FILES] {self.name}: Could not get model capabilities for {model_name}: {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
# Fall back to conservative default for safety
|
# Fall back to conservative default for safety
|
||||||
from config import MAX_CONTENT_TOKENS
|
from config import MAX_CONTENT_TOKENS
|
||||||
|
|
||||||
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
|
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Catch any other unexpected errors
|
# Catch any other unexpected errors
|
||||||
logger.error(f"[FILES] {self.name}: Unexpected error getting model capabilities: {type(e).__name__}: {e}")
|
logger.error(
|
||||||
|
f"[FILES] {self.name}: Unexpected error getting model capabilities: {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
from config import MAX_CONTENT_TOKENS
|
from config import MAX_CONTENT_TOKENS
|
||||||
|
|
||||||
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
|
effective_max_tokens = min(MAX_CONTENT_TOKENS, 100_000) - reserve_tokens
|
||||||
|
|
||||||
# Ensure we have a reasonable minimum budget
|
# Ensure we have a reasonable minimum budget
|
||||||
@@ -394,12 +404,16 @@ class BaseTool(ABC):
|
|||||||
|
|
||||||
files_to_embed = self.filter_new_files(request_files, continuation_id)
|
files_to_embed = self.filter_new_files(request_files, continuation_id)
|
||||||
logger.debug(f"[FILES] {self.name}: Will embed {len(files_to_embed)} files after filtering")
|
logger.debug(f"[FILES] {self.name}: Will embed {len(files_to_embed)} files after filtering")
|
||||||
|
|
||||||
# Log the specific files for debugging/testing
|
# Log the specific files for debugging/testing
|
||||||
if files_to_embed:
|
if files_to_embed:
|
||||||
logger.info(f"[FILE_PROCESSING] {self.name} tool will embed new files: {', '.join([os.path.basename(f) for f in files_to_embed])}")
|
logger.info(
|
||||||
|
f"[FILE_PROCESSING] {self.name} tool will embed new files: {', '.join([os.path.basename(f) for f in files_to_embed])}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"[FILE_PROCESSING] {self.name} tool: No new files to embed (all files already in conversation history)")
|
logger.info(
|
||||||
|
f"[FILE_PROCESSING] {self.name} tool: No new files to embed (all files already in conversation history)"
|
||||||
|
)
|
||||||
|
|
||||||
content_parts = []
|
content_parts = []
|
||||||
|
|
||||||
@@ -688,20 +702,20 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
|
|
||||||
# Check if we have continuation_id - if so, conversation history is already embedded
|
# Check if we have continuation_id - if so, conversation history is already embedded
|
||||||
continuation_id = getattr(request, "continuation_id", None)
|
continuation_id = getattr(request, "continuation_id", None)
|
||||||
|
|
||||||
if continuation_id:
|
if continuation_id:
|
||||||
# When continuation_id is present, server.py has already injected the
|
# When continuation_id is present, server.py has already injected the
|
||||||
# conversation history into the appropriate field. We need to check if
|
# conversation history into the appropriate field. We need to check if
|
||||||
# the prompt already contains conversation history marker.
|
# the prompt already contains conversation history marker.
|
||||||
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
|
logger.debug(f"Continuing {self.name} conversation with thread {continuation_id}")
|
||||||
|
|
||||||
# Store the original arguments to detect enhanced prompts
|
# Store the original arguments to detect enhanced prompts
|
||||||
self._has_embedded_history = False
|
self._has_embedded_history = False
|
||||||
|
|
||||||
# Check if conversation history is already embedded in the prompt field
|
# Check if conversation history is already embedded in the prompt field
|
||||||
field_value = getattr(request, "prompt", "")
|
field_value = getattr(request, "prompt", "")
|
||||||
field_name = "prompt"
|
field_name = "prompt"
|
||||||
|
|
||||||
if "=== CONVERSATION HISTORY ===" in field_value:
|
if "=== CONVERSATION HISTORY ===" in field_value:
|
||||||
# Conversation history is already embedded, use it directly
|
# Conversation history is already embedded, use it directly
|
||||||
prompt = field_value
|
prompt = field_value
|
||||||
@@ -714,9 +728,10 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
else:
|
else:
|
||||||
# New conversation, prepare prompt normally
|
# New conversation, prepare prompt normally
|
||||||
prompt = await self.prepare_prompt(request)
|
prompt = await self.prepare_prompt(request)
|
||||||
|
|
||||||
# Add follow-up instructions for new conversations
|
# Add follow-up instructions for new conversations
|
||||||
from server import get_follow_up_instructions
|
from server import get_follow_up_instructions
|
||||||
|
|
||||||
follow_up_instructions = get_follow_up_instructions(0) # New conversation, turn 0
|
follow_up_instructions = get_follow_up_instructions(0) # New conversation, turn 0
|
||||||
prompt = f"{prompt}\n\n{follow_up_instructions}"
|
prompt = f"{prompt}\n\n{follow_up_instructions}"
|
||||||
logger.debug(f"Added follow-up instructions for new {self.name} conversation")
|
logger.debug(f"Added follow-up instructions for new {self.name} conversation")
|
||||||
@@ -725,9 +740,10 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
model_name = getattr(request, "model", None)
|
model_name = getattr(request, "model", None)
|
||||||
if not model_name:
|
if not model_name:
|
||||||
model_name = DEFAULT_MODEL
|
model_name = DEFAULT_MODEL
|
||||||
|
|
||||||
# In auto mode, model parameter is required
|
# In auto mode, model parameter is required
|
||||||
from config import IS_AUTO_MODE
|
from config import IS_AUTO_MODE
|
||||||
|
|
||||||
if IS_AUTO_MODE and model_name.lower() == "auto":
|
if IS_AUTO_MODE and model_name.lower() == "auto":
|
||||||
error_output = ToolOutput(
|
error_output = ToolOutput(
|
||||||
status="error",
|
status="error",
|
||||||
@@ -735,10 +751,10 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
content_type="text",
|
content_type="text",
|
||||||
)
|
)
|
||||||
return [TextContent(type="text", text=error_output.model_dump_json())]
|
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||||
|
|
||||||
# Store model name for use by helper methods like _prepare_file_content_for_prompt
|
# Store model name for use by helper methods like _prepare_file_content_for_prompt
|
||||||
self._current_model_name = model_name
|
self._current_model_name = model_name
|
||||||
|
|
||||||
temperature = getattr(request, "temperature", None)
|
temperature = getattr(request, "temperature", None)
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = self.get_default_temperature()
|
temperature = self.get_default_temperature()
|
||||||
@@ -748,14 +764,14 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
|
|
||||||
# Get the appropriate model provider
|
# Get the appropriate model provider
|
||||||
provider = self.get_model_provider(model_name)
|
provider = self.get_model_provider(model_name)
|
||||||
|
|
||||||
# Validate and correct temperature for this model
|
# Validate and correct temperature for this model
|
||||||
temperature, temp_warnings = self._validate_and_correct_temperature(model_name, temperature)
|
temperature, temp_warnings = self._validate_and_correct_temperature(model_name, temperature)
|
||||||
|
|
||||||
# Log any temperature corrections
|
# Log any temperature corrections
|
||||||
for warning in temp_warnings:
|
for warning in temp_warnings:
|
||||||
logger.warning(warning)
|
logger.warning(warning)
|
||||||
|
|
||||||
# Get system prompt for this tool
|
# Get system prompt for this tool
|
||||||
system_prompt = self.get_system_prompt()
|
system_prompt = self.get_system_prompt()
|
||||||
|
|
||||||
@@ -763,16 +779,16 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.name}")
|
logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.name}")
|
||||||
logger.info(f"Using model: {model_name} via {provider.get_provider_type().value} provider")
|
logger.info(f"Using model: {model_name} via {provider.get_provider_type().value} provider")
|
||||||
logger.debug(f"Prompt length: {len(prompt)} characters")
|
logger.debug(f"Prompt length: {len(prompt)} characters")
|
||||||
|
|
||||||
# Generate content with provider abstraction
|
# Generate content with provider abstraction
|
||||||
model_response = provider.generate_content(
|
model_response = provider.generate_content(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
|
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Received response from {provider.get_provider_type().value} API for {self.name}")
|
logger.info(f"Received response from {provider.get_provider_type().value} API for {self.name}")
|
||||||
|
|
||||||
# Process the model's response
|
# Process the model's response
|
||||||
@@ -781,11 +797,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
|
|
||||||
# Parse response to check for clarification requests or format output
|
# Parse response to check for clarification requests or format output
|
||||||
# Pass model info for conversation tracking
|
# Pass model info for conversation tracking
|
||||||
model_info = {
|
model_info = {"provider": provider, "model_name": model_name, "model_response": model_response}
|
||||||
"provider": provider,
|
|
||||||
"model_name": model_name,
|
|
||||||
"model_response": model_response
|
|
||||||
}
|
|
||||||
tool_output = self._parse_response(raw_text, request, model_info)
|
tool_output = self._parse_response(raw_text, request, model_info)
|
||||||
logger.info(f"Successfully completed {self.name} tool execution")
|
logger.info(f"Successfully completed {self.name} tool execution")
|
||||||
|
|
||||||
@@ -819,15 +831,15 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None
|
thinking_mode=thinking_mode if provider.supports_thinking_mode(model_name) else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if retry_response.content:
|
if retry_response.content:
|
||||||
# If successful, process normally
|
# If successful, process normally
|
||||||
retry_model_info = {
|
retry_model_info = {
|
||||||
"provider": provider,
|
"provider": provider,
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
"model_response": retry_response
|
"model_response": retry_response,
|
||||||
}
|
}
|
||||||
tool_output = self._parse_response(retry_response.content, request, retry_model_info)
|
tool_output = self._parse_response(retry_response.content, request, retry_model_info)
|
||||||
return [TextContent(type="text", text=tool_output.model_dump_json())]
|
return [TextContent(type="text", text=tool_output.model_dump_json())]
|
||||||
@@ -916,7 +928,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
model_provider = None
|
model_provider = None
|
||||||
model_name = None
|
model_name = None
|
||||||
model_metadata = None
|
model_metadata = None
|
||||||
|
|
||||||
if model_info:
|
if model_info:
|
||||||
provider = model_info.get("provider")
|
provider = model_info.get("provider")
|
||||||
if provider:
|
if provider:
|
||||||
@@ -924,11 +936,8 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
model_name = model_info.get("model_name")
|
model_name = model_info.get("model_name")
|
||||||
model_response = model_info.get("model_response")
|
model_response = model_info.get("model_response")
|
||||||
if model_response:
|
if model_response:
|
||||||
model_metadata = {
|
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
|
||||||
"usage": model_response.usage,
|
|
||||||
"metadata": model_response.metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
success = add_turn(
|
success = add_turn(
|
||||||
continuation_id,
|
continuation_id,
|
||||||
"assistant",
|
"assistant",
|
||||||
@@ -986,7 +995,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _create_follow_up_response(self, content: str, follow_up_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
|
def _create_follow_up_response(
|
||||||
|
self, content: str, follow_up_data: dict, request, model_info: Optional[dict] = None
|
||||||
|
) -> ToolOutput:
|
||||||
"""
|
"""
|
||||||
Create a response with follow-up question for conversation threading.
|
Create a response with follow-up question for conversation threading.
|
||||||
|
|
||||||
@@ -1001,13 +1012,13 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
# Always create a new thread (with parent linkage if continuation)
|
# Always create a new thread (with parent linkage if continuation)
|
||||||
continuation_id = getattr(request, "continuation_id", None)
|
continuation_id = getattr(request, "continuation_id", None)
|
||||||
request_files = getattr(request, "files", []) or []
|
request_files = getattr(request, "files", []) or []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create new thread with parent linkage if continuing
|
# Create new thread with parent linkage if continuing
|
||||||
thread_id = create_thread(
|
thread_id = create_thread(
|
||||||
tool_name=self.name,
|
tool_name=self.name,
|
||||||
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
|
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
|
||||||
parent_thread_id=continuation_id # Link to parent thread if continuing
|
parent_thread_id=continuation_id, # Link to parent thread if continuing
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the assistant's response with follow-up
|
# Add the assistant's response with follow-up
|
||||||
@@ -1015,7 +1026,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
model_provider = None
|
model_provider = None
|
||||||
model_name = None
|
model_name = None
|
||||||
model_metadata = None
|
model_metadata = None
|
||||||
|
|
||||||
if model_info:
|
if model_info:
|
||||||
provider = model_info.get("provider")
|
provider = model_info.get("provider")
|
||||||
if provider:
|
if provider:
|
||||||
@@ -1023,11 +1034,8 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
model_name = model_info.get("model_name")
|
model_name = model_info.get("model_name")
|
||||||
model_response = model_info.get("model_response")
|
model_response = model_info.get("model_response")
|
||||||
if model_response:
|
if model_response:
|
||||||
model_metadata = {
|
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
|
||||||
"usage": model_response.usage,
|
|
||||||
"metadata": model_response.metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
add_turn(
|
add_turn(
|
||||||
thread_id, # Add to the new thread
|
thread_id, # Add to the new thread
|
||||||
"assistant",
|
"assistant",
|
||||||
@@ -1088,6 +1096,12 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
Returns:
|
Returns:
|
||||||
Dict with continuation data if opportunity should be offered, None otherwise
|
Dict with continuation data if opportunity should be offered, None otherwise
|
||||||
"""
|
"""
|
||||||
|
# Skip continuation offers in test mode
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.getenv("PYTEST_CURRENT_TEST"):
|
||||||
|
return None
|
||||||
|
|
||||||
continuation_id = getattr(request, "continuation_id", None)
|
continuation_id = getattr(request, "continuation_id", None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1117,7 +1131,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
# If anything fails, don't offer continuation
|
# If anything fails, don't offer continuation
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _create_continuation_offer_response(self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None) -> ToolOutput:
|
def _create_continuation_offer_response(
|
||||||
|
self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None
|
||||||
|
) -> ToolOutput:
|
||||||
"""
|
"""
|
||||||
Create a response offering Claude the opportunity to continue conversation.
|
Create a response offering Claude the opportunity to continue conversation.
|
||||||
|
|
||||||
@@ -1133,9 +1149,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
# Create new thread for potential continuation (with parent link if continuing)
|
# Create new thread for potential continuation (with parent link if continuing)
|
||||||
continuation_id = getattr(request, "continuation_id", None)
|
continuation_id = getattr(request, "continuation_id", None)
|
||||||
thread_id = create_thread(
|
thread_id = create_thread(
|
||||||
tool_name=self.name,
|
tool_name=self.name,
|
||||||
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
|
initial_request=request.model_dump() if hasattr(request, "model_dump") else {},
|
||||||
parent_thread_id=continuation_id # Link to parent if this is a continuation
|
parent_thread_id=continuation_id, # Link to parent if this is a continuation
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add this response as the first turn (assistant turn)
|
# Add this response as the first turn (assistant turn)
|
||||||
@@ -1144,7 +1160,7 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
model_provider = None
|
model_provider = None
|
||||||
model_name = None
|
model_name = None
|
||||||
model_metadata = None
|
model_metadata = None
|
||||||
|
|
||||||
if model_info:
|
if model_info:
|
||||||
provider = model_info.get("provider")
|
provider = model_info.get("provider")
|
||||||
if provider:
|
if provider:
|
||||||
@@ -1152,16 +1168,13 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
model_name = model_info.get("model_name")
|
model_name = model_info.get("model_name")
|
||||||
model_response = model_info.get("model_response")
|
model_response = model_info.get("model_response")
|
||||||
if model_response:
|
if model_response:
|
||||||
model_metadata = {
|
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
|
||||||
"usage": model_response.usage,
|
|
||||||
"metadata": model_response.metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
add_turn(
|
add_turn(
|
||||||
thread_id,
|
thread_id,
|
||||||
"assistant",
|
"assistant",
|
||||||
content,
|
content,
|
||||||
files=request_files,
|
files=request_files,
|
||||||
tool_name=self.name,
|
tool_name=self.name,
|
||||||
model_provider=model_provider,
|
model_provider=model_provider,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@@ -1260,11 +1273,11 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
def _validate_and_correct_temperature(self, model_name: str, temperature: float) -> tuple[float, list[str]]:
|
def _validate_and_correct_temperature(self, model_name: str, temperature: float) -> tuple[float, list[str]]:
|
||||||
"""
|
"""
|
||||||
Validate and correct temperature for the specified model.
|
Validate and correct temperature for the specified model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the model to validate temperature for
|
model_name: Name of the model to validate temperature for
|
||||||
temperature: Temperature value to validate
|
temperature: Temperature value to validate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (corrected_temperature, warning_messages)
|
Tuple of (corrected_temperature, warning_messages)
|
||||||
"""
|
"""
|
||||||
@@ -1272,9 +1285,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
provider = self.get_model_provider(model_name)
|
provider = self.get_model_provider(model_name)
|
||||||
capabilities = provider.get_capabilities(model_name)
|
capabilities = provider.get_capabilities(model_name)
|
||||||
constraint = capabilities.temperature_constraint
|
constraint = capabilities.temperature_constraint
|
||||||
|
|
||||||
warnings = []
|
warnings = []
|
||||||
|
|
||||||
if not constraint.validate(temperature):
|
if not constraint.validate(temperature):
|
||||||
corrected = constraint.get_corrected_value(temperature)
|
corrected = constraint.get_corrected_value(temperature)
|
||||||
warning = (
|
warning = (
|
||||||
@@ -1283,9 +1296,9 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
)
|
)
|
||||||
warnings.append(warning)
|
warnings.append(warning)
|
||||||
return corrected, warnings
|
return corrected, warnings
|
||||||
|
|
||||||
return temperature, warnings
|
return temperature, warnings
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If validation fails for any reason, use the original temperature
|
# If validation fails for any reason, use the original temperature
|
||||||
# and log a warning (but don't fail the request)
|
# and log a warning (but don't fail the request)
|
||||||
@@ -1308,26 +1321,28 @@ If any of these would strengthen your analysis, specify what Claude should searc
|
|||||||
"""
|
"""
|
||||||
# Get provider from registry
|
# Get provider from registry
|
||||||
provider = ModelProviderRegistry.get_provider_for_model(model_name)
|
provider = ModelProviderRegistry.get_provider_for_model(model_name)
|
||||||
|
|
||||||
if not provider:
|
if not provider:
|
||||||
# Try to determine provider from model name patterns
|
# Try to determine provider from model name patterns
|
||||||
if "gemini" in model_name.lower() or model_name.lower() in ["flash", "pro"]:
|
if "gemini" in model_name.lower() or model_name.lower() in ["flash", "pro"]:
|
||||||
# Register Gemini provider if not already registered
|
# Register Gemini provider if not already registered
|
||||||
from providers.gemini import GeminiModelProvider
|
|
||||||
from providers.base import ProviderType
|
from providers.base import ProviderType
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
||||||
elif "gpt" in model_name.lower() or "o3" in model_name.lower():
|
elif "gpt" in model_name.lower() or "o3" in model_name.lower():
|
||||||
# Register OpenAI provider if not already registered
|
# Register OpenAI provider if not already registered
|
||||||
from providers.openai import OpenAIModelProvider
|
|
||||||
from providers.base import ProviderType
|
from providers.base import ProviderType
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
provider = ModelProviderRegistry.get_provider(ProviderType.OPENAI)
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENAI)
|
||||||
|
|
||||||
if not provider:
|
if not provider:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No provider found for model '{model_name}'. "
|
f"No provider found for model '{model_name}'. "
|
||||||
f"Ensure the appropriate API key is set and the model name is correct."
|
f"Ensure the appropriate API key is set and the model name is correct."
|
||||||
)
|
)
|
||||||
|
|
||||||
return provider
|
return provider
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class ChatTool(BaseTool):
|
|||||||
},
|
},
|
||||||
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||||
}
|
}
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_system_prompt(self) -> str:
|
||||||
|
|||||||
@@ -44,7 +44,10 @@ class CodeReviewRequest(ToolRequest):
|
|||||||
description="User's summary of what the code does, expected behavior, constraints, and review objectives",
|
description="User's summary of what the code does, expected behavior, constraints, and review objectives",
|
||||||
)
|
)
|
||||||
review_type: str = Field("full", description="Type of review: full|security|performance|quick")
|
review_type: str = Field("full", description="Type of review: full|security|performance|quick")
|
||||||
focus_on: Optional[str] = Field(None, description="Specific aspects to focus on, or additional context that would help understand areas of concern")
|
focus_on: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Specific aspects to focus on, or additional context that would help understand areas of concern",
|
||||||
|
)
|
||||||
standards: Optional[str] = Field(None, description="Coding standards or guidelines to enforce")
|
standards: Optional[str] = Field(None, description="Coding standards or guidelines to enforce")
|
||||||
severity_filter: str = Field(
|
severity_filter: str = Field(
|
||||||
"all",
|
"all",
|
||||||
@@ -137,7 +140,7 @@ class CodeReviewTool(BaseTool):
|
|||||||
},
|
},
|
||||||
"required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []),
|
"required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||||
}
|
}
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_system_prompt(self) -> str:
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class DebugIssueTool(BaseTool):
|
|||||||
},
|
},
|
||||||
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||||
}
|
}
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_system_prompt(self) -> str:
|
||||||
@@ -201,7 +201,7 @@ Focus on finding the root cause and providing actionable solutions."""
|
|||||||
model_name = "the model"
|
model_name = "the model"
|
||||||
if model_info and model_info.get("model_response"):
|
if model_info and model_info.get("model_response"):
|
||||||
model_name = model_info["model_response"].friendly_name or "the model"
|
model_name = model_info["model_response"].friendly_name or "the model"
|
||||||
|
|
||||||
return f"""{response}
|
return f"""{response}
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ class Precommit(BaseTool):
|
|||||||
# Ensure model parameter has enhanced description
|
# Ensure model parameter has enhanced description
|
||||||
if "properties" in schema and "model" in schema["properties"]:
|
if "properties" in schema and "model" in schema["properties"]:
|
||||||
schema["properties"]["model"] = self.get_model_field_schema()
|
schema["properties"]["model"] = self.get_model_field_schema()
|
||||||
|
|
||||||
# In auto mode, model is required
|
# In auto mode, model is required
|
||||||
if IS_AUTO_MODE and "required" in schema:
|
if IS_AUTO_MODE and "required" in schema:
|
||||||
if "model" not in schema["required"]:
|
if "model" not in schema["required"]:
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ class ThinkDeepTool(BaseTool):
|
|||||||
},
|
},
|
||||||
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||||
}
|
}
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_system_prompt(self) -> str:
|
||||||
@@ -195,7 +195,7 @@ Please provide deep analysis that extends Claude's thinking with:
|
|||||||
model_name = "your fellow developer"
|
model_name = "your fellow developer"
|
||||||
if model_info and model_info.get("model_response"):
|
if model_info and model_info.get("model_response"):
|
||||||
model_name = model_info["model_response"].friendly_name or "your fellow developer"
|
model_name = model_info["model_response"].friendly_name or "your fellow developer"
|
||||||
|
|
||||||
return f"""{response}
|
return f"""{response}
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Utility functions for Gemini MCP Server
|
Utility functions for Zen MCP Server
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .file_utils import CODE_EXTENSIONS, expand_paths, read_file_content, read_files
|
from .file_utils import CODE_EXTENSIONS, expand_paths, read_file_content, read_files
|
||||||
|
|||||||
@@ -312,41 +312,41 @@ def add_turn(
|
|||||||
def get_thread_chain(thread_id: str, max_depth: int = 20) -> list[ThreadContext]:
|
def get_thread_chain(thread_id: str, max_depth: int = 20) -> list[ThreadContext]:
|
||||||
"""
|
"""
|
||||||
Traverse the parent chain to get all threads in conversation sequence.
|
Traverse the parent chain to get all threads in conversation sequence.
|
||||||
|
|
||||||
Retrieves the complete conversation chain by following parent_thread_id
|
Retrieves the complete conversation chain by following parent_thread_id
|
||||||
links. Returns threads in chronological order (oldest first).
|
links. Returns threads in chronological order (oldest first).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: Starting thread ID
|
thread_id: Starting thread ID
|
||||||
max_depth: Maximum chain depth to prevent infinite loops
|
max_depth: Maximum chain depth to prevent infinite loops
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[ThreadContext]: All threads in chain, oldest first
|
list[ThreadContext]: All threads in chain, oldest first
|
||||||
"""
|
"""
|
||||||
chain = []
|
chain = []
|
||||||
current_id = thread_id
|
current_id = thread_id
|
||||||
seen_ids = set()
|
seen_ids = set()
|
||||||
|
|
||||||
# Build chain from current to oldest
|
# Build chain from current to oldest
|
||||||
while current_id and len(chain) < max_depth:
|
while current_id and len(chain) < max_depth:
|
||||||
# Prevent circular references
|
# Prevent circular references
|
||||||
if current_id in seen_ids:
|
if current_id in seen_ids:
|
||||||
logger.warning(f"[THREAD] Circular reference detected in thread chain at {current_id}")
|
logger.warning(f"[THREAD] Circular reference detected in thread chain at {current_id}")
|
||||||
break
|
break
|
||||||
|
|
||||||
seen_ids.add(current_id)
|
seen_ids.add(current_id)
|
||||||
|
|
||||||
context = get_thread(current_id)
|
context = get_thread(current_id)
|
||||||
if not context:
|
if not context:
|
||||||
logger.debug(f"[THREAD] Thread {current_id} not found in chain traversal")
|
logger.debug(f"[THREAD] Thread {current_id} not found in chain traversal")
|
||||||
break
|
break
|
||||||
|
|
||||||
chain.append(context)
|
chain.append(context)
|
||||||
current_id = context.parent_thread_id
|
current_id = context.parent_thread_id
|
||||||
|
|
||||||
# Reverse to get chronological order (oldest first)
|
# Reverse to get chronological order (oldest first)
|
||||||
chain.reverse()
|
chain.reverse()
|
||||||
|
|
||||||
logger.debug(f"[THREAD] Retrieved chain of {len(chain)} threads for {thread_id}")
|
logger.debug(f"[THREAD] Retrieved chain of {len(chain)} threads for {thread_id}")
|
||||||
return chain
|
return chain
|
||||||
|
|
||||||
@@ -400,7 +400,7 @@ def build_conversation_history(context: ThreadContext, model_context=None, read_
|
|||||||
full file contents from all referenced files. Files are embedded only ONCE at the
|
full file contents from all referenced files. Files are embedded only ONCE at the
|
||||||
start, even if referenced in multiple turns, to prevent duplication and optimize
|
start, even if referenced in multiple turns, to prevent duplication and optimize
|
||||||
token usage.
|
token usage.
|
||||||
|
|
||||||
If the thread has a parent chain, this function traverses the entire chain to
|
If the thread has a parent chain, this function traverses the entire chain to
|
||||||
include the complete conversation history.
|
include the complete conversation history.
|
||||||
|
|
||||||
@@ -429,21 +429,21 @@ def build_conversation_history(context: ThreadContext, model_context=None, read_
|
|||||||
if context.parent_thread_id:
|
if context.parent_thread_id:
|
||||||
# This thread has a parent, get the full chain
|
# This thread has a parent, get the full chain
|
||||||
chain = get_thread_chain(context.thread_id)
|
chain = get_thread_chain(context.thread_id)
|
||||||
|
|
||||||
# Collect all turns from all threads in chain
|
# Collect all turns from all threads in chain
|
||||||
all_turns = []
|
all_turns = []
|
||||||
all_files_set = set()
|
all_files_set = set()
|
||||||
total_turns = 0
|
total_turns = 0
|
||||||
|
|
||||||
for thread in chain:
|
for thread in chain:
|
||||||
all_turns.extend(thread.turns)
|
all_turns.extend(thread.turns)
|
||||||
total_turns += len(thread.turns)
|
total_turns += len(thread.turns)
|
||||||
|
|
||||||
# Collect files from this thread
|
# Collect files from this thread
|
||||||
for turn in thread.turns:
|
for turn in thread.turns:
|
||||||
if turn.files:
|
if turn.files:
|
||||||
all_files_set.update(turn.files)
|
all_files_set.update(turn.files)
|
||||||
|
|
||||||
all_files = list(all_files_set)
|
all_files = list(all_files_set)
|
||||||
logger.debug(f"[THREAD] Built history from {len(chain)} threads with {total_turns} total turns")
|
logger.debug(f"[THREAD] Built history from {len(chain)} threads with {total_turns} total turns")
|
||||||
else:
|
else:
|
||||||
@@ -451,7 +451,7 @@ def build_conversation_history(context: ThreadContext, model_context=None, read_
|
|||||||
all_turns = context.turns
|
all_turns = context.turns
|
||||||
total_turns = len(context.turns)
|
total_turns = len(context.turns)
|
||||||
all_files = get_conversation_file_list(context)
|
all_files = get_conversation_file_list(context)
|
||||||
|
|
||||||
if not all_turns:
|
if not all_turns:
|
||||||
return "", 0
|
return "", 0
|
||||||
|
|
||||||
@@ -459,18 +459,19 @@ def build_conversation_history(context: ThreadContext, model_context=None, read_
|
|||||||
|
|
||||||
# Get model-specific token allocation early (needed for both files and turns)
|
# Get model-specific token allocation early (needed for both files and turns)
|
||||||
if model_context is None:
|
if model_context is None:
|
||||||
from utils.model_context import ModelContext
|
|
||||||
from config import DEFAULT_MODEL
|
from config import DEFAULT_MODEL
|
||||||
|
from utils.model_context import ModelContext
|
||||||
|
|
||||||
model_context = ModelContext(DEFAULT_MODEL)
|
model_context = ModelContext(DEFAULT_MODEL)
|
||||||
|
|
||||||
token_allocation = model_context.calculate_token_allocation()
|
token_allocation = model_context.calculate_token_allocation()
|
||||||
max_file_tokens = token_allocation.file_tokens
|
max_file_tokens = token_allocation.file_tokens
|
||||||
max_history_tokens = token_allocation.history_tokens
|
max_history_tokens = token_allocation.history_tokens
|
||||||
|
|
||||||
logger.debug(f"[HISTORY] Using model-specific limits for {model_context.model_name}:")
|
logger.debug(f"[HISTORY] Using model-specific limits for {model_context.model_name}:")
|
||||||
logger.debug(f"[HISTORY] Max file tokens: {max_file_tokens:,}")
|
logger.debug(f"[HISTORY] Max file tokens: {max_file_tokens:,}")
|
||||||
logger.debug(f"[HISTORY] Max history tokens: {max_history_tokens:,}")
|
logger.debug(f"[HISTORY] Max history tokens: {max_history_tokens:,}")
|
||||||
|
|
||||||
history_parts = [
|
history_parts = [
|
||||||
"=== CONVERSATION HISTORY ===",
|
"=== CONVERSATION HISTORY ===",
|
||||||
f"Thread: {context.thread_id}",
|
f"Thread: {context.thread_id}",
|
||||||
@@ -584,13 +585,13 @@ def build_conversation_history(context: ThreadContext, model_context=None, read_
|
|||||||
)
|
)
|
||||||
|
|
||||||
history_parts.append("Previous conversation turns:")
|
history_parts.append("Previous conversation turns:")
|
||||||
|
|
||||||
# Build conversation turns bottom-up (most recent first) but present chronologically
|
# Build conversation turns bottom-up (most recent first) but present chronologically
|
||||||
# This ensures we include as many recent turns as possible within the token budget
|
# This ensures we include as many recent turns as possible within the token budget
|
||||||
turn_entries = [] # Will store (index, formatted_turn_content) for chronological ordering
|
turn_entries = [] # Will store (index, formatted_turn_content) for chronological ordering
|
||||||
total_turn_tokens = 0
|
total_turn_tokens = 0
|
||||||
file_embedding_tokens = sum(model_context.estimate_tokens(part) for part in history_parts)
|
file_embedding_tokens = sum(model_context.estimate_tokens(part) for part in history_parts)
|
||||||
|
|
||||||
# Process turns in reverse order (most recent first) to prioritize recent context
|
# Process turns in reverse order (most recent first) to prioritize recent context
|
||||||
for idx in range(len(all_turns) - 1, -1, -1):
|
for idx in range(len(all_turns) - 1, -1, -1):
|
||||||
turn = all_turns[idx]
|
turn = all_turns[idx]
|
||||||
@@ -599,16 +600,16 @@ def build_conversation_history(context: ThreadContext, model_context=None, read_
|
|||||||
|
|
||||||
# Build the complete turn content
|
# Build the complete turn content
|
||||||
turn_parts = []
|
turn_parts = []
|
||||||
|
|
||||||
# Add turn header with tool attribution for cross-tool tracking
|
# Add turn header with tool attribution for cross-tool tracking
|
||||||
turn_header = f"\n--- Turn {turn_num} ({role_label}"
|
turn_header = f"\n--- Turn {turn_num} ({role_label}"
|
||||||
if turn.tool_name:
|
if turn.tool_name:
|
||||||
turn_header += f" using {turn.tool_name}"
|
turn_header += f" using {turn.tool_name}"
|
||||||
|
|
||||||
# Add model info if available
|
# Add model info if available
|
||||||
if turn.model_provider and turn.model_name:
|
if turn.model_provider and turn.model_name:
|
||||||
turn_header += f" via {turn.model_provider}/{turn.model_name}"
|
turn_header += f" via {turn.model_provider}/{turn.model_name}"
|
||||||
|
|
||||||
turn_header += ") ---"
|
turn_header += ") ---"
|
||||||
turn_parts.append(turn_header)
|
turn_parts.append(turn_header)
|
||||||
|
|
||||||
@@ -624,11 +625,11 @@ def build_conversation_history(context: ThreadContext, model_context=None, read_
|
|||||||
# Add follow-up question if present
|
# Add follow-up question if present
|
||||||
if turn.follow_up_question:
|
if turn.follow_up_question:
|
||||||
turn_parts.append(f"\n[Gemini's Follow-up: {turn.follow_up_question}]")
|
turn_parts.append(f"\n[Gemini's Follow-up: {turn.follow_up_question}]")
|
||||||
|
|
||||||
# Calculate tokens for this turn
|
# Calculate tokens for this turn
|
||||||
turn_content = "\n".join(turn_parts)
|
turn_content = "\n".join(turn_parts)
|
||||||
turn_tokens = model_context.estimate_tokens(turn_content)
|
turn_tokens = model_context.estimate_tokens(turn_content)
|
||||||
|
|
||||||
# Check if adding this turn would exceed history budget
|
# Check if adding this turn would exceed history budget
|
||||||
if file_embedding_tokens + total_turn_tokens + turn_tokens > max_history_tokens:
|
if file_embedding_tokens + total_turn_tokens + turn_tokens > max_history_tokens:
|
||||||
# Stop adding turns - we've reached the limit
|
# Stop adding turns - we've reached the limit
|
||||||
@@ -639,18 +640,18 @@ def build_conversation_history(context: ThreadContext, model_context=None, read_
|
|||||||
logger.debug(f"[HISTORY] Would total: {file_embedding_tokens + total_turn_tokens + turn_tokens:,}")
|
logger.debug(f"[HISTORY] Would total: {file_embedding_tokens + total_turn_tokens + turn_tokens:,}")
|
||||||
logger.debug(f"[HISTORY] Budget: {max_history_tokens:,}")
|
logger.debug(f"[HISTORY] Budget: {max_history_tokens:,}")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Add this turn to our list (we'll reverse it later for chronological order)
|
# Add this turn to our list (we'll reverse it later for chronological order)
|
||||||
turn_entries.append((idx, turn_content))
|
turn_entries.append((idx, turn_content))
|
||||||
total_turn_tokens += turn_tokens
|
total_turn_tokens += turn_tokens
|
||||||
|
|
||||||
# Reverse to get chronological order (oldest first)
|
# Reverse to get chronological order (oldest first)
|
||||||
turn_entries.reverse()
|
turn_entries.reverse()
|
||||||
|
|
||||||
# Add the turns in chronological order
|
# Add the turns in chronological order
|
||||||
for _, turn_content in turn_entries:
|
for _, turn_content in turn_entries:
|
||||||
history_parts.append(turn_content)
|
history_parts.append(turn_content)
|
||||||
|
|
||||||
# Log what we included
|
# Log what we included
|
||||||
included_turns = len(turn_entries)
|
included_turns = len(turn_entries)
|
||||||
total_turns = len(all_turns)
|
total_turns = len(all_turns)
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ ensuring that token limits are properly calculated based on the current model
|
|||||||
being used, not global constants.
|
being used, not global constants.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from providers import ModelProviderRegistry, ModelCapabilities
|
|
||||||
from config import DEFAULT_MODEL
|
from config import DEFAULT_MODEL
|
||||||
|
from providers import ModelCapabilities, ModelProviderRegistry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -19,12 +19,13 @@ logger = logging.getLogger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class TokenAllocation:
|
class TokenAllocation:
|
||||||
"""Token allocation strategy for a model."""
|
"""Token allocation strategy for a model."""
|
||||||
|
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
content_tokens: int
|
content_tokens: int
|
||||||
response_tokens: int
|
response_tokens: int
|
||||||
file_tokens: int
|
file_tokens: int
|
||||||
history_tokens: int
|
history_tokens: int
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available_for_prompt(self) -> int:
|
def available_for_prompt(self) -> int:
|
||||||
"""Tokens available for the actual prompt after allocations."""
|
"""Tokens available for the actual prompt after allocations."""
|
||||||
@@ -34,17 +35,17 @@ class TokenAllocation:
|
|||||||
class ModelContext:
|
class ModelContext:
|
||||||
"""
|
"""
|
||||||
Encapsulates model-specific information and token calculations.
|
Encapsulates model-specific information and token calculations.
|
||||||
|
|
||||||
This class provides a single source of truth for all model-related
|
This class provides a single source of truth for all model-related
|
||||||
token calculations, ensuring consistency across the system.
|
token calculations, ensuring consistency across the system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self._provider = None
|
self._provider = None
|
||||||
self._capabilities = None
|
self._capabilities = None
|
||||||
self._token_allocation = None
|
self._token_allocation = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider(self):
|
def provider(self):
|
||||||
"""Get the model provider lazily."""
|
"""Get the model provider lazily."""
|
||||||
@@ -53,78 +54,78 @@ class ModelContext:
|
|||||||
if not self._provider:
|
if not self._provider:
|
||||||
raise ValueError(f"No provider found for model: {self.model_name}")
|
raise ValueError(f"No provider found for model: {self.model_name}")
|
||||||
return self._provider
|
return self._provider
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def capabilities(self) -> ModelCapabilities:
|
def capabilities(self) -> ModelCapabilities:
|
||||||
"""Get model capabilities lazily."""
|
"""Get model capabilities lazily."""
|
||||||
if self._capabilities is None:
|
if self._capabilities is None:
|
||||||
self._capabilities = self.provider.get_capabilities(self.model_name)
|
self._capabilities = self.provider.get_capabilities(self.model_name)
|
||||||
return self._capabilities
|
return self._capabilities
|
||||||
|
|
||||||
def calculate_token_allocation(self, reserved_for_response: Optional[int] = None) -> TokenAllocation:
|
def calculate_token_allocation(self, reserved_for_response: Optional[int] = None) -> TokenAllocation:
|
||||||
"""
|
"""
|
||||||
Calculate token allocation based on model capacity.
|
Calculate token allocation based on model capacity.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reserved_for_response: Override response token reservation
|
reserved_for_response: Override response token reservation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TokenAllocation with calculated budgets
|
TokenAllocation with calculated budgets
|
||||||
"""
|
"""
|
||||||
total_tokens = self.capabilities.max_tokens
|
total_tokens = self.capabilities.max_tokens
|
||||||
|
|
||||||
# Dynamic allocation based on model capacity
|
# Dynamic allocation based on model capacity
|
||||||
if total_tokens < 300_000:
|
if total_tokens < 300_000:
|
||||||
# Smaller context models (O3, GPT-4O): Conservative allocation
|
# Smaller context models (O3): Conservative allocation
|
||||||
content_ratio = 0.6 # 60% for content
|
content_ratio = 0.6 # 60% for content
|
||||||
response_ratio = 0.4 # 40% for response
|
response_ratio = 0.4 # 40% for response
|
||||||
file_ratio = 0.3 # 30% of content for files
|
file_ratio = 0.3 # 30% of content for files
|
||||||
history_ratio = 0.5 # 50% of content for history
|
history_ratio = 0.5 # 50% of content for history
|
||||||
else:
|
else:
|
||||||
# Larger context models (Gemini): More generous allocation
|
# Larger context models (Gemini): More generous allocation
|
||||||
content_ratio = 0.8 # 80% for content
|
content_ratio = 0.8 # 80% for content
|
||||||
response_ratio = 0.2 # 20% for response
|
response_ratio = 0.2 # 20% for response
|
||||||
file_ratio = 0.4 # 40% of content for files
|
file_ratio = 0.4 # 40% of content for files
|
||||||
history_ratio = 0.4 # 40% of content for history
|
history_ratio = 0.4 # 40% of content for history
|
||||||
|
|
||||||
# Calculate allocations
|
# Calculate allocations
|
||||||
content_tokens = int(total_tokens * content_ratio)
|
content_tokens = int(total_tokens * content_ratio)
|
||||||
response_tokens = reserved_for_response or int(total_tokens * response_ratio)
|
response_tokens = reserved_for_response or int(total_tokens * response_ratio)
|
||||||
|
|
||||||
# Sub-allocations within content budget
|
# Sub-allocations within content budget
|
||||||
file_tokens = int(content_tokens * file_ratio)
|
file_tokens = int(content_tokens * file_ratio)
|
||||||
history_tokens = int(content_tokens * history_ratio)
|
history_tokens = int(content_tokens * history_ratio)
|
||||||
|
|
||||||
allocation = TokenAllocation(
|
allocation = TokenAllocation(
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
content_tokens=content_tokens,
|
content_tokens=content_tokens,
|
||||||
response_tokens=response_tokens,
|
response_tokens=response_tokens,
|
||||||
file_tokens=file_tokens,
|
file_tokens=file_tokens,
|
||||||
history_tokens=history_tokens
|
history_tokens=history_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Token allocation for {self.model_name}:")
|
logger.debug(f"Token allocation for {self.model_name}:")
|
||||||
logger.debug(f" Total: {allocation.total_tokens:,}")
|
logger.debug(f" Total: {allocation.total_tokens:,}")
|
||||||
logger.debug(f" Content: {allocation.content_tokens:,} ({content_ratio:.0%})")
|
logger.debug(f" Content: {allocation.content_tokens:,} ({content_ratio:.0%})")
|
||||||
logger.debug(f" Response: {allocation.response_tokens:,} ({response_ratio:.0%})")
|
logger.debug(f" Response: {allocation.response_tokens:,} ({response_ratio:.0%})")
|
||||||
logger.debug(f" Files: {allocation.file_tokens:,} ({file_ratio:.0%} of content)")
|
logger.debug(f" Files: {allocation.file_tokens:,} ({file_ratio:.0%} of content)")
|
||||||
logger.debug(f" History: {allocation.history_tokens:,} ({history_ratio:.0%} of content)")
|
logger.debug(f" History: {allocation.history_tokens:,} ({history_ratio:.0%} of content)")
|
||||||
|
|
||||||
return allocation
|
return allocation
|
||||||
|
|
||||||
def estimate_tokens(self, text: str) -> int:
|
def estimate_tokens(self, text: str) -> int:
|
||||||
"""
|
"""
|
||||||
Estimate token count for text using model-specific tokenizer.
|
Estimate token count for text using model-specific tokenizer.
|
||||||
|
|
||||||
For now, uses simple estimation. Can be enhanced with model-specific
|
For now, uses simple estimation. Can be enhanced with model-specific
|
||||||
tokenizers (tiktoken for OpenAI, etc.) in the future.
|
tokenizers (tiktoken for OpenAI, etc.) in the future.
|
||||||
"""
|
"""
|
||||||
# TODO: Integrate model-specific tokenizers
|
# TODO: Integrate model-specific tokenizers
|
||||||
# For now, use conservative estimation
|
# For now, use conservative estimation
|
||||||
return len(text) // 3 # Conservative estimate
|
return len(text) // 3 # Conservative estimate
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_arguments(cls, arguments: Dict[str, Any]) -> "ModelContext":
|
def from_arguments(cls, arguments: dict[str, Any]) -> "ModelContext":
|
||||||
"""Create ModelContext from tool arguments."""
|
"""Create ModelContext from tool arguments."""
|
||||||
model_name = arguments.get("model") or DEFAULT_MODEL
|
model_name = arguments.get("model") or DEFAULT_MODEL
|
||||||
return cls(model_name)
|
return cls(model_name)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Gemini MCP Server - Entry point for backward compatibility
|
Zen MCP Server - Entry point for backward compatibility
|
||||||
This file exists to maintain compatibility with existing configurations.
|
This file exists to maintain compatibility with existing configurations.
|
||||||
The main implementation is now in server.py
|
The main implementation is now in server.py
|
||||||
"""
|
"""
|
||||||
Reference in New Issue
Block a user