diff --git a/.env.example b/.env.example index 1d88d4c..b88bd70 100644 --- a/.env.example +++ b/.env.example @@ -143,3 +143,13 @@ MAX_CONVERSATION_TURNS=20 # ERROR: Shows only errors LOG_LEVEL=DEBUG +# Optional: Tool Selection +# Comma-separated list of tools to disable. If not set, all tools are enabled. +# Essential tools (version, listmodels) cannot be disabled. +# Available tools: chat, thinkdeep, planner, consensus, codereview, precommit, +# debug, docgen, analyze, refactor, tracer, testgen +# Examples: +# DISABLED_TOOLS= # All tools enabled (default) +# DISABLED_TOOLS=debug,tracer # Disable debug and tracer tools +# DISABLED_TOOLS=planner,consensus # Disable planning tools + diff --git a/server.py b/server.py index 9247aa6..ebb5ce2 100644 --- a/server.py +++ b/server.py @@ -158,6 +158,97 @@ logger = logging.getLogger(__name__) # This name is used by MCP clients to identify and connect to this specific server server: Server = Server("zen-server") + +# Constants for tool filtering +ESSENTIAL_TOOLS = {"version", "listmodels"} + + +def parse_disabled_tools_env() -> set[str]: + """ + Parse the DISABLED_TOOLS environment variable into a set of tool names. + + Returns: + Set of lowercase tool names to disable, empty set if none specified + """ + disabled_tools_env = os.getenv("DISABLED_TOOLS", "").strip() + if not disabled_tools_env: + return set() + return {t.strip().lower() for t in disabled_tools_env.split(",") if t.strip()} + + +def validate_disabled_tools(disabled_tools: set[str], all_tools: dict[str, Any]) -> None: + """ + Validate the disabled tools list and log appropriate warnings. + + Args: + disabled_tools: Set of tool names requested to be disabled + all_tools: Dictionary of all available tool instances + """ + essential_disabled = disabled_tools & ESSENTIAL_TOOLS + if essential_disabled: + logger.warning(f"Cannot disable essential tools: {sorted(essential_disabled)}") + unknown_tools = disabled_tools - set(all_tools.keys()) + if unknown_tools: + logger.warning(f"Unknown tools in DISABLED_TOOLS: {sorted(unknown_tools)}") + + +def apply_tool_filter(all_tools: dict[str, Any], disabled_tools: set[str]) -> dict[str, Any]: + """ + Apply the disabled tools filter to create the final tools dictionary. + + Args: + all_tools: Dictionary of all available tool instances + disabled_tools: Set of tool names to disable + + Returns: + Dictionary containing only enabled tools + """ + enabled_tools = {} + for tool_name, tool_instance in all_tools.items(): + if tool_name in ESSENTIAL_TOOLS or tool_name not in disabled_tools: + enabled_tools[tool_name] = tool_instance + else: + logger.debug(f"Tool '{tool_name}' disabled via DISABLED_TOOLS") + return enabled_tools + + +def log_tool_configuration(disabled_tools: set[str], enabled_tools: dict[str, Any]) -> None: + """ + Log the final tool configuration for visibility. + + Args: + disabled_tools: Set of tool names that were requested to be disabled + enabled_tools: Dictionary of tools that remain enabled + """ + if not disabled_tools: + logger.info("All tools enabled (DISABLED_TOOLS not set)") + return + actual_disabled = disabled_tools - ESSENTIAL_TOOLS + if actual_disabled: + logger.debug(f"Disabled tools: {sorted(actual_disabled)}") + logger.info(f"Active tools: {sorted(enabled_tools.keys())}") + + +def filter_disabled_tools(all_tools: dict[str, Any]) -> dict[str, Any]: + """ + Filter tools based on DISABLED_TOOLS environment variable. + + Args: + all_tools: Dictionary of all available tool instances + + Returns: + dict: Filtered dictionary containing only enabled tools + """ + disabled_tools = parse_disabled_tools_env() + if not disabled_tools: + log_tool_configuration(disabled_tools, all_tools) + return all_tools + validate_disabled_tools(disabled_tools, all_tools) + enabled_tools = apply_tool_filter(all_tools, disabled_tools) + log_tool_configuration(disabled_tools, enabled_tools) + return enabled_tools + + # Initialize the tool registry with all available AI-powered tools # Each tool provides specialized functionality for different development tasks # Tools are instantiated once and reused across requests (stateless design) @@ -178,6 +269,7 @@ TOOLS = { "listmodels": ListModelsTool(), # List all available AI models by provider "version": VersionTool(), # Display server version and system information } +TOOLS = filter_disabled_tools(TOOLS) # Rich prompt templates for all tools PROMPT_TEMPLATES = { diff --git a/tests/test_auto_mode_custom_provider_only.py b/tests/test_auto_mode_custom_provider_only.py index 5d03d4e..c97e649 100644 --- a/tests/test_auto_mode_custom_provider_only.py +++ b/tests/test_auto_mode_custom_provider_only.py @@ -70,7 +70,7 @@ class TestAutoModeCustomProviderOnly: } # Clear all other provider keys - clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"] + clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"] with patch.dict(os.environ, test_env, clear=False): # Ensure other provider keys are not set @@ -109,7 +109,7 @@ class TestAutoModeCustomProviderOnly: with patch.dict(os.environ, test_env, clear=False): # Clear other provider keys - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]: if key in os.environ: del os.environ[key] @@ -177,7 +177,7 @@ class TestAutoModeCustomProviderOnly: with patch.dict(os.environ, test_env, clear=False): # Clear other provider keys - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: + for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]: if key in os.environ: del os.environ[key] diff --git a/tests/test_disabled_tools.py b/tests/test_disabled_tools.py new file mode 100644 index 0000000..65a525f --- /dev/null +++ b/tests/test_disabled_tools.py @@ -0,0 +1,140 @@ +"""Tests for DISABLED_TOOLS environment variable functionality.""" + +import logging +import os +from unittest.mock import patch + +import pytest + +from server import ( + apply_tool_filter, + parse_disabled_tools_env, + validate_disabled_tools, +) + + +# Mock the tool classes since we're testing the filtering logic +class MockTool: + def __init__(self, name): + self.name = name + + +class TestDisabledTools: + """Test suite for DISABLED_TOOLS functionality.""" + + def test_parse_disabled_tools_empty(self): + """Empty string returns empty set (no tools disabled).""" + with patch.dict(os.environ, {"DISABLED_TOOLS": ""}): + assert parse_disabled_tools_env() == set() + + def test_parse_disabled_tools_not_set(self): + """Unset variable returns empty set.""" + with patch.dict(os.environ, {}, clear=True): + # Ensure DISABLED_TOOLS is not in environment + if "DISABLED_TOOLS" in os.environ: + del os.environ["DISABLED_TOOLS"] + assert parse_disabled_tools_env() == set() + + def test_parse_disabled_tools_single(self): + """Single tool name parsed correctly.""" + with patch.dict(os.environ, {"DISABLED_TOOLS": "debug"}): + assert parse_disabled_tools_env() == {"debug"} + + def test_parse_disabled_tools_multiple(self): + """Multiple tools with spaces parsed correctly.""" + with patch.dict(os.environ, {"DISABLED_TOOLS": "debug, analyze, refactor"}): + assert parse_disabled_tools_env() == {"debug", "analyze", "refactor"} + + def test_parse_disabled_tools_extra_spaces(self): + """Extra spaces and empty items handled correctly.""" + with patch.dict(os.environ, {"DISABLED_TOOLS": " debug , , analyze , "}): + assert parse_disabled_tools_env() == {"debug", "analyze"} + + def test_parse_disabled_tools_duplicates(self): + """Duplicate entries handled correctly (set removes duplicates).""" + with patch.dict(os.environ, {"DISABLED_TOOLS": "debug,analyze,debug"}): + assert parse_disabled_tools_env() == {"debug", "analyze"} + + def test_tool_filtering_logic(self): + """Test the complete filtering logic using the actual server functions.""" + # Simulate ALL_TOOLS + ALL_TOOLS = { + "chat": MockTool("chat"), + "debug": MockTool("debug"), + "analyze": MockTool("analyze"), + "version": MockTool("version"), + "listmodels": MockTool("listmodels"), + } + + # Test case 1: No tools disabled + disabled_tools = set() + enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools) + + assert len(enabled_tools) == 5 # All tools included + assert set(enabled_tools.keys()) == set(ALL_TOOLS.keys()) + + # Test case 2: Disable some regular tools + disabled_tools = {"debug", "analyze"} + enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools) + + assert len(enabled_tools) == 3 # chat, version, listmodels + assert "debug" not in enabled_tools + assert "analyze" not in enabled_tools + assert "chat" in enabled_tools + assert "version" in enabled_tools + assert "listmodels" in enabled_tools + + # Test case 3: Attempt to disable essential tools + disabled_tools = {"version", "chat"} + enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools) + + assert "version" in enabled_tools # Essential tool not disabled + assert "chat" not in enabled_tools # Regular tool disabled + assert "listmodels" in enabled_tools # Essential tool included + + def test_unknown_tools_warning(self, caplog): + """Test that unknown tool names generate appropriate warnings.""" + ALL_TOOLS = { + "chat": MockTool("chat"), + "debug": MockTool("debug"), + "analyze": MockTool("analyze"), + "version": MockTool("version"), + "listmodels": MockTool("listmodels"), + } + disabled_tools = {"chat", "unknown_tool", "another_unknown"} + + with caplog.at_level(logging.WARNING): + validate_disabled_tools(disabled_tools, ALL_TOOLS) + assert "Unknown tools in DISABLED_TOOLS: ['another_unknown', 'unknown_tool']" in caplog.text + + def test_essential_tools_warning(self, caplog): + """Test warning when trying to disable essential tools.""" + ALL_TOOLS = { + "chat": MockTool("chat"), + "debug": MockTool("debug"), + "analyze": MockTool("analyze"), + "version": MockTool("version"), + "listmodels": MockTool("listmodels"), + } + disabled_tools = {"version", "chat", "debug"} + + with caplog.at_level(logging.WARNING): + validate_disabled_tools(disabled_tools, ALL_TOOLS) + assert "Cannot disable essential tools: ['version']" in caplog.text + + @pytest.mark.parametrize( + "env_value,expected", + [ + ("", set()), # Empty string + (" ", set()), # Only spaces + (",,,", set()), # Only commas + ("chat", {"chat"}), # Single tool + ("chat,debug", {"chat", "debug"}), # Multiple tools + ("chat, debug, analyze", {"chat", "debug", "analyze"}), # With spaces + ("chat,debug,chat", {"chat", "debug"}), # Duplicates + ], + ) + def test_parse_disabled_tools_parametrized(self, env_value, expected): + """Parametrized tests for various input formats.""" + with patch.dict(os.environ, {"DISABLED_TOOLS": env_value}): + assert parse_disabled_tools_env() == expected