WIP major refactor and features
This commit is contained in:
@@ -15,9 +15,20 @@ parent_dir = Path(__file__).resolve().parent.parent
|
||||
if str(parent_dir) not in sys.path:
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
|
||||
# Set dummy API key for tests if not already set
|
||||
# Set dummy API keys for tests if not already set
|
||||
if "GEMINI_API_KEY" not in os.environ:
|
||||
os.environ["GEMINI_API_KEY"] = "dummy-key-for-tests"
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
os.environ["OPENAI_API_KEY"] = "dummy-key-for-tests"
|
||||
|
||||
# Set default model to a specific value for tests to avoid auto mode
|
||||
# This prevents all tests from failing due to missing model parameter
|
||||
os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash-exp"
|
||||
|
||||
# Force reload of config module to pick up the env var
|
||||
import importlib
|
||||
import config
|
||||
importlib.reload(config)
|
||||
|
||||
# Set MCP_PROJECT_ROOT to a temporary directory for tests
|
||||
# This provides a safe sandbox for file operations during testing
|
||||
@@ -29,6 +40,16 @@ os.environ["MCP_PROJECT_ROOT"] = test_root
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
# Register providers for all tests
|
||||
from providers import ModelProviderRegistry
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.base import ProviderType
|
||||
|
||||
# Register providers at test startup
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_path(tmp_path):
|
||||
|
||||
39
tests/mock_helpers.py
Normal file
39
tests/mock_helpers.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Helper functions for test mocking."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
from providers.base import ModelCapabilities, ProviderType
|
||||
|
||||
def create_mock_provider(model_name="gemini-2.0-flash-exp", max_tokens=1_048_576):
|
||||
"""Create a properly configured mock provider."""
|
||||
mock_provider = Mock()
|
||||
|
||||
# Set up capabilities
|
||||
mock_capabilities = ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name=model_name,
|
||||
friendly_name="Gemini",
|
||||
max_tokens=max_tokens,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
temperature_range=(0.0, 2.0),
|
||||
)
|
||||
|
||||
mock_provider.get_capabilities.return_value = mock_capabilities
|
||||
mock_provider.get_provider_type.return_value = ProviderType.GOOGLE
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.validate_model_name.return_value = True
|
||||
|
||||
# Set up generate_content response
|
||||
mock_response = Mock()
|
||||
mock_response.content = "Test response"
|
||||
mock_response.usage = {"input_tokens": 10, "output_tokens": 20}
|
||||
mock_response.model_name = model_name
|
||||
mock_response.friendly_name = "Gemini"
|
||||
mock_response.provider = ProviderType.GOOGLE
|
||||
mock_response.metadata = {"finish_reason": "STOP"}
|
||||
|
||||
mock_provider.generate_content.return_value = mock_response
|
||||
|
||||
return mock_provider
|
||||
180
tests/test_auto_mode.py
Normal file
180
tests/test_auto_mode.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""Tests for auto mode functionality"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
import importlib
|
||||
|
||||
from mcp.types import TextContent
|
||||
from tools.analyze import AnalyzeTool
|
||||
|
||||
|
||||
class TestAutoMode:
|
||||
"""Test auto mode configuration and behavior"""
|
||||
|
||||
def test_auto_mode_detection(self):
|
||||
"""Test that auto mode is detected correctly"""
|
||||
# Save original
|
||||
original = os.environ.get("DEFAULT_MODEL", "")
|
||||
|
||||
try:
|
||||
# Test auto mode
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
import config
|
||||
importlib.reload(config)
|
||||
|
||||
assert config.DEFAULT_MODEL == "auto"
|
||||
assert config.IS_AUTO_MODE is True
|
||||
|
||||
# Test non-auto mode
|
||||
os.environ["DEFAULT_MODEL"] = "pro"
|
||||
importlib.reload(config)
|
||||
|
||||
assert config.DEFAULT_MODEL == "pro"
|
||||
assert config.IS_AUTO_MODE is False
|
||||
|
||||
finally:
|
||||
# Restore
|
||||
if original:
|
||||
os.environ["DEFAULT_MODEL"] = original
|
||||
else:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
importlib.reload(config)
|
||||
|
||||
def test_model_capabilities_descriptions(self):
|
||||
"""Test that model capabilities are properly defined"""
|
||||
from config import MODEL_CAPABILITIES_DESC
|
||||
|
||||
# Check all expected models are present
|
||||
expected_models = ["flash", "pro", "o3", "o3-mini", "gpt-4o"]
|
||||
for model in expected_models:
|
||||
assert model in MODEL_CAPABILITIES_DESC
|
||||
assert isinstance(MODEL_CAPABILITIES_DESC[model], str)
|
||||
assert len(MODEL_CAPABILITIES_DESC[model]) > 50 # Meaningful description
|
||||
|
||||
def test_tool_schema_in_auto_mode(self):
|
||||
"""Test that tool schemas require model in auto mode"""
|
||||
# Save original
|
||||
original = os.environ.get("DEFAULT_MODEL", "")
|
||||
|
||||
try:
|
||||
# Enable auto mode
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
import config
|
||||
importlib.reload(config)
|
||||
|
||||
tool = AnalyzeTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Model should be required
|
||||
assert "model" in schema["required"]
|
||||
|
||||
# Model field should have detailed descriptions
|
||||
model_schema = schema["properties"]["model"]
|
||||
assert "enum" in model_schema
|
||||
assert "flash" in model_schema["enum"]
|
||||
assert "Choose the best model" in model_schema["description"]
|
||||
|
||||
finally:
|
||||
# Restore
|
||||
if original:
|
||||
os.environ["DEFAULT_MODEL"] = original
|
||||
else:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
importlib.reload(config)
|
||||
|
||||
def test_tool_schema_in_normal_mode(self):
|
||||
"""Test that tool schemas don't require model in normal mode"""
|
||||
# This test uses the default from conftest.py which sets non-auto mode
|
||||
tool = AnalyzeTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Model should not be required
|
||||
assert "model" not in schema["required"]
|
||||
|
||||
# Model field should have simpler description
|
||||
model_schema = schema["properties"]["model"]
|
||||
assert "enum" not in model_schema
|
||||
assert "Available:" in model_schema["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_requires_model_parameter(self):
|
||||
"""Test that auto mode enforces model parameter"""
|
||||
# Save original
|
||||
original = os.environ.get("DEFAULT_MODEL", "")
|
||||
|
||||
try:
|
||||
# Enable auto mode
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
import config
|
||||
importlib.reload(config)
|
||||
|
||||
tool = AnalyzeTool()
|
||||
|
||||
# Mock the provider to avoid real API calls
|
||||
with patch.object(tool, 'get_model_provider') as mock_provider:
|
||||
# Execute without model parameter
|
||||
result = await tool.execute({
|
||||
"files": ["/tmp/test.py"],
|
||||
"prompt": "Analyze this"
|
||||
})
|
||||
|
||||
# Should get error
|
||||
assert len(result) == 1
|
||||
response = result[0].text
|
||||
assert "error" in response
|
||||
assert "Model parameter is required" in response
|
||||
|
||||
finally:
|
||||
# Restore
|
||||
if original:
|
||||
os.environ["DEFAULT_MODEL"] = original
|
||||
else:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
importlib.reload(config)
|
||||
|
||||
def test_model_field_schema_generation(self):
|
||||
"""Test the get_model_field_schema method"""
|
||||
from tools.base import BaseTool
|
||||
|
||||
# Create a minimal concrete tool for testing
|
||||
class TestTool(BaseTool):
|
||||
def get_name(self): return "test"
|
||||
def get_description(self): return "test"
|
||||
def get_input_schema(self): return {}
|
||||
def get_system_prompt(self): return ""
|
||||
def get_request_model(self): return None
|
||||
async def prepare_prompt(self, request): return ""
|
||||
|
||||
tool = TestTool()
|
||||
|
||||
# Save original
|
||||
original = os.environ.get("DEFAULT_MODEL", "")
|
||||
|
||||
try:
|
||||
# Test auto mode
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
import config
|
||||
importlib.reload(config)
|
||||
|
||||
schema = tool.get_model_field_schema()
|
||||
assert "enum" in schema
|
||||
assert all(model in schema["enum"] for model in ["flash", "pro", "o3"])
|
||||
assert "Choose the best model" in schema["description"]
|
||||
|
||||
# Test normal mode
|
||||
os.environ["DEFAULT_MODEL"] = "pro"
|
||||
importlib.reload(config)
|
||||
|
||||
schema = tool.get_model_field_schema()
|
||||
assert "enum" not in schema
|
||||
assert "Available:" in schema["description"]
|
||||
assert "'pro'" in schema["description"]
|
||||
|
||||
finally:
|
||||
# Restore
|
||||
if original:
|
||||
os.environ["DEFAULT_MODEL"] = original
|
||||
else:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
importlib.reload(config)
|
||||
@@ -7,6 +7,7 @@ when Gemini doesn't explicitly ask a follow-up question.
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
@@ -116,20 +117,20 @@ class TestClaudeContinuationOffers:
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Mock the model to return a response without follow-up question
|
||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
||||
mock_model = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [
|
||||
Mock(
|
||||
content=Mock(parts=[Mock(text="Analysis complete. The code looks good.")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Analysis complete. The code looks good.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Execute tool with new conversation
|
||||
arguments = {"prompt": "Analyze this code"}
|
||||
arguments = {"prompt": "Analyze this code", "model": "flash"}
|
||||
response = await self.tool.execute(arguments)
|
||||
|
||||
# Parse response
|
||||
@@ -157,15 +158,12 @@ class TestClaudeContinuationOffers:
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Mock the model to return a response WITH follow-up question
|
||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
||||
mock_model = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [
|
||||
Mock(
|
||||
content=Mock(
|
||||
parts=[
|
||||
Mock(
|
||||
text="""Analysis complete. The code looks good.
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
# Include follow-up JSON in the content
|
||||
content_with_followup = """Analysis complete. The code looks good.
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -174,14 +172,13 @@ class TestClaudeContinuationOffers:
|
||||
"ui_hint": "Examining error handling would help ensure robustness"
|
||||
}
|
||||
```"""
|
||||
)
|
||||
]
|
||||
),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=content_with_followup,
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Execute tool
|
||||
arguments = {"prompt": "Analyze this code"}
|
||||
@@ -215,17 +212,17 @@ class TestClaudeContinuationOffers:
|
||||
mock_client.get.return_value = thread_context.model_dump_json()
|
||||
|
||||
# Mock the model
|
||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
||||
mock_model = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [
|
||||
Mock(
|
||||
content=Mock(parts=[Mock(text="Continued analysis complete.")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Continued analysis complete.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Execute tool with continuation_id
|
||||
arguments = {"prompt": "Continue the analysis", "continuation_id": "12345678-1234-1234-1234-123456789012"}
|
||||
|
||||
@@ -4,6 +4,7 @@ Tests for dynamic context request and collaboration features
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -24,8 +25,8 @@ class TestDynamicContextRequests:
|
||||
return DebugIssueTool()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_clarification_request_parsing(self, mock_create_model, analyze_tool):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_clarification_request_parsing(self, mock_get_provider, analyze_tool):
|
||||
"""Test that tools correctly parse clarification requests"""
|
||||
# Mock model to return a clarification request
|
||||
clarification_json = json.dumps(
|
||||
@@ -36,16 +37,21 @@ class TestDynamicContextRequests:
|
||||
}
|
||||
)
|
||||
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json,
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await analyze_tool.execute(
|
||||
{
|
||||
"files": ["/absolute/path/src/index.js"],
|
||||
"question": "Analyze the dependencies used in this project",
|
||||
"prompt": "Analyze the dependencies used in this project",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -62,8 +68,8 @@ class TestDynamicContextRequests:
|
||||
assert clarification["files_needed"] == ["package.json", "package-lock.json"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_normal_response_not_parsed_as_clarification(self, mock_create_model, debug_tool):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_normal_response_not_parsed_as_clarification(self, mock_get_provider, debug_tool):
|
||||
"""Test that normal responses are not mistaken for clarification requests"""
|
||||
normal_response = """
|
||||
## Summary
|
||||
@@ -75,13 +81,18 @@ class TestDynamicContextRequests:
|
||||
**Root Cause:** The module 'utils' is not imported
|
||||
"""
|
||||
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text=normal_response)]))]
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=normal_response,
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await debug_tool.execute({"error_description": "NameError: name 'utils' is not defined"})
|
||||
result = await debug_tool.execute({"prompt": "NameError: name 'utils' is not defined"})
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
@@ -92,18 +103,23 @@ class TestDynamicContextRequests:
|
||||
assert "Summary" in response_data["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_malformed_clarification_request_treated_as_normal(self, mock_create_model, analyze_tool):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_malformed_clarification_request_treated_as_normal(self, mock_get_provider, analyze_tool):
|
||||
"""Test that malformed JSON clarification requests are treated as normal responses"""
|
||||
malformed_json = '{"status": "requires_clarification", "question": "Missing closing brace"'
|
||||
malformed_json = '{"status": "requires_clarification", "prompt": "Missing closing brace"'
|
||||
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text=malformed_json)]))]
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=malformed_json,
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "question": "What does this do?"})
|
||||
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "prompt": "What does this do?"})
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
@@ -113,8 +129,8 @@ class TestDynamicContextRequests:
|
||||
assert malformed_json in response_data["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_clarification_with_suggested_action(self, mock_create_model, debug_tool):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_clarification_with_suggested_action(self, mock_get_provider, debug_tool):
|
||||
"""Test clarification request with suggested next action"""
|
||||
clarification_json = json.dumps(
|
||||
{
|
||||
@@ -124,7 +140,7 @@ class TestDynamicContextRequests:
|
||||
"suggested_next_action": {
|
||||
"tool": "debug",
|
||||
"args": {
|
||||
"error_description": "Connection timeout to database",
|
||||
"prompt": "Connection timeout to database",
|
||||
"files": [
|
||||
"/config/database.yml",
|
||||
"/src/db.py",
|
||||
@@ -135,15 +151,20 @@ class TestDynamicContextRequests:
|
||||
}
|
||||
)
|
||||
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json,
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await debug_tool.execute(
|
||||
{
|
||||
"error_description": "Connection timeout to database",
|
||||
"prompt": "Connection timeout to database",
|
||||
"files": ["/absolute/logs/error.log"],
|
||||
}
|
||||
)
|
||||
@@ -187,12 +208,12 @@ class TestDynamicContextRequests:
|
||||
assert request.suggested_next_action["tool"] == "analyze"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_error_response_format(self, mock_create_model, analyze_tool):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_error_response_format(self, mock_get_provider, analyze_tool):
|
||||
"""Test error response format"""
|
||||
mock_create_model.side_effect = Exception("API connection failed")
|
||||
mock_get_provider.side_effect = Exception("API connection failed")
|
||||
|
||||
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "question": "Analyze this"})
|
||||
result = await analyze_tool.execute({"files": ["/absolute/path/test.py"], "prompt": "Analyze this"})
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
@@ -206,8 +227,8 @@ class TestCollaborationWorkflow:
|
||||
"""Test complete collaboration workflows"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_dependency_analysis_triggers_clarification(self, mock_create_model):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_dependency_analysis_triggers_clarification(self, mock_get_provider):
|
||||
"""Test that asking about dependencies without package files triggers clarification"""
|
||||
tool = AnalyzeTool()
|
||||
|
||||
@@ -220,17 +241,22 @@ class TestCollaborationWorkflow:
|
||||
}
|
||||
)
|
||||
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json,
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Ask about dependencies with only source files
|
||||
result = await tool.execute(
|
||||
{
|
||||
"files": ["/absolute/path/src/index.js"],
|
||||
"question": "What npm packages and versions does this project use?",
|
||||
"prompt": "What npm packages and versions does this project use?",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -243,8 +269,8 @@ class TestCollaborationWorkflow:
|
||||
assert "package.json" in str(clarification["files_needed"]), "Should specifically request package.json"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_multi_step_collaboration(self, mock_create_model):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_multi_step_collaboration(self, mock_get_provider):
|
||||
"""Test a multi-step collaboration workflow"""
|
||||
tool = DebugIssueTool()
|
||||
|
||||
@@ -257,15 +283,20 @@ class TestCollaborationWorkflow:
|
||||
}
|
||||
)
|
||||
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text=clarification_json)]))]
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json,
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result1 = await tool.execute(
|
||||
{
|
||||
"error_description": "Database connection timeout",
|
||||
"prompt": "Database connection timeout",
|
||||
"error_context": "Timeout after 30s",
|
||||
}
|
||||
)
|
||||
@@ -285,13 +316,16 @@ class TestCollaborationWorkflow:
|
||||
**Root Cause:** The config.py file shows the database host is set to 'localhost' but the database is running on a different server.
|
||||
"""
|
||||
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text=final_response)]))]
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=final_response,
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
|
||||
result2 = await tool.execute(
|
||||
{
|
||||
"error_description": "Database connection timeout",
|
||||
"prompt": "Database connection timeout",
|
||||
"error_context": "Timeout after 30s",
|
||||
"files": ["/absolute/path/config.py"], # Additional context provided
|
||||
}
|
||||
|
||||
@@ -31,7 +31,8 @@ class TestConfig:
|
||||
|
||||
def test_model_config(self):
|
||||
"""Test model configuration"""
|
||||
assert DEFAULT_MODEL == "gemini-2.5-pro-preview-06-05"
|
||||
# DEFAULT_MODEL is set in conftest.py for tests
|
||||
assert DEFAULT_MODEL == "gemini-2.0-flash-exp"
|
||||
assert MAX_CONTEXT_TOKENS == 1_000_000
|
||||
|
||||
def test_temperature_defaults(self):
|
||||
|
||||
171
tests/test_conversation_field_mapping.py
Normal file
171
tests/test_conversation_field_mapping.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Test that conversation history is correctly mapped to tool-specific fields
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
from datetime import datetime
|
||||
|
||||
from server import reconstruct_thread_context
|
||||
from utils.conversation_memory import ConversationTurn, ThreadContext
|
||||
from providers.base import ProviderType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_history_field_mapping():
|
||||
"""Test that enhanced prompts are mapped to prompt field for all tools"""
|
||||
|
||||
# Test data for different tools - all use 'prompt' now
|
||||
test_cases = [
|
||||
{
|
||||
"tool_name": "analyze",
|
||||
"original_value": "What does this code do?",
|
||||
},
|
||||
{
|
||||
"tool_name": "chat",
|
||||
"original_value": "Explain this concept",
|
||||
},
|
||||
{
|
||||
"tool_name": "debug",
|
||||
"original_value": "Getting undefined error",
|
||||
},
|
||||
{
|
||||
"tool_name": "codereview",
|
||||
"original_value": "Review this implementation",
|
||||
},
|
||||
{
|
||||
"tool_name": "thinkdeep",
|
||||
"original_value": "My analysis so far",
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
# Create mock conversation context
|
||||
mock_context = ThreadContext(
|
||||
thread_id="test-thread-123",
|
||||
tool_name=test_case["tool_name"],
|
||||
created_at=datetime.now().isoformat(),
|
||||
last_updated_at=datetime.now().isoformat(),
|
||||
turns=[
|
||||
ConversationTurn(
|
||||
role="user",
|
||||
content="Previous user message",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
files=["/test/file1.py"],
|
||||
),
|
||||
ConversationTurn(
|
||||
role="assistant",
|
||||
content="Previous assistant response",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
),
|
||||
],
|
||||
initial_context={},
|
||||
)
|
||||
|
||||
# Mock get_thread to return our test context
|
||||
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
|
||||
with patch("utils.conversation_memory.add_turn", return_value=True):
|
||||
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
|
||||
# Mock provider registry to avoid model lookup errors
|
||||
with patch("providers.registry.ModelProviderRegistry.get_provider_for_model") as mock_get_provider:
|
||||
from providers.base import ModelCapabilities
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_capabilities.return_value = ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
friendly_name="Gemini",
|
||||
max_tokens=200000,
|
||||
supports_extended_thinking=True
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
# Mock conversation history building
|
||||
mock_build.return_value = (
|
||||
"=== CONVERSATION HISTORY ===\nPrevious conversation content\n=== END HISTORY ===",
|
||||
1000 # mock token count
|
||||
)
|
||||
|
||||
# Create arguments with continuation_id
|
||||
arguments = {
|
||||
"continuation_id": "test-thread-123",
|
||||
"prompt": test_case["original_value"],
|
||||
"files": ["/test/file2.py"],
|
||||
}
|
||||
|
||||
# Call reconstruct_thread_context
|
||||
enhanced_args = await reconstruct_thread_context(arguments)
|
||||
|
||||
# Verify the enhanced prompt is in the prompt field
|
||||
assert "prompt" in enhanced_args
|
||||
enhanced_value = enhanced_args["prompt"]
|
||||
|
||||
# Should contain conversation history
|
||||
assert "=== CONVERSATION HISTORY ===" in enhanced_value
|
||||
assert "Previous conversation content" in enhanced_value
|
||||
|
||||
# Should contain the new user input
|
||||
assert "=== NEW USER INPUT ===" in enhanced_value
|
||||
assert test_case["original_value"] in enhanced_value
|
||||
|
||||
# Should have token budget
|
||||
assert "_remaining_tokens" in enhanced_args
|
||||
assert enhanced_args["_remaining_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_defaults_to_prompt():
|
||||
"""Test that unknown tools default to using 'prompt' field"""
|
||||
|
||||
mock_context = ThreadContext(
|
||||
thread_id="test-thread-456",
|
||||
tool_name="unknown_tool",
|
||||
created_at=datetime.now().isoformat(),
|
||||
last_updated_at=datetime.now().isoformat(),
|
||||
turns=[],
|
||||
initial_context={},
|
||||
)
|
||||
|
||||
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
|
||||
with patch("utils.conversation_memory.add_turn", return_value=True):
|
||||
with patch("utils.conversation_memory.build_conversation_history", return_value=("History", 500)):
|
||||
arguments = {
|
||||
"continuation_id": "test-thread-456",
|
||||
"prompt": "User input",
|
||||
}
|
||||
|
||||
enhanced_args = await reconstruct_thread_context(arguments)
|
||||
|
||||
# Should default to 'prompt' field
|
||||
assert "prompt" in enhanced_args
|
||||
assert "History" in enhanced_args["prompt"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_parameter_standardization():
|
||||
"""Test that all tools use standardized 'prompt' parameter"""
|
||||
from tools.analyze import AnalyzeRequest
|
||||
from tools.debug import DebugIssueRequest
|
||||
from tools.codereview import CodeReviewRequest
|
||||
from tools.thinkdeep import ThinkDeepRequest
|
||||
from tools.precommit import PrecommitRequest
|
||||
|
||||
# Test analyze tool uses prompt
|
||||
analyze = AnalyzeRequest(files=["/test.py"], prompt="What does this do?")
|
||||
assert analyze.prompt == "What does this do?"
|
||||
|
||||
# Test debug tool uses prompt
|
||||
debug = DebugIssueRequest(prompt="Error occurred")
|
||||
assert debug.prompt == "Error occurred"
|
||||
|
||||
# Test codereview tool uses prompt
|
||||
review = CodeReviewRequest(files=["/test.py"], prompt="Review this")
|
||||
assert review.prompt == "Review this"
|
||||
|
||||
# Test thinkdeep tool uses prompt
|
||||
think = ThinkDeepRequest(prompt="My analysis")
|
||||
assert think.prompt == "My analysis"
|
||||
|
||||
# Test precommit tool uses prompt (optional)
|
||||
precommit = PrecommitRequest(path="/repo", prompt="Fix bug")
|
||||
assert precommit.prompt == "Fix bug"
|
||||
@@ -12,6 +12,7 @@ Claude had shared in earlier turns.
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
@@ -94,7 +95,7 @@ class TestConversationHistoryBugFix:
|
||||
files=["/src/auth.py", "/tests/test_auth.py"], # Files from codereview tool
|
||||
),
|
||||
],
|
||||
initial_context={"question": "Analyze authentication security"},
|
||||
initial_context={"prompt": "Analyze authentication security"},
|
||||
)
|
||||
|
||||
# Mock add_turn to return success
|
||||
@@ -103,23 +104,23 @@ class TestConversationHistoryBugFix:
|
||||
# Mock the model to capture what prompt it receives
|
||||
captured_prompt = None
|
||||
|
||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
||||
mock_model = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [
|
||||
Mock(
|
||||
content=Mock(parts=[Mock(text="Response with conversation context")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
|
||||
def capture_prompt(prompt):
|
||||
def capture_prompt(prompt, **kwargs):
|
||||
nonlocal captured_prompt
|
||||
captured_prompt = prompt
|
||||
return mock_response
|
||||
return Mock(
|
||||
content="Response with conversation context",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
|
||||
mock_model.generate_content.side_effect = capture_prompt
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_provider.generate_content.side_effect = capture_prompt
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Execute tool with continuation_id
|
||||
# In the corrected flow, server.py:reconstruct_thread_context
|
||||
@@ -163,23 +164,23 @@ class TestConversationHistoryBugFix:
|
||||
|
||||
captured_prompt = None
|
||||
|
||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
||||
mock_model = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [
|
||||
Mock(
|
||||
content=Mock(parts=[Mock(text="Response without history")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
|
||||
def capture_prompt(prompt):
|
||||
def capture_prompt(prompt, **kwargs):
|
||||
nonlocal captured_prompt
|
||||
captured_prompt = prompt
|
||||
return mock_response
|
||||
return Mock(
|
||||
content="Response without history",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
|
||||
mock_model.generate_content.side_effect = capture_prompt
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_provider.generate_content.side_effect = capture_prompt
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Execute tool with continuation_id for non-existent thread
|
||||
# In the real flow, server.py would have already handled the missing thread
|
||||
@@ -201,23 +202,23 @@ class TestConversationHistoryBugFix:
|
||||
|
||||
captured_prompt = None
|
||||
|
||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
||||
mock_model = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [
|
||||
Mock(
|
||||
content=Mock(parts=[Mock(text="New conversation response")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
|
||||
def capture_prompt(prompt):
|
||||
def capture_prompt(prompt, **kwargs):
|
||||
nonlocal captured_prompt
|
||||
captured_prompt = prompt
|
||||
return mock_response
|
||||
return Mock(
|
||||
content="New conversation response",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
|
||||
mock_model.generate_content.side_effect = capture_prompt
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_provider.generate_content.side_effect = capture_prompt
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Execute tool without continuation_id (new conversation)
|
||||
arguments = {"prompt": "Start new conversation", "files": ["/src/new_file.py"]}
|
||||
@@ -275,7 +276,7 @@ class TestConversationHistoryBugFix:
|
||||
files=["/src/auth.py", "/tests/test_auth.py"], # auth.py referenced again + new file
|
||||
),
|
||||
],
|
||||
initial_context={"question": "Analyze authentication security"},
|
||||
initial_context={"prompt": "Analyze authentication security"},
|
||||
)
|
||||
|
||||
# Mock get_thread to return our test context
|
||||
@@ -285,23 +286,23 @@ class TestConversationHistoryBugFix:
|
||||
# Mock the model to capture what prompt it receives
|
||||
captured_prompt = None
|
||||
|
||||
with patch.object(self.tool, "create_model") as mock_create_model:
|
||||
mock_model = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [
|
||||
Mock(
|
||||
content=Mock(parts=[Mock(text="Analysis of new files complete")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
|
||||
def capture_prompt(prompt):
|
||||
def capture_prompt(prompt, **kwargs):
|
||||
nonlocal captured_prompt
|
||||
captured_prompt = prompt
|
||||
return mock_response
|
||||
return Mock(
|
||||
content="Analysis of new files complete",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
|
||||
mock_model.generate_content.side_effect = capture_prompt
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_provider.generate_content.side_effect = capture_prompt
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Mock read_files to simulate file existence and capture its calls
|
||||
with patch("tools.base.read_files") as mock_read_files:
|
||||
|
||||
@@ -166,7 +166,7 @@ class TestConversationMemory:
|
||||
initial_context={},
|
||||
)
|
||||
|
||||
history, tokens = build_conversation_history(context)
|
||||
history, tokens = build_conversation_history(context, model_context=None)
|
||||
|
||||
# Test basic structure
|
||||
assert "CONVERSATION HISTORY" in history
|
||||
@@ -207,7 +207,7 @@ class TestConversationMemory:
|
||||
initial_context={},
|
||||
)
|
||||
|
||||
history, tokens = build_conversation_history(context)
|
||||
history, tokens = build_conversation_history(context, model_context=None)
|
||||
assert history == ""
|
||||
assert tokens == 0
|
||||
|
||||
@@ -374,7 +374,7 @@ class TestConversationFlow:
|
||||
initial_context={},
|
||||
)
|
||||
|
||||
history, tokens = build_conversation_history(context)
|
||||
history, tokens = build_conversation_history(context, model_context=None)
|
||||
expected_turn_text = f"Turn {test_max}/{MAX_CONVERSATION_TURNS}"
|
||||
assert expected_turn_text in history
|
||||
|
||||
@@ -763,7 +763,7 @@ class TestConversationFlow:
|
||||
)
|
||||
|
||||
# Build conversation history (should handle token limits gracefully)
|
||||
history, tokens = build_conversation_history(context)
|
||||
history, tokens = build_conversation_history(context, model_context=None)
|
||||
|
||||
# Verify the history was built successfully
|
||||
assert "=== CONVERSATION HISTORY ===" in history
|
||||
|
||||
@@ -7,6 +7,7 @@ allowing multi-turn conversations to span multiple tool types.
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
@@ -98,15 +99,12 @@ class TestCrossToolContinuation:
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Step 1: Analysis tool creates a conversation with follow-up
|
||||
with patch.object(self.analysis_tool, "create_model") as mock_create_model:
|
||||
mock_model = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [
|
||||
Mock(
|
||||
content=Mock(
|
||||
parts=[
|
||||
Mock(
|
||||
text="""Found potential security issues in authentication logic.
|
||||
with patch.object(self.analysis_tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
# Include follow-up JSON in the content
|
||||
content_with_followup = """Found potential security issues in authentication logic.
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -115,14 +113,13 @@ class TestCrossToolContinuation:
|
||||
"ui_hint": "Security review recommended"
|
||||
}
|
||||
```"""
|
||||
)
|
||||
]
|
||||
),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=content_with_followup,
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Execute analysis tool
|
||||
arguments = {"code": "function authenticate(user) { return true; }"}
|
||||
@@ -160,23 +157,17 @@ class TestCrossToolContinuation:
|
||||
mock_client.get.side_effect = mock_get_side_effect
|
||||
|
||||
# Step 3: Review tool uses the same continuation_id
|
||||
with patch.object(self.review_tool, "create_model") as mock_create_model:
|
||||
mock_model = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [
|
||||
Mock(
|
||||
content=Mock(
|
||||
parts=[
|
||||
Mock(
|
||||
text="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks."
|
||||
)
|
||||
]
|
||||
),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(self.review_tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Execute review tool with the continuation_id from analysis tool
|
||||
arguments = {
|
||||
@@ -247,7 +238,7 @@ class TestCrossToolContinuation:
|
||||
# Build conversation history
|
||||
from utils.conversation_memory import build_conversation_history
|
||||
|
||||
history, tokens = build_conversation_history(thread_context)
|
||||
history, tokens = build_conversation_history(thread_context, model_context=None)
|
||||
|
||||
# Verify tool names are included in the history
|
||||
assert "Turn 1 (Gemini using test_analysis)" in history
|
||||
@@ -286,17 +277,17 @@ class TestCrossToolContinuation:
|
||||
mock_get_thread.return_value = existing_context
|
||||
|
||||
# Mock review tool response
|
||||
with patch.object(self.review_tool, "create_model") as mock_create_model:
|
||||
mock_model = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [
|
||||
Mock(
|
||||
content=Mock(parts=[Mock(text="Security review of auth.py shows vulnerabilities")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(self.review_tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Security review of auth.py shows vulnerabilities",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Execute review tool with additional files
|
||||
arguments = {
|
||||
|
||||
@@ -11,6 +11,7 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
|
||||
import pytest
|
||||
from mcp.types import TextContent
|
||||
@@ -68,17 +69,17 @@ class TestLargePromptHandling:
|
||||
tool = ChatTool()
|
||||
|
||||
# Mock the model to avoid actual API calls
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [
|
||||
MagicMock(
|
||||
content=MagicMock(parts=[MagicMock(text="This is a test response")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="This is a test response",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute({"prompt": normal_prompt})
|
||||
|
||||
@@ -93,17 +94,17 @@ class TestLargePromptHandling:
|
||||
tool = ChatTool()
|
||||
|
||||
# Mock the model
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [
|
||||
MagicMock(
|
||||
content=MagicMock(parts=[MagicMock(text="Processed large prompt")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="Processed large prompt",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Mock read_file_content to avoid security checks
|
||||
with patch("tools.base.read_file_content") as mock_read_file:
|
||||
@@ -123,8 +124,11 @@ class TestLargePromptHandling:
|
||||
mock_read_file.assert_called_once_with(temp_prompt_file)
|
||||
|
||||
# Verify the large content was used
|
||||
call_args = mock_model.generate_content.call_args[0][0]
|
||||
assert large_prompt in call_args
|
||||
# generate_content is called with keyword arguments
|
||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||
prompt_arg = call_kwargs.get("prompt")
|
||||
assert prompt_arg is not None
|
||||
assert large_prompt in prompt_arg
|
||||
|
||||
# Cleanup
|
||||
temp_dir = os.path.dirname(temp_prompt_file)
|
||||
@@ -134,7 +138,7 @@ class TestLargePromptHandling:
|
||||
async def test_thinkdeep_large_analysis(self, large_prompt):
|
||||
"""Test that thinkdeep tool detects large current_analysis."""
|
||||
tool = ThinkDeepTool()
|
||||
result = await tool.execute({"current_analysis": large_prompt})
|
||||
result = await tool.execute({"prompt": large_prompt})
|
||||
|
||||
assert len(result) == 1
|
||||
output = json.loads(result[0].text)
|
||||
@@ -148,7 +152,7 @@ class TestLargePromptHandling:
|
||||
{
|
||||
"files": ["/some/file.py"],
|
||||
"focus_on": large_prompt,
|
||||
"context": "Test code review for validation purposes",
|
||||
"prompt": "Test code review for validation purposes",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -160,7 +164,7 @@ class TestLargePromptHandling:
|
||||
async def test_review_changes_large_original_request(self, large_prompt):
|
||||
"""Test that review_changes tool detects large original_request."""
|
||||
tool = Precommit()
|
||||
result = await tool.execute({"path": "/some/path", "original_request": large_prompt})
|
||||
result = await tool.execute({"path": "/some/path", "prompt": large_prompt})
|
||||
|
||||
assert len(result) == 1
|
||||
output = json.loads(result[0].text)
|
||||
@@ -170,7 +174,7 @@ class TestLargePromptHandling:
|
||||
async def test_debug_large_error_description(self, large_prompt):
|
||||
"""Test that debug tool detects large error_description."""
|
||||
tool = DebugIssueTool()
|
||||
result = await tool.execute({"error_description": large_prompt})
|
||||
result = await tool.execute({"prompt": large_prompt})
|
||||
|
||||
assert len(result) == 1
|
||||
output = json.loads(result[0].text)
|
||||
@@ -180,7 +184,7 @@ class TestLargePromptHandling:
|
||||
async def test_debug_large_error_context(self, large_prompt, normal_prompt):
|
||||
"""Test that debug tool detects large error_context."""
|
||||
tool = DebugIssueTool()
|
||||
result = await tool.execute({"error_description": normal_prompt, "error_context": large_prompt})
|
||||
result = await tool.execute({"prompt": normal_prompt, "error_context": large_prompt})
|
||||
|
||||
assert len(result) == 1
|
||||
output = json.loads(result[0].text)
|
||||
@@ -190,7 +194,7 @@ class TestLargePromptHandling:
|
||||
async def test_analyze_large_question(self, large_prompt):
|
||||
"""Test that analyze tool detects large question."""
|
||||
tool = AnalyzeTool()
|
||||
result = await tool.execute({"files": ["/some/file.py"], "question": large_prompt})
|
||||
result = await tool.execute({"files": ["/some/file.py"], "prompt": large_prompt})
|
||||
|
||||
assert len(result) == 1
|
||||
output = json.loads(result[0].text)
|
||||
@@ -202,17 +206,17 @@ class TestLargePromptHandling:
|
||||
tool = ChatTool()
|
||||
other_file = "/some/other/file.py"
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [
|
||||
MagicMock(
|
||||
content=MagicMock(parts=[MagicMock(text="Success")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="Success",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Mock the centralized file preparation method to avoid file system access
|
||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare_files:
|
||||
@@ -235,17 +239,17 @@ class TestLargePromptHandling:
|
||||
tool = ChatTool()
|
||||
exact_prompt = "x" * MCP_PROMPT_SIZE_LIMIT
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [
|
||||
MagicMock(
|
||||
content=MagicMock(parts=[MagicMock(text="Success")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="Success",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute({"prompt": exact_prompt})
|
||||
output = json.loads(result[0].text)
|
||||
@@ -266,17 +270,17 @@ class TestLargePromptHandling:
|
||||
"""Test empty prompt without prompt.txt file."""
|
||||
tool = ChatTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [
|
||||
MagicMock(
|
||||
content=MagicMock(parts=[MagicMock(text="Success")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="Success",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute({"prompt": ""})
|
||||
output = json.loads(result[0].text)
|
||||
@@ -288,17 +292,17 @@ class TestLargePromptHandling:
|
||||
tool = ChatTool()
|
||||
bad_file = "/nonexistent/prompt.txt"
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [
|
||||
MagicMock(
|
||||
content=MagicMock(parts=[MagicMock(text="Success")]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="Success",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Should continue with empty prompt when file can't be read
|
||||
result = await tool.execute({"prompt": "", "files": [bad_file]})
|
||||
|
||||
@@ -49,7 +49,7 @@ async def run_manual_live_tests():
|
||||
result = await tool.execute(
|
||||
{
|
||||
"files": [temp_path],
|
||||
"question": "What does this code do?",
|
||||
"prompt": "What does this code do?",
|
||||
"thinking_mode": "low",
|
||||
}
|
||||
)
|
||||
@@ -64,7 +64,7 @@ async def run_manual_live_tests():
|
||||
think_tool = ThinkDeepTool()
|
||||
result = await think_tool.execute(
|
||||
{
|
||||
"current_analysis": "Testing live integration",
|
||||
"prompt": "Testing live integration",
|
||||
"thinking_mode": "minimal", # Fast test
|
||||
}
|
||||
)
|
||||
@@ -86,7 +86,7 @@ async def run_manual_live_tests():
|
||||
result = await analyze_tool.execute(
|
||||
{
|
||||
"files": [temp_path], # Only Python file, no package.json
|
||||
"question": "What npm packages and their versions does this project depend on? List all dependencies.",
|
||||
"prompt": "What npm packages and their versions does this project depend on? List all dependencies.",
|
||||
"thinking_mode": "minimal", # Fast test
|
||||
}
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestPrecommitTool:
|
||||
schema = tool.get_input_schema()
|
||||
assert schema["type"] == "object"
|
||||
assert "path" in schema["properties"]
|
||||
assert "original_request" in schema["properties"]
|
||||
assert "prompt" in schema["properties"]
|
||||
assert "compare_to" in schema["properties"]
|
||||
assert "review_type" in schema["properties"]
|
||||
|
||||
@@ -36,7 +36,7 @@ class TestPrecommitTool:
|
||||
"""Test request model default values"""
|
||||
request = PrecommitRequest(path="/some/absolute/path")
|
||||
assert request.path == "/some/absolute/path"
|
||||
assert request.original_request is None
|
||||
assert request.prompt is None
|
||||
assert request.compare_to is None
|
||||
assert request.include_staged is True
|
||||
assert request.include_unstaged is True
|
||||
@@ -48,7 +48,7 @@ class TestPrecommitTool:
|
||||
@pytest.mark.asyncio
|
||||
async def test_relative_path_rejected(self, tool):
|
||||
"""Test that relative paths are rejected"""
|
||||
result = await tool.execute({"path": "./relative/path", "original_request": "Test"})
|
||||
result = await tool.execute({"path": "./relative/path", "prompt": "Test"})
|
||||
assert len(result) == 1
|
||||
response = json.loads(result[0].text)
|
||||
assert response["status"] == "error"
|
||||
@@ -128,7 +128,7 @@ class TestPrecommitTool:
|
||||
|
||||
request = PrecommitRequest(
|
||||
path="/absolute/repo/path",
|
||||
original_request="Add hello message",
|
||||
prompt="Add hello message",
|
||||
review_type="security",
|
||||
)
|
||||
result = await tool.prepare_prompt(request)
|
||||
|
||||
@@ -124,7 +124,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
||||
temp_dir, config_path = temp_repo
|
||||
|
||||
# Create request with files parameter
|
||||
request = PrecommitRequest(path=temp_dir, files=[config_path], original_request="Test configuration changes")
|
||||
request = PrecommitRequest(path=temp_dir, files=[config_path], prompt="Test configuration changes")
|
||||
|
||||
# Generate the prompt
|
||||
prompt = await tool.prepare_prompt(request)
|
||||
@@ -152,7 +152,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
||||
# Mock conversation memory functions to use our mock redis
|
||||
with patch("utils.conversation_memory.get_redis_client", return_value=mock_redis):
|
||||
# First request - should embed file content
|
||||
PrecommitRequest(path=temp_dir, files=[config_path], original_request="First review")
|
||||
PrecommitRequest(path=temp_dir, files=[config_path], prompt="First review")
|
||||
|
||||
# Simulate conversation thread creation
|
||||
from utils.conversation_memory import add_turn, create_thread
|
||||
@@ -168,7 +168,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
||||
|
||||
# Second request with continuation - should skip already embedded files
|
||||
PrecommitRequest(
|
||||
path=temp_dir, files=[config_path], continuation_id=thread_id, original_request="Follow-up review"
|
||||
path=temp_dir, files=[config_path], continuation_id=thread_id, prompt="Follow-up review"
|
||||
)
|
||||
|
||||
files_to_embed_2 = tool.filter_new_files([config_path], thread_id)
|
||||
@@ -182,7 +182,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
||||
request = PrecommitRequest(
|
||||
path=temp_dir,
|
||||
files=[config_path],
|
||||
original_request="Validate prompt structure",
|
||||
prompt="Validate prompt structure",
|
||||
review_type="full",
|
||||
severity_filter="high",
|
||||
)
|
||||
@@ -191,7 +191,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
||||
|
||||
# Split prompt into sections
|
||||
sections = {
|
||||
"original_request": "## Original Request",
|
||||
"prompt": "## Original Request",
|
||||
"review_parameters": "## Review Parameters",
|
||||
"repo_summary": "## Repository Changes Summary",
|
||||
"context_files_summary": "## Context Files Summary",
|
||||
@@ -207,7 +207,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
||||
section_indices[name] = index
|
||||
|
||||
# Verify sections appear in logical order
|
||||
assert section_indices["original_request"] < section_indices["review_parameters"]
|
||||
assert section_indices["prompt"] < section_indices["review_parameters"]
|
||||
assert section_indices["review_parameters"] < section_indices["repo_summary"]
|
||||
assert section_indices["git_diffs"] < section_indices["additional_context"]
|
||||
assert section_indices["additional_context"] < section_indices["review_instructions"]
|
||||
|
||||
@@ -7,6 +7,7 @@ normal-sized prompts after implementing the large prompt handling feature.
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -24,16 +25,16 @@ class TestPromptRegression:
|
||||
@pytest.fixture
|
||||
def mock_model_response(self):
|
||||
"""Create a mock model response."""
|
||||
from unittest.mock import Mock
|
||||
|
||||
def _create_response(text="Test response"):
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [
|
||||
MagicMock(
|
||||
content=MagicMock(parts=[MagicMock(text=text)]),
|
||||
finish_reason="STOP",
|
||||
)
|
||||
]
|
||||
return mock_response
|
||||
# Return a Mock that acts like ModelResponse
|
||||
return Mock(
|
||||
content=text,
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"}
|
||||
)
|
||||
|
||||
return _create_response
|
||||
|
||||
@@ -42,10 +43,12 @@ class TestPromptRegression:
|
||||
"""Test chat tool with normal prompt."""
|
||||
tool = ChatTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response("This is a helpful response about Python.")
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response("This is a helpful response about Python.")
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute({"prompt": "Explain Python decorators"})
|
||||
|
||||
@@ -54,18 +57,20 @@ class TestPromptRegression:
|
||||
assert output["status"] == "success"
|
||||
assert "helpful response about Python" in output["content"]
|
||||
|
||||
# Verify model was called
|
||||
mock_model.generate_content.assert_called_once()
|
||||
# Verify provider was called
|
||||
mock_provider.generate_content.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_files(self, mock_model_response):
|
||||
"""Test chat tool with files parameter."""
|
||||
tool = ChatTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response()
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response()
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Mock file reading through the centralized method
|
||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare_files:
|
||||
@@ -83,16 +88,18 @@ class TestPromptRegression:
|
||||
"""Test thinkdeep tool with normal analysis."""
|
||||
tool = ThinkDeepTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response(
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response(
|
||||
"Here's a deeper analysis with edge cases..."
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute(
|
||||
{
|
||||
"current_analysis": "I think we should use a cache for performance",
|
||||
"prompt": "I think we should use a cache for performance",
|
||||
"problem_context": "Building a high-traffic API",
|
||||
"focus_areas": ["scalability", "reliability"],
|
||||
}
|
||||
@@ -109,12 +116,14 @@ class TestPromptRegression:
|
||||
"""Test codereview tool with normal inputs."""
|
||||
tool = CodeReviewTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response(
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response(
|
||||
"Found 3 issues: 1) Missing error handling..."
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Mock file reading
|
||||
with patch("tools.base.read_files") as mock_read_files:
|
||||
@@ -125,7 +134,7 @@ class TestPromptRegression:
|
||||
"files": ["/path/to/code.py"],
|
||||
"review_type": "security",
|
||||
"focus_on": "Look for SQL injection vulnerabilities",
|
||||
"context": "Test code review for validation purposes",
|
||||
"prompt": "Test code review for validation purposes",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -139,12 +148,14 @@ class TestPromptRegression:
|
||||
"""Test review_changes tool with normal original_request."""
|
||||
tool = Precommit()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response(
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response(
|
||||
"Changes look good, implementing feature as requested..."
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Mock git operations
|
||||
with patch("tools.precommit.find_git_repositories") as mock_find_repos:
|
||||
@@ -158,7 +169,7 @@ class TestPromptRegression:
|
||||
result = await tool.execute(
|
||||
{
|
||||
"path": "/path/to/repo",
|
||||
"original_request": "Add user authentication feature with JWT tokens",
|
||||
"prompt": "Add user authentication feature with JWT tokens",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -171,16 +182,18 @@ class TestPromptRegression:
|
||||
"""Test debug tool with normal error description."""
|
||||
tool = DebugIssueTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response(
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response(
|
||||
"Root cause: The variable is undefined. Fix: Initialize it..."
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute(
|
||||
{
|
||||
"error_description": "TypeError: Cannot read property 'name' of undefined",
|
||||
"prompt": "TypeError: Cannot read property 'name' of undefined",
|
||||
"error_context": "at line 42 in user.js\n console.log(user.name)",
|
||||
"runtime_info": "Node.js v16.14.0",
|
||||
}
|
||||
@@ -197,12 +210,14 @@ class TestPromptRegression:
|
||||
"""Test analyze tool with normal question."""
|
||||
tool = AnalyzeTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response(
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response(
|
||||
"The code follows MVC pattern with clear separation..."
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Mock file reading
|
||||
with patch("tools.base.read_files") as mock_read_files:
|
||||
@@ -211,7 +226,7 @@ class TestPromptRegression:
|
||||
result = await tool.execute(
|
||||
{
|
||||
"files": ["/path/to/project"],
|
||||
"question": "What design patterns are used in this codebase?",
|
||||
"prompt": "What design patterns are used in this codebase?",
|
||||
"analysis_type": "architecture",
|
||||
}
|
||||
)
|
||||
@@ -226,10 +241,12 @@ class TestPromptRegression:
|
||||
"""Test tools work with empty optional fields."""
|
||||
tool = ChatTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response()
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response()
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Test with no files parameter
|
||||
result = await tool.execute({"prompt": "Hello"})
|
||||
@@ -243,10 +260,12 @@ class TestPromptRegression:
|
||||
"""Test that thinking modes are properly passed through."""
|
||||
tool = ChatTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response()
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response()
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute({"prompt": "Test", "thinking_mode": "high", "temperature": 0.8})
|
||||
|
||||
@@ -254,21 +273,24 @@ class TestPromptRegression:
|
||||
output = json.loads(result[0].text)
|
||||
assert output["status"] == "success"
|
||||
|
||||
# Verify create_model was called with correct parameters
|
||||
mock_create_model.assert_called_once()
|
||||
call_args = mock_create_model.call_args
|
||||
assert call_args[0][2] == "high" # thinking_mode
|
||||
assert call_args[0][1] == 0.8 # temperature
|
||||
# Verify generate_content was called with correct parameters
|
||||
mock_provider.generate_content.assert_called_once()
|
||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||
assert call_kwargs.get("temperature") == 0.8
|
||||
# thinking_mode would be passed if the provider supports it
|
||||
# In this test, we set supports_thinking_mode to False, so it won't be passed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_special_characters_in_prompts(self, mock_model_response):
|
||||
"""Test prompts with special characters work correctly."""
|
||||
tool = ChatTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response()
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response()
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
special_prompt = 'Test with "quotes" and\nnewlines\tand tabs'
|
||||
result = await tool.execute({"prompt": special_prompt})
|
||||
@@ -282,10 +304,12 @@ class TestPromptRegression:
|
||||
"""Test handling of various file path formats."""
|
||||
tool = AnalyzeTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response()
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response()
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
with patch("tools.base.read_files") as mock_read_files:
|
||||
mock_read_files.return_value = "Content"
|
||||
@@ -297,7 +321,7 @@ class TestPromptRegression:
|
||||
"/Users/name/project/src/",
|
||||
"/home/user/code.js",
|
||||
],
|
||||
"question": "Analyze these files",
|
||||
"prompt": "Analyze these files",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -311,10 +335,12 @@ class TestPromptRegression:
|
||||
"""Test handling of unicode content in prompts."""
|
||||
tool = ChatTool()
|
||||
|
||||
with patch.object(tool, "create_model") as mock_create_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.return_value = mock_model_response()
|
||||
mock_create_model.return_value = mock_model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = mock_model_response()
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
unicode_prompt = "Explain this: 你好世界 مرحبا بالعالم"
|
||||
result = await tool.execute({"prompt": unicode_prompt})
|
||||
|
||||
187
tests/test_providers.py
Normal file
187
tests/test_providers.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Tests for the model provider abstraction system"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
import os
|
||||
|
||||
from providers import ModelProviderRegistry, ModelProvider, ModelResponse, ModelCapabilities
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
|
||||
class TestModelProviderRegistry:
|
||||
"""Test the model provider registry"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry before each test"""
|
||||
ModelProviderRegistry._providers.clear()
|
||||
ModelProviderRegistry._initialized_providers.clear()
|
||||
|
||||
def test_register_provider(self):
|
||||
"""Test registering a provider"""
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
assert ProviderType.GOOGLE in ModelProviderRegistry._providers
|
||||
assert ModelProviderRegistry._providers[ProviderType.GOOGLE] == GeminiModelProvider
|
||||
|
||||
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
|
||||
def test_get_provider(self):
|
||||
"""Test getting a provider instance"""
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
||||
|
||||
assert provider is not None
|
||||
assert isinstance(provider, GeminiModelProvider)
|
||||
assert provider.api_key == "test-key"
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_get_provider_no_api_key(self):
|
||||
"""Test getting provider without API key returns None"""
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.GOOGLE)
|
||||
|
||||
assert provider is None
|
||||
|
||||
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"})
|
||||
def test_get_provider_for_model(self):
|
||||
"""Test getting provider for a specific model"""
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash-exp")
|
||||
|
||||
assert provider is not None
|
||||
assert isinstance(provider, GeminiModelProvider)
|
||||
|
||||
def test_get_available_providers(self):
|
||||
"""Test getting list of available providers"""
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
providers = ModelProviderRegistry.get_available_providers()
|
||||
|
||||
assert len(providers) == 2
|
||||
assert ProviderType.GOOGLE in providers
|
||||
assert ProviderType.OPENAI in providers
|
||||
|
||||
|
||||
class TestGeminiProvider:
|
||||
"""Test Gemini model provider"""
|
||||
|
||||
def test_provider_initialization(self):
|
||||
"""Test provider initialization"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.get_provider_type() == ProviderType.GOOGLE
|
||||
|
||||
def test_get_capabilities(self):
|
||||
"""Test getting model capabilities"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("gemini-2.0-flash-exp")
|
||||
|
||||
assert capabilities.provider == ProviderType.GOOGLE
|
||||
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
||||
assert capabilities.max_tokens == 1_048_576
|
||||
assert not capabilities.supports_extended_thinking
|
||||
|
||||
def test_get_capabilities_pro_model(self):
|
||||
"""Test getting capabilities for Pro model with thinking support"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("gemini-2.5-pro-preview-06-05")
|
||||
|
||||
assert capabilities.supports_extended_thinking
|
||||
|
||||
def test_model_shorthand_resolution(self):
|
||||
"""Test model shorthand resolution"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
assert provider.validate_model_name("flash")
|
||||
assert provider.validate_model_name("pro")
|
||||
|
||||
capabilities = provider.get_capabilities("flash")
|
||||
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
||||
|
||||
def test_supports_thinking_mode(self):
|
||||
"""Test thinking mode support detection"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
assert not provider.supports_thinking_mode("gemini-2.0-flash-exp")
|
||||
assert provider.supports_thinking_mode("gemini-2.5-pro-preview-06-05")
|
||||
|
||||
@patch("google.genai.Client")
|
||||
def test_generate_content(self, mock_client_class):
|
||||
"""Test content generation"""
|
||||
# Mock the client
|
||||
mock_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.text = "Generated content"
|
||||
# Mock candidates for finish_reason
|
||||
mock_candidate = Mock()
|
||||
mock_candidate.finish_reason = "STOP"
|
||||
mock_response.candidates = [mock_candidate]
|
||||
# Mock usage metadata
|
||||
mock_usage = Mock()
|
||||
mock_usage.prompt_token_count = 10
|
||||
mock_usage.candidates_token_count = 20
|
||||
mock_response.usage_metadata = mock_usage
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
response = provider.generate_content(
|
||||
prompt="Test prompt",
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.content == "Generated content"
|
||||
assert response.model_name == "gemini-2.0-flash-exp"
|
||||
assert response.provider == ProviderType.GOOGLE
|
||||
assert response.usage["input_tokens"] == 10
|
||||
assert response.usage["output_tokens"] == 20
|
||||
assert response.usage["total_tokens"] == 30
|
||||
|
||||
|
||||
class TestOpenAIProvider:
|
||||
"""Test OpenAI model provider"""
|
||||
|
||||
def test_provider_initialization(self):
|
||||
"""Test provider initialization"""
|
||||
provider = OpenAIModelProvider(api_key="test-key", organization="test-org")
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.organization == "test-org"
|
||||
assert provider.get_provider_type() == ProviderType.OPENAI
|
||||
|
||||
def test_get_capabilities_o3(self):
|
||||
"""Test getting O3 model capabilities"""
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("o3-mini")
|
||||
|
||||
assert capabilities.provider == ProviderType.OPENAI
|
||||
assert capabilities.model_name == "o3-mini"
|
||||
assert capabilities.max_tokens == 200_000
|
||||
assert not capabilities.supports_extended_thinking
|
||||
|
||||
def test_validate_model_names(self):
|
||||
"""Test model name validation"""
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
assert provider.validate_model_name("o3-mini")
|
||||
assert provider.validate_model_name("gpt-4o")
|
||||
assert not provider.validate_model_name("invalid-model")
|
||||
|
||||
def test_no_thinking_mode_support(self):
|
||||
"""Test that no OpenAI models support thinking mode"""
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
assert not provider.supports_thinking_mode("o3-mini")
|
||||
assert not provider.supports_thinking_mode("gpt-4o")
|
||||
@@ -3,6 +3,7 @@ Tests for the main server functionality
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -42,31 +43,36 @@ class TestServerTools:
|
||||
assert "Unknown tool: unknown_tool" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_chat(self):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_handle_chat(self, mock_get_provider):
|
||||
"""Test chat functionality"""
|
||||
# Set test environment
|
||||
import os
|
||||
|
||||
os.environ["PYTEST_CURRENT_TEST"] = "test"
|
||||
|
||||
# Create a mock for the model
|
||||
with patch("tools.base.BaseTool.create_model") as mock_create:
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text="Chat response")]))]
|
||||
)
|
||||
mock_create.return_value = mock_model
|
||||
# Create a mock for the provider
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Chat response",
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await handle_call_tool("chat", {"prompt": "Hello Gemini"})
|
||||
result = await handle_call_tool("chat", {"prompt": "Hello Gemini"})
|
||||
|
||||
assert len(result) == 1
|
||||
# Parse JSON response
|
||||
import json
|
||||
assert len(result) == 1
|
||||
# Parse JSON response
|
||||
import json
|
||||
|
||||
response_data = json.loads(result[0].text)
|
||||
assert response_data["status"] == "success"
|
||||
assert "Chat response" in response_data["content"]
|
||||
assert "Claude's Turn" in response_data["content"]
|
||||
response_data = json.loads(result[0].text)
|
||||
assert response_data["status"] == "success"
|
||||
assert "Chat response" in response_data["content"]
|
||||
assert "Claude's Turn" in response_data["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_version(self):
|
||||
|
||||
@@ -3,6 +3,7 @@ Tests for thinking_mode functionality across all tools
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -37,28 +38,35 @@ class TestThinkingModes:
|
||||
), f"{tool.__class__.__name__} should default to {expected_default}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_thinking_mode_minimal(self, mock_create_model):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_thinking_mode_minimal(self, mock_get_provider):
|
||||
"""Test minimal thinking mode"""
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text="Minimal thinking response")]))]
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Minimal thinking response",
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
tool = AnalyzeTool()
|
||||
result = await tool.execute(
|
||||
{
|
||||
"files": ["/absolute/path/test.py"],
|
||||
"question": "What is this?",
|
||||
"prompt": "What is this?",
|
||||
"thinking_mode": "minimal",
|
||||
}
|
||||
)
|
||||
|
||||
# Verify create_model was called with correct thinking_mode
|
||||
mock_create_model.assert_called_once()
|
||||
args = mock_create_model.call_args[0]
|
||||
assert args[2] == "minimal" # thinking_mode parameter
|
||||
mock_get_provider.assert_called_once()
|
||||
# Verify generate_content was called with thinking_mode
|
||||
mock_provider.generate_content.assert_called_once()
|
||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||
assert call_kwargs.get("thinking_mode") == "minimal" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None) # thinking_mode parameter
|
||||
|
||||
# Parse JSON response
|
||||
import json
|
||||
@@ -68,102 +76,130 @@ class TestThinkingModes:
|
||||
assert response_data["content"].startswith("Analysis:")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_thinking_mode_low(self, mock_create_model):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_thinking_mode_low(self, mock_get_provider):
|
||||
"""Test low thinking mode"""
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text="Low thinking response")]))]
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Low thinking response",
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
tool = CodeReviewTool()
|
||||
result = await tool.execute(
|
||||
{
|
||||
"files": ["/absolute/path/test.py"],
|
||||
"thinking_mode": "low",
|
||||
"context": "Test code review for validation purposes",
|
||||
"prompt": "Test code review for validation purposes",
|
||||
}
|
||||
)
|
||||
|
||||
# Verify create_model was called with correct thinking_mode
|
||||
mock_create_model.assert_called_once()
|
||||
args = mock_create_model.call_args[0]
|
||||
assert args[2] == "low"
|
||||
mock_get_provider.assert_called_once()
|
||||
# Verify generate_content was called with thinking_mode
|
||||
mock_provider.generate_content.assert_called_once()
|
||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||
assert call_kwargs.get("thinking_mode") == "low" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
||||
|
||||
assert "Code Review" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_thinking_mode_medium(self, mock_create_model):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_thinking_mode_medium(self, mock_get_provider):
|
||||
"""Test medium thinking mode (default for most tools)"""
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text="Medium thinking response")]))]
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Medium thinking response",
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
tool = DebugIssueTool()
|
||||
result = await tool.execute(
|
||||
{
|
||||
"error_description": "Test error",
|
||||
"prompt": "Test error",
|
||||
# Not specifying thinking_mode, should use default (medium)
|
||||
}
|
||||
)
|
||||
|
||||
# Verify create_model was called with default thinking_mode
|
||||
mock_create_model.assert_called_once()
|
||||
args = mock_create_model.call_args[0]
|
||||
assert args[2] == "medium"
|
||||
mock_get_provider.assert_called_once()
|
||||
# Verify generate_content was called with thinking_mode
|
||||
mock_provider.generate_content.assert_called_once()
|
||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||
assert call_kwargs.get("thinking_mode") == "medium" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
||||
|
||||
assert "Debug Analysis" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_thinking_mode_high(self, mock_create_model):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_thinking_mode_high(self, mock_get_provider):
|
||||
"""Test high thinking mode"""
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text="High thinking response")]))]
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="High thinking response",
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
tool = AnalyzeTool()
|
||||
await tool.execute(
|
||||
{
|
||||
"files": ["/absolute/path/complex.py"],
|
||||
"question": "Analyze architecture",
|
||||
"prompt": "Analyze architecture",
|
||||
"thinking_mode": "high",
|
||||
}
|
||||
)
|
||||
|
||||
# Verify create_model was called with correct thinking_mode
|
||||
mock_create_model.assert_called_once()
|
||||
args = mock_create_model.call_args[0]
|
||||
assert args[2] == "high"
|
||||
mock_get_provider.assert_called_once()
|
||||
# Verify generate_content was called with thinking_mode
|
||||
mock_provider.generate_content.assert_called_once()
|
||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||
assert call_kwargs.get("thinking_mode") == "high" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_thinking_mode_max(self, mock_create_model):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_thinking_mode_max(self, mock_get_provider):
|
||||
"""Test max thinking mode (default for thinkdeep)"""
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text="Max thinking response")]))]
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Max thinking response",
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
tool = ThinkDeepTool()
|
||||
result = await tool.execute(
|
||||
{
|
||||
"current_analysis": "Initial analysis",
|
||||
"prompt": "Initial analysis",
|
||||
# Not specifying thinking_mode, should use default (high)
|
||||
}
|
||||
)
|
||||
|
||||
# Verify create_model was called with default thinking_mode
|
||||
mock_create_model.assert_called_once()
|
||||
args = mock_create_model.call_args[0]
|
||||
assert args[2] == "high"
|
||||
mock_get_provider.assert_called_once()
|
||||
# Verify generate_content was called with thinking_mode
|
||||
mock_provider.generate_content.assert_called_once()
|
||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||
assert call_kwargs.get("thinking_mode") == "high" or (not mock_provider.supports_thinking_mode.return_value and call_kwargs.get("thinking_mode") is None)
|
||||
|
||||
assert "Extended Analysis by Gemini" in result[0].text
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ Tests for individual tool implementations
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -24,23 +25,28 @@ class TestThinkDeepTool:
|
||||
assert tool.get_default_temperature() == 0.7
|
||||
|
||||
schema = tool.get_input_schema()
|
||||
assert "current_analysis" in schema["properties"]
|
||||
assert schema["required"] == ["current_analysis"]
|
||||
assert "prompt" in schema["properties"]
|
||||
assert schema["required"] == ["prompt"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_execute_success(self, mock_create_model, tool):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_execute_success(self, mock_get_provider, tool):
|
||||
"""Test successful execution"""
|
||||
# Mock model
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text="Extended analysis")]))]
|
||||
# Mock provider
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Extended analysis",
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute(
|
||||
{
|
||||
"current_analysis": "Initial analysis",
|
||||
"prompt": "Initial analysis",
|
||||
"problem_context": "Building a cache",
|
||||
"focus_areas": ["performance", "scalability"],
|
||||
}
|
||||
@@ -69,30 +75,35 @@ class TestCodeReviewTool:
|
||||
|
||||
schema = tool.get_input_schema()
|
||||
assert "files" in schema["properties"]
|
||||
assert "context" in schema["properties"]
|
||||
assert schema["required"] == ["files", "context"]
|
||||
assert "prompt" in schema["properties"]
|
||||
assert schema["required"] == ["files", "prompt"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_execute_with_review_type(self, mock_create_model, tool, tmp_path):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_execute_with_review_type(self, mock_get_provider, tool, tmp_path):
|
||||
"""Test execution with specific review type"""
|
||||
# Create test file
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("def insecure(): pass", encoding="utf-8")
|
||||
|
||||
# Mock model
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text="Security issues found")]))]
|
||||
# Mock provider
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Security issues found",
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute(
|
||||
{
|
||||
"files": [str(test_file)],
|
||||
"review_type": "security",
|
||||
"focus_on": "authentication",
|
||||
"context": "Test code review for validation purposes",
|
||||
"prompt": "Test code review for validation purposes",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -116,23 +127,28 @@ class TestDebugIssueTool:
|
||||
assert tool.get_default_temperature() == 0.2
|
||||
|
||||
schema = tool.get_input_schema()
|
||||
assert "error_description" in schema["properties"]
|
||||
assert schema["required"] == ["error_description"]
|
||||
assert "prompt" in schema["properties"]
|
||||
assert schema["required"] == ["prompt"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_execute_with_context(self, mock_create_model, tool):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_execute_with_context(self, mock_get_provider, tool):
|
||||
"""Test execution with error context"""
|
||||
# Mock model
|
||||
mock_model = Mock()
|
||||
mock_model.generate_content.return_value = Mock(
|
||||
candidates=[Mock(content=Mock(parts=[Mock(text="Root cause: race condition")]))]
|
||||
# Mock provider
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Root cause: race condition",
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_create_model.return_value = mock_model
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute(
|
||||
{
|
||||
"error_description": "Test fails intermittently",
|
||||
"prompt": "Test fails intermittently",
|
||||
"error_context": "AssertionError in test_async",
|
||||
"previous_attempts": "Added sleep, still fails",
|
||||
}
|
||||
@@ -158,30 +174,33 @@ class TestAnalyzeTool:
|
||||
|
||||
schema = tool.get_input_schema()
|
||||
assert "files" in schema["properties"]
|
||||
assert "question" in schema["properties"]
|
||||
assert set(schema["required"]) == {"files", "question"}
|
||||
assert "prompt" in schema["properties"]
|
||||
assert set(schema["required"]) == {"files", "prompt"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.create_model")
|
||||
async def test_execute_with_analysis_type(self, mock_model, tool, tmp_path):
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_execute_with_analysis_type(self, mock_get_provider, tool, tmp_path):
|
||||
"""Test execution with specific analysis type"""
|
||||
# Create test file
|
||||
test_file = tmp_path / "module.py"
|
||||
test_file.write_text("class Service: pass", encoding="utf-8")
|
||||
|
||||
# Mock response
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [Mock()]
|
||||
mock_response.candidates[0].content.parts = [Mock(text="Architecture analysis")]
|
||||
|
||||
mock_instance = Mock()
|
||||
mock_instance.generate_content.return_value = mock_response
|
||||
mock_model.return_value = mock_instance
|
||||
# Mock provider
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Architecture analysis",
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute(
|
||||
{
|
||||
"files": [str(test_file)],
|
||||
"question": "What's the structure?",
|
||||
"prompt": "What's the structure?",
|
||||
"analysis_type": "architecture",
|
||||
"output_format": "summary",
|
||||
}
|
||||
@@ -203,7 +222,7 @@ class TestAbsolutePathValidation:
|
||||
result = await tool.execute(
|
||||
{
|
||||
"files": ["./relative/path.py", "/absolute/path.py"],
|
||||
"question": "What does this do?",
|
||||
"prompt": "What does this do?",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -221,7 +240,7 @@ class TestAbsolutePathValidation:
|
||||
{
|
||||
"files": ["../parent/file.py"],
|
||||
"review_type": "full",
|
||||
"context": "Test code review for validation purposes",
|
||||
"prompt": "Test code review for validation purposes",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -237,7 +256,7 @@ class TestAbsolutePathValidation:
|
||||
tool = DebugIssueTool()
|
||||
result = await tool.execute(
|
||||
{
|
||||
"error_description": "Something broke",
|
||||
"prompt": "Something broke",
|
||||
"files": ["src/main.py"], # relative path
|
||||
}
|
||||
)
|
||||
@@ -252,7 +271,7 @@ class TestAbsolutePathValidation:
|
||||
async def test_thinkdeep_tool_relative_path_rejected(self):
|
||||
"""Test that thinkdeep tool rejects relative paths"""
|
||||
tool = ThinkDeepTool()
|
||||
result = await tool.execute({"current_analysis": "My analysis", "files": ["./local/file.py"]})
|
||||
result = await tool.execute({"prompt": "My analysis", "files": ["./local/file.py"]})
|
||||
|
||||
assert len(result) == 1
|
||||
response = json.loads(result[0].text)
|
||||
@@ -278,21 +297,24 @@ class TestAbsolutePathValidation:
|
||||
assert "code.py" in response["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.AnalyzeTool.create_model")
|
||||
async def test_analyze_tool_accepts_absolute_paths(self, mock_model):
|
||||
@patch("tools.AnalyzeTool.get_model_provider")
|
||||
async def test_analyze_tool_accepts_absolute_paths(self, mock_get_provider):
|
||||
"""Test that analyze tool accepts absolute paths"""
|
||||
tool = AnalyzeTool()
|
||||
|
||||
# Mock the model response
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [Mock()]
|
||||
mock_response.candidates[0].content.parts = [Mock(text="Analysis complete")]
|
||||
# Mock provider
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Analysis complete",
|
||||
usage={},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
mock_instance = Mock()
|
||||
mock_instance.generate_content.return_value = mock_response
|
||||
mock_model.return_value = mock_instance
|
||||
|
||||
result = await tool.execute({"files": ["/absolute/path/file.py"], "question": "What does this do?"})
|
||||
result = await tool.execute({"files": ["/absolute/path/file.py"], "prompt": "What does this do?"})
|
||||
|
||||
assert len(result) == 1
|
||||
response = json.loads(result[0].text)
|
||||
|
||||
Reference in New Issue
Block a user