Files
my-pal-mcp-server/tests/test_cross_tool_continuation.py
2025-06-19 05:37:40 +04:00

373 lines
15 KiB
Python

"""
Test suite for cross-tool continuation functionality
Tests that continuation IDs work properly across different tools,
allowing multi-turn conversations to span multiple tool types.
"""
import json
import os
from unittest.mock import Mock, patch
import pytest
from pydantic import Field
from tests.mock_helpers import create_mock_provider
from tools.base import BaseTool, ToolRequest
from utils.conversation_memory import ConversationTurn, ThreadContext
class AnalysisRequest(ToolRequest):
"""Test request for analysis tool"""
code: str = Field(..., description="Code to analyze")
class ReviewRequest(ToolRequest):
"""Test request for review tool"""
findings: str = Field(..., description="Analysis findings to review")
files: list[str] = Field(default_factory=list, description="Optional files to review")
class MockAnalysisTool(BaseTool):
"""Mock analysis tool for cross-tool testing"""
def get_name(self) -> str:
return "test_analysis"
def get_description(self) -> str:
return "Test analysis tool"
def get_input_schema(self) -> dict:
return {
"type": "object",
"properties": {
"code": {"type": "string"},
"continuation_id": {"type": "string", "required": False},
},
}
def get_system_prompt(self) -> str:
return "Analyze the provided code"
def get_request_model(self):
return AnalysisRequest
async def prepare_prompt(self, request) -> str:
return f"System: {self.get_system_prompt()}\nCode: {request.code}"
class MockReviewTool(BaseTool):
"""Mock review tool for cross-tool testing"""
def get_name(self) -> str:
return "test_review"
def get_description(self) -> str:
return "Test review tool"
def get_input_schema(self) -> dict:
return {
"type": "object",
"properties": {
"findings": {"type": "string"},
"continuation_id": {"type": "string", "required": False},
},
}
def get_system_prompt(self) -> str:
return "Review the analysis findings"
def get_request_model(self):
return ReviewRequest
async def prepare_prompt(self, request) -> str:
return f"System: {self.get_system_prompt()}\nFindings: {request.findings}"
class TestCrossToolContinuation:
"""Test cross-tool continuation functionality"""
def setup_method(self):
self.analysis_tool = MockAnalysisTool()
self.review_tool = MockReviewTool()
@patch("utils.conversation_memory.get_storage")
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_continuation_id_works_across_different_tools(self, mock_storage):
"""Test that a continuation_id from one tool can be used with another tool"""
mock_client = Mock()
mock_storage.return_value = mock_client
# Step 1: Analysis tool creates a conversation with continuation offer
with patch.object(self.analysis_tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
# Simple content without JSON follow-up
content = """Found potential security issues in authentication logic.
I'd be happy to review these security findings in detail if that would be helpful."""
mock_provider.generate_content.return_value = Mock(
content=content,
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.5-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Execute analysis tool
arguments = {"code": "function authenticate(user) { return true; }"}
response = await self.analysis_tool.execute(arguments)
response_data = json.loads(response[0].text)
assert response_data["status"] == "continuation_available"
continuation_id = response_data["continuation_offer"]["continuation_id"]
# Step 2: Mock the existing thread context for the review tool
# The thread was created by analysis_tool but will be continued by review_tool
existing_context = ThreadContext(
thread_id=continuation_id,
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:01:00Z",
tool_name="test_analysis", # Original tool
turns=[
ConversationTurn(
role="assistant",
content="Found potential security issues in authentication logic.\n\nI'd be happy to review these security findings in detail if that would be helpful.",
timestamp="2023-01-01T00:00:30Z",
tool_name="test_analysis", # Original tool
)
],
initial_context={"code": "function authenticate(user) { return true; }"},
)
# Mock the get call to return existing context for add_turn to work
def mock_get_side_effect(key):
if key.startswith("thread:"):
return existing_context.model_dump_json()
return None
mock_client.get.side_effect = mock_get_side_effect
# Step 3: Review tool uses the same continuation_id
with patch.object(self.review_tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.5-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Execute review tool with the continuation_id from analysis tool
arguments = {
"findings": "Authentication bypass vulnerability detected",
"continuation_id": continuation_id,
}
response = await self.review_tool.execute(arguments)
response_data = json.loads(response[0].text)
# Should offer continuation since there are remaining turns available
assert response_data["status"] == "continuation_available"
assert "Critical security vulnerability confirmed" in response_data["content"]
# Step 4: Verify the cross-tool continuation worked
# Should have at least 2 setex calls: 1 from analysis tool follow-up, 1 from review tool add_turn
setex_calls = mock_client.setex.call_args_list
assert len(setex_calls) >= 2 # Analysis tool creates thread + review tool adds turn
# Get the final thread state from the last setex call
final_thread_data = setex_calls[-1][0][2] # Last setex call's data
final_context = json.loads(final_thread_data)
assert final_context["thread_id"] == continuation_id
assert final_context["tool_name"] == "test_analysis" # Original tool name preserved
assert len(final_context["turns"]) == 2 # Original + new turn
# Verify the new turn has the review tool's name
second_turn = final_context["turns"][1]
assert second_turn["role"] == "assistant"
assert second_turn["tool_name"] == "test_review" # New tool name
assert "Critical security vulnerability confirmed" in second_turn["content"]
@patch("utils.conversation_memory.get_storage")
def test_cross_tool_conversation_history_includes_tool_names(self, mock_storage):
"""Test that conversation history properly shows which tool was used for each turn"""
mock_client = Mock()
mock_storage.return_value = mock_client
# Create a thread context with turns from different tools
thread_context = ThreadContext(
thread_id="12345678-1234-1234-1234-123456789012",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:03:00Z",
tool_name="test_analysis", # Original tool
turns=[
ConversationTurn(
role="assistant",
content="Analysis complete: Found 3 issues",
timestamp="2023-01-01T00:01:00Z",
tool_name="test_analysis",
),
ConversationTurn(
role="assistant",
content="Review complete: 2 critical, 1 minor issue",
timestamp="2023-01-01T00:02:00Z",
tool_name="test_review",
),
ConversationTurn(
role="assistant",
content="Deep analysis: Root cause identified",
timestamp="2023-01-01T00:03:00Z",
tool_name="test_thinkdeep",
),
],
initial_context={"code": "test code"},
)
# Build conversation history
from providers.registry import ModelProviderRegistry
from utils.conversation_memory import build_conversation_history
# Set up provider for this test
with patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False):
ModelProviderRegistry.clear_cache()
history, tokens = build_conversation_history(thread_context, model_context=None)
# Verify tool names are included in the history
assert "Turn 1 (Gemini using test_analysis)" in history
assert "Turn 2 (Gemini using test_review)" in history
assert "Turn 3 (Gemini using test_thinkdeep)" in history
assert "Analysis complete: Found 3 issues" in history
assert "Review complete: 2 critical, 1 minor issue" in history
assert "Deep analysis: Root cause identified" in history
@patch("utils.conversation_memory.get_storage")
@patch("utils.conversation_memory.get_thread")
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_cross_tool_conversation_with_files_context(self, mock_get_thread, mock_storage):
"""Test that file context is preserved across tool switches"""
mock_client = Mock()
mock_storage.return_value = mock_client
# Create existing context with files from analysis tool
existing_context = ThreadContext(
thread_id="test-thread-id",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:01:00Z",
tool_name="test_analysis",
turns=[
ConversationTurn(
role="assistant",
content="Analysis of auth.py complete",
timestamp="2023-01-01T00:01:00Z",
tool_name="test_analysis",
files=["/src/auth.py", "/src/utils.py"],
)
],
initial_context={"code": "authentication code", "files": ["/src/auth.py"]},
)
# Mock get_thread to return the existing context
mock_get_thread.return_value = existing_context
# Mock review tool response
with patch.object(self.review_tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Security review of auth.py shows vulnerabilities",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.5-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# Execute review tool with additional files
arguments = {
"findings": "Auth vulnerabilities found",
"continuation_id": "test-thread-id",
"files": ["/src/security.py"], # Additional file for review
}
response = await self.review_tool.execute(arguments)
response_data = json.loads(response[0].text)
assert response_data["status"] == "continuation_available"
# Verify files from both tools are tracked in Redis calls
setex_calls = mock_client.setex.call_args_list
assert len(setex_calls) >= 1 # At least the add_turn call from review tool
# Get the final thread state
final_thread_data = setex_calls[-1][0][2]
final_context = json.loads(final_thread_data)
# Check that the new turn includes the review tool's files
review_turn = final_context["turns"][1] # Second turn (review tool)
assert review_turn["tool_name"] == "test_review"
assert review_turn["files"] == ["/src/security.py"]
# Original turn's files should still be there
analysis_turn = final_context["turns"][0] # First turn (analysis tool)
assert analysis_turn["files"] == ["/src/auth.py", "/src/utils.py"]
@patch("utils.conversation_memory.get_storage")
@patch("utils.conversation_memory.get_thread")
def test_thread_preserves_original_tool_name(self, mock_get_thread, mock_storage):
"""Test that the thread's original tool_name is preserved even when other tools contribute"""
mock_client = Mock()
mock_storage.return_value = mock_client
# Create existing thread from analysis tool
existing_context = ThreadContext(
thread_id="test-thread-id",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:01:00Z",
tool_name="test_analysis", # Original tool
turns=[
ConversationTurn(
role="assistant",
content="Initial analysis",
timestamp="2023-01-01T00:01:00Z",
tool_name="test_analysis",
)
],
initial_context={"code": "test"},
)
# Mock get_thread to return the existing context
mock_get_thread.return_value = existing_context
# Add turn from review tool
from utils.conversation_memory import add_turn
success = add_turn(
"test-thread-id",
"assistant",
"Review completed",
tool_name="test_review", # Different tool
)
# Verify the add_turn succeeded (basic cross-tool functionality test)
assert success
# Verify thread's original tool_name is preserved
setex_calls = mock_client.setex.call_args_list
updated_thread_data = setex_calls[-1][0][2]
updated_context = json.loads(updated_thread_data)
assert updated_context["tool_name"] == "test_analysis" # Original preserved
assert len(updated_context["turns"]) == 2
assert updated_context["turns"][0]["tool_name"] == "test_analysis"
assert updated_context["turns"][1]["tool_name"] == "test_review"
if __name__ == "__main__":
pytest.main([__file__])