feat!: Full code can now be generated by an external model and shared with the AI tool (Claude Code / Codex etc)!

model definitions now support a new `allow_code_generation` flag, only to be used with higher reasoning models such as GPT-5-Pro and-Gemini 2.5-Pro

 When `true`, the `chat` tool can now request the external model to generate a full implementation / update / instructions etc and then share the implementation with the calling agent.

 This effectively allows us to utilize more powerful models such as GPT-5-Pro to generate code for us or entire implementations (which are either API-only or part of the $200 Pro plan from within the ChatGPT app)
This commit is contained in:
Fahad
2025-10-07 18:49:13 +04:00
parent 04f7ce5b03
commit ece8a5ebed
29 changed files with 1008 additions and 122 deletions

File diff suppressed because one or more lines are too long

View File

@@ -137,7 +137,7 @@ class TestAutoMode:
importlib.reload(config)
@pytest.mark.asyncio
async def test_auto_mode_requires_model_parameter(self):
async def test_auto_mode_requires_model_parameter(self, tmp_path):
"""Test that auto mode enforces model parameter"""
# Save original
original = os.environ.get("DEFAULT_MODEL", "")
@@ -154,7 +154,7 @@ class TestAutoMode:
# Mock the provider to avoid real API calls
with patch.object(tool, "get_model_provider"):
# Execute without model parameter
result = await tool.execute({"prompt": "Test prompt"})
result = await tool.execute({"prompt": "Test prompt", "working_directory": str(tmp_path)})
# Should get error
assert len(result) == 1

View File

