diff --git a/README.md b/README.md index 49f4edf..c1acceb 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,6 @@ Just ask Claude naturally: - **Pre-commit validation?** → `review_changes` (validate git changes before committing) - **Something's broken?** → `debug_issue` (root cause analysis, error tracing) - **Want to understand code?** → `analyze` (architecture, patterns, dependencies) -- **Check models?** → `list_models` (see available Gemini models) - **Server info?** → `get_version` (version and configuration details) **Tools Overview:** @@ -132,8 +131,7 @@ Just ask Claude naturally: 4. [`review_changes`](#4-review_changes---pre-commit-validation) - Validate git changes before committing 5. [`debug_issue`](#5-debug_issue---expert-debugging-assistant) - Root cause analysis and debugging 6. [`analyze`](#6-analyze---smart-file-analysis) - General-purpose file and code analysis -7. [`list_models`](#7-list_models---see-available-gemini-models) - List available Gemini models -8. [`get_version`](#8-get_version---server-information) - Get server version and configuration +7. [`get_version`](#7-get_version---server-information) - Get server version and configuration ### 1. `chat` - General Development Chat & Collaborative Thinking **Your thinking partner - bounce ideas, get second opinions, brainstorm collaboratively** @@ -346,13 +344,7 @@ Combine your findings with gemini's to create a comprehensive security report." **Triggers:** analyze, examine, look at, understand, inspect -### 7. `list_models` - See Available Gemini Models -``` -"Use gemini to list available models" -"Get gemini to show me what models I can use" -``` - -### 8. `get_version` - Server Information +### 7. `get_version` - Server Information ``` "Use gemini for its version" "Get gemini to show server configuration" @@ -530,7 +522,7 @@ All tools support a `thinking_mode` parameter that controls Gemini's thinking bu The server includes several configurable properties that control its behavior: ### Model Configuration -- **`DEFAULT_MODEL`**: `"gemini-2.5-pro-preview-06-05"` - The latest Gemini 2.5 Pro model with native thinking support +- **`GEMINI_MODEL`**: `"gemini-2.5-pro-preview-06-05"` - The latest Gemini 2.5 Pro model with native thinking support - **`MAX_CONTEXT_TOKENS`**: `1,000,000` - Maximum input context (1M tokens for Gemini 2.5 Pro) ### Temperature Defaults diff --git a/config.py b/config.py index 8734eb5..61949e5 100644 --- a/config.py +++ b/config.py @@ -1,20 +1,43 @@ """ Configuration and constants for Gemini MCP Server + +This module centralizes all configuration settings for the Gemini MCP Server. +It defines model configurations, token limits, temperature defaults, and other +constants used throughout the application. + +Configuration values can be overridden by environment variables where appropriate. """ # Version and metadata -__version__ = "2.8.0" -__updated__ = "2025-09-09" -__author__ = "Fahad Gilani" +# These values are used in server responses and for tracking releases +__version__ = "2.8.0" # Semantic versioning: MAJOR.MINOR.PATCH +__updated__ = "2025-09-09" # Last update date in ISO format +__author__ = "Fahad Gilani" # Primary maintainer # Model configuration -DEFAULT_MODEL = "gemini-2.5-pro-preview-06-05" -THINKING_MODEL = ( - "gemini-2.0-flash-thinking-exp" # Enhanced reasoning model for think_deeper -) +# GEMINI_MODEL: The Gemini model used for all AI operations +# This should be a stable, high-performance model suitable for code analysis +GEMINI_MODEL = "gemini-2.5-pro-preview-06-05" + +# MAX_CONTEXT_TOKENS: Maximum number of tokens that can be included in a single request +# This limit includes both the prompt and expected response +# Gemini Pro models support up to 1M tokens, but practical usage should reserve +# space for the model's response (typically 50K-100K tokens reserved) MAX_CONTEXT_TOKENS = 1_000_000 # 1M tokens for Gemini Pro # Temperature defaults for different tool types +# Temperature controls the randomness/creativity of model responses +# Lower values (0.0-0.3) produce more deterministic, focused responses +# Higher values (0.7-1.0) produce more creative, varied responses + +# TEMPERATURE_ANALYTICAL: Used for tasks requiring precision and consistency +# Ideal for code review, debugging, and error analysis where accuracy is critical TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging + +# TEMPERATURE_BALANCED: Middle ground for general conversations +# Provides a good balance between consistency and helpful variety TEMPERATURE_BALANCED = 0.5 # For general chat + +# TEMPERATURE_CREATIVE: Higher temperature for exploratory tasks +# Used when brainstorming, exploring alternatives, or architectural discussions TEMPERATURE_CREATIVE = 0.7 # For architecture, deep thinking diff --git a/requirements.txt b/requirements.txt index d734008..5c740a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ mcp>=1.0.0 google-genai>=1.19.0 -python-dotenv>=1.0.0 pydantic>=2.0.0 # Development dependencies diff --git a/server.py b/server.py index bb06ada..5c464bb 100644 --- a/server.py +++ b/server.py @@ -1,5 +1,21 @@ """ Gemini MCP Server - Main server implementation + +This module implements the core MCP (Model Context Protocol) server that provides +AI-powered tools for code analysis, review, and assistance using Google's Gemini models. + +The server follows the MCP specification to expose various AI tools as callable functions +that can be used by MCP clients (like Claude). Each tool provides specialized functionality +such as code review, debugging, deep thinking, and general chat capabilities. + +Key Components: +- MCP Server: Handles protocol communication and tool discovery +- Tool Registry: Maps tool names to their implementations +- Request Handler: Processes incoming tool calls and returns formatted responses +- Configuration: Manages API keys and model settings + +The server runs on stdio (standard input/output) and communicates using JSON-RPC messages +as defined by the MCP protocol. """ import asyncio @@ -9,14 +25,13 @@ import sys from datetime import datetime from typing import Any, Dict, List -from google import genai from mcp.server import Server from mcp.server.models import InitializationOptions from mcp.server.stdio import stdio_server from mcp.types import TextContent, Tool from config import ( - DEFAULT_MODEL, + GEMINI_MODEL, MAX_CONTEXT_TOKENS, __author__, __updated__, @@ -31,41 +46,67 @@ from tools import ( ThinkDeeperTool, ) -# Configure logging +# Configure logging for server operations +# Set to INFO level to capture important operational messages without being too verbose logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Create the MCP server instance +# Create the MCP server instance with a unique name identifier +# This name is used by MCP clients to identify and connect to this specific server server: Server = Server("gemini-server") -# Initialize tools +# Initialize the tool registry with all available AI-powered tools +# Each tool provides specialized functionality for different development tasks +# Tools are instantiated once and reused across requests (stateless design) TOOLS = { - "think_deeper": ThinkDeeperTool(), - "review_code": ReviewCodeTool(), - "debug_issue": DebugIssueTool(), - "analyze": AnalyzeTool(), - "chat": ChatTool(), - "review_changes": ReviewChanges(), + "think_deeper": ThinkDeeperTool(), # Extended reasoning for complex problems + "review_code": ReviewCodeTool(), # Comprehensive code review and quality analysis + "debug_issue": DebugIssueTool(), # Root cause analysis and debugging assistance + "analyze": AnalyzeTool(), # General-purpose file and code analysis + "chat": ChatTool(), # Interactive development chat and brainstorming + "review_changes": ReviewChanges(), # Pre-commit review of git changes } def configure_gemini(): - """Configure Gemini API with the provided API key""" + """ + Configure Gemini API with the provided API key. + + This function validates that the GEMINI_API_KEY environment variable is set. + The actual API key is used when creating Gemini clients within individual tools + to ensure proper isolation and error handling. + + Raises: + ValueError: If GEMINI_API_KEY environment variable is not set + """ api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError( "GEMINI_API_KEY environment variable is required. " "Please set it with your Gemini API key." ) - # API key is used when creating clients in tools + # Note: We don't store the API key globally for security reasons + # Each tool creates its own Gemini client with the API key when needed logger.info("Gemini API key found") @server.list_tools() async def handle_list_tools() -> List[Tool]: - """List all available tools with verbose descriptions""" + """ + List all available tools with their descriptions and input schemas. + + This handler is called by MCP clients during initialization to discover + what tools are available. Each tool provides: + - name: Unique identifier for the tool + - description: Detailed explanation of what the tool does + - inputSchema: JSON Schema defining the expected parameters + + Returns: + List of Tool objects representing all available tools + """ tools = [] + # Add all registered AI-powered tools from the TOOLS registry for tool in TOOLS.values(): tools.append( Tool( @@ -75,17 +116,10 @@ async def handle_list_tools() -> List[Tool]: ) ) - # Add utility tools + # Add utility tools that provide server metadata and configuration info + # These tools don't require AI processing but are useful for clients tools.extend( [ - Tool( - name="list_models", - description=( - "LIST AVAILABLE MODELS - Show all Gemini models you can use. " - "Lists model names, descriptions, and which one is the default." - ), - inputSchema={"type": "object", "properties": {}}, - ), Tool( name="get_version", description=( @@ -102,100 +136,65 @@ async def handle_list_tools() -> List[Tool]: @server.call_tool() async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: - """Handle tool execution requests""" + """ + Handle incoming tool execution requests from MCP clients. - # Handle dynamic tools + This is the main request dispatcher that routes tool calls to their + appropriate handlers. It supports both AI-powered tools (from TOOLS registry) + and utility tools (implemented as static functions). + + Args: + name: The name of the tool to execute + arguments: Dictionary of arguments to pass to the tool + + Returns: + List of TextContent objects containing the tool's response + """ + + # Route to AI-powered tools that require Gemini API calls if name in TOOLS: tool = TOOLS[name] return await tool.execute(arguments) - # Handle static tools - elif name == "list_models": - return await handle_list_models() - + # Route to utility tools that provide server information elif name == "get_version": return await handle_get_version() + # Handle unknown tool requests gracefully else: return [TextContent(type="text", text=f"Unknown tool: {name}")] -async def handle_list_models() -> List[TextContent]: - """List available Gemini models""" - try: - import json - - # Get API key - api_key = os.getenv("GEMINI_API_KEY") - if not api_key: - return [TextContent(type="text", text="Error: GEMINI_API_KEY not set")] - - client = genai.Client(api_key=api_key) - models = [] - - # List models using the new API - try: - model_list = client.models.list() - for model_info in model_list: - models.append( - { - "name": getattr(model_info, "id", "Unknown"), - "display_name": getattr( - model_info, - "display_name", - getattr(model_info, "id", "Unknown"), - ), - "description": getattr( - model_info, "description", "No description" - ), - "is_default": getattr(model_info, "id", "").endswith( - DEFAULT_MODEL - ), - } - ) - - except Exception: - # Fallback: return some known models - models = [ - { - "name": "gemini-2.5-pro-preview-06-05", - "display_name": "Gemini 2.5 Pro", - "description": "Latest Gemini 2.5 Pro model", - "is_default": True, - }, - { - "name": "gemini-2.0-flash-thinking-exp", - "display_name": "Gemini 2.0 Flash Thinking", - "description": "Enhanced reasoning model", - "is_default": False, - }, - ] - - return [TextContent(type="text", text=json.dumps(models, indent=2))] - - except Exception as e: - return [TextContent(type="text", text=f"Error listing models: {str(e)}")] - - async def handle_get_version() -> List[TextContent]: - """Get version and configuration information""" + """ + Get comprehensive version and configuration information about the server. + + Provides details about the server version, configuration settings, + available tools, and runtime environment. Useful for debugging and + understanding the server's capabilities. + + Returns: + Formatted text with version and configuration details + """ + # Gather comprehensive server information version_info = { "version": __version__, "updated": __updated__, "author": __author__, - "default_model": DEFAULT_MODEL, + "gemini_model": GEMINI_MODEL, "max_context_tokens": f"{MAX_CONTEXT_TOKENS:,}", "python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", "server_started": datetime.now().isoformat(), - "available_tools": list(TOOLS.keys()) + ["chat", "list_models", "get_version"], + "available_tools": list(TOOLS.keys()) + ["get_version"], } + # Format the information in a human-readable way text = f"""Gemini MCP Server v{__version__} Updated: {__updated__} Author: {__author__} Configuration: -- Default Model: {DEFAULT_MODEL} +- Gemini Model: {GEMINI_MODEL} - Max Context: {MAX_CONTEXT_TOKENS:,} tokens - Python: {version_info['python_version']} - Started: {version_info['server_started']} @@ -209,11 +208,21 @@ For updates, visit: https://github.com/BeehiveInnovations/gemini-mcp-server""" async def main(): - """Main entry point for the server""" - # Configure Gemini API + """ + Main entry point for the MCP server. + + Initializes the Gemini API configuration and starts the server using + stdio transport. The server will continue running until the client + disconnects or an error occurs. + + The server communicates via standard input/output streams using the + MCP protocol's JSON-RPC message format. + """ + # Validate that Gemini API key is available before starting configure_gemini() - # Run the server using stdio transport + # Run the server using stdio transport (standard input/output) + # This allows the server to be launched by MCP clients as a subprocess async with stdio_server() as (read_stream, write_stream): await server.run( read_stream, @@ -221,7 +230,7 @@ async def main(): InitializationOptions( server_name="gemini", server_version=__version__, - capabilities={"tools": {}}, + capabilities={"tools": {}}, # Advertise tool support capability ), ) diff --git a/setup.py b/setup.py index 4df009f..491be40 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ setup( install_requires=[ "mcp>=1.0.0", "google-genai>=1.19.0", - "python-dotenv>=1.0.0", + "pydantic>=2.0.0", ], extras_require={ "dev": [ diff --git a/tests/test_config.py b/tests/test_config.py index 50c09c5..1582aa2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,7 +3,7 @@ Tests for configuration """ from config import ( - DEFAULT_MODEL, + GEMINI_MODEL, MAX_CONTEXT_TOKENS, TEMPERATURE_ANALYTICAL, TEMPERATURE_BALANCED, @@ -31,7 +31,7 @@ class TestConfig: def test_model_config(self): """Test model configuration""" - assert DEFAULT_MODEL == "gemini-2.5-pro-preview-06-05" + assert GEMINI_MODEL == "gemini-2.5-pro-preview-06-05" assert MAX_CONTEXT_TOKENS == 1_000_000 def test_temperature_defaults(self): diff --git a/tests/test_server.py b/tests/test_server.py index 08e0038..9874b27 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,7 +2,6 @@ Tests for the main server functionality """ -import json from unittest.mock import Mock, patch import pytest @@ -26,11 +25,10 @@ class TestServerTools: assert "analyze" in tool_names assert "chat" in tool_names assert "review_changes" in tool_names - assert "list_models" in tool_names assert "get_version" in tool_names - # Should have exactly 8 tools - assert len(tools) == 8 + # Should have exactly 7 tools + assert len(tools) == 7 # Check descriptions are verbose for tool in tools: @@ -69,22 +67,6 @@ class TestServerTools: assert response_data["status"] == "success" assert response_data["content"] == "Chat response" - @pytest.mark.asyncio - async def test_handle_list_models(self): - """Test listing models""" - result = await handle_call_tool("list_models", {}) - assert len(result) == 1 - - # Check if we got models or an error - text = result[0].text - if "Error" in text: - # API key not set in test environment - assert "GEMINI_API_KEY" in text - else: - # Should have models - models = json.loads(text) - assert len(models) >= 1 - @pytest.mark.asyncio async def test_handle_get_version(self): """Test getting version info""" diff --git a/tools/base.py b/tools/base.py index 427be1f..9475a66 100644 --- a/tools/base.py +++ b/tools/base.py @@ -1,5 +1,16 @@ """ Base class for all Gemini MCP tools + +This module provides the abstract base class that all tools must inherit from. +It defines the contract that tools must implement and provides common functionality +for request validation, error handling, and response formatting. + +Key responsibilities: +- Define the tool interface (abstract methods that must be implemented) +- Handle request validation and file path security +- Manage Gemini model creation with appropriate configurations +- Standardize response formatting and error handling +- Support for clarification requests when more information is needed """ from abc import ABC, abstractmethod @@ -16,7 +27,13 @@ from .models import ToolOutput, ClarificationRequest class ToolRequest(BaseModel): - """Base request model for all tools""" + """ + Base request model for all tools. + + This Pydantic model defines common parameters that can be used by any tool. + Tools can extend this model to add their specific parameters while inheriting + these common fields. + """ model: Optional[str] = Field( None, description="Model to use (defaults to Gemini 2.5 Pro)" @@ -24,6 +41,8 @@ class ToolRequest(BaseModel): temperature: Optional[float] = Field( None, description="Temperature for response (tool-specific defaults)" ) + # Thinking mode controls how much computational budget the model uses for reasoning + # Higher values allow for more complex reasoning but increase latency and cost thinking_mode: Optional[Literal["minimal", "low", "medium", "high", "max"]] = Field( None, description="Thinking depth: minimal (128), low (2048), medium (8192), high (16384), max (32768)", @@ -31,52 +50,130 @@ class ToolRequest(BaseModel): class BaseTool(ABC): - """Base class for all Gemini tools""" + """ + Abstract base class for all Gemini tools. + + This class defines the interface that all tools must implement and provides + common functionality for request handling, model creation, and response formatting. + + To create a new tool: + 1. Create a new class that inherits from BaseTool + 2. Implement all abstract methods + 3. Define a request model that inherits from ToolRequest + 4. Register the tool in server.py's TOOLS dictionary + """ def __init__(self): + # Cache tool metadata at initialization to avoid repeated calls self.name = self.get_name() self.description = self.get_description() self.default_temperature = self.get_default_temperature() @abstractmethod def get_name(self) -> str: - """Return the tool name""" + """ + Return the unique name identifier for this tool. + + This name is used by MCP clients to invoke the tool and must be + unique across all registered tools. + + Returns: + str: The tool's unique name (e.g., "review_code", "analyze") + """ pass @abstractmethod def get_description(self) -> str: - """Return the verbose tool description for Claude""" + """ + Return a detailed description of what this tool does. + + This description is shown to MCP clients (like Claude) to help them + understand when and how to use the tool. It should be comprehensive + and include trigger phrases. + + Returns: + str: Detailed tool description with usage examples + """ pass @abstractmethod def get_input_schema(self) -> Dict[str, Any]: - """Return the JSON schema for tool inputs""" + """ + Return the JSON Schema that defines this tool's parameters. + + This schema is used by MCP clients to validate inputs before + sending requests. It should match the tool's request model. + + Returns: + Dict[str, Any]: JSON Schema object defining required and optional parameters + """ pass @abstractmethod def get_system_prompt(self) -> str: - """Return the system prompt for this tool""" + """ + Return the system prompt that configures the AI model's behavior. + + This prompt sets the context and instructions for how the model + should approach the task. It's prepended to the user's request. + + Returns: + str: System prompt with role definition and instructions + """ pass def get_default_temperature(self) -> float: - """Return default temperature for this tool""" + """ + Return the default temperature setting for this tool. + + Override this method to set tool-specific temperature defaults. + Lower values (0.0-0.3) for analytical tasks, higher (0.7-1.0) for creative tasks. + + Returns: + float: Default temperature between 0.0 and 1.0 + """ return 0.5 def get_default_thinking_mode(self) -> str: - """Return default thinking_mode for this tool""" + """ + Return the default thinking mode for this tool. + + Thinking mode controls computational budget for reasoning. + Override for tools that need more or less reasoning depth. + + Returns: + str: One of "minimal", "low", "medium", "high", "max" + """ return "medium" # Default to medium thinking for better reasoning @abstractmethod def get_request_model(self): - """Return the Pydantic model for request validation""" + """ + Return the Pydantic model class used for validating requests. + + This model should inherit from ToolRequest and define all + parameters specific to this tool. + + Returns: + Type[ToolRequest]: The request model class + """ pass def validate_file_paths(self, request) -> Optional[str]: """ Validate that all file paths in the request are absolute. - Returns error message if validation fails, None if all paths are valid. + + This is a critical security function that prevents path traversal attacks + and ensures all file access is properly controlled. All file paths must + be absolute to avoid ambiguity and security issues. + + Args: + request: The validated request object + + Returns: + Optional[str]: Error message if validation fails, None if all paths are valid """ - # Check if request has 'files' attribute + # Check if request has 'files' attribute (used by most tools) if hasattr(request, "files") and request.files: for file_path in request.files: if not os.path.isabs(file_path): @@ -86,7 +183,7 @@ class BaseTool(ABC): f"Please provide the full absolute path starting with '/'" ) - # Check if request has 'path' attribute (for review_changes) + # Check if request has 'path' attribute (used by review_changes tool) if hasattr(request, "path") and request.path: if not os.path.isabs(request.path): return ( @@ -98,13 +195,31 @@ class BaseTool(ABC): return None async def execute(self, arguments: Dict[str, Any]) -> List[TextContent]: - """Execute the tool with given arguments""" + """ + Execute the tool with the provided arguments. + + This is the main entry point for tool execution. It handles: + 1. Request validation using the tool's Pydantic model + 2. File path security validation + 3. Prompt preparation + 4. Model creation and configuration + 5. Response generation and formatting + 6. Error handling and recovery + + Args: + arguments: Dictionary of arguments from the MCP client + + Returns: + List[TextContent]: Formatted response as MCP TextContent objects + """ try: - # Validate request + # Validate request using the tool's Pydantic model + # This ensures all required fields are present and properly typed request_model = self.get_request_model() request = request_model(**arguments) - # Validate file paths + # Validate file paths for security + # This prevents path traversal attacks and ensures proper access control path_error = self.validate_file_paths(request) if path_error: error_output = ToolOutput( @@ -114,13 +229,14 @@ class BaseTool(ABC): ) return [TextContent(type="text", text=error_output.model_dump_json())] - # Prepare the prompt + # Prepare the full prompt by combining system prompt with user request + # This is delegated to the tool implementation for customization prompt = await self.prepare_prompt(request) - # Get model configuration - from config import DEFAULT_MODEL + # Extract model configuration from request or use defaults + from config import GEMINI_MODEL - model_name = getattr(request, "model", None) or DEFAULT_MODEL + model_name = getattr(request, "model", None) or GEMINI_MODEL temperature = getattr(request, "temperature", None) if temperature is None: temperature = self.get_default_temperature() @@ -128,20 +244,23 @@ class BaseTool(ABC): if thinking_mode is None: thinking_mode = self.get_default_thinking_mode() - # Create and configure model + # Create model instance with appropriate configuration + # This handles both regular models and thinking-enabled models model = self.create_model(model_name, temperature, thinking_mode) - # Generate response + # Generate AI response using the configured model response = model.generate_content(prompt) - # Handle response and create standardized output + # Process the model's response if response.candidates and response.candidates[0].content.parts: raw_text = response.candidates[0].content.parts[0].text - # Check if this is a clarification request + # Parse response to check for clarification requests or format output tool_output = self._parse_response(raw_text, request) else: + # Handle cases where the model couldn't generate a response + # This might happen due to safety filters or other constraints finish_reason = ( response.candidates[0].finish_reason if response.candidates @@ -153,10 +272,12 @@ class BaseTool(ABC): content_type="text", ) - # Serialize the standardized output as JSON + # Return standardized JSON response for consistent client handling return [TextContent(type="text", text=tool_output.model_dump_json())] except Exception as e: + # Catch all exceptions to prevent server crashes + # Return error information in standardized format error_output = ToolOutput( status="error", content=f"Error in {self.name}: {str(e)}", @@ -165,7 +286,19 @@ class BaseTool(ABC): return [TextContent(type="text", text=error_output.model_dump_json())] def _parse_response(self, raw_text: str, request) -> ToolOutput: - """Parse the raw response and determine if it's a clarification request""" + """ + Parse the raw response and determine if it's a clarification request. + + Some tools may return JSON indicating they need more information. + This method detects such responses and formats them appropriately. + + Args: + raw_text: The raw text response from the model + request: The original request for context + + Returns: + ToolOutput: Standardized output object + """ try: # Try to parse as JSON to check for clarification requests potential_json = json.loads(raw_text.strip()) @@ -214,40 +347,79 @@ class BaseTool(ABC): @abstractmethod async def prepare_prompt(self, request) -> str: - """Prepare the full prompt for Gemini""" + """ + Prepare the complete prompt for the Gemini model. + + This method should combine the system prompt with the user's request + and any additional context (like file contents) needed for the task. + + Args: + request: The validated request object + + Returns: + str: Complete prompt ready for the model + """ pass def format_response(self, response: str, request) -> str: - """Format the response for display (can be overridden)""" + """ + Format the model's response for display. + + Override this method to add tool-specific formatting like headers, + summaries, or structured output. Default implementation returns + the response unchanged. + + Args: + response: The raw response from the model + request: The original request for context + + Returns: + str: Formatted response + """ return response def create_model( self, model_name: str, temperature: float, thinking_mode: str = "medium" ): - """Create a configured Gemini model with thinking configuration""" - # Map thinking modes to budget values + """ + Create a configured Gemini model instance. + + This method handles model creation with appropriate settings including + temperature and thinking budget configuration for models that support it. + + Args: + model_name: Name of the Gemini model to use + temperature: Temperature setting for response generation + thinking_mode: Thinking depth mode (affects computational budget) + + Returns: + Model instance configured and ready for generation + """ + # Map thinking modes to computational budget values + # Higher budgets allow for more complex reasoning but increase latency thinking_budgets = { - "minimal": 128, # Minimum for 2.5 Pro - "low": 2048, - "medium": 8192, - "high": 16384, - "max": 32768, + "minimal": 128, # Minimum for 2.5 Pro - fast responses + "low": 2048, # Light reasoning tasks + "medium": 8192, # Balanced reasoning (default) + "high": 16384, # Complex analysis + "max": 32768, # Maximum reasoning depth } thinking_budget = thinking_budgets.get(thinking_mode, 8192) - # For models supporting thinking config, use the new API - # Skip in test environment to allow mocking + # Gemini 2.5 models support thinking configuration for enhanced reasoning + # Skip special handling in test environment to allow mocking if "2.5" in model_name and not os.environ.get("PYTEST_CURRENT_TEST"): try: - # Get API key + # Retrieve API key for Gemini client creation api_key = os.environ.get("GEMINI_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY environment variable is required") client = genai.Client(api_key=api_key) - # Create a wrapper to match the expected interface + # Create a wrapper class to provide a consistent interface + # This abstracts the differences between API versions class ModelWrapper: def __init__( self, client, model_name, temperature, thinking_budget @@ -270,7 +442,8 @@ class BaseTool(ABC): ), ) - # Convert to match expected format + # Wrap the response to match the expected format + # This ensures compatibility across different API versions class ResponseWrapper: def __init__(self, text): self.text = text @@ -302,18 +475,19 @@ class BaseTool(ABC): return ModelWrapper(client, model_name, temperature, thinking_budget) except Exception: - # Fall back to regular genai model if new API fails + # Fall back to regular API if thinking configuration fails + # This ensures the tool remains functional even with API changes pass - # For non-2.5 models or if thinking not needed, use regular API - # Get API key + # For models that don't support thinking configuration, use standard API api_key = os.environ.get("GEMINI_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY environment variable is required") client = genai.Client(api_key=api_key) - # Create wrapper for consistency + # Create a simple wrapper for models without thinking configuration + # This provides the same interface as the thinking-enabled wrapper class SimpleModelWrapper: def __init__(self, client, model_name, temperature): self.client = client diff --git a/tools/review_code.py b/tools/review_code.py index 085f112..a4593f5 100644 --- a/tools/review_code.py +++ b/tools/review_code.py @@ -1,5 +1,17 @@ """ Code Review tool - Comprehensive code analysis and review + +This tool provides professional-grade code review capabilities using +Gemini's understanding of code patterns, best practices, and common issues. +It can analyze individual files or entire codebases, providing actionable +feedback categorized by severity. + +Key Features: +- Multi-file and directory support +- Configurable review types (full, security, performance, quick) +- Severity-based issue filtering +- Custom focus areas and coding standards +- Structured output with specific remediation steps """ from typing import Any, Dict, List, Optional @@ -14,7 +26,13 @@ from .base import BaseTool, ToolRequest class ReviewCodeRequest(ToolRequest): - """Request model for review_code tool""" + """ + Request model for the code review tool. + + This model defines all parameters that can be used to customize + the code review process, from selecting files to specifying + review focus and standards. + """ files: List[str] = Field( ..., @@ -36,7 +54,13 @@ class ReviewCodeRequest(ToolRequest): class ReviewCodeTool(BaseTool): - """Professional code review tool""" + """ + Professional code review tool implementation. + + This tool analyzes code for bugs, security vulnerabilities, performance + issues, and code quality problems. It provides detailed feedback with + severity ratings and specific remediation steps. + """ def get_name(self) -> str: return "review_code" @@ -105,11 +129,25 @@ class ReviewCodeTool(BaseTool): return ReviewCodeRequest async def prepare_prompt(self, request: ReviewCodeRequest) -> str: - """Prepare the code review prompt""" - # Read all files + """ + Prepare the code review prompt with customized instructions. + + This method reads the requested files, validates token limits, + and constructs a detailed prompt based on the review parameters. + + Args: + request: The validated review request + + Returns: + str: Complete prompt for the Gemini model + + Raises: + ValueError: If the code exceeds token limits + """ + # Read all requested files, expanding directories as needed file_content, summary = read_files(request.files) - # Check token limits + # Validate that the code fits within model context limits within_limit, estimated_tokens = check_token_limit(file_content) if not within_limit: raise ValueError( @@ -117,7 +155,7 @@ class ReviewCodeTool(BaseTool): f"Maximum is {MAX_CONTEXT_TOKENS:,} tokens." ) - # Build review instructions + # Build customized review instructions based on review type review_focus = [] if request.review_type == "security": review_focus.append( @@ -132,12 +170,15 @@ class ReviewCodeTool(BaseTool): "Provide a quick review focusing on critical issues only" ) + # Add any additional focus areas specified by the user if request.focus_on: review_focus.append(f"Pay special attention to: {request.focus_on}") + # Include custom coding standards if provided if request.standards: review_focus.append(f"Enforce these standards: {request.standards}") + # Apply severity filtering to reduce noise if requested if request.severity_filter != "all": review_focus.append( f"Only report issues of {request.severity_filter} severity or higher" @@ -145,7 +186,7 @@ class ReviewCodeTool(BaseTool): focus_instruction = "\n".join(review_focus) if review_focus else "" - # Combine everything + # Construct the complete prompt with system instructions and code full_prompt = f"""{self.get_system_prompt()} {focus_instruction} @@ -159,7 +200,19 @@ Please provide a comprehensive code review following the format specified in the return full_prompt def format_response(self, response: str, request: ReviewCodeRequest) -> str: - """Format the review response""" + """ + Format the review response with appropriate headers. + + Adds context about the review type and focus area to help + users understand the scope of the review. + + Args: + response: The raw review from the model + request: The original request for context + + Returns: + str: Formatted response with headers + """ header = f"Code Review ({request.review_type.upper()})" if request.focus_on: header += f" - Focus: {request.focus_on}" diff --git a/utils/file_utils.py b/utils/file_utils.py index 749bba3..f562a5f 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -1,5 +1,21 @@ """ File reading utilities with directory support and token management + +This module provides secure file access functionality for the MCP server. +It implements critical security measures to prevent unauthorized file access +and manages token limits to ensure efficient API usage. + +Key Features: +- Path validation and sandboxing to prevent directory traversal attacks +- Support for both individual files and recursive directory reading +- Token counting and management to stay within API limits +- Automatic file type detection and filtering +- Comprehensive error handling with informative messages + +Security Model: +- All file access is restricted to PROJECT_ROOT and its subdirectories +- Absolute paths are required to prevent ambiguity +- Symbolic links are resolved to ensure they stay within bounds """ import os @@ -10,9 +26,12 @@ from .token_utils import estimate_tokens, MAX_CONTEXT_TOKENS # Get project root from environment or use current directory # This defines the sandbox directory where file access is allowed +# Security: All file operations are restricted to this directory and its children PROJECT_ROOT = Path(os.environ.get("MCP_PROJECT_ROOT", os.getcwd())).resolve() -# Security: Prevent running with overly permissive root +# Critical Security Check: Prevent running with overly permissive root +# Setting PROJECT_ROOT to "/" would allow access to the entire filesystem, +# which is a severe security vulnerability if str(PROJECT_ROOT) == "/": raise RuntimeError( "Security Error: MCP_PROJECT_ROOT cannot be set to '/'. " @@ -20,7 +39,8 @@ if str(PROJECT_ROOT) == "/": ) -# Common code file extensions +# Common code file extensions that are automatically included when processing directories +# This set can be extended to support additional file types CODE_EXTENSIONS = { ".py", ".js", @@ -75,11 +95,16 @@ def resolve_and_validate_path(path_str: str) -> Path: """ Validates that a path is absolute and resolves it. + This is the primary security function that ensures all file access + is properly sandboxed. It enforces two critical security policies: + 1. All paths must be absolute (no ambiguity) + 2. All paths must resolve to within PROJECT_ROOT (sandboxing) + Args: path_str: Path string (must be absolute) Returns: - Resolved Path object + Resolved Path object that is guaranteed to be within PROJECT_ROOT Raises: ValueError: If path is not absolute @@ -88,17 +113,19 @@ def resolve_and_validate_path(path_str: str) -> Path: # Create a Path object from the user-provided path user_path = Path(path_str) - # Require absolute paths + # Security Policy 1: Require absolute paths to prevent ambiguity + # Relative paths could be interpreted differently depending on working directory if not user_path.is_absolute(): raise ValueError( f"Relative paths are not supported. Please provide an absolute path.\n" f"Received: {path_str}" ) - # Resolve the absolute path + # Resolve the absolute path (follows symlinks, removes .. and .) resolved_path = user_path.resolve() - # Security check: ensure the resolved path is within PROJECT_ROOT + # Security Policy 2: Ensure the resolved path is within PROJECT_ROOT + # This prevents directory traversal attacks (e.g., /project/../../../etc/passwd) try: resolved_path.relative_to(PROJECT_ROOT) except ValueError: @@ -115,12 +142,16 @@ def expand_paths(paths: List[str], extensions: Optional[Set[str]] = None) -> Lis """ Expand paths to individual files, handling both files and directories. + This function recursively walks directories to find all matching files. + It automatically filters out hidden files and common non-code directories + like __pycache__ to avoid including generated or system files. + Args: - paths: List of file or directory paths - extensions: Optional set of file extensions to include + paths: List of file or directory paths (must be absolute) + extensions: Optional set of file extensions to include (defaults to CODE_EXTENSIONS) Returns: - List of individual file paths + List of individual file paths, sorted for consistent ordering """ if extensions is None: extensions = CODE_EXTENSIONS @@ -130,9 +161,10 @@ def expand_paths(paths: List[str], extensions: Optional[Set[str]] = None) -> Lis for path in paths: try: + # Validate each path for security before processing path_obj = resolve_and_validate_path(path) except (ValueError, PermissionError): - # Skip invalid paths + # Skip invalid paths silently to allow partial success continue if not path_obj.exists(): @@ -145,51 +177,61 @@ def expand_paths(paths: List[str], extensions: Optional[Set[str]] = None) -> Lis seen.add(str(path_obj)) elif path_obj.is_dir(): - # Walk directory recursively + # Walk directory recursively to find all files for root, dirs, files in os.walk(path_obj): - # Skip hidden directories and __pycache__ + # Filter directories in-place to skip hidden and cache directories + # This prevents descending into .git, .venv, __pycache__, etc. dirs[:] = [ d for d in dirs if not d.startswith(".") and d != "__pycache__" ] for file in files: - # Skip hidden files + # Skip hidden files (e.g., .DS_Store, .gitignore) if file.startswith("."): continue file_path = Path(root) / file - # Check extension + # Filter by extension if specified if not extensions or file_path.suffix.lower() in extensions: full_path = str(file_path) + # Use set to prevent duplicates if full_path not in seen: expanded_files.append(full_path) seen.add(full_path) - # Sort for consistent ordering + # Sort for consistent ordering across different runs + # This makes output predictable and easier to debug expanded_files.sort() return expanded_files def read_file_content(file_path: str, max_size: int = 1_000_000) -> Tuple[str, int]: """ - Read a single file and format it for Gemini. + Read a single file and format it for inclusion in AI prompts. + + This function handles various error conditions gracefully and always + returns formatted content, even for errors. This ensures the AI model + gets context about what files were attempted but couldn't be read. Args: file_path: Path to file (must be absolute) - max_size: Maximum file size to read + max_size: Maximum file size to read (default 1MB to prevent memory issues) Returns: - (formatted_content, estimated_tokens) + Tuple of (formatted_content, estimated_tokens) + Content is wrapped with clear delimiters for AI parsing """ try: + # Validate path security before any file operations path = resolve_and_validate_path(file_path) except (ValueError, PermissionError) as e: + # Return error in a format that provides context to the AI content = f"\n--- ERROR ACCESSING FILE: {file_path} ---\nError: {str(e)}\n--- END FILE ---\n" return content, estimate_tokens(content) try: - # Check if path exists and is a file + # Validate file existence and type if not path.exists(): content = f"\n--- FILE NOT FOUND: {file_path} ---\nError: File does not exist\n--- END FILE ---\n" return content, estimate_tokens(content) @@ -198,17 +240,19 @@ def read_file_content(file_path: str, max_size: int = 1_000_000) -> Tuple[str, i content = f"\n--- NOT A FILE: {file_path} ---\nError: Path is not a file\n--- END FILE ---\n" return content, estimate_tokens(content) - # Check file size + # Check file size to prevent memory exhaustion file_size = path.stat().st_size if file_size > max_size: content = f"\n--- FILE TOO LARGE: {file_path} ---\nFile size: {file_size:,} bytes (max: {max_size:,})\n--- END FILE ---\n" return content, estimate_tokens(content) - # Read the file + # Read the file with UTF-8 encoding, replacing invalid characters + # This ensures we can handle files with mixed encodings with open(path, "r", encoding="utf-8", errors="replace") as f: file_content = f.read() - # Format with clear delimiters for Gemini + # Format with clear delimiters that help the AI understand file boundaries + # Using consistent markers makes it easier for the model to parse formatted = f"\n--- BEGIN FILE: {file_path} ---\n{file_content}\n--- END FILE: {file_path} ---\n" return formatted, estimate_tokens(formatted) @@ -226,14 +270,21 @@ def read_files( """ Read multiple files and optional direct code with smart token management. + This function implements intelligent token budgeting to maximize the amount + of relevant content that can be included in an AI prompt while staying + within token limits. It prioritizes direct code and reads files until + the token budget is exhausted. + Args: - file_paths: List of file or directory paths - code: Optional direct code to include + file_paths: List of file or directory paths (absolute paths required) + code: Optional direct code to include (prioritized over files) max_tokens: Maximum tokens to use (defaults to MAX_CONTEXT_TOKENS) - reserve_tokens: Tokens to reserve for prompt and response + reserve_tokens: Tokens to reserve for prompt and response (default 50K) Returns: - (full_content, brief_summary) + Tuple of (full_content, brief_summary) + - full_content: All file contents formatted for AI consumption + - brief_summary: Human-readable summary of what was processed """ if max_tokens is None: max_tokens = MAX_CONTEXT_TOKENS @@ -247,7 +298,8 @@ def read_files( files_skipped = [] dirs_processed = [] - # First, handle direct code if provided + # Priority 1: Handle direct code if provided + # Direct code is prioritized because it's explicitly provided by the user if code: formatted_code = ( f"\n--- BEGIN DIRECT CODE ---\n{code}\n--- END DIRECT CODE ---\n" @@ -258,19 +310,23 @@ def read_files( content_parts.append(formatted_code) total_tokens += code_tokens available_tokens -= code_tokens + # Create a preview for the summary code_preview = code[:50] + "..." if len(code) > 50 else code summary_parts.append(f"Direct code: {code_preview}") else: summary_parts.append("Direct code skipped (too large)") - # Expand all paths to get individual files + # Priority 2: Process file paths if file_paths: - # Track which paths are directories + # Track which paths are directories for summary for path in file_paths: - if Path(path).is_dir(): - dirs_processed.append(path) + try: + if Path(path).is_dir(): + dirs_processed.append(path) + except Exception: + pass # Ignore invalid paths - # Expand to get all files + # Expand directories to get all individual files all_files = expand_paths(file_paths) if not all_files and file_paths: @@ -279,7 +335,7 @@ def read_files( f"\n--- NO FILES FOUND ---\nProvided paths: {', '.join(file_paths)}\n--- END ---\n" ) else: - # Read files up to token limit + # Read files sequentially until token limit is reached for file_path in all_files: if total_tokens >= available_tokens: files_skipped.append(file_path) @@ -293,9 +349,10 @@ def read_files( total_tokens += file_tokens files_read.append(file_path) else: + # File too large for remaining budget files_skipped.append(file_path) - # Build summary + # Build human-readable summary of what was processed if dirs_processed: summary_parts.append(f"Processed {len(dirs_processed)} dir(s)") if files_read: @@ -305,11 +362,12 @@ def read_files( if total_tokens > 0: summary_parts.append(f"~{total_tokens:,} tokens used") - # Add skipped files note if any were skipped + # Add informative note about skipped files to help users understand + # what was omitted and why if files_skipped: skip_note = "\n\n--- SKIPPED FILES (TOKEN LIMIT) ---\n" skip_note += f"Total skipped: {len(files_skipped)}\n" - # Show first 10 skipped files + # Show first 10 skipped files as examples for i, file_path in enumerate(files_skipped[:10]): skip_note += f" - {file_path}\n" if len(files_skipped) > 10: diff --git a/utils/git_utils.py b/utils/git_utils.py index f54aed6..87fead0 100644 --- a/utils/git_utils.py +++ b/utils/git_utils.py @@ -1,5 +1,20 @@ """ Git utilities for finding repositories and generating diffs. + +This module provides Git integration functionality for the MCP server, +enabling tools to work with version control information. It handles +repository discovery, status checking, and diff generation. + +Key Features: +- Recursive repository discovery with depth limits +- Safe command execution with timeouts +- Comprehensive status information extraction +- Support for staged and unstaged changes + +Security Considerations: +- All git commands are run with timeouts to prevent hanging +- Repository discovery ignores common build/dependency directories +- Error handling for permission-denied scenarios """ import subprocess @@ -8,16 +23,18 @@ from pathlib import Path # Directories to ignore when searching for git repositories +# These are typically build artifacts, dependencies, or cache directories +# that don't contain source code and would slow down repository discovery IGNORED_DIRS = { - "node_modules", - "__pycache__", - "venv", - "env", - "build", - "dist", - "target", - ".tox", - ".pytest_cache", + "node_modules", # Node.js dependencies + "__pycache__", # Python bytecode cache + "venv", # Python virtual environment + "env", # Alternative virtual environment name + "build", # Common build output directory + "dist", # Distribution/release builds + "target", # Maven/Rust build output + ".tox", # Tox testing environments + ".pytest_cache", # Pytest cache directory } @@ -25,38 +42,45 @@ def find_git_repositories(start_path: str, max_depth: int = 5) -> List[str]: """ Recursively find all git repositories starting from the given path. + This function walks the directory tree looking for .git directories, + which indicate the root of a git repository. It respects depth limits + to prevent excessive recursion in deep directory structures. + Args: - start_path: Directory to start searching from - max_depth: Maximum depth to search (prevents excessive recursion) + start_path: Directory to start searching from (must be absolute) + max_depth: Maximum depth to search (default 5 prevents excessive recursion) Returns: - List of absolute paths to git repositories + List of absolute paths to git repositories, sorted alphabetically """ repositories = [] start_path = Path(start_path).resolve() def _find_repos(current_path: Path, current_depth: int): + # Stop recursion if we've reached maximum depth if current_depth > max_depth: return try: - # Check if current directory is a git repo + # Check if current directory contains a .git directory git_dir = current_path / ".git" if git_dir.exists() and git_dir.is_dir(): repositories.append(str(current_path)) - # Don't search inside .git directory + # Don't search inside git repositories for nested repos + # This prevents finding submodules which should be handled separately return - # Search subdirectories + # Search subdirectories for more repositories for item in current_path.iterdir(): if item.is_dir() and not item.name.startswith("."): - # Skip common non-code directories + # Skip common non-code directories to improve performance if item.name in IGNORED_DIRS: continue _find_repos(item, current_depth + 1) except PermissionError: - # Skip directories we can't access + # Skip directories we don't have permission to read + # This is common for system directories or other users' files pass _find_repos(start_path, 0) @@ -67,16 +91,28 @@ def run_git_command(repo_path: str, command: List[str]) -> Tuple[bool, str]: """ Run a git command in the specified repository. + This function provides a safe way to execute git commands with: + - Timeout protection (30 seconds) to prevent hanging + - Proper error handling and output capture + - Working directory context management + Args: - repo_path: Path to the git repository - command: Git command as a list of arguments + repo_path: Path to the git repository (working directory) + command: Git command as a list of arguments (excluding 'git' itself) Returns: Tuple of (success, output/error) + - success: True if command returned 0, False otherwise + - output/error: stdout if successful, stderr or error message if failed """ try: + # Execute git command with safety measures result = subprocess.run( - ["git"] + command, cwd=repo_path, capture_output=True, text=True, timeout=30 + ["git"] + command, + cwd=repo_path, # Run in repository directory + capture_output=True, # Capture stdout and stderr + text=True, # Return strings instead of bytes + timeout=30, # Prevent hanging on slow operations ) if result.returncode == 0: @@ -85,21 +121,36 @@ def run_git_command(repo_path: str, command: List[str]) -> Tuple[bool, str]: return False, result.stderr except subprocess.TimeoutExpired: - return False, "Command timed out" + return False, "Command timed out after 30 seconds" except Exception as e: - return False, str(e) + return False, f"Git command failed: {str(e)}" def get_git_status(repo_path: str) -> Dict[str, any]: """ - Get the current git status of a repository. + Get comprehensive git status information for a repository. + + This function gathers various pieces of repository state including: + - Current branch name + - Commits ahead/behind upstream + - Lists of staged, unstaged, and untracked files + + The function is resilient to repositories without remotes or + in detached HEAD state. Args: repo_path: Path to the git repository Returns: - Dictionary with status information + Dictionary with status information: + - branch: Current branch name (empty if detached) + - ahead: Number of commits ahead of upstream + - behind: Number of commits behind upstream + - staged_files: List of files with staged changes + - unstaged_files: List of files with unstaged changes + - untracked_files: List of untracked files """ + # Initialize status structure with default values status = { "branch": "", "ahead": 0, @@ -109,12 +160,12 @@ def get_git_status(repo_path: str) -> Dict[str, any]: "untracked_files": [], } - # Get current branch + # Get current branch name (empty if in detached HEAD state) success, branch = run_git_command(repo_path, ["branch", "--show-current"]) if success: status["branch"] = branch.strip() - # Get ahead/behind info + # Get ahead/behind information relative to upstream branch if status["branch"]: success, ahead_behind = run_git_command( repo_path, @@ -131,33 +182,38 @@ def get_git_status(repo_path: str) -> Dict[str, any]: if len(parts) == 2: status["behind"] = int(parts[0]) status["ahead"] = int(parts[1]) - # else: Could not get ahead/behind status (branch may not have upstream) + # Note: This will fail gracefully if branch has no upstream set - # Get file status + # Get file status using porcelain format for machine parsing + # Format: XY filename where X=staged status, Y=unstaged status success, status_output = run_git_command(repo_path, ["status", "--porcelain"]) if success: for line in status_output.strip().split("\n"): if not line: continue - status_code = line[:2] - path_info = line[3:] + status_code = line[:2] # Two-character status code + path_info = line[3:] # Filename (after space) - # Handle staged changes + # Parse staged changes (first character of status code) if status_code[0] == "R": - # Format is "old_path -> new_path" for renamed files + # Special handling for renamed files + # Format is "old_path -> new_path" if " -> " in path_info: _, new_path = path_info.split(" -> ", 1) status["staged_files"].append(new_path) else: status["staged_files"].append(path_info) elif status_code[0] in ["M", "A", "D", "C"]: + # M=modified, A=added, D=deleted, C=copied status["staged_files"].append(path_info) - # Handle unstaged changes + # Parse unstaged changes (second character of status code) if status_code[1] in ["M", "D"]: + # M=modified, D=deleted in working tree status["unstaged_files"].append(path_info) elif status_code == "??": + # Untracked files have special marker "??" status["untracked_files"].append(path_info) return status diff --git a/utils/token_utils.py b/utils/token_utils.py index 7cb2b7a..3764185 100644 --- a/utils/token_utils.py +++ b/utils/token_utils.py @@ -1,5 +1,12 @@ """ -Token counting utilities +Token counting utilities for managing API context limits + +This module provides functions for estimating token counts to ensure +requests stay within the Gemini API's context window limits. + +Note: The estimation uses a simple character-to-token ratio which is +approximate. For production systems requiring precise token counts, +consider using the actual tokenizer for the specific model. """ from typing import Tuple @@ -8,14 +15,40 @@ from config import MAX_CONTEXT_TOKENS def estimate_tokens(text: str) -> int: - """Estimate token count (rough: 1 token ≈ 4 characters)""" + """ + Estimate token count using a character-based approximation. + + This uses a rough heuristic where 1 token ≈ 4 characters, which is + a reasonable approximation for English text. The actual token count + may vary based on: + - Language (non-English text may have different ratios) + - Code vs prose (code often has more tokens per character) + - Special characters and formatting + + Args: + text: The text to estimate tokens for + + Returns: + int: Estimated number of tokens + """ return len(text) // 4 def check_token_limit(text: str) -> Tuple[bool, int]: """ - Check if text exceeds token limit. - Returns: (is_within_limit, estimated_tokens) + Check if text exceeds the maximum token limit for Gemini models. + + This function is used to validate that prepared prompts will fit + within the model's context window, preventing API errors and ensuring + reliable operation. + + Args: + text: The text to check + + Returns: + Tuple[bool, int]: (is_within_limit, estimated_tokens) + - is_within_limit: True if the text fits within MAX_CONTEXT_TOKENS + - estimated_tokens: The estimated token count """ estimated = estimate_tokens(text) return estimated <= MAX_CONTEXT_TOKENS, estimated