Categorize tools into 'model capabilities categories' to help determine which type of model to pick when in auto mode
Encourage Claude to pick the best model for the job automatically in auto mode Lots of new tests to ensure automatic model picking works reliably based on user preference or when a matching model is not found or ambiguous Improved error reporting when bogus model is requested and is not configured or available
This commit is contained in:
417
tests/test_per_tool_model_defaults.py
Normal file
417
tests/test_per_tool_model_defaults.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
Test per-tool model default selection functionality
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.registry import ModelProviderRegistry, ProviderType
|
||||
from tools.analyze import AnalyzeTool
|
||||
from tools.base import BaseTool
|
||||
from tools.chat import ChatTool
|
||||
from tools.codereview import CodeReviewTool
|
||||
from tools.debug import DebugIssueTool
|
||||
from tools.models import ToolModelCategory
|
||||
from tools.precommit import Precommit
|
||||
from tools.thinkdeep import ThinkDeepTool
|
||||
|
||||
|
||||
class TestToolModelCategories:
|
||||
"""Test that each tool returns the correct model category."""
|
||||
|
||||
def test_thinkdeep_category(self):
|
||||
tool = ThinkDeepTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def test_debug_category(self):
|
||||
tool = DebugIssueTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def test_analyze_category(self):
|
||||
tool = AnalyzeTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def test_precommit_category(self):
|
||||
tool = Precommit()
|
||||
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def test_chat_category(self):
|
||||
tool = ChatTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.FAST_RESPONSE
|
||||
|
||||
def test_codereview_category(self):
|
||||
tool = CodeReviewTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.BALANCED
|
||||
|
||||
def test_base_tool_default_category(self):
|
||||
# Test that BaseTool defaults to BALANCED
|
||||
class TestTool(BaseTool):
|
||||
def get_name(self):
|
||||
return "test"
|
||||
|
||||
def get_description(self):
|
||||
return "test"
|
||||
|
||||
def get_input_schema(self):
|
||||
return {}
|
||||
|
||||
def get_system_prompt(self):
|
||||
return "test"
|
||||
|
||||
def get_request_model(self):
|
||||
return MagicMock
|
||||
|
||||
async def prepare_prompt(self, request):
|
||||
return "test"
|
||||
|
||||
tool = TestTool()
|
||||
assert tool.get_model_category() == ToolModelCategory.BALANCED
|
||||
|
||||
|
||||
class TestModelSelection:
|
||||
"""Test model selection based on tool categories."""
|
||||
|
||||
def test_extended_reasoning_with_openai(self):
|
||||
"""Test EXTENDED_REASONING prefers o3 when OpenAI is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock OpenAI available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model == "o3"
|
||||
|
||||
def test_extended_reasoning_with_gemini_only(self):
|
||||
"""Test EXTENDED_REASONING prefers pro when only Gemini is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock only Gemini available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model == "pro"
|
||||
|
||||
def test_fast_response_with_openai(self):
|
||||
"""Test FAST_RESPONSE prefers o3-mini when OpenAI is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock OpenAI available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
assert model == "o3-mini"
|
||||
|
||||
def test_fast_response_with_gemini_only(self):
|
||||
"""Test FAST_RESPONSE prefers flash when only Gemini is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock only Gemini available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
assert model == "flash"
|
||||
|
||||
def test_balanced_category_fallback(self):
|
||||
"""Test BALANCED category uses existing logic."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock OpenAI available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
|
||||
assert model == "o3-mini" # Balanced prefers o3-mini when OpenAI available
|
||||
|
||||
def test_no_category_uses_balanced_logic(self):
|
||||
"""Test that no category specified uses balanced logic."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock Gemini available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
assert model == "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
|
||||
class TestCustomProviderFallback:
|
||||
"""Test fallback to custom/openrouter providers."""
|
||||
|
||||
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
|
||||
def test_extended_reasoning_custom_fallback(self, mock_find_thinking):
|
||||
"""Test EXTENDED_REASONING falls back to custom thinking model."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# No native providers available
|
||||
mock_get_provider.return_value = None
|
||||
mock_find_thinking.return_value = "custom/thinking-model"
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model == "custom/thinking-model"
|
||||
mock_find_thinking.assert_called_once()
|
||||
|
||||
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
|
||||
def test_extended_reasoning_final_fallback(self, mock_find_thinking):
|
||||
"""Test EXTENDED_REASONING falls back to pro when no custom found."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# No providers available
|
||||
mock_get_provider.return_value = None
|
||||
mock_find_thinking.return_value = None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model == "gemini-2.5-pro-preview-06-05"
|
||||
|
||||
|
||||
class TestAutoModeErrorMessages:
|
||||
"""Test that auto mode error messages include suggested models."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thinkdeep_auto_error_message(self):
|
||||
"""Test ThinkDeep tool suggests appropriate model in auto mode."""
|
||||
with patch("config.IS_AUTO_MODE", True):
|
||||
with patch("config.DEFAULT_MODEL", "auto"):
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock Gemini available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None
|
||||
|
||||
tool = ThinkDeepTool()
|
||||
result = await tool.execute({"prompt": "test", "model": "auto"})
|
||||
|
||||
assert len(result) == 1
|
||||
assert "Model parameter is required in auto mode" in result[0].text
|
||||
assert "Suggested model for thinkdeep: 'pro'" in result[0].text
|
||||
assert "(category: extended_reasoning)" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_auto_error_message(self):
|
||||
"""Test Chat tool suggests appropriate model in auto mode."""
|
||||
with patch("config.IS_AUTO_MODE", True):
|
||||
with patch("config.DEFAULT_MODEL", "auto"):
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock OpenAI available
|
||||
mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None
|
||||
|
||||
tool = ChatTool()
|
||||
result = await tool.execute({"prompt": "test", "model": "auto"})
|
||||
|
||||
assert len(result) == 1
|
||||
assert "Model parameter is required in auto mode" in result[0].text
|
||||
assert "Suggested model for chat: 'o3-mini'" in result[0].text
|
||||
assert "(category: fast_response)" in result[0].text
|
||||
|
||||
|
||||
class TestFileContentPreparation:
|
||||
"""Test that file content preparation uses tool-specific model for capacity."""
|
||||
|
||||
@patch("tools.base.read_files")
|
||||
@patch("tools.base.logger")
|
||||
def test_auto_mode_uses_tool_category(self, mock_logger, mock_read_files):
|
||||
"""Test that auto mode uses tool-specific model for capacity estimation."""
|
||||
mock_read_files.return_value = "file content"
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock provider with capabilities
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_capabilities.return_value = MagicMock(context_window=1_000_000)
|
||||
mock_get_provider.side_effect = lambda ptype: mock_provider if ptype == ProviderType.GOOGLE else None
|
||||
|
||||
# Create a tool and test file content preparation
|
||||
tool = ThinkDeepTool()
|
||||
tool._current_model_name = "auto"
|
||||
|
||||
# Call the method
|
||||
tool._prepare_file_content_for_prompt(["/test/file.py"], None, "test")
|
||||
|
||||
# Check that it logged the correct message
|
||||
debug_calls = [call for call in mock_logger.debug.call_args_list if "Auto mode detected" in str(call)]
|
||||
assert len(debug_calls) > 0
|
||||
assert "using pro for extended_reasoning tool capacity estimation" in str(debug_calls[0])
|
||||
|
||||
|
||||
class TestProviderHelperMethods:
|
||||
"""Test the helper methods for finding models from custom/openrouter."""
|
||||
|
||||
def test_find_extended_thinking_model_custom(self):
|
||||
"""Test finding thinking model from custom provider."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
# Mock custom provider with thinking model
|
||||
mock_custom = MagicMock(spec=CustomProvider)
|
||||
mock_custom.model_registry = {
|
||||
"model1": {"supports_extended_thinking": False},
|
||||
"model2": {"supports_extended_thinking": True},
|
||||
"model3": {"supports_extended_thinking": False},
|
||||
}
|
||||
mock_get_provider.side_effect = lambda ptype: mock_custom if ptype == ProviderType.CUSTOM else None
|
||||
|
||||
model = ModelProviderRegistry._find_extended_thinking_model()
|
||||
assert model == "model2"
|
||||
|
||||
def test_find_extended_thinking_model_openrouter(self):
|
||||
"""Test finding thinking model from openrouter."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock openrouter provider
|
||||
mock_openrouter = MagicMock()
|
||||
mock_openrouter.validate_model_name.side_effect = lambda m: m == "anthropic/claude-3.5-sonnet"
|
||||
mock_get_provider.side_effect = lambda ptype: mock_openrouter if ptype == ProviderType.OPENROUTER else None
|
||||
|
||||
model = ModelProviderRegistry._find_extended_thinking_model()
|
||||
assert model == "anthropic/claude-3.5-sonnet"
|
||||
|
||||
def test_find_extended_thinking_model_none_found(self):
|
||||
"""Test when no thinking model is found."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# No providers available
|
||||
mock_get_provider.return_value = None
|
||||
|
||||
model = ModelProviderRegistry._find_extended_thinking_model()
|
||||
assert model is None
|
||||
|
||||
|
||||
class TestEffectiveAutoMode:
|
||||
"""Test the is_effective_auto_mode method."""
|
||||
|
||||
def test_explicit_auto_mode(self):
|
||||
"""Test when DEFAULT_MODEL is explicitly 'auto'."""
|
||||
with patch("config.DEFAULT_MODEL", "auto"):
|
||||
with patch("config.IS_AUTO_MODE", True):
|
||||
tool = ChatTool()
|
||||
assert tool.is_effective_auto_mode() is True
|
||||
|
||||
def test_unavailable_model_triggers_auto_mode(self):
|
||||
"""Test when DEFAULT_MODEL is set but not available."""
|
||||
with patch("config.DEFAULT_MODEL", "o3"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
mock_get_provider.return_value = None # Model not available
|
||||
|
||||
tool = ChatTool()
|
||||
assert tool.is_effective_auto_mode() is True
|
||||
|
||||
def test_available_model_no_auto_mode(self):
|
||||
"""Test when DEFAULT_MODEL is set and available."""
|
||||
with patch("config.DEFAULT_MODEL", "pro"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
mock_get_provider.return_value = MagicMock() # Model is available
|
||||
|
||||
tool = ChatTool()
|
||||
assert tool.is_effective_auto_mode() is False
|
||||
|
||||
|
||||
class TestRuntimeModelSelection:
|
||||
"""Test runtime model selection behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_auto_in_request(self):
|
||||
"""Test when Claude explicitly passes model='auto'."""
|
||||
with patch("config.DEFAULT_MODEL", "pro"): # DEFAULT_MODEL is a real model
|
||||
with patch("config.IS_AUTO_MODE", False): # Not in auto mode
|
||||
tool = ThinkDeepTool()
|
||||
result = await tool.execute({"prompt": "test", "model": "auto"})
|
||||
|
||||
# Should require model selection even though DEFAULT_MODEL is valid
|
||||
assert len(result) == 1
|
||||
assert "Model parameter is required in auto mode" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unavailable_model_in_request(self):
|
||||
"""Test when Claude passes an unavailable model."""
|
||||
with patch("config.DEFAULT_MODEL", "pro"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
# Model is not available
|
||||
mock_get_provider.return_value = None
|
||||
|
||||
tool = ChatTool()
|
||||
result = await tool.execute({"prompt": "test", "model": "gpt-5-turbo"})
|
||||
|
||||
# Should require model selection
|
||||
assert len(result) == 1
|
||||
# When a specific model is requested but not available, error message is different
|
||||
assert "gpt-5-turbo" in result[0].text
|
||||
assert "is not available" in result[0].text
|
||||
assert "(category: fast_response)" in result[0].text
|
||||
|
||||
|
||||
class TestSchemaGeneration:
|
||||
"""Test schema generation with different configurations."""
|
||||
|
||||
def test_schema_with_explicit_auto_mode(self):
|
||||
"""Test schema when DEFAULT_MODEL='auto'."""
|
||||
with patch("config.DEFAULT_MODEL", "auto"):
|
||||
with patch("config.IS_AUTO_MODE", True):
|
||||
tool = ChatTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Model should be required
|
||||
assert "model" in schema["required"]
|
||||
|
||||
def test_schema_with_unavailable_default_model(self):
|
||||
"""Test schema when DEFAULT_MODEL is set but unavailable."""
|
||||
with patch("config.DEFAULT_MODEL", "o3"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
mock_get_provider.return_value = None # Model not available
|
||||
|
||||
tool = AnalyzeTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Model should be required due to unavailable DEFAULT_MODEL
|
||||
assert "model" in schema["required"]
|
||||
|
||||
def test_schema_with_available_default_model(self):
|
||||
"""Test schema when DEFAULT_MODEL is available."""
|
||||
with patch("config.DEFAULT_MODEL", "pro"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
mock_get_provider.return_value = MagicMock() # Model is available
|
||||
|
||||
tool = ThinkDeepTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Model should NOT be required
|
||||
assert "model" not in schema["required"]
|
||||
|
||||
|
||||
class TestUnavailableModelFallback:
|
||||
"""Test fallback behavior when DEFAULT_MODEL is not available."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unavailable_default_model_fallback(self):
|
||||
"""Test that unavailable DEFAULT_MODEL triggers auto mode behavior."""
|
||||
with patch("config.DEFAULT_MODEL", "o3"): # Set DEFAULT_MODEL to a specific model
|
||||
with patch("config.IS_AUTO_MODE", False): # Not in auto mode
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
# Model is not available (no provider)
|
||||
mock_get_provider.return_value = None
|
||||
|
||||
tool = ThinkDeepTool()
|
||||
result = await tool.execute({"prompt": "test"}) # No model specified
|
||||
|
||||
# Should get auto mode error since model is unavailable
|
||||
assert len(result) == 1
|
||||
# When DEFAULT_MODEL is unavailable, the error message indicates the model is not available
|
||||
assert "o3" in result[0].text
|
||||
assert "is not available" in result[0].text
|
||||
# The suggested model depends on which providers are available
|
||||
# Just check that it suggests a model for the extended_reasoning category
|
||||
assert "(category: extended_reasoning)" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_available_default_model_no_fallback(self):
|
||||
"""Test that available DEFAULT_MODEL works normally."""
|
||||
with patch("config.DEFAULT_MODEL", "pro"):
|
||||
with patch("config.IS_AUTO_MODE", False):
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
|
||||
# Model is available
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.generate_content.return_value = MagicMock(content="Test response", metadata={})
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Mock the provider lookup in BaseTool.get_model_provider
|
||||
with patch.object(BaseTool, "get_model_provider") as mock_get_model_provider:
|
||||
mock_get_model_provider.return_value = mock_provider
|
||||
|
||||
tool = ChatTool()
|
||||
result = await tool.execute({"prompt": "test"}) # No model specified
|
||||
|
||||
# Should work normally, not require model parameter
|
||||
assert len(result) == 1
|
||||
output = json.loads(result[0].text)
|
||||
assert output["status"] == "success"
|
||||
assert "Test response" in output["content"]
|
||||
Reference in New Issue
Block a user