@@ -200,7 +200,7 @@ class TestAutoModeComprehensive:
assert tool.get_model_category() == expected_category
@pytest.mark.asyncio
async def test_auto_mode_with_gemini_only_uses_correct_models(self):
async def test_auto_mode_with_gemini_only_uses_correct_models(self, tmp_path):
"""Test that auto mode with only Gemini uses flash for fast tools and pro for reasoning tools."""
provider_config = {
@@ -234,9 +234,13 @@ class TestAutoModeComprehensive:
)
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
workdir = tmp_path / "chat_artifacts"
workdir.mkdir(parents=True, exist_ok=True)
# Test ChatTool (FAST_RESPONSE) - should prefer flash
chat_tool = ChatTool()
await chat_tool.execute({"prompt": "test", "model": "auto"}) # This should trigger auto selection
await chat_tool.execute(
{"prompt": "test", "model": "auto", "working_directory": str(workdir)}
) # This should trigger auto selection
# In auto mode, the tool should get an error requiring model selection
# but the suggested model should be flash
@@ -355,7 +359,7 @@ class TestAutoModeComprehensive:
# would show models from all providers when called
@pytest.mark.asyncio
async def test_auto_mode_model_parameter_required_error(self):
async def test_auto_mode_model_parameter_required_error(self, tmp_path):
"""Test that auto mode properly requires model parameter and suggests correct model."""
provider_config = {
@@ -384,9 +388,12 @@ class TestAutoModeComprehensive:
# Test with ChatTool (FAST_RESPONSE category)
chat_tool = ChatTool()
workdir = tmp_path / "chat_artifacts"
workdir.mkdir(parents=True, exist_ok=True)
result = await chat_tool.execute(
{
"prompt": "test"
"prompt": "test",
"working_directory": str(workdir),
# Note: no model parameter provided in auto mode
}
)
@@ -508,7 +515,7 @@ class TestAutoModeComprehensive:
assert fast_response is not None
@pytest.mark.asyncio
async def test_actual_model_name_resolution_in_auto_mode(self):
async def test_actual_model_name_resolution_in_auto_mode(self, tmp_path):
"""Test that when a model is selected in auto mode, the tool executes successfully."""
provider_config = {
@@ -547,7 +554,11 @@ class TestAutoModeComprehensive:
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
chat_tool = ChatTool()
result = await chat_tool.execute({"prompt": "test", "model": "flash"}) # Use alias in auto mode
workdir = tmp_path / "chat_artifacts"
workdir.mkdir(parents=True, exist_ok=True)
result = await chat_tool.execute(
{"prompt": "test", "model": "flash", "working_directory": str(workdir)}
) # Use alias in auto mode
# Should succeed with proper model resolution
assert len(result) == 1

View File

@@ -0,0 +1,113 @@
"""Integration test for Chat tool code generation with Gemini 2.5 Pro.
This test uses the Google Gemini SDK's built-in record/replay support. To refresh the
cassette, delete the existing JSON file under
``tests/gemini_cassettes/chat_codegen/gemini25_pro_calculator/mldev.json`` and run:
```
GEMINI_API_KEY=<real-key> pytest tests/test_chat_codegen_integration.py::test_chat_codegen_saves_file
```
The test will automatically record a new interaction when the cassette is missing and
the environment variable `GEMINI_API_KEY` is set to a valid key.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
import pytest
from providers.gemini import GeminiModelProvider
from providers.registry import ModelProviderRegistry, ProviderType
from tools.chat import ChatTool
REPLAYS_ROOT = Path(__file__).parent / "gemini_cassettes"
CASSETTE_DIR = REPLAYS_ROOT / "chat_codegen"
CASSETTE_PATH = CASSETTE_DIR / "gemini25_pro_calculator" / "mldev.json"
CASSETTE_REPLAY_ID = "chat_codegen/gemini25_pro_calculator/mldev"
@pytest.mark.asyncio
@pytest.mark.no_mock_provider
async def test_chat_codegen_saves_file(monkeypatch, tmp_path):
"""Ensure Gemini 2.5 Pro responses create zen_generated.code when code is emitted."""
CASSETTE_PATH.parent.mkdir(parents=True, exist_ok=True)
recording_mode = not CASSETTE_PATH.exists()
gemini_key = os.getenv("GEMINI_API_KEY", "")
if recording_mode:
if not gemini_key or gemini_key.startswith("dummy"):
pytest.skip("Cassette missing and GEMINI_API_KEY not configured. Provide a real key to record.")
client_mode = "record"
else:
gemini_key = "dummy-key-for-replay"
client_mode = "replay"
with monkeypatch.context() as m:
m.setenv("GEMINI_API_KEY", gemini_key)
m.setenv("DEFAULT_MODEL", "auto")
m.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro")
m.setenv("GOOGLE_GENAI_CLIENT_MODE", client_mode)
m.setenv("GOOGLE_GENAI_REPLAYS_DIRECTORY", str(REPLAYS_ROOT))
m.setenv("GOOGLE_GENAI_REPLAY_ID", CASSETTE_REPLAY_ID)
# Clear other provider keys to avoid unintended routing
for key in ["OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "CUSTOM_API_KEY"]:
m.delenv(key, raising=False)
ModelProviderRegistry.reset_for_testing()
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
working_dir = tmp_path / "codegen"
working_dir.mkdir()
preexisting = working_dir / "zen_generated.code"
preexisting.write_text("stale contents", encoding="utf-8")
chat_tool = ChatTool()
prompt = (
"Please generate a Python module with functions `add` and `multiply` that perform"
" basic addition and multiplication. Produce the response using the structured"
" <GENERATED-CODE> format so the assistant can apply the files directly."
)
result = await chat_tool.execute(
{
"prompt": prompt,
"model": "gemini-2.5-pro",
"working_directory": str(working_dir),
}
)
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.5-pro")
if provider is not None:
try:
provider.client.close()
except AttributeError:
pass
# Reset restriction service cache to avoid leaking allowed-model config
try:
from utils import model_restrictions
model_restrictions._restriction_service = None # type: ignore[attr-defined]
except Exception:
pass
assert result and result[0].type == "text"
payload = json.loads(result[0].text)
assert payload["status"] in {"success", "continuation_available"}
artifact_path = working_dir / "zen_generated.code"
assert artifact_path.exists()
saved = artifact_path.read_text()
assert "<GENERATED-CODE>" in saved
assert "<NEWFILE:" in saved
assert "def add" in saved and "def multiply" in saved
assert "stale contents" not in saved
artifact_path.unlink()

View File

@@ -55,7 +55,7 @@ def _extract_number(text: str) -> str:
@pytest.mark.asyncio
@pytest.mark.no_mock_provider
async def test_chat_cross_model_continuation(monkeypatch):
async def test_chat_cross_model_continuation(monkeypatch, tmp_path):
"""Verify continuation across Gemini then OpenAI using recorded interactions."""
env_updates = {
@@ -115,10 +115,13 @@ async def test_chat_cross_model_continuation(monkeypatch):
m.setattr(conversation_memory.uuid, "uuid4", lambda: FIXED_THREAD_ID)
chat_tool = ChatTool()
working_directory = str(tmp_path)
step1_args = {
"prompt": "Pick a number between 1 and 10 and respond with JUST that number.",
"model": "gemini-2.5-flash",
"temperature": 0.2,
"working_directory": working_directory,
}
step1_result = await chat_tool.execute(step1_args)
@@ -183,6 +186,7 @@ async def test_chat_cross_model_continuation(monkeypatch):
"model": "gpt-5",
"continuation_id": continuation_id,
"temperature": 0.2,
"working_directory": working_directory,
}
step2_result = await chat_tool.execute(step2_args)

View File

@@ -23,7 +23,7 @@ CASSETTE_CONTINUATION_PATH = CASSETTE_DIR / "chat_gpt5_continuation.json"
@pytest.mark.asyncio
@pytest.mark.no_mock_provider
async def test_chat_auto_mode_with_openai(monkeypatch):
async def test_chat_auto_mode_with_openai(monkeypatch, tmp_path):
"""Ensure ChatTool in auto mode selects gpt-5 via OpenAI and returns a valid response."""
# Prepare environment so only OpenAI is available in auto mode
env_updates = {
@@ -63,10 +63,12 @@ async def test_chat_auto_mode_with_openai(monkeypatch):
# Execute ChatTool request targeting gpt-5 directly (server normally resolves auto→model)
chat_tool = ChatTool()
working_directory = str(tmp_path)
arguments = {
"prompt": "Use chat with gpt5 and ask how far the moon is from earth.",
"model": "gpt-5",
"temperature": 1.0,
"working_directory": working_directory,
}
result = await chat_tool.execute(arguments)
@@ -87,7 +89,7 @@ async def test_chat_auto_mode_with_openai(monkeypatch):
@pytest.mark.asyncio
@pytest.mark.no_mock_provider
async def test_chat_openai_continuation(monkeypatch):
async def test_chat_openai_continuation(monkeypatch, tmp_path):
"""Verify continuation_id workflow against gpt-5 using recorded OpenAI responses."""
env_updates = {
@@ -126,12 +128,14 @@ async def test_chat_openai_continuation(monkeypatch):
m.setattr(conversation_memory.uuid, "uuid4", lambda: fixed_thread_id)
chat_tool = ChatTool()
working_directory = str(tmp_path)
# First message: obtain continuation_id
first_args = {
"prompt": "In one word, which sells better: iOS app or macOS app?",
"model": "gpt-5",
"temperature": 1.0,
"working_directory": working_directory,
}
first_result = await chat_tool.execute(first_args)
@@ -152,6 +156,7 @@ async def test_chat_openai_continuation(monkeypatch):
"model": "gpt-5",
"continuation_id": continuation_id,
"temperature": 1.0,
"working_directory": working_directory,
}
second_result = await chat_tool.execute(second_args)

View File

@@ -38,12 +38,14 @@ class TestChatTool:
# Required fields
assert "prompt" in schema["required"]
assert "working_directory" in schema["required"]
# Properties
properties = schema["properties"]
assert "prompt" in properties
assert "files" in properties
assert "images" in properties
assert "working_directory" in properties
def test_request_model_validation(self):
"""Test that the request model validates correctly"""
@@ -54,6 +56,7 @@ class TestChatTool:
"images": ["test.png"],
"model": "anthropic/claude-opus-4.1",
"temperature": 0.7,
"working_directory": "/tmp", # Dummy absolute path
}
request = ChatRequest(**request_data)
@@ -62,6 +65,7 @@ class TestChatTool:
assert request.images == ["test.png"]
assert request.model == "anthropic/claude-opus-4.1"
assert request.temperature == 0.7
assert request.working_directory == "/tmp"
def test_required_fields(self):
"""Test that required fields are enforced"""
@@ -69,7 +73,7 @@ class TestChatTool:
from pydantic import ValidationError
with pytest.raises(ValidationError):
ChatRequest(model="anthropic/claude-opus-4.1")
ChatRequest(model="anthropic/claude-opus-4.1", working_directory="/tmp")
def test_model_availability(self):
"""Test that model availability works"""
@@ -96,7 +100,7 @@ class TestChatTool:
@pytest.mark.asyncio
async def test_prompt_preparation(self):
"""Test that prompt preparation works correctly"""
request = ChatRequest(prompt="Test prompt", files=[])
request = ChatRequest(prompt="Test prompt", files=[], working_directory="/tmp")
# Mock the system prompt and file handling
with patch.object(self.tool, "get_system_prompt", return_value="System prompt"):
@@ -113,7 +117,7 @@ class TestChatTool:
def test_response_formatting(self):
"""Test that response formatting works correctly"""
response = "Test response content"
request = ChatRequest(prompt="Test")
request = ChatRequest(prompt="Test", working_directory="/tmp")
formatted = self.tool.format_response(response, request)
@@ -146,6 +150,7 @@ class TestChatTool:
required_fields = self.tool.get_required_fields()
assert "prompt" in required_fields
assert "working_directory" in required_fields
class TestChatRequestModel:
@@ -160,10 +165,11 @@ class TestChatRequestModel:
assert "context" in CHAT_FIELD_DESCRIPTIONS["prompt"]
assert "full-paths" in CHAT_FIELD_DESCRIPTIONS["files"] or "absolute" in CHAT_FIELD_DESCRIPTIONS["files"]
assert "visual context" in CHAT_FIELD_DESCRIPTIONS["images"]
assert "directory" in CHAT_FIELD_DESCRIPTIONS["working_directory"].lower()
def test_default_values(self):
"""Test that default values work correctly"""
request = ChatRequest(prompt="Test")
request = ChatRequest(prompt="Test", working_directory="/tmp")
assert request.prompt == "Test"
assert request.files == [] # Should default to empty list
@@ -173,7 +179,7 @@ class TestChatRequestModel:
"""Test that ChatRequest properly inherits from ToolRequest"""
from tools.shared.base_models import ToolRequest
request = ChatRequest(prompt="Test")
request = ChatRequest(prompt="Test", working_directory="/tmp")
assert isinstance(request, ToolRequest)
# Should have inherited fields

View File

@@ -5,7 +5,7 @@ from utils.conversation_memory import get_thread
from utils.storage_backend import get_storage_backend
def test_first_response_persisted_in_conversation_history():
def test_first_response_persisted_in_conversation_history(tmp_path):
"""Ensure the assistant's initial reply is stored for newly created threads."""
# Clear in-memory storage to avoid cross-test contamination
@@ -13,7 +13,7 @@ def test_first_response_persisted_in_conversation_history():
storage._store.clear() # type: ignore[attr-defined]
tool = ChatTool()
request = ChatRequest(prompt="First question?", model="local-llama")
request = ChatRequest(prompt="First question?", model="local-llama", working_directory=str(tmp_path))
response_text = "Here is the initial answer."
# Mimic the first tool invocation (no continuation_id supplied)

View File

@@ -91,6 +91,7 @@ def helper_function():
"prompt": "Analyze this codebase structure",
"files": [directory], # Directory path, not individual files
"model": "flash",
"working_directory": directory,
}
# Execute the tool
@@ -168,6 +169,7 @@ def helper_function():
"files": [directory], # Same directory again
"model": "flash",
"continuation_id": thread_id,
"working_directory": directory,
}
# Mock to capture file filtering behavior
@@ -299,6 +301,7 @@ def helper_function():
"prompt": "Analyze this code",
"files": [directory],
"model": "flash",
"working_directory": directory,
}
result = await tool.execute(request_args)

