diff --git a/README.md b/README.md index a67be15..4b4f3fb 100644 --- a/README.md +++ b/README.md @@ -606,7 +606,14 @@ For detailed configuration options, see the [Advanced Usage Guide](docs/advanced ## Testing -For information on running tests and contributing, see the [Testing Guide](docs/testing.md). +For information on running tests, see the [Testing Guide](docs/testing.md). + +## Contributing + +We welcome contributions! Please see our comprehensive guides: +- [Contributing Guide](docs/contributions.md) - Code standards, PR process, and requirements +- [Adding a New Provider](docs/adding_providers.md) - Step-by-step guide for adding AI providers +- [Adding a New Tool](docs/adding_tools.md) - Step-by-step guide for creating new tools ## License diff --git a/docs/adding_providers.md b/docs/adding_providers.md new file mode 100644 index 0000000..182230c --- /dev/null +++ b/docs/adding_providers.md @@ -0,0 +1,734 @@ +# Adding a New Provider + +This guide explains how to add support for a new AI model provider to the Zen MCP Server. Follow these steps to integrate providers like Anthropic, Cohere, or any API that provides AI model access. + +## Overview + +The provider system in Zen MCP Server is designed to be extensible. Each provider: +- Inherits from a base class (`ModelProvider` or `OpenAICompatibleProvider`) +- Implements required methods for model interaction +- Is registered in the provider registry by the server +- Has its API key configured via environment variables + +## Implementation Paths + +You have two options when implementing a new provider: + +### Option A: Native Provider (Full Implementation) +Inherit from `ModelProvider` when: +- Your API has unique features not compatible with OpenAI's format +- You need full control over the implementation +- You want to implement custom features like extended thinking + +### Option B: OpenAI-Compatible Provider (Simplified) +Inherit from `OpenAICompatibleProvider` when: +- Your API follows OpenAI's chat completion format +- You want to reuse existing implementation for `generate_content` and `count_tokens` +- You only need to define model capabilities and validation + +## Step-by-Step Guide + +### 1. Add Provider Type to Enum + +First, add your provider to the `ProviderType` enum in `providers/base.py`: + +```python +class ProviderType(Enum): + """Supported model provider types.""" + + GOOGLE = "google" + OPENAI = "openai" + OPENROUTER = "openrouter" + CUSTOM = "custom" + EXAMPLE = "example" # Add your provider here +``` + +### 2. Create the Provider Implementation + +#### Option A: Native Provider Implementation + +Create a new file in the `providers/` directory (e.g., `providers/example.py`): + +```python +"""Example model provider implementation.""" + +import logging +from typing import Optional +from .base import ( + ModelCapabilities, + ModelProvider, + ModelResponse, + ProviderType, + RangeTemperatureConstraint, +) +from utils.model_restrictions import get_restriction_service + +logger = logging.getLogger(__name__) + + +class ExampleModelProvider(ModelProvider): + """Example model provider implementation.""" + + SUPPORTED_MODELS = { + "example-large-v1": { + "context_window": 100_000, + "supports_extended_thinking": False, + }, + "example-small-v1": { + "context_window": 50_000, + "supports_extended_thinking": False, + }, + # Shorthands + "large": "example-large-v1", + "small": "example-small-v1", + } + + def __init__(self, api_key: str, **kwargs): + super().__init__(api_key, **kwargs) + # Initialize your API client here + + def get_capabilities(self, model_name: str) -> ModelCapabilities: + resolved_name = self._resolve_model_name(model_name) + + if resolved_name not in self.SUPPORTED_MODELS: + raise ValueError(f"Unsupported model: {model_name}") + + restriction_service = get_restriction_service() + if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name): + raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.") + + config = self.SUPPORTED_MODELS[resolved_name] + + return ModelCapabilities( + provider=ProviderType.EXAMPLE, + model_name=resolved_name, + friendly_name="Example", + context_window=config["context_window"], + supports_extended_thinking=config["supports_extended_thinking"], + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7), + ) + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.7, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> ModelResponse: + resolved_name = self._resolve_model_name(model_name) + self.validate_parameters(resolved_name, temperature) + + # Call your API here + # response = your_api_call(...) + + return ModelResponse( + content="", # From API response + usage={ + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + }, + model_name=resolved_name, + friendly_name="Example", + provider=ProviderType.EXAMPLE, + ) + + def count_tokens(self, text: str, model_name: str) -> int: + # Implement your tokenization or use estimation + return len(text) // 4 + + def get_provider_type(self) -> ProviderType: + return ProviderType.EXAMPLE + + def validate_model_name(self, model_name: str) -> bool: + resolved_name = self._resolve_model_name(model_name) + + if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + return False + + restriction_service = get_restriction_service() + if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name): + logger.debug(f"Example model '{model_name}' -> '{resolved_name}' blocked by restrictions") + return False + + return True + + def supports_thinking_mode(self, model_name: str) -> bool: + capabilities = self.get_capabilities(model_name) + return capabilities.supports_extended_thinking + + def _resolve_model_name(self, model_name: str) -> str: + shorthand_value = self.SUPPORTED_MODELS.get(model_name) + if isinstance(shorthand_value, str): + return shorthand_value + return model_name +``` + +#### Option B: OpenAI-Compatible Provider Implementation + +For providers with OpenAI-compatible APIs, the implementation is much simpler: + +```python +"""Example provider using OpenAI-compatible interface.""" + +import logging +from .base import ( + ModelCapabilities, + ProviderType, + RangeTemperatureConstraint, +) +from .openai_compatible import OpenAICompatibleProvider + +logger = logging.getLogger(__name__) + + +class ExampleProvider(OpenAICompatibleProvider): + """Example provider using OpenAI-compatible API.""" + + FRIENDLY_NAME = "Example" + + # Define supported models + SUPPORTED_MODELS = { + "example-model-large": { + "context_window": 128_000, + "supports_extended_thinking": False, + }, + "example-model-small": { + "context_window": 32_000, + "supports_extended_thinking": False, + }, + # Shorthands + "large": "example-model-large", + "small": "example-model-small", + } + + def __init__(self, api_key: str, **kwargs): + """Initialize provider with API key.""" + # Set your API base URL + kwargs.setdefault("base_url", "https://api.example.com/v1") + super().__init__(api_key, **kwargs) + + def get_capabilities(self, model_name: str) -> ModelCapabilities: + """Get capabilities for a specific model.""" + resolved_name = self._resolve_model_name(model_name) + + if resolved_name not in self.SUPPORTED_MODELS: + raise ValueError(f"Unsupported model: {model_name}") + + # Check restrictions + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name): + raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.") + + config = self.SUPPORTED_MODELS[resolved_name] + + return ModelCapabilities( + provider=ProviderType.EXAMPLE, + model_name=resolved_name, + friendly_name=self.FRIENDLY_NAME, + context_window=config["context_window"], + supports_extended_thinking=config["supports_extended_thinking"], + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + temperature_constraint=RangeTemperatureConstraint(0.0, 1.0, 0.7), + ) + + def get_provider_type(self) -> ProviderType: + """Get the provider type.""" + return ProviderType.EXAMPLE + + def validate_model_name(self, model_name: str) -> bool: + """Validate if the model name is supported.""" + resolved_name = self._resolve_model_name(model_name) + + if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + return False + + # Check restrictions + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name): + return False + + return True + + def _resolve_model_name(self, model_name: str) -> str: + """Resolve model shorthand to full name.""" + shorthand_value = self.SUPPORTED_MODELS.get(model_name) + if isinstance(shorthand_value, str): + return shorthand_value + return model_name + + # Note: generate_content and count_tokens are inherited from OpenAICompatibleProvider +``` + +### 3. Update Registry Configuration + +#### 3.1. Add Environment Variable Mapping + +Update `providers/registry.py` to map your provider's API key: + +```python +@classmethod +def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]: + """Get API key for a provider from environment variables.""" + key_mapping = { + ProviderType.GOOGLE: "GEMINI_API_KEY", + ProviderType.OPENAI: "OPENAI_API_KEY", + ProviderType.OPENROUTER: "OPENROUTER_API_KEY", + ProviderType.CUSTOM: "CUSTOM_API_KEY", + ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this line + } + # ... rest of the method +``` + +### 4. Register Provider in server.py + +The `configure_providers()` function in `server.py` handles provider registration. You need to: + +**Note**: The provider priority is hardcoded in `registry.py`. If you're adding a new native provider (like Example), you'll need to update the `PROVIDER_PRIORITY_ORDER` in `get_provider_for_model()`: + +```python +# In providers/registry.py +PROVIDER_PRIORITY_ORDER = [ + ProviderType.GOOGLE, # Direct Gemini access + ProviderType.OPENAI, # Direct OpenAI access + ProviderType.EXAMPLE, # Add your native provider here + ProviderType.CUSTOM, # Local/self-hosted models + ProviderType.OPENROUTER, # Catch-all (must stay last) +] +``` + +Native providers should be placed BEFORE CUSTOM and OPENROUTER to ensure they get priority for their models. + +1. **Import your provider class** at the top of `server.py`: +```python +from providers.example import ExampleModelProvider +``` + +2. **Add API key checking** in the `configure_providers()` function: +```python +def configure_providers(): + """Configure and validate AI providers based on available API keys.""" + # ... existing code ... + + # Check for Example API key + example_key = os.getenv("EXAMPLE_API_KEY") + if example_key and example_key != "your_example_api_key_here": + valid_providers.append("Example") + has_native_apis = True + logger.info("Example API key found - Example models available") +``` + +3. **Register the provider** in the appropriate section: +```python + # Register providers in priority order: + # 1. Native APIs first (most direct and efficient) + if has_native_apis: + if gemini_key and gemini_key != "your_gemini_api_key_here": + ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider) + if openai_key and openai_key != "your_openai_api_key_here": + ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider) + if example_key and example_key != "your_example_api_key_here": + ModelProviderRegistry.register_provider(ProviderType.EXAMPLE, ExampleModelProvider) +``` + +4. **Update error message** to include your provider: +```python + if not valid_providers: + raise ValueError( + "At least one API configuration is required. Please set either:\n" + "- GEMINI_API_KEY for Gemini models\n" + "- OPENAI_API_KEY for OpenAI o3 model\n" + "- EXAMPLE_API_KEY for Example models\n" # Add this + "- OPENROUTER_API_KEY for OpenRouter (multiple models)\n" + "- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)" + ) +``` + +### 5. Add Model Capabilities for Auto Mode + +Update `config.py` to add your models to `MODEL_CAPABILITIES_DESC`: + +```python +MODEL_CAPABILITIES_DESC = { + # ... existing models ... + + # Example models - Available when EXAMPLE_API_KEY is configured + "large": "Example Large (100K context) - High capacity model for complex tasks", + "small": "Example Small (50K context) - Fast model for simple tasks", + # Full model names + "example-large-v1": "Example Large (100K context) - High capacity model", + "example-small-v1": "Example Small (50K context) - Fast lightweight model", +} +``` + +### 6. Update Documentation + +#### 6.1. Update README.md + +Add your provider to the quickstart section: + +```markdown +### 1. Get API Keys (at least one required) + +**Option B: Native APIs** +- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) +- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) +- **Example**: Visit [Example API Console](https://example.com/api-keys) # Add this +``` + +Also update the .env file example: + +```markdown +# Edit .env to add your API keys +# GEMINI_API_KEY=your-gemini-api-key-here +# OPENAI_API_KEY=your-openai-api-key-here +# EXAMPLE_API_KEY=your-example-api-key-here # Add this +``` + +### 7. Write Tests + +#### 7.1. Unit Tests + +Create `tests/test_example_provider.py`: + +```python +"""Tests for Example provider implementation.""" + +import os +from unittest.mock import patch +import pytest + +from providers.example import ExampleModelProvider +from providers.base import ProviderType + + +class TestExampleProvider: + """Test Example provider functionality.""" + + @patch.dict(os.environ, {"EXAMPLE_API_KEY": "test-key"}) + def test_initialization(self): + """Test provider initialization.""" + provider = ExampleModelProvider("test-key") + assert provider.api_key == "test-key" + assert provider.get_provider_type() == ProviderType.EXAMPLE + + def test_model_validation(self): + """Test model name validation.""" + provider = ExampleModelProvider("test-key") + + # Test valid models + assert provider.validate_model_name("large") is True + assert provider.validate_model_name("example-large-v1") is True + + # Test invalid model + assert provider.validate_model_name("invalid-model") is False + + def test_resolve_model_name(self): + """Test model name resolution.""" + provider = ExampleModelProvider("test-key") + + # Test shorthand resolution + assert provider._resolve_model_name("large") == "example-large-v1" + assert provider._resolve_model_name("small") == "example-small-v1" + + # Test full name passthrough + assert provider._resolve_model_name("example-large-v1") == "example-large-v1" + + def test_get_capabilities(self): + """Test getting model capabilities.""" + provider = ExampleModelProvider("test-key") + + capabilities = provider.get_capabilities("large") + assert capabilities.model_name == "example-large-v1" + assert capabilities.friendly_name == "Example" + assert capabilities.context_window == 100_000 + assert capabilities.provider == ProviderType.EXAMPLE + + # Test temperature range + assert capabilities.temperature_constraint.min_temp == 0.0 + assert capabilities.temperature_constraint.max_temp == 2.0 +``` + +#### 7.2. Simulator Tests (Real-World Validation) + +Create a simulator test to validate that your provider works correctly in real-world scenarios. Create `simulator_tests/test_example_models.py`: + +```python +""" +Example Provider Model Tests + +Tests that verify Example provider functionality including: +- Model alias resolution +- API integration +- Conversation continuity +- Error handling +""" + +from .base_test import BaseSimulatorTest + + +class TestExampleModels(BaseSimulatorTest): + """Test Example provider functionality""" + + @property + def test_name(self) -> str: + return "example_models" + + @property + def test_description(self) -> str: + return "Example provider model functionality and integration" + + def run_test(self) -> bool: + """Test Example provider models""" + try: + self.logger.info("Test: Example provider functionality") + + # Check if Example API key is configured + check_result = self.check_env_var("EXAMPLE_API_KEY") + if not check_result: + self.logger.info(" ⚠️ Example API key not configured - skipping test") + return True # Skip, not fail + + # Test 1: Shorthand alias mapping + self.logger.info(" 1: Testing 'large' alias mapping") + + response1, continuation_id = self.call_mcp_tool( + "chat", + { + "prompt": "Say 'Hello from Example Large model!' and nothing else.", + "model": "large", # Should map to example-large-v1 + "temperature": 0.1, + } + ) + + if not response1: + self.logger.error(" ❌ Large alias test failed") + return False + + self.logger.info(" ✅ Large alias call completed") + + # Test 2: Direct model name + self.logger.info(" 2: Testing direct model name (example-small-v1)") + + response2, _ = self.call_mcp_tool( + "chat", + { + "prompt": "Say 'Hello from Example Small model!' and nothing else.", + "model": "example-small-v1", + "temperature": 0.1, + } + ) + + if not response2: + self.logger.error(" ❌ Direct model name test failed") + return False + + self.logger.info(" ✅ Direct model name call completed") + + # Test 3: Conversation continuity + self.logger.info(" 3: Testing conversation continuity") + + response3, new_continuation_id = self.call_mcp_tool( + "chat", + { + "prompt": "Remember this number: 99. What number did I just tell you?", + "model": "large", + "temperature": 0.1, + } + ) + + if not response3 or not new_continuation_id: + self.logger.error(" ❌ Failed to start conversation") + return False + + # Continue conversation + response4, _ = self.call_mcp_tool( + "chat", + { + "prompt": "What was the number I told you earlier?", + "model": "large", + "continuation_id": new_continuation_id, + "temperature": 0.1, + } + ) + + if not response4: + self.logger.error(" ❌ Failed to continue conversation") + return False + + if "99" in response4: + self.logger.info(" ✅ Conversation continuity working") + else: + self.logger.warning(" ⚠️ Model may not have remembered the number") + + # Test 4: Check logs for proper provider usage + self.logger.info(" 4: Validating Example provider usage in logs") + logs = self.get_recent_server_logs() + + # Look for evidence of Example provider usage + example_logs = [line for line in logs.split("\n") if "example" in line.lower()] + model_resolution_logs = [ + line for line in logs.split("\n") + if "Resolved model" in line and "example" in line.lower() + ] + + self.logger.info(f" Example-related logs: {len(example_logs)}") + self.logger.info(f" Model resolution logs: {len(model_resolution_logs)}") + + # Success criteria + api_used = len(example_logs) > 0 + models_resolved = len(model_resolution_logs) > 0 + + if api_used and models_resolved: + self.logger.info(" ✅ Example provider tests passed") + return True + else: + self.logger.error(" ❌ Example provider tests failed") + return False + + except Exception as e: + self.logger.error(f"Example provider test failed: {e}") + return False + + +def main(): + """Run the Example provider tests""" + import sys + + verbose = "--verbose" in sys.argv or "-v" in sys.argv + test = TestExampleModels(verbose=verbose) + + success = test.run_test() + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() +``` + +The simulator test is crucial because it: +- Validates your provider works in the actual Docker environment +- Tests real API integration, not just mocked behavior +- Verifies model name resolution works correctly +- Checks conversation continuity across requests +- Examines server logs to ensure proper provider selection + +See `simulator_tests/test_openrouter_models.py` for a complete real-world example. + +## Model Name Mapping and Provider Priority + +### How Model Name Resolution Works + +When a user requests a model (e.g., "pro", "o3", "example-large-v1"), the system: + +1. **Checks providers in priority order** (defined in `registry.py`): + ```python + PROVIDER_PRIORITY_ORDER = [ + ProviderType.GOOGLE, # Native Gemini API + ProviderType.OPENAI, # Native OpenAI API + ProviderType.CUSTOM, # Local/self-hosted + ProviderType.OPENROUTER, # Catch-all for everything else + ] + ``` + +2. **For each provider**, calls `validate_model_name()`: + - Native providers (Gemini, OpenAI) return `true` only for their specific models + - OpenRouter returns `true` for ANY model (it's the catch-all) + - First provider that validates the model handles the request + +### Example: Model "gemini-2.5-pro" + +1. **Gemini provider** checks: YES, it's in my SUPPORTED_MODELS → Gemini handles it +2. OpenAI skips (Gemini already handled it) +3. OpenRouter never sees it + +### Example: Model "claude-3-opus" + +1. **Gemini provider** checks: NO, not my model → skip +2. **OpenAI provider** checks: NO, not my model → skip +3. **Custom provider** checks: NO, not configured → skip +4. **OpenRouter provider** checks: YES, I accept all models → OpenRouter handles it + +### Implementing Model Name Validation + +Your provider's `validate_model_name()` should: + +```python +def validate_model_name(self, model_name: str) -> bool: + resolved_name = self._resolve_model_name(model_name) + + # Only accept models you explicitly support + if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + return False + + # Check restrictions + restriction_service = get_restriction_service() + if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name): + logger.debug(f"Example model '{model_name}' -> '{resolved_name}' blocked by restrictions") + return False + + return True +``` + +**Important**: Native providers should ONLY return `true` for models they explicitly support. This ensures they get priority over proxy providers like OpenRouter. + +### Model Shorthands + +Each provider can define shorthands in their SUPPORTED_MODELS: + +```python +SUPPORTED_MODELS = { + "example-large-v1": { ... }, # Full model name + "large": "example-large-v1", # Shorthand mapping +} +``` + +The `_resolve_model_name()` method handles this mapping automatically. + +## Best Practices + +1. **Always validate model names** against supported models and restrictions +2. **Be specific in validation** - only accept models you actually support +3. **Handle API errors gracefully** with proper error messages +4. **Include retry logic** for transient errors (see `gemini.py` for example) +5. **Log important events** for debugging (initialization, model resolution, errors) +6. **Support model shorthands** for better user experience +7. **Document supported models** clearly in your provider class +8. **Test thoroughly** including error cases and edge conditions + +## Checklist + +Before submitting your PR: + +- [ ] Provider type added to `ProviderType` enum in `providers/base.py` +- [ ] Provider implementation complete with all required methods +- [ ] API key mapping added to `_get_api_key_for_provider()` in `providers/registry.py` +- [ ] Provider added to `PROVIDER_PRIORITY_ORDER` in `registry.py` (if native provider) +- [ ] Provider imported and registered in `server.py`'s `configure_providers()` +- [ ] API key checking added to `configure_providers()` function +- [ ] Error message updated to include new provider +- [ ] Model capabilities added to `config.py` for auto mode +- [ ] Documentation updated (README.md) +- [ ] Unit tests written and passing (`tests/test_.py`) +- [ ] Simulator tests written and passing (`simulator_tests/test__models.py`) +- [ ] Integration tested with actual API calls +- [ ] Code follows project style (run linting) +- [ ] PR follows the template requirements + +## Need Help? + +- Look at existing providers (`gemini.py`, `openai.py`) for examples +- Check the base classes for method signatures and requirements +- Run tests frequently during development +- Ask questions in GitHub issues if stuck \ No newline at end of file diff --git a/docs/adding_tools.md b/docs/adding_tools.md new file mode 100644 index 0000000..2875f52 --- /dev/null +++ b/docs/adding_tools.md @@ -0,0 +1,732 @@ +# Adding a New Tool + +This guide explains how to add a new tool to the Zen MCP Server. Tools are the primary way Claude interacts with the AI models, providing specialized capabilities like code review, debugging, test generation, and more. + +## Overview + +The tool system in Zen MCP Server is designed to be extensible. Each tool: +- Inherits from the `BaseTool` class +- Implements required abstract methods +- Defines a request model for parameter validation +- Is registered in the server's tool registry +- Can leverage different AI models based on task requirements + +## Architecture Overview + +### Key Components + +1. **BaseTool** (`tools/base.py`): Abstract base class providing common functionality +2. **Request Models**: Pydantic models for input validation +3. **System Prompts**: Specialized prompts that configure AI behavior +4. **Tool Registry**: Registration system in `server.py` + +### Tool Lifecycle + +1. Claude calls the tool with parameters +2. Parameters are validated using Pydantic +3. File paths are security-checked +4. Prompt is prepared with system instructions +5. AI model generates response +6. Response is formatted and returned + +## Step-by-Step Implementation Guide + +### 1. Create the Tool File + +Create a new file in the `tools/` directory (e.g., `tools/example.py`): + +```python +""" +Example tool - Brief description of what your tool does + +This tool provides [specific functionality] to help developers [achieve goal]. +Key features: +- Feature 1 +- Feature 2 +- Feature 3 +""" + +import logging +from typing import Any, Optional + +from mcp.types import TextContent +from pydantic import Field + +from config import TEMPERATURE_BALANCED +from systemprompts import EXAMPLE_PROMPT # You'll create this + +from .base import BaseTool, ToolRequest +from .models import ToolOutput + +logger = logging.getLogger(__name__) +``` + +### 2. Define the Request Model + +Create a Pydantic model that inherits from `ToolRequest`: + +```python +class ExampleRequest(ToolRequest): + """Request model for the example tool.""" + + # Required parameters + prompt: str = Field( + ..., + description="The main input/question for the tool" + ) + + # Optional parameters with defaults + files: Optional[list[str]] = Field( + default=None, + description="Files to analyze (must be absolute paths)" + ) + + focus_area: Optional[str] = Field( + default=None, + description="Specific aspect to focus on" + ) + + # You can add tool-specific parameters + output_format: Optional[str] = Field( + default="detailed", + description="Output format: 'summary', 'detailed', or 'actionable'" + ) +``` + +### 3. Implement the Tool Class + +```python +class ExampleTool(BaseTool): + """Implementation of the example tool.""" + + def get_name(self) -> str: + """Return the tool's unique identifier.""" + return "example" + + def get_description(self) -> str: + """Return detailed description for Claude.""" + return ( + "EXAMPLE TOOL - Brief tagline describing the tool's purpose. " + "Use this tool when you need to [specific use cases]. " + "Perfect for: [scenario 1], [scenario 2], [scenario 3]. " + "Supports [key features]. Choose thinking_mode based on " + "[guidance for mode selection]. " + "Note: If you're not currently using a top-tier model such as " + "Opus 4 or above, these tools can provide enhanced capabilities." + ) + + def get_input_schema(self) -> dict[str, Any]: + """Define the JSON schema for tool parameters.""" + schema = { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The main input/question for the tool", + }, + "files": { + "type": "array", + "items": {"type": "string"}, + "description": "Files to analyze (must be absolute paths)", + }, + "focus_area": { + "type": "string", + "description": "Specific aspect to focus on", + }, + "output_format": { + "type": "string", + "enum": ["summary", "detailed", "actionable"], + "description": "Output format type", + "default": "detailed", + }, + "model": self.get_model_field_schema(), + "temperature": { + "type": "number", + "description": "Temperature (0-1, default varies by tool)", + "minimum": 0, + "maximum": 1, + }, + "thinking_mode": { + "type": "string", + "enum": ["minimal", "low", "medium", "high", "max"], + "description": "Thinking depth: minimal (0.5% of model max), " + "low (8%), medium (33%), high (67%), max (100%)", + }, + "continuation_id": { + "type": "string", + "description": "Thread continuation ID for multi-turn conversations", + }, + }, + "required": ["prompt"] + ( + ["model"] if self.is_effective_auto_mode() else [] + ), + } + return schema + + def get_system_prompt(self) -> str: + """Return the system prompt for this tool.""" + return EXAMPLE_PROMPT # Defined in systemprompts/ + + def get_default_temperature(self) -> float: + """Return default temperature for this tool.""" + # Use predefined constants from config.py: + # TEMPERATURE_CREATIVE (0.7) - For creative tasks + # TEMPERATURE_BALANCED (0.5) - For balanced tasks + # TEMPERATURE_ANALYTICAL (0.2) - For analytical tasks + return TEMPERATURE_BALANCED + + def get_model_category(self): + """Specify which type of model this tool needs.""" + from tools.models import ToolModelCategory + + # Choose based on your tool's needs: + # FAST_RESPONSE - Quick responses, cost-efficient (chat, simple queries) + # BALANCED - Standard analysis and generation + # EXTENDED_REASONING - Complex analysis, deep thinking (debug, review) + return ToolModelCategory.BALANCED + + def get_request_model(self): + """Return the request model class.""" + return ExampleRequest + + def wants_line_numbers_by_default(self) -> bool: + """Whether to add line numbers to code files.""" + # Return True if your tool benefits from precise line references + # (e.g., code review, debugging, refactoring) + # Return False for general analysis or token-sensitive operations + return False + + async def prepare_prompt(self, request: ExampleRequest) -> str: + """ + Prepare the complete prompt for the AI model. + + This method combines: + - System prompt (behavior configuration) + - User request + - File contents (if provided) + - Additional context + """ + # Check for prompt.txt in files (handles large prompts) + prompt_content, updated_files = self.handle_prompt_file(request.files) + if prompt_content: + request.prompt = prompt_content + if updated_files is not None: + request.files = updated_files + + # Build the prompt parts + prompt_parts = [] + + # Add main request + prompt_parts.append(f"=== USER REQUEST ===") + prompt_parts.append(f"Focus Area: {request.focus_area}" if request.focus_area else "") + prompt_parts.append(f"Output Format: {request.output_format}") + prompt_parts.append(request.prompt) + prompt_parts.append("=== END REQUEST ===") + + # Add file contents if provided + if request.files: + # Use the centralized file handling (respects continuation) + file_content = self._prepare_file_content_for_prompt( + request.files, + request.continuation_id, + "Files to analyze" + ) + if file_content: + prompt_parts.append("\n=== FILES ===") + prompt_parts.append(file_content) + prompt_parts.append("=== END FILES ===") + + # Validate token limits + full_prompt = "\n".join(filter(None, prompt_parts)) + self._validate_token_limit(full_prompt, "Prompt") + + return full_prompt + + def format_response(self, response: str, request: ExampleRequest, + model_info: Optional[dict] = None) -> str: + """ + Format the AI's response for display. + + Override this to add custom formatting, headers, or structure. + The base class handles special status parsing automatically. + """ + # Example: Add a footer with next steps + return f"{response}\n\n---\n\n**Next Steps:** Review the analysis above and proceed with implementation." +``` + +### 4. Handle Large Prompts (Optional) + +If your tool might receive large text inputs, override the `execute` method: + +```python +async def execute(self, arguments: dict[str, Any]) -> list[TextContent]: + """Override to check prompt size before processing.""" + # Validate request first + request_model = self.get_request_model() + request = request_model(**arguments) + + # Check if prompt is too large for MCP limits + size_check = self.check_prompt_size(request.prompt) + if size_check: + return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())] + + # Continue with normal execution + return await super().execute(arguments) +``` + +### 5. Create the System Prompt + +Create a new file in `systemprompts/` (e.g., `systemprompts/example_prompt.py`): + +```python +"""System prompt for the example tool.""" + +EXAMPLE_PROMPT = """You are an AI assistant specialized in [tool purpose]. + +Your role is to [primary responsibility] by [approach/methodology]. + +Key principles: +1. [Principle 1] +2. [Principle 2] +3. [Principle 3] + +When analyzing content: +- [Guideline 1] +- [Guideline 2] +- [Guideline 3] + +Output format: +- Start with a brief summary +- Provide detailed analysis organized by [structure] +- Include specific examples and recommendations +- End with actionable next steps + +Remember to: +- Be specific and reference exact locations (file:line) when discussing code +- Provide practical, implementable suggestions +- Consider the broader context and implications +- Maintain a helpful, constructive tone +""" +``` + +Add the import to `systemprompts/__init__.py`: + +```python +from .example_prompt import EXAMPLE_PROMPT +``` + +### 6. Register the Tool + +#### 6.1. Import in server.py + +Add the import at the top of `server.py`: + +```python +from tools.example import ExampleTool +``` + +#### 6.2. Add to TOOLS Dictionary + +Find the `TOOLS` dictionary in `server.py` and add your tool: + +```python +TOOLS = { + "analyze": AnalyzeTool(), + "chat": ChatTool(), + "review_code": CodeReviewTool(), + "debug": DebugTool(), + "review_changes": PreCommitTool(), + "generate_tests": TestGenTool(), + "thinkdeep": ThinkDeepTool(), + "refactor": RefactorTool(), + "example": ExampleTool(), # Add your tool here +} +``` + +### 7. Write Tests + +Create unit tests in `tests/test_example.py`: + +```python +"""Tests for the example tool.""" + +import pytest +from unittest.mock import Mock, patch + +from tools.example import ExampleTool, ExampleRequest +from tools.models import ToolModelCategory + + +class TestExampleTool: + """Test suite for ExampleTool.""" + + def test_tool_metadata(self): + """Test tool metadata methods.""" + tool = ExampleTool() + + assert tool.get_name() == "example" + assert "EXAMPLE TOOL" in tool.get_description() + assert tool.get_default_temperature() == 0.5 + assert tool.get_model_category() == ToolModelCategory.BALANCED + + def test_request_validation(self): + """Test request model validation.""" + # Valid request + request = ExampleRequest(prompt="Test prompt") + assert request.prompt == "Test prompt" + assert request.output_format == "detailed" # default + + # Invalid request (missing required field) + with pytest.raises(ValueError): + ExampleRequest() + + def test_input_schema(self): + """Test input schema generation.""" + tool = ExampleTool() + schema = tool.get_input_schema() + + assert schema["type"] == "object" + assert "prompt" in schema["properties"] + assert "prompt" in schema["required"] + assert "model" in schema["properties"] + + @pytest.mark.asyncio + async def test_prepare_prompt(self): + """Test prompt preparation.""" + tool = ExampleTool() + request = ExampleRequest( + prompt="Analyze this code", + focus_area="performance", + output_format="summary" + ) + + with patch.object(tool, '_validate_token_limit'): + prompt = await tool.prepare_prompt(request) + + assert "USER REQUEST" in prompt + assert "Analyze this code" in prompt + assert "Focus Area: performance" in prompt + assert "Output Format: summary" in prompt + + @pytest.mark.asyncio + async def test_file_handling(self): + """Test file content handling.""" + tool = ExampleTool() + request = ExampleRequest( + prompt="Analyze", + files=["/path/to/file.py"] + ) + + # Mock file reading + with patch.object(tool, '_prepare_file_content_for_prompt') as mock_prep: + mock_prep.return_value = "file contents" + with patch.object(tool, '_validate_token_limit'): + prompt = await tool.prepare_prompt(request) + + assert "FILES" in prompt + assert "file contents" in prompt +``` + +### 8. Add Simulator Tests (Optional) + +For tools that interact with external systems, create simulator tests in `simulator_tests/test_example_basic.py`: + +```python +"""Basic simulator test for example tool.""" + +from simulator_tests.base_test import SimulatorTest + + +class TestExampleBasic(SimulatorTest): + """Test basic example tool functionality.""" + + def test_example_analysis(self): + """Test basic analysis with example tool.""" + result = self.call_tool( + "example", + { + "prompt": "Analyze the architecture of this codebase", + "model": "flash", + "output_format": "summary" + } + ) + + self.assert_tool_success(result) + self.assert_content_contains(result, ["architecture", "summary"]) +``` + +### 9. Update Documentation + +Add your tool to the README.md in the tools section: + +```markdown +### Available Tools + +- **example** - Brief description of what the tool does + - Use cases: [scenario 1], [scenario 2] + - Supports: [key features] + - Best model: `balanced` category for standard analysis +``` + +## Advanced Features + +### Understanding Conversation Memory + +The `continuation_id` feature enables multi-turn conversations using the conversation memory system (`utils/conversation_memory.py`). Here's how it works: + +1. **Thread Creation**: When a tool wants to enable follow-up conversations, it creates a thread +2. **Turn Storage**: Each exchange (user/assistant) is stored as a turn with metadata +3. **Cross-Tool Continuation**: Any tool can continue a conversation started by another tool +4. **Automatic History**: When `continuation_id` is provided, the full conversation history is reconstructed + +Key concepts: +- **ThreadContext**: Contains all conversation turns, files, and metadata +- **ConversationTurn**: Single exchange with role, content, timestamp, files, tool attribution +- **Thread Chains**: Conversations can have parent threads for extended discussions +- **Turn Limits**: Default 20 turns (configurable via MAX_CONVERSATION_TURNS) + +Example flow: +```python +# Tool A creates thread +thread_id = create_thread("analyze", request_data) + +# Tool A adds its response +add_turn(thread_id, "assistant", response, files=[...], tool_name="analyze") + +# Tool B continues the same conversation +context = get_thread(thread_id) # Gets full history +# Tool B sees all previous turns and files +``` + +### Supporting Special Response Types + +Tools can return special status responses for complex interactions. These are defined in `tools/models.py`: + +```python +# Currently supported special statuses: +SPECIAL_STATUS_MODELS = { + "need_clarification": NeedClarificationModel, + "focused_review_required": FocusedReviewRequiredModel, + "more_review_required": MoreReviewRequiredModel, + "more_testgen_required": MoreTestGenRequiredModel, + "more_refactor_required": MoreRefactorRequiredModel, + "resend_prompt": ResendPromptModel, +} +``` + +Example implementation: +```python +# In your tool's format_response or within the AI response: +if need_clarification: + return json.dumps({ + "status": "need_clarification", + "questions": ["What specific aspect should I focus on?"], + "context": "I need more information to proceed" + }) + +# For custom review status: +if more_analysis_needed: + return json.dumps({ + "status": "focused_review_required", + "files": ["/path/to/file1.py", "/path/to/file2.py"], + "focus": "security", + "reason": "Found potential SQL injection vulnerabilities" + }) +``` + +To add a new custom response type: + +1. Define the model in `tools/models.py`: +```python +class CustomStatusModel(BaseModel): + """Model for custom status responses""" + status: Literal["custom_status"] + custom_field: str + details: dict[str, Any] +``` + +2. Register it in `SPECIAL_STATUS_MODELS`: +```python +SPECIAL_STATUS_MODELS = { + # ... existing statuses ... + "custom_status": CustomStatusModel, +} +``` + +3. The base tool will automatically handle parsing and validation + +### Token Management + +For tools processing large amounts of data: + +```python +# Calculate available tokens dynamically +def prepare_large_content(self, files: list[str], remaining_budget: int): + # Reserve tokens for response + reserve_tokens = 5000 + + # Use model-specific limits + effective_max = remaining_budget - reserve_tokens + + # Process files with budget + content = self._prepare_file_content_for_prompt( + files, + continuation_id, + "Analysis files", + max_tokens=effective_max, + reserve_tokens=reserve_tokens + ) +``` + +### Web Search Integration + +Enable web search for tools that benefit from current information: + +```python +# In prepare_prompt: +websearch_instruction = self.get_websearch_instruction( + request.use_websearch, + """Consider searching for: + - Current best practices for [topic] + - Recent updates to [technology] + - Community solutions for [problem]""" +) + +full_prompt = f"{system_prompt}{websearch_instruction}\n\n{user_content}" +``` + +## Best Practices + +1. **Clear Tool Descriptions**: Write descriptive text that helps Claude understand when to use your tool +2. **Proper Validation**: Use Pydantic models for robust input validation +3. **Security First**: Always validate file paths are absolute +4. **Token Awareness**: Handle large inputs gracefully with prompt.txt mechanism +5. **Model Selection**: Choose appropriate model category for your tool's complexity +6. **Line Numbers**: Enable for tools needing precise code references +7. **Error Handling**: Provide helpful error messages for common issues +8. **Testing**: Write comprehensive unit tests and simulator tests +9. **Documentation**: Include examples and use cases in your description + +## Common Pitfalls to Avoid + +1. **Don't Skip Validation**: Always validate inputs, especially file paths +2. **Don't Ignore Token Limits**: Use `_validate_token_limit` and handle large prompts +3. **Don't Hardcode Models**: Use model categories for flexibility +4. **Don't Forget Tests**: Every tool needs tests for reliability +5. **Don't Break Conventions**: Follow existing patterns from other tools + +## Testing Your Tool + +### Manual Testing + +1. Start the server with your tool registered +2. Use Claude Desktop to call your tool +3. Test various parameter combinations +4. Verify error handling + +### Automated Testing + +```bash +# Run unit tests +pytest tests/test_example.py -xvs + +# Run all tests to ensure no regressions +pytest -xvs + +# Run simulator tests if applicable +python communication_simulator_test.py +``` + +## Checklist + +Before submitting your PR: + +- [ ] Tool class created inheriting from `BaseTool` +- [ ] All abstract methods implemented +- [ ] Request model defined with proper validation +- [ ] System prompt created in `systemprompts/` +- [ ] Tool registered in `server.py` +- [ ] Unit tests written and passing +- [ ] Simulator tests added (if applicable) +- [ ] Documentation updated +- [ ] Code follows project style (ruff, black, isort) +- [ ] Large prompt handling implemented (if needed) +- [ ] Security validation for file paths +- [ ] Appropriate model category selected +- [ ] Tool description is clear and helpful + +## Example: Complete Simple Tool + +Here's a minimal but complete example tool: + +```python +""" +Simple calculator tool for mathematical operations. +""" + +from typing import Any, Optional +from mcp.types import TextContent +from pydantic import Field + +from config import TEMPERATURE_ANALYTICAL +from .base import BaseTool, ToolRequest +from .models import ToolOutput + + +class CalculateRequest(ToolRequest): + """Request model for calculator tool.""" + + expression: str = Field( + ..., + description="Mathematical expression to evaluate" + ) + + +class CalculatorTool(BaseTool): + """Simple calculator tool.""" + + def get_name(self) -> str: + return "calculate" + + def get_description(self) -> str: + return ( + "CALCULATOR - Evaluates mathematical expressions. " + "Use this for calculations, conversions, and math problems." + ) + + def get_input_schema(self) -> dict[str, Any]: + schema = { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate", + }, + "model": self.get_model_field_schema(), + }, + "required": ["expression"] + ( + ["model"] if self.is_effective_auto_mode() else [] + ), + } + return schema + + def get_system_prompt(self) -> str: + return """You are a mathematical assistant. Evaluate the expression + and explain the calculation steps clearly.""" + + def get_default_temperature(self) -> float: + return TEMPERATURE_ANALYTICAL + + def get_request_model(self): + return CalculateRequest + + async def prepare_prompt(self, request: CalculateRequest) -> str: + return f"Calculate: {request.expression}\n\nShow your work step by step." +``` + +## Need Help? + +- Look at existing tools (`chat.py`, `refactor.py`) for examples +- Check `base.py` for available helper methods +- Review test files for testing patterns +- Ask questions in GitHub issues if stuck \ No newline at end of file diff --git a/docs/contributions.md b/docs/contributions.md new file mode 100644 index 0000000..59a402d --- /dev/null +++ b/docs/contributions.md @@ -0,0 +1,233 @@ +# Contributing to Zen MCP Server + +Thank you for your interest in contributing to Zen MCP Server! This guide will help you understand our development process, coding standards, and how to submit high-quality contributions. + +## Getting Started + +1. **Fork the repository** on GitHub +2. **Clone your fork** locally +3. **Set up the development environment**: + ```bash + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + pip install -r requirements.txt + ``` +4. **Create a feature branch** from `main`: + ```bash + git checkout -b feat/your-feature-name + ``` + +## Development Process + +### 1. Code Quality Standards + +We maintain high code quality standards. **All contributions must pass our automated checks**. + +#### Required Code Quality Checks + +Before submitting any PR, run these commands: + +```bash +# Activate virtual environment first +source venv/bin/activate + +# Run all linting checks (MUST pass 100%) +ruff check . +black --check . +isort --check-only . + +# Auto-fix issues if needed +ruff check . --fix +black . +isort . + +# Run complete unit test suite (MUST pass 100%) +python -m pytest -xvs + +# Run simulator tests for tool changes +python communication_simulator_test.py +``` + +**Important**: +- **Every single test must pass** - we have zero tolerance for failing tests in CI +- All linting must pass cleanly (ruff, black, isort) +- Import sorting must be correct +- Tests failing in GitHub Actions will result in PR rejection + +### 2. Testing Requirements + +#### When to Add Tests + +1. **New features MUST include tests**: + - Add unit tests in `tests/` for new functions or classes + - Test both success and error cases + +2. **Tool changes require simulator tests**: + - Add simulator tests in `simulator_tests/` for new or modified tools + - Use realistic prompts that demonstrate the feature + - Validate output through Docker logs + +3. **Bug fixes require regression tests**: + - Add a test that would have caught the bug + - Ensure the bug cannot reoccur + +#### Test Naming Conventions +- Unit tests: `test__.py` +- Simulator tests: `test__.py` + +### 3. Pull Request Process + +#### PR Title Format + +Your PR title MUST follow one of these formats: + +**Version Bumping Prefixes** (trigger Docker build + version bump): +- `feat: ` - New features (MINOR version bump) +- `fix: ` - Bug fixes (PATCH version bump) +- `breaking: ` or `BREAKING CHANGE: ` - Breaking changes (MAJOR version bump) +- `perf: ` - Performance improvements (PATCH version bump) +- `refactor: ` - Code refactoring (PATCH version bump) + +**Non-Version Prefixes** (no version bump): +- `docs: ` - Documentation only +- `chore: ` - Maintenance tasks +- `test: ` - Test additions/changes +- `ci: ` - CI/CD changes +- `style: ` - Code style changes + +**Docker Build Options**: +- `docker: ` - Force Docker build without version bump +- `docs+docker: ` - Documentation + Docker build +- `chore+docker: ` - Maintenance + Docker build + +#### PR Checklist + +Use our [PR template](../.github/pull_request_template.md) and ensure: + +- [ ] PR title follows the format guidelines above +- [ ] Activated venv and ran all linting +- [ ] Self-review completed +- [ ] Tests added for ALL changes +- [ ] Documentation updated as needed +- [ ] All unit tests passing +- [ ] Relevant simulator tests passing (if tool changes) +- [ ] Ready for review + +### 4. Code Style Guidelines + +#### Python Code Style +- Follow PEP 8 with Black formatting +- Use type hints for function parameters and returns +- Add docstrings to all public functions and classes +- Keep functions focused and under 50 lines when possible +- Use descriptive variable names + +#### Example: +```python +def process_model_response( + response: ModelResponse, + max_tokens: Optional[int] = None +) -> ProcessedResult: + """Process and validate model response. + + Args: + response: Raw response from the model provider + max_tokens: Optional token limit for truncation + + Returns: + ProcessedResult with validated and formatted content + + Raises: + ValueError: If response is invalid or exceeds limits + """ + # Implementation here +``` + +#### Import Organization +Imports must be organized by isort into these groups: +1. Standard library imports +2. Third-party imports +3. Local application imports + +### 5. Specific Contribution Types + +#### Adding a New Provider +See our detailed guide: [Adding a New Provider](./adding_providers.md) + +#### Adding a New Tool +See our detailed guide: [Adding a New Tool](./adding_tools.md) + +#### Modifying Existing Tools +1. Ensure backward compatibility unless explicitly breaking +2. Update all affected tests +3. Update documentation if behavior changes +4. Add simulator tests for new functionality + +### 6. Documentation Standards + +- Update README.md for user-facing changes +- Add docstrings to all new code +- Update relevant docs/ files +- Include examples for new features +- Keep documentation concise and clear + +### 7. Commit Message Guidelines + +Write clear, descriptive commit messages: +- First line: Brief summary (50 chars or less) +- Blank line +- Detailed explanation if needed +- Reference issues: "Fixes #123" + +Example: +``` +feat: Add retry logic to Gemini provider + +Implements exponential backoff for transient errors +in Gemini API calls. Retries up to 2 times with +configurable delays. + +Fixes #45 +``` + +## Common Issues and Solutions + +### Linting Failures +```bash +# Auto-fix most issues +ruff check . --fix +black . +isort . +``` + +### Test Failures +- Check test output for specific errors +- Run individual tests for debugging: `pytest tests/test_specific.py -xvs` +- Ensure Docker is running for simulator tests + +### Import Errors +- Verify virtual environment is activated +- Check all dependencies are installed: `pip install -r requirements.txt` + +## Getting Help + +- **Questions**: Open a GitHub issue with the "question" label +- **Bug Reports**: Use the bug report template +- **Feature Requests**: Use the feature request template +- **Discussions**: Use GitHub Discussions for general topics + +## Code of Conduct + +- Be respectful and inclusive +- Welcome newcomers and help them get started +- Focus on constructive feedback +- Assume good intentions + +## Recognition + +Contributors are recognized in: +- GitHub contributors page +- Release notes for significant contributions +- Special mentions for exceptional work + +Thank you for contributing to Zen MCP Server! Your efforts help make this tool better for everyone. \ No newline at end of file diff --git a/docs/testing.md b/docs/testing.md index 76104e1..e198b08 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -124,56 +124,24 @@ Validate real-world usage scenarios by simulating actual Claude prompts: - **Token allocation**: Context window management in practice - **Redis validation**: Conversation persistence and retrieval -## Contributing: Test Requirements +## Contributing -When contributing to this project: +For detailed contribution guidelines, testing requirements, and code quality standards, please see our [Contributing Guide](./contributions.md). -1. **New features MUST include tests**: - - Add unit tests in `tests/` for new functions or classes - - Test both success and error cases - -2. **Tool changes require simulator tests**: - - Add simulator tests in `simulator_tests/` for new or modified tools - - Use realistic prompts that demonstrate the feature - - Validate output through Docker logs - -3. **Test naming conventions**: - - Unit tests: `test__.py` - - Simulator tests: `test__.py` +### Quick Testing Reference -4. **Before submitting PR - Complete Validation Checklist**: - ```bash - # Activate virtual environment first as needed - source venv/bin/activate - - # Run all linting tools (must pass 100%) - ruff check . - black --check . - isort --check-only . - - # Auto-fix issues if needed - ruff check . --fix - black . - isort . - - # Run complete unit test suite (must pass 100%) - python -m pytest -xvs - - # Run simulator tests for tool changes - python communication_simulator_test.py - ``` +```bash +# Activate virtual environment +source venv/bin/activate -5. **GitHub Actions Compliance**: - - **Every single test must pass** - we have zero tolerance for failing tests in CI - - All linting must pass cleanly (ruff, black, isort) - - Import sorting must be correct - - Virtual environment activation is required for consistent results - - Tests failing in GitHub Actions will result in PR rejection +# Run linting checks +ruff check . && black --check . && isort --check-only . -6. **Contribution Standards**: - - Follow the [PR template](../.github/pull_request_template.md) requirements exactly - - Check every box in the template checklist before submitting - - Include comprehensive tests for all new functionality - - Ensure backward compatibility unless explicitly breaking +# Run unit tests +python -m pytest -xvs -Remember: Tests are documentation. They show how features are intended to be used and help prevent regressions. **Quality over speed** - take the time to ensure everything passes locally before pushing. \ No newline at end of file +# Run simulator tests (for tool changes) +python communication_simulator_test.py +``` + +Remember: All tests must pass before submitting a PR. See the [Contributing Guide](./contributions.md) for complete requirements. \ No newline at end of file diff --git a/server.py b/server.py index ae8d038..ecae98b 100644 --- a/server.py +++ b/server.py @@ -24,7 +24,7 @@ import os import sys import time from datetime import datetime -from logging.handlers import RotatingFileHandler, TimedRotatingFileHandler +from logging.handlers import RotatingFileHandler from typing import Any from mcp.server import Server diff --git a/simulator_tests/__init__.py b/simulator_tests/__init__.py index 2ec4f74..7e51b47 100644 --- a/simulator_tests/__init__.py +++ b/simulator_tests/__init__.py @@ -11,9 +11,11 @@ from .test_content_validation import ContentValidationTest from .test_conversation_chain_validation import ConversationChainValidationTest from .test_cross_tool_comprehensive import CrossToolComprehensiveTest from .test_cross_tool_continuation import CrossToolContinuationTest +from .test_line_number_validation import LineNumberValidationTest from .test_logs_validation import LogsValidationTest from .test_model_thinking_config import TestModelThinkingConfig from .test_o3_model_selection import O3ModelSelectionTest +from .test_o3_pro_expensive import O3ProExpensiveTest from .test_ollama_custom_url import OllamaCustomUrlTest from .test_openrouter_fallback import OpenRouterFallbackTest from .test_openrouter_models import OpenRouterModelsTest @@ -30,6 +32,7 @@ TEST_REGISTRY = { "per_tool_deduplication": PerToolDeduplicationTest, "cross_tool_continuation": CrossToolContinuationTest, "cross_tool_comprehensive": CrossToolComprehensiveTest, + "line_number_validation": LineNumberValidationTest, "logs_validation": LogsValidationTest, "redis_validation": RedisValidationTest, "model_thinking_config": TestModelThinkingConfig, @@ -41,6 +44,7 @@ TEST_REGISTRY = { "testgen_validation": TestGenValidationTest, "refactor_validation": RefactorValidationTest, "conversation_chain_validation": ConversationChainValidationTest, + # "o3_pro_expensive": O3ProExpensiveTest, # COMMENTED OUT - too expensive to run by default } __all__ = [ @@ -50,10 +54,12 @@ __all__ = [ "PerToolDeduplicationTest", "CrossToolContinuationTest", "CrossToolComprehensiveTest", + "LineNumberValidationTest", "LogsValidationTest", "RedisValidationTest", "TestModelThinkingConfig", "O3ModelSelectionTest", + "O3ProExpensiveTest", "OllamaCustomUrlTest", "OpenRouterFallbackTest", "OpenRouterModelsTest", diff --git a/simulator_tests/test_line_number_validation.py b/simulator_tests/test_line_number_validation.py new file mode 100644 index 0000000..714bb8d --- /dev/null +++ b/simulator_tests/test_line_number_validation.py @@ -0,0 +1,177 @@ +""" +Test to validate line number handling across different tools +""" + +import json +import os + +from .base_test import BaseSimulatorTest + + +class LineNumberValidationTest(BaseSimulatorTest): + """Test that validates correct line number handling in chat, analyze, and refactor tools""" + + @property + def test_name(self) -> str: + return "line_number_validation" + + @property + def test_description(self) -> str: + return "Line number handling validation across tools" + + def run_test(self) -> bool: + """Test line number handling in different tools""" + try: + self.logger.info("Test: Line number handling validation") + + # Setup test files + self.setup_test_files() + + # Create a test file with known content + test_file_content = '''# Example code with specific elements +def calculate_total(items): + """Calculate total with tax""" + subtotal = 0 + tax_rate = 0.08 # Line 5 - tax_rate defined + + for item in items: # Line 7 - loop starts + if item.price > 0: + subtotal += item.price + + tax_amount = subtotal * tax_rate # Line 11 + return subtotal + tax_amount + +def validate_data(data): + """Validate input data""" # Line 15 + required_fields = ["name", "email", "age"] # Line 16 + + for field in required_fields: + if field not in data: + raise ValueError(f"Missing field: {field}") + + return True # Line 22 +''' + + test_file_path = os.path.join(self.test_dir, "line_test.py") + with open(test_file_path, "w") as f: + f.write(test_file_content) + + self.logger.info(f"Created test file: {test_file_path}") + + # Test 1: Chat tool asking about specific line + self.logger.info(" 1.1: Testing chat tool with line number question") + content, continuation_id = self.call_mcp_tool( + "chat", + { + "prompt": "Where is tax_rate defined in this file? Please tell me the exact line number.", + "files": [test_file_path], + "model": "flash", + }, + ) + + if content: + # Check if the response mentions line 5 + if "line 5" in content.lower() or "line 5" in content: + self.logger.info(" ✅ Chat tool correctly identified tax_rate at line 5") + else: + self.logger.warning(f" ⚠️ Chat tool response didn't mention line 5: {content[:200]}...") + else: + self.logger.error(" ❌ Chat tool request failed") + return False + + # Test 2: Analyze tool with line number reference + self.logger.info(" 1.2: Testing analyze tool with line number analysis") + content, continuation_id = self.call_mcp_tool( + "analyze", + { + "prompt": "What happens between lines 7-11 in this code? Focus on the loop logic.", + "files": [test_file_path], + "model": "flash", + }, + ) + + if content: + # Check if the response references the loop + if any(term in content.lower() for term in ["loop", "iterate", "line 7", "lines 7"]): + self.logger.info(" ✅ Analyze tool correctly analyzed the specified line range") + else: + self.logger.warning(" ⚠️ Analyze tool response unclear about line range") + else: + self.logger.error(" ❌ Analyze tool request failed") + return False + + # Test 3: Refactor tool with line number precision + self.logger.info(" 1.3: Testing refactor tool line number precision") + content, continuation_id = self.call_mcp_tool( + "refactor", + { + "prompt": "Analyze this code for refactoring opportunities", + "files": [test_file_path], + "refactor_type": "codesmells", + "model": "flash", + }, + ) + + if content: + try: + # Parse the JSON response + result = json.loads(content) + if result.get("status") == "refactor_analysis_complete": + opportunities = result.get("refactor_opportunities", []) + if opportunities: + # Check if line numbers are precise + has_line_refs = any( + opp.get("start_line") is not None and opp.get("end_line") is not None + for opp in opportunities + ) + if has_line_refs: + self.logger.info(" ✅ Refactor tool provided precise line number references") + # Log some examples + for opp in opportunities[:2]: + if opp.get("start_line"): + self.logger.info( + f" - Issue at lines {opp['start_line']}-{opp['end_line']}: {opp.get('issue', '')[:50]}..." + ) + else: + self.logger.warning(" ⚠️ Refactor tool response missing line numbers") + else: + self.logger.info(" ℹ️ No refactoring opportunities found (code might be too clean)") + except json.JSONDecodeError: + self.logger.warning(" ⚠️ Refactor tool response not valid JSON") + else: + self.logger.error(" ❌ Refactor tool request failed") + return False + + # Test 4: Validate log patterns + self.logger.info(" 1.4: Validating line number processing in logs") + + # Get logs from container + result = self.run_command( + ["docker", "exec", self.container_name, "tail", "-500", "/tmp/mcp_server.log"], capture_output=True + ) + + logs = "" + if result.returncode == 0: + logs = result.stdout.decode() + + # Check for line number formatting patterns + line_number_patterns = ["Line numbers for", "enabled", "│", "line number"] # The line number separator + + found_patterns = 0 + for pattern in line_number_patterns: + if pattern in logs: + found_patterns += 1 + + self.logger.info(f" Found {found_patterns}/{len(line_number_patterns)} line number patterns in logs") + + if found_patterns >= 2: + self.logger.info(" ✅ Line number processing confirmed in logs") + else: + self.logger.warning(" ⚠️ Limited line number processing evidence in logs") + + self.logger.info(" ✅ Line number validation test completed successfully") + return True + + except Exception as e: + self.logger.error(f"Line number validation test failed: {type(e).__name__}: {e}") + return False diff --git a/simulator_tests/test_refactor_validation.py b/simulator_tests/test_refactor_validation.py index 8ee17e5..579a39f 100644 --- a/simulator_tests/test_refactor_validation.py +++ b/simulator_tests/test_refactor_validation.py @@ -9,6 +9,7 @@ Tests the refactor tool with a simple code smell example to validate: """ import json + from .base_test import BaseSimulatorTest @@ -32,7 +33,7 @@ class RefactorValidationTest(BaseSimulatorTest): self.setup_test_files() # Create a simple Python file with obvious code smells - code_with_smells = '''# Code with obvious smells for testing + code_with_smells = """# Code with obvious smells for testing def process_data(data): # Code smell: Magic number if len(data) > 42: @@ -57,22 +58,22 @@ def handle_everything(user_input, config, database): if not user_input: print("Error: No input") # Code smell: print instead of logging return - + # Processing processed = user_input.strip().lower() - + # Database operation connection = database.connect() data = connection.query("SELECT * FROM users") # Code smell: SQL in code - + # Business logic mixed with data access valid_users = [] for row in data: if row[2] == processed: # Code smell: Magic index valid_users.append(row) - + return valid_users -''' +""" # Create test file test_file = self.create_additional_test_file("smelly_code.py", code_with_smells) @@ -88,7 +89,7 @@ def handle_everything(user_input, config, database): "refactor_type": "codesmells", "model": "flash", "thinking_mode": "low", # Keep it fast for testing - } + }, ) if not response: @@ -96,14 +97,14 @@ def handle_everything(user_input, config, database): return False self.logger.info(" ✅ Got refactor response") - + # Parse response to check for line references try: response_data = json.loads(response) - + # Debug: log the response structure self.logger.debug(f"Response keys: {list(response_data.keys())}") - + # Extract the actual content if it's wrapped if "content" in response_data: # The actual refactoring data is in the content field @@ -114,93 +115,91 @@ def handle_everything(user_input, config, database): if content.endswith("```"): content = content[:-3] # Remove ``` content = content.strip() - + # Find the end of the JSON object - handle truncated responses # Count braces to find where the JSON ends brace_count = 0 json_end = -1 in_string = False escape_next = False - + for i, char in enumerate(content): if escape_next: escape_next = False continue - if char == '\\': + if char == "\\": escape_next = True continue if char == '"' and not escape_next: in_string = not in_string if not in_string: - if char == '{': + if char == "{": brace_count += 1 - elif char == '}': + elif char == "}": brace_count -= 1 if brace_count == 0: json_end = i + 1 break - + if json_end > 0: content = content[:json_end] - + # Parse the inner JSON inner_data = json.loads(content) self.logger.debug(f"Inner data keys: {list(inner_data.keys())}") else: inner_data = response_data - + # Check that we got refactoring suggestions (might be called refactor_opportunities) refactorings_key = None for key in ["refactorings", "refactor_opportunities"]: if key in inner_data: refactorings_key = key break - + if not refactorings_key: self.logger.error("No refactorings found in response") self.logger.error(f"Response structure: {json.dumps(inner_data, indent=2)[:500]}...") return False - + refactorings = inner_data[refactorings_key] if not isinstance(refactorings, list) or len(refactorings) == 0: self.logger.error("Empty refactorings list") return False - + # Validate that we have line references for code smells # Flash model typically detects these issues: # - Lines 4-18: process_data function (magic number, nested loops, duplicate code) # - Lines 11-14: duplicate code blocks # - Lines 21-40: handle_everything god function - expected_line_ranges = [ - (4, 18), # process_data function issues - (11, 14), # duplicate code - (21, 40), # god function - ] - + self.logger.debug(f"Refactorings found: {len(refactorings)}") for i, ref in enumerate(refactorings[:3]): # Log first 3 - self.logger.debug(f"Refactoring {i}: start_line={ref.get('start_line')}, end_line={ref.get('end_line')}, type={ref.get('type')}") - + self.logger.debug( + f"Refactoring {i}: start_line={ref.get('start_line')}, end_line={ref.get('end_line')}, type={ref.get('type')}" + ) + found_references = [] for refactoring in refactorings: # Check for line numbers in various fields start_line = refactoring.get("start_line") end_line = refactoring.get("end_line") location = refactoring.get("location", "") - + # Add found line numbers if start_line: found_references.append(f"line {start_line}") if end_line and end_line != start_line: found_references.append(f"line {end_line}") - + # Also extract from location string import re - line_matches = re.findall(r'line[s]?\s+(\d+)', location.lower()) + + line_matches = re.findall(r"line[s]?\s+(\d+)", location.lower()) found_references.extend([f"line {num}" for num in line_matches]) - + self.logger.info(f" 📍 Found line references: {found_references}") - + # Check that flash found the expected refactoring areas found_ranges = [] for refactoring in refactorings: @@ -208,71 +207,70 @@ def handle_everything(user_input, config, database): end = refactoring.get("end_line") if start and end: found_ranges.append((start, end)) - + self.logger.info(f" 📍 Found refactoring ranges: {found_ranges}") - + # Verify we found issues in the main problem areas # Check if we have issues detected in process_data function area (lines 2-18) process_data_issues = [r for r in found_ranges if r[0] >= 2 and r[1] <= 18] # Check if we have issues detected in handle_everything function area (lines 21-40) god_function_issues = [r for r in found_ranges if r[0] >= 21 and r[1] <= 40] - + self.logger.info(f" 📍 Issues in process_data area (lines 2-18): {len(process_data_issues)}") self.logger.info(f" 📍 Issues in handle_everything area (lines 21-40): {len(god_function_issues)}") - + if len(process_data_issues) >= 1 and len(god_function_issues) >= 1: - self.logger.info(f" ✅ Flash correctly identified code smells in both major areas") + self.logger.info(" ✅ Flash correctly identified code smells in both major areas") self.logger.info(f" ✅ Found {len(refactorings)} total refactoring opportunities") - + # Verify we have reasonable number of total issues if len(refactorings) >= 3: - self.logger.info(f" ✅ Refactoring analysis validation passed") + self.logger.info(" ✅ Refactoring analysis validation passed") else: self.logger.warning(f" ⚠️ Only {len(refactorings)} refactorings found (expected >= 3)") else: - self.logger.error(f" ❌ Flash didn't find enough issues in expected areas") + self.logger.error(" ❌ Flash didn't find enough issues in expected areas") self.logger.error(f" - process_data area: found {len(process_data_issues)}, expected >= 1") self.logger.error(f" - handle_everything area: found {len(god_function_issues)}, expected >= 1") return False - + except json.JSONDecodeError as e: self.logger.error(f"Failed to parse refactor response as JSON: {e}") return False - + # Validate logs self.logger.info(" 📋 Validating execution logs...") - + # Get server logs from the actual log file inside the container result = self.run_command( - ["docker", "exec", self.container_name, "tail", "-500", "/tmp/mcp_server.log"], - capture_output=True + ["docker", "exec", self.container_name, "tail", "-500", "/tmp/mcp_server.log"], capture_output=True ) - + if result.returncode == 0: logs = result.stdout.decode() + result.stderr.decode() - + # Look for refactor tool execution patterns refactor_patterns = [ "[REFACTOR]", "refactor tool", "codesmells", "Token budget", - "Code files embedded successfully" + "Code files embedded successfully", ] - + patterns_found = 0 for pattern in refactor_patterns: if pattern in logs: patterns_found += 1 self.logger.debug(f" ✅ Found log pattern: {pattern}") - + if patterns_found >= 3: self.logger.info(f" ✅ Log validation passed ({patterns_found}/{len(refactor_patterns)} patterns)") else: self.logger.warning(f" ⚠️ Only found {patterns_found}/{len(refactor_patterns)} log patterns") else: self.logger.warning(" ⚠️ Could not retrieve Docker logs") - + self.logger.info(" ✅ Refactor tool validation completed successfully") return True @@ -280,4 +278,4 @@ def handle_everything(user_input, config, database): self.logger.error(f"Refactor validation test failed: {e}") return False finally: - self.cleanup_test_files() \ No newline at end of file + self.cleanup_test_files() diff --git a/systemprompts/analyze_prompt.py b/systemprompts/analyze_prompt.py index 0026ffc..7460042 100644 --- a/systemprompts/analyze_prompt.py +++ b/systemprompts/analyze_prompt.py @@ -8,6 +8,13 @@ You are a senior software analyst performing a holistic technical audit of the g to help engineers understand how a codebase aligns with long-term goals, architectural soundness, scalability, and maintainability—not just spot routine code-review issues. +CRITICAL LINE NUMBER INSTRUCTIONS +Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be +included in any code you generate. Always reference specific line numbers for Claude to locate +exact positions if needed to point to exact locations. Include a very short code excerpt alongside for clarity. +Include context_start_text and context_end_text as backup references. Never include "LINE│" markers in generated code +snippets. + IF MORE INFORMATION IS NEEDED If you need additional context (e.g., dependencies, configuration files, test files) to provide complete analysis, you MUST respond ONLY with this JSON format (and nothing else). Do NOT ask for the same file you've been provided unless diff --git a/systemprompts/chat_prompt.py b/systemprompts/chat_prompt.py index 61445cb..0a0b6a8 100644 --- a/systemprompts/chat_prompt.py +++ b/systemprompts/chat_prompt.py @@ -6,6 +6,13 @@ CHAT_PROMPT = """ You are a senior engineering thought-partner collaborating with Claude. Your mission is to brainstorm, validate ideas, and offer well-reasoned second opinions on technical decisions. +CRITICAL LINE NUMBER INSTRUCTIONS +Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be +included in any code you generate. Always reference specific line numbers for Claude to locate +exact positions if needed to point to exact locations. Include a very short code excerpt alongside for clarity. +Include context_start_text and context_end_text as backup references. Never include "LINE│" markers in generated code +snippets. + IF MORE INFORMATION IS NEEDED If Claude is discussing specific code, functions, or project components that was not given as part of the context, and you need additional context (e.g., related files, configuration, dependencies, test files) to provide meaningful diff --git a/systemprompts/codereview_prompt.py b/systemprompts/codereview_prompt.py index 5a5cd22..3665845 100644 --- a/systemprompts/codereview_prompt.py +++ b/systemprompts/codereview_prompt.py @@ -8,6 +8,13 @@ You are an expert code reviewer with deep knowledge of software-engineering best performance, maintainability, and architecture. Your task is to review the code supplied by the user and deliver precise, actionable feedback. +CRITICAL LINE NUMBER INSTRUCTIONS +Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be +included in any code you generate. Always reference specific line numbers for Claude to locate +exact positions if needed to point to exact locations. Include a very short code excerpt alongside for clarity. +Include context_start_text and context_end_text as backup references. Never include "LINE│" markers in generated code +snippets. + IF MORE INFORMATION IS NEEDED If you need additional context (e.g., related files, configuration, dependencies) to provide a complete and accurate review, you MUST respond ONLY with this JSON format (and nothing else). Do NOT ask for the @@ -15,11 +22,6 @@ same file you've been provided unless for some reason its content is missing or {"status": "clarification_required", "question": "", "files_needed": ["[file name here]", "[or some folder/]"]} -CRITICAL LINE NUMBER INSTRUCTIONS -Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be included -in any code you generate. Always reference specific line numbers for precise feedback. Include exact line numbers in -your issue descriptions. - CRITICAL: Align your review with the user's context and expectations. Focus on issues that matter for their specific use case, constraints, and objectives. Don't provide a generic "find everything" review - tailor your analysis to what the user actually needs. diff --git a/systemprompts/debug_prompt.py b/systemprompts/debug_prompt.py index a92e6b2..738accd 100644 --- a/systemprompts/debug_prompt.py +++ b/systemprompts/debug_prompt.py @@ -7,6 +7,13 @@ ROLE You are an expert debugger and problem-solver. Analyze errors, trace root causes, and propose the minimal fix required. Bugs can ONLY be found and fixed from given code. These cannot be made up or imagined. +CRITICAL LINE NUMBER INSTRUCTIONS +Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be +included in any code you generate. Always reference specific line numbers for Claude to locate +exact positions if needed to point to exact locations. Include a very short code excerpt alongside for clarity. +Include context_start_text and context_end_text as backup references. Never include "LINE│" markers in generated code +snippets. + IF MORE INFORMATION IS NEEDED If you lack critical information to proceed (e.g., missing files, ambiguous error details, insufficient context), OR if the provided diagnostics (log files, crash reports, stack traces) appear irrelevant, @@ -15,11 +22,6 @@ Do NOT ask for the same file you've been provided unless for some reason its con {"status": "clarification_required", "question": "", "files_needed": ["[file name here]", "[or some folder/]"]} -CRITICAL LINE NUMBER INSTRUCTIONS -Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be included -in any code you generate. Always reference specific line numbers for precise feedback. Include exact line numbers in -your issue descriptions. - CRITICAL: Your primary objective is to identify the root cause of the specific issue at hand and suggest the minimal fix required to resolve it. Stay focused on the main problem - avoid suggesting extensive refactoring, architectural changes, or unrelated improvements. diff --git a/systemprompts/precommit_prompt.py b/systemprompts/precommit_prompt.py index c6888a0..0c9a60b 100644 --- a/systemprompts/precommit_prompt.py +++ b/systemprompts/precommit_prompt.py @@ -6,6 +6,13 @@ PRECOMMIT_PROMPT = """ ROLE You are an expert pre-commit reviewer. Analyse git diffs as a senior developer giving a final sign-off to production. +CRITICAL LINE NUMBER INSTRUCTIONS +Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be +included in any code you generate. Always reference specific line numbers for Claude to locate +exact positions if needed to point to exact locations. Include a very short code excerpt alongside for clarity. +Include context_start_text and context_end_text as backup references. Never include "LINE│" markers in generated code +snippets. + IF MORE INFORMATION IS NEEDED If you need additional context (e.g., related files not in the diff, test files, configuration) to provide thorough analysis and without this context your review would be ineffective or biased, you MUST respond ONLY with this JSON diff --git a/systemprompts/refactor_prompt.py b/systemprompts/refactor_prompt.py index d560c2b..9232cd7 100644 --- a/systemprompts/refactor_prompt.py +++ b/systemprompts/refactor_prompt.py @@ -8,6 +8,12 @@ You are a principal software engineer specializing in intelligent code refactori opportunities and provide precise, actionable suggestions with exact line-number references that Claude can implement directly. +CRITICAL LINE NUMBER INSTRUCTIONS +Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be +included in any code you generate. Always reference specific line numbers for Claude to locate exact positions. +Include context_start_text and context_end_text as backup references. Never include "LINE│" markers in generated code +snippets. + IF MORE INFORMATION IS NEEDED If you need additional context (e.g., related files, configuration, dependencies) to provide accurate refactoring recommendations, you MUST respond ONLY with this JSON format (and nothing else). Do NOT ask for the same file you've @@ -92,12 +98,6 @@ handling and type safety. NOTE: Can only be applied AFTER decomposition if large **organization**: Improve organization and structure - group related functionality, improve file structure, standardize naming, clarify module boundaries. NOTE: Can only be applied AFTER decomposition if large files exist. -CRITICAL LINE NUMBER INSTRUCTIONS -Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be -included in any code you generate. Always reference specific line numbers for Claude to locate exact positions. -Include context_start_text and context_end_text as backup references. Never include "LINE│" markers in generated code -snippets. - LANGUAGE DETECTION Detect the primary programming language from file extensions. Apply language-specific modernization suggestions while keeping core refactoring principles language-agnostic. diff --git a/systemprompts/testgen_prompt.py b/systemprompts/testgen_prompt.py index 3166ddd..0d8e2de 100644 --- a/systemprompts/testgen_prompt.py +++ b/systemprompts/testgen_prompt.py @@ -8,6 +8,13 @@ You are a principal software engineer who specialises in writing bullet-proof pr high-signal test suites. You reason about control flow, data flow, mutation, concurrency, failure modes, and security in equal measure. Your mission: design and write tests that surface real-world defects before code ever leaves CI. +CRITICAL LINE NUMBER INSTRUCTIONS +Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be +included in any code you generate. Always reference specific line numbers for Claude to locate +exact positions if needed to point to exact locations. Include a very short code excerpt alongside for clarity. +Include context_start_text and context_end_text as backup references. Never include "LINE│" markers in generated code +snippets. + IF MORE INFORMATION IS NEEDED If you need additional context (e.g., test framework details, dependencies, existing test patterns) to provide accurate test generation, you MUST respond ONLY with this JSON format (and nothing else). Do NOT ask for the @@ -15,11 +22,6 @@ same file you've been provided unless for some reason its content is missing or {"status": "clarification_required", "question": "", "files_needed": ["[file name here]", "[or some folder/]"]} -CRITICAL LINE NUMBER INSTRUCTIONS -Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be included -in any code you generate. Always reference specific line numbers for precise feedback. Include exact line numbers in -your issue descriptions. - MULTI-AGENT WORKFLOW You sequentially inhabit five expert personas—each passes a concise artefact to the next: diff --git a/systemprompts/thinkdeep_prompt.py b/systemprompts/thinkdeep_prompt.py index f4bdc68..2e48397 100644 --- a/systemprompts/thinkdeep_prompt.py +++ b/systemprompts/thinkdeep_prompt.py @@ -7,6 +7,13 @@ ROLE You are a senior engineering collaborator working with Claude on complex software problems. Claude will send you content—analysis, prompts, questions, ideas, or theories—to deepen, validate, and extend. +CRITICAL LINE NUMBER INSTRUCTIONS +Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be +included in any code you generate. Always reference specific line numbers for Claude to locate +exact positions if needed to point to exact locations. Include a very short code excerpt alongside for clarity. +Include context_start_text and context_end_text as backup references. Never include "LINE│" markers in generated code +snippets. + IF MORE INFORMATION IS NEEDED If you need additional context (e.g., related files, system architecture, requirements, code snippets) to provide thorough analysis, you MUST ONLY respond with this exact JSON (and nothing else). Do NOT ask for the same file you've diff --git a/test_line_number_accuracy.py b/test_line_number_accuracy.py new file mode 100644 index 0000000..7c11d3f --- /dev/null +++ b/test_line_number_accuracy.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +""" +Test script to verify line number accuracy in the MCP server +""" + +import asyncio +import json + +from tools.analyze import AnalyzeTool +from tools.chat import ChatTool + + +async def test_line_number_reporting(): + """Test if tools report accurate line numbers when analyzing code""" + + print("=== Testing Line Number Accuracy ===\n") + + # Test 1: Analyze tool with line numbers + analyze_tool = AnalyzeTool() + + # Create a request that asks about specific line numbers + analyze_request = { + "files": ["/Users/fahad/Developer/gemini-mcp-server/test_line_numbers.py"], + "prompt": "Find all the lines where 'ignore_patterns' is assigned a list value. Report the exact line numbers.", + "model": "flash", # Use a real model + } + + print("1. Testing Analyze tool:") + print(f" Prompt: {analyze_request['prompt']}") + + try: + response = await analyze_tool.execute(analyze_request) + result = json.loads(response[0].text) + + if result["status"] == "success": + print(f" Response excerpt: {result['content'][:200]}...") + else: + print(f" Error: {result}") + except Exception as e: + print(f" Exception: {e}") + + print("\n" + "=" * 50 + "\n") + + # Test 2: Chat tool to simulate the user's scenario + chat_tool = ChatTool() + + chat_request = { + "files": ["/Users/fahad/Developer/loganalyzer/main.py"], + "prompt": "Tell me the exact line number where 'ignore_patterns' is assigned a list in the file. Be precise about the line number.", + "model": "flash", + } + + print("2. Testing Chat tool with user's actual file:") + print(f" File: {chat_request['files'][0]}") + print(f" Prompt: {chat_request['prompt']}") + + try: + response = await chat_tool.execute(chat_request) + result = json.loads(response[0].text) + + if result["status"] == "success": + print(f" Response excerpt: {result['content'][:300]}...") + else: + print(f" Error: {result}") + except Exception as e: + print(f" Exception: {e}") + + print("\n=== Test Complete ===") + + +if __name__ == "__main__": + asyncio.run(test_line_number_reporting()) diff --git a/test_line_numbers.py b/test_line_numbers.py new file mode 100644 index 0000000..8dcd963 --- /dev/null +++ b/test_line_numbers.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +"""Test file to verify line number accuracy""" + + +# Line 4: Empty line above +def example_function(): + """Line 6: Docstring""" + # Line 7: Comment + pass # Line 8 + + +# Line 10: Another comment +class TestClass: + """Line 12: Class docstring""" + + def __init__(self): + """Line 15: Init docstring""" + # Line 16: This is where we'll test + self.test_variable = "Line 17" + + def method_one(self): + """Line 20: Method docstring""" + # Line 21: Important assignment below + ignore_patterns = ["pattern1", "pattern2", "pattern3"] # Line 22: This is our test line + return ignore_patterns + + +# Line 25: More code below +def another_function(): + """Line 27: Another docstring""" + # Line 28: Another assignment + ignore_patterns = ["different", "patterns"] # Line 29: Second occurrence + return ignore_patterns + + +# Line 32: End of file marker diff --git a/tests/test_line_numbers_integration.py b/tests/test_line_numbers_integration.py new file mode 100644 index 0000000..2c43f77 --- /dev/null +++ b/tests/test_line_numbers_integration.py @@ -0,0 +1,49 @@ +""" +Integration test demonstrating that all tools get line numbers by default. +""" + +from tools.analyze import AnalyzeTool +from tools.chat import ChatTool +from tools.codereview import CodeReviewTool +from tools.debug import DebugIssueTool +from tools.precommit import Precommit +from tools.refactor import RefactorTool +from tools.testgen import TestGenTool + + +class TestLineNumbersIntegration: + """Test that all tools inherit line number behavior correctly.""" + + def test_all_tools_want_line_numbers(self): + """Verify that all tools want line numbers by default.""" + tools = [ + ChatTool(), + AnalyzeTool(), + CodeReviewTool(), + DebugIssueTool(), + RefactorTool(), + TestGenTool(), + Precommit(), + ] + + for tool in tools: + assert tool.wants_line_numbers_by_default(), f"{tool.get_name()} should want line numbers by default" + + def test_no_tools_override_line_numbers(self): + """Verify that no tools override the base class line number behavior.""" + # Check that tools don't have their own wants_line_numbers_by_default method + tools_classes = [ + ChatTool, + AnalyzeTool, + CodeReviewTool, + DebugIssueTool, + RefactorTool, + TestGenTool, + Precommit, + ] + + for tool_class in tools_classes: + # Check if the method is defined in the tool class itself + # (not inherited from base) + has_override = "wants_line_numbers_by_default" in tool_class.__dict__ + assert not has_override, f"{tool_class.__name__} should not override wants_line_numbers_by_default" diff --git a/tests/test_precommit_diff_formatting.py b/tests/test_precommit_diff_formatting.py new file mode 100644 index 0000000..4ee42cb --- /dev/null +++ b/tests/test_precommit_diff_formatting.py @@ -0,0 +1,163 @@ +""" +Test to verify that precommit tool formats diffs correctly without line numbers. +This test focuses on the diff formatting logic rather than full integration. +""" + +from tools.precommit import Precommit + + +class TestPrecommitDiffFormatting: + """Test that precommit correctly formats diffs without line numbers.""" + + def test_git_diff_formatting_has_no_line_numbers(self): + """Test that git diff output is preserved without line number additions.""" + # Sample git diff output + git_diff = """diff --git a/example.py b/example.py +index 1234567..abcdefg 100644 +--- a/example.py ++++ b/example.py +@@ -1,5 +1,8 @@ + def hello(): +- print("Hello, World!") ++ print("Hello, Universe!") # Changed this line + + def goodbye(): + print("Goodbye!") ++ ++def new_function(): ++ print("This is new") +""" + + # Simulate how precommit formats a diff + repo_name = "test_repo" + file_path = "example.py" + diff_header = f"\n--- BEGIN DIFF: {repo_name} / {file_path} (unstaged) ---\n" + diff_footer = f"\n--- END DIFF: {repo_name} / {file_path} ---\n" + formatted_diff = diff_header + git_diff + diff_footer + + # Verify the diff doesn't contain line number markers (│) + assert "│" not in formatted_diff, "Git diffs should NOT have line number markers" + + # Verify the diff preserves git's own line markers + assert "@@ -1,5 +1,8 @@" in formatted_diff + assert '- print("Hello, World!")' in formatted_diff + assert '+ print("Hello, Universe!")' in formatted_diff + + def test_untracked_file_diff_formatting(self): + """Test that untracked files formatted as diffs don't have line numbers.""" + # Simulate untracked file content + file_content = """def new_function(): + return "I am new" + +class NewClass: + pass +""" + + # Simulate how precommit formats untracked files as diffs + repo_name = "test_repo" + file_path = "new_file.py" + + diff_header = f"\n--- BEGIN DIFF: {repo_name} / {file_path} (untracked - new file) ---\n" + diff_content = f"+++ b/{file_path}\n" + + # Add each line with + prefix (simulating new file diff) + for _line_num, line in enumerate(file_content.splitlines(), 1): + diff_content += f"+{line}\n" + + diff_footer = f"\n--- END DIFF: {repo_name} / {file_path} ---\n" + formatted_diff = diff_header + diff_content + diff_footer + + # Verify no line number markers + assert "│" not in formatted_diff, "Untracked file diffs should NOT have line number markers" + + # Verify diff format + assert "+++ b/new_file.py" in formatted_diff + assert "+def new_function():" in formatted_diff + assert '+ return "I am new"' in formatted_diff + + def test_compare_to_diff_formatting(self): + """Test that compare_to mode diffs don't have line numbers.""" + # Sample git diff for compare_to mode + git_diff = """diff --git a/config.py b/config.py +index abc123..def456 100644 +--- a/config.py ++++ b/config.py +@@ -10,7 +10,7 @@ class Config: + def __init__(self): + self.debug = False +- self.timeout = 30 ++ self.timeout = 60 # Increased timeout + self.retries = 3 +""" + + # Format as compare_to diff + repo_name = "test_repo" + file_path = "config.py" + compare_ref = "v1.0" + + diff_header = f"\n--- BEGIN DIFF: {repo_name} / {file_path} (compare to {compare_ref}) ---\n" + diff_footer = f"\n--- END DIFF: {repo_name} / {file_path} ---\n" + formatted_diff = diff_header + git_diff + diff_footer + + # Verify no line number markers + assert "│" not in formatted_diff, "Compare-to diffs should NOT have line number markers" + + # Verify diff markers + assert "@@ -10,7 +10,7 @@ class Config:" in formatted_diff + assert "- self.timeout = 30" in formatted_diff + assert "+ self.timeout = 60 # Increased timeout" in formatted_diff + + def test_base_tool_default_line_numbers(self): + """Test that the base tool wants line numbers by default.""" + tool = Precommit() + assert tool.wants_line_numbers_by_default(), "Base tool should want line numbers by default" + + def test_context_files_want_line_numbers(self): + """Test that precommit tool inherits base class behavior for line numbers.""" + tool = Precommit() + + # The precommit tool should want line numbers by default (inherited from base) + assert tool.wants_line_numbers_by_default() + + # This means when it calls read_files for context files, + # it will pass include_line_numbers=True + + def test_diff_sections_in_prompt(self): + """Test the structure of diff sections in the final prompt.""" + # Create sample prompt sections + diff_section = """ +## Git Diffs + +--- BEGIN DIFF: repo / file.py (staged) --- +diff --git a/file.py b/file.py +index 123..456 100644 +--- a/file.py ++++ b/file.py +@@ -1,3 +1,4 @@ + def main(): + print("Hello") ++ print("World") +--- END DIFF: repo / file.py --- +""" + + context_section = """ +## Additional Context Files +The following files are provided for additional context. They have NOT been modified. + +--- BEGIN FILE: /path/to/context.py --- + 1│ # Context file + 2│ def helper(): + 3│ pass +--- END FILE: /path/to/context.py --- +""" + + # Verify diff section has no line numbers + assert "│" not in diff_section, "Diff section should not have line number markers" + + # Verify context section has line numbers + assert "│" in context_section, "Context section should have line number markers" + + # Verify the sections are clearly separated + assert "## Git Diffs" in diff_section + assert "## Additional Context Files" in context_section + assert "have NOT been modified" in context_section diff --git a/tests/test_precommit_line_numbers.py b/tests/test_precommit_line_numbers.py new file mode 100644 index 0000000..5b5ae77 --- /dev/null +++ b/tests/test_precommit_line_numbers.py @@ -0,0 +1,165 @@ +""" +Test to verify that precommit tool handles line numbers correctly: +- Diffs should NOT have line numbers (they have their own diff markers) +- Additional context files SHOULD have line numbers +""" + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tools.precommit import Precommit, PrecommitRequest + + +class TestPrecommitLineNumbers: + """Test that precommit correctly handles line numbers for diffs vs context files.""" + + @pytest.fixture + def tool(self): + """Create a Precommit tool instance.""" + return Precommit() + + @pytest.fixture + def mock_provider(self): + """Create a mock provider.""" + provider = MagicMock() + provider.get_provider_type.return_value.value = "test" + + # Mock the model response + model_response = MagicMock() + model_response.content = "Test review response" + model_response.usage = {"total_tokens": 100} + model_response.metadata = {"finish_reason": "stop"} + model_response.friendly_name = "test-model" + + provider.generate_content = AsyncMock(return_value=model_response) + provider.get_capabilities.return_value = MagicMock( + context_window=200000, + temperature_constraint=MagicMock( + validate=lambda x: True, get_corrected_value=lambda x: x, get_description=lambda: "0.0 to 1.0" + ), + ) + provider.supports_thinking_mode.return_value = False + + return provider + + @pytest.mark.asyncio + async def test_diffs_have_no_line_numbers_but_context_files_do(self, tool, mock_provider, tmp_path): + """Test that git diffs don't have line numbers but context files do.""" + # Use the workspace root for test files + import tempfile + + test_workspace = tempfile.mkdtemp(prefix="test_precommit_") + + # Create a context file in the workspace + context_file = os.path.join(test_workspace, "context.py") + with open(context_file, "w") as f: + f.write( + """# This is a context file +def context_function(): + return "This should have line numbers" +""" + ) + + # Mock git commands to return predictable output + def mock_run_git_command(repo_path, command): + if command == ["status", "--porcelain"]: + return True, " M example.py" + elif command == ["diff", "--name-only"]: + return True, "example.py" + elif command == ["diff", "--", "example.py"]: + # Return a sample diff - this should NOT have line numbers added + return ( + True, + """diff --git a/example.py b/example.py +index 1234567..abcdefg 100644 +--- a/example.py ++++ b/example.py +@@ -1,5 +1,8 @@ + def hello(): +- print("Hello, World!") ++ print("Hello, Universe!") # Changed this line + + def goodbye(): + print("Goodbye!") ++ ++def new_function(): ++ print("This is new") +""", + ) + else: + return True, "" + + # Create request with context file + request = PrecommitRequest( + path=test_workspace, + prompt="Review my changes", + files=[context_file], # This should get line numbers + include_staged=False, + include_unstaged=True, + ) + + # Mock the tool's provider and git functions + with ( + patch.object(tool, "get_model_provider", return_value=mock_provider), + patch("tools.precommit.run_git_command", side_effect=mock_run_git_command), + patch("tools.precommit.find_git_repositories", return_value=[test_workspace]), + patch( + "tools.precommit.get_git_status", + return_value={ + "branch": "main", + "ahead": 0, + "behind": 0, + "staged_files": [], + "unstaged_files": ["example.py"], + "untracked_files": [], + }, + ), + ): + + # Prepare the prompt + prompt = await tool.prepare_prompt(request) + + # Print prompt sections for debugging if test fails + # print("\n=== PROMPT OUTPUT ===") + # print(prompt) + # print("=== END PROMPT ===\n") + + # Verify that diffs don't have line numbers + assert "--- BEGIN DIFF:" in prompt + assert "--- END DIFF:" in prompt + + # Check that the diff content doesn't have line number markers (│) + # Find diff section + diff_start = prompt.find("--- BEGIN DIFF:") + diff_end = prompt.find("--- END DIFF:", diff_start) + len("--- END DIFF:") + if diff_start != -1 and diff_end > diff_start: + diff_section = prompt[diff_start:diff_end] + assert "│" not in diff_section, "Diff section should NOT have line number markers" + + # Verify the diff has its own line markers + assert "@@ -1,5 +1,8 @@" in diff_section + assert '- print("Hello, World!")' in diff_section + assert '+ print("Hello, Universe!") # Changed this line' in diff_section + + # Verify that context files DO have line numbers + if "--- BEGIN FILE:" in prompt: + # Extract context file section + file_start = prompt.find("--- BEGIN FILE:") + file_end = prompt.find("--- END FILE:", file_start) + len("--- END FILE:") + if file_start != -1 and file_end > file_start: + context_section = prompt[file_start:file_end] + + # Context files should have line number markers + assert "│" in context_section, "Context file section SHOULD have line number markers" + + # Verify specific line numbers in context file + assert "1│ # This is a context file" in context_section + assert "2│ def context_function():" in context_section + assert '3│ return "This should have line numbers"' in context_section + + def test_base_tool_wants_line_numbers_by_default(self, tool): + """Verify that the base tool configuration wants line numbers by default.""" + # The precommit tool should inherit the base behavior + assert tool.wants_line_numbers_by_default(), "Base tool should want line numbers by default" diff --git a/tools/base.py b/tools/base.py index 9dc8c57..fa5bb77 100644 --- a/tools/base.py +++ b/tools/base.py @@ -399,20 +399,22 @@ class BaseTool(ABC): """ Return whether this tool wants line numbers added to code files by default. - Tools that benefit from precise line references (refactor, codereview, debug) - should return True. Tools that prioritize token efficiency or don't need - precise references can return False. + By default, ALL tools get line numbers for precise code references. + Line numbers are essential for accurate communication about code locations. Line numbers add ~8-10% token overhead but provide precise targeting for: - Code review feedback ("SQL injection on line 45") - Debug error locations ("Memory leak in loop at lines 123-156") - Test generation targets ("Generate tests for method at lines 78-95") - Refactoring guidance ("Extract method from lines 67-89") + - General code discussions ("Where is X defined?" -> "Line 42") + + The only exception is when reading diffs, which have their own line markers. Returns: bool: True if line numbers should be added by default for this tool """ - return False # Conservative default - tools opt-in as needed + return True # All tools get line numbers by default for consistency def get_default_thinking_mode(self) -> str: """ diff --git a/tools/codereview.py b/tools/codereview.py index 98fa2cc..0d18b14 100644 --- a/tools/codereview.py +++ b/tools/codereview.py @@ -148,9 +148,7 @@ class CodeReviewTool(BaseTool): def get_default_temperature(self) -> float: return TEMPERATURE_ANALYTICAL - def wants_line_numbers_by_default(self) -> bool: - """Code review tool needs line numbers for precise feedback""" - return True + # Line numbers are enabled by default from base class for precise feedback def get_request_model(self): return CodeReviewRequest diff --git a/tools/debug.py b/tools/debug.py index 22837de..3732150 100644 --- a/tools/debug.py +++ b/tools/debug.py @@ -111,9 +111,7 @@ class DebugIssueTool(BaseTool): def get_default_temperature(self) -> float: return TEMPERATURE_ANALYTICAL - def wants_line_numbers_by_default(self) -> bool: - """Debug tool needs line numbers for precise error location""" - return True + # Line numbers are enabled by default from base class for precise error location def get_model_category(self) -> "ToolModelCategory": """Debug requires deep analysis and reasoning""" diff --git a/tools/refactor.py b/tools/refactor.py index beae2f4..50a1156 100644 --- a/tools/refactor.py +++ b/tools/refactor.py @@ -143,9 +143,7 @@ class RefactorTool(BaseTool): def get_default_temperature(self) -> float: return TEMPERATURE_ANALYTICAL - def wants_line_numbers_by_default(self) -> bool: - """Refactor tool needs line numbers for precise targeting""" - return True + # Line numbers are enabled by default from base class for precise targeting def get_model_category(self): """Refactor tool requires extended reasoning for comprehensive analysis""" @@ -159,7 +157,7 @@ class RefactorTool(BaseTool): async def execute(self, arguments: dict[str, Any]) -> list[TextContent]: """Override execute to check prompt size before processing""" logger.info(f"[REFACTOR] execute called with arguments: {list(arguments.keys())}") - + # First validate request request_model = self.get_request_model() request = request_model(**arguments) @@ -168,10 +166,10 @@ class RefactorTool(BaseTool): if request.prompt: size_check = self.check_prompt_size(request.prompt) if size_check: - logger.info(f"[REFACTOR] Prompt size check triggered, returning early") + logger.info("[REFACTOR] Prompt size check triggered, returning early") return [TextContent(type="text", text=ToolOutput(**size_check).model_dump_json())] - logger.info(f"[REFACTOR] Prompt size OK, calling super().execute()") + logger.info("[REFACTOR] Prompt size OK, calling super().execute()") # Continue with normal execution return await super().execute(arguments) diff --git a/tools/testgen.py b/tools/testgen.py index 7f14898..05462a5 100644 --- a/tools/testgen.py +++ b/tools/testgen.py @@ -134,9 +134,7 @@ class TestGenTool(BaseTool): def get_default_temperature(self) -> float: return TEMPERATURE_ANALYTICAL - def wants_line_numbers_by_default(self) -> bool: - """Test generation tool needs line numbers for precise targeting""" - return True + # Line numbers are enabled by default from base class for precise targeting def get_model_category(self): """TestGen requires extended reasoning for comprehensive test analysis"""