Generates unit tests and encourages model to auto-detect framework and testing style from existing sample (if available)
382 lines
15 KiB
Python
382 lines
15 KiB
Python
"""
|
|
Tests for TestGen tool implementation
|
|
"""
|
|
|
|
import json
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
|
|
from tests.mock_helpers import create_mock_provider
|
|
from tools.testgen import TestGenRequest, TestGenTool
|
|
|
|
|
|
class TestTestGenTool:
|
|
"""Test the TestGen tool"""
|
|
|
|
@pytest.fixture
|
|
def tool(self):
|
|
return TestGenTool()
|
|
|
|
@pytest.fixture
|
|
def temp_files(self):
|
|
"""Create temporary test files"""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
temp_path = Path(temp_dir)
|
|
|
|
# Create sample code files
|
|
code_file = temp_path / "calculator.py"
|
|
code_file.write_text(
|
|
"""
|
|
def add(a, b):
|
|
'''Add two numbers'''
|
|
return a + b
|
|
|
|
def divide(a, b):
|
|
'''Divide two numbers'''
|
|
if b == 0:
|
|
raise ValueError("Cannot divide by zero")
|
|
return a / b
|
|
"""
|
|
)
|
|
|
|
# Create sample test files (different sizes)
|
|
small_test = temp_path / "test_small.py"
|
|
small_test.write_text(
|
|
"""
|
|
import unittest
|
|
|
|
class TestBasic(unittest.TestCase):
|
|
def test_simple(self):
|
|
self.assertEqual(1 + 1, 2)
|
|
"""
|
|
)
|
|
|
|
large_test = temp_path / "test_large.py"
|
|
large_test.write_text(
|
|
"""
|
|
import unittest
|
|
from unittest.mock import Mock, patch
|
|
|
|
class TestComprehensive(unittest.TestCase):
|
|
def setUp(self):
|
|
self.mock_data = Mock()
|
|
|
|
def test_feature_one(self):
|
|
# Comprehensive test with lots of setup
|
|
result = self.process_data()
|
|
self.assertIsNotNone(result)
|
|
|
|
def test_feature_two(self):
|
|
# Another comprehensive test
|
|
with patch('some.module') as mock_module:
|
|
mock_module.return_value = 'test'
|
|
result = self.process_data()
|
|
self.assertEqual(result, 'expected')
|
|
|
|
def process_data(self):
|
|
return "test_result"
|
|
"""
|
|
)
|
|
|
|
yield {
|
|
"temp_dir": temp_dir,
|
|
"code_file": str(code_file),
|
|
"small_test": str(small_test),
|
|
"large_test": str(large_test),
|
|
}
|
|
|
|
def test_tool_metadata(self, tool):
|
|
"""Test tool metadata"""
|
|
assert tool.get_name() == "testgen"
|
|
assert "COMPREHENSIVE TEST GENERATION" in tool.get_description()
|
|
assert "BE SPECIFIC about scope" in tool.get_description()
|
|
assert tool.get_default_temperature() == 0.2 # Analytical temperature
|
|
|
|
# Check model category
|
|
from tools.models import ToolModelCategory
|
|
|
|
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
|
|
|
|
def test_input_schema_structure(self, tool):
|
|
"""Test input schema structure"""
|
|
schema = tool.get_input_schema()
|
|
|
|
# Required fields
|
|
assert "files" in schema["properties"]
|
|
assert "prompt" in schema["properties"]
|
|
assert "files" in schema["required"]
|
|
assert "prompt" in schema["required"]
|
|
|
|
# Optional fields
|
|
assert "test_examples" in schema["properties"]
|
|
assert "thinking_mode" in schema["properties"]
|
|
assert "continuation_id" in schema["properties"]
|
|
|
|
# Should not have temperature or use_websearch
|
|
assert "temperature" not in schema["properties"]
|
|
assert "use_websearch" not in schema["properties"]
|
|
|
|
# Check test_examples description
|
|
test_examples_desc = schema["properties"]["test_examples"]["description"]
|
|
assert "absolute paths" in test_examples_desc
|
|
assert "smallest representative tests" in test_examples_desc
|
|
|
|
def test_request_model_validation(self):
|
|
"""Test request model validation"""
|
|
# Valid request
|
|
valid_request = TestGenRequest(files=["/tmp/test.py"], prompt="Generate tests for calculator functions")
|
|
assert valid_request.files == ["/tmp/test.py"]
|
|
assert valid_request.prompt == "Generate tests for calculator functions"
|
|
assert valid_request.test_examples is None
|
|
|
|
# With test examples
|
|
request_with_examples = TestGenRequest(
|
|
files=["/tmp/test.py"], prompt="Generate tests", test_examples=["/tmp/test_example.py"]
|
|
)
|
|
assert request_with_examples.test_examples == ["/tmp/test_example.py"]
|
|
|
|
# Invalid request (missing required fields)
|
|
with pytest.raises(ValueError):
|
|
TestGenRequest(files=["/tmp/test.py"]) # Missing prompt
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("tools.base.BaseTool.get_model_provider")
|
|
async def test_execute_success(self, mock_get_provider, tool, temp_files):
|
|
"""Test successful execution"""
|
|
# Mock provider
|
|
mock_provider = create_mock_provider()
|
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
|
mock_provider.generate_content.return_value = Mock(
|
|
content="Generated comprehensive test suite with edge cases",
|
|
usage={"input_tokens": 100, "output_tokens": 200},
|
|
model_name="gemini-2.5-flash-preview-05-20",
|
|
metadata={"finish_reason": "STOP"},
|
|
)
|
|
mock_get_provider.return_value = mock_provider
|
|
|
|
result = await tool.execute(
|
|
{"files": [temp_files["code_file"]], "prompt": "Generate comprehensive tests for the calculator functions"}
|
|
)
|
|
|
|
# Verify result structure
|
|
assert len(result) == 1
|
|
response_data = json.loads(result[0].text)
|
|
assert response_data["status"] == "success"
|
|
assert "Generated comprehensive test suite" in response_data["content"]
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("tools.base.BaseTool.get_model_provider")
|
|
async def test_execute_with_test_examples(self, mock_get_provider, tool, temp_files):
|
|
"""Test execution with test examples"""
|
|
mock_provider = create_mock_provider()
|
|
mock_provider.generate_content.return_value = Mock(
|
|
content="Generated tests following the provided examples",
|
|
usage={"input_tokens": 150, "output_tokens": 250},
|
|
model_name="gemini-2.5-flash-preview-05-20",
|
|
metadata={"finish_reason": "STOP"},
|
|
)
|
|
mock_get_provider.return_value = mock_provider
|
|
|
|
result = await tool.execute(
|
|
{
|
|
"files": [temp_files["code_file"]],
|
|
"prompt": "Generate tests following existing patterns",
|
|
"test_examples": [temp_files["small_test"]],
|
|
}
|
|
)
|
|
|
|
# Verify result
|
|
assert len(result) == 1
|
|
response_data = json.loads(result[0].text)
|
|
assert response_data["status"] == "success"
|
|
|
|
def test_process_test_examples_empty(self, tool):
|
|
"""Test processing empty test examples"""
|
|
content, note = tool._process_test_examples([], None)
|
|
assert content == ""
|
|
assert note == ""
|
|
|
|
def test_process_test_examples_budget_allocation(self, tool, temp_files):
|
|
"""Test token budget allocation for test examples"""
|
|
with patch.object(tool, "filter_new_files") as mock_filter:
|
|
mock_filter.return_value = [temp_files["small_test"], temp_files["large_test"]]
|
|
|
|
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
|
mock_prepare.return_value = "Mocked test content"
|
|
|
|
# Test with available tokens
|
|
content, note = tool._process_test_examples(
|
|
[temp_files["small_test"], temp_files["large_test"]], None, available_tokens=100000
|
|
)
|
|
|
|
# Should allocate 25% of 100k = 25k tokens for test examples
|
|
mock_prepare.assert_called_once()
|
|
call_args = mock_prepare.call_args
|
|
assert call_args[1]["max_tokens"] == 25000 # 25% of 100k
|
|
|
|
def test_process_test_examples_size_sorting(self, tool, temp_files):
|
|
"""Test that test examples are sorted by size (smallest first)"""
|
|
with patch.object(tool, "filter_new_files") as mock_filter:
|
|
# Return files in random order
|
|
mock_filter.return_value = [temp_files["large_test"], temp_files["small_test"]]
|
|
|
|
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
|
mock_prepare.return_value = "test content"
|
|
|
|
tool._process_test_examples(
|
|
[temp_files["large_test"], temp_files["small_test"]], None, available_tokens=50000
|
|
)
|
|
|
|
# Check that files were passed in size order (smallest first)
|
|
call_args = mock_prepare.call_args[0]
|
|
files_passed = call_args[0]
|
|
|
|
# Verify smallest file comes first
|
|
assert files_passed[0] == temp_files["small_test"]
|
|
assert files_passed[1] == temp_files["large_test"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prepare_prompt_structure(self, tool, temp_files):
|
|
"""Test prompt preparation structure"""
|
|
request = TestGenRequest(files=[temp_files["code_file"]], prompt="Test the calculator functions")
|
|
|
|
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
|
mock_prepare.return_value = "mocked file content"
|
|
|
|
prompt = await tool.prepare_prompt(request)
|
|
|
|
# Check prompt structure
|
|
assert "=== USER CONTEXT ===" in prompt
|
|
assert "Test the calculator functions" in prompt
|
|
assert "=== CODE TO TEST ===" in prompt
|
|
assert "mocked file content" in prompt
|
|
assert tool.get_system_prompt() in prompt
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prepare_prompt_with_examples(self, tool, temp_files):
|
|
"""Test prompt preparation with test examples"""
|
|
request = TestGenRequest(
|
|
files=[temp_files["code_file"]], prompt="Generate tests", test_examples=[temp_files["small_test"]]
|
|
)
|
|
|
|
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
|
mock_prepare.return_value = "mocked content"
|
|
|
|
with patch.object(tool, "_process_test_examples") as mock_process:
|
|
mock_process.return_value = ("test examples content", "Note: examples included")
|
|
|
|
prompt = await tool.prepare_prompt(request)
|
|
|
|
# Check test examples section
|
|
assert "=== TEST EXAMPLES FOR STYLE REFERENCE ===" in prompt
|
|
assert "test examples content" in prompt
|
|
assert "Note: examples included" in prompt
|
|
|
|
def test_format_response(self, tool):
|
|
"""Test response formatting"""
|
|
request = TestGenRequest(files=["/tmp/test.py"], prompt="Generate tests")
|
|
|
|
raw_response = "Generated test cases with edge cases"
|
|
formatted = tool.format_response(raw_response, request)
|
|
|
|
# Check formatting includes next steps
|
|
assert raw_response in formatted
|
|
assert "**Next Steps:**" in formatted
|
|
assert "Review Generated Tests" in formatted
|
|
assert "Setup Test Environment" in formatted
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_handling_invalid_files(self, tool):
|
|
"""Test error handling for invalid file paths"""
|
|
result = await tool.execute(
|
|
{"files": ["relative/path.py"], "prompt": "Generate tests"} # Invalid: not absolute
|
|
)
|
|
|
|
# Should return error for relative path
|
|
response_data = json.loads(result[0].text)
|
|
assert response_data["status"] == "error"
|
|
assert "absolute" in response_data["content"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_large_prompt_handling(self, tool):
|
|
"""Test handling of large prompts"""
|
|
large_prompt = "x" * 60000 # Exceeds MCP_PROMPT_SIZE_LIMIT
|
|
|
|
result = await tool.execute({"files": ["/tmp/test.py"], "prompt": large_prompt})
|
|
|
|
# Should return resend_prompt status
|
|
response_data = json.loads(result[0].text)
|
|
assert response_data["status"] == "resend_prompt"
|
|
assert "too large" in response_data["content"]
|
|
|
|
def test_token_budget_calculation(self, tool):
|
|
"""Test token budget calculation logic"""
|
|
# Mock model capabilities
|
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
|
mock_provider = create_mock_provider(context_window=200000)
|
|
mock_get_provider.return_value = mock_provider
|
|
|
|
# Simulate model name being set
|
|
tool._current_model_name = "test-model"
|
|
|
|
with patch.object(tool, "_process_test_examples") as mock_process:
|
|
mock_process.return_value = ("test content", "")
|
|
|
|
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
|
mock_prepare.return_value = "code content"
|
|
|
|
request = TestGenRequest(
|
|
files=["/tmp/test.py"], prompt="Test prompt", test_examples=["/tmp/example.py"]
|
|
)
|
|
|
|
# This should trigger token budget calculation
|
|
import asyncio
|
|
|
|
asyncio.run(tool.prepare_prompt(request))
|
|
|
|
# Verify test examples got 25% of 150k tokens (75% of 200k context)
|
|
mock_process.assert_called_once()
|
|
call_args = mock_process.call_args[0]
|
|
assert call_args[2] == 150000 # 75% of 200k context window
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_continuation_support(self, tool, temp_files):
|
|
"""Test continuation ID support"""
|
|
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
|
mock_prepare.return_value = "code content"
|
|
|
|
request = TestGenRequest(
|
|
files=[temp_files["code_file"]], prompt="Continue testing", continuation_id="test-thread-123"
|
|
)
|
|
|
|
await tool.prepare_prompt(request)
|
|
|
|
# Verify continuation_id was passed to _prepare_file_content_for_prompt
|
|
# The method should be called twice (once for code, once for test examples logic)
|
|
assert mock_prepare.call_count >= 1
|
|
|
|
# Check that continuation_id was passed in at least one call
|
|
calls = mock_prepare.call_args_list
|
|
continuation_passed = any(
|
|
call[0][1] == "test-thread-123" for call in calls # continuation_id is second argument
|
|
)
|
|
assert continuation_passed, f"continuation_id not passed. Calls: {calls}"
|
|
|
|
def test_no_websearch_in_prompt(self, tool, temp_files):
|
|
"""Test that web search instructions are not included"""
|
|
request = TestGenRequest(files=[temp_files["code_file"]], prompt="Generate tests")
|
|
|
|
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare:
|
|
mock_prepare.return_value = "code content"
|
|
|
|
import asyncio
|
|
|
|
prompt = asyncio.run(tool.prepare_prompt(request))
|
|
|
|
# Should not contain web search instructions
|
|
assert "WEB SEARCH CAPABILITY" not in prompt
|
|
assert "web search" not in prompt.lower()
|