feat: Add DISABLED_TOOLS environment variable for selective tool disabling (#127)
## Description This PR adds support for selectively disabling tools via the DISABLED_TOOLS environment variable, allowing users to customize which MCP tools are available in their Zen server instance. This feature enables better control over tool availability for security, performance, or organizational requirements. ## Changes Made - [x] Added `DISABLED_TOOLS` environment variable support to selectively disable tools - [x] Implemented tool filtering logic with protection for essential tools (version, listmodels) - [x] Added comprehensive validation with warnings for unknown tools and attempts to disable essential tools - [x] Updated `.env.example` with DISABLED_TOOLS documentation and examples - [x] Added comprehensive test suite (16 tests) covering all edge cases - [x] No breaking changes - feature is opt-in with default behavior unchanged ## Configuration Add to `.env` file: ```bash # Optional: Tool Selection # Comma-separated list of tools to disable. If not set, all tools are enabled. # Essential tools (version, listmodels) cannot be disabled. # Available tools: chat, thinkdeep, planner, consensus, codereview, precommit, # debug, docgen, analyze, refactor, tracer, testgen # Examples: # DISABLED_TOOLS= # All tools enabled (default) # DISABLED_TOOLS=debug,tracer # Disable debug and tracer tools # DISABLED_TOOLS=planner,consensus # Disable planning tools
This commit is contained in:
committed by
GitHub
parent
3b250c95df
commit
a355b80afc
10
.env.example
10
.env.example
@@ -143,3 +143,13 @@ MAX_CONVERSATION_TURNS=20
|
|||||||
# ERROR: Shows only errors
|
# ERROR: Shows only errors
|
||||||
LOG_LEVEL=DEBUG
|
LOG_LEVEL=DEBUG
|
||||||
|
|
||||||
|
# Optional: Tool Selection
|
||||||
|
# Comma-separated list of tools to disable. If not set, all tools are enabled.
|
||||||
|
# Essential tools (version, listmodels) cannot be disabled.
|
||||||
|
# Available tools: chat, thinkdeep, planner, consensus, codereview, precommit,
|
||||||
|
# debug, docgen, analyze, refactor, tracer, testgen
|
||||||
|
# Examples:
|
||||||
|
# DISABLED_TOOLS= # All tools enabled (default)
|
||||||
|
# DISABLED_TOOLS=debug,tracer # Disable debug and tracer tools
|
||||||
|
# DISABLED_TOOLS=planner,consensus # Disable planning tools
|
||||||
|
|
||||||
|
|||||||
92
server.py
92
server.py
@@ -158,6 +158,97 @@ logger = logging.getLogger(__name__)
|
|||||||
# This name is used by MCP clients to identify and connect to this specific server
|
# This name is used by MCP clients to identify and connect to this specific server
|
||||||
server: Server = Server("zen-server")
|
server: Server = Server("zen-server")
|
||||||
|
|
||||||
|
|
||||||
|
# Constants for tool filtering
|
||||||
|
ESSENTIAL_TOOLS = {"version", "listmodels"}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_disabled_tools_env() -> set[str]:
|
||||||
|
"""
|
||||||
|
Parse the DISABLED_TOOLS environment variable into a set of tool names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of lowercase tool names to disable, empty set if none specified
|
||||||
|
"""
|
||||||
|
disabled_tools_env = os.getenv("DISABLED_TOOLS", "").strip()
|
||||||
|
if not disabled_tools_env:
|
||||||
|
return set()
|
||||||
|
return {t.strip().lower() for t in disabled_tools_env.split(",") if t.strip()}
|
||||||
|
|
||||||
|
|
||||||
|
def validate_disabled_tools(disabled_tools: set[str], all_tools: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Validate the disabled tools list and log appropriate warnings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
disabled_tools: Set of tool names requested to be disabled
|
||||||
|
all_tools: Dictionary of all available tool instances
|
||||||
|
"""
|
||||||
|
essential_disabled = disabled_tools & ESSENTIAL_TOOLS
|
||||||
|
if essential_disabled:
|
||||||
|
logger.warning(f"Cannot disable essential tools: {sorted(essential_disabled)}")
|
||||||
|
unknown_tools = disabled_tools - set(all_tools.keys())
|
||||||
|
if unknown_tools:
|
||||||
|
logger.warning(f"Unknown tools in DISABLED_TOOLS: {sorted(unknown_tools)}")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_tool_filter(all_tools: dict[str, Any], disabled_tools: set[str]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Apply the disabled tools filter to create the final tools dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_tools: Dictionary of all available tool instances
|
||||||
|
disabled_tools: Set of tool names to disable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing only enabled tools
|
||||||
|
"""
|
||||||
|
enabled_tools = {}
|
||||||
|
for tool_name, tool_instance in all_tools.items():
|
||||||
|
if tool_name in ESSENTIAL_TOOLS or tool_name not in disabled_tools:
|
||||||
|
enabled_tools[tool_name] = tool_instance
|
||||||
|
else:
|
||||||
|
logger.debug(f"Tool '{tool_name}' disabled via DISABLED_TOOLS")
|
||||||
|
return enabled_tools
|
||||||
|
|
||||||
|
|
||||||
|
def log_tool_configuration(disabled_tools: set[str], enabled_tools: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Log the final tool configuration for visibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
disabled_tools: Set of tool names that were requested to be disabled
|
||||||
|
enabled_tools: Dictionary of tools that remain enabled
|
||||||
|
"""
|
||||||
|
if not disabled_tools:
|
||||||
|
logger.info("All tools enabled (DISABLED_TOOLS not set)")
|
||||||
|
return
|
||||||
|
actual_disabled = disabled_tools - ESSENTIAL_TOOLS
|
||||||
|
if actual_disabled:
|
||||||
|
logger.debug(f"Disabled tools: {sorted(actual_disabled)}")
|
||||||
|
logger.info(f"Active tools: {sorted(enabled_tools.keys())}")
|
||||||
|
|
||||||
|
|
||||||
|
def filter_disabled_tools(all_tools: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Filter tools based on DISABLED_TOOLS environment variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_tools: Dictionary of all available tool instances
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Filtered dictionary containing only enabled tools
|
||||||
|
"""
|
||||||
|
disabled_tools = parse_disabled_tools_env()
|
||||||
|
if not disabled_tools:
|
||||||
|
log_tool_configuration(disabled_tools, all_tools)
|
||||||
|
return all_tools
|
||||||
|
validate_disabled_tools(disabled_tools, all_tools)
|
||||||
|
enabled_tools = apply_tool_filter(all_tools, disabled_tools)
|
||||||
|
log_tool_configuration(disabled_tools, enabled_tools)
|
||||||
|
return enabled_tools
|
||||||
|
|
||||||
|
|
||||||
# Initialize the tool registry with all available AI-powered tools
|
# Initialize the tool registry with all available AI-powered tools
|
||||||
# Each tool provides specialized functionality for different development tasks
|
# Each tool provides specialized functionality for different development tasks
|
||||||
# Tools are instantiated once and reused across requests (stateless design)
|
# Tools are instantiated once and reused across requests (stateless design)
|
||||||
@@ -178,6 +269,7 @@ TOOLS = {
|
|||||||
"listmodels": ListModelsTool(), # List all available AI models by provider
|
"listmodels": ListModelsTool(), # List all available AI models by provider
|
||||||
"version": VersionTool(), # Display server version and system information
|
"version": VersionTool(), # Display server version and system information
|
||||||
}
|
}
|
||||||
|
TOOLS = filter_disabled_tools(TOOLS)
|
||||||
|
|
||||||
# Rich prompt templates for all tools
|
# Rich prompt templates for all tools
|
||||||
PROMPT_TEMPLATES = {
|
PROMPT_TEMPLATES = {
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class TestAutoModeCustomProviderOnly:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Clear all other provider keys
|
# Clear all other provider keys
|
||||||
clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]
|
clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]
|
||||||
|
|
||||||
with patch.dict(os.environ, test_env, clear=False):
|
with patch.dict(os.environ, test_env, clear=False):
|
||||||
# Ensure other provider keys are not set
|
# Ensure other provider keys are not set
|
||||||
@@ -109,7 +109,7 @@ class TestAutoModeCustomProviderOnly:
|
|||||||
|
|
||||||
with patch.dict(os.environ, test_env, clear=False):
|
with patch.dict(os.environ, test_env, clear=False):
|
||||||
# Clear other provider keys
|
# Clear other provider keys
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
|
||||||
if key in os.environ:
|
if key in os.environ:
|
||||||
del os.environ[key]
|
del os.environ[key]
|
||||||
|
|
||||||
@@ -177,7 +177,7 @@ class TestAutoModeCustomProviderOnly:
|
|||||||
|
|
||||||
with patch.dict(os.environ, test_env, clear=False):
|
with patch.dict(os.environ, test_env, clear=False):
|
||||||
# Clear other provider keys
|
# Clear other provider keys
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
|
||||||
if key in os.environ:
|
if key in os.environ:
|
||||||
del os.environ[key]
|
del os.environ[key]
|
||||||
|
|
||||||
|
|||||||
140
tests/test_disabled_tools.py
Normal file
140
tests/test_disabled_tools.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""Tests for DISABLED_TOOLS environment variable functionality."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from server import (
|
||||||
|
apply_tool_filter,
|
||||||
|
parse_disabled_tools_env,
|
||||||
|
validate_disabled_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Mock the tool classes since we're testing the filtering logic
|
||||||
|
class MockTool:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisabledTools:
|
||||||
|
"""Test suite for DISABLED_TOOLS functionality."""
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_empty(self):
|
||||||
|
"""Empty string returns empty set (no tools disabled)."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": ""}):
|
||||||
|
assert parse_disabled_tools_env() == set()
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_not_set(self):
|
||||||
|
"""Unset variable returns empty set."""
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
# Ensure DISABLED_TOOLS is not in environment
|
||||||
|
if "DISABLED_TOOLS" in os.environ:
|
||||||
|
del os.environ["DISABLED_TOOLS"]
|
||||||
|
assert parse_disabled_tools_env() == set()
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_single(self):
|
||||||
|
"""Single tool name parsed correctly."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug"}):
|
||||||
|
assert parse_disabled_tools_env() == {"debug"}
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_multiple(self):
|
||||||
|
"""Multiple tools with spaces parsed correctly."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug, analyze, refactor"}):
|
||||||
|
assert parse_disabled_tools_env() == {"debug", "analyze", "refactor"}
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_extra_spaces(self):
|
||||||
|
"""Extra spaces and empty items handled correctly."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": " debug , , analyze , "}):
|
||||||
|
assert parse_disabled_tools_env() == {"debug", "analyze"}
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_duplicates(self):
|
||||||
|
"""Duplicate entries handled correctly (set removes duplicates)."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug,analyze,debug"}):
|
||||||
|
assert parse_disabled_tools_env() == {"debug", "analyze"}
|
||||||
|
|
||||||
|
def test_tool_filtering_logic(self):
|
||||||
|
"""Test the complete filtering logic using the actual server functions."""
|
||||||
|
# Simulate ALL_TOOLS
|
||||||
|
ALL_TOOLS = {
|
||||||
|
"chat": MockTool("chat"),
|
||||||
|
"debug": MockTool("debug"),
|
||||||
|
"analyze": MockTool("analyze"),
|
||||||
|
"version": MockTool("version"),
|
||||||
|
"listmodels": MockTool("listmodels"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test case 1: No tools disabled
|
||||||
|
disabled_tools = set()
|
||||||
|
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||||
|
|
||||||
|
assert len(enabled_tools) == 5 # All tools included
|
||||||
|
assert set(enabled_tools.keys()) == set(ALL_TOOLS.keys())
|
||||||
|
|
||||||
|
# Test case 2: Disable some regular tools
|
||||||
|
disabled_tools = {"debug", "analyze"}
|
||||||
|
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||||
|
|
||||||
|
assert len(enabled_tools) == 3 # chat, version, listmodels
|
||||||
|
assert "debug" not in enabled_tools
|
||||||
|
assert "analyze" not in enabled_tools
|
||||||
|
assert "chat" in enabled_tools
|
||||||
|
assert "version" in enabled_tools
|
||||||
|
assert "listmodels" in enabled_tools
|
||||||
|
|
||||||
|
# Test case 3: Attempt to disable essential tools
|
||||||
|
disabled_tools = {"version", "chat"}
|
||||||
|
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||||
|
|
||||||
|
assert "version" in enabled_tools # Essential tool not disabled
|
||||||
|
assert "chat" not in enabled_tools # Regular tool disabled
|
||||||
|
assert "listmodels" in enabled_tools # Essential tool included
|
||||||
|
|
||||||
|
def test_unknown_tools_warning(self, caplog):
|
||||||
|
"""Test that unknown tool names generate appropriate warnings."""
|
||||||
|
ALL_TOOLS = {
|
||||||
|
"chat": MockTool("chat"),
|
||||||
|
"debug": MockTool("debug"),
|
||||||
|
"analyze": MockTool("analyze"),
|
||||||
|
"version": MockTool("version"),
|
||||||
|
"listmodels": MockTool("listmodels"),
|
||||||
|
}
|
||||||
|
disabled_tools = {"chat", "unknown_tool", "another_unknown"}
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
validate_disabled_tools(disabled_tools, ALL_TOOLS)
|
||||||
|
assert "Unknown tools in DISABLED_TOOLS: ['another_unknown', 'unknown_tool']" in caplog.text
|
||||||
|
|
||||||
|
def test_essential_tools_warning(self, caplog):
|
||||||
|
"""Test warning when trying to disable essential tools."""
|
||||||
|
ALL_TOOLS = {
|
||||||
|
"chat": MockTool("chat"),
|
||||||
|
"debug": MockTool("debug"),
|
||||||
|
"analyze": MockTool("analyze"),
|
||||||
|
"version": MockTool("version"),
|
||||||
|
"listmodels": MockTool("listmodels"),
|
||||||
|
}
|
||||||
|
disabled_tools = {"version", "chat", "debug"}
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
validate_disabled_tools(disabled_tools, ALL_TOOLS)
|
||||||
|
assert "Cannot disable essential tools: ['version']" in caplog.text
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_value,expected",
|
||||||
|
[
|
||||||
|
("", set()), # Empty string
|
||||||
|
(" ", set()), # Only spaces
|
||||||
|
(",,,", set()), # Only commas
|
||||||
|
("chat", {"chat"}), # Single tool
|
||||||
|
("chat,debug", {"chat", "debug"}), # Multiple tools
|
||||||
|
("chat, debug, analyze", {"chat", "debug", "analyze"}), # With spaces
|
||||||
|
("chat,debug,chat", {"chat", "debug"}), # Duplicates
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_disabled_tools_parametrized(self, env_value, expected):
|
||||||
|
"""Parametrized tests for various input formats."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": env_value}):
|
||||||
|
assert parse_disabled_tools_env() == expected
|
||||||
Reference in New Issue
Block a user