fix: improved conversation retrieval
This commit is contained in:
36
tests/test_conversation_continuation_integration.py
Normal file
36
tests/test_conversation_continuation_integration.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""Integration test for conversation continuation persistence."""
|
||||||
|
|
||||||
|
from tools.chat import ChatRequest, ChatTool
|
||||||
|
from utils.conversation_memory import get_thread
|
||||||
|
from utils.storage_backend import get_storage_backend
|
||||||
|
|
||||||
|
|
||||||
|
def test_first_response_persisted_in_conversation_history():
|
||||||
|
"""Ensure the assistant's initial reply is stored for newly created threads."""
|
||||||
|
|
||||||
|
# Clear in-memory storage to avoid cross-test contamination
|
||||||
|
storage = get_storage_backend()
|
||||||
|
storage._store.clear() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
tool = ChatTool()
|
||||||
|
request = ChatRequest(prompt="First question?", model="local-llama")
|
||||||
|
response_text = "Here is the initial answer."
|
||||||
|
|
||||||
|
# Mimic the first tool invocation (no continuation_id supplied)
|
||||||
|
continuation_data = tool._create_continuation_offer(request, model_info={"model_name": "local-llama"})
|
||||||
|
tool._create_continuation_offer_response(
|
||||||
|
response_text,
|
||||||
|
continuation_data,
|
||||||
|
request,
|
||||||
|
{"model_name": "local-llama", "provider": "custom"},
|
||||||
|
)
|
||||||
|
|
||||||
|
thread_id = continuation_data["continuation_id"]
|
||||||
|
thread = get_thread(thread_id)
|
||||||
|
|
||||||
|
assert thread is not None
|
||||||
|
assert [turn.role for turn in thread.turns] == ["user", "assistant"]
|
||||||
|
assert thread.turns[-1].content == response_text
|
||||||
|
|
||||||
|
# Cleanup storage for subsequent tests
|
||||||
|
storage._store.clear() # type: ignore[attr-defined]
|
||||||
@@ -582,44 +582,7 @@ class SimpleTool(BaseTool):
|
|||||||
# Handle conversation continuation like old base.py
|
# Handle conversation continuation like old base.py
|
||||||
continuation_id = self.get_request_continuation_id(request)
|
continuation_id = self.get_request_continuation_id(request)
|
||||||
if continuation_id:
|
if continuation_id:
|
||||||
# Add turn to conversation memory
|
self._record_assistant_turn(continuation_id, raw_text, request, model_info)
|
||||||
from utils.conversation_memory import add_turn
|
|
||||||
|
|
||||||
# Extract model metadata for conversation tracking
|
|
||||||
model_provider = None
|
|
||||||
model_name = None
|
|
||||||
model_metadata = None
|
|
||||||
|
|
||||||
if model_info:
|
|
||||||
provider = model_info.get("provider")
|
|
||||||
if provider:
|
|
||||||
# Handle both provider objects and string values
|
|
||||||
if isinstance(provider, str):
|
|
||||||
model_provider = provider
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
model_provider = provider.get_provider_type().value
|
|
||||||
except AttributeError:
|
|
||||||
# Fallback if provider doesn't have get_provider_type method
|
|
||||||
model_provider = str(provider)
|
|
||||||
model_name = model_info.get("model_name")
|
|
||||||
model_response = model_info.get("model_response")
|
|
||||||
if model_response:
|
|
||||||
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
|
|
||||||
|
|
||||||
# Only add the assistant's response to the conversation
|
|
||||||
# The user's turn is handled elsewhere (when thread is created/continued)
|
|
||||||
add_turn(
|
|
||||||
continuation_id, # thread_id as positional argument
|
|
||||||
"assistant", # role as positional argument
|
|
||||||
raw_text, # content as positional argument
|
|
||||||
files=self.get_request_files(request),
|
|
||||||
images=self.get_request_images(request),
|
|
||||||
tool_name=self.get_name(),
|
|
||||||
model_provider=model_provider,
|
|
||||||
model_name=model_name,
|
|
||||||
model_metadata=model_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create continuation offer like old base.py
|
# Create continuation offer like old base.py
|
||||||
continuation_data = self._create_continuation_offer(request, model_info)
|
continuation_data = self._create_continuation_offer(request, model_info)
|
||||||
@@ -708,6 +671,14 @@ class SimpleTool(BaseTool):
|
|||||||
from tools.models import ContinuationOffer, ToolOutput
|
from tools.models import ContinuationOffer, ToolOutput
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not self.get_request_continuation_id(request):
|
||||||
|
self._record_assistant_turn(
|
||||||
|
continuation_data["continuation_id"],
|
||||||
|
content,
|
||||||
|
request,
|
||||||
|
model_info,
|
||||||
|
)
|
||||||
|
|
||||||
continuation_offer = ContinuationOffer(
|
continuation_offer = ContinuationOffer(
|
||||||
continuation_id=continuation_data["continuation_id"],
|
continuation_id=continuation_data["continuation_id"],
|
||||||
note=continuation_data["note"],
|
note=continuation_data["note"],
|
||||||
@@ -743,6 +714,45 @@ class SimpleTool(BaseTool):
|
|||||||
# Fallback to simple success if continuation offer fails
|
# Fallback to simple success if continuation offer fails
|
||||||
return ToolOutput(status="success", content=content, content_type="text")
|
return ToolOutput(status="success", content=content, content_type="text")
|
||||||
|
|
||||||
|
def _record_assistant_turn(self, continuation_id: str, response_text: str, request, model_info: Optional[dict]) -> None:
|
||||||
|
"""Persist an assistant response in conversation memory."""
|
||||||
|
|
||||||
|
if not continuation_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
from utils.conversation_memory import add_turn
|
||||||
|
|
||||||
|
model_provider = None
|
||||||
|
model_name = None
|
||||||
|
model_metadata = None
|
||||||
|
|
||||||
|
if model_info:
|
||||||
|
provider = model_info.get("provider")
|
||||||
|
if provider:
|
||||||
|
if isinstance(provider, str):
|
||||||
|
model_provider = provider
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
model_provider = provider.get_provider_type().value
|
||||||
|
except AttributeError:
|
||||||
|
model_provider = str(provider)
|
||||||
|
model_name = model_info.get("model_name")
|
||||||
|
model_response = model_info.get("model_response")
|
||||||
|
if model_response:
|
||||||
|
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
|
||||||
|
|
||||||
|
add_turn(
|
||||||
|
continuation_id,
|
||||||
|
"assistant",
|
||||||
|
response_text,
|
||||||
|
files=self.get_request_files(request),
|
||||||
|
images=self.get_request_images(request),
|
||||||
|
tool_name=self.get_name(),
|
||||||
|
model_provider=model_provider,
|
||||||
|
model_name=model_name,
|
||||||
|
model_metadata=model_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
# Convenience methods for common tool patterns
|
# Convenience methods for common tool patterns
|
||||||
|
|
||||||
def build_standard_prompt(
|
def build_standard_prompt(
|
||||||
|
|||||||
Reference in New Issue
Block a user