diff --git a/.gitignore b/.gitignore index ceb055a..9675212 100644 --- a/.gitignore +++ b/.gitignore @@ -106,6 +106,7 @@ celerybeat.pid # Environments .env +.env~ .venv env/ venv/ diff --git a/config.py b/config.py index 9e13fdf..c487f75 100644 --- a/config.py +++ b/config.py @@ -14,7 +14,7 @@ import os # These values are used in server responses and for tracking releases # IMPORTANT: This is the single source of truth for version and author info # Semantic versioning: MAJOR.MINOR.PATCH -__version__ = "4.2.2" +__version__ = "4.3.0" # Last update date in ISO format __updated__ = "2025-06-14" # Primary maintainer diff --git a/providers/registry.py b/providers/registry.py index 44f75d9..1a326ba 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -2,10 +2,13 @@ import logging import os -from typing import Optional +from typing import TYPE_CHECKING, Optional from .base import ModelProvider, ProviderType +if TYPE_CHECKING: + from tools.models import ToolModelCategory + class ModelProviderRegistry: """Registry for managing model providers.""" @@ -189,27 +192,50 @@ class ModelProviderRegistry: return os.getenv(env_var) @classmethod - def get_preferred_fallback_model(cls) -> str: - """Get the preferred fallback model based on available API keys. + def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str: + """Get the preferred fallback model based on available API keys and tool category. This method checks which providers have valid API keys and returns a sensible default model for auto mode fallback situations. - Priority order: - 1. OpenAI o3-mini (balanced performance/cost) if OpenAI API key available - 2. Gemini 2.0 Flash (fast and efficient) if Gemini API key available - 3. OpenAI o3 (high performance) if OpenAI API key available - 4. Gemini 2.5 Pro (deep reasoning) if Gemini API key available - 5. Fallback to gemini-2.5-flash-preview-05-20 (most common case) + Args: + tool_category: Optional category to influence model selection Returns: Model name string for fallback use """ + # Import here to avoid circular import + from tools.models import ToolModelCategory + # Check provider availability by trying to get instances openai_available = cls.get_provider(ProviderType.OPENAI) is not None gemini_available = cls.get_provider(ProviderType.GOOGLE) is not None - # Priority order: prefer balanced models first, then high-performance + if tool_category == ToolModelCategory.EXTENDED_REASONING: + # Prefer thinking-capable models for deep reasoning tools + if openai_available: + return "o3" # O3 for deep reasoning + elif gemini_available: + return "pro" # Gemini Pro with thinking mode + else: + # Try to find thinking-capable model from custom/openrouter + thinking_model = cls._find_extended_thinking_model() + if thinking_model: + return thinking_model + # Fallback to pro if nothing found + return "gemini-2.5-pro-preview-06-05" + + elif tool_category == ToolModelCategory.FAST_RESPONSE: + # Prefer fast, cost-efficient models + if openai_available: + return "o3-mini" # Fast and efficient + elif gemini_available: + return "flash" # Gemini Flash for speed + else: + # Default to flash + return "gemini-2.5-flash-preview-05-20" + + # BALANCED or no category specified - use existing balanced logic if openai_available: return "o3-mini" # Balanced performance/cost elif gemini_available: @@ -219,6 +245,51 @@ class ModelProviderRegistry: # This maintains backward compatibility for tests return "gemini-2.5-flash-preview-05-20" + @classmethod + def _find_extended_thinking_model(cls) -> Optional[str]: + """Find a model suitable for extended reasoning from custom/openrouter providers. + + Returns: + Model name if found, None otherwise + """ + # Check custom provider first + custom_provider = cls.get_provider(ProviderType.CUSTOM) + if custom_provider: + # Check if it's a CustomModelProvider and has thinking models + try: + from providers.custom import CustomProvider + + if isinstance(custom_provider, CustomProvider) and hasattr(custom_provider, "model_registry"): + for model_name, config in custom_provider.model_registry.items(): + if config.get("supports_extended_thinking", False): + return model_name + except ImportError: + pass + + # Then check OpenRouter for high-context/powerful models + openrouter_provider = cls.get_provider(ProviderType.OPENROUTER) + if openrouter_provider: + # Prefer models known for deep reasoning + preferred_models = [ + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3-opus-20240229", + "meta-llama/llama-3.1-70b-instruct", + "google/gemini-pro-1.5", + "mistralai/mixtral-8x7b-instruct", + ] + for model in preferred_models: + try: + if openrouter_provider.validate_model_name(model): + return model + except Exception as e: + # Log the error for debugging purposes but continue searching + import logging + + logging.warning(f"Model validation for '{model}' on OpenRouter failed: {e}") + continue + + return None + @classmethod def get_available_providers_with_keys(cls) -> list[ProviderType]: """Get list of provider types that have valid API keys. diff --git a/tests/test_auto_mode.py b/tests/test_auto_mode.py index 6d63301..42142f1 100644 --- a/tests/test_auto_mode.py +++ b/tests/test_auto_mode.py @@ -75,7 +75,7 @@ class TestAutoMode: model_schema = schema["properties"]["model"] assert "enum" in model_schema assert "flash" in model_schema["enum"] - assert "Choose the best model" in model_schema["description"] + assert "select the most suitable model" in model_schema["description"] finally: # Restore @@ -134,6 +134,58 @@ class TestAutoMode: os.environ.pop("DEFAULT_MODEL", None) importlib.reload(config) + @pytest.mark.asyncio + async def test_unavailable_model_error_message(self): + """Test that unavailable model shows helpful error with available models""" + # Save original + original = os.environ.get("DEFAULT_MODEL", "") + + try: + # Enable auto mode + os.environ["DEFAULT_MODEL"] = "auto" + import config + + importlib.reload(config) + + tool = AnalyzeTool() + + # Mock the provider to simulate o3 not being available + with patch("providers.registry.ModelProviderRegistry.get_provider_for_model") as mock_provider: + # Mock that o3 is not available but flash/pro are + def mock_get_provider(model_name): + if model_name in ["flash", "pro", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-06-05"]: + # Return a mock provider for available models + from unittest.mock import MagicMock + + return MagicMock() + else: + # o3 and others are not available + return None + + mock_provider.side_effect = mock_get_provider + + # Execute with unavailable model + result = await tool.execute( + {"files": ["/tmp/test.py"], "prompt": "Analyze this", "model": "o3"} # This model is not available + ) + + # Should get error with helpful message + assert len(result) == 1 + response = result[0].text + assert "error" in response + assert "Model 'o3' is not available" in response + assert "Available models:" in response + # Should list the available models + assert "flash" in response or "pro" in response + + finally: + # Restore + if original: + os.environ["DEFAULT_MODEL"] = original + else: + os.environ.pop("DEFAULT_MODEL", None) + importlib.reload(config) + def test_model_field_schema_generation(self): """Test the get_model_field_schema method""" from tools.base import BaseTool @@ -173,7 +225,7 @@ class TestAutoMode: schema = tool.get_model_field_schema() assert "enum" in schema assert all(model in schema["enum"] for model in ["flash", "pro", "o3"]) - assert "Choose the best model" in schema["description"] + assert "select the most suitable model" in schema["description"] # Test normal mode os.environ["DEFAULT_MODEL"] = "pro" diff --git a/tests/test_per_tool_model_defaults.py b/tests/test_per_tool_model_defaults.py new file mode 100644 index 0000000..9083f4e --- /dev/null +++ b/tests/test_per_tool_model_defaults.py @@ -0,0 +1,417 @@ +""" +Test per-tool model default selection functionality +""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from providers.registry import ModelProviderRegistry, ProviderType +from tools.analyze import AnalyzeTool +from tools.base import BaseTool +from tools.chat import ChatTool +from tools.codereview import CodeReviewTool +from tools.debug import DebugIssueTool +from tools.models import ToolModelCategory +from tools.precommit import Precommit +from tools.thinkdeep import ThinkDeepTool + + +class TestToolModelCategories: + """Test that each tool returns the correct model category.""" + + def test_thinkdeep_category(self): + tool = ThinkDeepTool() + assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING + + def test_debug_category(self): + tool = DebugIssueTool() + assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING + + def test_analyze_category(self): + tool = AnalyzeTool() + assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING + + def test_precommit_category(self): + tool = Precommit() + assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING + + def test_chat_category(self): + tool = ChatTool() + assert tool.get_model_category() == ToolModelCategory.FAST_RESPONSE + + def test_codereview_category(self): + tool = CodeReviewTool() + assert tool.get_model_category() == ToolModelCategory.BALANCED + + def test_base_tool_default_category(self): + # Test that BaseTool defaults to BALANCED + class TestTool(BaseTool): + def get_name(self): + return "test" + + def get_description(self): + return "test" + + def get_input_schema(self): + return {} + + def get_system_prompt(self): + return "test" + + def get_request_model(self): + return MagicMock + + async def prepare_prompt(self, request): + return "test" + + tool = TestTool() + assert tool.get_model_category() == ToolModelCategory.BALANCED + + +class TestModelSelection: + """Test model selection based on tool categories.""" + + def test_extended_reasoning_with_openai(self): + """Test EXTENDED_REASONING prefers o3 when OpenAI is available.""" + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # Mock OpenAI available + mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None + + model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING) + assert model == "o3" + + def test_extended_reasoning_with_gemini_only(self): + """Test EXTENDED_REASONING prefers pro when only Gemini is available.""" + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # Mock only Gemini available + mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None + + model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING) + assert model == "pro" + + def test_fast_response_with_openai(self): + """Test FAST_RESPONSE prefers o3-mini when OpenAI is available.""" + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # Mock OpenAI available + mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None + + model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE) + assert model == "o3-mini" + + def test_fast_response_with_gemini_only(self): + """Test FAST_RESPONSE prefers flash when only Gemini is available.""" + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # Mock only Gemini available + mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None + + model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE) + assert model == "flash" + + def test_balanced_category_fallback(self): + """Test BALANCED category uses existing logic.""" + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # Mock OpenAI available + mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None + + model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED) + assert model == "o3-mini" # Balanced prefers o3-mini when OpenAI available + + def test_no_category_uses_balanced_logic(self): + """Test that no category specified uses balanced logic.""" + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # Mock Gemini available + mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None + + model = ModelProviderRegistry.get_preferred_fallback_model() + assert model == "gemini-2.5-flash-preview-05-20" + + +class TestCustomProviderFallback: + """Test fallback to custom/openrouter providers.""" + + @patch.object(ModelProviderRegistry, "_find_extended_thinking_model") + def test_extended_reasoning_custom_fallback(self, mock_find_thinking): + """Test EXTENDED_REASONING falls back to custom thinking model.""" + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # No native providers available + mock_get_provider.return_value = None + mock_find_thinking.return_value = "custom/thinking-model" + + model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING) + assert model == "custom/thinking-model" + mock_find_thinking.assert_called_once() + + @patch.object(ModelProviderRegistry, "_find_extended_thinking_model") + def test_extended_reasoning_final_fallback(self, mock_find_thinking): + """Test EXTENDED_REASONING falls back to pro when no custom found.""" + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # No providers available + mock_get_provider.return_value = None + mock_find_thinking.return_value = None + + model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING) + assert model == "gemini-2.5-pro-preview-06-05" + + +class TestAutoModeErrorMessages: + """Test that auto mode error messages include suggested models.""" + + @pytest.mark.asyncio + async def test_thinkdeep_auto_error_message(self): + """Test ThinkDeep tool suggests appropriate model in auto mode.""" + with patch("config.IS_AUTO_MODE", True): + with patch("config.DEFAULT_MODEL", "auto"): + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # Mock Gemini available + mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.GOOGLE else None + + tool = ThinkDeepTool() + result = await tool.execute({"prompt": "test", "model": "auto"}) + + assert len(result) == 1 + assert "Model parameter is required in auto mode" in result[0].text + assert "Suggested model for thinkdeep: 'pro'" in result[0].text + assert "(category: extended_reasoning)" in result[0].text + + @pytest.mark.asyncio + async def test_chat_auto_error_message(self): + """Test Chat tool suggests appropriate model in auto mode.""" + with patch("config.IS_AUTO_MODE", True): + with patch("config.DEFAULT_MODEL", "auto"): + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # Mock OpenAI available + mock_get_provider.side_effect = lambda ptype: MagicMock() if ptype == ProviderType.OPENAI else None + + tool = ChatTool() + result = await tool.execute({"prompt": "test", "model": "auto"}) + + assert len(result) == 1 + assert "Model parameter is required in auto mode" in result[0].text + assert "Suggested model for chat: 'o3-mini'" in result[0].text + assert "(category: fast_response)" in result[0].text + + +class TestFileContentPreparation: + """Test that file content preparation uses tool-specific model for capacity.""" + + @patch("tools.base.read_files") + @patch("tools.base.logger") + def test_auto_mode_uses_tool_category(self, mock_logger, mock_read_files): + """Test that auto mode uses tool-specific model for capacity estimation.""" + mock_read_files.return_value = "file content" + + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # Mock provider with capabilities + mock_provider = MagicMock() + mock_provider.get_capabilities.return_value = MagicMock(context_window=1_000_000) + mock_get_provider.side_effect = lambda ptype: mock_provider if ptype == ProviderType.GOOGLE else None + + # Create a tool and test file content preparation + tool = ThinkDeepTool() + tool._current_model_name = "auto" + + # Call the method + tool._prepare_file_content_for_prompt(["/test/file.py"], None, "test") + + # Check that it logged the correct message + debug_calls = [call for call in mock_logger.debug.call_args_list if "Auto mode detected" in str(call)] + assert len(debug_calls) > 0 + assert "using pro for extended_reasoning tool capacity estimation" in str(debug_calls[0]) + + +class TestProviderHelperMethods: + """Test the helper methods for finding models from custom/openrouter.""" + + def test_find_extended_thinking_model_custom(self): + """Test finding thinking model from custom provider.""" + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + from providers.custom import CustomProvider + + # Mock custom provider with thinking model + mock_custom = MagicMock(spec=CustomProvider) + mock_custom.model_registry = { + "model1": {"supports_extended_thinking": False}, + "model2": {"supports_extended_thinking": True}, + "model3": {"supports_extended_thinking": False}, + } + mock_get_provider.side_effect = lambda ptype: mock_custom if ptype == ProviderType.CUSTOM else None + + model = ModelProviderRegistry._find_extended_thinking_model() + assert model == "model2" + + def test_find_extended_thinking_model_openrouter(self): + """Test finding thinking model from openrouter.""" + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # Mock openrouter provider + mock_openrouter = MagicMock() + mock_openrouter.validate_model_name.side_effect = lambda m: m == "anthropic/claude-3.5-sonnet" + mock_get_provider.side_effect = lambda ptype: mock_openrouter if ptype == ProviderType.OPENROUTER else None + + model = ModelProviderRegistry._find_extended_thinking_model() + assert model == "anthropic/claude-3.5-sonnet" + + def test_find_extended_thinking_model_none_found(self): + """Test when no thinking model is found.""" + with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: + # No providers available + mock_get_provider.return_value = None + + model = ModelProviderRegistry._find_extended_thinking_model() + assert model is None + + +class TestEffectiveAutoMode: + """Test the is_effective_auto_mode method.""" + + def test_explicit_auto_mode(self): + """Test when DEFAULT_MODEL is explicitly 'auto'.""" + with patch("config.DEFAULT_MODEL", "auto"): + with patch("config.IS_AUTO_MODE", True): + tool = ChatTool() + assert tool.is_effective_auto_mode() is True + + def test_unavailable_model_triggers_auto_mode(self): + """Test when DEFAULT_MODEL is set but not available.""" + with patch("config.DEFAULT_MODEL", "o3"): + with patch("config.IS_AUTO_MODE", False): + with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider: + mock_get_provider.return_value = None # Model not available + + tool = ChatTool() + assert tool.is_effective_auto_mode() is True + + def test_available_model_no_auto_mode(self): + """Test when DEFAULT_MODEL is set and available.""" + with patch("config.DEFAULT_MODEL", "pro"): + with patch("config.IS_AUTO_MODE", False): + with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider: + mock_get_provider.return_value = MagicMock() # Model is available + + tool = ChatTool() + assert tool.is_effective_auto_mode() is False + + +class TestRuntimeModelSelection: + """Test runtime model selection behavior.""" + + @pytest.mark.asyncio + async def test_explicit_auto_in_request(self): + """Test when Claude explicitly passes model='auto'.""" + with patch("config.DEFAULT_MODEL", "pro"): # DEFAULT_MODEL is a real model + with patch("config.IS_AUTO_MODE", False): # Not in auto mode + tool = ThinkDeepTool() + result = await tool.execute({"prompt": "test", "model": "auto"}) + + # Should require model selection even though DEFAULT_MODEL is valid + assert len(result) == 1 + assert "Model parameter is required in auto mode" in result[0].text + + @pytest.mark.asyncio + async def test_unavailable_model_in_request(self): + """Test when Claude passes an unavailable model.""" + with patch("config.DEFAULT_MODEL", "pro"): + with patch("config.IS_AUTO_MODE", False): + with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider: + # Model is not available + mock_get_provider.return_value = None + + tool = ChatTool() + result = await tool.execute({"prompt": "test", "model": "gpt-5-turbo"}) + + # Should require model selection + assert len(result) == 1 + # When a specific model is requested but not available, error message is different + assert "gpt-5-turbo" in result[0].text + assert "is not available" in result[0].text + assert "(category: fast_response)" in result[0].text + + +class TestSchemaGeneration: + """Test schema generation with different configurations.""" + + def test_schema_with_explicit_auto_mode(self): + """Test schema when DEFAULT_MODEL='auto'.""" + with patch("config.DEFAULT_MODEL", "auto"): + with patch("config.IS_AUTO_MODE", True): + tool = ChatTool() + schema = tool.get_input_schema() + + # Model should be required + assert "model" in schema["required"] + + def test_schema_with_unavailable_default_model(self): + """Test schema when DEFAULT_MODEL is set but unavailable.""" + with patch("config.DEFAULT_MODEL", "o3"): + with patch("config.IS_AUTO_MODE", False): + with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider: + mock_get_provider.return_value = None # Model not available + + tool = AnalyzeTool() + schema = tool.get_input_schema() + + # Model should be required due to unavailable DEFAULT_MODEL + assert "model" in schema["required"] + + def test_schema_with_available_default_model(self): + """Test schema when DEFAULT_MODEL is available.""" + with patch("config.DEFAULT_MODEL", "pro"): + with patch("config.IS_AUTO_MODE", False): + with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider: + mock_get_provider.return_value = MagicMock() # Model is available + + tool = ThinkDeepTool() + schema = tool.get_input_schema() + + # Model should NOT be required + assert "model" not in schema["required"] + + +class TestUnavailableModelFallback: + """Test fallback behavior when DEFAULT_MODEL is not available.""" + + @pytest.mark.asyncio + async def test_unavailable_default_model_fallback(self): + """Test that unavailable DEFAULT_MODEL triggers auto mode behavior.""" + with patch("config.DEFAULT_MODEL", "o3"): # Set DEFAULT_MODEL to a specific model + with patch("config.IS_AUTO_MODE", False): # Not in auto mode + with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider: + # Model is not available (no provider) + mock_get_provider.return_value = None + + tool = ThinkDeepTool() + result = await tool.execute({"prompt": "test"}) # No model specified + + # Should get auto mode error since model is unavailable + assert len(result) == 1 + # When DEFAULT_MODEL is unavailable, the error message indicates the model is not available + assert "o3" in result[0].text + assert "is not available" in result[0].text + # The suggested model depends on which providers are available + # Just check that it suggests a model for the extended_reasoning category + assert "(category: extended_reasoning)" in result[0].text + + @pytest.mark.asyncio + async def test_available_default_model_no_fallback(self): + """Test that available DEFAULT_MODEL works normally.""" + with patch("config.DEFAULT_MODEL", "pro"): + with patch("config.IS_AUTO_MODE", False): + with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider: + # Model is available + mock_provider = MagicMock() + mock_provider.generate_content.return_value = MagicMock(content="Test response", metadata={}) + mock_get_provider.return_value = mock_provider + + # Mock the provider lookup in BaseTool.get_model_provider + with patch.object(BaseTool, "get_model_provider") as mock_get_model_provider: + mock_get_model_provider.return_value = mock_provider + + tool = ChatTool() + result = await tool.execute({"prompt": "test"}) # No model specified + + # Should work normally, not require model parameter + assert len(result) == 1 + output = json.loads(result[0].text) + assert output["status"] == "success" + assert "Test response" in output["content"] diff --git a/tools/analyze.py b/tools/analyze.py index bd1f597..724d462 100644 --- a/tools/analyze.py +++ b/tools/analyze.py @@ -2,11 +2,14 @@ Analyze tool - General-purpose code and file analysis """ -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from mcp.types import TextContent from pydantic import Field +if TYPE_CHECKING: + from tools.models import ToolModelCategory + from config import TEMPERATURE_ANALYTICAL from prompts import ANALYZE_PROMPT @@ -42,8 +45,6 @@ class AnalyzeTool(BaseTool): ) def get_input_schema(self) -> dict[str, Any]: - from config import IS_AUTO_MODE - schema = { "type": "object", "properties": { @@ -95,7 +96,7 @@ class AnalyzeTool(BaseTool): "description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.", }, }, - "required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []), + "required": ["files", "prompt"] + (["model"] if self.is_effective_auto_mode() else []), } return schema @@ -106,6 +107,12 @@ class AnalyzeTool(BaseTool): def get_default_temperature(self) -> float: return TEMPERATURE_ANALYTICAL + def get_model_category(self) -> "ToolModelCategory": + """Analyze requires deep understanding and reasoning""" + from tools.models import ToolModelCategory + + return ToolModelCategory.EXTENDED_REASONING + def get_request_model(self): return AnalyzeRequest diff --git a/tools/base.py b/tools/base.py index 806fd12..0e1de81 100644 --- a/tools/base.py +++ b/tools/base.py @@ -17,11 +17,14 @@ import json import logging import os from abc import ABC, abstractmethod -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional from mcp.types import TextContent from pydantic import BaseModel, Field +if TYPE_CHECKING: + from tools.models import ToolModelCategory + from config import MCP_PROMPT_SIZE_LIMIT from providers import ModelProvider, ModelProviderRegistry from utils import check_token_limit @@ -156,6 +159,88 @@ class BaseTool(ABC): """ pass + def is_effective_auto_mode(self) -> bool: + """ + Check if we're in effective auto mode for schema generation. + + This determines whether the model parameter should be required in the tool schema. + Used at initialization time when schemas are generated. + + Returns: + bool: True if model parameter should be required in the schema + """ + from config import DEFAULT_MODEL, IS_AUTO_MODE + from providers.registry import ModelProviderRegistry + + # Case 1: Explicit auto mode + if IS_AUTO_MODE: + return True + + # Case 2: Model not available + if DEFAULT_MODEL.lower() != "auto": + provider = ModelProviderRegistry.get_provider_for_model(DEFAULT_MODEL) + if not provider: + return True + + return False + + def _should_require_model_selection(self, model_name: str) -> bool: + """ + Check if we should require Claude to select a model at runtime. + + This is called during request execution to determine if we need + to return an error asking Claude to provide a model parameter. + + Args: + model_name: The model name from the request or DEFAULT_MODEL + + Returns: + bool: True if we should require model selection + """ + # Case 1: Model is explicitly "auto" + if model_name.lower() == "auto": + return True + + # Case 2: Requested model is not available + from providers.registry import ModelProviderRegistry + + provider = ModelProviderRegistry.get_provider_for_model(model_name) + if not provider: + logger = logging.getLogger(f"tools.{self.name}") + logger.warning( + f"Model '{model_name}' is not available with current API keys. " f"Requiring model selection." + ) + return True + + return False + + def _get_available_models(self) -> list[str]: + """ + Get list of models that are actually available with current API keys. + + Returns: + List of available model names + """ + from config import MODEL_CAPABILITIES_DESC + from providers.base import ProviderType + from providers.registry import ModelProviderRegistry + + available_models = [] + + # Check each model in our capabilities list + for model_name in MODEL_CAPABILITIES_DESC.keys(): + provider = ModelProviderRegistry.get_provider_for_model(model_name) + if provider: + available_models.append(model_name) + + # Also check if OpenRouter is available (it accepts any model) + openrouter_provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER) + if openrouter_provider and not available_models: + # If only OpenRouter is available, suggest using any model through it + available_models.append("any model via OpenRouter") + + return available_models if available_models else ["none - please configure API keys"] + def get_model_field_schema(self) -> dict[str, Any]: """ Generate the model field schema based on auto mode configuration. @@ -168,16 +253,20 @@ class BaseTool(ABC): """ import os - from config import DEFAULT_MODEL, IS_AUTO_MODE, MODEL_CAPABILITIES_DESC + from config import DEFAULT_MODEL, MODEL_CAPABILITIES_DESC # Check if OpenRouter is configured has_openrouter = bool( os.getenv("OPENROUTER_API_KEY") and os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here" ) - if IS_AUTO_MODE: + # Use the centralized effective auto mode check + if self.is_effective_auto_mode(): # In auto mode, model is required and we provide detailed descriptions - model_desc_parts = ["Choose the best model for this task based on these capabilities:"] + model_desc_parts = [ + "IMPORTANT: Use the model specified by the user if provided, OR select the most suitable model " + "for this specific task based on the requirements and capabilities listed below:" + ] for model, desc in MODEL_CAPABILITIES_DESC.items(): model_desc_parts.append(f"- '{model}': {desc}") @@ -302,6 +391,21 @@ class BaseTool(ABC): """ return "medium" # Default to medium thinking for better reasoning + def get_model_category(self) -> "ToolModelCategory": + """ + Return the model category for this tool. + + Model category influences which model is selected in auto mode. + Override to specify whether your tool needs extended reasoning, + fast response, or balanced capabilities. + + Returns: + ToolModelCategory: Category that influences model selection + """ + from tools.models import ToolModelCategory + + return ToolModelCategory.BALANCED + def get_conversation_embedded_files(self, continuation_id: Optional[str]) -> list[str]: """ Get list of files already embedded in conversation history. @@ -474,11 +578,13 @@ class BaseTool(ABC): if model_name.lower() == "auto": from providers.registry import ModelProviderRegistry - # Use the preferred fallback model for capacity estimation + # Use tool-specific fallback model for capacity estimation # This properly handles different providers (OpenAI=200K, Gemini=1M) - fallback_model = ModelProviderRegistry.get_preferred_fallback_model() + tool_category = self.get_model_category() + fallback_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category) logger.debug( - f"[FILES] {self.name}: Auto mode detected, using {fallback_model} for capacity estimation" + f"[FILES] {self.name}: Auto mode detected, using {fallback_model} " + f"for {tool_category.value} tool capacity estimation" ) try: @@ -898,13 +1004,39 @@ When recommending searches, be specific about what information you need and why model_name = DEFAULT_MODEL - # In auto mode, model parameter is required - from config import IS_AUTO_MODE + # Check if we need Claude to select a model + # This happens when: + # 1. The model is explicitly "auto" + # 2. The requested model is not available + if self._should_require_model_selection(model_name): + # Get suggested model based on tool category + from providers.registry import ModelProviderRegistry + + tool_category = self.get_model_category() + suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category) + + # Build error message based on why selection is required + if model_name.lower() == "auto": + error_message = ( + f"Model parameter is required in auto mode. " + f"Suggested model for {self.name}: '{suggested_model}' " + f"(category: {tool_category.value})" + ) + else: + # Model was specified but not available + # Get list of available models + available_models = self._get_available_models() + + error_message = ( + f"Model '{model_name}' is not available with current API keys. " + f"Available models: {', '.join(available_models)}. " + f"Suggested model for {self.name}: '{suggested_model}' " + f"(category: {tool_category.value})" + ) - if IS_AUTO_MODE and model_name.lower() == "auto": error_output = ToolOutput( status="error", - content="Model parameter is required. Please specify which model to use for this task.", + content=error_message, content_type="text", ) return [TextContent(type="text", text=error_output.model_dump_json())] diff --git a/tools/chat.py b/tools/chat.py index 704f71f..c263a02 100644 --- a/tools/chat.py +++ b/tools/chat.py @@ -2,11 +2,14 @@ Chat tool - General development chat and collaborative thinking """ -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from mcp.types import TextContent from pydantic import Field +if TYPE_CHECKING: + from tools.models import ToolModelCategory + from config import TEMPERATURE_BALANCED from prompts import CHAT_PROMPT @@ -44,8 +47,6 @@ class ChatTool(BaseTool): ) def get_input_schema(self) -> dict[str, Any]: - from config import IS_AUTO_MODE - schema = { "type": "object", "properties": { @@ -80,7 +81,7 @@ class ChatTool(BaseTool): "description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.", }, }, - "required": ["prompt"] + (["model"] if IS_AUTO_MODE else []), + "required": ["prompt"] + (["model"] if self.is_effective_auto_mode() else []), } return schema @@ -91,6 +92,12 @@ class ChatTool(BaseTool): def get_default_temperature(self) -> float: return TEMPERATURE_BALANCED + def get_model_category(self) -> "ToolModelCategory": + """Chat prioritizes fast responses and cost efficiency""" + from tools.models import ToolModelCategory + + return ToolModelCategory.FAST_RESPONSE + def get_request_model(self): return ChatRequest diff --git a/tools/codereview.py b/tools/codereview.py index e6889b3..c124bc4 100644 --- a/tools/codereview.py +++ b/tools/codereview.py @@ -82,8 +82,6 @@ class CodeReviewTool(BaseTool): ) def get_input_schema(self) -> dict[str, Any]: - from config import IS_AUTO_MODE - schema = { "type": "object", "properties": { @@ -138,7 +136,7 @@ class CodeReviewTool(BaseTool): "description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.", }, }, - "required": ["files", "prompt"] + (["model"] if IS_AUTO_MODE else []), + "required": ["files", "prompt"] + (["model"] if self.is_effective_auto_mode() else []), } return schema diff --git a/tools/debug.py b/tools/debug.py index a58758e..865d4ca 100644 --- a/tools/debug.py +++ b/tools/debug.py @@ -2,11 +2,14 @@ Debug Issue tool - Root cause analysis and debugging assistance """ -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from mcp.types import TextContent from pydantic import Field +if TYPE_CHECKING: + from tools.models import ToolModelCategory + from config import TEMPERATURE_ANALYTICAL from prompts import DEBUG_ISSUE_PROMPT @@ -50,8 +53,6 @@ class DebugIssueTool(BaseTool): ) def get_input_schema(self) -> dict[str, Any]: - from config import IS_AUTO_MODE - schema = { "type": "object", "properties": { @@ -98,7 +99,7 @@ class DebugIssueTool(BaseTool): "description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.", }, }, - "required": ["prompt"] + (["model"] if IS_AUTO_MODE else []), + "required": ["prompt"] + (["model"] if self.is_effective_auto_mode() else []), } return schema @@ -109,6 +110,12 @@ class DebugIssueTool(BaseTool): def get_default_temperature(self) -> float: return TEMPERATURE_ANALYTICAL + def get_model_category(self) -> "ToolModelCategory": + """Debug requires deep analysis and reasoning""" + from tools.models import ToolModelCategory + + return ToolModelCategory.EXTENDED_REASONING + def get_request_model(self): return DebugIssueRequest diff --git a/tools/models.py b/tools/models.py index 81825b9..146ecd8 100644 --- a/tools/models.py +++ b/tools/models.py @@ -2,11 +2,20 @@ Data models for tool responses and interactions """ +from enum import Enum from typing import Any, Literal, Optional from pydantic import BaseModel, Field +class ToolModelCategory(Enum): + """Categories for tool model selection based on requirements.""" + + EXTENDED_REASONING = "extended_reasoning" # Requires deep thinking capabilities + FAST_RESPONSE = "fast_response" # Speed and cost efficiency preferred + BALANCED = "balanced" # Balance of capability and performance + + class ContinuationOffer(BaseModel): """Offer for Claude to continue conversation when Gemini doesn't ask follow-up""" diff --git a/tools/precommit.py b/tools/precommit.py index a73a859..fdc466f 100644 --- a/tools/precommit.py +++ b/tools/precommit.py @@ -9,11 +9,14 @@ This provides comprehensive context for AI analysis - not a duplication bug. """ import os -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional from mcp.types import TextContent from pydantic import Field +if TYPE_CHECKING: + from tools.models import ToolModelCategory + from prompts.tool_prompts import PRECOMMIT_PROMPT from utils.file_utils import translate_file_paths, translate_path_for_environment from utils.git_utils import find_git_repositories, get_git_status, run_git_command @@ -100,30 +103,83 @@ class Precommit(BaseTool): ) def get_input_schema(self) -> dict[str, Any]: - from config import IS_AUTO_MODE - - schema = self.get_request_model().model_json_schema() - # Ensure model parameter has enhanced description - if "properties" in schema and "model" in schema["properties"]: - schema["properties"]["model"] = self.get_model_field_schema() - - # In auto mode, model is required - if IS_AUTO_MODE and "required" in schema: - if "model" not in schema["required"]: - schema["required"].append("model") - # Ensure use_websearch is in the schema with proper description - if "properties" in schema and "use_websearch" not in schema["properties"]: - schema["properties"]["use_websearch"] = { - "type": "boolean", - "description": "Enable web search for documentation, best practices, and current information. Particularly useful for: brainstorming sessions, architectural design discussions, exploring industry best practices, working with specific frameworks/technologies, researching solutions to complex problems, or when current documentation and community insights would enhance the analysis.", - "default": True, - } - # Add continuation_id parameter - if "properties" in schema and "continuation_id" not in schema["properties"]: - schema["properties"]["continuation_id"] = { - "type": "string", - "description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.", - } + schema = { + "type": "object", + "title": "PrecommitRequest", + "description": "Request model for precommit tool", + "properties": { + "path": { + "type": "string", + "description": "Starting directory to search for git repositories (must be absolute path).", + }, + "model": self.get_model_field_schema(), + "prompt": { + "type": "string", + "description": "The original user request description for the changes. Provides critical context for the review.", + }, + "compare_to": { + "type": "string", + "description": "Optional: A git ref (branch, tag, commit hash) to compare against. If not provided, reviews local staged and unstaged changes.", + }, + "include_staged": { + "type": "boolean", + "default": True, + "description": "Include staged changes in the review. Only applies if 'compare_to' is not set.", + }, + "include_unstaged": { + "type": "boolean", + "default": True, + "description": "Include uncommitted (unstaged) changes in the review. Only applies if 'compare_to' is not set.", + }, + "focus_on": { + "type": "string", + "description": "Specific aspects to focus on (e.g., 'logic for user authentication', 'database query efficiency').", + }, + "review_type": { + "type": "string", + "enum": ["full", "security", "performance", "quick"], + "default": "full", + "description": "Type of review to perform on the changes.", + }, + "severity_filter": { + "type": "string", + "enum": ["critical", "high", "medium", "all"], + "default": "all", + "description": "Minimum severity level to report on the changes.", + }, + "max_depth": { + "type": "integer", + "default": 5, + "description": "Maximum depth to search for nested git repositories to prevent excessive recursion.", + }, + "temperature": { + "type": "number", + "description": "Temperature for the response (0.0 to 1.0). Lower values are more focused and deterministic.", + "minimum": 0, + "maximum": 1, + }, + "thinking_mode": { + "type": "string", + "enum": ["minimal", "low", "medium", "high", "max"], + "description": "Thinking depth mode for the assistant.", + }, + "files": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional files or directories to provide as context (must be absolute paths). These files are not part of the changes but provide helpful context like configs, docs, or related code.", + }, + "use_websearch": { + "type": "boolean", + "description": "Enable web search for documentation, best practices, and current information. Particularly useful for: brainstorming sessions, architectural design discussions, exploring industry best practices, working with specific frameworks/technologies, researching solutions to complex problems, or when current documentation and community insights would enhance the analysis.", + "default": True, + }, + "continuation_id": { + "type": "string", + "description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.", + }, + }, + "required": ["path"] + (["model"] if self.is_effective_auto_mode() else []), + } return schema def get_system_prompt(self) -> str: @@ -138,6 +194,12 @@ class Precommit(BaseTool): return TEMPERATURE_ANALYTICAL + def get_model_category(self) -> "ToolModelCategory": + """Precommit requires thorough analysis and reasoning""" + from tools.models import ToolModelCategory + + return ToolModelCategory.EXTENDED_REASONING + async def execute(self, arguments: dict[str, Any]) -> list[TextContent]: """Override execute to check original_request size before processing""" # First validate request diff --git a/tools/thinkdeep.py b/tools/thinkdeep.py index e2d5f86..27a4b2e 100644 --- a/tools/thinkdeep.py +++ b/tools/thinkdeep.py @@ -2,11 +2,14 @@ ThinkDeep tool - Extended reasoning and problem-solving """ -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from mcp.types import TextContent from pydantic import Field +if TYPE_CHECKING: + from tools.models import ToolModelCategory + from config import TEMPERATURE_CREATIVE from prompts import THINKDEEP_PROMPT @@ -48,8 +51,6 @@ class ThinkDeepTool(BaseTool): ) def get_input_schema(self) -> dict[str, Any]: - from config import IS_AUTO_MODE - schema = { "type": "object", "properties": { @@ -93,7 +94,7 @@ class ThinkDeepTool(BaseTool): "description": "Thread continuation ID for multi-turn conversations. Can be used to continue conversations across different tools. Only provide this if continuing a previous conversation thread.", }, }, - "required": ["prompt"] + (["model"] if IS_AUTO_MODE else []), + "required": ["prompt"] + (["model"] if self.is_effective_auto_mode() else []), } return schema @@ -110,6 +111,12 @@ class ThinkDeepTool(BaseTool): return DEFAULT_THINKING_MODE_THINKDEEP + def get_model_category(self) -> "ToolModelCategory": + """ThinkDeep requires extended reasoning capabilities""" + from tools.models import ToolModelCategory + + return ToolModelCategory.EXTENDED_REASONING + def get_request_model(self): return ThinkDeepRequest