diff --git a/tests/test_conversation_continuation_integration.py b/tests/test_conversation_continuation_integration.py new file mode 100644 index 0000000..153bd16 --- /dev/null +++ b/tests/test_conversation_continuation_integration.py @@ -0,0 +1,36 @@ +"""Integration test for conversation continuation persistence.""" + +from tools.chat import ChatRequest, ChatTool +from utils.conversation_memory import get_thread +from utils.storage_backend import get_storage_backend + + +def test_first_response_persisted_in_conversation_history(): + """Ensure the assistant's initial reply is stored for newly created threads.""" + + # Clear in-memory storage to avoid cross-test contamination + storage = get_storage_backend() + storage._store.clear() # type: ignore[attr-defined] + + tool = ChatTool() + request = ChatRequest(prompt="First question?", model="local-llama") + response_text = "Here is the initial answer." + + # Mimic the first tool invocation (no continuation_id supplied) + continuation_data = tool._create_continuation_offer(request, model_info={"model_name": "local-llama"}) + tool._create_continuation_offer_response( + response_text, + continuation_data, + request, + {"model_name": "local-llama", "provider": "custom"}, + ) + + thread_id = continuation_data["continuation_id"] + thread = get_thread(thread_id) + + assert thread is not None + assert [turn.role for turn in thread.turns] == ["user", "assistant"] + assert thread.turns[-1].content == response_text + + # Cleanup storage for subsequent tests + storage._store.clear() # type: ignore[attr-defined] diff --git a/tools/simple/base.py b/tools/simple/base.py index 166b09a..df7dd4a 100644 --- a/tools/simple/base.py +++ b/tools/simple/base.py @@ -582,44 +582,7 @@ class SimpleTool(BaseTool): # Handle conversation continuation like old base.py continuation_id = self.get_request_continuation_id(request) if continuation_id: - # Add turn to conversation memory - from utils.conversation_memory import add_turn - - # Extract model metadata for conversation tracking - model_provider = None - model_name = None - model_metadata = None - - if model_info: - provider = model_info.get("provider") - if provider: - # Handle both provider objects and string values - if isinstance(provider, str): - model_provider = provider - else: - try: - model_provider = provider.get_provider_type().value - except AttributeError: - # Fallback if provider doesn't have get_provider_type method - model_provider = str(provider) - model_name = model_info.get("model_name") - model_response = model_info.get("model_response") - if model_response: - model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata} - - # Only add the assistant's response to the conversation - # The user's turn is handled elsewhere (when thread is created/continued) - add_turn( - continuation_id, # thread_id as positional argument - "assistant", # role as positional argument - raw_text, # content as positional argument - files=self.get_request_files(request), - images=self.get_request_images(request), - tool_name=self.get_name(), - model_provider=model_provider, - model_name=model_name, - model_metadata=model_metadata, - ) + self._record_assistant_turn(continuation_id, raw_text, request, model_info) # Create continuation offer like old base.py continuation_data = self._create_continuation_offer(request, model_info) @@ -708,6 +671,14 @@ class SimpleTool(BaseTool): from tools.models import ContinuationOffer, ToolOutput try: + if not self.get_request_continuation_id(request): + self._record_assistant_turn( + continuation_data["continuation_id"], + content, + request, + model_info, + ) + continuation_offer = ContinuationOffer( continuation_id=continuation_data["continuation_id"], note=continuation_data["note"], @@ -743,6 +714,45 @@ class SimpleTool(BaseTool): # Fallback to simple success if continuation offer fails return ToolOutput(status="success", content=content, content_type="text") + def _record_assistant_turn(self, continuation_id: str, response_text: str, request, model_info: Optional[dict]) -> None: + """Persist an assistant response in conversation memory.""" + + if not continuation_id: + return + + from utils.conversation_memory import add_turn + + model_provider = None + model_name = None + model_metadata = None + + if model_info: + provider = model_info.get("provider") + if provider: + if isinstance(provider, str): + model_provider = provider + else: + try: + model_provider = provider.get_provider_type().value + except AttributeError: + model_provider = str(provider) + model_name = model_info.get("model_name") + model_response = model_info.get("model_response") + if model_response: + model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata} + + add_turn( + continuation_id, + "assistant", + response_text, + files=self.get_request_files(request), + images=self.get_request_images(request), + tool_name=self.get_name(), + model_provider=model_provider, + model_name=model_name, + model_metadata=model_metadata, + ) + # Convenience methods for common tool patterns def build_standard_prompt(