Simplified thread continuations
Fixed and improved tests
This commit is contained in:
@@ -13,7 +13,6 @@ from pydantic import Field
|
||||
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
from tools.base import BaseTool, ToolRequest
|
||||
from tools.models import ContinuationOffer, ToolOutput
|
||||
from utils.conversation_memory import MAX_CONVERSATION_TURNS
|
||||
|
||||
|
||||
@@ -59,58 +58,97 @@ class TestClaudeContinuationOffers:
|
||||
self.tool = ClaudeContinuationTool()
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_new_conversation_offers_continuation(self, mock_redis):
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_new_conversation_offers_continuation(self, mock_redis):
|
||||
"""Test that new conversations offer Claude continuation opportunity"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Test request without continuation_id (new conversation)
|
||||
request = ContinuationRequest(prompt="Analyze this code")
|
||||
# Mock the model
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Analysis complete.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Check continuation opportunity
|
||||
continuation_data = self.tool._check_continuation_opportunity(request)
|
||||
# Execute tool without continuation_id (new conversation)
|
||||
arguments = {"prompt": "Analyze this code"}
|
||||
response = await self.tool.execute(arguments)
|
||||
|
||||
assert continuation_data is not None
|
||||
assert continuation_data["remaining_turns"] == MAX_CONVERSATION_TURNS - 1
|
||||
assert continuation_data["tool_name"] == "test_continuation"
|
||||
# Parse response
|
||||
response_data = json.loads(response[0].text)
|
||||
|
||||
def test_existing_conversation_no_continuation_offer(self):
|
||||
"""Test that existing threaded conversations don't offer continuation"""
|
||||
# Test request with continuation_id (existing conversation)
|
||||
request = ContinuationRequest(
|
||||
prompt="Continue analysis", continuation_id="12345678-1234-1234-1234-123456789012"
|
||||
)
|
||||
|
||||
# Check continuation opportunity
|
||||
continuation_data = self.tool._check_continuation_opportunity(request)
|
||||
|
||||
assert continuation_data is None
|
||||
# Should offer continuation for new conversation
|
||||
assert response_data["status"] == "continuation_available"
|
||||
assert "continuation_offer" in response_data
|
||||
assert response_data["continuation_offer"]["remaining_turns"] == MAX_CONVERSATION_TURNS - 1
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_create_continuation_offer_response(self, mock_redis):
|
||||
"""Test creating continuation offer response"""
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_existing_conversation_still_offers_continuation(self, mock_redis):
|
||||
"""Test that existing threaded conversations still offer continuation if turns remain"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
request = ContinuationRequest(prompt="Test prompt")
|
||||
content = "This is the analysis result."
|
||||
continuation_data = {"remaining_turns": 4, "tool_name": "test_continuation"}
|
||||
# Mock existing thread context with 2 turns
|
||||
from utils.conversation_memory import ConversationTurn, ThreadContext
|
||||
|
||||
# Create continuation offer response
|
||||
response = self.tool._create_continuation_offer_response(content, continuation_data, request)
|
||||
thread_context = ThreadContext(
|
||||
thread_id="12345678-1234-1234-1234-123456789012",
|
||||
created_at="2023-01-01T00:00:00Z",
|
||||
last_updated_at="2023-01-01T00:01:00Z",
|
||||
tool_name="test_continuation",
|
||||
turns=[
|
||||
ConversationTurn(
|
||||
role="assistant",
|
||||
content="Previous response",
|
||||
timestamp="2023-01-01T00:00:30Z",
|
||||
tool_name="test_continuation",
|
||||
),
|
||||
ConversationTurn(
|
||||
role="user",
|
||||
content="Follow up question",
|
||||
timestamp="2023-01-01T00:01:00Z",
|
||||
),
|
||||
],
|
||||
initial_context={"prompt": "Initial analysis"},
|
||||
)
|
||||
mock_client.get.return_value = thread_context.model_dump_json()
|
||||
|
||||
assert isinstance(response, ToolOutput)
|
||||
assert response.status == "continuation_available"
|
||||
assert response.content == content
|
||||
assert response.continuation_offer is not None
|
||||
# Mock the model
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Continued analysis.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
offer = response.continuation_offer
|
||||
assert isinstance(offer, ContinuationOffer)
|
||||
assert offer.remaining_turns == 4
|
||||
assert "continuation_id" in offer.suggested_tool_params
|
||||
assert "You have 4 more exchange(s) available" in offer.message_to_user
|
||||
# Execute tool with continuation_id
|
||||
arguments = {"prompt": "Continue analysis", "continuation_id": "12345678-1234-1234-1234-123456789012"}
|
||||
response = await self.tool.execute(arguments)
|
||||
|
||||
# Parse response
|
||||
response_data = json.loads(response[0].text)
|
||||
|
||||
# Should still offer continuation since turns remain
|
||||
assert response_data["status"] == "continuation_available"
|
||||
assert "continuation_offer" in response_data
|
||||
# 10 max - 2 existing - 1 new = 7 remaining
|
||||
assert response_data["continuation_offer"]["remaining_turns"] == 7
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_full_response_flow_with_continuation_offer(self, mock_redis):
|
||||
"""Test complete response flow that creates continuation offer"""
|
||||
mock_client = Mock()
|
||||
@@ -152,26 +190,21 @@ class TestClaudeContinuationOffers:
|
||||
assert "more exchange(s) available" in offer["message_to_user"]
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
async def test_gemini_follow_up_takes_precedence(self, mock_redis):
|
||||
"""Test that Gemini follow-up questions take precedence over continuation offers"""
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_continuation_always_offered_with_natural_language(self, mock_redis):
|
||||
"""Test that continuation is always offered with natural language prompts"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Mock the model to return a response WITH follow-up question
|
||||
# Mock the model to return a response with natural language follow-up
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
# Include follow-up JSON in the content
|
||||
# Include natural language follow-up in the content
|
||||
content_with_followup = """Analysis complete. The code looks good.
|
||||
|
||||
```json
|
||||
{
|
||||
"follow_up_question": "Would you like me to examine the error handling patterns?",
|
||||
"suggested_params": {"files": ["/src/error_handler.py"]},
|
||||
"ui_hint": "Examining error handling would help ensure robustness"
|
||||
}
|
||||
```"""
|
||||
I'd be happy to examine the error handling patterns in more detail if that would be helpful."""
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=content_with_followup,
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
@@ -187,12 +220,13 @@ class TestClaudeContinuationOffers:
|
||||
# Parse response
|
||||
response_data = json.loads(response[0].text)
|
||||
|
||||
# Should be follow-up, not continuation offer
|
||||
assert response_data["status"] == "requires_continuation"
|
||||
assert "follow_up_request" in response_data
|
||||
assert response_data.get("continuation_offer") is None
|
||||
# Should always offer continuation
|
||||
assert response_data["status"] == "continuation_available"
|
||||
assert "continuation_offer" in response_data
|
||||
assert response_data["continuation_offer"]["remaining_turns"] == MAX_CONVERSATION_TURNS - 1
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_threaded_conversation_with_continuation_offer(self, mock_redis):
|
||||
"""Test that threaded conversations still get continuation offers when turns remain"""
|
||||
mock_client = Mock()
|
||||
@@ -236,81 +270,60 @@ class TestClaudeContinuationOffers:
|
||||
assert response_data.get("continuation_offer") is not None
|
||||
assert response_data["continuation_offer"]["remaining_turns"] == 9
|
||||
|
||||
def test_max_turns_reached_no_continuation_offer(self):
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_max_turns_reached_no_continuation_offer(self, mock_redis):
|
||||
"""Test that no continuation is offered when max turns would be exceeded"""
|
||||
# Mock MAX_CONVERSATION_TURNS to be 1 for this test
|
||||
with patch("tools.base.MAX_CONVERSATION_TURNS", 1):
|
||||
request = ContinuationRequest(prompt="Test prompt")
|
||||
|
||||
# Check continuation opportunity
|
||||
continuation_data = self.tool._check_continuation_opportunity(request)
|
||||
|
||||
# Should be None because remaining_turns would be 0
|
||||
assert continuation_data is None
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_continuation_offer_thread_creation_failure_fallback(self, mock_redis):
|
||||
"""Test fallback to normal response when thread creation fails"""
|
||||
# Mock Redis to fail
|
||||
mock_client = Mock()
|
||||
mock_client.setex.side_effect = Exception("Redis failure")
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
request = ContinuationRequest(prompt="Test prompt")
|
||||
content = "Analysis result"
|
||||
continuation_data = {"remaining_turns": 4, "tool_name": "test_continuation"}
|
||||
|
||||
# Should fallback to normal response
|
||||
response = self.tool._create_continuation_offer_response(content, continuation_data, request)
|
||||
|
||||
assert response.status == "success"
|
||||
assert response.content == content
|
||||
assert response.continuation_offer is None
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_continuation_offer_message_format(self, mock_redis):
|
||||
"""Test that continuation offer message is properly formatted for Claude"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
request = ContinuationRequest(prompt="Analyze architecture")
|
||||
content = "Architecture analysis complete."
|
||||
continuation_data = {"remaining_turns": 3, "tool_name": "test_continuation"}
|
||||
# Mock existing thread context at max turns
|
||||
from utils.conversation_memory import ConversationTurn, ThreadContext
|
||||
|
||||
response = self.tool._create_continuation_offer_response(content, continuation_data, request)
|
||||
# Create turns at the limit (MAX_CONVERSATION_TURNS - 1 since we're about to add one)
|
||||
turns = [
|
||||
ConversationTurn(
|
||||
role="assistant" if i % 2 else "user",
|
||||
content=f"Turn {i+1}",
|
||||
timestamp="2023-01-01T00:00:00Z",
|
||||
tool_name="test_continuation",
|
||||
)
|
||||
for i in range(MAX_CONVERSATION_TURNS - 1)
|
||||
]
|
||||
|
||||
offer = response.continuation_offer
|
||||
message = offer.message_to_user
|
||||
thread_context = ThreadContext(
|
||||
thread_id="12345678-1234-1234-1234-123456789012",
|
||||
created_at="2023-01-01T00:00:00Z",
|
||||
last_updated_at="2023-01-01T00:01:00Z",
|
||||
tool_name="test_continuation",
|
||||
turns=turns,
|
||||
initial_context={"prompt": "Initial"},
|
||||
)
|
||||
mock_client.get.return_value = thread_context.model_dump_json()
|
||||
|
||||
# Check message contains key information for Claude
|
||||
assert "continue this analysis" in message
|
||||
assert "continuation_id" in message
|
||||
assert "test_continuation tool call" in message
|
||||
assert "3 more exchange(s)" in message
|
||||
# Mock the model
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Final response.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# Check suggested params are properly formatted
|
||||
suggested_params = offer.suggested_tool_params
|
||||
assert "continuation_id" in suggested_params
|
||||
assert "prompt" in suggested_params
|
||||
assert isinstance(suggested_params["continuation_id"], str)
|
||||
# Execute tool with continuation_id at max turns
|
||||
arguments = {"prompt": "Final question", "continuation_id": "12345678-1234-1234-1234-123456789012"}
|
||||
response = await self.tool.execute(arguments)
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_continuation_offer_metadata(self, mock_redis):
|
||||
"""Test that continuation offer includes proper metadata"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
# Parse response
|
||||
response_data = json.loads(response[0].text)
|
||||
|
||||
request = ContinuationRequest(prompt="Test")
|
||||
content = "Test content"
|
||||
continuation_data = {"remaining_turns": 2, "tool_name": "test_continuation"}
|
||||
|
||||
response = self.tool._create_continuation_offer_response(content, continuation_data, request)
|
||||
|
||||
metadata = response.metadata
|
||||
assert metadata["tool_name"] == "test_continuation"
|
||||
assert metadata["remaining_turns"] == 2
|
||||
assert "thread_id" in metadata
|
||||
assert len(metadata["thread_id"]) == 36 # UUID length
|
||||
# Should NOT offer continuation since we're at max turns
|
||||
assert response_data["status"] == "success"
|
||||
assert response_data.get("continuation_offer") is None
|
||||
|
||||
|
||||
class TestContinuationIntegration:
|
||||
@@ -320,7 +333,8 @@ class TestContinuationIntegration:
|
||||
self.tool = ClaudeContinuationTool()
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_continuation_offer_creates_proper_thread(self, mock_redis):
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_continuation_offer_creates_proper_thread(self, mock_redis):
|
||||
"""Test that continuation offers create properly formatted threads"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
@@ -336,77 +350,119 @@ class TestContinuationIntegration:
|
||||
|
||||
mock_client.get.side_effect = side_effect_get
|
||||
|
||||
request = ContinuationRequest(prompt="Initial analysis", files=["/test/file.py"])
|
||||
content = "Analysis result"
|
||||
continuation_data = {"remaining_turns": 4, "tool_name": "test_continuation"}
|
||||
# Mock the model
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Analysis result",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
self.tool._create_continuation_offer_response(content, continuation_data, request)
|
||||
# Execute tool for initial analysis
|
||||
arguments = {"prompt": "Initial analysis", "files": ["/test/file.py"]}
|
||||
response = await self.tool.execute(arguments)
|
||||
|
||||
# Verify thread creation was called (should be called twice: create_thread + add_turn)
|
||||
assert mock_client.setex.call_count == 2
|
||||
# Parse response
|
||||
response_data = json.loads(response[0].text)
|
||||
|
||||
# Check the first call (create_thread)
|
||||
first_call = mock_client.setex.call_args_list[0]
|
||||
thread_key = first_call[0][0]
|
||||
assert thread_key.startswith("thread:")
|
||||
assert len(thread_key.split(":")[-1]) == 36 # UUID length
|
||||
# Should offer continuation
|
||||
assert response_data["status"] == "continuation_available"
|
||||
assert "continuation_offer" in response_data
|
||||
|
||||
# Check the second call (add_turn) which should have the assistant response
|
||||
second_call = mock_client.setex.call_args_list[1]
|
||||
thread_data = second_call[0][2]
|
||||
thread_context = json.loads(thread_data)
|
||||
# Verify thread creation was called (should be called twice: create_thread + add_turn)
|
||||
assert mock_client.setex.call_count == 2
|
||||
|
||||
assert thread_context["tool_name"] == "test_continuation"
|
||||
assert len(thread_context["turns"]) == 1 # Assistant's response added
|
||||
assert thread_context["turns"][0]["role"] == "assistant"
|
||||
assert thread_context["turns"][0]["content"] == content
|
||||
assert thread_context["turns"][0]["files"] == ["/test/file.py"] # Files from request
|
||||
assert thread_context["initial_context"]["prompt"] == "Initial analysis"
|
||||
assert thread_context["initial_context"]["files"] == ["/test/file.py"]
|
||||
# Check the first call (create_thread)
|
||||
first_call = mock_client.setex.call_args_list[0]
|
||||
thread_key = first_call[0][0]
|
||||
assert thread_key.startswith("thread:")
|
||||
assert len(thread_key.split(":")[-1]) == 36 # UUID length
|
||||
|
||||
# Check the second call (add_turn) which should have the assistant response
|
||||
second_call = mock_client.setex.call_args_list[1]
|
||||
thread_data = second_call[0][2]
|
||||
thread_context = json.loads(thread_data)
|
||||
|
||||
assert thread_context["tool_name"] == "test_continuation"
|
||||
assert len(thread_context["turns"]) == 1 # Assistant's response added
|
||||
assert thread_context["turns"][0]["role"] == "assistant"
|
||||
assert thread_context["turns"][0]["content"] == "Analysis result"
|
||||
assert thread_context["turns"][0]["files"] == ["/test/file.py"] # Files from request
|
||||
assert thread_context["initial_context"]["prompt"] == "Initial analysis"
|
||||
assert thread_context["initial_context"]["files"] == ["/test/file.py"]
|
||||
|
||||
@patch("utils.conversation_memory.get_redis_client")
|
||||
def test_claude_can_use_continuation_id(self, mock_redis):
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_claude_can_use_continuation_id(self, mock_redis):
|
||||
"""Test that Claude can use the provided continuation_id in subsequent calls"""
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Step 1: Initial request creates continuation offer
|
||||
request1 = ToolRequest(prompt="Analyze code structure")
|
||||
continuation_data = {"remaining_turns": 4, "tool_name": "test_continuation"}
|
||||
response1 = self.tool._create_continuation_offer_response(
|
||||
"Structure analysis done.", continuation_data, request1
|
||||
)
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Structure analysis done.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
thread_id = response1.continuation_offer.continuation_id
|
||||
# Execute initial request
|
||||
arguments = {"prompt": "Analyze code structure"}
|
||||
response = await self.tool.execute(arguments)
|
||||
|
||||
# Step 2: Mock the thread context for Claude's follow-up
|
||||
from utils.conversation_memory import ConversationTurn, ThreadContext
|
||||
# Parse response
|
||||
response_data = json.loads(response[0].text)
|
||||
thread_id = response_data["continuation_offer"]["continuation_id"]
|
||||
|
||||
existing_context = ThreadContext(
|
||||
thread_id=thread_id,
|
||||
created_at="2023-01-01T00:00:00Z",
|
||||
last_updated_at="2023-01-01T00:01:00Z",
|
||||
tool_name="test_continuation",
|
||||
turns=[
|
||||
ConversationTurn(
|
||||
role="assistant",
|
||||
content="Structure analysis done.",
|
||||
timestamp="2023-01-01T00:00:30Z",
|
||||
tool_name="test_continuation",
|
||||
)
|
||||
],
|
||||
initial_context={"prompt": "Analyze code structure"},
|
||||
)
|
||||
mock_client.get.return_value = existing_context.model_dump_json()
|
||||
# Step 2: Mock the thread context for Claude's follow-up
|
||||
from utils.conversation_memory import ConversationTurn, ThreadContext
|
||||
|
||||
# Step 3: Claude uses continuation_id
|
||||
request2 = ToolRequest(prompt="Now analyze the performance aspects", continuation_id=thread_id)
|
||||
existing_context = ThreadContext(
|
||||
thread_id=thread_id,
|
||||
created_at="2023-01-01T00:00:00Z",
|
||||
last_updated_at="2023-01-01T00:01:00Z",
|
||||
tool_name="test_continuation",
|
||||
turns=[
|
||||
ConversationTurn(
|
||||
role="assistant",
|
||||
content="Structure analysis done.",
|
||||
timestamp="2023-01-01T00:00:30Z",
|
||||
tool_name="test_continuation",
|
||||
)
|
||||
],
|
||||
initial_context={"prompt": "Analyze code structure"},
|
||||
)
|
||||
mock_client.get.return_value = existing_context.model_dump_json()
|
||||
|
||||
# Should still offer continuation if there are remaining turns
|
||||
continuation_data2 = self.tool._check_continuation_opportunity(request2)
|
||||
assert continuation_data2 is not None
|
||||
assert continuation_data2["remaining_turns"] == 8 # MAX_CONVERSATION_TURNS(10) - current_turns(1) - 1
|
||||
assert continuation_data2["tool_name"] == "test_continuation"
|
||||
# Step 3: Claude uses continuation_id
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Performance analysis done.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
|
||||
arguments2 = {"prompt": "Now analyze the performance aspects", "continuation_id": thread_id}
|
||||
response2 = await self.tool.execute(arguments2)
|
||||
|
||||
# Parse response
|
||||
response_data2 = json.loads(response2[0].text)
|
||||
|
||||
# Should still offer continuation if there are remaining turns
|
||||
assert response_data2["status"] == "continuation_available"
|
||||
assert "continuation_offer" in response_data2
|
||||
# 10 max - 1 existing - 1 new = 8 remaining
|
||||
assert response_data2["continuation_offer"]["remaining_turns"] == 8
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user