Files
my-pal-mcp-server/tests/test_chat_cross_model_continuation.py
Fahad d2773f488a fix: configure codex with a longer timeout
refactor: param names
2025-10-21 10:35:44 +04:00

205 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Cross-provider continuation tests for ChatTool."""
from __future__ import annotations
import json
import os
import re
import uuid
from pathlib import Path
import pytest
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from tests.transport_helpers import inject_transport
from tools.chat import ChatTool
CASSETTE_DIR = Path(__file__).parent / "openai_cassettes"
CASSETTE_DIR.mkdir(exist_ok=True)
OPENAI_CASSETTE_PATH = CASSETTE_DIR / "chat_cross_step2_gpt5_reminder.json"
GEMINI_CASSETTE_DIR = Path(__file__).parent / "gemini_cassettes"
GEMINI_CASSETTE_DIR.mkdir(exist_ok=True)
GEMINI_REPLAY_ID = "chat_cross/step1_gemini25_flash_number/mldev"
GEMINI_REPLAY_PATH = GEMINI_CASSETTE_DIR / "chat_cross" / "step1_gemini25_flash_number" / "mldev.json"
FIXED_THREAD_ID = uuid.UUID("dbadc23e-c0f4-4853-982f-6c5bc722b5de")
WORD_TO_NUMBER = {
"one": 1,
"two": 2,
"three": 3,
"four": 4,
"five": 5,
"six": 6,
"seven": 7,
"eight": 8,
"nine": 9,
"ten": 10,
}
def _extract_number(text: str) -> str:
digit_match = re.search(r"\b(\d{1,2})\b", text)
if digit_match:
return digit_match.group(1)
lower_text = text.lower()
for word, value in WORD_TO_NUMBER.items():
if re.search(rf"\b{word}\b", lower_text):
return str(value)
return ""
@pytest.mark.asyncio
@pytest.mark.no_mock_provider
async def test_chat_cross_model_continuation(monkeypatch, tmp_path):
"""Verify continuation across Gemini then OpenAI using recorded interactions."""
env_updates = {
"DEFAULT_MODEL": "auto",
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY", ""),
"GEMINI_API_KEY": os.getenv("GEMINI_API_KEY", ""),
}
keys_to_clear = [
"XAI_API_KEY",
"OPENROUTER_API_KEY",
"ANTHROPIC_API_KEY",
"MISTRAL_API_KEY",
"CUSTOM_API_KEY",
"CUSTOM_API_URL",
]
recording_mode = not OPENAI_CASSETTE_PATH.exists() or not GEMINI_REPLAY_PATH.exists()
if recording_mode:
openai_key = env_updates["OPENAI_API_KEY"].strip()
gemini_key = env_updates["GEMINI_API_KEY"].strip()
if (not openai_key or openai_key.startswith("dummy")) or (not gemini_key or gemini_key.startswith("dummy")):
pytest.skip(
"Cross-provider cassette missing and OPENAI_API_KEY/GEMINI_API_KEY not configured. Provide real keys to record."
)
GEMINI_REPLAY_PATH.parent.mkdir(parents=True, exist_ok=True)
# Step 1 Gemini picks a number
with monkeypatch.context() as m:
m.setenv("DEFAULT_MODEL", env_updates["DEFAULT_MODEL"])
m.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-flash")
m.setenv("OPENAI_ALLOWED_MODELS", "gpt-5")
if recording_mode:
m.setenv("OPENAI_API_KEY", env_updates["OPENAI_API_KEY"])
m.setenv("GEMINI_API_KEY", env_updates["GEMINI_API_KEY"])
m.setenv("GOOGLE_GENAI_CLIENT_MODE", "record")
else:
m.setenv("OPENAI_API_KEY", "dummy-key-for-replay")
m.setenv("GEMINI_API_KEY", "dummy-key-for-replay")
m.setenv("GOOGLE_GENAI_CLIENT_MODE", "replay")
m.setenv("GOOGLE_GENAI_REPLAYS_DIRECTORY", str(GEMINI_CASSETTE_DIR))
m.setenv("GOOGLE_GENAI_REPLAY_ID", GEMINI_REPLAY_ID)
for key in keys_to_clear:
m.delenv(key, raising=False)
ModelProviderRegistry.reset_for_testing()
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
from utils import conversation_memory
m.setattr(conversation_memory.uuid, "uuid4", lambda: FIXED_THREAD_ID)
chat_tool = ChatTool()
working_directory = str(tmp_path)
step1_args = {
"prompt": "Pick a number between 1 and 10 and respond with JUST that number.",
"model": "gemini-2.5-flash",
"temperature": 0.2,
"working_directory_absolute_path": working_directory,
}
step1_result = await chat_tool.execute(step1_args)
assert step1_result and step1_result[0].type == "text"
step1_data = json.loads(step1_result[0].text)
assert step1_data["status"] in {"success", "continuation_available"}
assert step1_data.get("metadata", {}).get("provider_used") == "google"
continuation_offer = step1_data.get("continuation_offer")
assert continuation_offer is not None
continuation_id = continuation_offer["continuation_id"]
assert continuation_id
chosen_number = _extract_number(step1_data["content"])
assert chosen_number.isdigit()
assert 1 <= int(chosen_number) <= 10
# Ensure replay is flushed for Gemini recordings
gemini_provider = ModelProviderRegistry.get_provider_for_model("gemini-2.5-flash")
if gemini_provider is not None:
try:
client = gemini_provider.client
if hasattr(client, "close"):
client.close()
finally:
if hasattr(gemini_provider, "_client"):
gemini_provider._client = None
assert GEMINI_REPLAY_PATH.exists()
# Step 2 gpt-5 recalls the number via continuation
with monkeypatch.context() as m:
if recording_mode:
m.setenv("OPENAI_API_KEY", env_updates["OPENAI_API_KEY"])
m.setenv("GEMINI_API_KEY", env_updates["GEMINI_API_KEY"])
m.setenv("GOOGLE_GENAI_CLIENT_MODE", "record")
else:
m.setenv("OPENAI_API_KEY", "dummy-key-for-replay")
m.setenv("GEMINI_API_KEY", "dummy-key-for-replay")
m.setenv("GOOGLE_GENAI_CLIENT_MODE", "replay")
m.setenv("DEFAULT_MODEL", env_updates["DEFAULT_MODEL"])
m.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-flash")
m.setenv("OPENAI_ALLOWED_MODELS", "gpt-5")
m.setenv("GOOGLE_GENAI_REPLAYS_DIRECTORY", str(GEMINI_CASSETTE_DIR))
m.setenv("GOOGLE_GENAI_REPLAY_ID", GEMINI_REPLAY_ID)
for key in keys_to_clear:
m.delenv(key, raising=False)
ModelProviderRegistry.reset_for_testing()
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
inject_transport(monkeypatch, OPENAI_CASSETTE_PATH)
chat_tool = ChatTool()
step2_args = {
"prompt": "Remind me, what number did you pick, respond with JUST that number.",
"model": "gpt-5",
"continuation_id": continuation_id,
"temperature": 0.2,
"working_directory_absolute_path": working_directory,
}
step2_result = await chat_tool.execute(step2_args)
assert step2_result and step2_result[0].type == "text"
step2_data = json.loads(step2_result[0].text)
assert step2_data["status"] in {"success", "continuation_available"}
assert step2_data.get("metadata", {}).get("provider_used") == "openai"
recalled_number = _extract_number(step2_data["content"])
assert recalled_number == chosen_number
assert OPENAI_CASSETTE_PATH.exists()
ModelProviderRegistry.reset_for_testing()