Migration from Docker to Standalone Python Server (#73)
* Migration from docker to standalone server Migration handling Fixed tests Use simpler in-memory storage Support for concurrent logging to disk Simplified direct connections to localhost * Migration from docker / redis to standalone script Updated tests Updated run script Fixed requirements Use dotenv Ask if user would like to install MCP in Claude Desktop once Updated docs * More cleanup and references to docker removed * Cleanup * Comments * Fixed tests * Fix GitHub Actions workflow for standalone Python architecture - Install requirements-dev.txt for pytest and testing dependencies - Remove Docker setup from simulation tests (now standalone) - Simplify linting job to use requirements-dev.txt - Update simulation tests to run directly without Docker Fixes unit test failures in CI due to missing pytest dependency. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Remove simulation tests from GitHub Actions - Removed simulation-tests job that makes real API calls - Keep only unit tests (mocked, no API costs) and linting - Simulation tests should be run manually with real API keys - Reduces CI costs and complexity GitHub Actions now only runs: - Unit tests (569 tests, all mocked) - Code quality checks (ruff, black) 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Fixed tests * Fixed tests --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
9d72545ecd
commit
4151c3c3a5
@@ -6,7 +6,6 @@ import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -33,11 +32,8 @@ import config # noqa: E402
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Set WORKSPACE_ROOT to a temporary directory for tests
|
||||
# This provides a safe sandbox for file operations during testing
|
||||
# Create a temporary directory that will be used as the workspace for all tests
|
||||
test_root = tempfile.mkdtemp(prefix="zen_mcp_test_")
|
||||
os.environ["WORKSPACE_ROOT"] = test_root
|
||||
# Note: This creates a test sandbox environment
|
||||
# Tests create their own temporary directories as needed
|
||||
|
||||
# Configure asyncio for Windows compatibility
|
||||
if sys.platform == "win32":
|
||||
@@ -47,7 +43,7 @@ if sys.platform == "win32":
|
||||
from providers import ModelProviderRegistry # noqa: E402
|
||||
from providers.base import ProviderType # noqa: E402
|
||||
from providers.gemini import GeminiModelProvider # noqa: E402
|
||||
from providers.openai import OpenAIModelProvider # noqa: E402
|
||||
from providers.openai_provider import OpenAIModelProvider # noqa: E402
|
||||
from providers.xai import XAIModelProvider # noqa: E402
|
||||
|
||||
# Register providers at test startup
|
||||
@@ -59,14 +55,11 @@ ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
@pytest.fixture
|
||||
def project_path(tmp_path):
|
||||
"""
|
||||
Provides a temporary directory within the WORKSPACE_ROOT sandbox for tests.
|
||||
This ensures all file operations during tests are within the allowed directory.
|
||||
Provides a temporary directory for tests.
|
||||
This ensures all file operations during tests are isolated.
|
||||
"""
|
||||
# Get the test workspace root
|
||||
test_root = Path(os.environ.get("WORKSPACE_ROOT", "/tmp"))
|
||||
|
||||
# Create a subdirectory for this specific test
|
||||
test_dir = test_root / f"test_{tmp_path.name}"
|
||||
test_dir = tmp_path / "test_workspace"
|
||||
test_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return test_dir
|
||||
|
||||
@@ -10,7 +10,7 @@ from unittest.mock import patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from utils.model_restrictions import ModelRestrictionService
|
||||
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class TestAutoModeComprehensive:
|
||||
|
||||
# Re-register providers for subsequent tests (like conftest.py does)
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
@@ -178,7 +178,7 @@ class TestAutoModeComprehensive:
|
||||
|
||||
# Register providers based on configuration
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
@@ -349,7 +349,7 @@ class TestAutoModeComprehensive:
|
||||
|
||||
# Register all native providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
@@ -460,7 +460,7 @@ class TestAutoModeComprehensive:
|
||||
|
||||
# Register providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
@@ -86,7 +86,7 @@ class TestAutoModeProviderSelection:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Register only OpenAI provider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
@@ -127,7 +127,7 @@ class TestAutoModeProviderSelection:
|
||||
|
||||
# Register both providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
@@ -212,7 +212,7 @@ class TestAutoModeProviderSelection:
|
||||
|
||||
# Register both providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
@@ -256,7 +256,7 @@ class TestAutoModeProviderSelection:
|
||||
|
||||
# Register all providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
@@ -307,7 +307,7 @@ class TestAutoModeProviderSelection:
|
||||
|
||||
# Register all providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
@@ -16,7 +16,7 @@ import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from utils.model_restrictions import ModelRestrictionService
|
||||
|
||||
|
||||
|
||||
@@ -61,16 +61,16 @@ class TestClaudeContinuationOffers:
|
||||
# Set default model to avoid effective auto mode
|
||||
self.tool.default_model = "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_new_conversation_offers_continuation(self, mock_redis):
|
||||
async def test_new_conversation_offers_continuation(self, mock_storage):
|
||||
"""Test that new conversations offer Claude continuation opportunity"""
|
||||
# Create tool AFTER providers are registered (in conftest.py fixture)
|
||||
tool = ClaudeContinuationTool()
|
||||
tool.default_model = "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Mock the model
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
@@ -97,12 +97,12 @@ class TestClaudeContinuationOffers:
|
||||
assert "continuation_offer" in response_data
|
||||
assert response_data["continuation_offer"]["remaining_turns"] == MAX_CONVERSATION_TURNS - 1
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_existing_conversation_still_offers_continuation(self, mock_redis):
|
||||
async def test_existing_conversation_still_offers_continuation(self, mock_storage):
|
||||
"""Test that existing threaded conversations still offer continuation if turns remain"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Mock existing thread context with 2 turns
|
||||
from utils.conversation_memory import ConversationTurn, ThreadContext
|
||||
@@ -155,12 +155,12 @@ class TestClaudeContinuationOffers:
|
||||
# MAX_CONVERSATION_TURNS - 2 existing - 1 new = remaining
|
||||
assert response_data["continuation_offer"]["remaining_turns"] == MAX_CONVERSATION_TURNS - 3
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_full_response_flow_with_continuation_offer(self, mock_redis):
|
||||
async def test_full_response_flow_with_continuation_offer(self, mock_storage):
|
||||
"""Test complete response flow that creates continuation offer"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Mock the model to return a response without follow-up question
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
@@ -193,12 +193,12 @@ class TestClaudeContinuationOffers:
|
||||
assert "You have" in offer["note"]
|
||||
assert "more exchange(s) available" in offer["note"]
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_continuation_always_offered_with_natural_language(self, mock_redis):
|
||||
async def test_continuation_always_offered_with_natural_language(self, mock_storage):
|
||||
"""Test that continuation is always offered with natural language prompts"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Mock the model to return a response with natural language follow-up
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
@@ -229,12 +229,12 @@ I'd be happy to examine the error handling patterns in more detail if that would
|
||||
assert "continuation_offer" in response_data
|
||||
assert response_data["continuation_offer"]["remaining_turns"] == MAX_CONVERSATION_TURNS - 1
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_threaded_conversation_with_continuation_offer(self, mock_redis):
|
||||
async def test_threaded_conversation_with_continuation_offer(self, mock_storage):
|
||||
"""Test that threaded conversations still get continuation offers when turns remain"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Mock existing thread context
|
||||
from utils.conversation_memory import ThreadContext
|
||||
@@ -274,12 +274,12 @@ I'd be happy to examine the error handling patterns in more detail if that would
|
||||
assert response_data.get("continuation_offer") is not None
|
||||
assert response_data["continuation_offer"]["remaining_turns"] == MAX_CONVERSATION_TURNS - 1
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_max_turns_reached_no_continuation_offer(self, mock_redis):
|
||||
async def test_max_turns_reached_no_continuation_offer(self, mock_storage):
|
||||
"""Test that no continuation is offered when max turns would be exceeded"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Mock existing thread context at max turns
|
||||
from utils.conversation_memory import ConversationTurn, ThreadContext
|
||||
@@ -338,12 +338,12 @@ class TestContinuationIntegration:
|
||||
# Set default model to avoid effective auto mode
|
||||
self.tool.default_model = "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_continuation_offer_creates_proper_thread(self, mock_redis):
|
||||
async def test_continuation_offer_creates_proper_thread(self, mock_storage):
|
||||
"""Test that continuation offers create properly formatted threads"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Mock the get call that add_turn makes to retrieve the existing thread
|
||||
# We'll set this up after the first setex call
|
||||
@@ -402,12 +402,12 @@ class TestContinuationIntegration:
|
||||
assert thread_context["initial_context"]["prompt"] == "Initial analysis"
|
||||
assert thread_context["initial_context"]["files"] == ["/test/file.py"]
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_claude_can_use_continuation_id(self, mock_redis):
|
||||
async def test_claude_can_use_continuation_id(self, mock_storage):
|
||||
"""Test that Claude can use the provided continuation_id in subsequent calls"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Step 1: Initial request creates continuation offer
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
|
||||
@@ -10,7 +10,7 @@ import pytest
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
from tools.analyze import AnalyzeTool
|
||||
from tools.debug import DebugIssueTool
|
||||
from tools.models import ClarificationRequest, ToolOutput
|
||||
from tools.models import FilesNeededRequest, ToolOutput
|
||||
|
||||
|
||||
class TestDynamicContextRequests:
|
||||
@@ -31,8 +31,8 @@ class TestDynamicContextRequests:
|
||||
# Mock model to return a clarification request
|
||||
clarification_json = json.dumps(
|
||||
{
|
||||
"status": "clarification_required",
|
||||
"question": "I need to see the package.json file to understand dependencies",
|
||||
"status": "files_required_to_continue",
|
||||
"mandatory_instructions": "I need to see the package.json file to understand dependencies",
|
||||
"files_needed": ["package.json", "package-lock.json"],
|
||||
}
|
||||
)
|
||||
@@ -56,12 +56,16 @@ class TestDynamicContextRequests:
|
||||
|
||||
# Parse the response
|
||||
response_data = json.loads(result[0].text)
|
||||
assert response_data["status"] == "clarification_required"
|
||||
assert response_data["status"] == "files_required_to_continue"
|
||||
assert response_data["content_type"] == "json"
|
||||
|
||||
# Parse the clarification request
|
||||
clarification = json.loads(response_data["content"])
|
||||
assert clarification["question"] == "I need to see the package.json file to understand dependencies"
|
||||
# Check that the enhanced instructions contain the original message and additional guidance
|
||||
expected_start = "I need to see the package.json file to understand dependencies"
|
||||
assert clarification["mandatory_instructions"].startswith(expected_start)
|
||||
assert "IMPORTANT GUIDANCE:" in clarification["mandatory_instructions"]
|
||||
assert "Use FULL absolute paths" in clarification["mandatory_instructions"]
|
||||
assert clarification["files_needed"] == ["package.json", "package-lock.json"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -100,7 +104,7 @@ class TestDynamicContextRequests:
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_malformed_clarification_request_treated_as_normal(self, mock_get_provider, analyze_tool):
|
||||
"""Test that malformed JSON clarification requests are treated as normal responses"""
|
||||
malformed_json = '{"status": "clarification_required", "prompt": "Missing closing brace"'
|
||||
malformed_json = '{"status": "files_required_to_continue", "prompt": "Missing closing brace"'
|
||||
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
@@ -125,8 +129,8 @@ class TestDynamicContextRequests:
|
||||
"""Test clarification request with suggested next action"""
|
||||
clarification_json = json.dumps(
|
||||
{
|
||||
"status": "clarification_required",
|
||||
"question": "I need to see the database configuration to diagnose the connection error",
|
||||
"status": "files_required_to_continue",
|
||||
"mandatory_instructions": "I need to see the database configuration to diagnose the connection error",
|
||||
"files_needed": ["config/database.yml", "src/db.py"],
|
||||
"suggested_next_action": {
|
||||
"tool": "debug",
|
||||
@@ -160,7 +164,7 @@ class TestDynamicContextRequests:
|
||||
assert len(result) == 1
|
||||
|
||||
response_data = json.loads(result[0].text)
|
||||
assert response_data["status"] == "clarification_required"
|
||||
assert response_data["status"] == "files_required_to_continue"
|
||||
|
||||
clarification = json.loads(response_data["content"])
|
||||
assert "suggested_next_action" in clarification
|
||||
@@ -184,17 +188,54 @@ class TestDynamicContextRequests:
|
||||
assert parsed["metadata"]["tool_name"] == "test"
|
||||
|
||||
def test_clarification_request_model(self):
|
||||
"""Test ClarificationRequest model"""
|
||||
request = ClarificationRequest(
|
||||
question="Need more context",
|
||||
"""Test FilesNeededRequest model"""
|
||||
request = FilesNeededRequest(
|
||||
mandatory_instructions="Need more context",
|
||||
files_needed=["file1.py", "file2.py"],
|
||||
suggested_next_action={"tool": "analyze", "args": {}},
|
||||
)
|
||||
|
||||
assert request.question == "Need more context"
|
||||
assert request.mandatory_instructions == "Need more context"
|
||||
assert len(request.files_needed) == 2
|
||||
assert request.suggested_next_action["tool"] == "analyze"
|
||||
|
||||
def test_mandatory_instructions_enhancement(self):
|
||||
"""Test that mandatory_instructions are enhanced with additional guidance"""
|
||||
from tools.base import BaseTool
|
||||
|
||||
# Create a dummy tool instance for testing
|
||||
class TestTool(BaseTool):
|
||||
def get_name(self):
|
||||
return "test"
|
||||
|
||||
def get_description(self):
|
||||
return "test"
|
||||
|
||||
def get_request_model(self):
|
||||
return None
|
||||
|
||||
def prepare_prompt(self, request):
|
||||
return ""
|
||||
|
||||
def get_system_prompt(self):
|
||||
return ""
|
||||
|
||||
def get_input_schema(self):
|
||||
return {}
|
||||
|
||||
tool = TestTool()
|
||||
original = "I need additional files to proceed"
|
||||
enhanced = tool._enhance_mandatory_instructions(original)
|
||||
|
||||
# Verify the original instructions are preserved
|
||||
assert enhanced.startswith(original)
|
||||
|
||||
# Verify additional guidance is added
|
||||
assert "IMPORTANT GUIDANCE:" in enhanced
|
||||
assert "CRITICAL for providing accurate analysis" in enhanced
|
||||
assert "Use FULL absolute paths" in enhanced
|
||||
assert "continuation_id to continue" in enhanced
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.base.BaseTool.get_model_provider")
|
||||
async def test_error_response_format(self, mock_get_provider, analyze_tool):
|
||||
@@ -223,8 +264,8 @@ class TestCollaborationWorkflow:
|
||||
# Mock Gemini to request package.json when asked about dependencies
|
||||
clarification_json = json.dumps(
|
||||
{
|
||||
"status": "clarification_required",
|
||||
"question": "I need to see the package.json file to analyze npm dependencies",
|
||||
"status": "files_required_to_continue",
|
||||
"mandatory_instructions": "I need to see the package.json file to analyze npm dependencies",
|
||||
"files_needed": ["package.json", "package-lock.json"],
|
||||
}
|
||||
)
|
||||
@@ -247,7 +288,7 @@ class TestCollaborationWorkflow:
|
||||
|
||||
response = json.loads(result[0].text)
|
||||
assert (
|
||||
response["status"] == "clarification_required"
|
||||
response["status"] == "files_required_to_continue"
|
||||
), "Should request clarification when asked about dependencies without package files"
|
||||
|
||||
clarification = json.loads(response["content"])
|
||||
@@ -262,8 +303,8 @@ class TestCollaborationWorkflow:
|
||||
# Step 1: Initial request returns clarification needed
|
||||
clarification_json = json.dumps(
|
||||
{
|
||||
"status": "clarification_required",
|
||||
"question": "I need to see the configuration file to understand the connection settings",
|
||||
"status": "files_required_to_continue",
|
||||
"mandatory_instructions": "I need to see the configuration file to understand the connection settings",
|
||||
"files_needed": ["config.py"],
|
||||
}
|
||||
)
|
||||
@@ -284,7 +325,7 @@ class TestCollaborationWorkflow:
|
||||
)
|
||||
|
||||
response1 = json.loads(result1[0].text)
|
||||
assert response1["status"] == "clarification_required"
|
||||
assert response1["status"] == "files_required_to_continue"
|
||||
|
||||
# Step 2: Claude would provide additional context and re-invoke
|
||||
# This simulates the second call with more context
|
||||
|
||||
@@ -26,11 +26,11 @@ from utils.conversation_memory import (
|
||||
class TestConversationMemory:
|
||||
"""Test the conversation memory system for stateless MCP requests"""
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_create_thread(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_create_thread(self, mock_storage):
|
||||
"""Test creating a new thread"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
thread_id = create_thread("chat", {"prompt": "Hello", "files": ["/test.py"]})
|
||||
|
||||
@@ -43,11 +43,11 @@ class TestConversationMemory:
|
||||
assert call_args[0][0] == f"thread:{thread_id}" # key
|
||||
assert call_args[0][1] == CONVERSATION_TIMEOUT_SECONDS # TTL from configuration
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_get_thread_valid(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_get_thread_valid(self, mock_storage):
|
||||
"""Test retrieving an existing thread"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
test_uuid = "12345678-1234-1234-1234-123456789012"
|
||||
|
||||
@@ -69,27 +69,27 @@ class TestConversationMemory:
|
||||
assert context.tool_name == "chat"
|
||||
mock_client.get.assert_called_once_with(f"thread:{test_uuid}")
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_get_thread_invalid_uuid(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_get_thread_invalid_uuid(self, mock_storage):
|
||||
"""Test handling invalid UUID"""
|
||||
context = get_thread("invalid-uuid")
|
||||
assert context is None
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_get_thread_not_found(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_get_thread_not_found(self, mock_storage):
|
||||
"""Test handling thread not found"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
mock_client.get.return_value = None
|
||||
|
||||
context = get_thread("12345678-1234-1234-1234-123456789012")
|
||||
assert context is None
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_add_turn_success(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_add_turn_success(self, mock_storage):
|
||||
"""Test adding a turn to existing thread"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
test_uuid = "12345678-1234-1234-1234-123456789012"
|
||||
|
||||
@@ -111,11 +111,11 @@ class TestConversationMemory:
|
||||
mock_client.get.assert_called_once()
|
||||
mock_client.setex.assert_called_once()
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_add_turn_max_limit(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_add_turn_max_limit(self, mock_storage):
|
||||
"""Test turn limit enforcement"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
test_uuid = "12345678-1234-1234-1234-123456789012"
|
||||
|
||||
@@ -237,11 +237,11 @@ class TestConversationMemory:
|
||||
class TestConversationFlow:
|
||||
"""Test complete conversation flows simulating stateless MCP requests"""
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_complete_conversation_cycle(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_complete_conversation_cycle(self, mock_storage):
|
||||
"""Test a complete 5-turn conversation until limit reached"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Simulate independent MCP request cycles
|
||||
|
||||
@@ -341,13 +341,13 @@ class TestConversationFlow:
|
||||
success = add_turn(thread_id, "user", "This should be rejected")
|
||||
assert success is False # CONVERSATION STOPS HERE
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_invalid_continuation_id_error(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_invalid_continuation_id_error(self, mock_storage):
|
||||
"""Test that invalid continuation IDs raise proper error for restart"""
|
||||
from server import reconstruct_thread_context
|
||||
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
mock_client.get.return_value = None # Thread not found
|
||||
|
||||
arguments = {"continuation_id": "invalid-uuid-12345", "prompt": "Continue conversation"}
|
||||
@@ -439,11 +439,11 @@ class TestConversationFlow:
|
||||
expected_remaining = MAX_CONVERSATION_TURNS - 1
|
||||
assert f"({expected_remaining} exchanges remaining)" in instructions
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_complete_conversation_with_dynamic_turns(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_complete_conversation_with_dynamic_turns(self, mock_storage):
|
||||
"""Test complete conversation respecting MAX_CONVERSATION_TURNS dynamically"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
thread_id = create_thread("chat", {"prompt": "Start conversation"})
|
||||
|
||||
@@ -495,16 +495,16 @@ class TestConversationFlow:
|
||||
success = add_turn(thread_id, "user", "This should fail")
|
||||
assert success is False, f"Turn {MAX_CONVERSATION_TURNS + 1} should fail"
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch.dict(os.environ, {"GEMINI_API_KEY": "test-key", "OPENAI_API_KEY": ""}, clear=False)
|
||||
def test_conversation_with_files_and_context_preservation(self, mock_redis):
|
||||
def test_conversation_with_files_and_context_preservation(self, mock_storage):
|
||||
"""Test complete conversation flow with file tracking and context preservation"""
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
ModelProviderRegistry.clear_cache()
|
||||
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Start conversation with files
|
||||
thread_id = create_thread("analyze", {"prompt": "Analyze this codebase", "files": ["/project/src/"]})
|
||||
@@ -648,11 +648,11 @@ class TestConversationFlow:
|
||||
|
||||
assert turn_1_pos < turn_2_pos < turn_3_pos
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_stateless_request_isolation(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_stateless_request_isolation(self, mock_storage):
|
||||
"""Test that each request cycle is independent but shares context via Redis"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Simulate two different "processes" accessing same thread
|
||||
thread_id = "12345678-1234-1234-1234-123456789012"
|
||||
|
||||
@@ -93,12 +93,12 @@ class TestCrossToolContinuation:
|
||||
self.analysis_tool = MockAnalysisTool()
|
||||
self.review_tool = MockReviewTool()
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_continuation_id_works_across_different_tools(self, mock_redis):
|
||||
async def test_continuation_id_works_across_different_tools(self, mock_storage):
|
||||
"""Test that a continuation_id from one tool can be used with another tool"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Step 1: Analysis tool creates a conversation with continuation offer
|
||||
with patch.object(self.analysis_tool, "get_model_provider") as mock_get_provider:
|
||||
@@ -195,11 +195,11 @@ I'd be happy to review these security findings in detail if that would be helpfu
|
||||
assert second_turn["tool_name"] == "test_review" # New tool name
|
||||
assert "Critical security vulnerability confirmed" in second_turn["content"]
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_cross_tool_conversation_history_includes_tool_names(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_cross_tool_conversation_history_includes_tool_names(self, mock_storage):
|
||||
"""Test that conversation history properly shows which tool was used for each turn"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Create a thread context with turns from different tools
|
||||
thread_context = ThreadContext(
|
||||
@@ -247,13 +247,13 @@ I'd be happy to review these security findings in detail if that would be helpfu
|
||||
assert "Review complete: 2 critical, 1 minor issue" in history
|
||||
assert "Deep analysis: Root cause identified" in history
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch("utils.conversation_memory.get_thread")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_cross_tool_conversation_with_files_context(self, mock_get_thread, mock_redis):
|
||||
async def test_cross_tool_conversation_with_files_context(self, mock_get_thread, mock_storage):
|
||||
"""Test that file context is preserved across tool switches"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Create existing context with files from analysis tool
|
||||
existing_context = ThreadContext(
|
||||
@@ -317,12 +317,12 @@ I'd be happy to review these security findings in detail if that would be helpfu
|
||||
analysis_turn = final_context["turns"][0] # First turn (analysis tool)
|
||||
assert analysis_turn["files"] == ["/src/auth.py", "/src/utils.py"]
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch("utils.conversation_memory.get_thread")
|
||||
def test_thread_preserves_original_tool_name(self, mock_get_thread, mock_redis):
|
||||
def test_thread_preserves_original_tool_name(self, mock_get_thread, mock_storage):
|
||||
"""Test that the thread's original tool_name is preserved even when other tools contribute"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Create existing thread from analysis tool
|
||||
existing_context = ThreadContext(
|
||||
|
||||
@@ -31,8 +31,9 @@ class TestCustomProvider:
|
||||
|
||||
def test_provider_initialization_missing_url(self):
|
||||
"""Test CustomProvider raises error when URL is missing."""
|
||||
with pytest.raises(ValueError, match="Custom API URL must be provided"):
|
||||
CustomProvider(api_key="test-key")
|
||||
with patch.dict(os.environ, {"CUSTOM_API_URL": ""}, clear=False):
|
||||
with pytest.raises(ValueError, match="Custom API URL must be provided"):
|
||||
CustomProvider(api_key="test-key")
|
||||
|
||||
def test_validate_model_names_always_true(self):
|
||||
"""Test CustomProvider accepts any model name."""
|
||||
|
||||
@@ -121,10 +121,10 @@ def helper_function():
|
||||
assert any(str(Path(f).resolve()) == expected_resolved for f in captured_files)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
@patch("providers.ModelProviderRegistry.get_provider_for_model")
|
||||
async def test_conversation_continuation_with_directory_files(
|
||||
self, mock_get_provider, mock_redis, tool, temp_directory_with_files
|
||||
self, mock_get_provider, mock_storage, tool, temp_directory_with_files
|
||||
):
|
||||
"""Test that conversation continuation works correctly with directory expansion"""
|
||||
# Setup mock Redis client with in-memory storage
|
||||
@@ -140,7 +140,7 @@ def helper_function():
|
||||
|
||||
mock_client.get.side_effect = mock_get
|
||||
mock_client.setex.side_effect = mock_setex
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Setup mock provider
|
||||
mock_provider = create_mock_provider()
|
||||
@@ -196,8 +196,8 @@ def helper_function():
|
||||
# This test shows the fix is working - conversation continuation properly filters out
|
||||
# already-embedded files. The exact length depends on whether any new files are found.
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_get_conversation_embedded_files_with_expanded_files(self, mock_redis, tool, temp_directory_with_files):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_get_conversation_embedded_files_with_expanded_files(self, mock_storage, tool, temp_directory_with_files):
|
||||
"""Test that get_conversation_embedded_files returns expanded files"""
|
||||
# Setup mock Redis client with in-memory storage
|
||||
mock_client = Mock()
|
||||
@@ -212,7 +212,7 @@ def helper_function():
|
||||
|
||||
mock_client.get.side_effect = mock_get
|
||||
mock_client.setex.side_effect = mock_setex
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
directory = temp_directory_with_files["directory"]
|
||||
expected_files = temp_directory_with_files["files"]
|
||||
@@ -237,8 +237,8 @@ def helper_function():
|
||||
assert set(embedded_files) == set(expected_files)
|
||||
assert directory not in embedded_files
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_file_filtering_with_mixed_files_and_directories(self, mock_redis, tool, temp_directory_with_files):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_file_filtering_with_mixed_files_and_directories(self, mock_storage, tool, temp_directory_with_files):
|
||||
"""Test file filtering when request contains both individual files and directories"""
|
||||
# Setup mock Redis client with in-memory storage
|
||||
mock_client = Mock()
|
||||
@@ -253,7 +253,7 @@ def helper_function():
|
||||
|
||||
mock_client.get.side_effect = mock_get
|
||||
mock_client.setex.side_effect = mock_setex
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
directory = temp_directory_with_files["directory"]
|
||||
python_file = temp_directory_with_files["python_file"]
|
||||
|
||||
@@ -1,320 +0,0 @@
|
||||
"""
|
||||
Integration tests for Docker path translation
|
||||
|
||||
These tests verify the actual behavior when running in a Docker-like environment
|
||||
by creating temporary directories and testing the path translation logic.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
# We'll reload the module to test different environment configurations
|
||||
import utils.file_utils
|
||||
|
||||
|
||||
def test_docker_path_translation_integration():
|
||||
"""Test path translation in a simulated Docker environment"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Set up directories
|
||||
host_workspace = Path(tmpdir) / "host_workspace"
|
||||
host_workspace.mkdir()
|
||||
container_workspace = Path(tmpdir) / "container_workspace"
|
||||
container_workspace.mkdir()
|
||||
|
||||
# Create a test file structure
|
||||
(host_workspace / "src").mkdir()
|
||||
test_file = host_workspace / "src" / "test.py"
|
||||
test_file.write_text("# test file")
|
||||
|
||||
# Set environment variables and reload the module
|
||||
original_env = os.environ.copy()
|
||||
try:
|
||||
os.environ["WORKSPACE_ROOT"] = str(host_workspace)
|
||||
|
||||
# Reload the modules to pick up new environment variables
|
||||
# Need to reload security_config first since it sets WORKSPACE_ROOT
|
||||
import utils.security_config
|
||||
|
||||
importlib.reload(utils.security_config)
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
# Properly mock the CONTAINER_WORKSPACE
|
||||
with patch("utils.file_utils.CONTAINER_WORKSPACE", container_workspace):
|
||||
# Test the translation
|
||||
from utils.file_utils import translate_path_for_environment
|
||||
|
||||
# This should translate the host path to container path
|
||||
host_path = str(test_file)
|
||||
result = translate_path_for_environment(host_path)
|
||||
|
||||
# Verify the translation worked
|
||||
expected = str(container_workspace / "src" / "test.py")
|
||||
assert result == expected
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
import utils.security_config
|
||||
|
||||
importlib.reload(utils.security_config)
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
|
||||
def test_docker_security_validation():
|
||||
"""Test that path traversal attempts are properly blocked"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Set up directories
|
||||
host_workspace = Path(tmpdir) / "workspace"
|
||||
host_workspace.mkdir()
|
||||
secret_dir = Path(tmpdir) / "secret"
|
||||
secret_dir.mkdir()
|
||||
secret_file = secret_dir / "password.txt"
|
||||
secret_file.write_text("secret")
|
||||
|
||||
# Create a symlink inside workspace pointing to secret
|
||||
symlink = host_workspace / "link_to_secret"
|
||||
symlink.symlink_to(secret_file)
|
||||
|
||||
original_env = os.environ.copy()
|
||||
try:
|
||||
os.environ["WORKSPACE_ROOT"] = str(host_workspace)
|
||||
|
||||
# Reload the modules
|
||||
import utils.security_config
|
||||
|
||||
importlib.reload(utils.security_config)
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
# Properly mock the CONTAINER_WORKSPACE
|
||||
with patch("utils.file_utils.CONTAINER_WORKSPACE", Path("/workspace")):
|
||||
from utils.file_utils import resolve_and_validate_path
|
||||
|
||||
# Trying to access the symlink should fail
|
||||
with pytest.raises(PermissionError):
|
||||
resolve_and_validate_path(str(symlink))
|
||||
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
import utils.security_config
|
||||
|
||||
importlib.reload(utils.security_config)
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
|
||||
def test_no_docker_environment():
|
||||
"""Test that paths are unchanged when Docker environment is not set"""
|
||||
|
||||
original_env = os.environ.copy()
|
||||
try:
|
||||
# Clear Docker-related environment variables
|
||||
os.environ.pop("WORKSPACE_ROOT", None)
|
||||
|
||||
# Reload the module
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
from utils.file_utils import translate_path_for_environment
|
||||
|
||||
# Path should remain unchanged
|
||||
test_path = "/some/random/path.py"
|
||||
assert translate_path_for_environment(test_path) == test_path
|
||||
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
|
||||
def test_review_changes_docker_path_translation():
|
||||
"""Test that review_changes tool properly translates Docker paths"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Set up directories to simulate Docker mount
|
||||
host_workspace = Path(tmpdir) / "host_workspace"
|
||||
host_workspace.mkdir()
|
||||
container_workspace = Path(tmpdir) / "container_workspace"
|
||||
container_workspace.mkdir()
|
||||
|
||||
# Create a git repository in the container workspace
|
||||
project_dir = container_workspace / "project"
|
||||
project_dir.mkdir()
|
||||
|
||||
# Initialize git repo
|
||||
import subprocess
|
||||
|
||||
subprocess.run(["git", "init"], cwd=project_dir, capture_output=True)
|
||||
|
||||
# Create a test file
|
||||
test_file = project_dir / "test.py"
|
||||
test_file.write_text("print('hello')")
|
||||
|
||||
# Stage the file
|
||||
subprocess.run(["git", "add", "test.py"], cwd=project_dir, capture_output=True)
|
||||
|
||||
original_env = os.environ.copy()
|
||||
try:
|
||||
# Simulate Docker environment
|
||||
os.environ["WORKSPACE_ROOT"] = str(host_workspace)
|
||||
|
||||
# Reload the modules
|
||||
import utils.security_config
|
||||
|
||||
importlib.reload(utils.security_config)
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
# Properly mock the CONTAINER_WORKSPACE and reload precommit module
|
||||
with patch("utils.file_utils.CONTAINER_WORKSPACE", container_workspace):
|
||||
# Need to also patch it in the modules that import it
|
||||
with patch("utils.security_config.CONTAINER_WORKSPACE", container_workspace):
|
||||
# Import after patching to get updated environment
|
||||
from tools.precommit import Precommit
|
||||
|
||||
# Create tool instance
|
||||
tool = Precommit()
|
||||
|
||||
# Test path translation in prepare_prompt
|
||||
request = tool.get_request_model()(
|
||||
path=str(host_workspace / "project"), # Host path that needs translation
|
||||
review_type="quick",
|
||||
severity_filter="all",
|
||||
)
|
||||
|
||||
# This should translate the path and find the git repository
|
||||
import asyncio
|
||||
|
||||
result = asyncio.run(tool.prepare_prompt(request))
|
||||
|
||||
# Should find the repository (not raise an error about inaccessible path)
|
||||
# If we get here without exception, the path was successfully translated
|
||||
assert isinstance(result, str)
|
||||
# The result should contain git diff information or indicate no changes
|
||||
assert "No git repositories found" not in result or "changes" in result.lower()
|
||||
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
import utils.security_config
|
||||
|
||||
importlib.reload(utils.security_config)
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
|
||||
def test_review_changes_docker_path_error():
|
||||
"""Test that review_changes tool raises error for inaccessible paths"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Set up directories to simulate Docker mount
|
||||
host_workspace = Path(tmpdir) / "host_workspace"
|
||||
host_workspace.mkdir()
|
||||
container_workspace = Path(tmpdir) / "container_workspace"
|
||||
container_workspace.mkdir()
|
||||
|
||||
# Create a path outside the mounted workspace
|
||||
outside_path = Path(tmpdir) / "outside_workspace"
|
||||
outside_path.mkdir()
|
||||
|
||||
original_env = os.environ.copy()
|
||||
try:
|
||||
# Simulate Docker environment
|
||||
os.environ["WORKSPACE_ROOT"] = str(host_workspace)
|
||||
|
||||
# Reload the modules
|
||||
import utils.security_config
|
||||
|
||||
importlib.reload(utils.security_config)
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
# Properly mock the CONTAINER_WORKSPACE
|
||||
with patch("utils.file_utils.CONTAINER_WORKSPACE", container_workspace):
|
||||
with patch("utils.security_config.CONTAINER_WORKSPACE", container_workspace):
|
||||
# Import after patching to get updated environment
|
||||
from tools.precommit import Precommit
|
||||
|
||||
# Create tool instance
|
||||
tool = Precommit()
|
||||
|
||||
# Test path translation with an inaccessible path
|
||||
request = tool.get_request_model()(
|
||||
path=str(outside_path), # Path outside the mounted workspace
|
||||
review_type="quick",
|
||||
severity_filter="all",
|
||||
)
|
||||
|
||||
# This should raise a ValueError
|
||||
import asyncio
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
asyncio.run(tool.prepare_prompt(request))
|
||||
|
||||
# Check the error message
|
||||
assert "not accessible from within the Docker container" in str(exc_info.value)
|
||||
assert "mounted workspace" in str(exc_info.value)
|
||||
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
import utils.security_config
|
||||
|
||||
importlib.reload(utils.security_config)
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
|
||||
def test_double_translation_prevention():
|
||||
"""Test that already-translated paths are not double-translated"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Set up directories
|
||||
host_workspace = Path(tmpdir) / "host_workspace"
|
||||
host_workspace.mkdir()
|
||||
container_workspace = Path(tmpdir) / "container_workspace"
|
||||
container_workspace.mkdir()
|
||||
|
||||
original_env = os.environ.copy()
|
||||
try:
|
||||
os.environ["WORKSPACE_ROOT"] = str(host_workspace)
|
||||
|
||||
# Reload the modules
|
||||
import utils.security_config
|
||||
|
||||
importlib.reload(utils.security_config)
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
# Properly mock the CONTAINER_WORKSPACE
|
||||
with patch("utils.file_utils.CONTAINER_WORKSPACE", container_workspace):
|
||||
from utils.file_utils import translate_path_for_environment
|
||||
|
||||
# Test 1: Normal translation
|
||||
host_path = str(host_workspace / "src" / "main.py")
|
||||
translated_once = translate_path_for_environment(host_path)
|
||||
expected = str(container_workspace / "src" / "main.py")
|
||||
assert translated_once == expected
|
||||
|
||||
# Test 2: Double translation should return the same path
|
||||
translated_twice = translate_path_for_environment(translated_once)
|
||||
assert translated_twice == translated_once
|
||||
assert translated_twice == expected
|
||||
|
||||
# Test 3: Container workspace root should not be double-translated
|
||||
root_path = str(container_workspace)
|
||||
translated_root = translate_path_for_environment(root_path)
|
||||
assert translated_root == root_path
|
||||
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
import utils.security_config
|
||||
|
||||
importlib.reload(utils.security_config)
|
||||
importlib.reload(utils.file_utils)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -5,12 +5,10 @@ Test file protection mechanisms to ensure MCP doesn't scan:
|
||||
3. Excluded directories
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from utils.file_utils import (
|
||||
MCP_SIGNATURE_FILES,
|
||||
expand_paths,
|
||||
get_user_home_directory,
|
||||
is_home_directory_root,
|
||||
@@ -21,25 +19,31 @@ from utils.file_utils import (
|
||||
class TestMCPDirectoryDetection:
|
||||
"""Test MCP self-detection to prevent scanning its own code."""
|
||||
|
||||
def test_detect_mcp_directory_with_all_signatures(self, tmp_path):
|
||||
"""Test detection when all signature files are present."""
|
||||
# Create a fake MCP directory with signature files
|
||||
for sig_file in list(MCP_SIGNATURE_FILES)[:4]: # Use 4 files
|
||||
if "/" in sig_file:
|
||||
(tmp_path / sig_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
(tmp_path / sig_file).touch()
|
||||
def test_detect_mcp_directory_dynamically(self, tmp_path):
|
||||
"""Test dynamic MCP directory detection based on script location."""
|
||||
# The is_mcp_directory function now uses __file__ to detect MCP location
|
||||
# It checks if the given path is a subdirectory of the MCP server
|
||||
from pathlib import Path
|
||||
|
||||
assert is_mcp_directory(tmp_path) is True
|
||||
import utils.file_utils
|
||||
|
||||
def test_no_detection_with_few_signatures(self, tmp_path):
|
||||
"""Test no detection with only 1-2 signature files."""
|
||||
# Create only 2 signature files (less than threshold)
|
||||
for sig_file in list(MCP_SIGNATURE_FILES)[:2]:
|
||||
if "/" in sig_file:
|
||||
(tmp_path / sig_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
(tmp_path / sig_file).touch()
|
||||
# Get the actual MCP server directory
|
||||
mcp_server_dir = Path(utils.file_utils.__file__).parent.parent.resolve()
|
||||
|
||||
assert is_mcp_directory(tmp_path) is False
|
||||
# Test that the MCP server directory itself is detected
|
||||
assert is_mcp_directory(mcp_server_dir) is True
|
||||
|
||||
# Test that a subdirectory of MCP is also detected
|
||||
if (mcp_server_dir / "tools").exists():
|
||||
assert is_mcp_directory(mcp_server_dir / "tools") is True
|
||||
|
||||
def test_no_detection_on_non_mcp_directory(self, tmp_path):
|
||||
"""Test no detection on directories outside MCP."""
|
||||
# Any directory outside the MCP server should not be detected
|
||||
non_mcp_dir = tmp_path / "some_other_project"
|
||||
non_mcp_dir.mkdir()
|
||||
|
||||
assert is_mcp_directory(non_mcp_dir) is False
|
||||
|
||||
def test_no_detection_on_regular_directory(self, tmp_path):
|
||||
"""Test no detection on regular project directories."""
|
||||
@@ -59,7 +63,11 @@ class TestMCPDirectoryDetection:
|
||||
|
||||
def test_mcp_directory_excluded_from_scan(self, tmp_path):
|
||||
"""Test that MCP directories are excluded during path expansion."""
|
||||
# Create a project with MCP as subdirectory
|
||||
# For this test, we need to mock is_mcp_directory since we can't
|
||||
# actually create the MCP directory structure in tmp_path
|
||||
from unittest.mock import patch as mock_patch
|
||||
|
||||
# Create a project with a subdirectory we'll pretend is MCP
|
||||
project_root = tmp_path / "my_project"
|
||||
project_root.mkdir()
|
||||
|
||||
@@ -67,19 +75,18 @@ class TestMCPDirectoryDetection:
|
||||
(project_root / "app.py").write_text("# My app")
|
||||
(project_root / "config.py").write_text("# Config")
|
||||
|
||||
# Create MCP subdirectory
|
||||
mcp_dir = project_root / "gemini-mcp-server"
|
||||
mcp_dir.mkdir()
|
||||
for sig_file in list(MCP_SIGNATURE_FILES)[:4]:
|
||||
if "/" in sig_file:
|
||||
(mcp_dir / sig_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
(mcp_dir / sig_file).write_text("# MCP file")
|
||||
# Create a subdirectory that we'll mock as MCP
|
||||
fake_mcp_dir = project_root / "gemini-mcp-server"
|
||||
fake_mcp_dir.mkdir()
|
||||
(fake_mcp_dir / "server.py").write_text("# MCP server")
|
||||
(fake_mcp_dir / "test.py").write_text("# Should not be included")
|
||||
|
||||
# Also add a regular file to MCP dir
|
||||
(mcp_dir / "test.py").write_text("# Should not be included")
|
||||
# Mock is_mcp_directory to return True for our fake MCP dir
|
||||
def mock_is_mcp(path):
|
||||
return str(path).endswith("gemini-mcp-server")
|
||||
|
||||
# Scan the project - use parent as SECURITY_ROOT to avoid workspace root check
|
||||
with patch("utils.file_utils.SECURITY_ROOT", tmp_path):
|
||||
# Scan the project with mocked MCP detection
|
||||
with mock_patch("utils.file_utils.is_mcp_directory", side_effect=mock_is_mcp):
|
||||
files = expand_paths([str(project_root)])
|
||||
|
||||
# Verify project files are included but MCP files are not
|
||||
@@ -135,42 +142,45 @@ class TestHomeDirectoryProtection:
|
||||
"""Test that home directory root is excluded during path expansion."""
|
||||
with patch("utils.file_utils.get_user_home_directory") as mock_home:
|
||||
mock_home.return_value = tmp_path
|
||||
with patch("utils.file_utils.SECURITY_ROOT", tmp_path):
|
||||
# Try to scan home directory
|
||||
files = expand_paths([str(tmp_path)])
|
||||
# Should return empty as home root is skipped
|
||||
assert files == []
|
||||
# Try to scan home directory
|
||||
files = expand_paths([str(tmp_path)])
|
||||
# Should return empty as home root is skipped
|
||||
assert files == []
|
||||
|
||||
|
||||
class TestUserHomeEnvironmentVariable:
|
||||
"""Test USER_HOME environment variable handling."""
|
||||
|
||||
def test_user_home_from_env(self):
|
||||
"""Test USER_HOME is used when set."""
|
||||
test_home = "/Users/dockeruser"
|
||||
with patch.dict(os.environ, {"USER_HOME": test_home}):
|
||||
def test_user_home_from_pathlib(self):
|
||||
"""Test that get_user_home_directory uses Path.home()."""
|
||||
with patch("pathlib.Path.home") as mock_home:
|
||||
mock_home.return_value = Path("/Users/testuser")
|
||||
home = get_user_home_directory()
|
||||
assert home == Path(test_home).resolve()
|
||||
assert home == Path("/Users/testuser")
|
||||
|
||||
def test_fallback_to_workspace_root_in_docker(self):
|
||||
"""Test fallback to WORKSPACE_ROOT in Docker when USER_HOME not set."""
|
||||
with patch("utils.file_utils.WORKSPACE_ROOT", "/Users/realuser"):
|
||||
with patch("utils.file_utils.CONTAINER_WORKSPACE") as mock_container:
|
||||
mock_container.exists.return_value = True
|
||||
# Clear USER_HOME to test fallback
|
||||
with patch.dict(os.environ, {"USER_HOME": ""}, clear=False):
|
||||
home = get_user_home_directory()
|
||||
assert str(home) == "/Users/realuser"
|
||||
def test_get_home_directory_uses_pathlib(self):
|
||||
"""Test that get_user_home_directory always uses Path.home()."""
|
||||
with patch("pathlib.Path.home") as mock_home:
|
||||
mock_home.return_value = Path("/home/testuser")
|
||||
home = get_user_home_directory()
|
||||
assert home == Path("/home/testuser")
|
||||
# Verify Path.home() was called
|
||||
mock_home.assert_called_once()
|
||||
|
||||
def test_fallback_to_system_home(self):
|
||||
"""Test fallback to system home when not in Docker."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with patch("utils.file_utils.CONTAINER_WORKSPACE") as mock_container:
|
||||
mock_container.exists.return_value = False
|
||||
with patch("pathlib.Path.home") as mock_home:
|
||||
mock_home.return_value = Path("/home/user")
|
||||
home = get_user_home_directory()
|
||||
assert home == Path("/home/user")
|
||||
def test_home_directory_on_different_platforms(self):
|
||||
"""Test home directory detection on different platforms."""
|
||||
# Test different platform home directories
|
||||
test_homes = [
|
||||
Path("/Users/john"), # macOS
|
||||
Path("/home/ubuntu"), # Linux
|
||||
Path("C:\\Users\\John"), # Windows
|
||||
]
|
||||
|
||||
for test_home in test_homes:
|
||||
with patch("pathlib.Path.home") as mock_home:
|
||||
mock_home.return_value = test_home
|
||||
home = get_user_home_directory()
|
||||
assert home == test_home
|
||||
|
||||
|
||||
class TestExcludedDirectories:
|
||||
@@ -198,8 +208,7 @@ class TestExcludedDirectories:
|
||||
src.mkdir()
|
||||
(src / "utils.py").write_text("# Utils")
|
||||
|
||||
with patch("utils.file_utils.SECURITY_ROOT", tmp_path):
|
||||
files = expand_paths([str(project)])
|
||||
files = expand_paths([str(project)])
|
||||
|
||||
file_names = [Path(f).name for f in files]
|
||||
|
||||
@@ -226,8 +235,7 @@ class TestExcludedDirectories:
|
||||
# Create an allowed file
|
||||
(project / "index.js").write_text("// Index")
|
||||
|
||||
with patch("utils.file_utils.SECURITY_ROOT", tmp_path):
|
||||
files = expand_paths([str(project)])
|
||||
files = expand_paths([str(project)])
|
||||
|
||||
file_names = [Path(f).name for f in files]
|
||||
|
||||
@@ -254,10 +262,12 @@ class TestIntegrationScenarios:
|
||||
# MCP cloned inside the project
|
||||
mcp = user_project / "tools" / "gemini-mcp-server"
|
||||
mcp.mkdir(parents=True)
|
||||
for sig_file in list(MCP_SIGNATURE_FILES)[:4]:
|
||||
if "/" in sig_file:
|
||||
(mcp / sig_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
(mcp / sig_file).write_text("# MCP code")
|
||||
# Create typical MCP files
|
||||
(mcp / "server.py").write_text("# MCP server code")
|
||||
(mcp / "config.py").write_text("# MCP config")
|
||||
tools_dir = mcp / "tools"
|
||||
tools_dir.mkdir()
|
||||
(tools_dir / "chat.py").write_text("# Chat tool")
|
||||
(mcp / "LICENSE").write_text("MIT License")
|
||||
(mcp / "README.md").write_text("# Gemini MCP")
|
||||
|
||||
@@ -266,7 +276,11 @@ class TestIntegrationScenarios:
|
||||
node_modules.mkdir()
|
||||
(node_modules / "package.json").write_text("{}")
|
||||
|
||||
with patch("utils.file_utils.SECURITY_ROOT", tmp_path):
|
||||
# Mock is_mcp_directory for this test
|
||||
def mock_is_mcp(path):
|
||||
return "gemini-mcp-server" in str(path)
|
||||
|
||||
with patch("utils.file_utils.is_mcp_directory", side_effect=mock_is_mcp):
|
||||
files = expand_paths([str(user_project)])
|
||||
|
||||
file_paths = [str(f) for f in files]
|
||||
@@ -278,23 +292,28 @@ class TestIntegrationScenarios:
|
||||
|
||||
# MCP files should NOT be included
|
||||
assert not any("gemini-mcp-server" in p for p in file_paths)
|
||||
assert not any("zen_server.py" in p for p in file_paths)
|
||||
assert not any("server.py" in p for p in file_paths)
|
||||
|
||||
# node_modules should NOT be included
|
||||
assert not any("node_modules" in p for p in file_paths)
|
||||
|
||||
def test_cannot_scan_above_workspace_root(self, tmp_path):
|
||||
"""Test that we cannot scan outside the workspace root."""
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
def test_security_without_workspace_root(self, tmp_path):
|
||||
"""Test that security still works with the new security model."""
|
||||
# The system now relies on is_dangerous_path and is_home_directory_root
|
||||
# for security protection
|
||||
|
||||
# Create a file in workspace
|
||||
(workspace / "allowed.py").write_text("# Allowed")
|
||||
# Test that we can scan regular project directories
|
||||
project_dir = tmp_path / "my_project"
|
||||
project_dir.mkdir()
|
||||
(project_dir / "app.py").write_text("# App")
|
||||
|
||||
# Create a file outside workspace
|
||||
(tmp_path / "outside.py").write_text("# Outside")
|
||||
files = expand_paths([str(project_dir)])
|
||||
assert len(files) == 1
|
||||
assert "app.py" in files[0]
|
||||
|
||||
with patch("utils.file_utils.SECURITY_ROOT", workspace):
|
||||
# Try to expand paths outside workspace - should return empty list
|
||||
# Test that home directory root is still protected
|
||||
with patch("utils.file_utils.get_user_home_directory") as mock_home:
|
||||
mock_home.return_value = tmp_path
|
||||
# Scanning home root should return empty
|
||||
files = expand_paths([str(tmp_path)])
|
||||
assert files == [] # Path outside workspace is skipped silently
|
||||
assert files == []
|
||||
|
||||
@@ -80,11 +80,11 @@ class TestImageSupportIntegration:
|
||||
expected = ["shared.png", "new_diagram.png", "middle.png", "old_diagram.png"]
|
||||
assert image_list == expected
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_add_turn_with_images(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_add_turn_with_images(self, mock_storage):
|
||||
"""Test adding a conversation turn with images."""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Mock the Redis operations to return success
|
||||
mock_client.set.return_value = True
|
||||
@@ -348,11 +348,11 @@ class TestImageSupportIntegration:
|
||||
importlib.reload(config)
|
||||
ModelProviderRegistry._instance = None
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_cross_tool_image_context_preservation(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_cross_tool_image_context_preservation(self, mock_storage):
|
||||
"""Test that images are preserved across different tools in conversation."""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Mock the Redis operations to return success
|
||||
mock_client.set.return_value = True
|
||||
@@ -521,11 +521,11 @@ class TestImageSupportIntegration:
|
||||
result = tool._validate_image_limits(None, "test_model")
|
||||
assert result is None
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_conversation_memory_thread_chaining_with_images(self, mock_redis):
|
||||
@patch("utils.conversation_memory.get_storage")
|
||||
def test_conversation_memory_thread_chaining_with_images(self, mock_storage):
|
||||
"""Test that images work correctly with conversation thread chaining."""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Mock the Redis operations to return success
|
||||
mock_client.set.return_value = True
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestIntelligentFallback:
|
||||
def test_prefers_openai_o3_mini_when_available(self):
|
||||
"""Test that o4-mini is preferred when OpenAI API key is available"""
|
||||
# Register only OpenAI provider for this test
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
@@ -62,7 +62,7 @@ class TestIntelligentFallback:
|
||||
"""Test that OpenAI is preferred when both API keys are available"""
|
||||
# Register both OpenAI and Gemini providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
@@ -75,7 +75,7 @@ class TestIntelligentFallback:
|
||||
"""Test fallback behavior when no API keys are available"""
|
||||
# Register providers but with no API keys available
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
@@ -86,7 +86,7 @@ class TestIntelligentFallback:
|
||||
def test_available_providers_with_keys(self):
|
||||
"""Test the get_available_providers_with_keys method"""
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False):
|
||||
# Clear and register providers
|
||||
@@ -119,7 +119,7 @@ class TestIntelligentFallback:
|
||||
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False),
|
||||
):
|
||||
# Register only OpenAI provider for this test
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
|
||||
@@ -246,9 +246,9 @@ class TestLargePromptHandling:
|
||||
|
||||
assert len(result) == 1
|
||||
output = json.loads(result[0].text)
|
||||
# The precommit tool may return success or clarification_required depending on git state
|
||||
# The precommit tool may return success or files_required_to_continue depending on git state
|
||||
# The core fix ensures large prompts are detected at the right time
|
||||
assert output["status"] in ["success", "clarification_required", "resend_prompt"]
|
||||
assert output["status"] in ["success", "files_required_to_continue", "resend_prompt"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_large_error_description(self, large_prompt):
|
||||
@@ -298,17 +298,26 @@ class TestLargePromptHandling:
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Mock the centralized file preparation method to avoid file system access
|
||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare_files:
|
||||
mock_prepare_files.return_value = ("File content", [other_file])
|
||||
# Mock handle_prompt_file to verify prompt.txt is handled
|
||||
with patch.object(tool, "handle_prompt_file") as mock_handle_prompt:
|
||||
# Return the prompt content and updated files list (without prompt.txt)
|
||||
mock_handle_prompt.return_value = ("Large prompt content from file", [other_file])
|
||||
|
||||
await tool.execute({"prompt": "", "files": [temp_prompt_file, other_file]})
|
||||
# Mock the centralized file preparation method
|
||||
with patch.object(tool, "_prepare_file_content_for_prompt") as mock_prepare_files:
|
||||
mock_prepare_files.return_value = ("File content", [other_file])
|
||||
|
||||
# Verify prompt.txt was removed from files list
|
||||
mock_prepare_files.assert_called_once()
|
||||
files_arg = mock_prepare_files.call_args[0][0]
|
||||
assert len(files_arg) == 1
|
||||
assert files_arg[0] == other_file
|
||||
# Use a small prompt to avoid triggering size limit
|
||||
await tool.execute({"prompt": "Test prompt", "files": [temp_prompt_file, other_file]})
|
||||
|
||||
# Verify handle_prompt_file was called with the original files list
|
||||
mock_handle_prompt.assert_called_once_with([temp_prompt_file, other_file])
|
||||
|
||||
# Verify _prepare_file_content_for_prompt was called with the updated files list (without prompt.txt)
|
||||
mock_prepare_files.assert_called_once()
|
||||
files_arg = mock_prepare_files.call_args[0][0]
|
||||
assert len(files_arg) == 1
|
||||
assert files_arg[0] == other_file
|
||||
|
||||
temp_dir = os.path.dirname(temp_prompt_file)
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from utils.model_restrictions import ModelRestrictionService
|
||||
|
||||
|
||||
@@ -677,7 +677,7 @@ class TestAutoModeWithRestrictions:
|
||||
# Clear registry and register only OpenAI and Gemini providers
|
||||
ModelProviderRegistry._instance = None
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
@@ -195,7 +195,7 @@ class TestOldBehaviorSimulation:
|
||||
Verify that our fix provides comprehensive alias->target coverage.
|
||||
"""
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
# Test real providers to ensure they implement our fix correctly
|
||||
providers = [OpenAIModelProvider(api_key="test-key"), GeminiModelProvider(api_key="test-key")]
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
|
||||
class TestOpenAIProvider:
|
||||
|
||||
@@ -115,7 +115,7 @@ class TestPlannerTool:
|
||||
"""Test execute method for subsequent planning step."""
|
||||
tool = PlannerTool()
|
||||
arguments = {
|
||||
"step": "Set up Docker containers for each microservice",
|
||||
"step": "Set up deployment configuration for each microservice",
|
||||
"step_number": 2,
|
||||
"total_steps": 8,
|
||||
"next_step_required": True,
|
||||
|
||||
@@ -4,7 +4,6 @@ Enhanced tests for precommit tool using mock storage to test real logic
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -50,21 +49,18 @@ class TestPrecommitToolWithMockStore:
|
||||
"""Test precommit tool with mock storage to validate actual logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
def mock_storage(self):
|
||||
"""Create mock Redis client"""
|
||||
return MockRedisClient()
|
||||
|
||||
@pytest.fixture
|
||||
def tool(self, mock_redis, temp_repo):
|
||||
def tool(self, mock_storage, temp_repo):
|
||||
"""Create tool instance with mocked Redis"""
|
||||
temp_dir, _ = temp_repo
|
||||
tool = Precommit()
|
||||
|
||||
# Mock the Redis client getter and SECURITY_ROOT to allow access to temp files
|
||||
with (
|
||||
patch("utils.conversation_memory.get_redis_client", return_value=mock_redis),
|
||||
patch("utils.file_utils.SECURITY_ROOT", Path(temp_dir).resolve()),
|
||||
):
|
||||
# Mock the Redis client getter to use our mock storage
|
||||
with patch("utils.conversation_memory.get_storage", return_value=mock_storage):
|
||||
yield tool
|
||||
|
||||
@pytest.fixture
|
||||
@@ -112,7 +108,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_duplicate_file_content_in_prompt(self, tool, temp_repo, mock_redis):
|
||||
async def test_no_duplicate_file_content_in_prompt(self, tool, temp_repo, mock_storage):
|
||||
"""Test that file content appears in expected locations
|
||||
|
||||
This test validates our design decision that files can legitimately appear in both:
|
||||
@@ -145,12 +141,12 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
||||
# This is intentional and provides comprehensive context to the AI
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_memory_integration(self, tool, temp_repo, mock_redis):
|
||||
async def test_conversation_memory_integration(self, tool, temp_repo, mock_storage):
|
||||
"""Test that conversation memory works with mock storage"""
|
||||
temp_dir, config_path = temp_repo
|
||||
|
||||
# Mock conversation memory functions to use our mock redis
|
||||
with patch("utils.conversation_memory.get_redis_client", return_value=mock_redis):
|
||||
with patch("utils.conversation_memory.get_storage", return_value=mock_storage):
|
||||
# First request - should embed file content
|
||||
PrecommitRequest(path=temp_dir, files=[config_path], prompt="First review")
|
||||
|
||||
@@ -173,7 +169,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
||||
assert len(files_to_embed_2) == 0, "Continuation should skip already embedded files"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_structure_integrity(self, tool, temp_repo, mock_redis):
|
||||
async def test_prompt_structure_integrity(self, tool, temp_repo, mock_storage):
|
||||
"""Test that the prompt structure is well-formed and doesn't have content duplication"""
|
||||
temp_dir, config_path = temp_repo
|
||||
|
||||
@@ -227,7 +223,7 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
||||
assert '__version__ = "1.0.0"' not in after_file_section
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_content_formatting(self, tool, temp_repo, mock_redis):
|
||||
async def test_file_content_formatting(self, tool, temp_repo, mock_storage):
|
||||
"""Test that file content is properly formatted without duplication"""
|
||||
temp_dir, config_path = temp_repo
|
||||
|
||||
@@ -254,18 +250,18 @@ TEMPERATURE_ANALYTICAL = 0.2 # For code review, debugging
|
||||
assert file_content.count('__version__ = "1.0.0"') == 1
|
||||
|
||||
|
||||
def test_mock_redis_basic_operations():
|
||||
def test_mock_storage_basic_operations():
|
||||
"""Test that our mock Redis implementation works correctly"""
|
||||
mock_redis = MockRedisClient()
|
||||
mock_storage = MockRedisClient()
|
||||
|
||||
# Test basic operations
|
||||
assert mock_redis.get("nonexistent") is None
|
||||
assert mock_redis.exists("nonexistent") == 0
|
||||
assert mock_storage.get("nonexistent") is None
|
||||
assert mock_storage.exists("nonexistent") == 0
|
||||
|
||||
mock_redis.set("test_key", "test_value")
|
||||
assert mock_redis.get("test_key") == "test_value"
|
||||
assert mock_redis.exists("test_key") == 1
|
||||
mock_storage.set("test_key", "test_value")
|
||||
assert mock_storage.get("test_key") == "test_value"
|
||||
assert mock_storage.exists("test_key") == 1
|
||||
|
||||
assert mock_redis.delete("test_key") == 1
|
||||
assert mock_redis.get("test_key") is None
|
||||
assert mock_redis.delete("test_key") == 0 # Already deleted
|
||||
assert mock_storage.delete("test_key") == 1
|
||||
assert mock_storage.get("test_key") is None
|
||||
assert mock_storage.delete("test_key") == 0 # Already deleted
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
from providers import ModelProviderRegistry, ModelResponse
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
|
||||
class TestModelProviderRegistry:
|
||||
|
||||
@@ -3,7 +3,7 @@ Test to verify structured error code-based retry logic.
|
||||
"""
|
||||
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
|
||||
def test_openai_structured_error_retry_logic():
|
||||
|
||||
@@ -84,17 +84,15 @@ class TestSpecialStatusParsing:
|
||||
assert result.content_type == "json"
|
||||
assert "pending_tests" in result.content
|
||||
|
||||
def test_clarification_required_still_works(self):
|
||||
"""Test that existing clarification_required still works"""
|
||||
response_json = (
|
||||
'{"status": "clarification_required", "question": "What files need review?", "files_needed": ["src/"]}'
|
||||
)
|
||||
def test_files_required_to_continue_still_works(self):
|
||||
"""Test that existing files_required_to_continue still works"""
|
||||
response_json = '{"status": "files_required_to_continue", "mandatory_instructions": "What files need review?", "files_needed": ["src/"]}'
|
||||
|
||||
result = self.tool._parse_response(response_json, self.request)
|
||||
|
||||
assert result.status == "clarification_required"
|
||||
assert result.status == "files_required_to_continue"
|
||||
assert result.content_type == "json"
|
||||
assert "question" in result.content
|
||||
assert "mandatory_instructions" in result.content
|
||||
|
||||
def test_invalid_status_payload(self):
|
||||
"""Test that invalid payloads for known statuses are handled gracefully"""
|
||||
@@ -127,7 +125,7 @@ class TestSpecialStatusParsing:
|
||||
|
||||
def test_malformed_json_handled(self):
|
||||
"""Test that malformed JSON is handled gracefully"""
|
||||
response_text = '{"status": "clarification_required", "question": "incomplete json'
|
||||
response_text = '{"status": "files_required_to_continue", "question": "incomplete json'
|
||||
|
||||
result = self.tool._parse_response(response_text, self.request)
|
||||
|
||||
@@ -192,8 +190,8 @@ class TestSpecialStatusParsing:
|
||||
"""Test that special status responses preserve exact JSON format for Claude"""
|
||||
test_cases = [
|
||||
{
|
||||
"input": '{"status": "clarification_required", "question": "What framework to use?", "files_needed": ["tests/"]}',
|
||||
"expected_fields": ["status", "question", "files_needed"],
|
||||
"input": '{"status": "files_required_to_continue", "mandatory_instructions": "What framework to use?", "files_needed": ["tests/"]}',
|
||||
"expected_fields": ["status", "mandatory_instructions", "files_needed"],
|
||||
},
|
||||
{
|
||||
"input": '{"status": "full_codereview_required", "reason": "Codebase too large"}',
|
||||
@@ -223,9 +221,20 @@ class TestSpecialStatusParsing:
|
||||
parsed_content = json.loads(result.content)
|
||||
for field in test_case["expected_fields"]:
|
||||
assert field in parsed_content, f"Field {field} missing from {input_data['status']} response"
|
||||
assert (
|
||||
parsed_content[field] == input_data[field]
|
||||
), f"Field {field} value mismatch in {input_data['status']} response"
|
||||
|
||||
# Special handling for mandatory_instructions which gets enhanced
|
||||
if field == "mandatory_instructions" and input_data["status"] == "files_required_to_continue":
|
||||
# Check that enhanced instructions contain the original message
|
||||
assert parsed_content[field].startswith(
|
||||
input_data[field]
|
||||
), f"Enhanced {field} should start with original value in {input_data['status']} response"
|
||||
assert (
|
||||
"IMPORTANT GUIDANCE:" in parsed_content[field]
|
||||
), f"Enhanced {field} should contain guidance in {input_data['status']} response"
|
||||
else:
|
||||
assert (
|
||||
parsed_content[field] == input_data[field]
|
||||
), f"Field {field} value mismatch in {input_data['status']} response"
|
||||
|
||||
def test_focused_review_required_parsing(self):
|
||||
"""Test that focused_review_required status is parsed correctly"""
|
||||
|
||||
@@ -29,12 +29,14 @@ class TestFileUtils:
|
||||
assert "Error: File does not exist" in content
|
||||
assert tokens > 0
|
||||
|
||||
def test_read_file_content_outside_project_root(self):
|
||||
"""Test that paths outside project root are rejected"""
|
||||
# Try to read a file outside the project root
|
||||
def test_read_file_content_safe_files_allowed(self):
|
||||
"""Test that safe files outside the original project root are now allowed"""
|
||||
# In the new security model, safe files like /etc/passwd
|
||||
# can be read as they're not in the dangerous paths list
|
||||
content, tokens = read_file_content("/etc/passwd")
|
||||
assert "--- ERROR ACCESSING FILE:" in content
|
||||
assert "Path outside workspace" in content
|
||||
# Should successfully read the file
|
||||
assert "--- BEGIN FILE: /etc/passwd ---" in content
|
||||
assert "--- END FILE: /etc/passwd ---" in content
|
||||
assert tokens > 0
|
||||
|
||||
def test_read_file_content_relative_path_rejected(self):
|
||||
|
||||
Reference in New Issue
Block a user