feat!: Full code can now be generated by an external model and shared with the AI tool (Claude Code / Codex etc)!
model definitions now support a new `allow_code_generation` flag, only to be used with higher reasoning models such as GPT-5-Pro and-Gemini 2.5-Pro When `true`, the `chat` tool can now request the external model to generate a full implementation / update / instructions etc and then share the implementation with the calling agent. This effectively allows us to utilize more powerful models such as GPT-5-Pro to generate code for us or entire implementations (which are either API-only or part of the $200 Pro plan from within the ChatGPT app)
This commit is contained in:
186
tools/chat.py
186
tools/chat.py
@@ -6,15 +6,20 @@ brainstorming, problem-solving, and collaborative thinking. It supports file con
|
||||
images, and conversation continuation for seamless multi-turn interactions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from providers.shared import ModelCapabilities
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from config import TEMPERATURE_BALANCED
|
||||
from systemprompts import CHAT_PROMPT
|
||||
from systemprompts import CHAT_PROMPT, GENERATE_CODE_PROMPT
|
||||
from tools.shared.base_models import COMMON_FIELD_DESCRIPTIONS, ToolRequest
|
||||
|
||||
from .simple.base import SimpleTool
|
||||
@@ -27,6 +32,9 @@ CHAT_FIELD_DESCRIPTIONS = {
|
||||
),
|
||||
"files": "absolute file or folder paths for code context (do NOT shorten).",
|
||||
"images": "Optional absolute image paths or base64 for visual context when helpful.",
|
||||
"working_directory": (
|
||||
"Absolute full directory path where the assistant AI can save generated code for implementation. The directory must already exist"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +44,7 @@ class ChatRequest(ToolRequest):
|
||||
prompt: str = Field(..., description=CHAT_FIELD_DESCRIPTIONS["prompt"])
|
||||
files: Optional[list[str]] = Field(default_factory=list, description=CHAT_FIELD_DESCRIPTIONS["files"])
|
||||
images: Optional[list[str]] = Field(default_factory=list, description=CHAT_FIELD_DESCRIPTIONS["images"])
|
||||
working_directory: str = Field(..., description=CHAT_FIELD_DESCRIPTIONS["working_directory"])
|
||||
|
||||
|
||||
class ChatTool(SimpleTool):
|
||||
@@ -49,6 +58,10 @@ class ChatTool(SimpleTool):
|
||||
Chat tool with 100% behavioral compatibility.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._last_recordable_response: Optional[str] = None
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "chat"
|
||||
|
||||
@@ -58,9 +71,20 @@ class ChatTool(SimpleTool):
|
||||
"getting second opinions, and exploring ideas. Use for ideas, validations, questions, and thoughtful explanations."
|
||||
)
|
||||
|
||||
def get_annotations(self) -> Optional[dict[str, Any]]:
|
||||
"""Chat writes generated artifacts when code-generation is enabled."""
|
||||
|
||||
return {"readOnlyHint": False}
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return CHAT_PROMPT
|
||||
|
||||
def get_capability_system_prompts(self, capabilities: Optional["ModelCapabilities"]) -> list[str]:
|
||||
prompts = list(super().get_capability_system_prompts(capabilities))
|
||||
if capabilities and capabilities.allow_code_generation:
|
||||
prompts.append(GENERATE_CODE_PROMPT)
|
||||
return prompts
|
||||
|
||||
def get_default_temperature(self) -> float:
|
||||
return TEMPERATURE_BALANCED
|
||||
|
||||
@@ -85,7 +109,7 @@ class ChatTool(SimpleTool):
|
||||
the same schema generation approach while still benefiting from SimpleTool
|
||||
convenience methods.
|
||||
"""
|
||||
required_fields = ["prompt"]
|
||||
required_fields = ["prompt", "working_directory"]
|
||||
if self.is_effective_auto_mode():
|
||||
required_fields.append("model")
|
||||
|
||||
@@ -106,6 +130,10 @@ class ChatTool(SimpleTool):
|
||||
"items": {"type": "string"},
|
||||
"description": CHAT_FIELD_DESCRIPTIONS["images"],
|
||||
},
|
||||
"working_directory": {
|
||||
"type": "string",
|
||||
"description": CHAT_FIELD_DESCRIPTIONS["working_directory"],
|
||||
},
|
||||
"model": self.get_model_field_schema(),
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
@@ -159,7 +187,7 @@ class ChatTool(SimpleTool):
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
"""Required fields for ChatSimple tool"""
|
||||
return ["prompt"]
|
||||
return ["prompt", "working_directory"]
|
||||
|
||||
# === Hook Method Implementations ===
|
||||
|
||||
@@ -173,17 +201,165 @@ class ChatTool(SimpleTool):
|
||||
# Use SimpleTool's Chat-style prompt preparation
|
||||
return self.prepare_chat_style_prompt(request)
|
||||
|
||||
def _validate_file_paths(self, request) -> Optional[str]:
|
||||
"""Extend validation to cover the working directory path."""
|
||||
|
||||
error = super()._validate_file_paths(request)
|
||||
if error:
|
||||
return error
|
||||
|
||||
working_directory = getattr(request, "working_directory", None)
|
||||
if working_directory:
|
||||
expanded = os.path.expanduser(working_directory)
|
||||
if not os.path.isabs(expanded):
|
||||
return (
|
||||
"Error: 'working_directory' must be an absolute path (you may use '~' which will be expanded). "
|
||||
f"Received: {working_directory}"
|
||||
)
|
||||
return None
|
||||
|
||||
def format_response(self, response: str, request: ChatRequest, model_info: Optional[dict] = None) -> str:
|
||||
"""
|
||||
Format the chat response to match the original Chat tool exactly.
|
||||
"""
|
||||
return (
|
||||
f"{response}\n\n---\n\nAGENT'S TURN: Evaluate this perspective alongside your analysis to "
|
||||
self._last_recordable_response = None
|
||||
body = response
|
||||
recordable_override: Optional[str] = None
|
||||
|
||||
if self._model_supports_code_generation():
|
||||
block, remainder = self._extract_generated_code_block(response)
|
||||
if block:
|
||||
sanitized_text = remainder.strip()
|
||||
try:
|
||||
artifact_path = self._persist_generated_code_block(block, request.working_directory)
|
||||
except Exception as exc: # pragma: no cover - rare filesystem failures
|
||||
logger.error("Failed to persist generated code block: %s", exc, exc_info=True)
|
||||
warning = (
|
||||
f"WARNING: Unable to write zen_generated.code inside '{request.working_directory}'. "
|
||||
"Check the path permissions and re-run. The generated code block is included below for manual handling."
|
||||
)
|
||||
|
||||
history_copy = self._join_sections(sanitized_text, warning) if sanitized_text else warning
|
||||
recordable_override = history_copy
|
||||
|
||||
sanitized_warning = history_copy.strip()
|
||||
body = f"{sanitized_warning}\n\n{block.strip()}".strip()
|
||||
else:
|
||||
if not sanitized_text:
|
||||
sanitized_text = "Generated code saved to zen_generated.code. Follow the structured instructions in that file exactly before continuing."
|
||||
|
||||
instruction = self._build_agent_instruction(artifact_path)
|
||||
body = self._join_sections(sanitized_text, instruction)
|
||||
|
||||
final_output = (
|
||||
f"{body}\n\n---\n\nAGENT'S TURN: Evaluate this perspective alongside your analysis to "
|
||||
"form a comprehensive solution and continue with the user's request and task at hand."
|
||||
)
|
||||
|
||||
if recordable_override is not None:
|
||||
self._last_recordable_response = (
|
||||
f"{recordable_override}\n\n---\n\nAGENT'S TURN: Evaluate this perspective alongside your analysis to "
|
||||
"form a comprehensive solution and continue with the user's request and task at hand."
|
||||
)
|
||||
else:
|
||||
self._last_recordable_response = final_output
|
||||
|
||||
return final_output
|
||||
|
||||
def _record_assistant_turn(
|
||||
self, continuation_id: str, response_text: str, request, model_info: Optional[dict]
|
||||
) -> None:
|
||||
recordable = self._last_recordable_response if self._last_recordable_response is not None else response_text
|
||||
try:
|
||||
super()._record_assistant_turn(continuation_id, recordable, request, model_info)
|
||||
finally:
|
||||
self._last_recordable_response = None
|
||||
|
||||
def _model_supports_code_generation(self) -> bool:
|
||||
context = getattr(self, "_model_context", None)
|
||||
if not context:
|
||||
return False
|
||||
|
||||
try:
|
||||
capabilities = context.capabilities
|
||||
except Exception: # pragma: no cover - defensive fallback
|
||||
return False
|
||||
|
||||
return bool(capabilities.allow_code_generation)
|
||||
|
||||
def _extract_generated_code_block(self, text: str) -> tuple[Optional[str], str]:
|
||||
match = re.search(r"<GENERATED-CODE>.*?</GENERATED-CODE>", text, flags=re.DOTALL | re.IGNORECASE)
|
||||
if not match:
|
||||
return None, text
|
||||
|
||||
block = match.group(0)
|
||||
before = text[: match.start()].rstrip()
|
||||
after = text[match.end() :].lstrip()
|
||||
|
||||
if before and after:
|
||||
remainder = f"{before}\n\n{after}"
|
||||
else:
|
||||
remainder = before or after
|
||||
|
||||
return block, remainder or ""
|
||||
|
||||
def _persist_generated_code_block(self, block: str, working_directory: str) -> Path:
|
||||
expanded = os.path.expanduser(working_directory)
|
||||
target_dir = Path(expanded).resolve()
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
target_file = target_dir / "zen_generated.code"
|
||||
if target_file.exists():
|
||||
try:
|
||||
target_file.unlink()
|
||||
except OSError as exc:
|
||||
logger.warning("Unable to remove existing zen_generated.code: %s", exc)
|
||||
|
||||
content = block if block.endswith("\n") else f"{block}\n"
|
||||
target_file.write_text(content, encoding="utf-8")
|
||||
logger.info("Generated code artifact written to %s", target_file)
|
||||
return target_file
|
||||
|
||||
@staticmethod
|
||||
def _build_agent_instruction(artifact_path: Path) -> str:
|
||||
return (
|
||||
f"CONTINUING FROM PREVIOUS DISCUSSION: The coding assistant has analyzed our conversation context and generated "
|
||||
f"a structured implementation plan at `{artifact_path}`. This is a direct continuation of our discussion—all previous "
|
||||
"context, requirements, and shared code remain relevant.\n"
|
||||
"\n"
|
||||
f"MANDATORY NEXT STEP: Open `{artifact_path}` immediately and review the implementation plan:\n"
|
||||
"1. Read the step-by-step instructions—they reference our previous discussion. You may need to read the file in parts if it's too long.\n"
|
||||
"2. Review each <NEWFILE:…> or <UPDATED_EXISTING_FILE:…> section in the context of what we've discussed\n"
|
||||
"3. Verify the proposed changes align with the requirements and code we've already shared\n"
|
||||
"4. Check for syntax errors, missing imports, or incomplete implementations\n"
|
||||
"\n"
|
||||
"Then systematically apply the changes:\n"
|
||||
"- Create new files or update existing ones as instructed, maintaining code style consistency\n"
|
||||
"- If updating existing code we discussed earlier, carefully preserve unmodified sections\n"
|
||||
"- Run syntax validation after each modification\n"
|
||||
"- Execute relevant tests to confirm functionality\n"
|
||||
"- Verify the implementation works end-to-end with existing code\n"
|
||||
"\n"
|
||||
"Remember: This builds upon our conversation. The generated code reflects the full context of what we've discussed, "
|
||||
"including any files, requirements, or constraints mentioned earlier. Proceed with implementation immediately."
|
||||
"Only after you finish applying ALL the changes completely: delete `zen_generated.code` so stale instructions do not linger."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _join_sections(*sections: str) -> str:
|
||||
chunks: list[str] = []
|
||||
for section in sections:
|
||||
if section:
|
||||
trimmed = section.strip()
|
||||
if trimmed:
|
||||
chunks.append(trimmed)
|
||||
return "\n\n".join(chunks)
|
||||
|
||||
def get_websearch_guidance(self) -> str:
|
||||
"""
|
||||
Return Chat tool-style web search guidance.
|
||||
"""
|
||||
return self.get_chat_style_websearch_guidance()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -140,6 +140,8 @@ class ListModelsTool(BaseTool):
|
||||
except AttributeError:
|
||||
description = "No description available"
|
||||
lines = [header, f" - {context_str}", f" - {description}"]
|
||||
if capabilities.allow_code_generation:
|
||||
lines.append(" - Supports structured code generation")
|
||||
return lines
|
||||
|
||||
# Check each native provider type
|
||||
@@ -187,6 +189,8 @@ class ListModelsTool(BaseTool):
|
||||
|
||||
output_lines.append(f"- `{model_name}` - {context_str}")
|
||||
output_lines.append(f" - {description}")
|
||||
if capabilities.allow_code_generation:
|
||||
output_lines.append(" - Supports structured code generation")
|
||||
|
||||
for alias in capabilities.aliases or []:
|
||||
if alias != model_name:
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
from mcp.types import TextContent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from providers.shared import ModelCapabilities
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from config import MCP_PROMPT_SIZE_LIMIT
|
||||
@@ -165,6 +166,42 @@ class BaseTool(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_capability_system_prompts(self, capabilities: Optional["ModelCapabilities"]) -> list[str]:
|
||||
"""Return additional system prompt snippets gated on model capabilities.
|
||||
|
||||
Subclasses can override this hook to append capability-specific
|
||||
instructions (for example, enabling code-generation exports when a
|
||||
model advertises support). The default implementation returns an empty
|
||||
list so no extra instructions are appended.
|
||||
|
||||
Args:
|
||||
capabilities: The resolved capabilities for the active model.
|
||||
|
||||
Returns:
|
||||
List of prompt fragments to append after the base system prompt.
|
||||
"""
|
||||
|
||||
return []
|
||||
|
||||
def _augment_system_prompt_with_capabilities(
|
||||
self, base_prompt: str, capabilities: Optional["ModelCapabilities"]
|
||||
) -> str:
|
||||
"""Merge capability-driven prompt addenda with the base system prompt."""
|
||||
|
||||
additions: list[str] = []
|
||||
if capabilities is not None:
|
||||
additions = [fragment.strip() for fragment in self.get_capability_system_prompts(capabilities) if fragment]
|
||||
|
||||
if not additions:
|
||||
return base_prompt
|
||||
|
||||
addition_text = "\n\n".join(additions)
|
||||
if not base_prompt:
|
||||
return addition_text
|
||||
|
||||
suffix = "" if base_prompt.endswith("\n\n") else "\n\n"
|
||||
return f"{base_prompt}{suffix}{addition_text}"
|
||||
|
||||
def get_annotations(self) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Return optional annotations for this tool.
|
||||
@@ -413,13 +450,16 @@ class BaseTool(ABC):
|
||||
for rank, canonical_name, capabilities in filtered[:limit]:
|
||||
details: list[str] = []
|
||||
|
||||
context_str = self._format_context_window(getattr(capabilities, "context_window", 0))
|
||||
context_str = self._format_context_window(capabilities.context_window)
|
||||
if context_str:
|
||||
details.append(context_str)
|
||||
|
||||
if getattr(capabilities, "supports_extended_thinking", False):
|
||||
if capabilities.supports_extended_thinking:
|
||||
details.append("thinking")
|
||||
|
||||
if capabilities.allow_code_generation:
|
||||
details.append("code-gen")
|
||||
|
||||
base = f"{canonical_name} (score {rank}"
|
||||
if details:
|
||||
base = f"{base}, {', '.join(details)}"
|
||||
|
||||
@@ -404,11 +404,15 @@ class SimpleTool(BaseTool):
|
||||
|
||||
# Get the provider from model context (clean OOP - no re-fetching)
|
||||
provider = self._model_context.provider
|
||||
capabilities = self._model_context.capabilities
|
||||
|
||||
# Get system prompt for this tool
|
||||
base_system_prompt = self.get_system_prompt()
|
||||
capability_augmented_prompt = self._augment_system_prompt_with_capabilities(
|
||||
base_system_prompt, capabilities
|
||||
)
|
||||
language_instruction = self.get_language_instruction()
|
||||
system_prompt = language_instruction + base_system_prompt
|
||||
system_prompt = language_instruction + capability_augmented_prompt
|
||||
|
||||
# Generate AI response using the provider
|
||||
logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.get_name()}")
|
||||
@@ -423,7 +427,6 @@ class SimpleTool(BaseTool):
|
||||
logger.debug(f"Prompt length: {len(prompt)} characters (~{estimated_tokens:,} tokens)")
|
||||
|
||||
# Resolve model capabilities for feature gating
|
||||
capabilities = self._model_context.capabilities
|
||||
supports_thinking = capabilities.supports_extended_thinking
|
||||
|
||||
# Generate content with provider abstraction
|
||||
|
||||
@@ -1480,8 +1480,11 @@ class BaseWorkflowMixin(ABC):
|
||||
|
||||
# Get system prompt for this tool with localization support
|
||||
base_system_prompt = self.get_system_prompt()
|
||||
capability_augmented_prompt = self._augment_system_prompt_with_capabilities(
|
||||
base_system_prompt, getattr(self._model_context, "capabilities", None)
|
||||
)
|
||||
language_instruction = self.get_language_instruction()
|
||||
system_prompt = language_instruction + base_system_prompt
|
||||
system_prompt = language_instruction + capability_augmented_prompt
|
||||
|
||||
# Check if tool wants system prompt embedded in main prompt
|
||||
if self.should_embed_system_prompt():
|
||||
|
||||
Reference in New Issue
Block a user