View File

@@ -56,7 +56,12 @@ class TestLargePromptHandling:
async def test_chat_large_prompt_detection(self, large_prompt):
"""Test that chat tool detects large prompts."""
tool = ChatTool()
result = await tool.execute({"prompt": large_prompt})
temp_dir = tempfile.mkdtemp()
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute({"prompt": large_prompt, "working_directory": temp_dir})
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
assert len(result) == 1
assert isinstance(result[0], TextContent)
@@ -73,9 +78,16 @@ class TestLargePromptHandling:
"""Test that chat tool works normally with regular prompts."""
tool = ChatTool()
temp_dir = tempfile.mkdtemp()
# This test runs in the test environment which uses dummy keys
# The chat tool will return an error for dummy keys, which is expected
result = await tool.execute({"prompt": normal_prompt, "model": "gemini-2.5-flash"})
try:
result = await tool.execute(
{"prompt": normal_prompt, "model": "gemini-2.5-flash", "working_directory": temp_dir}
)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
assert len(result) == 1
output = json.loads(result[0].text)
@@ -105,7 +117,14 @@ class TestLargePromptHandling:
try:
# This test runs in the test environment which uses dummy keys
# The chat tool will return an error for dummy keys, which is expected
result = await tool.execute({"prompt": "", "files": [temp_prompt_file], "model": "gemini-2.5-flash"})
result = await tool.execute(
{
"prompt": "",
"files": [temp_prompt_file],
"model": "gemini-2.5-flash",
"working_directory": temp_dir,
}
)
assert len(result) == 1
output = json.loads(result[0].text)
@@ -261,7 +280,13 @@ class TestLargePromptHandling:
mock_prepare_files.return_value = ("File content", [other_file])
# Use a small prompt to avoid triggering size limit
await tool.execute({"prompt": "Test prompt", "files": [temp_prompt_file, other_file]})
await tool.execute(
{
"prompt": "Test prompt",
"files": [temp_prompt_file, other_file],
"working_directory": os.path.dirname(temp_prompt_file),
}
)
# Verify handle_prompt_file was called with the original files list
mock_handle_prompt.assert_called_once_with([temp_prompt_file, other_file])
@@ -295,7 +320,11 @@ class TestLargePromptHandling:
mock_get_provider.return_value = mock_provider
# With the fix, this should now pass because we check at MCP transport boundary before adding internal content
result = await tool.execute({"prompt": exact_prompt})
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute({"prompt": exact_prompt, "working_directory": temp_dir})
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
output = json.loads(result[0].text)
assert output["status"] != "resend_prompt"
@@ -305,7 +334,11 @@ class TestLargePromptHandling:
tool = ChatTool()
over_prompt = "x" * (MCP_PROMPT_SIZE_LIMIT + 1)
result = await tool.execute({"prompt": over_prompt})
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute({"prompt": over_prompt, "working_directory": temp_dir})
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
output = json.loads(result[0].text)
assert output["status"] == "resend_prompt"
@@ -326,7 +359,11 @@ class TestLargePromptHandling:
)
mock_get_provider.return_value = mock_provider
result = await tool.execute({"prompt": ""})
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute({"prompt": "", "working_directory": temp_dir})
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
output = json.loads(result[0].text)
assert output["status"] != "resend_prompt"
@@ -362,7 +399,11 @@ class TestLargePromptHandling:
mock_model_context_class.return_value = mock_model_context
# Should continue with empty prompt when file can't be read
result = await tool.execute({"prompt": "", "files": [bad_file]})
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute({"prompt": "", "files": [bad_file], "working_directory": temp_dir})
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
output = json.loads(result[0].text)
assert output["status"] != "resend_prompt"
@@ -408,6 +449,7 @@ class TestLargePromptHandling:
"prompt": "Summarize the design decisions",
"files": [str(large_file)],
"model": "flash",
"working_directory": str(tmp_path),
"_model_context": dummy_context,
}
)
@@ -424,6 +466,7 @@ class TestLargePromptHandling:
This test verifies that even if our internal prompt (with system prompts, history, etc.)
exceeds MCP_PROMPT_SIZE_LIMIT, it should still work as long as the user's input is small.
"""
tool = ChatTool()
# Small user input that should pass MCP boundary check
@@ -432,62 +475,57 @@ class TestLargePromptHandling:
# Mock a huge conversation history that would exceed MCP limits if incorrectly checked
huge_history = "x" * (MCP_PROMPT_SIZE_LIMIT * 2) # 100K chars = way over 50K limit
with (
patch.object(tool, "get_model_provider") as mock_get_provider,
patch("utils.model_context.ModelContext") as mock_model_context_class,
):
from tests.mock_helpers import create_mock_provider
temp_dir = tempfile.mkdtemp()
original_prepare_prompt = tool.prepare_prompt
mock_provider = create_mock_provider(model_name="flash")
mock_get_provider.return_value = mock_provider
try:
with (
patch.object(tool, "get_model_provider") as mock_get_provider,
patch("utils.model_context.ModelContext") as mock_model_context_class,
):
from tests.mock_helpers import create_mock_provider
from utils.model_context import TokenAllocation
# Mock ModelContext to avoid the comparison issue
from utils.model_context import TokenAllocation
mock_provider = create_mock_provider(model_name="flash")
mock_get_provider.return_value = mock_provider
mock_model_context = MagicMock()
mock_model_context.model_name = "flash"
mock_model_context.provider = mock_provider
mock_model_context.calculate_token_allocation.return_value = TokenAllocation(
total_tokens=1_048_576,
content_tokens=838_861,
response_tokens=209_715,
file_tokens=335_544,
history_tokens=335_544,
)
mock_model_context_class.return_value = mock_model_context
mock_model_context = MagicMock()
mock_model_context.model_name = "flash"
mock_model_context.provider = mock_provider
mock_model_context.calculate_token_allocation.return_value = TokenAllocation(
total_tokens=1_048_576,
content_tokens=838_861,
response_tokens=209_715,
file_tokens=335_544,
history_tokens=335_544,
)
mock_model_context_class.return_value = mock_model_context
# Mock the prepare_prompt to simulate huge internal context
original_prepare_prompt = tool.prepare_prompt
async def mock_prepare_prompt(request):
normal_prompt = await original_prepare_prompt(request)
huge_internal_prompt = f"{normal_prompt}\n\n=== HUGE INTERNAL CONTEXT ===\n{huge_history}"
assert len(huge_internal_prompt) > MCP_PROMPT_SIZE_LIMIT
return huge_internal_prompt
async def mock_prepare_prompt(request):
# Call original to get normal processing
normal_prompt = await original_prepare_prompt(request)
# Add huge internal context (simulating large history, system prompts, files)
huge_internal_prompt = f"{normal_prompt}\n\n=== HUGE INTERNAL CONTEXT ===\n{huge_history}"
tool.prepare_prompt = mock_prepare_prompt
# Verify the huge internal prompt would exceed MCP limits if incorrectly checked
assert len(huge_internal_prompt) > MCP_PROMPT_SIZE_LIMIT
result = await tool.execute(
{"prompt": small_user_prompt, "model": "flash", "working_directory": temp_dir}
)
output = json.loads(result[0].text)
return huge_internal_prompt
assert output["status"] != "resend_prompt"
tool.prepare_prompt = mock_prepare_prompt
mock_provider.generate_content.assert_called_once()
call_kwargs = mock_provider.generate_content.call_args[1]
actual_prompt = call_kwargs.get("prompt")
# This should succeed because we only check user input at MCP boundary
result = await tool.execute({"prompt": small_user_prompt, "model": "flash"})
output = json.loads(result[0].text)
# Should succeed even though internal context is huge
assert output["status"] != "resend_prompt"
# Verify the model was actually called with the huge prompt
mock_provider.generate_content.assert_called_once()
call_kwargs = mock_provider.generate_content.call_args[1]
actual_prompt = call_kwargs.get("prompt")
# Verify internal prompt was huge (proving we don't limit internal processing)
assert len(actual_prompt) > MCP_PROMPT_SIZE_LIMIT
assert huge_history in actual_prompt
assert small_user_prompt in actual_prompt
assert len(actual_prompt) > MCP_PROMPT_SIZE_LIMIT
assert huge_history in actual_prompt
assert small_user_prompt in actual_prompt
finally:
tool.prepare_prompt = original_prepare_prompt
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.mark.asyncio
async def test_mcp_boundary_vs_internal_processing_distinction(self):
@@ -500,27 +538,37 @@ class TestLargePromptHandling:
# Test case 1: Large user input should fail at MCP boundary
large_user_input = "x" * (MCP_PROMPT_SIZE_LIMIT + 1000)
result = await tool.execute({"prompt": large_user_input, "model": "flash"})
output = json.loads(result[0].text)
assert output["status"] == "resend_prompt" # Should fail
assert "too large for MCP's token limits" in output["content"]
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute({"prompt": large_user_input, "model": "flash", "working_directory": temp_dir})
output = json.loads(result[0].text)
assert output["status"] == "resend_prompt" # Should fail
assert "too large for MCP's token limits" in output["content"]
# Test case 2: Small user input should succeed even with huge internal processing
small_user_input = "Hello"
# Test case 2: Small user input should succeed even with huge internal processing
small_user_input = "Hello"
# This test runs in the test environment which uses dummy keys
# The chat tool will return an error for dummy keys, which is expected
result = await tool.execute({"prompt": small_user_input, "model": "gemini-2.5-flash"})
output = json.loads(result[0].text)
# This test runs in the test environment which uses dummy keys
# The chat tool will return an error for dummy keys, which is expected
result = await tool.execute(
{
"prompt": small_user_input,
"model": "gemini-2.5-flash",
"working_directory": temp_dir,
}
)
output = json.loads(result[0].text)
# The test will fail with dummy API keys, which is expected behavior
# We're mainly testing that the tool processes small prompts correctly without size errors
if output["status"] == "error":
# If it's an API error, that's fine - we're testing prompt handling, not API calls
assert "API" in output["content"] or "key" in output["content"] or "authentication" in output["content"]
else:
# If somehow it succeeds (e.g., with mocked provider), check the response
assert output["status"] != "resend_prompt"
# The test will fail with dummy API keys, which is expected behavior
# We're mainly testing that the tool processes small prompts correctly without size errors
if output["status"] == "error":
# If it's an API error, that's fine - we're testing prompt handling, not API calls
assert "API" in output["content"] or "key" in output["content"] or "authentication" in output["content"]
else:
# If somehow it succeeds (e.g., with mocked provider), check the response
assert output["status"] != "resend_prompt"
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.mark.asyncio
async def test_continuation_with_huge_conversation_history(self):
@@ -548,6 +596,8 @@ class TestLargePromptHandling:
# Ensure the history exceeds MCP limits
assert len(huge_conversation_history) > MCP_PROMPT_SIZE_LIMIT
temp_dir = tempfile.mkdtemp()
with (
patch.object(tool, "get_model_provider") as mock_get_provider,
patch("utils.model_context.ModelContext") as mock_model_context_class,
@@ -579,6 +629,7 @@ class TestLargePromptHandling:
"prompt": f"{huge_conversation_history}\n\n=== CURRENT REQUEST ===\n{small_continuation_prompt}",
"model": "flash",
"continuation_id": "test_thread_123",
"working_directory": temp_dir,
}
# Mock the conversation history embedding to simulate server.py behavior
@@ -628,6 +679,7 @@ class TestLargePromptHandling:
finally:
# Restore original execute method
tool.__class__.execute = original_execute
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__":

View File

@@ -68,6 +68,7 @@ class TestListModelsTool:
assert "`flash` → `gemini-2.5-flash`" in content
assert "`pro` → `gemini-2.5-pro`" in content
assert "1M context" in content
assert "Supports structured code generation" in content
# Check summary
assert "**Configured Providers**: 1" in content

View File

@@ -12,6 +12,7 @@ RECORDING: To record new responses, delete the cassette file and run with real A
import logging
import os
import tempfile
from pathlib import Path
from unittest.mock import patch
@@ -92,9 +93,15 @@ class TestO3ProOutputTextFix:
async def _execute_chat_tool_test(self):
"""Execute the ChatTool with o3-pro and return the result."""
chat_tool = ChatTool()
arguments = {"prompt": "What is 2 + 2?", "model": "o3-pro", "temperature": 1.0}
with tempfile.TemporaryDirectory() as workdir:
arguments = {
"prompt": "What is 2 + 2?",
"model": "o3-pro",
"temperature": 1.0,
"working_directory": workdir,
}
return await chat_tool.execute(arguments)
return await chat_tool.execute(arguments)
def _verify_chat_tool_response(self, result):
"""Verify the ChatTool response contains expected data."""

View File

@@ -4,6 +4,8 @@ Test per-tool model default selection functionality
import json
import os
import shutil
import tempfile
from unittest.mock import MagicMock, patch
import pytest
@@ -290,7 +292,13 @@ class TestAutoModeErrorMessages:
mock_get_provider_for.return_value = None
tool = ChatTool()
result = await tool.execute({"prompt": "test", "model": "auto"})
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute(
{"prompt": "test", "model": "auto", "working_directory": temp_dir}
)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
assert len(result) == 1
# The SimpleTool will wrap the error message
@@ -418,7 +426,13 @@ class TestRuntimeModelSelection:
mock_get_provider.return_value = None
tool = ChatTool()
result = await tool.execute({"prompt": "test", "model": "gpt-5-turbo"})
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute(
{"prompt": "test", "model": "gpt-5-turbo", "working_directory": temp_dir}
)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
# Should require model selection
assert len(result) == 1
@@ -515,7 +529,11 @@ class TestUnavailableModelFallback:
mock_get_model_provider.return_value = mock_provider
tool = ChatTool()
result = await tool.execute({"prompt": "test"}) # No model specified
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute({"prompt": "test", "working_directory": temp_dir})
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
# Should work normally, not require model parameter
assert len(result) == 1

View File

@@ -3,6 +3,8 @@ Tests for individual tool implementations
"""
import json
import shutil
import tempfile
import pytest
@@ -343,12 +345,17 @@ class TestAbsolutePathValidation:
async def test_chat_tool_relative_path_rejected(self):
"""Test that chat tool rejects relative paths"""
tool = ChatTool()
result = await tool.execute(
{
"prompt": "Explain this code",
"files": ["code.py"], # relative path without ./
}
)
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute(
{
"prompt": "Explain this code",
"files": ["code.py"], # relative path without ./
"working_directory": temp_dir,
}
)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
assert len(result) == 1
response = json.loads(result[0].text)