Lots of tests with live simulation to validate conversation continuation / preservation work across requests
This commit is contained in:
255
simulator_tests/base_test.py
Normal file
255
simulator_tests/base_test.py
Normal file
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Base Test Class for Communication Simulator Tests
|
||||
|
||||
Provides common functionality and utilities for all simulator tests.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class BaseSimulatorTest:
|
||||
"""Base class for all communication simulator tests"""
|
||||
|
||||
def __init__(self, verbose: bool = False):
|
||||
self.verbose = verbose
|
||||
self.test_files = {}
|
||||
self.test_dir = None
|
||||
self.container_name = "gemini-mcp-server"
|
||||
self.redis_container = "gemini-mcp-redis"
|
||||
|
||||
# Configure logging
|
||||
log_level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
def setup_test_files(self):
|
||||
"""Create test files for the simulation"""
|
||||
# Test Python file
|
||||
python_content = '''"""
|
||||
Sample Python module for testing MCP conversation continuity
|
||||
"""
|
||||
|
||||
def fibonacci(n):
|
||||
"""Calculate fibonacci number recursively"""
|
||||
if n <= 1:
|
||||
return n
|
||||
return fibonacci(n-1) + fibonacci(n-2)
|
||||
|
||||
def factorial(n):
|
||||
"""Calculate factorial iteratively"""
|
||||
result = 1
|
||||
for i in range(1, n + 1):
|
||||
result *= i
|
||||
return result
|
||||
|
||||
class Calculator:
|
||||
"""Simple calculator class"""
|
||||
|
||||
def __init__(self):
|
||||
self.history = []
|
||||
|
||||
def add(self, a, b):
|
||||
result = a + b
|
||||
self.history.append(f"{a} + {b} = {result}")
|
||||
return result
|
||||
|
||||
def multiply(self, a, b):
|
||||
result = a * b
|
||||
self.history.append(f"{a} * {b} = {result}")
|
||||
return result
|
||||
'''
|
||||
|
||||
# Test configuration file
|
||||
config_content = """{
|
||||
"database": {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"name": "testdb",
|
||||
"ssl": true
|
||||
},
|
||||
"cache": {
|
||||
"redis_url": "redis://localhost:6379",
|
||||
"ttl": 3600
|
||||
},
|
||||
"logging": {
|
||||
"level": "INFO",
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
}
|
||||
}"""
|
||||
|
||||
# Create files in the current project directory
|
||||
current_dir = os.getcwd()
|
||||
self.test_dir = os.path.join(current_dir, "test_simulation_files")
|
||||
os.makedirs(self.test_dir, exist_ok=True)
|
||||
|
||||
test_py = os.path.join(self.test_dir, "test_module.py")
|
||||
test_config = os.path.join(self.test_dir, "config.json")
|
||||
|
||||
with open(test_py, "w") as f:
|
||||
f.write(python_content)
|
||||
with open(test_config, "w") as f:
|
||||
f.write(config_content)
|
||||
|
||||
self.test_files = {"python": test_py, "config": test_config}
|
||||
self.logger.debug(f"Created test files: {list(self.test_files.values())}")
|
||||
|
||||
def call_mcp_tool(self, tool_name: str, params: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Call an MCP tool via Claude CLI (docker exec)"""
|
||||
try:
|
||||
# Prepare the MCP initialization and tool call sequence
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {"name": "communication-simulator", "version": "1.0.0"},
|
||||
},
|
||||
}
|
||||
|
||||
# Send initialized notification
|
||||
initialized_notification = {"jsonrpc": "2.0", "method": "notifications/initialized"}
|
||||
|
||||
# Prepare the tool call request
|
||||
tool_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/call",
|
||||
"params": {"name": tool_name, "arguments": params},
|
||||
}
|
||||
|
||||
# Combine all messages
|
||||
messages = [json.dumps(init_request), json.dumps(initialized_notification), json.dumps(tool_request)]
|
||||
|
||||
# Join with newlines as MCP expects
|
||||
input_data = "\n".join(messages) + "\n"
|
||||
|
||||
# Simulate Claude CLI calling the MCP server via docker exec
|
||||
docker_cmd = ["docker", "exec", "-i", self.container_name, "python", "server.py"]
|
||||
|
||||
self.logger.debug(f"Calling MCP tool {tool_name} with proper initialization")
|
||||
|
||||
# Execute the command
|
||||
result = subprocess.run(
|
||||
docker_cmd, input=input_data, text=True, capture_output=True, timeout=300 # 5 minute timeout
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
self.logger.error(f"Docker exec failed: {result.stderr}")
|
||||
return None, None
|
||||
|
||||
# Parse the response - look for the tool call response
|
||||
response_data = self._parse_mcp_response(result.stdout, expected_id=2)
|
||||
if not response_data:
|
||||
return None, None
|
||||
|
||||
# Extract continuation_id if present
|
||||
continuation_id = self._extract_continuation_id(response_data)
|
||||
|
||||
return response_data, continuation_id
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
self.logger.error(f"MCP tool call timed out: {tool_name}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.logger.error(f"MCP tool call failed: {e}")
|
||||
return None, None
|
||||
|
||||
def _parse_mcp_response(self, stdout: str, expected_id: int = 2) -> Optional[str]:
|
||||
"""Parse MCP JSON-RPC response from stdout"""
|
||||
try:
|
||||
lines = stdout.strip().split("\n")
|
||||
for line in lines:
|
||||
if line.strip() and line.startswith("{"):
|
||||
response = json.loads(line)
|
||||
# Look for the tool call response with the expected ID
|
||||
if response.get("id") == expected_id and "result" in response:
|
||||
# Extract the actual content from the response
|
||||
result = response["result"]
|
||||
# Handle new response format with 'content' array
|
||||
if isinstance(result, dict) and "content" in result:
|
||||
content_array = result["content"]
|
||||
if isinstance(content_array, list) and len(content_array) > 0:
|
||||
return content_array[0].get("text", "")
|
||||
# Handle legacy format
|
||||
elif isinstance(result, list) and len(result) > 0:
|
||||
return result[0].get("text", "")
|
||||
elif response.get("id") == expected_id and "error" in response:
|
||||
self.logger.error(f"MCP error: {response['error']}")
|
||||
return None
|
||||
|
||||
# If we get here, log all responses for debugging
|
||||
self.logger.warning(f"No valid tool call response found for ID {expected_id}")
|
||||
self.logger.debug(f"Full stdout: {stdout}")
|
||||
return None
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.error(f"Failed to parse MCP response: {e}")
|
||||
self.logger.debug(f"Stdout that failed to parse: {stdout}")
|
||||
return None
|
||||
|
||||
def _extract_continuation_id(self, response_text: str) -> Optional[str]:
|
||||
"""Extract continuation_id from response metadata"""
|
||||
try:
|
||||
# Parse the response text as JSON to look for continuation metadata
|
||||
response_data = json.loads(response_text)
|
||||
|
||||
# Look for continuation_id in various places
|
||||
if isinstance(response_data, dict):
|
||||
# Check metadata
|
||||
metadata = response_data.get("metadata", {})
|
||||
if "thread_id" in metadata:
|
||||
return metadata["thread_id"]
|
||||
|
||||
# Check follow_up_request
|
||||
follow_up = response_data.get("follow_up_request", {})
|
||||
if follow_up and "continuation_id" in follow_up:
|
||||
return follow_up["continuation_id"]
|
||||
|
||||
# Check continuation_offer
|
||||
continuation_offer = response_data.get("continuation_offer", {})
|
||||
if continuation_offer and "continuation_id" in continuation_offer:
|
||||
return continuation_offer["continuation_id"]
|
||||
|
||||
self.logger.debug(f"No continuation_id found in response: {response_data}")
|
||||
return None
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.debug(f"Failed to parse response for continuation_id: {e}")
|
||||
return None
|
||||
|
||||
def run_command(self, cmd: list[str], check: bool = True, capture_output: bool = False, **kwargs):
|
||||
"""Run a shell command with logging"""
|
||||
if self.verbose:
|
||||
self.logger.debug(f"Running: {' '.join(cmd)}")
|
||||
|
||||
return subprocess.run(cmd, check=check, capture_output=capture_output, **kwargs)
|
||||
|
||||
def cleanup_test_files(self):
|
||||
"""Clean up test files"""
|
||||
if hasattr(self, "test_dir") and self.test_dir and os.path.exists(self.test_dir):
|
||||
import shutil
|
||||
shutil.rmtree(self.test_dir)
|
||||
self.logger.debug(f"Removed test files directory: {self.test_dir}")
|
||||
|
||||
def run_test(self) -> bool:
|
||||
"""Run the test - to be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement run_test()")
|
||||
|
||||
@property
|
||||
def test_name(self) -> str:
|
||||
"""Get the test name - to be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement test_name property")
|
||||
|
||||
@property
|
||||
def test_description(self) -> str:
|
||||
"""Get the test description - to be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement test_description property")
|
||||
Reference in New Issue
Block a user