New tool: testgen

Generates unit tests and encourages model to auto-detect framework and testing style from existing sample (if available)
This commit is contained in:
Fahad
2025-06-14 15:41:47 +04:00
parent 7d33aafcab
commit 4086306c58
14 changed files with 1118 additions and 9 deletions

View File

@@ -26,10 +26,11 @@ class TestServerTools:
assert "analyze" in tool_names
assert "chat" in tool_names
assert "precommit" in tool_names
assert "testgen" in tool_names
assert "get_version" in tool_names
# Should have exactly 7 tools
assert len(tools) == 7
# Should have exactly 8 tools (including testgen)
assert len(tools) == 8
# Check descriptions are verbose
for tool in tools:

381
tests/test_testgen.py Normal file
View File

@@ -0,0 +1,381 @@
"""
Tests for TestGen tool implementation
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import Mock, patch
import pytest
from tests.mock_helpers import create_mock_provider
from tools.testgen import TestGenRequest, TestGenTool
class TestTestGenTool:
"""Test the TestGen tool"""
@pytest.fixture
def tool(self):
return TestGenTool()
@pytest.fixture
def temp_files(self):
"""Create temporary test files"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create sample code files
code_file = temp_path / "calculator.py"
code_file.write_text(
"""
def add(a, b):
'''Add two numbers'''
return a + b
def divide(a, b):
'''Divide two numbers'''
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
"""
)
# Create sample test files (different sizes)
small_test = temp_path / "test_small.py"
small_test.write_text(
"""
import unittest
class TestBasic(unittest.TestCase):
def test_simple(self):
self.assertEqual(1 + 1, 2)
"""
)
large_test = temp_path / "test_large.py"
large_test.write_text(
"""
import unittest
from unittest.mock import Mock, patch
class TestComprehensive(unittest.TestCase):
def setUp(self):
self.mock_data = Mock()
def test_feature_one(self):
# Comprehensive test with lots of setup
result = self.process_data()
self.assertIsNotNone(result)
def test_feature_two(self):
# Another comprehensive test
with patch('some.module') as mock_module:
mock_module.return_value = 'test'
result = self.process_data()
self.assertEqual(result, 'expected')
def process_data(self):
return "test_result"
"""
)
yield {
"temp_dir": temp_dir,
"code_file": str(code_file),
"small_test": str(small_test),
"large_test": str(large_test),
}
def test_tool_metadata(self, tool):
"""Test tool metadata"""
assert tool.get_name() == "testgen"
assert "COMPREHENSIVE TEST GENERATION" in tool.get_description()
assert "BE SPECIFIC about scope" in tool.get_description()
assert tool.get_default_temperature() == 0.2 # Analytical temperature
# Check model category
from tools.models import ToolModelCategory
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
def test_input_schema_structure(self, tool):
"""Test input schema structure"""
schema = tool.get_input_schema()
# Required fields
assert "files" in schema["properties"]
assert "prompt" in schema["properties"]
assert "files" in schema["required"]
assert "prompt" in schema["required"]
# Optional fields
assert "test_examples" in schema["properties"]
assert "thinking_mode" in schema["properties"]
assert "continuation_id" in schema["properties"]
# Should not have temperature or use_websearch
assert "temperature" not in schema["properties"]
assert "use_websearch" not in schema["properties"]
# Check test_examples description
test_examples_desc = schema["properties"]["test_examples"]["description"]
assert "absolute paths" in test_examples_desc
assert "smallest representative tests" in test_examples_desc
def test_request_model_validation(self):
"""Test request model validation"""
# Valid request
valid_request = TestGenRequest(files=["/tmp/test.py"], prompt="Generate tests for calculator functions")
assert valid_request.files == ["/tmp/test.py"]
assert valid_request.prompt == "Generate tests for calculator functions"
assert valid_request.test_examples is None
# With test examples
request_with_examples = TestGenRequest(
files=["/tmp/test.py"], prompt="Generate tests", test_examples=["/tmp/test_example.py"]
)
assert request_with_examples.test_examples == ["/tmp/test_example.py"]
# Invalid request (missing required fields)
with pytest.raises(ValueError):
TestGenRequest(files=["/tmp/test.py"]) # Missing prompt
@pytest.mark.asyncio
@patch("tools.base.BaseTool.get_model_provider")
async def test_execute_success(self, mock_get_provider, tool, temp_files):
"""Test successful execution"""
# Mock provider
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.generate_content.return_value = Mock(
content="Generated comprehensive test suite with edge cases",
usage={"input_tokens": 100, "output_tokens": 200},
model_name="gemini-2.5-flash-preview-05-20",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
result = await tool.execute(
{"files": [temp_files["code_file"]], "prompt": "Generate comprehensive tests for the calculator functions"}
)
# Verify result structure
assert len(result) == 1
response_data = json.loads(result[0].text)
assert response_data["status"] == "success"
assert "Generated comprehensive test suite" in response_data["content"]
@pytest.mark.asyncio
@patch("tools.base.BaseTool.get_model_provider")
async def test_execute_with_test_examples(self, mock_get_provider, tool, temp_files):
"""Test execution with test examples"""
mock_provider = create_mock_provider()
mock_provider.generate_content.return_value = Mock(
content="Generated tests following the provided examples",
usage={"input_tokens": 150, "output_tokens": 250},
model_name="gemini-2.5-flash-preview-05-20",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
result = await tool.execute(
{
"files": [temp_files["code_file"]],
"prompt": "Generate tests following existing patterns",
"test_examples": [temp_files["small_test"]],
}
)
# Verify result
assert len(result) == 1
response_data = json.loads(result[0].text)
assert response_data["status"] == "success"
def test_process_test_examples_empty(self, tool):
"""Test processing empty test examples"""
content, note = tool._process_test_examples([], None)
assert content == ""
assert note == ""
def test_process_test_examples_budget_allocation(self, tool, temp_files):
"""Test token budget allocation for test examples"""
with patch.object(tool, "filter_new_files") as mock_filter:
mock_filter.return_value = [temp_files["small_test"], temp_files["large_test"]]
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
mock_prepare.return_value = "Mocked test content"
# Test with available tokens
content, note = tool._process_test_examples(
[temp_files["small_test"], temp_files["large_test"]], None, available_tokens=100000
)
# Should allocate 25% of 100k = 25k tokens for test examples
mock_prepare.assert_called_once()
call_args = mock_prepare.call_args
assert call_args[1]["max_tokens"] == 25000 # 25% of 100k
def test_process_test_examples_size_sorting(self, tool, temp_files):
"""Test that test examples are sorted by size (smallest first)"""
with patch.object(tool, "filter_new_files") as mock_filter:
# Return files in random order
mock_filter.return_value = [temp_files["large_test"], temp_files["small_test"]]
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
mock_prepare.return_value = "test content"
tool._process_test_examples(
[temp_files["large_test"], temp_files["small_test"]], None, available_tokens=50000
)
# Check that files were passed in size order (smallest first)
call_args = mock_prepare.call_args[0]
files_passed = call_args[0]
# Verify smallest file comes first
assert files_passed[0] == temp_files["small_test"]
assert files_passed[1] == temp_files["large_test"]
@pytest.mark.asyncio
async def test_prepare_prompt_structure(self, tool, temp_files):
"""Test prompt preparation structure"""
request = TestGenRequest(files=[temp_files["code_file"]], prompt="Test the calculator functions")
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
mock_prepare.return_value = "mocked file content"
prompt = await tool.prepare_prompt(request)
# Check prompt structure
assert "=== USER CONTEXT ===" in prompt
assert "Test the calculator functions" in prompt
assert "=== CODE TO TEST ===" in prompt
assert "mocked file content" in prompt
assert tool.get_system_prompt() in prompt
@pytest.mark.asyncio
async def test_prepare_prompt_with_examples(self, tool, temp_files):
"""Test prompt preparation with test examples"""
request = TestGenRequest(
files=[temp_files["code_file"]], prompt="Generate tests", test_examples=[temp_files["small_test"]]
)
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
mock_prepare.return_value = "mocked content"
with patch.object(tool, "_process_test_examples") as mock_process:
mock_process.return_value = ("test examples content", "Note: examples included")
prompt = await tool.prepare_prompt(request)
# Check test examples section
assert "=== TEST EXAMPLES FOR STYLE REFERENCE ===" in prompt
assert "test examples content" in prompt
assert "Note: examples included" in prompt
def test_format_response(self, tool):
"""Test response formatting"""
request = TestGenRequest(files=["/tmp/test.py"], prompt="Generate tests")
raw_response = "Generated test cases with edge cases"
formatted = tool.format_response(raw_response, request)
# Check formatting includes next steps
assert raw_response in formatted
assert "**Next Steps:**" in formatted
assert "Review Generated Tests" in formatted
assert "Setup Test Environment" in formatted
@pytest.mark.asyncio
async def test_error_handling_invalid_files(self, tool):
"""Test error handling for invalid file paths"""
result = await tool.execute(
{"files": ["relative/path.py"], "prompt": "Generate tests"} # Invalid: not absolute
)
# Should return error for relative path
response_data = json.loads(result[0].text)
assert response_data["status"] == "error"
assert "absolute" in response_data["content"]
@pytest.mark.asyncio
async def test_large_prompt_handling(self, tool):
"""Test handling of large prompts"""
large_prompt = "x" * 60000 # Exceeds MCP_PROMPT_SIZE_LIMIT
result = await tool.execute({"files": ["/tmp/test.py"], "prompt": large_prompt})
# Should return resend_prompt status
response_data = json.loads(result[0].text)
assert response_data["status"] == "resend_prompt"
assert "too large" in response_data["content"]
def test_token_budget_calculation(self, tool):
"""Test token budget calculation logic"""
# Mock model capabilities
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider(context_window=200000)
mock_get_provider.return_value = mock_provider
# Simulate model name being set
tool._current_model_name = "test-model"
with patch.object(tool, "_process_test_examples") as mock_process:
mock_process.return_value = ("test content", "")
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
mock_prepare.return_value = "code content"
request = TestGenRequest(
files=["/tmp/test.py"], prompt="Test prompt", test_examples=["/tmp/example.py"]
)
# This should trigger token budget calculation
import asyncio
asyncio.run(tool.prepare_prompt(request))
# Verify test examples got 25% of 150k tokens (75% of 200k context)
mock_process.assert_called_once()
call_args = mock_process.call_args[0]
assert call_args[2] == 150000 # 75% of 200k context window
@pytest.mark.asyncio
async def test_continuation_support(self, tool, temp_files):
"""Test continuation ID support"""
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
mock_prepare.return_value = "code content"
request = TestGenRequest(
files=[temp_files["code_file"]], prompt="Continue testing", continuation_id="test-thread-123"
)
await tool.prepare_prompt(request)
# Verify continuation_id was passed to _prepare_file_content_for_prompt
# The method should be called twice (once for code, once for test examples logic)
assert mock_prepare.call_count >= 1
# Check that continuation_id was passed in at least one call
calls = mock_prepare.call_args_list
continuation_passed = any(
call[0][1] == "test-thread-123" for call in calls # continuation_id is second argument
)
assert continuation_passed, f"continuation_id not passed. Calls: {calls}"
def test_no_websearch_in_prompt(self, tool, temp_files):
"""Test that web search instructions are not included"""
request = TestGenRequest(files=[temp_files["code_file"]], prompt="Generate tests")
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
mock_prepare.return_value = "code content"
import asyncio
prompt = asyncio.run(tool.prepare_prompt(request))
# Should not contain web search instructions
assert "WEB SEARCH CAPABILITY" not in prompt
assert "web search" not in prompt.lower()

View File

@@ -284,6 +284,22 @@ class TestAbsolutePathValidation:
assert "must be absolute" in response["content"]
assert "code.py" in response["content"]
@pytest.mark.asyncio
async def test_testgen_tool_relative_path_rejected(self):
"""Test that testgen tool rejects relative paths"""
from tools import TestGenTool
tool = TestGenTool()
result = await tool.execute(
{"files": ["src/main.py"], "prompt": "Generate tests for the functions"} # relative path
)
assert len(result) == 1
response = json.loads(result[0].text)
assert response["status"] == "error"
assert "must be absolute" in response["content"]
assert "src/main.py" in response["content"]
@pytest.mark.asyncio
@patch("tools.AnalyzeTool.get_model_provider")
async def test_analyze_tool_accepts_absolute_paths(self, mock_get_provider):