New tool: testgen
Generates unit tests and encourages model to auto-detect framework and testing style from existing sample (if available)
This commit is contained in:
@@ -26,10 +26,11 @@ class TestServerTools:
|
||||
assert "analyze" in tool_names
|
||||
assert "chat" in tool_names
|
||||
assert "precommit" in tool_names
|
||||
assert "testgen" in tool_names
|
||||
assert "get_version" in tool_names
|
||||
|
||||
# Should have exactly 7 tools
|
||||
assert len(tools) == 7
|
||||
# Should have exactly 8 tools (including testgen)
|
||||
assert len(tools) == 8
|
||||
|
||||
# Check descriptions are verbose
|
||||
for tool in tools:
|
||||
|
||||
381
tests/test_testgen.py
Normal file
381
tests/test_testgen.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
Tests for TestGen tool implementation
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
from tools.testgen import TestGenRequest, TestGenTool
|
||||
|
||||
|
||||
class TestTestGenTool:
|
||||
"""Test the TestGen tool"""
|
||||
|
||||
@pytest.fixture
|
||||
def tool(self):
|
||||
return TestGenTool()
|
||||
|
||||
@pytest.fixture
|
||||
def temp_files(self):
|
||||
"""Create temporary test files"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Create sample code files
|
||||
code_file = temp_path / "calculator.py"
|
||||
code_file.write_text(
|
||||
"""
|
||||
def add(a, b):
|
||||
'''Add two numbers'''
|
||||
return a + b
|
||||
|
||||
def divide(a, b):
|
||||
'''Divide two numbers'''
|
||||
if b == 0:
|
||||
raise ValueError("Cannot divide by zero")
|
||||
return a / b
|
||||
"""
|
||||
)
|
||||
|
||||
# Create sample test files (different sizes)
|
||||
small_test = temp_path / "test_small.py"
|
||||
small_test.write_text(
|
||||
"""
|
||||
import unittest
|
||||
|
||||
class TestBasic(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
self.assertEqual(1 + 1, 2)
|
||||
"""
|
||||
)
|
||||
|
||||
large_test = temp_path / "test_large.py"
|
||||
large_test.write_text(
|
||||
"""
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
class TestComprehensive(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.mock_data = Mock()
|
||||
|
||||
def test_feature_one(self):
|
||||
# Comprehensive test with lots of setup
|
||||
result = self.process_data()
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_feature_two(self):
|
||||
# Another comprehensive test
|
||||
with patch('some.module') as mock_module:
|
||||
mock_module.return_value = 'test'
|
||||
result = self.process_data()
|
||||
self.assertEqual(result, 'expected')
|
||||
|
||||
def process_data(self):
|
||||
return "test_result"
|
||||
"""
|
||||
)
|
||||
|
||||
yield {
|
||||
"temp_dir": temp_dir,
|
||||
"code_file": str(code_file),
|
||||
"small_test": str(small_test),
|
||||
"large_test": str(large_test),
|
||||
}
|
||||
|
||||
def test_tool_metadata(self, tool):
|
||||
"""Test tool metadata"""
|
||||
assert tool.get_name() == "testgen"
|
||||
assert "COMPREHENSIVE TEST GENERATION" in tool.get_description()
|
||||
assert "BE SPECIFIC about scope" in tool.get_description()
|
||||
assert tool.get_default_temperature() == 0.2 # Analytical temperature
|
||||
|
||||
# Check model category
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
|
||||
|
||||
def test_input_schema_structure(self, tool):
|
||||
"""Test input schema structure"""
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Required fields
|
||||
assert "files" in schema["properties"]
|
||||
assert "prompt" in schema["properties"]
|
||||
assert "files" in schema["required"]
|
||||
assert "prompt" in schema["required"]
|
||||
|
||||
# Optional fields
|
||||
assert "test_examples" in schema["properties"]
|
||||
assert "thinking_mode" in schema["properties"]
|
||||
assert "continuation_id" in schema["properties"]
|
||||
|
||||
# Should not have temperature or use_websearch
|
||||
assert "temperature" not in schema["properties"]
|
||||
assert "use_websearch" not in schema["properties"]
|
||||
|
||||
# Check test_examples description
|
||||
test_examples_desc = schema["properties"]["test_examples"]["description"]
|
||||
assert "absolute paths" in test_examples_desc
|
||||
assert "smallest representative tests" in test_examples_desc
|
||||
|
||||
def test_request_model_validation(self):
|
||||
"""Test request model validation"""
|
||||
# Valid request
|
||||
valid_request = TestGenRequest(files=["/tmp/test.py"], prompt="Generate tests for calculator functions")
|
||||
assert valid_request.files == ["/tmp/test.py"]
|
||||
assert valid_request.prompt == "Generate tests for calculator functions"
|
||||
assert valid_request.test_examples is None
|
||||
|
||||
# With test examples
|
||||
request_with_examples = TestGenRequest(
|
||||
files=["/tmp/test.py"], prompt="Generate tests", test_examples=["/tmp/test_example.py"]
|
||||
)
|
||||
assert request_with_examples.test_examples == ["/tmp/test_example.py"]
|
||||
|
||||
# Invalid request (missing required fields)
|
||||
with pytest.raises(ValueError):
|
||||
TestGenRequest(files=["/tmp/test.py"]) # Missing prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_execute_success(self, mock_get_provider, tool, temp_files):
|
||||
"""Test successful execution"""
|
||||
# Mock provider
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Generated comprehensive test suite with edge cases",
|
||||
usage={"input_tokens": 100, "output_tokens": 200},
|
||||
model_name="gemini-2.5-flash-preview-05-20",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute(
|
||||
{"files": [temp_files["code_file"]], "prompt": "Generate comprehensive tests for the calculator functions"}
|
||||
)
|
||||
|
||||
# Verify result structure
|
||||
assert len(result) == 1
|
||||
response_data = json.loads(result[0].text)
|
||||
assert response_data["status"] == "success"
|
||||
assert "Generated comprehensive test suite" in response_data["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_execute_with_test_examples(self, mock_get_provider, tool, temp_files):
|
||||
"""Test execution with test examples"""
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Generated tests following the provided examples",
|
||||
usage={"input_tokens": 150, "output_tokens": 250},
|
||||
model_name="gemini-2.5-flash-preview-05-20",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute(
|
||||
{
|
||||
"files": [temp_files["code_file"]],
|
||||
"prompt": "Generate tests following existing patterns",
|
||||
"test_examples": [temp_files["small_test"]],
|
||||
}
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert len(result) == 1
|
||||
response_data = json.loads(result[0].text)
|
||||
assert response_data["status"] == "success"
|
||||
|
||||
def test_process_test_examples_empty(self, tool):
|
||||
"""Test processing empty test examples"""
|
||||
content, note = tool._process_test_examples([], None)
|
||||
assert content == ""
|
||||
assert note == ""
|
||||
|
||||
def test_process_test_examples_budget_allocation(self, tool, temp_files):
|
||||
"""Test token budget allocation for test examples"""
|
||||
with patch.object(tool, "filter_new_files") as mock_filter:
|
||||
mock_filter.return_value = [temp_files["small_test"], temp_files["large_test"]]
|
||||
|
||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
||||
mock_prepare.return_value = "Mocked test content"
|
||||
|
||||
# Test with available tokens
|
||||
content, note = tool._process_test_examples(
|
||||
[temp_files["small_test"], temp_files["large_test"]], None, available_tokens=100000
|
||||
)
|
||||
|
||||
# Should allocate 25% of 100k = 25k tokens for test examples
|
||||
mock_prepare.assert_called_once()
|
||||
call_args = mock_prepare.call_args
|
||||
assert call_args[1]["max_tokens"] == 25000 # 25% of 100k
|
||||
|
||||
def test_process_test_examples_size_sorting(self, tool, temp_files):
|
||||
"""Test that test examples are sorted by size (smallest first)"""
|
||||
with patch.object(tool, "filter_new_files") as mock_filter:
|
||||
# Return files in random order
|
||||
mock_filter.return_value = [temp_files["large_test"], temp_files["small_test"]]
|
||||
|
||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
||||
mock_prepare.return_value = "test content"
|
||||
|
||||
tool._process_test_examples(
|
||||
[temp_files["large_test"], temp_files["small_test"]], None, available_tokens=50000
|
||||
)
|
||||
|
||||
# Check that files were passed in size order (smallest first)
|
||||
call_args = mock_prepare.call_args[0]
|
||||
files_passed = call_args[0]
|
||||
|
||||
# Verify smallest file comes first
|
||||
assert files_passed[0] == temp_files["small_test"]
|
||||
assert files_passed[1] == temp_files["large_test"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_prompt_structure(self, tool, temp_files):
|
||||
"""Test prompt preparation structure"""
|
||||
request = TestGenRequest(files=[temp_files["code_file"]], prompt="Test the calculator functions")
|
||||
|
||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
||||
mock_prepare.return_value = "mocked file content"
|
||||
|
||||
prompt = await tool.prepare_prompt(request)
|
||||
|
||||
# Check prompt structure
|
||||
assert "=== USER CONTEXT ===" in prompt
|
||||
assert "Test the calculator functions" in prompt
|
||||
assert "=== CODE TO TEST ===" in prompt
|
||||
assert "mocked file content" in prompt
|
||||
assert tool.get_system_prompt() in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_prompt_with_examples(self, tool, temp_files):
|
||||
"""Test prompt preparation with test examples"""
|
||||
request = TestGenRequest(
|
||||
files=[temp_files["code_file"]], prompt="Generate tests", test_examples=[temp_files["small_test"]]
|
||||
)
|
||||
|
||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
||||
mock_prepare.return_value = "mocked content"
|
||||
|
||||
with patch.object(tool, "_process_test_examples") as mock_process:
|
||||
mock_process.return_value = ("test examples content", "Note: examples included")
|
||||
|
||||
prompt = await tool.prepare_prompt(request)
|
||||
|
||||
# Check test examples section
|
||||
assert "=== TEST EXAMPLES FOR STYLE REFERENCE ===" in prompt
|
||||
assert "test examples content" in prompt
|
||||
assert "Note: examples included" in prompt
|
||||
|
||||
def test_format_response(self, tool):
|
||||
"""Test response formatting"""
|
||||
request = TestGenRequest(files=["/tmp/test.py"], prompt="Generate tests")
|
||||
|
||||
raw_response = "Generated test cases with edge cases"
|
||||
formatted = tool.format_response(raw_response, request)
|
||||
|
||||
# Check formatting includes next steps
|
||||
assert raw_response in formatted
|
||||
assert "**Next Steps:**" in formatted
|
||||
assert "Review Generated Tests" in formatted
|
||||
assert "Setup Test Environment" in formatted
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_invalid_files(self, tool):
|
||||
"""Test error handling for invalid file paths"""
|
||||
result = await tool.execute(
|
||||
{"files": ["relative/path.py"], "prompt": "Generate tests"} # Invalid: not absolute
|
||||
)
|
||||
|
||||
# Should return error for relative path
|
||||
response_data = json.loads(result[0].text)
|
||||
assert response_data["status"] == "error"
|
||||
assert "absolute" in response_data["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_prompt_handling(self, tool):
|
||||
"""Test handling of large prompts"""
|
||||
large_prompt = "x" * 60000 # Exceeds MCP_PROMPT_SIZE_LIMIT
|
||||
|
||||
result = await tool.execute({"files": ["/tmp/test.py"], "prompt": large_prompt})
|
||||
|
||||
# Should return resend_prompt status
|
||||
response_data = json.loads(result[0].text)
|
||||
assert response_data["status"] == "resend_prompt"
|
||||
assert "too large" in response_data["content"]
|
||||
|
||||
def test_token_budget_calculation(self, tool):
|
||||
"""Test token budget calculation logic"""
|
||||
# Mock model capabilities
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider(context_window=200000)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Simulate model name being set
|
||||
tool._current_model_name = "test-model"
|
||||
|
||||
with patch.object(tool, "_process_test_examples") as mock_process:
|
||||
mock_process.return_value = ("test content", "")
|
||||
|
||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
||||
mock_prepare.return_value = "code content"
|
||||
|
||||
request = TestGenRequest(
|
||||
files=["/tmp/test.py"], prompt="Test prompt", test_examples=["/tmp/example.py"]
|
||||
)
|
||||
|
||||
# This should trigger token budget calculation
|
||||
import asyncio
|
||||
|
||||
asyncio.run(tool.prepare_prompt(request))
|
||||
|
||||
# Verify test examples got 25% of 150k tokens (75% of 200k context)
|
||||
mock_process.assert_called_once()
|
||||
call_args = mock_process.call_args[0]
|
||||
assert call_args[2] == 150000 # 75% of 200k context window
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continuation_support(self, tool, temp_files):
|
||||
"""Test continuation ID support"""
|
||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
||||
mock_prepare.return_value = "code content"
|
||||
|
||||
request = TestGenRequest(
|
||||
files=[temp_files["code_file"]], prompt="Continue testing", continuation_id="test-thread-123"
|
||||
)
|
||||
|
||||
await tool.prepare_prompt(request)
|
||||
|
||||
# Verify continuation_id was passed to _prepare_file_content_for_prompt
|
||||
# The method should be called twice (once for code, once for test examples logic)
|
||||
assert mock_prepare.call_count >= 1
|
||||
|
||||
# Check that continuation_id was passed in at least one call
|
||||
calls = mock_prepare.call_args_list
|
||||
continuation_passed = any(
|
||||
call[0][1] == "test-thread-123" for call in calls # continuation_id is second argument
|
||||
)
|
||||
assert continuation_passed, f"continuation_id not passed. Calls: {calls}"
|
||||
|
||||
def test_no_websearch_in_prompt(self, tool, temp_files):
|
||||
"""Test that web search instructions are not included"""
|
||||
request = TestGenRequest(files=[temp_files["code_file"]], prompt="Generate tests")
|
||||
|
||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
||||
mock_prepare.return_value = "code content"
|
||||
|
||||
import asyncio
|
||||
|
||||
prompt = asyncio.run(tool.prepare_prompt(request))
|
||||
|
||||
# Should not contain web search instructions
|
||||
assert "WEB SEARCH CAPABILITY" not in prompt
|
||||
assert "web search" not in prompt.lower()
|
||||
@@ -284,6 +284,22 @@ class TestAbsolutePathValidation:
|
||||
assert "must be absolute" in response["content"]
|
||||
assert "code.py" in response["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_testgen_tool_relative_path_rejected(self):
|
||||
"""Test that testgen tool rejects relative paths"""
|
||||
from tools import TestGenTool
|
||||
|
||||
tool = TestGenTool()
|
||||
result = await tool.execute(
|
||||
{"files": ["src/main.py"], "prompt": "Generate tests for the functions"} # relative path
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
response = json.loads(result[0].text)
|
||||
assert response["status"] == "error"
|
||||
assert "must be absolute" in response["content"]
|
||||
assert "src/main.py" in response["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.AnalyzeTool.get_model_provider")
|
||||
async def test_analyze_tool_accepts_absolute_paths(self, mock_get_provider):
|
||||
|
||||
Reference in New Issue
Block a user