fix: improved conversation retrieval
This commit is contained in:
36
tests/test_conversation_continuation_integration.py
Normal file
36
tests/test_conversation_continuation_integration.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Integration test for conversation continuation persistence."""
|
||||
|
||||
from tools.chat import ChatRequest, ChatTool
|
||||
from utils.conversation_memory import get_thread
|
||||
from utils.storage_backend import get_storage_backend
|
||||
|
||||
|
||||
def test_first_response_persisted_in_conversation_history():
|
||||
"""Ensure the assistant's initial reply is stored for newly created threads."""
|
||||
|
||||
# Clear in-memory storage to avoid cross-test contamination
|
||||
storage = get_storage_backend()
|
||||
storage._store.clear() # type: ignore[attr-defined]
|
||||
|
||||
tool = ChatTool()
|
||||
request = ChatRequest(prompt="First question?", model="local-llama")
|
||||
response_text = "Here is the initial answer."
|
||||
|
||||
# Mimic the first tool invocation (no continuation_id supplied)
|
||||
continuation_data = tool._create_continuation_offer(request, model_info={"model_name": "local-llama"})
|
||||
tool._create_continuation_offer_response(
|
||||
response_text,
|
||||
continuation_data,
|
||||
request,
|
||||
{"model_name": "local-llama", "provider": "custom"},
|
||||
)
|
||||
|
||||
thread_id = continuation_data["continuation_id"]
|
||||
thread = get_thread(thread_id)
|
||||
|
||||
assert thread is not None
|
||||
assert [turn.role for turn in thread.turns] == ["user", "assistant"]
|
||||
assert thread.turns[-1].content == response_text
|
||||
|
||||
# Cleanup storage for subsequent tests
|
||||
storage._store.clear() # type: ignore[attr-defined]
|
||||
@@ -582,44 +582,7 @@ class SimpleTool(BaseTool):
|
||||
# Handle conversation continuation like old base.py
|
||||
continuation_id = self.get_request_continuation_id(request)
|
||||
if continuation_id:
|
||||
# Add turn to conversation memory
|
||||
from utils.conversation_memory import add_turn
|
||||
|
||||
# Extract model metadata for conversation tracking
|
||||
model_provider = None
|
||||
model_name = None
|
||||
model_metadata = None
|
||||
|
||||
if model_info:
|
||||
provider = model_info.get("provider")
|
||||
if provider:
|
||||
# Handle both provider objects and string values
|
||||
if isinstance(provider, str):
|
||||
model_provider = provider
|
||||
else:
|
||||
try:
|
||||
model_provider = provider.get_provider_type().value
|
||||
except AttributeError:
|
||||
# Fallback if provider doesn't have get_provider_type method
|
||||
model_provider = str(provider)
|
||||
model_name = model_info.get("model_name")
|
||||
model_response = model_info.get("model_response")
|
||||
if model_response:
|
||||
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
|
||||
|
||||
# Only add the assistant's response to the conversation
|
||||
# The user's turn is handled elsewhere (when thread is created/continued)
|
||||
add_turn(
|
||||
continuation_id, # thread_id as positional argument
|
||||
"assistant", # role as positional argument
|
||||
raw_text, # content as positional argument
|
||||
files=self.get_request_files(request),
|
||||
images=self.get_request_images(request),
|
||||
tool_name=self.get_name(),
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
model_metadata=model_metadata,
|
||||
)
|
||||
self._record_assistant_turn(continuation_id, raw_text, request, model_info)
|
||||
|
||||
# Create continuation offer like old base.py
|
||||
continuation_data = self._create_continuation_offer(request, model_info)
|
||||
@@ -708,6 +671,14 @@ class SimpleTool(BaseTool):
|
||||
from tools.models import ContinuationOffer, ToolOutput
|
||||
|
||||
try:
|
||||
if not self.get_request_continuation_id(request):
|
||||
self._record_assistant_turn(
|
||||
continuation_data["continuation_id"],
|
||||
content,
|
||||
request,
|
||||
model_info,
|
||||
)
|
||||
|
||||
continuation_offer = ContinuationOffer(
|
||||
continuation_id=continuation_data["continuation_id"],
|
||||
note=continuation_data["note"],
|
||||
@@ -743,6 +714,45 @@ class SimpleTool(BaseTool):
|
||||
# Fallback to simple success if continuation offer fails
|
||||
return ToolOutput(status="success", content=content, content_type="text")
|
||||
|
||||
def _record_assistant_turn(self, continuation_id: str, response_text: str, request, model_info: Optional[dict]) -> None:
|
||||
"""Persist an assistant response in conversation memory."""
|
||||
|
||||
if not continuation_id:
|
||||
return
|
||||
|
||||
from utils.conversation_memory import add_turn
|
||||
|
||||
model_provider = None
|
||||
model_name = None
|
||||
model_metadata = None
|
||||
|
||||
if model_info:
|
||||
provider = model_info.get("provider")
|
||||
if provider:
|
||||
if isinstance(provider, str):
|
||||
model_provider = provider
|
||||
else:
|
||||
try:
|
||||
model_provider = provider.get_provider_type().value
|
||||
except AttributeError:
|
||||
model_provider = str(provider)
|
||||
model_name = model_info.get("model_name")
|
||||
model_response = model_info.get("model_response")
|
||||
if model_response:
|
||||
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
|
||||
|
||||
add_turn(
|
||||
continuation_id,
|
||||
"assistant",
|
||||
response_text,
|
||||
files=self.get_request_files(request),
|
||||
images=self.get_request_images(request),
|
||||
tool_name=self.get_name(),
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
model_metadata=model_metadata,
|
||||
)
|
||||
|
||||
# Convenience methods for common tool patterns
|
||||
|
||||
def build_standard_prompt(
|
||||
|
||||
Reference in New Issue
Block a user