Categorize tools into 'model capabilities categories' to help determine which type of model to pick when in auto mode
Encourage Claude to pick the best model for the job automatically in auto mode Lots of new tests to ensure automatic model picking works reliably based on user preference or when a matching model is not found or ambiguous Improved error reporting when bogus model is requested and is not configured or available
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -106,6 +106,7 @@ celerybeat.pid
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.env~
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
|
||||
@@ -14,7 +14,7 @@ import os
|
||||
# These values are used in server responses and for tracking releases
|
||||
# IMPORTANT: This is the single source of truth for version and author info
|
||||
# Semantic versioning: MAJOR.MINOR.PATCH
|
||||
__version__ = "4.2.2"
|
||||
__version__ = "4.3.0"
|
||||
# Last update date in ISO format
|
||||
__updated__ = "2025-06-14"
|
||||
# Primary maintainer
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .base import ModelProvider, ProviderType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
|
||||
class ModelProviderRegistry:
|
||||
"""Registry for managing model providers."""
|
||||
@@ -189,27 +192,50 @@ class ModelProviderRegistry:
|
||||
return os.getenv(env_var)
|
||||
|
||||
@classmethod
|
||||
def get_preferred_fallback_model(cls) -> str:
|
||||
"""Get the preferred fallback model based on available API keys.
|
||||
def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str:
|
||||
"""Get the preferred fallback model based on available API keys and tool category.
|
||||
|
||||
This method checks which providers have valid API keys and returns
|
||||
a sensible default model for auto mode fallback situations.
|
||||
|
||||
Priority order:
|
||||
1. OpenAI o3-mini (balanced performance/cost) if OpenAI API key available
|
||||
2. Gemini 2.0 Flash (fast and efficient) if Gemini API key available
|
||||
3. OpenAI o3 (high performance) if OpenAI API key available
|
||||
4. Gemini 2.5 Pro (deep reasoning) if Gemini API key available
|
||||
5. Fallback to gemini-2.5-flash-preview-05-20 (most common case)
|
||||
Args:
|
||||
tool_category: Optional category to influence model selection
|
||||
|
||||
Returns:
|
||||
Model name string for fallback use
|
||||
"""
|
||||
# Import here to avoid circular import
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
# Check provider availability by trying to get instances
|
||||
openai_available = cls.get_provider(ProviderType.OPENAI) is not None
|
||||
gemini_available = cls.get_provider(ProviderType.GOOGLE) is not None
|
||||
|
||||
# Priority order: prefer balanced models first, then high-performance
|
||||
if tool_category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# Prefer thinking-capable models for deep reasoning tools
|
||||
if openai_available:
|
||||
return "o3" # O3 for deep reasoning
|
||||
elif gemini_available:
|
||||
return "pro" # Gemini Pro with thinking mode
|
||||
else:
|
||||
# Try to find thinking-capable model from custom/openrouter
|
||||
thinking_model = cls._find_extended_thinking_model()
|
||||
if thinking_model:
|
||||
return thinking_model
|
||||
# Fallback to pro if nothing found
|
||||
return "gemini-2.5-pro-preview-06-05"
|
||||
|
||||
elif tool_category == ToolModelCategory.FAST_RESPONSE:
|
||||
# Prefer fast, cost-efficient models
|
||||
if openai_available:
|
||||
return "o3-mini" # Fast and efficient
|
||||
elif gemini_available:
|
||||
return "flash" # Gemini Flash for speed
|
||||
else:
|
||||
# Default to flash
|
||||
return "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
# BALANCED or no category specified - use existing balanced logic
|
||||
if openai_available:
|
||||
return "o3-mini" # Balanced performance/cost
|
||||
elif gemini_available:
|
||||
@@ -219,6 +245,51 @@ class ModelProviderRegistry:
|
||||
# This maintains backward compatibility for tests
|
||||
return "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
@classmethod
|
||||
def _find_extended_thinking_model(cls) -> Optional[str]:
|
||||
"""Find a model suitable for extended reasoning from custom/openrouter providers.
|
||||
|
||||
Returns:
|
||||
Model name if found, None otherwise
|
||||
"""
|
||||
# Check custom provider first
|
||||
custom_provider = cls.get_provider(ProviderType.CUSTOM)
|
||||
if custom_provider:
|
||||
# Check if it's a CustomModelProvider and has thinking models
|
||||
try:
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
if isinstance(custom_provider, CustomProvider) and hasattr(custom_provider, "model_registry"):
|
||||
for model_name, config in custom_provider.model_registry.items():
|
||||
if config.get("supports_extended_thinking", False):
|
||||
return model_name
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Then check OpenRouter for high-context/powerful models
|
||||
openrouter_provider = cls.get_provider(ProviderType.OPENROUTER)
|
||||
if openrouter_provider:
|
||||
# Prefer models known for deep reasoning
|
||||
preferred_models = [
|
||||
"anthropic/claude-3.5-sonnet",
|
||||
"anthropic/claude-3-opus-20240229",
|
||||
"meta-llama/llama-3.1-70b-instruct",
|
||||
"google/gemini-pro-1.5",
|
||||
"mistralai/mixtral-8x7b-instruct",
|
||||
]
|
||||
for model in preferred_models:
|
||||
try:
|
||||
if openrouter_provider.validate_model_name(model):
|
||||
return model
|
||||
except Exception as e:
|
||||
# Log the error for debugging purposes but continue searching
|
||||
import logging
|
||||
|
||||
logging.warning(f"Model validation for '{model}' on OpenRouter failed: {e}")
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_available_providers_with_keys(cls) -> list[ProviderType]:
|
||||
"""Get list of provider types that have valid API keys.
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestAutoMode:
|
||||
model_schema = schema["properties"]["model"]
|
||||
assert "enum" in model_schema
|
||||
assert "flash" in model_schema["enum"]
|
||||
assert "Choose the best model" in model_schema["description"]
|
||||
assert "select the most suitable model" in model_schema["description"]
|
||||
|
||||
finally:
|
||||
# Restore
|
||||
@@ -134,6 +134,58 @@ class TestAutoMode:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
importlib.reload(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unavailable_model_error_message(self):
|
||||
"""Test that unavailable model shows helpful error with available models"""
|
||||
# Save original
|
||||
original = os.environ.get("DEFAULT_MODEL", "")
|
||||
|
||||
try:
|
||||
# Enable auto mode
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
tool = AnalyzeTool()
|
||||
|
||||
# Mock the provider to simulate o3 not being available
|
||||
with patch("providers.registry.ModelProviderRegistry.get_provider_for_model") as mock_provider:
|
||||
# Mock that o3 is not available but flash/pro are
|
||||
def mock_get_provider(model_name):
|
||||
if model_name in ["flash", "pro", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-06-05"]:
|
||||
# Return a mock provider for available models
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
return MagicMock()
|
||||
else:
|
||||
# o3 and others are not available
|
||||
return None
|
||||
|
||||
mock_provider.side_effect = mock_get_provider
|
||||
|
||||
# Execute with unavailable model
|
||||
result = await tool.execute(
|
||||
{"files": ["/tmp/test.py"], "prompt": "Analyze this", "model": "o3"} # This model is not available
|
||||
)
|
||||
|
||||
# Should get error with helpful message
|
||||
assert len(result) == 1
|
||||
response = result[0].text
|
||||
assert "error" in response
|
||||
assert "Model 'o3' is not available" in response
|
||||
assert "Available models:" in response
|
||||
# Should list the available models
|
||||
assert "flash" in response or "pro" in response
|
||||
|
||||
finally:
|
||||
# Restore
|
||||
if original:
|
||||
os.environ["DEFAULT_MODEL"] = original
|
||||
else:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
importlib.reload(config)
|
||||
|
||||
def test_model_field_schema_generation(self):
|
||||
"""Test the get_model_field_schema method"""
|
||||
from tools.base import BaseTool
|
||||
@@ -173,7 +225,7 @@ class TestAutoMode:
|
||||
schema = tool.get_model_field_schema()
|
||||
assert "enum" in schema
|
||||
assert all(model in schema["enum"] for model in ["flash", "pro", "o3"])
|
||||
assert "Choose the best model" in schema["description"]
|
||||
assert "select the most suitable model" in schema["description"]
|
||||
|
||||
# Test normal mode
|
||||
os.environ["DEFAULT_MODEL"] = "pro"
|
||||
|
||||
417
tests/test_per_tool_model_defaults.py
Normal file
417
tests/test_per_tool_model_defaults.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
Test per-tool model default selection functionality
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.registry import ModelProviderRegistry, ProviderType
|
||||
from tools.analyze import AnalyzeTool
|
||||
from tools.base import BaseTool
|
||||
from tools.chat import ChatTool
|
||||
from tools.codereview import CodeReviewTool
|
||||
from tools.debug import DebugIssueTool
|
||||
from tools.models import ToolModelCategory
|
||||
from tools.precommit import Precommit
|
||||
from tools.thinkdeep import ThinkDeepTool
|
||||
|
||||
|
||||
class TestToolModelCategories:
|
||||
"""Test that each tool returns the correct model category."""
|
||||
|
||||
def test_thinkdeep_category(self):
|
||||
tool = ThinkDeepTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def test_debug_category(self):
|
||||
tool = DebugIssueTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def test_analyze_category(self):
|
||||
tool = AnalyzeTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def test_precommit_category(self):
|
||||
tool = Precommit()
|
||||
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def test_chat_category(self):
|
||||
tool = ChatTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.FAST_RESPONSE
|
||||
|
||||
def test_codereview_category(self):
|
||||
tool = CodeReviewTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.BALANCED
|
||||
|
||||
def test_base_tool_default_category(self):
|
||||
# Test that BaseTool defaults to BALANCED
|
||||
class TestTool(BaseTool):
|
||||
def get_name(self):
|
||||
return "test"
|
||||
|
||||
def get_description(self):
|
||||
return "test"
|
||||
|
||||
def get_input_schema(self):
|
||||
return {}
|
||||
|
||||
def get_system_prompt(self):
|
||||
return "test"
|
||||
|
||||
def get_request_model(self):
|
||||
return MagicMock
|
||||
|
||||
async def prepare_prompt(self, request):
|
||||
return "test"
|
||||
|
||||
tool = TestTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.BALANCED
|
||||
|
||||
|
||||
class TestModelSelection:
|
||||
"""Test model selection based on tool categories."""
|
||||
|
||||
def test_extended_reasoning_with_openai(self):
|
||||
"""Test EXTENDED_REASONING prefers o3 when OpenAI is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock OpenAI available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model == "o3"
|
||||
|
||||
def test_extended_reasoning_with_gemini_only(self):
|
||||
"""Test EXTENDED_REASONING prefers pro when only Gemini is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock only Gemini available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model == "pro"
|
||||
|
||||
def test_fast_response_with_openai(self):
|
||||
"""Test FAST_RESPONSE prefers o3-mini when OpenAI is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock OpenAI available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
assert model == "o3-mini"
|
||||
|
||||
def test_fast_response_with_gemini_only(self):
|
||||
"""Test FAST_RESPONSE prefers flash when only Gemini is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock only Gemini available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
assert model == "flash"
|
||||
|
||||
def test_balanced_category_fallback(self):
|
||||
"""Test BALANCED category uses existing logic."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock OpenAI available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
|
||||
assert model == "o3-mini" # Balanced prefers o3-mini when OpenAI available
|
||||
|
||||
def test_no_category_uses_balanced_logic(self):
|
||||
"""Test that no category specified uses balanced logic."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock Gemini available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
assert model == "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
|
||||
class TestCustomProviderFallback:
|
||||
"""Test fallback to custom/openrouter providers."""
|
||||
|
||||
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
|
||||
def test_extended_reasoning_custom_fallback(self, mock_find_thinking):
|
||||
"""Test EXTENDED_REASONING falls back to custom thinking model."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# No native providers available
|
||||
mock_get_provider.return_value = None
|
||||
mock_find_thinking.return_value = "custom/thinking-model"
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model == "custom/thinking-model"
|
||||
mock_find_thinking.assert_called_once()
|
||||
|
||||
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
|
||||
def test_extended_reasoning_final_fallback(self, mock_find_thinking):
|
||||
"""Test EXTENDED_REASONING falls back to pro when no custom found."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# No providers available
|
||||
mock_get_provider.return_value = None
|
||||
mock_find_thinking.return_value = None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model == "gemini-2.5-pro-preview-06-05"
|
||||
|
||||
|
||||
class TestAutoModeErrorMessages:
|
||||
"""Test that auto mode error messages include suggested models."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thinkdeep_auto_error_message(self):
|
||||
"""Test ThinkDeep tool suggests appropriate model in auto mode."""
|
||||
with patch("config.IS_AUTO_MODE", True):
|
||||
with patch("config.DEFAULT_MODEL", "auto"):
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock Gemini available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None
|
||||
|
||||
tool = ThinkDeepTool()
|
||||
result = await tool.execute({"prompt": "test", "model": "auto"})
|
||||
|
||||
assert len(result) == 1
|
||||
assert "Model parameter is required in auto mode" in result[0].text
|
||||
assert "Suggested model for thinkdeep: 'pro'" in result[0].text
|
||||
assert "(category: extended_reasoning)" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_auto_error_message(self):
|
||||
"""Test Chat tool suggests appropriate model in auto mode."""
|
||||
with patch("config.IS_AUTO_MODE", True):
|
||||
with patch("config.DEFAULT_MODEL", "auto"):
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock OpenAI available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None
|
||||
|
||||
tool = ChatTool()
|
||||
result = await tool.execute({"prompt": "test", "model": "auto"})
|
||||
|
||||
assert len(result) == 1
|
||||
assert "Model parameter is required in auto mode" in result[0].text
|
||||
assert "Suggested model for chat: 'o3-mini'" in result[0].text
|
||||
assert "(category: fast_response)" in result[0].text
|
||||
|
||||
|
||||
class TestFileContentPreparation:
|
||||
"""Test that file content preparation uses tool-specific model for capacity."""
|
||||
|
||||
@patch("tools.base.read_files")
|
||||
@patch("tools.base.logger")
|
||||
def test_auto_mode_uses_tool_category(self, mock_logger, mock_read_files):
|
||||
"""Test that auto mode uses tool-specific model for capacity estimation."""
|
||||
mock_read_files.return_value = "file content"
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock provider with capabilities
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_capabilities.return_value = MagicMock(context_window=1_000_000)
|
||||
mock_get_provider.side_effect = lambda ptype: mock_provider if ptype == ProviderType.GOOGLE else None
|
||||
|
||||
# Create a tool and test file content preparation
|
||||
tool = ThinkDeepTool()
|
||||
tool._current_model_name = "auto"
|
||||
|
||||
# Call the method
|
||||
tool._prepare_file_content_for_prompt(["/test/file.py"], None, "test")
|
||||
|
||||
# Check that it logged the correct message
|
||||
debug_calls = [call for call in mock_logger.debug.call_args_list if "Auto mode detected" in str(call)]
|
||||
assert len(debug_calls) > 0
|
||||
assert "using pro for extended_reasoning tool capacity estimation" in str(debug_calls[0])
|
||||
|
||||
|
||||
class TestProviderHelperMethods:
|
||||
"""Test the helper methods for finding models from custom/openrouter."""
|
||||
|
||||
def test_find_extended_thinking_model_custom(self):
|
||||
"""Test finding thinking model from custom provider."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
# Mock custom provider with thinking model
|
||||
mock_custom = MagicMock(spec=CustomProvider)
|
||||
mock_custom.model_registry = {
|
||||
"model1": {"supports_extended_thinking": False},
|
||||
"model2": {"supports_extended_thinking": True},
|
||||
"model3": {"supports_extended_thinking": False},
|
||||
}
|
||||
mock_get_provider.side_effect = lambda ptype: mock_custom if ptype == ProviderType.CUSTOM else None
|
||||
|
||||
model = ModelProviderRegistry._find_extended_thinking_model()
|
||||
assert model == "model2"
|
||||
|
||||
def test_find_extended_thinking_model_openrouter(self):
|
||||
"""Test finding thinking model from openrouter."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock openrouter provider
|
||||
mock_openrouter = MagicMock()
|
||||
mock_openrouter.validate_model_name.side_effect = lambda m: m == "anthropic/claude-3.5-sonnet"
|
||||
mock_get_provider.side_effect = lambda ptype: mock_openrouter if ptype == ProviderType.OPENROUTER else None
|
||||
|
||||
model = ModelProviderRegistry._find_extended_thinking_model()
|
||||
assert model == "anthropic/claude-3.5-sonnet"
|
||||
|
||||
def test_find_extended_thinking_model_none_found(self):
|
||||
"""Test when no thinking model is found."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# No providers available
|
||||
mock_get_provider.return_value = None
|
||||
|
||||
model = ModelProviderRegistry._find_extended_thinking_model()
|
||||
assert model is None
|
||||
|
||||
|
||||
class TestEffectiveAutoMode:
|
||||
"""Test the is_effective_auto_mode method."""
|
||||
|
||||
def test_explicit_auto_mode(self):
|
||||
"""Test when DEFAULT_MODEL is explicitly 'auto'."""
|
||||
with patch("config.DEFAULT_MODEL", "auto"):
|
||||
with patch("config.IS_AUTO_MODE", True):
|
||||
tool = ChatTool()
|
||||
assert tool.is_effective_auto_mode() is True
|
||||
|
||||
def test_unavailable_model_triggers_auto_mode(self):
|
||||
"""Test when DEFAULT_MODEL is set but not available."""
|
||||
with patch("config.DEFAULT_MODEL", "o3"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
mock_get_provider.return_value = None # Model not available
|
||||
|
||||
tool = ChatTool()
|
||||
assert tool.is_effective_auto_mode() is True
|
||||
|
||||
def test_available_model_no_auto_mode(self):
|
||||
"""Test when DEFAULT_MODEL is set and available."""
|
||||
with patch("config.DEFAULT_MODEL", "pro"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
mock_get_provider.return_value = MagicMock() # Model is available
|
||||
|
||||
tool = ChatTool()
|
||||
assert tool.is_effective_auto_mode() is False
|
||||
|
||||
|
||||
class TestRuntimeModelSelection:
|
||||
"""Test runtime model selection behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_auto_in_request(self):
|
||||
"""Test when Claude explicitly passes model='auto'."""
|
||||
with patch("config.DEFAULT_MODEL", "pro"): # DEFAULT_MODEL is a real model
|
||||
with patch("config.IS_AUTO_MODE", False): # Not in auto mode
|
||||
tool = ThinkDeepTool()
|
||||
result = await tool.execute({"prompt": "test", "model": "auto"})
|
||||
|
||||
# Should require model selection even though DEFAULT_MODEL is valid
|
||||
assert len(result) == 1
|
||||
assert "Model parameter is required in auto mode" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unavailable_model_in_request(self):
|
||||
"""Test when Claude passes an unavailable model."""
|
||||
with patch("config.DEFAULT_MODEL", "pro"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
# Model is not available
|
||||
mock_get_provider.return_value = None
|
||||
|
||||
tool = ChatTool()
|
||||
result = await tool.execute({"prompt": "test", "model": "gpt-5-turbo"})
|
||||
|
||||
# Should require model selection
|
||||
assert len(result) == 1
|
||||
# When a specific model is requested but not available, error message is different
|
||||
assert "gpt-5-turbo" in result[0].text
|
||||
assert "is not available" in result[0].text
|
||||
assert "(category: fast_response)" in result[0].text
|
||||
|
||||
|
||||
class TestSchemaGeneration:
|
||||
"""Test schema generation with different configurations."""
|
||||
|
||||
def test_schema_with_explicit_auto_mode(self):
|
||||
"""Test schema when DEFAULT_MODEL='auto'."""
|
||||
with patch("config.DEFAULT_MODEL", "auto"):
|
||||
with patch("config.IS_AUTO_MODE", True):
|
||||
tool = ChatTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Model should be required
|
||||
assert "model" in schema["required"]
|
||||
|
||||
def test_schema_with_unavailable_default_model(self):
|
||||
"""Test schema when DEFAULT_MODEL is set but unavailable."""
|
||||
with patch("config.DEFAULT_MODEL", "o3"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
mock_get_provider.return_value = None # Model not available
|
||||
|
||||
tool = AnalyzeTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Model should be required due to unavailable DEFAULT_MODEL
|
||||
assert "model" in schema["required"]
|
||||
|
||||
def test_schema_with_available_default_model(self):
|
||||
"""Test schema when DEFAULT_MODEL is available."""
|
||||
with patch("config.DEFAULT_MODEL", "pro"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
mock_get_provider.return_value = MagicMock() # Model is available
|
||||
|
||||
tool = ThinkDeepTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Model should NOT be required
|
||||
assert "model" not in schema["required"]
|
||||
|
||||
|
||||
class TestUnavailableModelFallback:
|
||||
"""Test fallback behavior when DEFAULT_MODEL is not available."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unavailable_default_model_fallback(self):
|
||||
"""Test that unavailable DEFAULT_MODEL triggers auto mode behavior."""
|
||||
with patch("config.DEFAULT_MODEL", "o3"): # Set DEFAULT_MODEL to a specific model
|
||||
with patch("config.IS_AUTO_MODE", False): # Not in auto mode
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
# Model is not available (no provider)
|
||||
mock_get_provider.return_value = None
|
||||
|
||||
tool = ThinkDeepTool()
|
||||
result = await tool.execute({"prompt": "test"}) # No model specified
|
||||
|
||||
# Should get auto mode error since model is unavailable
|
||||
assert len(result) == 1
|
||||
# When DEFAULT_MODEL is unavailable, the error message indicates the model is not available
|
||||
assert "o3" in result[0].text
|
||||
assert "is not available" in result[0].text
|
||||
# The suggested model depends on which providers are available
|
||||
# Just check that it suggests a model for the extended_reasoning category
|
||||
assert "(category: extended_reasoning)" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_available_default_model_no_fallback(self):
|
||||
"""Test that available DEFAULT_MODEL works normally."""
|
||||
with patch("config.DEFAULT_MODEL", "pro"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
# Model is available
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.generate_content.return_value = MagicMock(content="Test response", metadata={})
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Mock the provider lookup in BaseTool.get_model_provider
|
||||
with patch.object(BaseTool, "get_model_provider") as mock_get_model_provider:
|
||||
mock_get_model_provider.return_value = mock_provider
|
||||
|
||||
tool = ChatTool()
|
||||
result = await tool.execute({"prompt": "test"}) # No model specified
|
||||
|
||||
# Should work normally, not require model parameter
|
||||
assert len(result) == 1
|
||||
output = json.loads(result[0].text)
|
||||
assert output["status"] == "success"
|
||||
assert "Test response" in output["content"]
|
||||
@@ -2,11 +2,14 @@
|
||||
Analyze tool - General-purpose code and file analysis
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from mcp.types import TextContent
|
||||
from pydantic import Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from config import TEMPERATURE_ANALYTICAL
|
||||
from prompts import ANALYZE_PROMPT
|
||||
|
||||
@@ -42,8 +45,6 @@ class AnalyzeTool(BaseTool):
|
||||
)
|
||||
|
||||
def get_input_schema(self) -> dict[str, Any]:
|
||||
from config import IS_AUTO_MODE
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -95,7 +96,7 @@ class AnalyzeTool(BaseTool):
|
||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||
},
|
||||
},
|
||||
"required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||
"required": ["files", "prompt"] + (["model"] if self.is_effective_auto_mode() else []),
|
||||
}
|
||||
|
||||
return schema
|
||||
@@ -106,6 +107,12 @@ class AnalyzeTool(BaseTool):
|
||||
def get_default_temperature(self) -> float:
|
||||
return TEMPERATURE_ANALYTICAL
|
||||
|
||||
def get_model_category(self) -> "ToolModelCategory":
|
||||
"""Analyze requires deep understanding and reasoning"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
return ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def get_request_model(self):
|
||||
return AnalyzeRequest
|
||||
|
||||
|
||||
154
tools/base.py
154
tools/base.py
@@ -17,11 +17,14 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
from mcp.types import TextContent
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from config import MCP_PROMPT_SIZE_LIMIT
|
||||
from providers import ModelProvider, ModelProviderRegistry
|
||||
from utils import check_token_limit
|
||||
@@ -156,6 +159,88 @@ class BaseTool(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_effective_auto_mode(self) -> bool:
|
||||
"""
|
||||
Check if we're in effective auto mode for schema generation.
|
||||
|
||||
This determines whether the model parameter should be required in the tool schema.
|
||||
Used at initialization time when schemas are generated.
|
||||
|
||||
Returns:
|
||||
bool: True if model parameter should be required in the schema
|
||||
"""
|
||||
from config import DEFAULT_MODEL, IS_AUTO_MODE
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Case 1: Explicit auto mode
|
||||
if IS_AUTO_MODE:
|
||||
return True
|
||||
|
||||
# Case 2: Model not available
|
||||
if DEFAULT_MODEL.lower() != "auto":
|
||||
provider = ModelProviderRegistry.get_provider_for_model(DEFAULT_MODEL)
|
||||
if not provider:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _should_require_model_selection(self, model_name: str) -> bool:
|
||||
"""
|
||||
Check if we should require Claude to select a model at runtime.
|
||||
|
||||
This is called during request execution to determine if we need
|
||||
to return an error asking Claude to provide a model parameter.
|
||||
|
||||
Args:
|
||||
model_name: The model name from the request or DEFAULT_MODEL
|
||||
|
||||
Returns:
|
||||
bool: True if we should require model selection
|
||||
"""
|
||||
# Case 1: Model is explicitly "auto"
|
||||
if model_name.lower() == "auto":
|
||||
return True
|
||||
|
||||
# Case 2: Requested model is not available
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
provider = ModelProviderRegistry.get_provider_for_model(model_name)
|
||||
if not provider:
|
||||
logger = logging.getLogger(f"tools.{self.name}")
|
||||
logger.warning(
|
||||
f"Model '{model_name}' is not available with current API keys. " f"Requiring model selection."
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_available_models(self) -> list[str]:
|
||||
"""
|
||||
Get list of models that are actually available with current API keys.
|
||||
|
||||
Returns:
|
||||
List of available model names
|
||||
"""
|
||||
from config import MODEL_CAPABILITIES_DESC
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
available_models = []
|
||||
|
||||
# Check each model in our capabilities list
|
||||
for model_name in MODEL_CAPABILITIES_DESC.keys():
|
||||
provider = ModelProviderRegistry.get_provider_for_model(model_name)
|
||||
if provider:
|
||||
available_models.append(model_name)
|
||||
|
||||
# Also check if OpenRouter is available (it accepts any model)
|
||||
openrouter_provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||
if openrouter_provider and not available_models:
|
||||
# If only OpenRouter is available, suggest using any model through it
|
||||
available_models.append("any model via OpenRouter")
|
||||
|
||||
return available_models if available_models else ["none - please configure API keys"]
|
||||
|
||||
def get_model_field_schema(self) -> dict[str, Any]:
|
||||
"""
|
||||
Generate the model field schema based on auto mode configuration.
|
||||
@@ -168,16 +253,20 @@ class BaseTool(ABC):
|
||||
"""
|
||||
import os
|
||||
|
||||
from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC
|
||||
from config import DEFAULT_MODEL, MODEL_CAPABILITIES_DESC
|
||||
|
||||
# Check if OpenRouter is configured
|
||||
has_openrouter = bool(
|
||||
os.getenv("OPENROUTER_API_KEY") and os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here"
|
||||
)
|
||||
|
||||
if IS_AUTO_MODE:
|
||||
# Use the centralized effective auto mode check
|
||||
if self.is_effective_auto_mode():
|
||||
# In auto mode, model is required and we provide detailed descriptions
|
||||
model_desc_parts = ["Choose the best model for this task based on these capabilities:"]
|
||||
model_desc_parts = [
|
||||
"IMPORTANT: Use the model specified by the user if provided, OR select the most suitable model "
|
||||
"for this specific task based on the requirements and capabilities listed below:"
|
||||
]
|
||||
for model, desc in MODEL_CAPABILITIES_DESC.items():
|
||||
model_desc_parts.append(f"- '{model}': {desc}")
|
||||
|
||||
@@ -302,6 +391,21 @@ class BaseTool(ABC):
|
||||
"""
|
||||
return "medium" # Default to medium thinking for better reasoning
|
||||
|
||||
def get_model_category(self) -> "ToolModelCategory":
|
||||
"""
|
||||
Return the model category for this tool.
|
||||
|
||||
Model category influences which model is selected in auto mode.
|
||||
Override to specify whether your tool needs extended reasoning,
|
||||
fast response, or balanced capabilities.
|
||||
|
||||
Returns:
|
||||
ToolModelCategory: Category that influences model selection
|
||||
"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
return ToolModelCategory.BALANCED
|
||||
|
||||
def get_conversation_embedded_files(self, continuation_id: Optional[str]) -> list[str]:
|
||||
"""
|
||||
Get list of files already embedded in conversation history.
|
||||
@@ -474,11 +578,13 @@ class BaseTool(ABC):
|
||||
if model_name.lower() == "auto":
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Use the preferred fallback model for capacity estimation
|
||||
# Use tool-specific fallback model for capacity estimation
|
||||
# This properly handles different providers (OpenAI=200K, Gemini=1M)
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
tool_category = self.get_model_category()
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
|
||||
logger.debug(
|
||||
f"[FILES] {self.name}: Auto mode detected, using {fallback_model} for capacity estimation"
|
||||
f"[FILES] {self.name}: Auto mode detected, using {fallback_model} "
|
||||
f"for {tool_category.value} tool capacity estimation"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -898,13 +1004,39 @@ When recommending searches, be specific about what information you need and why
|
||||
|
||||
model_name = DEFAULT_MODEL
|
||||
|
||||
# In auto mode, model parameter is required
|
||||
from config import IS_AUTO_MODE
|
||||
# Check if we need Claude to select a model
|
||||
# This happens when:
|
||||
# 1. The model is explicitly "auto"
|
||||
# 2. The requested model is not available
|
||||
if self._should_require_model_selection(model_name):
|
||||
# Get suggested model based on tool category
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
tool_category = self.get_model_category()
|
||||
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
|
||||
|
||||
# Build error message based on why selection is required
|
||||
if model_name.lower() == "auto":
|
||||
error_message = (
|
||||
f"Model parameter is required in auto mode. "
|
||||
f"Suggested model for {self.name}: '{suggested_model}' "
|
||||
f"(category: {tool_category.value})"
|
||||
)
|
||||
else:
|
||||
# Model was specified but not available
|
||||
# Get list of available models
|
||||
available_models = self._get_available_models()
|
||||
|
||||
error_message = (
|
||||
f"Model '{model_name}' is not available with current API keys. "
|
||||
f"Available models: {', '.join(available_models)}. "
|
||||
f"Suggested model for {self.name}: '{suggested_model}' "
|
||||
f"(category: {tool_category.value})"
|
||||
)
|
||||
|
||||
if IS_AUTO_MODE and model_name.lower() == "auto":
|
||||
error_output = ToolOutput(
|
||||
status="error",
|
||||
content="Model parameter is required. Please specify which model to use for this task.",
|
||||
content=error_message,
|
||||
content_type="text",
|
||||
)
|
||||
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
Chat tool - General development chat and collaborative thinking
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from mcp.types import TextContent
|
||||
from pydantic import Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from config import TEMPERATURE_BALANCED
|
||||
from prompts import CHAT_PROMPT
|
||||
|
||||
@@ -44,8 +47,6 @@ class ChatTool(BaseTool):
|
||||
)
|
||||
|
||||
def get_input_schema(self) -> dict[str, Any]:
|
||||
from config import IS_AUTO_MODE
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -80,7 +81,7 @@ class ChatTool(BaseTool):
|
||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||
},
|
||||
},
|
||||
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||
"required": ["prompt"] + (["model"] if self.is_effective_auto_mode() else []),
|
||||
}
|
||||
|
||||
return schema
|
||||
@@ -91,6 +92,12 @@ class ChatTool(BaseTool):
|
||||
def get_default_temperature(self) -> float:
|
||||
return TEMPERATURE_BALANCED
|
||||
|
||||
def get_model_category(self) -> "ToolModelCategory":
|
||||
"""Chat prioritizes fast responses and cost efficiency"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
return ToolModelCategory.FAST_RESPONSE
|
||||
|
||||
def get_request_model(self):
|
||||
return ChatRequest
|
||||
|
||||
|
||||
@@ -82,8 +82,6 @@ class CodeReviewTool(BaseTool):
|
||||
)
|
||||
|
||||
def get_input_schema(self) -> dict[str, Any]:
|
||||
from config import IS_AUTO_MODE
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -138,7 +136,7 @@ class CodeReviewTool(BaseTool):
|
||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||
},
|
||||
},
|
||||
"required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||
"required": ["files", "prompt"] + (["model"] if self.is_effective_auto_mode() else []),
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
Debug Issue tool - Root cause analysis and debugging assistance
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from mcp.types import TextContent
|
||||
from pydantic import Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from config import TEMPERATURE_ANALYTICAL
|
||||
from prompts import DEBUG_ISSUE_PROMPT
|
||||
|
||||
@@ -50,8 +53,6 @@ class DebugIssueTool(BaseTool):
|
||||
)
|
||||
|
||||
def get_input_schema(self) -> dict[str, Any]:
|
||||
from config import IS_AUTO_MODE
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -98,7 +99,7 @@ class DebugIssueTool(BaseTool):
|
||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||
},
|
||||
},
|
||||
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||
"required": ["prompt"] + (["model"] if self.is_effective_auto_mode() else []),
|
||||
}
|
||||
|
||||
return schema
|
||||
@@ -109,6 +110,12 @@ class DebugIssueTool(BaseTool):
|
||||
def get_default_temperature(self) -> float:
|
||||
return TEMPERATURE_ANALYTICAL
|
||||
|
||||
def get_model_category(self) -> "ToolModelCategory":
|
||||
"""Debug requires deep analysis and reasoning"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
return ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def get_request_model(self):
|
||||
return DebugIssueRequest
|
||||
|
||||
|
||||
@@ -2,11 +2,20 @@
|
||||
Data models for tool responses and interactions
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ToolModelCategory(Enum):
|
||||
"""Categories for tool model selection based on requirements."""
|
||||
|
||||
EXTENDED_REASONING = "extended_reasoning" # Requires deep thinking capabilities
|
||||
FAST_RESPONSE = "fast_response" # Speed and cost efficiency preferred
|
||||
BALANCED = "balanced" # Balance of capability and performance
|
||||
|
||||
|
||||
class ContinuationOffer(BaseModel):
|
||||
"""Offer for Claude to continue conversation when Gemini doesn't ask follow-up"""
|
||||
|
||||
|
||||
@@ -9,11 +9,14 @@ This provides comprehensive context for AI analysis - not a duplication bug.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
from mcp.types import TextContent
|
||||
from pydantic import Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from prompts.tool_prompts import PRECOMMIT_PROMPT
|
||||
from utils.file_utils import translate_file_paths, translate_path_for_environment
|
||||
from utils.git_utils import find_git_repositories, get_git_status, run_git_command
|
||||
@@ -100,30 +103,83 @@ class Precommit(BaseTool):
|
||||
)
|
||||
|
||||
def get_input_schema(self) -> dict[str, Any]:
|
||||
from config import IS_AUTO_MODE
|
||||
|
||||
schema = self.get_request_model().model_json_schema()
|
||||
# Ensure model parameter has enhanced description
|
||||
if "properties" in schema and "model" in schema["properties"]:
|
||||
schema["properties"]["model"] = self.get_model_field_schema()
|
||||
|
||||
# In auto mode, model is required
|
||||
if IS_AUTO_MODE and "required" in schema:
|
||||
if "model" not in schema["required"]:
|
||||
schema["required"].append("model")
|
||||
# Ensure use_websearch is in the schema with proper description
|
||||
if "properties" in schema and "use_websearch" not in schema["properties"]:
|
||||
schema["properties"]["use_websearch"] = {
|
||||
"type": "boolean",
|
||||
"description": "Enable web search for documentation, best practices, and current information. Particularly useful for: brainstorming sessions, architectural design discussions, exploring industry best practices, working with specific frameworks/technologies, researching solutions to complex problems, or when current documentation and community insights would enhance the analysis.",
|
||||
"default": True,
|
||||
}
|
||||
# Add continuation_id parameter
|
||||
if "properties" in schema and "continuation_id" not in schema["properties"]:
|
||||
schema["properties"]["continuation_id"] = {
|
||||
"type": "string",
|
||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||
}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"title": "PrecommitRequest",
|
||||
"description": "Request model for precommit tool",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Starting directory to search for git repositories (must be absolute path).",
|
||||
},
|
||||
"model": self.get_model_field_schema(),
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The original user request description for the changes. Provides critical context for the review.",
|
||||
},
|
||||
"compare_to": {
|
||||
"type": "string",
|
||||
"description": "Optional: A git ref (branch, tag, commit hash) to compare against. If not provided, reviews local staged and unstaged changes.",
|
||||
},
|
||||
"include_staged": {
|
||||
"type": "boolean",
|
||||
"default": True,
|
||||
"description": "Include staged changes in the review. Only applies if 'compare_to' is not set.",
|
||||
},
|
||||
"include_unstaged": {
|
||||
"type": "boolean",
|
||||
"default": True,
|
||||
"description": "Include uncommitted (unstaged) changes in the review. Only applies if 'compare_to' is not set.",
|
||||
},
|
||||
"focus_on": {
|
||||
"type": "string",
|
||||
"description": "Specific aspects to focus on (e.g., 'logic for user authentication', 'database query efficiency').",
|
||||
},
|
||||
"review_type": {
|
||||
"type": "string",
|
||||
"enum": ["full", "security", "performance", "quick"],
|
||||
"default": "full",
|
||||
"description": "Type of review to perform on the changes.",
|
||||
},
|
||||
"severity_filter": {
|
||||
"type": "string",
|
||||
"enum": ["critical", "high", "medium", "all"],
|
||||
"default": "all",
|
||||
"description": "Minimum severity level to report on the changes.",
|
||||
},
|
||||
"max_depth": {
|
||||
"type": "integer",
|
||||
"default": 5,
|
||||
"description": "Maximum depth to search for nested git repositories to prevent excessive recursion.",
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"description": "Temperature for the response (0.0 to 1.0). Lower values are more focused and deterministic.",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
},
|
||||
"thinking_mode": {
|
||||
"type": "string",
|
||||
"enum": ["minimal", "low", "medium", "high", "max"],
|
||||
"description": "Thinking depth mode for the assistant.",
|
||||
},
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional files or directories to provide as context (must be absolute paths). These files are not part of the changes but provide helpful context like configs, docs, or related code.",
|
||||
},
|
||||
"use_websearch": {
|
||||
"type": "boolean",
|
||||
"description": "Enable web search for documentation, best practices, and current information. Particularly useful for: brainstorming sessions, architectural design discussions, exploring industry best practices, working with specific frameworks/technologies, researching solutions to complex problems, or when current documentation and community insights would enhance the analysis.",
|
||||
"default": True,
|
||||
},
|
||||
"continuation_id": {
|
||||
"type": "string",
|
||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||
},
|
||||
},
|
||||
"required": ["path"] + (["model"] if self.is_effective_auto_mode() else []),
|
||||
}
|
||||
return schema
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
@@ -138,6 +194,12 @@ class Precommit(BaseTool):
|
||||
|
||||
return TEMPERATURE_ANALYTICAL
|
||||
|
||||
def get_model_category(self) -> "ToolModelCategory":
|
||||
"""Precommit requires thorough analysis and reasoning"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
return ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
async def execute(self, arguments: dict[str, Any]) -> list[TextContent]:
|
||||
"""Override execute to check original_request size before processing"""
|
||||
# First validate request
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
ThinkDeep tool - Extended reasoning and problem-solving
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from mcp.types import TextContent
|
||||
from pydantic import Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from config import TEMPERATURE_CREATIVE
|
||||
from prompts import THINKDEEP_PROMPT
|
||||
|
||||
@@ -48,8 +51,6 @@ class ThinkDeepTool(BaseTool):
|
||||
)
|
||||
|
||||
def get_input_schema(self) -> dict[str, Any]:
|
||||
from config import IS_AUTO_MODE
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -93,7 +94,7 @@ class ThinkDeepTool(BaseTool):
|
||||
"description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.",
|
||||
},
|
||||
},
|
||||
"required": ["prompt"] + (["model"] if IS_AUTO_MODE else []),
|
||||
"required": ["prompt"] + (["model"] if self.is_effective_auto_mode() else []),
|
||||
}
|
||||
|
||||
return schema
|
||||
@@ -110,6 +111,12 @@ class ThinkDeepTool(BaseTool):
|
||||
|
||||
return DEFAULT_THINKING_MODE_THINKDEEP
|
||||
|
||||
def get_model_category(self) -> "ToolModelCategory":
|
||||
"""ThinkDeep requires extended reasoning capabilities"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
return ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def get_request_model(self):
|
||||
return ThinkDeepRequest
|
||||
|
||||
|
||||
Reference in New Issue
Block a user