Merge branch 'BeehiveInnovations:main' into feat-local_support_with_UTF-8_encoding-update
This commit is contained in:
@@ -1,88 +1,83 @@
|
||||
# Adding a New Provider
|
||||
|
||||
This guide explains how to add support for a new AI model provider to the Zen MCP Server. Follow these steps to integrate providers like Anthropic, Cohere, or any API that provides AI model access.
|
||||
This guide explains how to add support for a new AI model provider to the Zen MCP Server. The provider system is designed to be extensible and follows a simple pattern.
|
||||
|
||||
## Overview
|
||||
|
||||
The provider system in Zen MCP Server is designed to be extensible. Each provider:
|
||||
- Inherits from a base class (`ModelProvider` or `OpenAICompatibleProvider`)
|
||||
- Implements required methods for model interaction
|
||||
- Is registered in the provider registry by the server
|
||||
- Has its API key configured via environment variables
|
||||
Each provider:
|
||||
- Inherits from `ModelProvider` (base class) or `OpenAICompatibleProvider` (for OpenAI-compatible APIs)
|
||||
- Defines supported models using `ModelCapabilities` objects
|
||||
- Implements a few core abstract methods
|
||||
- Gets registered automatically via environment variables
|
||||
|
||||
## Implementation Paths
|
||||
## Choose Your Implementation Path
|
||||
|
||||
You have two options when implementing a new provider:
|
||||
**Option A: Full Provider (`ModelProvider`)**
|
||||
- For APIs with unique features or custom authentication
|
||||
- Complete control over API calls and response handling
|
||||
- Required methods: `generate_content()`, `count_tokens()`, `get_capabilities()`, `validate_model_name()`, `supports_thinking_mode()`, `get_provider_type()`
|
||||
|
||||
### Option A: Native Provider (Full Implementation)
|
||||
Inherit from `ModelProvider` when:
|
||||
- Your API has unique features not compatible with OpenAI's format
|
||||
- You need full control over the implementation
|
||||
- You want to implement custom features like extended thinking
|
||||
**Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)**
|
||||
- For APIs that follow OpenAI's chat completion format
|
||||
- Only need to define: model configurations, capabilities, and validation
|
||||
- Inherits all API handling automatically
|
||||
|
||||
### Option B: OpenAI-Compatible Provider (Simplified)
|
||||
Inherit from `OpenAICompatibleProvider` when:
|
||||
- Your API follows OpenAI's chat completion format
|
||||
- You want to reuse existing implementation for most functionality
|
||||
- You only need to define model capabilities and validation
|
||||
|
||||
⚠️ **CRITICAL**: If your provider has model aliases (shorthands), you **MUST** override `generate_content()` to resolve aliases before API calls. See implementation example below.
|
||||
⚠️ **Important**: If using aliases (like `"gpt"` → `"gpt-4"`), override `generate_content()` to resolve them before API calls.
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
### 1. Add Provider Type to Enum
|
||||
### 1. Add Provider Type
|
||||
|
||||
First, add your provider to the `ProviderType` enum in `providers/base.py`:
|
||||
Add your provider to `ProviderType` enum in `providers/base.py`:
|
||||
|
||||
```python
|
||||
class ProviderType(Enum):
|
||||
"""Supported model provider types."""
|
||||
|
||||
GOOGLE = "google"
|
||||
OPENAI = "openai"
|
||||
OPENROUTER = "openrouter"
|
||||
CUSTOM = "custom"
|
||||
EXAMPLE = "example" # Add your provider here
|
||||
EXAMPLE = "example" # Add this
|
||||
```
|
||||
|
||||
### 2. Create the Provider Implementation
|
||||
|
||||
#### Option A: Native Provider Implementation
|
||||
#### Option A: Full Provider (Native Implementation)
|
||||
|
||||
Create a new file in the `providers/` directory (e.g., `providers/example.py`):
|
||||
Create `providers/example.py`:
|
||||
|
||||
```python
|
||||
"""Example model provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExampleModelProvider(ModelProvider):
|
||||
"""Example model provider implementation."""
|
||||
|
||||
# Define models using ModelCapabilities objects (like Gemini provider)
|
||||
SUPPORTED_MODELS = {
|
||||
"example-large-v1": {
|
||||
"context_window": 100_000,
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
"example-small-v1": {
|
||||
"context_window": 50_000,
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
# Shorthands
|
||||
"large": "example-large-v1",
|
||||
"small": "example-small-v1",
|
||||
"example-large": ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name="example-large",
|
||||
friendly_name="Example Large",
|
||||
context_window=100_000,
|
||||
max_output_tokens=50_000,
|
||||
supports_extended_thinking=False,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
description="Large model for complex tasks",
|
||||
aliases=["large", "big"],
|
||||
),
|
||||
"example-small": ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name="example-small",
|
||||
friendly_name="Example Small",
|
||||
context_window=32_000,
|
||||
max_output_tokens=16_000,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
description="Fast model for simple tasks",
|
||||
aliases=["small", "fast"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
@@ -95,708 +90,225 @@ class ExampleModelProvider(ModelProvider):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
|
||||
# Apply restrictions if needed
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
raise ValueError(f"Model '{model_name}' is not allowed.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name=resolved_name,
|
||||
friendly_name="Example",
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
)
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
self.validate_parameters(resolved_name, temperature)
|
||||
|
||||
# Call your API here
|
||||
# response = your_api_call(...)
|
||||
# Your API call logic here
|
||||
# response = your_api_client.generate(...)
|
||||
|
||||
return ModelResponse(
|
||||
content="", # From API response
|
||||
usage={
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
content="Generated response", # From your API
|
||||
usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
||||
model_name=resolved_name,
|
||||
friendly_name="Example",
|
||||
provider=ProviderType.EXAMPLE,
|
||||
)
|
||||
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
# Implement your tokenization or use estimation
|
||||
return len(text) // 4
|
||||
return len(text) // 4 # Simple estimation
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
return ProviderType.EXAMPLE
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return False
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
logger.debug(f"Example model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
return resolved_name in self.SUPPORTED_MODELS
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
return capabilities.supports_extended_thinking
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
```
|
||||
|
||||
#### Option B: OpenAI-Compatible Provider Implementation
|
||||
#### Option B: OpenAI-Compatible Provider (Simplified)
|
||||
|
||||
For providers with OpenAI-compatible APIs, the implementation is much simpler:
|
||||
For OpenAI-compatible APIs:
|
||||
|
||||
```python
|
||||
"""Example provider using OpenAI-compatible interface."""
|
||||
"""Example OpenAI-compatible provider."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
from .base import ModelCapabilities, ModelResponse, ProviderType, RangeTemperatureConstraint
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExampleProvider(OpenAICompatibleProvider):
|
||||
"""Example provider using OpenAI-compatible API."""
|
||||
"""Example OpenAI-compatible provider."""
|
||||
|
||||
FRIENDLY_NAME = "Example"
|
||||
|
||||
# Define supported models
|
||||
# Define models using ModelCapabilities (consistent with other providers)
|
||||
SUPPORTED_MODELS = {
|
||||
"example-model-large": {
|
||||
"context_window": 128_000,
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
"example-model-small": {
|
||||
"context_window": 32_000,
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
# Shorthands
|
||||
"large": "example-model-large",
|
||||
"small": "example-model-small",
|
||||
"example-model-large": ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name="example-model-large",
|
||||
friendly_name="Example Large",
|
||||
context_window=128_000,
|
||||
max_output_tokens=64_000,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
aliases=["large", "big"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize provider with API key."""
|
||||
# Set your API base URL
|
||||
kwargs.setdefault("base_url", "https://api.example.com/v1")
|
||||
super().__init__(api_key, **kwargs)
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
|
||||
# Check restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name=resolved_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 1.0, 0.7),
|
||||
)
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.EXAMPLE
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return False
|
||||
|
||||
# Check restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
return False
|
||||
|
||||
return True
|
||||
return resolved_name in self.SUPPORTED_MODELS
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using API with proper model name resolution."""
|
||||
# CRITICAL: Resolve model alias before making API call
|
||||
# This ensures aliases like "large" get sent as "example-model-large" to the API
|
||||
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
||||
# IMPORTANT: Resolve aliases before API call
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Call parent implementation with resolved model name
|
||||
return super().generate_content(
|
||||
prompt=prompt,
|
||||
model_name=resolved_model_name,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Note: count_tokens is inherited from OpenAICompatibleProvider
|
||||
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
|
||||
```
|
||||
|
||||
### 3. Update Registry Configuration
|
||||
### 3. Register Your Provider
|
||||
|
||||
#### 3.1. Add Environment Variable Mapping
|
||||
|
||||
Update `providers/registry.py` to map your provider's API key:
|
||||
Add environment variable mapping in `providers/registry.py`:
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
|
||||
"""Get API key for a provider from environment variables."""
|
||||
key_mapping = {
|
||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
||||
ProviderType.CUSTOM: "CUSTOM_API_KEY",
|
||||
ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this line
|
||||
}
|
||||
# ... rest of the method
|
||||
```
|
||||
|
||||
### 4. Register Provider in server.py
|
||||
|
||||
The `configure_providers()` function in `server.py` handles provider registration. You need to:
|
||||
|
||||
**Note**: The provider priority is hardcoded in `registry.py`. If you're adding a new native provider (like Example), you'll need to update the `PROVIDER_PRIORITY_ORDER` in `get_provider_for_model()`:
|
||||
|
||||
```python
|
||||
# In providers/registry.py
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE, # Direct Gemini access
|
||||
ProviderType.OPENAI, # Direct OpenAI access
|
||||
ProviderType.EXAMPLE, # Add your native provider here
|
||||
ProviderType.CUSTOM, # Local/self-hosted models
|
||||
ProviderType.OPENROUTER, # Catch-all (must stay last)
|
||||
]
|
||||
```
|
||||
|
||||
Native providers should be placed BEFORE CUSTOM and OPENROUTER to ensure they get priority for their models.
|
||||
|
||||
1. **Import your provider class** at the top of `server.py`:
|
||||
```python
|
||||
from providers.example import ExampleModelProvider
|
||||
```
|
||||
|
||||
2. **Add API key checking** in the `configure_providers()` function:
|
||||
```python
|
||||
def configure_providers():
|
||||
"""Configure and validate AI providers based on available API keys."""
|
||||
# ... existing code ...
|
||||
|
||||
# Check for Example API key
|
||||
example_key = os.getenv("EXAMPLE_API_KEY")
|
||||
if example_key and example_key != "your_example_api_key_here":
|
||||
valid_providers.append("Example")
|
||||
has_native_apis = True
|
||||
logger.info("Example API key found - Example models available")
|
||||
```
|
||||
|
||||
3. **Register the provider** in the appropriate section:
|
||||
```python
|
||||
# Register providers in priority order:
|
||||
# 1. Native APIs first (most direct and efficient)
|
||||
if has_native_apis:
|
||||
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
if openai_key and openai_key != "your_openai_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
if example_key and example_key != "your_example_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.EXAMPLE, ExampleModelProvider)
|
||||
```
|
||||
|
||||
4. **Update error message** to include your provider:
|
||||
```python
|
||||
if not valid_providers:
|
||||
raise ValueError(
|
||||
"At least one API configuration is required. Please set either:\n"
|
||||
"- GEMINI_API_KEY for Gemini models\n"
|
||||
"- OPENAI_API_KEY for OpenAI o3 model\n"
|
||||
"- EXAMPLE_API_KEY for Example models\n" # Add this
|
||||
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
|
||||
"- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)"
|
||||
)
|
||||
```
|
||||
|
||||
### 6. Add Model Descriptions for Auto Mode
|
||||
|
||||
Add descriptions to your model configurations in the `SUPPORTED_MODELS` dictionary. These descriptions help Claude choose the best model for each task in auto mode:
|
||||
|
||||
```python
|
||||
# In your provider implementation
|
||||
SUPPORTED_MODELS = {
|
||||
"example-large-v1": {
|
||||
"context_window": 100_000,
|
||||
"supports_extended_thinking": False,
|
||||
"description": "Example Large (100K context) - High capacity model for complex tasks",
|
||||
},
|
||||
"example-small-v1": {
|
||||
"context_window": 50_000,
|
||||
"supports_extended_thinking": False,
|
||||
"description": "Example Small (50K context) - Fast model for simple tasks",
|
||||
},
|
||||
# Aliases for convenience
|
||||
"large": "example-large-v1",
|
||||
"small": "example-small-v1",
|
||||
# In _get_api_key_for_provider method:
|
||||
key_mapping = {
|
||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||
ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this
|
||||
}
|
||||
```
|
||||
|
||||
The descriptions should be detailed and help Claude understand when to use each model. Include context about performance, capabilities, cost, and ideal use cases.
|
||||
|
||||
### 7. Update Documentation
|
||||
|
||||
#### 7.1. Update README.md
|
||||
|
||||
Add your provider to the quickstart section:
|
||||
|
||||
```markdown
|
||||
### 1. Get API Keys (at least one required)
|
||||
|
||||
**Option B: Native APIs**
|
||||
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey)
|
||||
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys)
|
||||
- **Example**: Visit [Example API Console](https://example.com/api-keys) # Add this
|
||||
```
|
||||
|
||||
Also update the .env file example:
|
||||
|
||||
```markdown
|
||||
# Edit .env to add your API keys
|
||||
# GEMINI_API_KEY=your-gemini-api-key-here
|
||||
# OPENAI_API_KEY=your-openai-api-key-here
|
||||
# EXAMPLE_API_KEY=your-example-api-key-here # Add this
|
||||
```
|
||||
|
||||
### 8. Write Tests
|
||||
|
||||
#### 8.1. Unit Tests
|
||||
|
||||
Create `tests/test_example_provider.py`:
|
||||
Add to `server.py`:
|
||||
|
||||
1. **Import your provider**:
|
||||
```python
|
||||
"""Tests for Example provider implementation."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from providers.example import ExampleModelProvider
|
||||
from providers.base import ProviderType
|
||||
|
||||
|
||||
class TestExampleProvider:
|
||||
"""Test Example provider functionality."""
|
||||
|
||||
@patch.dict(os.environ, {"EXAMPLE_API_KEY": "test-key"})
|
||||
def test_initialization(self):
|
||||
"""Test provider initialization."""
|
||||
provider = ExampleModelProvider("test-key")
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.get_provider_type() == ProviderType.EXAMPLE
|
||||
|
||||
def test_model_validation(self):
|
||||
"""Test model name validation."""
|
||||
provider = ExampleModelProvider("test-key")
|
||||
|
||||
# Test valid models
|
||||
assert provider.validate_model_name("large") is True
|
||||
assert provider.validate_model_name("example-large-v1") is True
|
||||
|
||||
# Test invalid model
|
||||
assert provider.validate_model_name("invalid-model") is False
|
||||
|
||||
def test_resolve_model_name(self):
|
||||
"""Test model name resolution."""
|
||||
provider = ExampleModelProvider("test-key")
|
||||
|
||||
# Test shorthand resolution
|
||||
assert provider._resolve_model_name("large") == "example-large-v1"
|
||||
assert provider._resolve_model_name("small") == "example-small-v1"
|
||||
|
||||
# Test full name passthrough
|
||||
assert provider._resolve_model_name("example-large-v1") == "example-large-v1"
|
||||
|
||||
def test_get_capabilities(self):
|
||||
"""Test getting model capabilities."""
|
||||
provider = ExampleModelProvider("test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("large")
|
||||
assert capabilities.model_name == "example-large-v1"
|
||||
assert capabilities.friendly_name == "Example"
|
||||
assert capabilities.context_window == 100_000
|
||||
assert capabilities.provider == ProviderType.EXAMPLE
|
||||
|
||||
# Test temperature range
|
||||
assert capabilities.temperature_constraint.min_temp == 0.0
|
||||
assert capabilities.temperature_constraint.max_temp == 2.0
|
||||
```
|
||||
|
||||
#### 8.2. Simulator Tests (Real-World Validation)
|
||||
2. **Add to `configure_providers()` function**:
|
||||
```python
|
||||
# Check for Example API key
|
||||
example_key = os.getenv("EXAMPLE_API_KEY")
|
||||
if example_key:
|
||||
ModelProviderRegistry.register_provider(ProviderType.EXAMPLE, ExampleModelProvider)
|
||||
logger.info("Example API key found - Example models available")
|
||||
```
|
||||
|
||||
Create a simulator test to validate that your provider works correctly in real-world scenarios. Create `simulator_tests/test_example_models.py`:
|
||||
3. **Add to provider priority** (in `providers/registry.py`):
|
||||
```python
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE,
|
||||
ProviderType.OPENAI,
|
||||
ProviderType.EXAMPLE, # Add your provider here
|
||||
ProviderType.CUSTOM, # Local models
|
||||
ProviderType.OPENROUTER, # Catch-all (keep last)
|
||||
]
|
||||
```
|
||||
|
||||
### 4. Environment Configuration
|
||||
|
||||
Add to your `.env` file:
|
||||
```bash
|
||||
# Your provider's API key
|
||||
EXAMPLE_API_KEY=your_api_key_here
|
||||
|
||||
# Optional: Disable specific tools
|
||||
DISABLED_TOOLS=debug,tracer
|
||||
```
|
||||
|
||||
**Note**: The `description` field in `ModelCapabilities` helps Claude choose the best model in auto mode.
|
||||
|
||||
### 5. Test Your Provider
|
||||
|
||||
Create basic tests to verify your implementation:
|
||||
|
||||
```python
|
||||
"""
|
||||
Example Provider Model Tests
|
||||
# Test model validation
|
||||
provider = ExampleModelProvider("test-key")
|
||||
assert provider.validate_model_name("large") == True
|
||||
assert provider.validate_model_name("unknown") == False
|
||||
|
||||
Tests that verify Example provider functionality including:
|
||||
- Model alias resolution
|
||||
- API integration
|
||||
- Conversation continuity
|
||||
- Error handling
|
||||
"""
|
||||
|
||||
from .base_test import BaseSimulatorTest
|
||||
|
||||
|
||||
class TestExampleModels(BaseSimulatorTest):
|
||||
"""Test Example provider functionality"""
|
||||
|
||||
@property
|
||||
def test_name(self) -> str:
|
||||
return "example_models"
|
||||
|
||||
@property
|
||||
def test_description(self) -> str:
|
||||
return "Example provider model functionality and integration"
|
||||
|
||||
def run_test(self) -> bool:
|
||||
"""Test Example provider models"""
|
||||
try:
|
||||
self.logger.info("Test: Example provider functionality")
|
||||
|
||||
# Check if Example API key is configured
|
||||
check_result = self.check_env_var("EXAMPLE_API_KEY")
|
||||
if not check_result:
|
||||
self.logger.info(" ⚠️ Example API key not configured - skipping test")
|
||||
return True # Skip, not fail
|
||||
|
||||
# Test 1: Shorthand alias mapping
|
||||
self.logger.info(" 1: Testing 'large' alias mapping")
|
||||
|
||||
response1, continuation_id = self.call_mcp_tool(
|
||||
"chat",
|
||||
{
|
||||
"prompt": "Say 'Hello from Example Large model!' and nothing else.",
|
||||
"model": "large", # Should map to example-large-v1
|
||||
"temperature": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
if not response1:
|
||||
self.logger.error(" ❌ Large alias test failed")
|
||||
return False
|
||||
|
||||
self.logger.info(" ✅ Large alias call completed")
|
||||
|
||||
# Test 2: Direct model name
|
||||
self.logger.info(" 2: Testing direct model name (example-small-v1)")
|
||||
|
||||
response2, _ = self.call_mcp_tool(
|
||||
"chat",
|
||||
{
|
||||
"prompt": "Say 'Hello from Example Small model!' and nothing else.",
|
||||
"model": "example-small-v1",
|
||||
"temperature": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
if not response2:
|
||||
self.logger.error(" ❌ Direct model name test failed")
|
||||
return False
|
||||
|
||||
self.logger.info(" ✅ Direct model name call completed")
|
||||
|
||||
# Test 3: Conversation continuity
|
||||
self.logger.info(" 3: Testing conversation continuity")
|
||||
|
||||
response3, new_continuation_id = self.call_mcp_tool(
|
||||
"chat",
|
||||
{
|
||||
"prompt": "Remember this number: 99. What number did I just tell you?",
|
||||
"model": "large",
|
||||
"temperature": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
if not response3 or not new_continuation_id:
|
||||
self.logger.error(" ❌ Failed to start conversation")
|
||||
return False
|
||||
|
||||
# Continue conversation
|
||||
response4, _ = self.call_mcp_tool(
|
||||
"chat",
|
||||
{
|
||||
"prompt": "What was the number I told you earlier?",
|
||||
"model": "large",
|
||||
"continuation_id": new_continuation_id,
|
||||
"temperature": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
if not response4:
|
||||
self.logger.error(" ❌ Failed to continue conversation")
|
||||
return False
|
||||
|
||||
if "99" in response4:
|
||||
self.logger.info(" ✅ Conversation continuity working")
|
||||
else:
|
||||
self.logger.warning(" ⚠️ Model may not have remembered the number")
|
||||
|
||||
# Test 4: Check logs for proper provider usage
|
||||
self.logger.info(" 4: Validating Example provider usage in logs")
|
||||
logs = self.get_recent_server_logs()
|
||||
|
||||
# Look for evidence of Example provider usage
|
||||
example_logs = [line for line in logs.split("\n") if "example" in line.lower()]
|
||||
model_resolution_logs = [
|
||||
line for line in logs.split("\n")
|
||||
if "Resolved model" in line and "example" in line.lower()
|
||||
]
|
||||
|
||||
self.logger.info(f" Example-related logs: {len(example_logs)}")
|
||||
self.logger.info(f" Model resolution logs: {len(model_resolution_logs)}")
|
||||
|
||||
# Success criteria
|
||||
api_used = len(example_logs) > 0
|
||||
models_resolved = len(model_resolution_logs) > 0
|
||||
|
||||
if api_used and models_resolved:
|
||||
self.logger.info(" ✅ Example provider tests passed")
|
||||
return True
|
||||
else:
|
||||
self.logger.error(" ❌ Example provider tests failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Example provider test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the Example provider tests"""
|
||||
import sys
|
||||
|
||||
verbose = "--verbose" in sys.argv or "-v" in sys.argv
|
||||
test = TestExampleModels(verbose=verbose)
|
||||
|
||||
success = test.run_test()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Test capabilities
|
||||
caps = provider.get_capabilities("large")
|
||||
assert caps.context_window > 0
|
||||
assert caps.provider == ProviderType.EXAMPLE
|
||||
```
|
||||
|
||||
The simulator test is crucial because it:
|
||||
- Validates your provider works in the actual server environment
|
||||
- Tests real API integration, not just mocked behavior
|
||||
- Verifies model name resolution works correctly
|
||||
- Checks conversation continuity across requests
|
||||
- Examines server logs to ensure proper provider selection
|
||||
|
||||
See `simulator_tests/test_openrouter_models.py` for a complete real-world example.
|
||||
|
||||
## Model Name Mapping and Provider Priority
|
||||
## Key Concepts
|
||||
|
||||
### How Model Name Resolution Works
|
||||
### Provider Priority
|
||||
When a user requests a model, providers are checked in priority order:
|
||||
1. **Native providers** (Gemini, OpenAI, Example) - handle their specific models
|
||||
2. **Custom provider** - handles local/self-hosted models
|
||||
3. **OpenRouter** - catch-all for everything else
|
||||
|
||||
When a user requests a model (e.g., "pro", "o3", "example-large-v1"), the system:
|
||||
|
||||
1. **Checks providers in priority order** (defined in `registry.py`):
|
||||
```python
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE, # Native Gemini API
|
||||
ProviderType.OPENAI, # Native OpenAI API
|
||||
ProviderType.CUSTOM, # Local/self-hosted
|
||||
ProviderType.OPENROUTER, # Catch-all for everything else
|
||||
]
|
||||
```
|
||||
|
||||
2. **For each provider**, calls `validate_model_name()`:
|
||||
- Native providers (Gemini, OpenAI) return `true` only for their specific models
|
||||
- OpenRouter returns `true` for ANY model (it's the catch-all)
|
||||
- First provider that validates the model handles the request
|
||||
|
||||
### Example: Model "gemini-2.5-pro"
|
||||
|
||||
1. **Gemini provider** checks: YES, it's in my SUPPORTED_MODELS → Gemini handles it
|
||||
2. OpenAI skips (Gemini already handled it)
|
||||
3. OpenRouter never sees it
|
||||
|
||||
### Example: Model "claude-4-opus"
|
||||
|
||||
1. **Gemini provider** checks: NO, not my model → skip
|
||||
2. **OpenAI provider** checks: NO, not my model → skip
|
||||
3. **Custom provider** checks: NO, not configured → skip
|
||||
4. **OpenRouter provider** checks: YES, I accept all models → OpenRouter handles it
|
||||
|
||||
### Implementing Model Name Validation
|
||||
|
||||
Your provider's `validate_model_name()` should:
|
||||
### Model Validation
|
||||
Your `validate_model_name()` should **only** return `True` for models you explicitly support:
|
||||
|
||||
```python
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Only accept models you explicitly support
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return False
|
||||
|
||||
# Check restrictions
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
logger.debug(f"Example model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
return resolved_name in self.SUPPORTED_MODELS # Be specific!
|
||||
```
|
||||
|
||||
**Important**: Native providers should ONLY return `true` for models they explicitly support. This ensures they get priority over proxy providers like OpenRouter.
|
||||
### Model Aliases
|
||||
The base class handles alias resolution automatically via the `aliases` field in `ModelCapabilities`.
|
||||
|
||||
### Model Shorthands
|
||||
## Important Notes
|
||||
|
||||
Each provider can define shorthands in their SUPPORTED_MODELS:
|
||||
### Alias Resolution in OpenAI-Compatible Providers
|
||||
If using `OpenAICompatibleProvider` with aliases, **you must override `generate_content()`** to resolve aliases before API calls:
|
||||
|
||||
```python
|
||||
SUPPORTED_MODELS = {
|
||||
"example-large-v1": { ... }, # Full model name
|
||||
"large": "example-large-v1", # Shorthand mapping
|
||||
}
|
||||
```
|
||||
|
||||
The `_resolve_model_name()` method handles this mapping automatically.
|
||||
|
||||
## Critical Implementation Requirements
|
||||
|
||||
### Alias Resolution for OpenAI-Compatible Providers
|
||||
|
||||
If you inherit from `OpenAICompatibleProvider` and define model aliases, you **MUST** override `generate_content()` to resolve aliases before API calls. This is because:
|
||||
|
||||
1. **The base `OpenAICompatibleProvider.generate_content()`** sends the original model name directly to the API
|
||||
2. **Your API expects the full model name**, not the alias
|
||||
3. **Without resolution**, requests like `model="large"` will fail with 404/400 errors
|
||||
|
||||
**Examples of providers that need this:**
|
||||
- XAI provider: `"grok"` → `"grok-3"`
|
||||
- OpenAI provider: `"mini"` → `"o4-mini"`
|
||||
- Custom provider: `"fast"` → `"llama-3.1-8b-instruct"`
|
||||
|
||||
**Example implementation pattern:**
|
||||
```python
|
||||
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
||||
# CRITICAL: Resolve alias before API call
|
||||
# Resolve alias before API call
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Pass resolved name to parent
|
||||
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
|
||||
```
|
||||
|
||||
**Providers that DON'T need this:**
|
||||
- Gemini provider (has its own generate_content implementation)
|
||||
- OpenRouter provider (already implements this pattern)
|
||||
- Providers without aliases
|
||||
Without this, API calls with aliases like `"large"` will fail because your API doesn't recognize the alias.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always validate model names** against supported models and restrictions
|
||||
2. **Be specific in validation** - only accept models you actually support
|
||||
3. **Handle API errors gracefully** with proper error messages
|
||||
4. **Include retry logic** for transient errors (see `gemini.py` for example)
|
||||
5. **Log important events** for debugging (initialization, model resolution, errors)
|
||||
6. **Support model shorthands** for better user experience
|
||||
7. **Document supported models** clearly in your provider class
|
||||
8. **Test thoroughly** including error cases and edge conditions
|
||||
- **Be specific in model validation** - only accept models you actually support
|
||||
- **Use ModelCapabilities objects** consistently (like Gemini provider)
|
||||
- **Include descriptive aliases** for better user experience
|
||||
- **Add error handling** and logging for debugging
|
||||
- **Test with real API calls** to verify everything works
|
||||
- **Follow the existing patterns** in `providers/gemini.py` and `providers/custom.py`
|
||||
|
||||
## Checklist
|
||||
## Quick Checklist
|
||||
|
||||
Before submitting your PR:
|
||||
- [ ] Added to `ProviderType` enum in `providers/base.py`
|
||||
- [ ] Created provider class with all required methods
|
||||
- [ ] Added API key mapping in `providers/registry.py`
|
||||
- [ ] Added to provider priority order in `registry.py`
|
||||
- [ ] Imported and registered in `server.py`
|
||||
- [ ] Basic tests verify model validation and capabilities
|
||||
- [ ] Tested with real API calls
|
||||
|
||||
- [ ] Provider type added to `ProviderType` enum in `providers/base.py`
|
||||
- [ ] Provider implementation complete with all required methods
|
||||
- [ ] API key mapping added to `_get_api_key_for_provider()` in `providers/registry.py`
|
||||
- [ ] Provider added to `PROVIDER_PRIORITY_ORDER` in `registry.py` (if native provider)
|
||||
- [ ] **Environment variables added to `.env` file** (API key and restrictions)
|
||||
- [ ] Provider imported and registered in `server.py`'s `configure_providers()`
|
||||
- [ ] API key checking added to `configure_providers()` function
|
||||
- [ ] Error message updated to include new provider
|
||||
- [ ] Model capabilities added to `config.py` for auto mode
|
||||
- [ ] Documentation updated (README.md)
|
||||
- [ ] Unit tests written and passing (`tests/test_<provider>.py`)
|
||||
- [ ] Simulator tests written and passing (`simulator_tests/test_<provider>_models.py`)
|
||||
- [ ] Integration tested with actual API calls
|
||||
- [ ] Code follows project style (run linting)
|
||||
- [ ] PR follows the template requirements
|
||||
## Examples
|
||||
|
||||
## Need Help?
|
||||
See existing implementations:
|
||||
- **Full provider**: `providers/gemini.py`
|
||||
- **OpenAI-compatible**: `providers/custom.py`
|
||||
- **Base classes**: `providers/base.py`
|
||||
|
||||
- Look at existing providers (`gemini.py`, `openai.py`) for examples
|
||||
- Check the base classes for method signatures and requirements
|
||||
- Run tests frequently during development
|
||||
- Ask questions in GitHub issues if stuck
|
||||
The modern approach uses `ModelCapabilities` objects directly in `SUPPORTED_MODELS`, making the implementation much cleaner and more consistent.
|
||||
@@ -1,657 +1,137 @@
|
||||
# Adding a New Tool to Zen MCP Server
|
||||
# Adding Tools to Zen MCP Server
|
||||
|
||||
This guide provides step-by-step instructions for adding new tools to the Zen MCP Server. Tools are specialized interfaces that let Claude interact with AI models for specific tasks like code review, debugging, consensus gathering, and more.
|
||||
This guide explains how to add new tools to the Zen MCP Server. Tools enable Claude to interact with AI models for specialized tasks like code analysis, debugging, and collaborative thinking.
|
||||
|
||||
## Quick Overview
|
||||
## Tool Types
|
||||
|
||||
Every tool must:
|
||||
- Inherit from `BaseTool` and implement 6 abstract methods
|
||||
- Define a Pydantic request model for validation
|
||||
- Create a system prompt in `systemprompts/`
|
||||
- Register in `server.py`
|
||||
- Handle file/image inputs and conversation threading
|
||||
Zen supports two tool architectures:
|
||||
|
||||
**Key Features**: Automatic conversation threading, file deduplication, token management, model-specific capabilities, web search integration, and comprehensive error handling.
|
||||
### Simple Tools
|
||||
- **Pattern**: Single request → AI response → formatted output
|
||||
- **Use cases**: Chat, quick analysis, straightforward tasks
|
||||
- **Benefits**: Clean, lightweight, easy to implement
|
||||
- **Base class**: `SimpleTool` (`tools/simple/base.py`)
|
||||
|
||||
## Core Architecture
|
||||
### Multi-step Workflow Tools
|
||||
- **Pattern**: Step-by-step investigation with Claude pausing between steps to investigate
|
||||
- **Use cases**: Complex analysis, debugging, code review, security audits
|
||||
- **Benefits**: Systematic investigation, expert analysis integration, better results for complex tasks
|
||||
- **Base class**: `WorkflowTool` (`tools/workflow/base.py`)
|
||||
|
||||
### Components
|
||||
1. **BaseTool** (`tools/base.py`): Abstract base with conversation memory, file handling, and model management
|
||||
2. **Request Models**: Pydantic validation with common fields (model, temperature, thinking_mode, continuation_id, images, use_websearch)
|
||||
3. **System Prompts**: AI behavior configuration with placeholders for dynamic content
|
||||
4. **Model Context**: Automatic provider resolution and token allocation
|
||||
**Recommendation**: Use workflow tools for most complex analysis tasks as they produce significantly better results by forcing systematic investigation.
|
||||
|
||||
### Execution Flow
|
||||
1. **MCP Boundary**: Parameter validation, file security checks, image validation
|
||||
2. **Model Resolution**: Automatic provider selection and capability checking
|
||||
3. **Conversation Context**: History reconstruction and file deduplication
|
||||
4. **Prompt Preparation**: System prompt + user content + file content + conversation history
|
||||
5. **AI Generation**: Provider-agnostic model calls with retry logic
|
||||
6. **Response Processing**: Format output, offer continuation, store in conversation memory
|
||||
## Implementation Guide
|
||||
|
||||
## Step-by-Step Implementation Guide
|
||||
|
||||
### 1. Create the Tool File
|
||||
|
||||
Create `tools/example.py` with proper imports and structure:
|
||||
### Simple Tool Example
|
||||
|
||||
```python
|
||||
"""
|
||||
Example tool - Intelligent code analysis and recommendations
|
||||
|
||||
This tool provides comprehensive code analysis including style, performance,
|
||||
and maintainability recommendations for development teams.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from tools.simple.base import SimpleTool
|
||||
from tools.shared.base_models import ToolRequest
|
||||
from pydantic import Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from config import TEMPERATURE_BALANCED
|
||||
from systemprompts import EXAMPLE_PROMPT # You'll create this
|
||||
|
||||
from .base import BaseTool, ToolRequest
|
||||
|
||||
# No need to import ToolOutput or logging - handled by base class
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Use `TYPE_CHECKING` import for ToolModelCategory to avoid circular imports
|
||||
- Import temperature constants from `config.py`
|
||||
- System prompt imported from `systemprompts/`
|
||||
- Base class handles all common functionality
|
||||
|
||||
### 2. Define the Request Model
|
||||
|
||||
Create a Pydantic model inheriting from `ToolRequest`:
|
||||
|
||||
```python
|
||||
class ExampleRequest(ToolRequest):
|
||||
"""Request model for example tool."""
|
||||
|
||||
# Required field - main user input
|
||||
prompt: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Detailed description of the code analysis needed. Include specific areas "
|
||||
"of concern, goals, and any constraints. The more context provided, "
|
||||
"the more targeted and valuable the analysis will be."
|
||||
)
|
||||
)
|
||||
|
||||
# Optional file input with proper default
|
||||
files: Optional[list[str]] = Field(
|
||||
default_factory=list, # Use factory for mutable defaults
|
||||
description="Code files to analyze (must be absolute paths)"
|
||||
)
|
||||
|
||||
# Tool-specific parameters
|
||||
analysis_depth: Optional[str] = Field(
|
||||
default="standard",
|
||||
description="Analysis depth: 'quick', 'standard', or 'comprehensive'"
|
||||
)
|
||||
|
||||
focus_areas: Optional[list[str]] = Field(
|
||||
default_factory=list,
|
||||
description="Specific areas to focus on (e.g., 'performance', 'security', 'maintainability')"
|
||||
)
|
||||
|
||||
# Images field inherited from ToolRequest - no need to redefine
|
||||
# use_websearch field inherited from ToolRequest - no need to redefine
|
||||
# continuation_id field inherited from ToolRequest - no need to redefine
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Use `default_factory=list` for mutable defaults (not `default=None`)
|
||||
- Common fields (images, use_websearch, continuation_id, model, temperature) are inherited
|
||||
- Detailed descriptions help Claude understand when/how to use parameters
|
||||
- Focus on tool-specific parameters only
|
||||
|
||||
### 3. Implement the Tool Class
|
||||
|
||||
Implement the 6 required abstract methods:
|
||||
|
||||
```python
|
||||
class ExampleTool(BaseTool):
|
||||
"""Intelligent code analysis and recommendations tool."""
|
||||
|
||||
class ChatTool(SimpleTool):
|
||||
def get_name(self) -> str:
|
||||
"""Return unique tool identifier (used by MCP clients)."""
|
||||
return "example"
|
||||
return "chat"
|
||||
|
||||
def get_description(self) -> str:
|
||||
"""Return detailed description to help Claude understand when to use this tool."""
|
||||
return (
|
||||
"CODE ANALYSIS & RECOMMENDATIONS - Provides comprehensive code analysis including "
|
||||
"style improvements, performance optimizations, and maintainability suggestions. "
|
||||
"Perfect for: code reviews, refactoring planning, performance analysis, best practices "
|
||||
"validation. Supports multi-file analysis with focus areas. Use 'comprehensive' analysis "
|
||||
"for complex codebases, 'standard' for regular reviews, 'quick' for simple checks."
|
||||
)
|
||||
return "GENERAL CHAT & COLLABORATIVE THINKING..."
|
||||
|
||||
def get_input_schema(self) -> dict[str, Any]:
|
||||
"""Generate JSON schema - inherit common fields from base class."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Detailed description of the code analysis needed. Include specific areas "
|
||||
"of concern, goals, and any constraints."
|
||||
),
|
||||
},
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Code files to analyze (must be absolute paths)",
|
||||
},
|
||||
"analysis_depth": {
|
||||
"type": "string",
|
||||
"enum": ["quick", "standard", "comprehensive"],
|
||||
"description": "Analysis depth level",
|
||||
"default": "standard",
|
||||
},
|
||||
"focus_areas": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Specific areas to focus on (e.g., 'performance', 'security')",
|
||||
},
|
||||
# Common fields added automatically by base class
|
||||
"model": self.get_model_field_schema(),
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"description": "Response creativity (0-1, default varies by tool)",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
},
|
||||
"thinking_mode": {
|
||||
"type": "string",
|
||||
"enum": ["minimal", "low", "medium", "high", "max"],
|
||||
"description": "Thinking depth: minimal (0.5% of model max), low (8%), medium (33%), high (67%), max (100%)",
|
||||
},
|
||||
"use_websearch": {
|
||||
"type": "boolean",
|
||||
"description": "Enable web search for current best practices and documentation",
|
||||
"default": True,
|
||||
},
|
||||
"images": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional screenshots or diagrams for visual context",
|
||||
},
|
||||
"continuation_id": {
|
||||
"type": "string",
|
||||
"description": "Thread continuation ID for multi-turn conversations",
|
||||
},
|
||||
def get_tool_fields(self) -> dict:
|
||||
return {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Your question or idea..."
|
||||
},
|
||||
"required": ["prompt"] + (["model"] if self.is_effective_auto_mode() else []),
|
||||
"files": SimpleTool.FILES_FIELD # Reuse common field
|
||||
}
|
||||
return schema
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
"""Return system prompt that configures AI behavior."""
|
||||
return EXAMPLE_PROMPT
|
||||
def get_required_fields(self) -> list[str]:
|
||||
return ["prompt"]
|
||||
|
||||
def get_request_model(self):
|
||||
"""Return Pydantic request model class for validation."""
|
||||
return ExampleRequest
|
||||
|
||||
async def prepare_prompt(self, request: ExampleRequest) -> str:
|
||||
"""Prepare complete prompt with user request + file content + context."""
|
||||
# Handle large prompts via prompt.txt file mechanism
|
||||
prompt_content, updated_files = self.handle_prompt_file(request.files)
|
||||
user_content = prompt_content if prompt_content else request.prompt
|
||||
|
||||
# Check MCP transport size limits on user input
|
||||
size_check = self.check_prompt_size(user_content)
|
||||
if size_check:
|
||||
from tools.models import ToolOutput
|
||||
raise ValueError(f"MCP_SIZE_CHECK:{ToolOutput(**size_check).model_dump_json()}")
|
||||
|
||||
# Update files list if prompt.txt was found
|
||||
if updated_files is not None:
|
||||
request.files = updated_files
|
||||
|
||||
# Add focus areas to user content
|
||||
if request.focus_areas:
|
||||
focus_text = "\n\nFocus areas: " + ", ".join(request.focus_areas)
|
||||
user_content += focus_text
|
||||
|
||||
# Add file content using centralized handler (handles deduplication & token limits)
|
||||
if request.files:
|
||||
file_content, processed_files = self._prepare_file_content_for_prompt(
|
||||
request.files, request.continuation_id, "Code files"
|
||||
)
|
||||
self._actually_processed_files = processed_files # For conversation memory
|
||||
if file_content:
|
||||
user_content = f"{user_content}\n\n=== CODE FILES ===\n{file_content}\n=== END FILES ==="
|
||||
|
||||
# Validate final prompt doesn't exceed model context window
|
||||
self._validate_token_limit(user_content, "Prompt content")
|
||||
|
||||
# Add web search instruction if enabled
|
||||
websearch_instruction = self.get_websearch_instruction(
|
||||
request.use_websearch,
|
||||
"""Consider searching for:
|
||||
- Current best practices for the technologies used
|
||||
- Recent security advisories or performance improvements
|
||||
- Community solutions to similar code patterns"""
|
||||
)
|
||||
|
||||
return f"""{self.get_system_prompt()}{websearch_instruction}
|
||||
|
||||
=== ANALYSIS REQUEST ===
|
||||
Analysis Depth: {request.analysis_depth}
|
||||
|
||||
{user_content}
|
||||
=== END REQUEST ===
|
||||
|
||||
Provide comprehensive code analysis with specific, actionable recommendations:"""
|
||||
|
||||
# Optional: Override these methods for customization
|
||||
def get_default_temperature(self) -> float:
|
||||
return TEMPERATURE_BALANCED # 0.5 - good for analytical tasks
|
||||
|
||||
def get_model_category(self) -> "ToolModelCategory":
|
||||
from tools.models import ToolModelCategory
|
||||
return ToolModelCategory.BALANCED # Standard analysis capabilities
|
||||
|
||||
def wants_line_numbers_by_default(self) -> bool:
|
||||
return True # Essential for precise code feedback
|
||||
|
||||
def format_response(self, response: str, request: ExampleRequest, model_info: Optional[dict] = None) -> str:
|
||||
"""Add custom formatting - base class handles continuation offers automatically."""
|
||||
return f"{response}\n\n---\n\n**Next Steps:** Review recommendations and prioritize implementation based on impact."
|
||||
async def prepare_prompt(self, request) -> str:
|
||||
return self.prepare_chat_style_prompt(request)
|
||||
```
|
||||
|
||||
**Key Changes from Documentation:**
|
||||
- **Schema Inheritance**: Common fields handled by base class automatically
|
||||
- **MCP Size Checking**: Required for large prompt handling
|
||||
- **File Processing**: Use `_prepare_file_content_for_prompt()` for conversation-aware deduplication
|
||||
- **Error Handling**: `check_prompt_size()` and `_validate_token_limit()` prevent crashes
|
||||
- **Web Search**: Use `get_websearch_instruction()` for consistent implementation
|
||||
### Workflow Tool Example
|
||||
|
||||
### 4. Create the System Prompt
|
||||
```python
|
||||
from tools.workflow.base import WorkflowTool
|
||||
|
||||
Create `systemprompts/example_prompt.py`:
|
||||
class DebugTool(WorkflowTool):
|
||||
def get_name(self) -> str:
|
||||
return "debug"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "DEBUG & ROOT CAUSE ANALYSIS - Step-by-step investigation..."
|
||||
|
||||
def get_required_actions(self, step_number, confidence, findings, total_steps):
|
||||
if step_number == 1:
|
||||
return ["Search for code related to issue", "Examine relevant files"]
|
||||
return ["Trace execution flow", "Verify hypothesis with code evidence"]
|
||||
|
||||
def should_call_expert_analysis(self, consolidated_findings):
|
||||
return len(consolidated_findings.relevant_files) > 0
|
||||
|
||||
def prepare_expert_analysis_context(self, consolidated_findings):
|
||||
return f"Investigation findings: {consolidated_findings.findings}"
|
||||
```
|
||||
|
||||
## Key Implementation Points
|
||||
|
||||
### Simple Tools
|
||||
- Inherit from `SimpleTool`
|
||||
- Implement: `get_name()`, `get_description()`, `get_tool_fields()`, `prepare_prompt()`
|
||||
- Override: `get_required_fields()`, `format_response()` (optional)
|
||||
|
||||
### Workflow Tools
|
||||
- Inherit from `WorkflowTool`
|
||||
- Implement: `get_name()`, `get_description()`, `get_required_actions()`, `should_call_expert_analysis()`, `prepare_expert_analysis_context()`
|
||||
- Override: `get_tool_fields()` (optional)
|
||||
|
||||
### Registration
|
||||
1. Create system prompt in `systemprompts/`
|
||||
2. Import in `server.py`
|
||||
3. Add to `TOOLS` dictionary
|
||||
|
||||
## Testing Your Tool
|
||||
|
||||
### Simulator Tests (Recommended)
|
||||
The most important validation is adding your tool to the simulator test suite:
|
||||
|
||||
```python
|
||||
"""System prompt for the example code analysis tool."""
|
||||
|
||||
EXAMPLE_PROMPT = """You are an expert code analyst and software engineering consultant specializing in comprehensive code review and optimization recommendations.
|
||||
|
||||
Your analysis should cover:
|
||||
|
||||
TECHNICAL ANALYSIS:
|
||||
- Code structure, organization, and architectural patterns
|
||||
- Performance implications and optimization opportunities
|
||||
- Security vulnerabilities and defensive programming practices
|
||||
- Maintainability factors and technical debt assessment
|
||||
- Best practices adherence and industry standards compliance
|
||||
|
||||
RECOMMENDATIONS FORMAT:
|
||||
1. **Critical Issues** - Security, bugs, or breaking problems (fix immediately)
|
||||
2. **Performance Optimizations** - Specific improvements with expected impact
|
||||
3. **Code Quality Improvements** - Maintainability, readability, and structure
|
||||
4. **Best Practices** - Industry standards and modern patterns
|
||||
5. **Future Considerations** - Scalability and extensibility suggestions
|
||||
|
||||
ANALYSIS GUIDELINES:
|
||||
- Reference specific line numbers when discussing code (file:line format)
|
||||
- Provide concrete, actionable recommendations with examples
|
||||
- Explain the "why" behind each suggestion
|
||||
- Consider the broader system context and trade-offs
|
||||
- Prioritize suggestions by impact and implementation difficulty
|
||||
|
||||
Be precise, practical, and constructive in your analysis. Focus on improvements that provide tangible value to the development team."""
|
||||
# Add to communication_simulator_test.py
|
||||
def test_your_tool_validation(self):
|
||||
"""Test your new tool with real API calls"""
|
||||
response = self.call_tool("your_tool", {
|
||||
"prompt": "Test the tool functionality",
|
||||
"model": "flash"
|
||||
})
|
||||
|
||||
# Validate response structure and content
|
||||
self.assertIn("status", response)
|
||||
self.assertEqual(response["status"], "success")
|
||||
```
|
||||
|
||||
**Add to `systemprompts/__init__.py`:**
|
||||
```python
|
||||
from .example_prompt import EXAMPLE_PROMPT
|
||||
```
|
||||
**Why simulator tests matter:**
|
||||
- Test actual MCP communication with Claude
|
||||
- Validate real AI model interactions
|
||||
- Catch integration issues unit tests miss
|
||||
- Ensure proper conversation threading
|
||||
- Verify file handling and deduplication
|
||||
|
||||
**Key Elements:**
|
||||
- Clear role definition and expertise area
|
||||
- Structured output format that's useful for developers
|
||||
- Specific guidelines for code references and explanations
|
||||
- Focus on actionable, prioritized recommendations
|
||||
|
||||
### 5. Register the Tool
|
||||
|
||||
**Step 5.1: Import in `server.py`**
|
||||
```python
|
||||
from tools.example import ExampleTool
|
||||
```
|
||||
|
||||
**Step 5.2: Add to TOOLS dictionary in `server.py`**
|
||||
```python
|
||||
TOOLS = {
|
||||
"thinkdeep": ThinkDeepTool(),
|
||||
"codereview": CodeReviewTool(),
|
||||
"debug": DebugIssueTool(),
|
||||
"analyze": AnalyzeTool(),
|
||||
"chat": ChatTool(),
|
||||
"example": ExampleTool(), # Add your tool here
|
||||
# ... other tools
|
||||
}
|
||||
```
|
||||
|
||||
**That's it!** The server automatically:
|
||||
- Exposes the tool via MCP protocol
|
||||
- Handles request validation and routing
|
||||
- Manages model resolution and provider selection
|
||||
- Implements conversation threading and file deduplication
|
||||
|
||||
### 6. Write Tests
|
||||
|
||||
Create `tests/test_example.py`:
|
||||
|
||||
```python
|
||||
"""Tests for the example tool."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from tools.example import ExampleTool, ExampleRequest
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
|
||||
class TestExampleTool:
|
||||
"""Test suite for ExampleTool."""
|
||||
|
||||
def test_tool_metadata(self):
|
||||
"""Test basic tool metadata and configuration."""
|
||||
tool = ExampleTool()
|
||||
|
||||
assert tool.get_name() == "example"
|
||||
assert "CODE ANALYSIS" in tool.get_description()
|
||||
assert tool.get_default_temperature() == 0.5 # TEMPERATURE_BALANCED
|
||||
assert tool.get_model_category() == ToolModelCategory.BALANCED
|
||||
assert tool.wants_line_numbers_by_default() is True
|
||||
|
||||
def test_request_validation(self):
|
||||
"""Test Pydantic request model validation."""
|
||||
# Valid request
|
||||
request = ExampleRequest(prompt="Analyze this code for performance issues")
|
||||
assert request.prompt == "Analyze this code for performance issues"
|
||||
assert request.analysis_depth == "standard" # default
|
||||
assert request.focus_areas == [] # default_factory
|
||||
|
||||
# Invalid request (missing required field)
|
||||
with pytest.raises(ValueError):
|
||||
ExampleRequest() # Missing prompt
|
||||
|
||||
def test_input_schema_generation(self):
|
||||
"""Test JSON schema generation for MCP client."""
|
||||
tool = ExampleTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
assert schema["type"] == "object"
|
||||
assert "prompt" in schema["properties"]
|
||||
assert "prompt" in schema["required"]
|
||||
assert "analysis_depth" in schema["properties"]
|
||||
|
||||
# Common fields should be present
|
||||
assert "model" in schema["properties"]
|
||||
assert "continuation_id" in schema["properties"]
|
||||
assert "images" in schema["properties"]
|
||||
|
||||
def test_model_category_for_auto_mode(self):
|
||||
"""Test model category affects auto mode selection."""
|
||||
tool = ExampleTool()
|
||||
category = tool.get_model_category()
|
||||
|
||||
# Should match expected category for provider selection
|
||||
assert category == ToolModelCategory.BALANCED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_prompt_basic(self):
|
||||
"""Test prompt preparation with basic input."""
|
||||
tool = ExampleTool()
|
||||
request = ExampleRequest(
|
||||
prompt="Review this code",
|
||||
analysis_depth="comprehensive",
|
||||
focus_areas=["performance", "security"]
|
||||
)
|
||||
|
||||
# Mock validation methods
|
||||
with patch.object(tool, 'check_prompt_size', return_value=None):
|
||||
with patch.object(tool, '_validate_token_limit'):
|
||||
with patch.object(tool, 'get_websearch_instruction', return_value=""):
|
||||
prompt = await tool.prepare_prompt(request)
|
||||
|
||||
assert "Review this code" in prompt
|
||||
assert "performance, security" in prompt
|
||||
assert "comprehensive" in prompt
|
||||
assert "ANALYSIS REQUEST" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_handling_with_deduplication(self):
|
||||
"""Test file processing with conversation-aware deduplication."""
|
||||
tool = ExampleTool()
|
||||
request = ExampleRequest(
|
||||
prompt="Analyze these files",
|
||||
files=["/path/to/file1.py", "/path/to/file2.py"],
|
||||
continuation_id="test-thread-123"
|
||||
)
|
||||
|
||||
# Mock file processing
|
||||
with patch.object(tool, 'check_prompt_size', return_value=None):
|
||||
with patch.object(tool, '_validate_token_limit'):
|
||||
with patch.object(tool, 'get_websearch_instruction', return_value=""):
|
||||
with patch.object(tool, '_prepare_file_content_for_prompt') as mock_prep:
|
||||
mock_prep.return_value = ("file content", ["/path/to/file1.py"])
|
||||
|
||||
prompt = await tool.prepare_prompt(request)
|
||||
|
||||
# Should call centralized file handler with continuation_id
|
||||
mock_prep.assert_called_once_with(
|
||||
["/path/to/file1.py", "/path/to/file2.py"],
|
||||
"test-thread-123",
|
||||
"Code files"
|
||||
)
|
||||
|
||||
assert "CODE FILES" in prompt
|
||||
assert "file content" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_file_handling(self):
|
||||
"""Test prompt.txt file handling for large inputs."""
|
||||
tool = ExampleTool()
|
||||
request = ExampleRequest(
|
||||
prompt="small prompt", # Will be replaced
|
||||
files=["/path/to/prompt.txt", "/path/to/other.py"]
|
||||
)
|
||||
|
||||
# Mock prompt.txt handling
|
||||
with patch.object(tool, 'handle_prompt_file') as mock_handle:
|
||||
mock_handle.return_value = ("Large prompt content from file", ["/path/to/other.py"])
|
||||
with patch.object(tool, 'check_prompt_size', return_value=None):
|
||||
with patch.object(tool, '_validate_token_limit'):
|
||||
with patch.object(tool, 'get_websearch_instruction', return_value=""):
|
||||
with patch.object(tool, '_prepare_file_content_for_prompt', return_value=("", [])):
|
||||
prompt = await tool.prepare_prompt(request)
|
||||
|
||||
assert "Large prompt content from file" in prompt
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_format_response_customization(self):
|
||||
"""Test custom response formatting."""
|
||||
tool = ExampleTool()
|
||||
request = ExampleRequest(prompt="test")
|
||||
|
||||
formatted = tool.format_response("Analysis complete", request)
|
||||
|
||||
assert "Analysis complete" in formatted
|
||||
assert "Next Steps:" in formatted
|
||||
assert "prioritize implementation" in formatted
|
||||
|
||||
|
||||
# Integration test (requires actual model context)
|
||||
class TestExampleToolIntegration:
|
||||
"""Integration tests that require full tool setup."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up model context for integration tests."""
|
||||
# Initialize model context for file processing
|
||||
from utils.model_context import ModelContext
|
||||
self.tool = ExampleTool()
|
||||
self.tool._model_context = ModelContext("flash") # Test model
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_prompt_preparation(self):
|
||||
"""Test complete prompt preparation flow."""
|
||||
request = ExampleRequest(
|
||||
prompt="Analyze this codebase for security issues",
|
||||
analysis_depth="comprehensive",
|
||||
focus_areas=["security", "performance"]
|
||||
)
|
||||
|
||||
# Mock file system and validation
|
||||
with patch.object(self.tool, 'check_prompt_size', return_value=None):
|
||||
with patch.object(self.tool, '_validate_token_limit'):
|
||||
with patch.object(self.tool, 'get_websearch_instruction', return_value="\nWEB_SEARCH_ENABLED"):
|
||||
prompt = await self.tool.prepare_prompt(request)
|
||||
|
||||
# Verify complete prompt structure
|
||||
assert self.tool.get_system_prompt() in prompt
|
||||
assert "WEB_SEARCH_ENABLED" in prompt
|
||||
assert "security, performance" in prompt
|
||||
assert "comprehensive" in prompt
|
||||
assert "ANALYSIS REQUEST" in prompt
|
||||
```
|
||||
|
||||
**Key Testing Patterns:**
|
||||
- **Metadata Tests**: Verify tool configuration and schema generation
|
||||
- **Validation Tests**: Test Pydantic request models and edge cases
|
||||
- **Prompt Tests**: Mock external dependencies, test prompt composition
|
||||
- **Integration Tests**: Test full flow with model context
|
||||
- **File Handling**: Test conversation-aware deduplication
|
||||
- **Error Cases**: Test size limits, validation failures
|
||||
|
||||
## Essential Gotchas & Best Practices
|
||||
|
||||
### Critical Requirements
|
||||
|
||||
**🚨 MUST DO:**
|
||||
1. **Inherit from ToolRequest**: Request models MUST inherit from `ToolRequest` to get common fields
|
||||
2. **Use `default_factory=list`**: For mutable defaults, never use `default=[]` - causes shared state bugs
|
||||
3. **Implement all 6 abstract methods**: `get_name()`, `get_description()`, `get_input_schema()`, `get_system_prompt()`, `get_request_model()`, `prepare_prompt()`
|
||||
4. **Handle MCP size limits**: Call `check_prompt_size()` on user input in `prepare_prompt()`
|
||||
5. **Use centralized file processing**: Call `_prepare_file_content_for_prompt()` for conversation-aware deduplication
|
||||
6. **Register in server.py**: Import tool and add to `TOOLS` dictionary
|
||||
|
||||
**🚨 COMMON MISTAKES:**
|
||||
- **Forgetting TYPE_CHECKING**: Import `ToolModelCategory` under `TYPE_CHECKING` to avoid circular imports
|
||||
- **Hardcoding models**: Use `get_model_category()` instead of hardcoding model selection
|
||||
- **Ignoring continuation_id**: File processing should pass `continuation_id` for deduplication
|
||||
- **Missing error handling**: Always validate token limits with `_validate_token_limit()`
|
||||
- **Wrong default patterns**: Use `default_factory=list` not `default=None` for file lists
|
||||
|
||||
### File Handling Patterns
|
||||
|
||||
```python
|
||||
# ✅ CORRECT: Conversation-aware file processing
|
||||
file_content, processed_files = self._prepare_file_content_for_prompt(
|
||||
request.files, request.continuation_id, "Context files"
|
||||
)
|
||||
self._actually_processed_files = processed_files # For conversation memory
|
||||
|
||||
# ❌ WRONG: Direct file reading (no deduplication)
|
||||
file_content = read_files(request.files)
|
||||
```
|
||||
|
||||
### Request Model Patterns
|
||||
|
||||
```python
|
||||
# ✅ CORRECT: Proper defaults and inheritance
|
||||
class MyToolRequest(ToolRequest):
|
||||
files: Optional[list[str]] = Field(default_factory=list, ...)
|
||||
options: Optional[list[str]] = Field(default_factory=list, ...)
|
||||
|
||||
# ❌ WRONG: Shared mutable defaults
|
||||
class MyToolRequest(ToolRequest):
|
||||
files: Optional[list[str]] = Field(default=[], ...) # BUG!
|
||||
```
|
||||
|
||||
### Testing Requirements
|
||||
|
||||
**Required Tests:**
|
||||
- Tool metadata (name, description, category)
|
||||
- Request validation (valid/invalid cases)
|
||||
- Schema generation for MCP
|
||||
- Prompt preparation with mocks
|
||||
- File handling with conversation IDs
|
||||
- Error cases (size limits, validation failures)
|
||||
|
||||
### Model Categories Guide
|
||||
|
||||
- **FAST_RESPONSE**: Chat, simple queries, quick tasks (→ o4-mini, flash)
|
||||
- **BALANCED**: Standard analysis, code review, general tasks (→ o3-mini, pro)
|
||||
- **EXTENDED_REASONING**: Complex debugging, deep analysis (→ o3, pro with high thinking)
|
||||
|
||||
### Advanced Features
|
||||
|
||||
**Conversation Threading**: Automatic if `continuation_id` provided
|
||||
**File Deduplication**: Automatic via `_prepare_file_content_for_prompt()`
|
||||
**Web Search**: Use `get_websearch_instruction()` for consistent implementation
|
||||
**Image Support**: Inherited from ToolRequest, validated automatically
|
||||
**Large Prompts**: Handle via `check_prompt_size()` → prompt.txt mechanism
|
||||
|
||||
## Quick Checklist
|
||||
|
||||
**Before Submitting PR:**
|
||||
- [ ] Tool inherits from `BaseTool`, request from `ToolRequest`
|
||||
- [ ] All 6 abstract methods implemented
|
||||
- [ ] System prompt created in `systemprompts/`
|
||||
- [ ] Tool registered in `server.py` TOOLS dict
|
||||
- [ ] Comprehensive unit tests written
|
||||
- [ ] File handling uses `_prepare_file_content_for_prompt()`
|
||||
- [ ] MCP size checking with `check_prompt_size()`
|
||||
- [ ] Token validation with `_validate_token_limit()`
|
||||
- [ ] Proper model category selected
|
||||
- [ ] No hardcoded model names
|
||||
|
||||
**Run Before Commit:**
|
||||
### Running Tests
|
||||
```bash
|
||||
# Test your tool
|
||||
pytest tests/test_example.py -xvs
|
||||
# Test your specific tool
|
||||
python communication_simulator_test.py --individual your_tool_validation
|
||||
|
||||
# Run all tests
|
||||
./code_quality_checks.sh
|
||||
# Quick comprehensive test
|
||||
python communication_simulator_test.py --quick
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
## Examples to Study
|
||||
|
||||
The example tool we built provides:
|
||||
- **Comprehensive code analysis** with configurable depth
|
||||
- **Multi-file support** with conversation-aware deduplication
|
||||
- **Focus areas** for targeted analysis
|
||||
- **Web search integration** for current best practices
|
||||
- **Image support** for screenshots/diagrams
|
||||
- **Conversation threading** for follow-up discussions
|
||||
- **Automatic model selection** based on task complexity
|
||||
- **Simple Tool**: `tools/chat.py` - Clean request/response pattern
|
||||
- **Workflow Tool**: `tools/debug.py` - Multi-step investigation with expert analysis
|
||||
|
||||
**Usage by Claude:**
|
||||
```json
|
||||
{
|
||||
"tool": "example",
|
||||
"arguments": {
|
||||
"prompt": "Analyze this codebase for security vulnerabilities and performance issues",
|
||||
"files": ["/path/to/src/", "/path/to/config.py"],
|
||||
"analysis_depth": "comprehensive",
|
||||
"focus_areas": ["security", "performance"],
|
||||
"model": "o3"
|
||||
}
|
||||
}
|
||||
```
|
||||
**Recommendation**: Start with existing tools as templates and explore the base classes to understand available hooks and methods.
|
||||
|
||||
The tool automatically handles file deduplication, validates inputs, manages token limits, and offers continuation opportunities for deeper analysis.
|
||||
|
||||
---
|
||||
|
||||
**Need Help?** Look at existing tools like `chat.py` and `consensus.py` for reference implementations, or check GitHub issues for support.
|
||||
Reference in New Issue
Block a user