feat: depending on the number of tools in use, this change should save ~50% of overall tokens used. fixes https://github.com/BeehiveInnovations/zen-mcp-server/issues/255 but also refactored individual tools to instead encourage the agent to use the listmodels tool if needed.

This commit is contained in:
Fahad
2025-10-01 21:40:31 +04:00
parent 5ff27f5b3e
commit d9449c7bb6
10 changed files with 74 additions and 262 deletions

View File

@@ -87,6 +87,10 @@ class ChatTool(SimpleTool):
the same schema generation approach while still benefiting from SimpleTool
convenience methods.
"""
required_fields = ["prompt"]
if self.is_effective_auto_mode():
required_fields.append("model")
schema = {
"type": "object",
"properties": {
@@ -121,7 +125,7 @@ class ChatTool(SimpleTool):
"description": COMMON_FIELD_DESCRIPTIONS["continuation_id"],
},
},
"required": ["prompt"] + (["model"] if self.is_effective_auto_mode() else []),
"required": required_fields,
}
return schema

View File

@@ -298,182 +298,30 @@ class BaseTool(ABC):
Returns:
Dict containing the model field JSON schema
"""
import os
from config import DEFAULT_MODEL
# Check if OpenRouter is configured
has_openrouter = bool(
os.getenv("OPENROUTER_API_KEY") and os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here"
)
# Use the centralized effective auto mode check
if self.is_effective_auto_mode():
# In auto mode, model is required and we provide detailed descriptions
model_desc_parts = [
"IMPORTANT: Use the model specified by the user if provided, OR select the most suitable model "
"for this specific task based on the requirements and capabilities listed below:"
]
# Get descriptions from enabled providers
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
# Map provider types to readable names
provider_names = {
ProviderType.GOOGLE: "Gemini models",
ProviderType.OPENAI: "OpenAI models",
ProviderType.XAI: "X.AI GROK models",
ProviderType.DIAL: "DIAL models",
ProviderType.CUSTOM: "Custom models",
ProviderType.OPENROUTER: "OpenRouter models",
}
# Check available providers and add their model descriptions
# Start with native providers
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]:
# Only if this is registered / available
provider = ModelProviderRegistry.get_provider(provider_type)
if provider:
provider_section_added = False
for model_name in provider.list_models(respect_restrictions=True):
try:
# Get model config to extract description
model_config = provider.SUPPORTED_MODELS.get(model_name)
if model_config and model_config.description:
if not provider_section_added:
model_desc_parts.append(
f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:"
)
provider_section_added = True
model_desc_parts.append(f"- '{model_name}': {model_config.description}")
except Exception:
# Skip models without descriptions
continue
# Add custom models if custom API is configured
custom_url = os.getenv("CUSTOM_API_URL")
if custom_url:
# Load custom models from registry
try:
registry = self._get_openrouter_registry()
model_desc_parts.append(f"\nCustom models via {custom_url}:")
# Find all custom models (is_custom=true)
for alias in registry.list_aliases():
config = registry.resolve(alias)
# Check if this is a custom model that requires custom endpoints
if config and config.is_custom:
# Format context window
context_tokens = config.context_window
if context_tokens >= 1_000_000:
context_str = f"{context_tokens // 1_000_000}M"
elif context_tokens >= 1_000:
context_str = f"{context_tokens // 1_000}K"
else:
context_str = str(context_tokens)
desc_line = f"- '{alias}' ({context_str} context): {config.description}"
if desc_line not in model_desc_parts: # Avoid duplicates
model_desc_parts.append(desc_line)
except Exception as e:
import logging
logging.debug(f"Failed to load custom model descriptions: {e}")
model_desc_parts.append(f"\nCustom models: Models available via {custom_url}")
if has_openrouter:
# Add OpenRouter models with descriptions
try:
import logging
registry = self._get_openrouter_registry()
# Group models by their model_name to avoid duplicates
seen_models = set()
model_configs = []
for alias in registry.list_aliases():
config = registry.resolve(alias)
if config and config.model_name not in seen_models:
seen_models.add(config.model_name)
model_configs.append((alias, config))
# Sort by context window (descending) then by alias
model_configs.sort(key=lambda x: (-x[1].context_window, x[0]))
if model_configs:
model_desc_parts.append("\nOpenRouter models (use these aliases):")
for alias, config in model_configs: # Show ALL models so the CLI can choose
# Format context window in human-readable form
context_tokens = config.context_window
if context_tokens >= 1_000_000:
context_str = f"{context_tokens // 1_000_000}M"
elif context_tokens >= 1_000:
context_str = f"{context_tokens // 1_000}K"
else:
context_str = str(context_tokens)
# Build description line
if config.description:
desc = f"- '{alias}' ({context_str} context): {config.description}"
else:
# Fallback to showing the model name if no description
desc = f"- '{alias}' ({context_str} context): {config.model_name}"
model_desc_parts.append(desc)
# Show all models - no truncation needed
except Exception as e:
# Log for debugging but don't fail
import logging
logging.debug(f"Failed to load OpenRouter model descriptions: {e}")
# Fallback to simple message
model_desc_parts.append(
"\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter."
)
# Get all available models for the enum
all_models = self._get_available_models()
return {
"type": "string",
"description": "\n".join(model_desc_parts),
"enum": all_models,
}
else:
# Normal mode - model is optional with default
available_models = self._get_available_models()
models_str = ", ".join(f"'{m}'" for m in available_models) # Show ALL models so the CLI can choose
description = f"Model to use. Native models: {models_str}."
if has_openrouter:
# Add OpenRouter aliases
try:
registry = self._get_openrouter_registry()
aliases = registry.list_aliases()
# Show ALL aliases from the configuration
if aliases:
# Show all aliases so the CLI knows every option available
all_aliases = sorted(aliases)
alias_list = ", ".join(f"'{a}'" for a in all_aliases)
description += f" OpenRouter aliases: {alias_list}."
else:
description += " OpenRouter: Any model available on openrouter.ai."
except Exception:
description += (
" OpenRouter: Any model available on openrouter.ai "
"(e.g., 'gpt-4', 'claude-4-opus', 'mistral-large')."
)
description += f" Defaults to '{DEFAULT_MODEL}' if not specified."
description = (
"Currently in auto model selection mode. IMPORTANT: Use the model specified by the user if provided, OR select the most suitable model by calling the "
"`listmodels` tool to obtain the full catalog with capabilities, aliases, and context window info."
)
return {
"type": "string",
"description": description,
}
description = (
f"The default model is '{DEFAULT_MODEL}'. Override by supplying another supported model name ONLY when requested by the user. "
"If needed, use the `listmodels` tool to obtain a full model catalog along with their capability details."
)
return {
"type": "string",
"description": description,
}
def get_default_temperature(self) -> float:
"""
Return the default temperature setting for this tool.

View File

@@ -58,6 +58,7 @@ class SchemaBuilder:
required_fields: list[str] = None,
model_field_schema: dict[str, Any] = None,
auto_mode: bool = False,
require_model: bool = False,
) -> dict[str, Any]:
"""
Build complete schema for simple tools.
@@ -88,8 +89,8 @@ class SchemaBuilder:
properties.update(tool_specific_fields)
# Build required fields list
required = required_fields or []
if auto_mode and "model" not in required:
required = list(required_fields) if required_fields else []
if (auto_mode or require_model) and "model" not in required:
required.append("model")
# Build the complete schema

View File

@@ -148,9 +148,10 @@ class SimpleTool(BaseTool):
Returns:
Complete JSON schema for the tool
"""
required_fields = list(self.get_required_fields())
return SchemaBuilder.build_schema(
tool_specific_fields=self.get_tool_fields(),
required_fields=self.get_required_fields(),
required_fields=required_fields,
model_field_schema=self.get_model_field_schema(),
auto_mode=self.is_effective_auto_mode(),
)

View File

@@ -93,6 +93,7 @@ class WorkflowSchemaBuilder:
tool_name: str = None,
excluded_workflow_fields: list[str] = None,
excluded_common_fields: list[str] = None,
require_model: bool = False,
) -> dict[str, Any]:
"""
Build complete schema for workflow tools.
@@ -142,7 +143,7 @@ class WorkflowSchemaBuilder:
required = standard_required + (required_fields or [])
if auto_mode and "model" not in required:
if (auto_mode or require_model) and "model" not in required:
required.append("model")
# Build the complete schema