test: cross tool memory recall, testing continuation via cassette recording
This commit is contained in:
196
tests/test_chat_cross_model_continuation.py
Normal file
196
tests/test_chat_cross_model_continuation.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Cross-provider continuation tests for ChatTool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from tests.transport_helpers import inject_transport
|
||||
from tools.chat import ChatTool
|
||||
|
||||
CASSETTE_DIR = Path(__file__).parent / "openai_cassettes"
|
||||
CASSETTE_DIR.mkdir(exist_ok=True)
|
||||
OPENAI_CASSETTE_PATH = CASSETTE_DIR / "chat_cross_step2_gpt5_reminder.json"
|
||||
|
||||
GEMINI_CASSETTE_DIR = Path(__file__).parent / "gemini_cassettes"
|
||||
GEMINI_CASSETTE_DIR.mkdir(exist_ok=True)
|
||||
GEMINI_REPLAY_ID = "chat_cross/step1_gemini25_flash_number/mldev"
|
||||
GEMINI_REPLAY_PATH = GEMINI_CASSETTE_DIR / "chat_cross" / "step1_gemini25_flash_number" / "mldev.json"
|
||||
|
||||
FIXED_THREAD_ID = uuid.UUID("dbadc23e-c0f4-4853-982f-6c5bc722b5de")
|
||||
|
||||
|
||||
WORD_TO_NUMBER = {
|
||||
"one": 1,
|
||||
"two": 2,
|
||||
"three": 3,
|
||||
"four": 4,
|
||||
"five": 5,
|
||||
"six": 6,
|
||||
"seven": 7,
|
||||
"eight": 8,
|
||||
"nine": 9,
|
||||
"ten": 10,
|
||||
}
|
||||
|
||||
|
||||
def _extract_number(text: str) -> str:
|
||||
digit_match = re.search(r"\b(\d{1,2})\b", text)
|
||||
if digit_match:
|
||||
return digit_match.group(1)
|
||||
|
||||
lower_text = text.lower()
|
||||
for word, value in WORD_TO_NUMBER.items():
|
||||
if re.search(rf"\b{word}\b", lower_text):
|
||||
return str(value)
|
||||
return ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.no_mock_provider
|
||||
async def test_chat_cross_model_continuation(monkeypatch):
|
||||
"""Verify continuation across Gemini then OpenAI using recorded interactions."""
|
||||
|
||||
env_updates = {
|
||||
"DEFAULT_MODEL": "auto",
|
||||
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY", ""),
|
||||
"GEMINI_API_KEY": os.getenv("GEMINI_API_KEY", ""),
|
||||
}
|
||||
keys_to_clear = [
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"ANTHROPIC_API_KEY",
|
||||
"MISTRAL_API_KEY",
|
||||
"CUSTOM_API_KEY",
|
||||
"CUSTOM_API_URL",
|
||||
]
|
||||
|
||||
recording_mode = not OPENAI_CASSETTE_PATH.exists() or not GEMINI_REPLAY_PATH.exists()
|
||||
if recording_mode:
|
||||
openai_key = env_updates["OPENAI_API_KEY"].strip()
|
||||
gemini_key = env_updates["GEMINI_API_KEY"].strip()
|
||||
if (not openai_key or openai_key.startswith("dummy")) or (not gemini_key or gemini_key.startswith("dummy")):
|
||||
pytest.skip(
|
||||
"Cross-provider cassette missing and OPENAI_API_KEY/GEMINI_API_KEY not configured. Provide real keys to record."
|
||||
)
|
||||
|
||||
GEMINI_REPLAY_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Step 1 – Gemini picks a number
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("DEFAULT_MODEL", env_updates["DEFAULT_MODEL"])
|
||||
if recording_mode:
|
||||
m.setenv("OPENAI_API_KEY", env_updates["OPENAI_API_KEY"])
|
||||
m.setenv("GEMINI_API_KEY", env_updates["GEMINI_API_KEY"])
|
||||
m.setenv("GOOGLE_GENAI_CLIENT_MODE", "record")
|
||||
else:
|
||||
m.setenv("OPENAI_API_KEY", "dummy-key-for-replay")
|
||||
m.setenv("GEMINI_API_KEY", "dummy-key-for-replay")
|
||||
m.setenv("GOOGLE_GENAI_CLIENT_MODE", "replay")
|
||||
|
||||
m.setenv("GOOGLE_GENAI_REPLAYS_DIRECTORY", str(GEMINI_CASSETTE_DIR))
|
||||
m.setenv("GOOGLE_GENAI_REPLAY_ID", GEMINI_REPLAY_ID)
|
||||
|
||||
for key in keys_to_clear:
|
||||
m.delenv(key, raising=False)
|
||||
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
from utils import conversation_memory
|
||||
|
||||
m.setattr(conversation_memory.uuid, "uuid4", lambda: FIXED_THREAD_ID)
|
||||
|
||||
chat_tool = ChatTool()
|
||||
step1_args = {
|
||||
"prompt": "Pick a number between 1 and 10 and respond with JUST that number.",
|
||||
"model": "gemini-2.5-flash",
|
||||
"temperature": 0.2,
|
||||
}
|
||||
|
||||
step1_result = await chat_tool.execute(step1_args)
|
||||
assert step1_result and step1_result[0].type == "text"
|
||||
|
||||
step1_data = json.loads(step1_result[0].text)
|
||||
assert step1_data["status"] in {"success", "continuation_available"}
|
||||
assert step1_data.get("metadata", {}).get("provider_used") == "google"
|
||||
continuation_offer = step1_data.get("continuation_offer")
|
||||
assert continuation_offer is not None
|
||||
continuation_id = continuation_offer["continuation_id"]
|
||||
assert continuation_id
|
||||
|
||||
chosen_number = _extract_number(step1_data["content"])
|
||||
assert chosen_number.isdigit()
|
||||
assert 1 <= int(chosen_number) <= 10
|
||||
|
||||
# Ensure replay is flushed for Gemini recordings
|
||||
gemini_provider = ModelProviderRegistry.get_provider_for_model("gemini-2.5-flash")
|
||||
if gemini_provider is not None:
|
||||
try:
|
||||
client = gemini_provider.client
|
||||
if hasattr(client, "close"):
|
||||
client.close()
|
||||
finally:
|
||||
if hasattr(gemini_provider, "_client"):
|
||||
gemini_provider._client = None
|
||||
|
||||
assert GEMINI_REPLAY_PATH.exists()
|
||||
|
||||
# Step 2 – gpt-5 recalls the number via continuation
|
||||
with monkeypatch.context() as m:
|
||||
if recording_mode:
|
||||
m.setenv("OPENAI_API_KEY", env_updates["OPENAI_API_KEY"])
|
||||
m.setenv("GEMINI_API_KEY", env_updates["GEMINI_API_KEY"])
|
||||
m.setenv("GOOGLE_GENAI_CLIENT_MODE", "record")
|
||||
else:
|
||||
m.setenv("OPENAI_API_KEY", "dummy-key-for-replay")
|
||||
m.setenv("GEMINI_API_KEY", "dummy-key-for-replay")
|
||||
m.setenv("GOOGLE_GENAI_CLIENT_MODE", "replay")
|
||||
|
||||
m.setenv("DEFAULT_MODEL", env_updates["DEFAULT_MODEL"])
|
||||
m.setenv("GOOGLE_GENAI_REPLAYS_DIRECTORY", str(GEMINI_CASSETTE_DIR))
|
||||
m.setenv("GOOGLE_GENAI_REPLAY_ID", GEMINI_REPLAY_ID)
|
||||
for key in keys_to_clear:
|
||||
m.delenv(key, raising=False)
|
||||
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
inject_transport(monkeypatch, OPENAI_CASSETTE_PATH)
|
||||
|
||||
chat_tool = ChatTool()
|
||||
step2_args = {
|
||||
"prompt": "Remind me, what number did you pick, respond with JUST that number.",
|
||||
"model": "gpt-5",
|
||||
"continuation_id": continuation_id,
|
||||
"temperature": 0.2,
|
||||
}
|
||||
|
||||
step2_result = await chat_tool.execute(step2_args)
|
||||
assert step2_result and step2_result[0].type == "text"
|
||||
|
||||
step2_data = json.loads(step2_result[0].text)
|
||||
assert step2_data["status"] in {"success", "continuation_available"}
|
||||
assert step2_data.get("metadata", {}).get("provider_used") == "openai"
|
||||
|
||||
recalled_number = _extract_number(step2_data["content"])
|
||||
assert recalled_number == chosen_number
|
||||
|
||||
assert OPENAI_CASSETTE_PATH.exists()
|
||||
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
Reference in New Issue
Block a user