Fixed tests
This commit is contained in:
@@ -275,8 +275,6 @@ if __name__ == "__main__":
|
|||||||
step1_file_tokens = 0
|
step1_file_tokens = 0
|
||||||
for log in file_embedding_logs_step1:
|
for log in file_embedding_logs_step1:
|
||||||
# Look for pattern like "successfully embedded 1 files (146 tokens)"
|
# Look for pattern like "successfully embedded 1 files (146 tokens)"
|
||||||
import re
|
|
||||||
|
|
||||||
match = re.search(r"\((\d+) tokens\)", log)
|
match = re.search(r"\((\d+) tokens\)", log)
|
||||||
if match:
|
if match:
|
||||||
step1_file_tokens = int(match.group(1))
|
step1_file_tokens = int(match.group(1))
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ Pytest configuration for Zen MCP Server tests
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -26,9 +27,7 @@ if "OPENAI_API_KEY" not in os.environ:
|
|||||||
os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash"
|
os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash"
|
||||||
|
|
||||||
# Force reload of config module to pick up the env var
|
# Force reload of config module to pick up the env var
|
||||||
import importlib
|
import config # noqa: E402
|
||||||
|
|
||||||
import config
|
|
||||||
|
|
||||||
importlib.reload(config)
|
importlib.reload(config)
|
||||||
|
|
||||||
@@ -43,10 +42,10 @@ if sys.platform == "win32":
|
|||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
# Register providers for all tests
|
# Register providers for all tests
|
||||||
from providers import ModelProviderRegistry
|
from providers import ModelProviderRegistry # noqa: E402
|
||||||
from providers.base import ProviderType
|
from providers.base import ProviderType # noqa: E402
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider # noqa: E402
|
||||||
from providers.openai import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider # noqa: E402
|
||||||
|
|
||||||
# Register providers at test startup
|
# Register providers at test startup
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
Test that conversation history is correctly mapped to tool-specific fields
|
Test that conversation history is correctly mapped to tool-specific fields
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
@@ -129,12 +130,17 @@ async def test_unknown_tool_defaults_to_prompt():
|
|||||||
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
|
with patch("utils.conversation_memory.get_thread", return_value=mock_context):
|
||||||
with patch("utils.conversation_memory.add_turn", return_value=True):
|
with patch("utils.conversation_memory.add_turn", return_value=True):
|
||||||
with patch("utils.conversation_memory.build_conversation_history", return_value=("History", 500)):
|
with patch("utils.conversation_memory.build_conversation_history", return_value=("History", 500)):
|
||||||
arguments = {
|
with patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False):
|
||||||
"continuation_id": "test-thread-456",
|
from providers.registry import ModelProviderRegistry
|
||||||
"prompt": "User input",
|
|
||||||
}
|
|
||||||
|
|
||||||
enhanced_args = await reconstruct_thread_context(arguments)
|
ModelProviderRegistry.clear_cache()
|
||||||
|
|
||||||
|
arguments = {
|
||||||
|
"continuation_id": "test-thread-456",
|
||||||
|
"prompt": "User input",
|
||||||
|
}
|
||||||
|
|
||||||
|
enhanced_args = await reconstruct_thread_context(arguments)
|
||||||
|
|
||||||
# Should default to 'prompt' field
|
# Should default to 'prompt' field
|
||||||
assert "prompt" in enhanced_args
|
assert "prompt" in enhanced_args
|
||||||
|
|||||||
@@ -73,30 +73,10 @@ class TestConversationHistoryBugFix:
|
|||||||
async def test_conversation_history_included_with_continuation_id(self, mock_add_turn):
|
async def test_conversation_history_included_with_continuation_id(self, mock_add_turn):
|
||||||
"""Test that conversation history (including file context) is included when using continuation_id"""
|
"""Test that conversation history (including file context) is included when using continuation_id"""
|
||||||
|
|
||||||
# Create a thread context with previous turns including files
|
# Test setup note: This test simulates a conversation thread with previous turns
|
||||||
_thread_context = ThreadContext(
|
# containing files from different tools (analyze -> codereview)
|
||||||
thread_id="test-history-id",
|
# The continuation_id "test-history-id" references this implicit thread context
|
||||||
created_at="2023-01-01T00:00:00Z",
|
# In the real flow, server.py would reconstruct this context and add it to the prompt
|
||||||
last_updated_at="2023-01-01T00:02:00Z",
|
|
||||||
tool_name="analyze", # Started with analyze tool
|
|
||||||
turns=[
|
|
||||||
ConversationTurn(
|
|
||||||
role="assistant",
|
|
||||||
content="I've analyzed the authentication module and found several security issues.",
|
|
||||||
timestamp="2023-01-01T00:01:00Z",
|
|
||||||
tool_name="analyze",
|
|
||||||
files=["/src/auth.py", "/src/security.py"], # Files from analyze tool
|
|
||||||
),
|
|
||||||
ConversationTurn(
|
|
||||||
role="assistant",
|
|
||||||
content="The code review shows these files have critical vulnerabilities.",
|
|
||||||
timestamp="2023-01-01T00:02:00Z",
|
|
||||||
tool_name="codereview",
|
|
||||||
files=["/src/auth.py", "/tests/test_auth.py"], # Files from codereview tool
|
|
||||||
),
|
|
||||||
],
|
|
||||||
initial_context={"prompt": "Analyze authentication security"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock add_turn to return success
|
# Mock add_turn to return success
|
||||||
mock_add_turn.return_value = True
|
mock_add_turn.return_value = True
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Tests the Redis-based conversation persistence needed for AI-to-AI multi-turn
|
|||||||
discussions in stateless MCP environments.
|
discussions in stateless MCP environments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -136,8 +137,13 @@ class TestConversationMemory:
|
|||||||
|
|
||||||
assert success is False
|
assert success is False
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False)
|
||||||
def test_build_conversation_history(self):
|
def test_build_conversation_history(self):
|
||||||
"""Test building conversation history format with files and speaker identification"""
|
"""Test building conversation history format with files and speaker identification"""
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
ModelProviderRegistry.clear_cache()
|
||||||
|
|
||||||
test_uuid = "12345678-1234-1234-1234-123456789012"
|
test_uuid = "12345678-1234-1234-1234-123456789012"
|
||||||
|
|
||||||
turns = [
|
turns = [
|
||||||
@@ -339,8 +345,13 @@ class TestConversationFlow:
|
|||||||
in error_msg
|
in error_msg
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False)
|
||||||
def test_dynamic_max_turns_configuration(self):
|
def test_dynamic_max_turns_configuration(self):
|
||||||
"""Test that all functions respect MAX_CONVERSATION_TURNS configuration"""
|
"""Test that all functions respect MAX_CONVERSATION_TURNS configuration"""
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
ModelProviderRegistry.clear_cache()
|
||||||
|
|
||||||
# This test ensures if we change MAX_CONVERSATION_TURNS, everything updates
|
# This test ensures if we change MAX_CONVERSATION_TURNS, everything updates
|
||||||
|
|
||||||
# Test with different max values by patching the constant
|
# Test with different max values by patching the constant
|
||||||
@@ -465,8 +476,13 @@ class TestConversationFlow:
|
|||||||
assert success is False, f"Turn {MAX_CONVERSATION_TURNS + 1} should fail"
|
assert success is False, f"Turn {MAX_CONVERSATION_TURNS + 1} should fail"
|
||||||
|
|
||||||
@patch("utils.conversation_memory.get_redis_client")
|
@patch("utils.conversation_memory.get_redis_client")
|
||||||
|
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False)
|
||||||
def test_conversation_with_files_and_context_preservation(self, mock_redis):
|
def test_conversation_with_files_and_context_preservation(self, mock_redis):
|
||||||
"""Test complete conversation flow with file tracking and context preservation"""
|
"""Test complete conversation flow with file tracking and context preservation"""
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
ModelProviderRegistry.clear_cache()
|
||||||
|
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
mock_redis.return_value = mock_client
|
mock_redis.return_value = mock_client
|
||||||
|
|
||||||
@@ -657,11 +673,16 @@ class TestConversationFlow:
|
|||||||
assert retrieved_context is not None
|
assert retrieved_context is not None
|
||||||
assert len(retrieved_context.turns) == 1
|
assert len(retrieved_context.turns) == 1
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False)
|
||||||
def test_token_limit_optimization_in_conversation_history(self):
|
def test_token_limit_optimization_in_conversation_history(self):
|
||||||
"""Test that build_conversation_history efficiently handles token limits"""
|
"""Test that build_conversation_history efficiently handles token limits"""
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
ModelProviderRegistry.clear_cache()
|
||||||
|
|
||||||
from utils.conversation_memory import build_conversation_history
|
from utils.conversation_memory import build_conversation_history
|
||||||
|
|
||||||
# Create test files with known content sizes
|
# Create test files with known content sizes
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ allowing multi-turn conversations to span multiple tool types.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -230,9 +231,13 @@ I'd be happy to review these security findings in detail if that would be helpfu
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build conversation history
|
# Build conversation history
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
from utils.conversation_memory import build_conversation_history
|
from utils.conversation_memory import build_conversation_history
|
||||||
|
|
||||||
history, tokens = build_conversation_history(thread_context, model_context=None)
|
# Set up provider for this test
|
||||||
|
with patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False):
|
||||||
|
ModelProviderRegistry.clear_cache()
|
||||||
|
history, tokens = build_conversation_history(thread_context, model_context=None)
|
||||||
|
|
||||||
# Verify tool names are included in the history
|
# Verify tool names are included in the history
|
||||||
assert "Turn 1 (Gemini using test_analysis)" in history
|
assert "Turn 1 (Gemini using test_analysis)" in history
|
||||||
|
|||||||
Reference in New Issue
Block a user