Improved prompts to encourage better investigative flow

Improved abstraction
This commit is contained in:
Fahad
2025-06-19 10:56:39 +04:00
parent fccfb0d999
commit 43485dadd6
3 changed files with 231 additions and 99 deletions

View File

@@ -512,3 +512,91 @@ class TestDebugToolIntegration:
assert parsed_response["status"] == "investigation_in_progress"
assert parsed_response["step_number"] == 1
assert parsed_response["continuation_id"] == "debug-flow-uuid"
@pytest.mark.asyncio
async def test_model_context_initialization_in_expert_analysis(self):
"""Real integration test that model context is properly initialized when expert analysis is called."""
tool = DebugIssueTool()
# Do NOT manually set up model context - let the method do it itself
# Set up investigation state for final step
tool.initial_issue = "Memory leak investigation"
tool.investigation_history = [
{
"step_number": 1,
"step": "Initial investigation",
"findings": "Found memory issues",
"files_checked": [],
}
]
tool.consolidated_findings = {
"files_checked": set(),
"relevant_files": set(), # No files to avoid file I/O in this test
"relevant_methods": {"process_data"},
"findings": ["Step 1: Found memory issues"],
"hypotheses": [],
"images": [],
}
# Test the _call_expert_analysis method directly to verify ModelContext is properly handled
# This is the real test - we're testing that the method can be called without the ModelContext error
try:
# Only mock the API call itself, not the model resolution infrastructure
from unittest.mock import MagicMock
mock_provider = MagicMock()
mock_response = MagicMock()
mock_response.content = '{"status": "analysis_complete", "summary": "Test completed"}'
mock_provider.generate_content.return_value = mock_response
# Use the real get_model_provider method but override its result to avoid API calls
original_get_provider = tool.get_model_provider
tool.get_model_provider = lambda model_name: mock_provider
try:
# Create mock arguments and request for model resolution
from tools.debug import DebugInvestigationRequest
mock_arguments = {"model": None} # No model specified, should fall back to DEFAULT_MODEL
mock_request = DebugInvestigationRequest(
step="Test step",
step_number=1,
total_steps=1,
next_step_required=False,
findings="Test findings"
)
# This should NOT raise a ModelContext error - the method should set up context itself
result = await tool._call_expert_analysis(
initial_issue="Test issue",
investigation_summary="Test summary",
relevant_files=[], # Empty to avoid file operations
relevant_methods=["test_method"],
final_hypothesis="Test hypothesis",
error_context=None,
images=[],
model_info=None, # No pre-resolved model info
arguments=mock_arguments, # Provide arguments for model resolution
request=mock_request, # Provide request for model resolution
)
# Should complete without ModelContext error
assert "error" not in result
assert result["status"] == "analysis_complete"
# Verify the model context was actually set up
assert hasattr(tool, "_model_context")
assert hasattr(tool, "_current_model_name")
# Should use DEFAULT_MODEL when no model specified
from config import DEFAULT_MODEL
assert tool._current_model_name == DEFAULT_MODEL
finally:
# Restore original method
tool.get_model_provider = original_get_provider
except RuntimeError as e:
if "ModelContext not initialized" in str(e):
pytest.fail("ModelContext error still occurs - the fix is not working properly")
else:
raise # Re-raise other RuntimeErrors

View File

