Improved prompts to encourage better investigative flow
Improved abstraction
This commit is contained in:
138
tools/base.py
138
tools/base.py
@@ -1325,62 +1325,19 @@ When recommending searches, be specific about what information you need and why
|
||||
# Extract and validate images from request
|
||||
images = getattr(request, "images", None) or []
|
||||
|
||||
# MODEL RESOLUTION NOW HAPPENS AT MCP BOUNDARY
|
||||
# Extract pre-resolved model context from server.py
|
||||
model_context = self._current_arguments.get("_model_context")
|
||||
resolved_model_name = self._current_arguments.get("_resolved_model_name")
|
||||
# Use centralized model resolution
|
||||
try:
|
||||
model_name, model_context = self._resolve_model_context(self._current_arguments, request)
|
||||
except ValueError as e:
|
||||
# Model resolution failed, return error
|
||||
error_output = ToolOutput(
|
||||
status="error",
|
||||
content=str(e),
|
||||
content_type="text",
|
||||
)
|
||||
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||
|
||||
if model_context and resolved_model_name:
|
||||
# Model was already resolved at MCP boundary
|
||||
model_name = resolved_model_name
|
||||
logger.debug(f"Using pre-resolved model '{model_name}' from MCP boundary")
|
||||
else:
|
||||
# Fallback for direct execute calls
|
||||
model_name = getattr(request, "model", None)
|
||||
if not model_name:
|
||||
from config import DEFAULT_MODEL
|
||||
|
||||
model_name = DEFAULT_MODEL
|
||||
logger.debug(f"Using fallback model resolution for '{model_name}' (test mode)")
|
||||
|
||||
# For tests: Check if we should require model selection (auto mode)
|
||||
if self._should_require_model_selection(model_name):
|
||||
# Get suggested model based on tool category
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
tool_category = self.get_model_category()
|
||||
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
|
||||
|
||||
# Build error message based on why selection is required
|
||||
if model_name.lower() == "auto":
|
||||
error_message = (
|
||||
f"Model parameter is required in auto mode. "
|
||||
f"Suggested model for {self.name}: '{suggested_model}' "
|
||||
f"(category: {tool_category.value})"
|
||||
)
|
||||
else:
|
||||
# Model was specified but not available
|
||||
available_models = self._get_available_models()
|
||||
|
||||
error_message = (
|
||||
f"Model '{model_name}' is not available with current API keys. "
|
||||
f"Available models: {', '.join(available_models)}. "
|
||||
f"Suggested model for {self.name}: '{suggested_model}' "
|
||||
f"(category: {tool_category.value})"
|
||||
)
|
||||
error_output = ToolOutput(
|
||||
status="error",
|
||||
content=error_message,
|
||||
content_type="text",
|
||||
)
|
||||
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||
|
||||
# Create model context for tests
|
||||
from utils.model_context import ModelContext
|
||||
|
||||
model_context = ModelContext(model_name)
|
||||
|
||||
# Store resolved model name for use by helper methods
|
||||
# Store resolved model name and context for use by helper methods
|
||||
self._current_model_name = model_name
|
||||
self._model_context = model_context
|
||||
|
||||
@@ -1929,6 +1886,77 @@ When recommending searches, be specific about what information you need and why
|
||||
logger.warning(f"Temperature validation failed for {model_name}: {e}")
|
||||
return temperature, [f"Temperature validation failed: {e}"]
|
||||
|
||||
def _resolve_model_context(self, arguments: dict[str, Any], request) -> tuple[str, Any]:
|
||||
"""
|
||||
Resolve model context and name using centralized logic.
|
||||
|
||||
This method extracts the model resolution logic from execute() so it can be
|
||||
reused by tools that override execute() (like debug tool) without duplicating code.
|
||||
|
||||
Args:
|
||||
arguments: Dictionary of arguments from the MCP client
|
||||
request: The validated request object
|
||||
|
||||
Returns:
|
||||
tuple[str, ModelContext]: (resolved_model_name, model_context)
|
||||
|
||||
Raises:
|
||||
ValueError: If model resolution fails or model selection is required
|
||||
"""
|
||||
logger = logging.getLogger(f"tools.{self.name}")
|
||||
|
||||
# MODEL RESOLUTION NOW HAPPENS AT MCP BOUNDARY
|
||||
# Extract pre-resolved model context from server.py
|
||||
model_context = arguments.get("_model_context")
|
||||
resolved_model_name = arguments.get("_resolved_model_name")
|
||||
|
||||
if model_context and resolved_model_name:
|
||||
# Model was already resolved at MCP boundary
|
||||
model_name = resolved_model_name
|
||||
logger.debug(f"Using pre-resolved model '{model_name}' from MCP boundary")
|
||||
else:
|
||||
# Fallback for direct execute calls
|
||||
model_name = getattr(request, "model", None)
|
||||
if not model_name:
|
||||
from config import DEFAULT_MODEL
|
||||
|
||||
model_name = DEFAULT_MODEL
|
||||
logger.debug(f"Using fallback model resolution for '{model_name}' (test mode)")
|
||||
|
||||
# For tests: Check if we should require model selection (auto mode)
|
||||
if self._should_require_model_selection(model_name):
|
||||
# Get suggested model based on tool category
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
tool_category = self.get_model_category()
|
||||
suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category)
|
||||
|
||||
# Build error message based on why selection is required
|
||||
if model_name.lower() == "auto":
|
||||
error_message = (
|
||||
f"Model parameter is required in auto mode. "
|
||||
f"Suggested model for {self.name}: '{suggested_model}' "
|
||||
f"(category: {tool_category.value})"
|
||||
)
|
||||
else:
|
||||
# Model was specified but not available
|
||||
available_models = self._get_available_models()
|
||||
|
||||
error_message = (
|
||||
f"Model '{model_name}' is not available with current API keys. "
|
||||
f"Available models: {', '.join(available_models)}. "
|
||||
f"Suggested model for {self.name}: '{suggested_model}' "
|
||||
f"(category: {tool_category.value})"
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Create model context for tests
|
||||
from utils.model_context import ModelContext
|
||||
|
||||
model_context = ModelContext(model_name)
|
||||
|
||||
return model_name, model_context
|
||||
|
||||
def get_model_provider(self, model_name: str) -> ModelProvider:
|
||||
"""
|
||||
Get a model provider for the specified model.
|
||||
|
||||
Reference in New Issue
Block a user