feat!: Full code can now be generated by an external model and shared with the AI tool (Claude Code / Codex etc)!
model definitions now support a new `allow_code_generation` flag, only to be used with higher reasoning models such as GPT-5-Pro and-Gemini 2.5-Pro When `true`, the `chat` tool can now request the external model to generate a full implementation / update / instructions etc and then share the implementation with the calling agent. This effectively allows us to utilize more powerful models such as GPT-5-Pro to generate code for us or entire implementations (which are either API-only or part of the $200 Pro plan from within the ChatGPT app)
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -137,7 +137,7 @@ class TestAutoMode:
|
||||
importlib.reload(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_requires_model_parameter(self):
|
||||
async def test_auto_mode_requires_model_parameter(self, tmp_path):
|
||||
"""Test that auto mode enforces model parameter"""
|
||||
# Save original
|
||||
original = os.environ.get("DEFAULT_MODEL", "")
|
||||
@@ -154,7 +154,7 @@ class TestAutoMode:
|
||||
# Mock the provider to avoid real API calls
|
||||
with patch.object(tool, "get_model_provider"):
|
||||
# Execute without model parameter
|
||||
result = await tool.execute({"prompt": "Test prompt"})
|
||||
result = await tool.execute({"prompt": "Test prompt", "working_directory": str(tmp_path)})
|
||||
|
||||
# Should get error
|
||||
assert len(result) == 1
|
||||
|
||||
@@ -200,7 +200,7 @@ class TestAutoModeComprehensive:
|
||||
assert tool.get_model_category() == expected_category
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_gemini_only_uses_correct_models(self):
|
||||
async def test_auto_mode_with_gemini_only_uses_correct_models(self, tmp_path):
|
||||
"""Test that auto mode with only Gemini uses flash for fast tools and pro for reasoning tools."""
|
||||
|
||||
provider_config = {
|
||||
@@ -234,9 +234,13 @@ class TestAutoModeComprehensive:
|
||||
)
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
|
||||
workdir = tmp_path / "chat_artifacts"
|
||||
workdir.mkdir(parents=True, exist_ok=True)
|
||||
# Test ChatTool (FAST_RESPONSE) - should prefer flash
|
||||
chat_tool = ChatTool()
|
||||
await chat_tool.execute({"prompt": "test", "model": "auto"}) # This should trigger auto selection
|
||||
await chat_tool.execute(
|
||||
{"prompt": "test", "model": "auto", "working_directory": str(workdir)}
|
||||
) # This should trigger auto selection
|
||||
|
||||
# In auto mode, the tool should get an error requiring model selection
|
||||
# but the suggested model should be flash
|
||||
@@ -355,7 +359,7 @@ class TestAutoModeComprehensive:
|
||||
# would show models from all providers when called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_model_parameter_required_error(self):
|
||||
async def test_auto_mode_model_parameter_required_error(self, tmp_path):
|
||||
"""Test that auto mode properly requires model parameter and suggests correct model."""
|
||||
|
||||
provider_config = {
|
||||
@@ -384,9 +388,12 @@ class TestAutoModeComprehensive:
|
||||
|
||||
# Test with ChatTool (FAST_RESPONSE category)
|
||||
chat_tool = ChatTool()
|
||||
workdir = tmp_path / "chat_artifacts"
|
||||
workdir.mkdir(parents=True, exist_ok=True)
|
||||
result = await chat_tool.execute(
|
||||
{
|
||||
"prompt": "test"
|
||||
"prompt": "test",
|
||||
"working_directory": str(workdir),
|
||||
# Note: no model parameter provided in auto mode
|
||||
}
|
||||
)
|
||||
@@ -508,7 +515,7 @@ class TestAutoModeComprehensive:
|
||||
assert fast_response is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_actual_model_name_resolution_in_auto_mode(self):
|
||||
async def test_actual_model_name_resolution_in_auto_mode(self, tmp_path):
|
||||
"""Test that when a model is selected in auto mode, the tool executes successfully."""
|
||||
|
||||
provider_config = {
|
||||
@@ -547,7 +554,11 @@ class TestAutoModeComprehensive:
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
|
||||
chat_tool = ChatTool()
|
||||
result = await chat_tool.execute({"prompt": "test", "model": "flash"}) # Use alias in auto mode
|
||||
workdir = tmp_path / "chat_artifacts"
|
||||
workdir.mkdir(parents=True, exist_ok=True)
|
||||
result = await chat_tool.execute(
|
||||
{"prompt": "test", "model": "flash", "working_directory": str(workdir)}
|
||||
) # Use alias in auto mode
|
||||
|
||||
# Should succeed with proper model resolution
|
||||
assert len(result) == 1
|
||||
|
||||
113
tests/test_chat_codegen_integration.py
Normal file
113
tests/test_chat_codegen_integration.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Integration test for Chat tool code generation with Gemini 2.5 Pro.
|
||||
|
||||
This test uses the Google Gemini SDK's built-in record/replay support. To refresh the
|
||||
cassette, delete the existing JSON file under
|
||||
``tests/gemini_cassettes/chat_codegen/gemini25_pro_calculator/mldev.json`` and run:
|
||||
|
||||
```
|
||||
GEMINI_API_KEY=<real-key> pytest tests/test_chat_codegen_integration.py::test_chat_codegen_saves_file
|
||||
```
|
||||
|
||||
The test will automatically record a new interaction when the cassette is missing and
|
||||
the environment variable `GEMINI_API_KEY` is set to a valid key.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.registry import ModelProviderRegistry, ProviderType
|
||||
from tools.chat import ChatTool
|
||||
|
||||
REPLAYS_ROOT = Path(__file__).parent / "gemini_cassettes"
|
||||
CASSETTE_DIR = REPLAYS_ROOT / "chat_codegen"
|
||||
CASSETTE_PATH = CASSETTE_DIR / "gemini25_pro_calculator" / "mldev.json"
|
||||
CASSETTE_REPLAY_ID = "chat_codegen/gemini25_pro_calculator/mldev"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.no_mock_provider
|
||||
async def test_chat_codegen_saves_file(monkeypatch, tmp_path):
|
||||
"""Ensure Gemini 2.5 Pro responses create zen_generated.code when code is emitted."""
|
||||
|
||||
CASSETTE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
recording_mode = not CASSETTE_PATH.exists()
|
||||
gemini_key = os.getenv("GEMINI_API_KEY", "")
|
||||
|
||||
if recording_mode:
|
||||
if not gemini_key or gemini_key.startswith("dummy"):
|
||||
pytest.skip("Cassette missing and GEMINI_API_KEY not configured. Provide a real key to record.")
|
||||
client_mode = "record"
|
||||
else:
|
||||
gemini_key = "dummy-key-for-replay"
|
||||
client_mode = "replay"
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("GEMINI_API_KEY", gemini_key)
|
||||
m.setenv("DEFAULT_MODEL", "auto")
|
||||
m.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro")
|
||||
m.setenv("GOOGLE_GENAI_CLIENT_MODE", client_mode)
|
||||
m.setenv("GOOGLE_GENAI_REPLAYS_DIRECTORY", str(REPLAYS_ROOT))
|
||||
m.setenv("GOOGLE_GENAI_REPLAY_ID", CASSETTE_REPLAY_ID)
|
||||
|
||||
# Clear other provider keys to avoid unintended routing
|
||||
for key in ["OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "CUSTOM_API_KEY"]:
|
||||
m.delenv(key, raising=False)
|
||||
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
working_dir = tmp_path / "codegen"
|
||||
working_dir.mkdir()
|
||||
preexisting = working_dir / "zen_generated.code"
|
||||
preexisting.write_text("stale contents", encoding="utf-8")
|
||||
|
||||
chat_tool = ChatTool()
|
||||
prompt = (
|
||||
"Please generate a Python module with functions `add` and `multiply` that perform"
|
||||
" basic addition and multiplication. Produce the response using the structured"
|
||||
" <GENERATED-CODE> format so the assistant can apply the files directly."
|
||||
)
|
||||
|
||||
result = await chat_tool.execute(
|
||||
{
|
||||
"prompt": prompt,
|
||||
"model": "gemini-2.5-pro",
|
||||
"working_directory": str(working_dir),
|
||||
}
|
||||
)
|
||||
|
||||
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.5-pro")
|
||||
if provider is not None:
|
||||
try:
|
||||
provider.client.close()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Reset restriction service cache to avoid leaking allowed-model config
|
||||
try:
|
||||
from utils import model_restrictions
|
||||
|
||||
model_restrictions._restriction_service = None # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assert result and result[0].type == "text"
|
||||
payload = json.loads(result[0].text)
|
||||
assert payload["status"] in {"success", "continuation_available"}
|
||||
|
||||
artifact_path = working_dir / "zen_generated.code"
|
||||
assert artifact_path.exists()
|
||||
saved = artifact_path.read_text()
|
||||
assert "<GENERATED-CODE>" in saved
|
||||
assert "<NEWFILE:" in saved
|
||||
assert "def add" in saved and "def multiply" in saved
|
||||
assert "stale contents" not in saved
|
||||
|
||||
artifact_path.unlink()
|
||||
@@ -55,7 +55,7 @@ def _extract_number(text: str) -> str:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.no_mock_provider
|
||||
async def test_chat_cross_model_continuation(monkeypatch):
|
||||
async def test_chat_cross_model_continuation(monkeypatch, tmp_path):
|
||||
"""Verify continuation across Gemini then OpenAI using recorded interactions."""
|
||||
|
||||
env_updates = {
|
||||
@@ -115,10 +115,13 @@ async def test_chat_cross_model_continuation(monkeypatch):
|
||||
m.setattr(conversation_memory.uuid, "uuid4", lambda: FIXED_THREAD_ID)
|
||||
|
||||
chat_tool = ChatTool()
|
||||
working_directory = str(tmp_path)
|
||||
|
||||
step1_args = {
|
||||
"prompt": "Pick a number between 1 and 10 and respond with JUST that number.",
|
||||
"model": "gemini-2.5-flash",
|
||||
"temperature": 0.2,
|
||||
"working_directory": working_directory,
|
||||
}
|
||||
|
||||
step1_result = await chat_tool.execute(step1_args)
|
||||
@@ -183,6 +186,7 @@ async def test_chat_cross_model_continuation(monkeypatch):
|
||||
"model": "gpt-5",
|
||||
"continuation_id": continuation_id,
|
||||
"temperature": 0.2,
|
||||
"working_directory": working_directory,
|
||||
}
|
||||
|
||||
step2_result = await chat_tool.execute(step2_args)
|
||||
|
||||
@@ -23,7 +23,7 @@ CASSETTE_CONTINUATION_PATH = CASSETTE_DIR / "chat_gpt5_continuation.json"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.no_mock_provider
|
||||
async def test_chat_auto_mode_with_openai(monkeypatch):
|
||||
async def test_chat_auto_mode_with_openai(monkeypatch, tmp_path):
|
||||
"""Ensure ChatTool in auto mode selects gpt-5 via OpenAI and returns a valid response."""
|
||||
# Prepare environment so only OpenAI is available in auto mode
|
||||
env_updates = {
|
||||
@@ -63,10 +63,12 @@ async def test_chat_auto_mode_with_openai(monkeypatch):
|
||||
|
||||
# Execute ChatTool request targeting gpt-5 directly (server normally resolves auto→model)
|
||||
chat_tool = ChatTool()
|
||||
working_directory = str(tmp_path)
|
||||
arguments = {
|
||||
"prompt": "Use chat with gpt5 and ask how far the moon is from earth.",
|
||||
"model": "gpt-5",
|
||||
"temperature": 1.0,
|
||||
"working_directory": working_directory,
|
||||
}
|
||||
|
||||
result = await chat_tool.execute(arguments)
|
||||
@@ -87,7 +89,7 @@ async def test_chat_auto_mode_with_openai(monkeypatch):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.no_mock_provider
|
||||
async def test_chat_openai_continuation(monkeypatch):
|
||||
async def test_chat_openai_continuation(monkeypatch, tmp_path):
|
||||
"""Verify continuation_id workflow against gpt-5 using recorded OpenAI responses."""
|
||||
|
||||
env_updates = {
|
||||
@@ -126,12 +128,14 @@ async def test_chat_openai_continuation(monkeypatch):
|
||||
m.setattr(conversation_memory.uuid, "uuid4", lambda: fixed_thread_id)
|
||||
|
||||
chat_tool = ChatTool()
|
||||
working_directory = str(tmp_path)
|
||||
|
||||
# First message: obtain continuation_id
|
||||
first_args = {
|
||||
"prompt": "In one word, which sells better: iOS app or macOS app?",
|
||||
"model": "gpt-5",
|
||||
"temperature": 1.0,
|
||||
"working_directory": working_directory,
|
||||
}
|
||||
first_result = await chat_tool.execute(first_args)
|
||||
|
||||
@@ -152,6 +156,7 @@ async def test_chat_openai_continuation(monkeypatch):
|
||||
"model": "gpt-5",
|
||||
"continuation_id": continuation_id,
|
||||
"temperature": 1.0,
|
||||
"working_directory": working_directory,
|
||||
}
|
||||
|
||||
second_result = await chat_tool.execute(second_args)
|
||||
|
||||
@@ -38,12 +38,14 @@ class TestChatTool:
|
||||
|
||||
# Required fields
|
||||
assert "prompt" in schema["required"]
|
||||
assert "working_directory" in schema["required"]
|
||||
|
||||
# Properties
|
||||
properties = schema["properties"]
|
||||
assert "prompt" in properties
|
||||
assert "files" in properties
|
||||
assert "images" in properties
|
||||
assert "working_directory" in properties
|
||||
|
||||
def test_request_model_validation(self):
|
||||
"""Test that the request model validates correctly"""
|
||||
@@ -54,6 +56,7 @@ class TestChatTool:
|
||||
"images": ["test.png"],
|
||||
"model": "anthropic/claude-opus-4.1",
|
||||
"temperature": 0.7,
|
||||
"working_directory": "/tmp", # Dummy absolute path
|
||||
}
|
||||
|
||||
request = ChatRequest(**request_data)
|
||||
@@ -62,6 +65,7 @@ class TestChatTool:
|
||||
assert request.images == ["test.png"]
|
||||
assert request.model == "anthropic/claude-opus-4.1"
|
||||
assert request.temperature == 0.7
|
||||
assert request.working_directory == "/tmp"
|
||||
|
||||
def test_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
@@ -69,7 +73,7 @@ class TestChatTool:
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ChatRequest(model="anthropic/claude-opus-4.1")
|
||||
ChatRequest(model="anthropic/claude-opus-4.1", working_directory="/tmp")
|
||||
|
||||
def test_model_availability(self):
|
||||
"""Test that model availability works"""
|
||||
@@ -96,7 +100,7 @@ class TestChatTool:
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_preparation(self):
|
||||
"""Test that prompt preparation works correctly"""
|
||||
request = ChatRequest(prompt="Test prompt", files=[])
|
||||
request = ChatRequest(prompt="Test prompt", files=[], working_directory="/tmp")
|
||||
|
||||
# Mock the system prompt and file handling
|
||||
with patch.object(self.tool, "get_system_prompt", return_value="System prompt"):
|
||||
@@ -113,7 +117,7 @@ class TestChatTool:
|
||||
def test_response_formatting(self):
|
||||
"""Test that response formatting works correctly"""
|
||||
response = "Test response content"
|
||||
request = ChatRequest(prompt="Test")
|
||||
request = ChatRequest(prompt="Test", working_directory="/tmp")
|
||||
|
||||
formatted = self.tool.format_response(response, request)
|
||||
|
||||
@@ -146,6 +150,7 @@ class TestChatTool:
|
||||
|
||||
required_fields = self.tool.get_required_fields()
|
||||
assert "prompt" in required_fields
|
||||
assert "working_directory" in required_fields
|
||||
|
||||
|
||||
class TestChatRequestModel:
|
||||
@@ -160,10 +165,11 @@ class TestChatRequestModel:
|
||||
assert "context" in CHAT_FIELD_DESCRIPTIONS["prompt"]
|
||||
assert "full-paths" in CHAT_FIELD_DESCRIPTIONS["files"] or "absolute" in CHAT_FIELD_DESCRIPTIONS["files"]
|
||||
assert "visual context" in CHAT_FIELD_DESCRIPTIONS["images"]
|
||||
assert "directory" in CHAT_FIELD_DESCRIPTIONS["working_directory"].lower()
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test that default values work correctly"""
|
||||
request = ChatRequest(prompt="Test")
|
||||
request = ChatRequest(prompt="Test", working_directory="/tmp")
|
||||
|
||||
assert request.prompt == "Test"
|
||||
assert request.files == [] # Should default to empty list
|
||||
@@ -173,7 +179,7 @@ class TestChatRequestModel:
|
||||
"""Test that ChatRequest properly inherits from ToolRequest"""
|
||||
from tools.shared.base_models import ToolRequest
|
||||
|
||||
request = ChatRequest(prompt="Test")
|
||||
request = ChatRequest(prompt="Test", working_directory="/tmp")
|
||||
assert isinstance(request, ToolRequest)
|
||||
|
||||
# Should have inherited fields
|
||||
|
||||
@@ -5,7 +5,7 @@ from utils.conversation_memory import get_thread
|
||||
from utils.storage_backend import get_storage_backend
|
||||
|
||||
|
||||
def test_first_response_persisted_in_conversation_history():
|
||||
def test_first_response_persisted_in_conversation_history(tmp_path):
|
||||
"""Ensure the assistant's initial reply is stored for newly created threads."""
|
||||
|
||||
# Clear in-memory storage to avoid cross-test contamination
|
||||
@@ -13,7 +13,7 @@ def test_first_response_persisted_in_conversation_history():
|
||||
storage._store.clear() # type: ignore[attr-defined]
|
||||
|
||||
tool = ChatTool()
|
||||
request = ChatRequest(prompt="First question?", model="local-llama")
|
||||
request = ChatRequest(prompt="First question?", model="local-llama", working_directory=str(tmp_path))
|
||||
response_text = "Here is the initial answer."
|
||||
|
||||
# Mimic the first tool invocation (no continuation_id supplied)
|
||||
|
||||
@@ -91,6 +91,7 @@ def helper_function():
|
||||
"prompt": "Analyze this codebase structure",
|
||||
"files": [directory], # Directory path, not individual files
|
||||
"model": "flash",
|
||||
"working_directory": directory,
|
||||
}
|
||||
|
||||
# Execute the tool
|
||||
@@ -168,6 +169,7 @@ def helper_function():
|
||||
"files": [directory], # Same directory again
|
||||
"model": "flash",
|
||||
"continuation_id": thread_id,
|
||||
"working_directory": directory,
|
||||
}
|
||||
|
||||
# Mock to capture file filtering behavior
|
||||
@@ -299,6 +301,7 @@ def helper_function():
|
||||
"prompt": "Analyze this code",
|
||||
"files": [directory],
|
||||
"model": "flash",
|
||||
"working_directory": directory,
|
||||
}
|
||||
|
||||
result = await tool.execute(request_args)
|
||||
|
||||
@@ -56,7 +56,12 @@ class TestLargePromptHandling:
|
||||
async def test_chat_large_prompt_detection(self, large_prompt):
|
||||
"""Test that chat tool detects large prompts."""
|
||||
tool = ChatTool()
|
||||
result = await tool.execute({"prompt": large_prompt})
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
result = await tool.execute({"prompt": large_prompt, "working_directory": temp_dir})
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], TextContent)
|
||||
@@ -73,9 +78,16 @@ class TestLargePromptHandling:
|
||||
"""Test that chat tool works normally with regular prompts."""
|
||||
tool = ChatTool()
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
|
||||
# This test runs in the test environment which uses dummy keys
|
||||
# The chat tool will return an error for dummy keys, which is expected
|
||||
result = await tool.execute({"prompt": normal_prompt, "model": "gemini-2.5-flash"})
|
||||
try:
|
||||
result = await tool.execute(
|
||||
{"prompt": normal_prompt, "model": "gemini-2.5-flash", "working_directory": temp_dir}
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
assert len(result) == 1
|
||||
output = json.loads(result[0].text)
|
||||
@@ -105,7 +117,14 @@ class TestLargePromptHandling:
|
||||
try:
|
||||
# This test runs in the test environment which uses dummy keys
|
||||
# The chat tool will return an error for dummy keys, which is expected
|
||||
result = await tool.execute({"prompt": "", "files": [temp_prompt_file], "model": "gemini-2.5-flash"})
|
||||
result = await tool.execute(
|
||||
{
|
||||
"prompt": "",
|
||||
"files": [temp_prompt_file],
|
||||
"model": "gemini-2.5-flash",
|
||||
"working_directory": temp_dir,
|
||||
}
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
output = json.loads(result[0].text)
|
||||
@@ -261,7 +280,13 @@ class TestLargePromptHandling:
|
||||
mock_prepare_files.return_value = ("File content", [other_file])
|
||||
|
||||
# Use a small prompt to avoid triggering size limit
|
||||
await tool.execute({"prompt": "Test prompt", "files": [temp_prompt_file, other_file]})
|
||||
await tool.execute(
|
||||
{
|
||||
"prompt": "Test prompt",
|
||||
"files": [temp_prompt_file, other_file],
|
||||
"working_directory": os.path.dirname(temp_prompt_file),
|
||||
}
|
||||
)
|
||||
|
||||
# Verify handle_prompt_file was called with the original files list
|
||||
mock_handle_prompt.assert_called_once_with([temp_prompt_file, other_file])
|
||||
@@ -295,7 +320,11 @@ class TestLargePromptHandling:
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# With the fix, this should now pass because we check at MCP transport boundary before adding internal content
|
||||
result = await tool.execute({"prompt": exact_prompt})
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
result = await tool.execute({"prompt": exact_prompt, "working_directory": temp_dir})
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
output = json.loads(result[0].text)
|
||||
assert output["status"] != "resend_prompt"
|
||||
|
||||
@@ -305,7 +334,11 @@ class TestLargePromptHandling:
|
||||
tool = ChatTool()
|
||||
over_prompt = "x" * (MCP_PROMPT_SIZE_LIMIT + 1)
|
||||
|
||||
result = await tool.execute({"prompt": over_prompt})
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
result = await tool.execute({"prompt": over_prompt, "working_directory": temp_dir})
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
output = json.loads(result[0].text)
|
||||
assert output["status"] == "resend_prompt"
|
||||
|
||||
@@ -326,7 +359,11 @@ class TestLargePromptHandling:
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await tool.execute({"prompt": ""})
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
result = await tool.execute({"prompt": "", "working_directory": temp_dir})
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
output = json.loads(result[0].text)
|
||||
assert output["status"] != "resend_prompt"
|
||||
|
||||
@@ -362,7 +399,11 @@ class TestLargePromptHandling:
|
||||
mock_model_context_class.return_value = mock_model_context
|
||||
|
||||
# Should continue with empty prompt when file can't be read
|
||||
result = await tool.execute({"prompt": "", "files": [bad_file]})
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
result = await tool.execute({"prompt": "", "files": [bad_file], "working_directory": temp_dir})
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
output = json.loads(result[0].text)
|
||||
assert output["status"] != "resend_prompt"
|
||||
|
||||
@@ -408,6 +449,7 @@ class TestLargePromptHandling:
|
||||
"prompt": "Summarize the design decisions",
|
||||
"files": [str(large_file)],
|
||||
"model": "flash",
|
||||
"working_directory": str(tmp_path),
|
||||
"_model_context": dummy_context,
|
||||
}
|
||||
)
|
||||
@@ -424,6 +466,7 @@ class TestLargePromptHandling:
|
||||
This test verifies that even if our internal prompt (with system prompts, history, etc.)
|
||||
exceeds MCP_PROMPT_SIZE_LIMIT, it should still work as long as the user's input is small.
|
||||
"""
|
||||
|
||||
tool = ChatTool()
|
||||
|
||||
# Small user input that should pass MCP boundary check
|
||||
@@ -432,62 +475,57 @@ class TestLargePromptHandling:
|
||||
# Mock a huge conversation history that would exceed MCP limits if incorrectly checked
|
||||
huge_history = "x" * (MCP_PROMPT_SIZE_LIMIT * 2) # 100K chars = way over 50K limit
|
||||
|
||||
with (
|
||||
patch.object(tool, "get_model_provider") as mock_get_provider,
|
||||
patch("utils.model_context.ModelContext") as mock_model_context_class,
|
||||
):
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
original_prepare_prompt = tool.prepare_prompt
|
||||
|
||||
mock_provider = create_mock_provider(model_name="flash")
|
||||
mock_get_provider.return_value = mock_provider
|
||||
try:
|
||||
with (
|
||||
patch.object(tool, "get_model_provider") as mock_get_provider,
|
||||
patch("utils.model_context.ModelContext") as mock_model_context_class,
|
||||
):
|
||||
from tests.mock_helpers import create_mock_provider
|
||||
from utils.model_context import TokenAllocation
|
||||
|
||||
# Mock ModelContext to avoid the comparison issue
|
||||
from utils.model_context import TokenAllocation
|
||||
mock_provider = create_mock_provider(model_name="flash")
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
mock_model_context = MagicMock()
|
||||
mock_model_context.model_name = "flash"
|
||||
mock_model_context.provider = mock_provider
|
||||
mock_model_context.calculate_token_allocation.return_value = TokenAllocation(
|
||||
total_tokens=1_048_576,
|
||||
content_tokens=838_861,
|
||||
response_tokens=209_715,
|
||||
file_tokens=335_544,
|
||||
history_tokens=335_544,
|
||||
)
|
||||
mock_model_context_class.return_value = mock_model_context
|
||||
mock_model_context = MagicMock()
|
||||
mock_model_context.model_name = "flash"
|
||||
mock_model_context.provider = mock_provider
|
||||
mock_model_context.calculate_token_allocation.return_value = TokenAllocation(
|
||||
total_tokens=1_048_576,
|
||||
content_tokens=838_861,
|
||||
response_tokens=209_715,
|
||||
file_tokens=335_544,
|
||||
history_tokens=335_544,
|
||||
)
|
||||
mock_model_context_class.return_value = mock_model_context
|
||||
|
||||
# Mock the prepare_prompt to simulate huge internal context
|
||||
original_prepare_prompt = tool.prepare_prompt
|
||||
async def mock_prepare_prompt(request):
|
||||
normal_prompt = await original_prepare_prompt(request)
|
||||
huge_internal_prompt = f"{normal_prompt}\n\n=== HUGE INTERNAL CONTEXT ===\n{huge_history}"
|
||||
assert len(huge_internal_prompt) > MCP_PROMPT_SIZE_LIMIT
|
||||
return huge_internal_prompt
|
||||
|
||||
async def mock_prepare_prompt(request):
|
||||
# Call original to get normal processing
|
||||
normal_prompt = await original_prepare_prompt(request)
|
||||
# Add huge internal context (simulating large history, system prompts, files)
|
||||
huge_internal_prompt = f"{normal_prompt}\n\n=== HUGE INTERNAL CONTEXT ===\n{huge_history}"
|
||||
tool.prepare_prompt = mock_prepare_prompt
|
||||
|
||||
# Verify the huge internal prompt would exceed MCP limits if incorrectly checked
|
||||
assert len(huge_internal_prompt) > MCP_PROMPT_SIZE_LIMIT
|
||||
result = await tool.execute(
|
||||
{"prompt": small_user_prompt, "model": "flash", "working_directory": temp_dir}
|
||||
)
|
||||
output = json.loads(result[0].text)
|
||||
|
||||
return huge_internal_prompt
|
||||
assert output["status"] != "resend_prompt"
|
||||
|
||||
tool.prepare_prompt = mock_prepare_prompt
|
||||
mock_provider.generate_content.assert_called_once()
|
||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||
actual_prompt = call_kwargs.get("prompt")
|
||||
|
||||
# This should succeed because we only check user input at MCP boundary
|
||||
result = await tool.execute({"prompt": small_user_prompt, "model": "flash"})
|
||||
output = json.loads(result[0].text)
|
||||
|
||||
# Should succeed even though internal context is huge
|
||||
assert output["status"] != "resend_prompt"
|
||||
|
||||
# Verify the model was actually called with the huge prompt
|
||||
mock_provider.generate_content.assert_called_once()
|
||||
call_kwargs = mock_provider.generate_content.call_args[1]
|
||||
actual_prompt = call_kwargs.get("prompt")
|
||||
|
||||
# Verify internal prompt was huge (proving we don't limit internal processing)
|
||||
assert len(actual_prompt) > MCP_PROMPT_SIZE_LIMIT
|
||||
assert huge_history in actual_prompt
|
||||
assert small_user_prompt in actual_prompt
|
||||
assert len(actual_prompt) > MCP_PROMPT_SIZE_LIMIT
|
||||
assert huge_history in actual_prompt
|
||||
assert small_user_prompt in actual_prompt
|
||||
finally:
|
||||
tool.prepare_prompt = original_prepare_prompt
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_boundary_vs_internal_processing_distinction(self):
|
||||
@@ -500,27 +538,37 @@ class TestLargePromptHandling:
|
||||
|
||||
# Test case 1: Large user input should fail at MCP boundary
|
||||
large_user_input = "x" * (MCP_PROMPT_SIZE_LIMIT + 1000)
|
||||
result = await tool.execute({"prompt": large_user_input, "model": "flash"})
|
||||
output = json.loads(result[0].text)
|
||||
assert output["status"] == "resend_prompt" # Should fail
|
||||
assert "too large for MCP's token limits" in output["content"]
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
result = await tool.execute({"prompt": large_user_input, "model": "flash", "working_directory": temp_dir})
|
||||
output = json.loads(result[0].text)
|
||||
assert output["status"] == "resend_prompt" # Should fail
|
||||
assert "too large for MCP's token limits" in output["content"]
|
||||
|
||||
# Test case 2: Small user input should succeed even with huge internal processing
|
||||
small_user_input = "Hello"
|
||||
# Test case 2: Small user input should succeed even with huge internal processing
|
||||
small_user_input = "Hello"
|
||||
|
||||
# This test runs in the test environment which uses dummy keys
|
||||
# The chat tool will return an error for dummy keys, which is expected
|
||||
result = await tool.execute({"prompt": small_user_input, "model": "gemini-2.5-flash"})
|
||||
output = json.loads(result[0].text)
|
||||
# This test runs in the test environment which uses dummy keys
|
||||
# The chat tool will return an error for dummy keys, which is expected
|
||||
result = await tool.execute(
|
||||
{
|
||||
"prompt": small_user_input,
|
||||
"model": "gemini-2.5-flash",
|
||||
"working_directory": temp_dir,
|
||||
}
|
||||
)
|
||||
output = json.loads(result[0].text)
|
||||
|
||||
# The test will fail with dummy API keys, which is expected behavior
|
||||
# We're mainly testing that the tool processes small prompts correctly without size errors
|
||||
if output["status"] == "error":
|
||||
# If it's an API error, that's fine - we're testing prompt handling, not API calls
|
||||
assert "API" in output["content"] or "key" in output["content"] or "authentication" in output["content"]
|
||||
else:
|
||||
# If somehow it succeeds (e.g., with mocked provider), check the response
|
||||
assert output["status"] != "resend_prompt"
|
||||
# The test will fail with dummy API keys, which is expected behavior
|
||||
# We're mainly testing that the tool processes small prompts correctly without size errors
|
||||
if output["status"] == "error":
|
||||
# If it's an API error, that's fine - we're testing prompt handling, not API calls
|
||||
assert "API" in output["content"] or "key" in output["content"] or "authentication" in output["content"]
|
||||
else:
|
||||
# If somehow it succeeds (e.g., with mocked provider), check the response
|
||||
assert output["status"] != "resend_prompt"
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continuation_with_huge_conversation_history(self):
|
||||
@@ -548,6 +596,8 @@ class TestLargePromptHandling:
|
||||
# Ensure the history exceeds MCP limits
|
||||
assert len(huge_conversation_history) > MCP_PROMPT_SIZE_LIMIT
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
|
||||
with (
|
||||
patch.object(tool, "get_model_provider") as mock_get_provider,
|
||||
patch("utils.model_context.ModelContext") as mock_model_context_class,
|
||||
@@ -579,6 +629,7 @@ class TestLargePromptHandling:
|
||||
"prompt": f"{huge_conversation_history}\n\n=== CURRENT REQUEST ===\n{small_continuation_prompt}",
|
||||
"model": "flash",
|
||||
"continuation_id": "test_thread_123",
|
||||
"working_directory": temp_dir,
|
||||
}
|
||||
|
||||
# Mock the conversation history embedding to simulate server.py behavior
|
||||
@@ -628,6 +679,7 @@ class TestLargePromptHandling:
|
||||
finally:
|
||||
# Restore original execute method
|
||||
tool.__class__.execute = original_execute
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -68,6 +68,7 @@ class TestListModelsTool:
|
||||
assert "`flash` → `gemini-2.5-flash`" in content
|
||||
assert "`pro` → `gemini-2.5-pro`" in content
|
||||
assert "1M context" in content
|
||||
assert "Supports structured code generation" in content
|
||||
|
||||
# Check summary
|
||||
assert "**Configured Providers**: 1" in content
|
||||
|
||||
@@ -12,6 +12,7 @@ RECORDING: To record new responses, delete the cassette file and run with real A
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -92,9 +93,15 @@ class TestO3ProOutputTextFix:
|
||||
async def _execute_chat_tool_test(self):
|
||||
"""Execute the ChatTool with o3-pro and return the result."""
|
||||
chat_tool = ChatTool()
|
||||
arguments = {"prompt": "What is 2 + 2?", "model": "o3-pro", "temperature": 1.0}
|
||||
with tempfile.TemporaryDirectory() as workdir:
|
||||
arguments = {
|
||||
"prompt": "What is 2 + 2?",
|
||||
"model": "o3-pro",
|
||||
"temperature": 1.0,
|
||||
"working_directory": workdir,
|
||||
}
|
||||
|
||||
return await chat_tool.execute(arguments)
|
||||
return await chat_tool.execute(arguments)
|
||||
|
||||
def _verify_chat_tool_response(self, result):
|
||||
"""Verify the ChatTool response contains expected data."""
|
||||
|
||||
@@ -4,6 +4,8 @@ Test per-tool model default selection functionality
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -290,7 +292,13 @@ class TestAutoModeErrorMessages:
|
||||
mock_get_provider_for.return_value = None
|
||||
|
||||
tool = ChatTool()
|
||||
result = await tool.execute({"prompt": "test", "model": "auto"})
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
result = await tool.execute(
|
||||
{"prompt": "test", "model": "auto", "working_directory": temp_dir}
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
assert len(result) == 1
|
||||
# The SimpleTool will wrap the error message
|
||||
@@ -418,7 +426,13 @@ class TestRuntimeModelSelection:
|
||||
mock_get_provider.return_value = None
|
||||
|
||||
tool = ChatTool()
|
||||
result = await tool.execute({"prompt": "test", "model": "gpt-5-turbo"})
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
result = await tool.execute(
|
||||
{"prompt": "test", "model": "gpt-5-turbo", "working_directory": temp_dir}
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
# Should require model selection
|
||||
assert len(result) == 1
|
||||
@@ -515,7 +529,11 @@ class TestUnavailableModelFallback:
|
||||
mock_get_model_provider.return_value = mock_provider
|
||||
|
||||
tool = ChatTool()
|
||||
result = await tool.execute({"prompt": "test"}) # No model specified
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
result = await tool.execute({"prompt": "test", "working_directory": temp_dir})
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
# Should work normally, not require model parameter
|
||||
assert len(result) == 1
|
||||
|
||||
@@ -3,6 +3,8 @@ Tests for individual tool implementations
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -343,12 +345,17 @@ class TestAbsolutePathValidation:
|
||||
async def test_chat_tool_relative_path_rejected(self):
|
||||
"""Test that chat tool rejects relative paths"""
|
||||
tool = ChatTool()
|
||||
result = await tool.execute(
|
||||
{
|
||||
"prompt": "Explain this code",
|
||||
"files": ["code.py"], # relative path without ./
|
||||
}
|
||||
)
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
result = await tool.execute(
|
||||
{
|
||||
"prompt": "Explain this code",
|
||||
"files": ["code.py"], # relative path without ./
|
||||
"working_directory": temp_dir,
|
||||
}
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
assert len(result) == 1
|
||||
response = json.loads(result[0].text)
|
||||
|
||||
Reference in New Issue
Block a user