@@ -1325,62 +1325,19 @@ When recommending searches, be specific about what information you need and why
# Extract and validate images from request
images = getattr(request, "images", None) or []
# MODEL RESOLUTION NOW HAPPENS AT MCP BOUNDARY
# Extract pre-resolved model context from server.py
model_context = self._current_arguments.get("_model_context")
resolved_model_name = self._current_arguments.get("_resolved_model_name")
if model_context and resolved_model_name:
# Model was already resolved at MCP boundary
model_name = resolved_model_name
logger.debug(f"Using pre-resolved model '{model_name}' from MCP boundary")
else:
# Fallback for direct execute calls
model_name = getattr(request, "model", None)
if not model_name:
from config import DEFAULT_MODEL
model_name = DEFAULT_MODEL
logger.debug(f"Using fallback model resolution for '{model_name}' (test mode)")
# For tests: Check if we should require model selection (auto mode)
if self._should_require_model_selection(model_name):
# Get suggested model based on tool category
from providers.registry import ModelProviderRegistry
tool_category = self.get_model_category()
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
# Build error message based on why selection is required
if model_name.lower() == "auto":
error_message = (
f"Model parameter is required in auto mode. "
f"Suggested model for {self.name}: '{suggested_model}' "
f"(category: {tool_category.value})"
)
else:
# Model was specified but not available
available_models = self._get_available_models()
error_message = (
f"Model '{model_name}' is not available with current API keys. "
f"Available models: {', '.join(available_models)}. "
f"Suggested model for {self.name}: '{suggested_model}' "
f"(category: {tool_category.value})"
)
# Use centralized model resolution
try:
model_name, model_context = self._resolve_model_context(self._current_arguments, request)
except ValueError as e:
# Model resolution failed, return error
error_output = ToolOutput(
status="error",
content=error_message,
content=str(e),
content_type="text",
)
return [TextContent(type="text", text=error_output.model_dump_json())]
# Create model context for tests
from utils.model_context import ModelContext
model_context = ModelContext(model_name)
# Store resolved model name for use by helper methods
# Store resolved model name and context for use by helper methods
self._current_model_name = model_name
self._model_context = model_context
@@ -1929,6 +1886,77 @@ When recommending searches, be specific about what information you need and why
logger.warning(f"Temperature validation failed for {model_name}: {e}")
return temperature, [f"Temperature validation failed: {e}"]
def _resolve_model_context(self, arguments: dict[str, Any], request) -> tuple[str, Any]:
"""
Resolve model context and name using centralized logic.
This method extracts the model resolution logic from execute() so it can be
reused by tools that override execute() (like debug tool) without duplicating code.
Args:
arguments: Dictionary of arguments from the MCP client
request: The validated request object
Returns:
tuple[str, ModelContext]: (resolved_model_name, model_context)
Raises:
ValueError: If model resolution fails or model selection is required
"""
logger = logging.getLogger(f"tools.{self.name}")
# MODEL RESOLUTION NOW HAPPENS AT MCP BOUNDARY
# Extract pre-resolved model context from server.py
model_context = arguments.get("_model_context")
resolved_model_name = arguments.get("_resolved_model_name")
if model_context and resolved_model_name:
# Model was already resolved at MCP boundary
model_name = resolved_model_name
logger.debug(f"Using pre-resolved model '{model_name}' from MCP boundary")
else:
# Fallback for direct execute calls
model_name = getattr(request, "model", None)
if not model_name:
from config import DEFAULT_MODEL
model_name = DEFAULT_MODEL
logger.debug(f"Using fallback model resolution for '{model_name}' (test mode)")
# For tests: Check if we should require model selection (auto mode)
if self._should_require_model_selection(model_name):
# Get suggested model based on tool category
from providers.registry import ModelProviderRegistry
tool_category = self.get_model_category()
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
# Build error message based on why selection is required
if model_name.lower() == "auto":
error_message = (
f"Model parameter is required in auto mode. "
f"Suggested model for {self.name}: '{suggested_model}' "
f"(category: {tool_category.value})"
)
else:
# Model was specified but not available
available_models = self._get_available_models()
error_message = (
f"Model '{model_name}' is not available with current API keys. "
f"Available models: {', '.join(available_models)}. "
f"Suggested model for {self.name}: '{suggested_model}' "
f"(category: {tool_category.value})"
)
raise ValueError(error_message)
# Create model context for tests
from utils.model_context import ModelContext
model_context = ModelContext(model_name)
return model_name, model_context
def get_model_provider(self, model_name: str) -> ModelProvider:
"""
Get a model provider for the specified model.

View File

@@ -21,52 +21,48 @@ logger = logging.getLogger(__name__)
# Field descriptions for the investigation steps
DEBUG_INVESTIGATION_FIELD_DESCRIPTIONS = {
"step": (
"Your current investigation step. For the first step, describe the issue/error to investigate. "
"For subsequent steps, describe what you're investigating, what code you're examining, "
"what patterns you're looking for, or what hypothesis you're testing."
"Describe what you're currently investigating. In step 1, clearly state the issue to investigate and begin "
"thinking deeply about where the problem might originate. In all subsequent steps, continue uncovering relevant "
"code, examining patterns, and formulating hypotheses with deliberate attention to detail."
),
"step_number": "Current step number in the investigation sequence (starts at 1)",
"total_steps": "Current estimate of total investigation steps needed (can be adjusted as investigation progresses)",
"next_step_required": "Whether another investigation step is required",
"step_number": "Current step number in the investigation sequence (starts at 1).",
"total_steps": "Estimate of total investigation steps expected (adjustable as the process evolves).",
"next_step_required": "Whether another investigation step is needed after this one.",
"findings": (
"Current findings from this investigation step. Include code patterns discovered, "
"potential causes identified, hypotheses formed, or evidence gathered."
"Summarize discoveries in this step. Think critically and include relevant code behavior, suspicious patterns, "
"evidence collected, and any partial conclusions or leads."
),
"files_checked": (
"List of files you've examined so far in the investigation (cumulative list). "
"Include all files you've looked at, even if they turned out to be irrelevant."
"List all files examined during the investigation so far. Include even files ruled out, as this tracks your exploration path."
),
"relevant_files": (
"List of files that are definitely related to the issue (subset of files_checked). "
"Only include files that contain code directly related to the problem."
"Subset of files_checked that contain code directly relevant to the issue. Only list those that are directly tied to the root cause or its effects."
),
"relevant_methods": (
"List of specific methods/functions that are involved in the issue. "
"Format: 'ClassName.methodName' or 'functionName'"
"List specific methods/functions clearly tied to the issue. Use 'ClassName.methodName' or 'functionName' format."
),
"hypothesis": (
"Your current working hypothesis about the root cause. This can be updated/revised "
"as the investigation progresses."
"Formulate your current best guess about the underlying cause. This is a working theory and may evolve based on further evidence."
),
"confidence": "Your confidence level in the current hypothesis: 'low', 'medium', or 'high'",
"backtrack_from_step": "If you need to revise a previous finding, which step number to backtrack from",
"continuation_id": "Thread continuation ID for multi-turn investigation sessions",
"confidence": "How confident you are in the current hypothesis: 'low', 'medium', or 'high'.",
"backtrack_from_step": "If a previous step needs revision, specify the step number to backtrack from.",
"continuation_id": "Continuation token used for linking multi-step investigations.",
"images": (
"Optional images showing error screens, UI issues, logs displays, or visual debugging information "
"that help understand the issue (must be FULL absolute paths - DO NOT SHORTEN)"
"Optional. Include full absolute paths to visual debugging images (UI issues, logs, error screens) that help clarify the issue."
),
}
# Field descriptions for the final debug request
DEBUG_FIELD_DESCRIPTIONS = {
"initial_issue": "The original issue description that started the investigation",
"investigation_summary": "Complete summary of the systematic investigation performed",
"findings": "Consolidated findings from all investigation steps",
"files": "Essential files identified during investigation (must be FULL absolute paths - DO NOT SHORTEN)",
"error_context": "Stack trace, logs, or error context discovered during investigation",
"relevant_methods": "List of methods/functions identified as involved in the issue",
"hypothesis": "Final hypothesis about the root cause after investigation",
"images": "Optional images showing error screens, UI issues, or visual debugging information",
"initial_issue": "Describe the original problem that triggered the investigation.",
"investigation_summary": (
"Full overview of the systematic investigation process. Reflect deep thinking and each step's contribution to narrowing down the issue."
),
"findings": "Final list of critical insights and discoveries across all steps.",
"files": "Essential files referenced during investigation (must be full absolute paths).",
"error_context": "Logs, tracebacks, or execution details that support the root cause hypothesis.",
"relevant_methods": "List of all methods/functions identified as directly involved.",
"hypothesis": "Final, most likely explanation of the root cause based on evidence.",
"images": "Optional screenshots or visual materials that helped diagnose the issue.",
}
@@ -268,7 +264,9 @@ class DebugIssueTool(BaseTool):
# Create thread for first step
if not continuation_id and request.step_number == 1:
continuation_id = create_thread("debug", arguments)
# Clean arguments to remove non-serializable fields
clean_args = {k: v for k, v in arguments.items() if k not in ["_model_context", "_resolved_model_name"]}
continuation_id = create_thread("debug", clean_args)
# Store initial issue description
self.initial_issue = request.step
@@ -356,8 +354,9 @@ class DebugIssueTool(BaseTool):
final_hypothesis=request.hypothesis,
error_context=self._extract_error_context(),
images=list(set(self.consolidated_findings["images"])), # Unique images
model_info=arguments.get("_model_context"),
model_override=arguments.get("model"), # Pass model selection from final step
model_info=arguments.get("_model_context"), # Use pre-resolved model context from server.py
arguments=arguments, # Pass arguments for model resolution
request=request, # Pass request for model resolution
)
# Combine investigation and expert analysis
@@ -478,9 +477,36 @@ class DebugIssueTool(BaseTool):
error_context: Optional[str],
images: list[str],
model_info: Optional[Any] = None,
model_override: Optional[str] = None,
arguments: Optional[dict] = None,
request: Optional[Any] = None,
) -> dict:
"""Call AI model for expert analysis of the investigation"""
# Set up model context when we actually need it for expert analysis
# Use the same model resolution logic as the base class
if model_info:
# Use pre-resolved model context from server.py (normal case)
self._model_context = model_info
model_name = model_info.model_name
else:
# Use centralized model resolution from base class
if arguments and request:
try:
model_name, model_context = self._resolve_model_context(arguments, request)
self._model_context = model_context
except ValueError as e:
# Model resolution failed, return error
return {"error": f"Model resolution failed: {str(e)}", "status": "model_resolution_error"}
else:
# Last resort fallback if no arguments/request provided
from config import DEFAULT_MODEL
from utils.model_context import ModelContext
model_name = DEFAULT_MODEL
self._model_context = ModelContext(model_name)
# Store model name for use by other methods
self._current_model_name = model_name
provider = self.get_model_provider(model_name)
# Prepare the debug prompt with all investigation context
prompt_parts = [
f"=== ISSUE DESCRIPTION ===\n{initial_issue}\n=== END DESCRIPTION ===",
@@ -517,16 +543,6 @@ class DebugIssueTool(BaseTool):
full_prompt = "\n".join(prompt_parts)
# Get appropriate model and provider
from config import DEFAULT_MODEL
from providers.registry import ModelProviderRegistry
model_name = model_override or DEFAULT_MODEL # Use override if provided
provider = ModelProviderRegistry.get_provider_for_model(model_name)
if not provider:
return {"error": f"No provider available for model {model_name}", "status": "provider_error"}
# Generate AI response
try:
full_analysis_prompt = f"{self.get_system_prompt()}\n\n{full_prompt}\n\nPlease debug this issue following the structured format in the system prompt."