From d9449c7bb607caff3f0454f210ddfc36256c738a Mon Sep 17 00:00:00 2001 From: Fahad Date: Wed, 1 Oct 2025 21:40:31 +0400 Subject: [PATCH] feat: depending on the number of tools in use, this change should save ~50% of overall tokens used. fixes https://github.com/BeehiveInnovations/zen-mcp-server/issues/255 but also refactored individual tools to instead encourage the agent to use the listmodels tool if needed. --- tests/test_auto_mode.py | 34 ++--- tests/test_auto_mode_comprehensive.py | 86 ++++-------- tests/test_chat_simple.py | 13 +- tests/test_model_enumeration.py | 4 +- tests/test_per_tool_model_defaults.py | 2 +- tools/chat.py | 6 +- tools/shared/base_tool.py | 180 ++------------------------ tools/shared/schema_builders.py | 5 +- tools/simple/base.py | 3 +- tools/workflow/schema_builders.py | 3 +- 10 files changed, 74 insertions(+), 262 deletions(-) diff --git a/tests/test_auto_mode.py b/tests/test_auto_mode.py index f96feb3..e30e8b6 100644 --- a/tests/test_auto_mode.py +++ b/tests/test_auto_mode.py @@ -92,9 +92,9 @@ class TestAutoMode: # Model field should have detailed descriptions model_schema = schema["properties"]["model"] - assert "enum" in model_schema - assert "flash" in model_schema["enum"] - assert "select the most suitable model" in model_schema["description"] + assert "enum" not in model_schema + assert "auto model selection" in model_schema["description"].lower() + assert "listmodels" in model_schema["description"] finally: # Restore @@ -111,14 +111,14 @@ class TestAutoMode: tool = ChatTool() schema = tool.get_input_schema() - # Model should not be required + # Model should not be required when default model is configured assert "model" not in schema["required"] # Model field should have simpler description model_schema = schema["properties"]["model"] assert "enum" not in model_schema - assert "Native models:" in model_schema["description"] - assert "Defaults to" in model_schema["description"] + assert "listmodels" in model_schema["description"] + assert "default model" in model_schema["description"].lower() @pytest.mark.asyncio async def test_auto_mode_requires_model_parameter(self): @@ -287,19 +287,10 @@ class TestAutoMode: importlib.reload(config) schema = tool.get_model_field_schema() - assert "enum" in schema - # Test that some basic models are available (those that should be available with dummy keys) - available_models = schema["enum"] - # Check for models that should be available with basic provider setup - expected_basic_models = ["flash", "pro"] # Gemini models from conftest.py - for model in expected_basic_models: - if model not in available_models: - print(f"Missing expected model: {model}") - print(f"Available models: {available_models}") - assert any( - model in available_models for model in expected_basic_models - ), f"None of {expected_basic_models} found in {available_models}" - assert "select the most suitable model" in schema["description"] + assert "enum" not in schema + assert schema["type"] == "string" + assert "auto model selection" in schema["description"] + assert "listmodels" in schema["description"] # Test normal mode os.environ["DEFAULT_MODEL"] = "pro" @@ -307,10 +298,9 @@ class TestAutoMode: schema = tool.get_model_field_schema() assert "enum" not in schema - # Check for the new schema format - assert "Model to use." in schema["description"] + assert schema["type"] == "string" assert "'pro'" in schema["description"] - assert "Defaults to" in schema["description"] + assert "listmodels" in schema["description"] finally: # Restore diff --git a/tests/test_auto_mode_comprehensive.py b/tests/test_auto_mode_comprehensive.py index d4736f0..a68db41 100644 --- a/tests/test_auto_mode_comprehensive.py +++ b/tests/test_auto_mode_comprehensive.py @@ -291,58 +291,28 @@ class TestAutoModeComprehensive: # Should have model as required field assert "model" in schema["required"] - # Should include all model options from global config + # In auto mode, the schema should now have a description field + # instructing users to use the listmodels tool instead of an enum model_schema = schema["properties"]["model"] - assert "enum" in model_schema + assert "type" in model_schema + assert model_schema["type"] == "string" + assert "description" in model_schema - available_models = model_schema["enum"] + # Check that the description mentions using listmodels tool + description = model_schema["description"] + assert "listmodels" in description.lower() + assert "auto" in description.lower() or "selection" in description.lower() - # Should include Gemini models - assert "flash" in available_models - assert "pro" in available_models - assert "gemini-2.5-flash" in available_models - assert "gemini-2.5-pro" in available_models + # Should NOT have enum field anymore - this is the new behavior + assert "enum" not in model_schema - # After the fix, schema only shows models from enabled providers - # This prevents model namespace collisions and misleading users - # If only Gemini is configured, only Gemini models should appear - provider_count = len( - [ - key - for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"] - if os.getenv(key) and os.getenv(key) != f"your_{key.lower()}_here" - ] - ) + # After the design change, the system directs users to use listmodels + # instead of enumerating all models in the schema + # This prevents model namespace collisions and keeps the schema cleaner - if provider_count == 1 and os.getenv("GEMINI_API_KEY"): - # Only Gemini configured - should only show Gemini models - non_gemini_models = [ - m - for m in available_models - if not m.startswith("gemini") - and m - not in [ - "flash", - "pro", - "flash-2.0", - "flash2", - "flashlite", - "flash-lite", - "flash2.5", - "gemini pro", - "gemini-pro", - ] - ] - assert ( - len(non_gemini_models) == 0 - ), f"Found non-Gemini models when only Gemini configured: {non_gemini_models}" - else: - # Multiple providers or OpenRouter - should include various models - # Only check if models are available if their providers might be configured - if os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"): - assert any("o3" in m or "o4" in m for m in available_models), "No OpenAI models found" - if os.getenv("XAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"): - assert any("grok" in m for m in available_models), "No XAI models found" + # With the new design change, we no longer enumerate models in the schema + # The listmodels tool should be used to discover available models + # This test now validates the schema structure rather than model enumeration def test_auto_mode_schema_with_all_providers(self): """Test that auto mode schema includes models from all available providers.""" @@ -380,21 +350,21 @@ class TestAutoModeComprehensive: tool = AnalyzeTool() schema = tool.get_input_schema() + # In auto mode with multiple providers, should still use the new schema format model_schema = schema["properties"]["model"] - available_models = model_schema["enum"] + assert "type" in model_schema + assert model_schema["type"] == "string" + assert "description" in model_schema - # Should include models from all providers - # Gemini models - assert "flash" in available_models - assert "pro" in available_models + # Check that the description mentions using listmodels tool + description = model_schema["description"] + assert "listmodels" in description.lower() - # OpenAI models - assert "o3" in available_models - assert "o4-mini" in available_models + # Should NOT have enum field - uses listmodels tool instead + assert "enum" not in model_schema - # XAI models - assert "grok" in available_models - assert "grok-3" in available_models + # With multiple providers configured, the listmodels tool + # would show models from all providers when called @pytest.mark.asyncio async def test_auto_mode_model_parameter_required_error(self): diff --git a/tests/test_chat_simple.py b/tests/test_chat_simple.py index 34064be..d6fbf19 100644 --- a/tests/test_chat_simple.py +++ b/tests/test_chat_simple.py @@ -84,15 +84,14 @@ class TestChatTool: assert schema["type"] == "string" assert "description" in schema - # In auto mode, should have enum. In normal mode, should have model descriptions + # Description should route callers to listmodels, regardless of mode + assert "listmodels" in schema["description"] if self.tool.is_effective_auto_mode(): - assert "enum" in schema - assert len(schema["enum"]) > 0 - assert "IMPORTANT:" in schema["description"] + assert "auto model selection" in schema["description"] else: - # Normal mode - should have model descriptions in description - assert "Model to use" in schema["description"] - assert "Native models:" in schema["description"] + import config + + assert f"'{config.DEFAULT_MODEL}'" in schema["description"] @pytest.mark.asyncio async def test_prompt_preparation(self): diff --git a/tests/test_model_enumeration.py b/tests/test_model_enumeration.py index ef30b56..6dc390b 100644 --- a/tests/test_model_enumeration.py +++ b/tests/test_model_enumeration.py @@ -130,9 +130,7 @@ class TestModelEnumeration: models = tool._get_available_models() for alias in ("local-llama", "llama3.2"): - assert ( - alias not in models - ), f"Custom model alias '{alias}' should remain hidden without CUSTOM_API_URL" + assert alias not in models, f"Custom model alias '{alias}' should remain hidden without CUSTOM_API_URL" def test_no_duplicates_with_overlapping_providers(self): """Test that models aren't duplicated when multiple providers offer the same model.""" diff --git a/tests/test_per_tool_model_defaults.py b/tests/test_per_tool_model_defaults.py index 167df88..1099dbf 100644 --- a/tests/test_per_tool_model_defaults.py +++ b/tests/test_per_tool_model_defaults.py @@ -465,7 +465,7 @@ class TestSchemaGeneration: tool = ThinkDeepTool() schema = tool.get_input_schema() - # Model should NOT be required + # Model should remain optional when DEFAULT_MODEL is available assert "model" not in schema["required"] diff --git a/tools/chat.py b/tools/chat.py index dca2c60..3854561 100644 --- a/tools/chat.py +++ b/tools/chat.py @@ -87,6 +87,10 @@ class ChatTool(SimpleTool): the same schema generation approach while still benefiting from SimpleTool convenience methods. """ + required_fields = ["prompt"] + if self.is_effective_auto_mode(): + required_fields.append("model") + schema = { "type": "object", "properties": { @@ -121,7 +125,7 @@ class ChatTool(SimpleTool): "description": COMMON_FIELD_DESCRIPTIONS["continuation_id"], }, }, - "required": ["prompt"] + (["model"] if self.is_effective_auto_mode() else []), + "required": required_fields, } return schema diff --git a/tools/shared/base_tool.py b/tools/shared/base_tool.py index 87bb1b4..adb77f1 100644 --- a/tools/shared/base_tool.py +++ b/tools/shared/base_tool.py @@ -298,182 +298,30 @@ class BaseTool(ABC): Returns: Dict containing the model field JSON schema """ - import os from config import DEFAULT_MODEL - # Check if OpenRouter is configured - has_openrouter = bool( - os.getenv("OPENROUTER_API_KEY") and os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here" - ) - # Use the centralized effective auto mode check if self.is_effective_auto_mode(): - # In auto mode, model is required and we provide detailed descriptions - model_desc_parts = [ - "IMPORTANT: Use the model specified by the user if provided, OR select the most suitable model " - "for this specific task based on the requirements and capabilities listed below:" - ] - - # Get descriptions from enabled providers - from providers.base import ProviderType - from providers.registry import ModelProviderRegistry - - # Map provider types to readable names - provider_names = { - ProviderType.GOOGLE: "Gemini models", - ProviderType.OPENAI: "OpenAI models", - ProviderType.XAI: "X.AI GROK models", - ProviderType.DIAL: "DIAL models", - ProviderType.CUSTOM: "Custom models", - ProviderType.OPENROUTER: "OpenRouter models", - } - - # Check available providers and add their model descriptions - - # Start with native providers - for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]: - # Only if this is registered / available - provider = ModelProviderRegistry.get_provider(provider_type) - if provider: - provider_section_added = False - for model_name in provider.list_models(respect_restrictions=True): - try: - # Get model config to extract description - model_config = provider.SUPPORTED_MODELS.get(model_name) - if model_config and model_config.description: - if not provider_section_added: - model_desc_parts.append( - f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:" - ) - provider_section_added = True - model_desc_parts.append(f"- '{model_name}': {model_config.description}") - except Exception: - # Skip models without descriptions - continue - - # Add custom models if custom API is configured - custom_url = os.getenv("CUSTOM_API_URL") - if custom_url: - # Load custom models from registry - try: - registry = self._get_openrouter_registry() - model_desc_parts.append(f"\nCustom models via {custom_url}:") - - # Find all custom models (is_custom=true) - for alias in registry.list_aliases(): - config = registry.resolve(alias) - # Check if this is a custom model that requires custom endpoints - if config and config.is_custom: - # Format context window - context_tokens = config.context_window - if context_tokens >= 1_000_000: - context_str = f"{context_tokens // 1_000_000}M" - elif context_tokens >= 1_000: - context_str = f"{context_tokens // 1_000}K" - else: - context_str = str(context_tokens) - - desc_line = f"- '{alias}' ({context_str} context): {config.description}" - if desc_line not in model_desc_parts: # Avoid duplicates - model_desc_parts.append(desc_line) - except Exception as e: - import logging - - logging.debug(f"Failed to load custom model descriptions: {e}") - model_desc_parts.append(f"\nCustom models: Models available via {custom_url}") - - if has_openrouter: - # Add OpenRouter models with descriptions - try: - import logging - - registry = self._get_openrouter_registry() - - # Group models by their model_name to avoid duplicates - seen_models = set() - model_configs = [] - - for alias in registry.list_aliases(): - config = registry.resolve(alias) - if config and config.model_name not in seen_models: - seen_models.add(config.model_name) - model_configs.append((alias, config)) - - # Sort by context window (descending) then by alias - model_configs.sort(key=lambda x: (-x[1].context_window, x[0])) - - if model_configs: - model_desc_parts.append("\nOpenRouter models (use these aliases):") - for alias, config in model_configs: # Show ALL models so the CLI can choose - # Format context window in human-readable form - context_tokens = config.context_window - if context_tokens >= 1_000_000: - context_str = f"{context_tokens // 1_000_000}M" - elif context_tokens >= 1_000: - context_str = f"{context_tokens // 1_000}K" - else: - context_str = str(context_tokens) - - # Build description line - if config.description: - desc = f"- '{alias}' ({context_str} context): {config.description}" - else: - # Fallback to showing the model name if no description - desc = f"- '{alias}' ({context_str} context): {config.model_name}" - model_desc_parts.append(desc) - - # Show all models - no truncation needed - except Exception as e: - # Log for debugging but don't fail - import logging - - logging.debug(f"Failed to load OpenRouter model descriptions: {e}") - # Fallback to simple message - model_desc_parts.append( - "\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter." - ) - - # Get all available models for the enum - all_models = self._get_available_models() - - return { - "type": "string", - "description": "\n".join(model_desc_parts), - "enum": all_models, - } - else: - # Normal mode - model is optional with default - available_models = self._get_available_models() - models_str = ", ".join(f"'{m}'" for m in available_models) # Show ALL models so the CLI can choose - - description = f"Model to use. Native models: {models_str}." - if has_openrouter: - # Add OpenRouter aliases - try: - registry = self._get_openrouter_registry() - aliases = registry.list_aliases() - - # Show ALL aliases from the configuration - if aliases: - # Show all aliases so the CLI knows every option available - all_aliases = sorted(aliases) - alias_list = ", ".join(f"'{a}'" for a in all_aliases) - description += f" OpenRouter aliases: {alias_list}." - else: - description += " OpenRouter: Any model available on openrouter.ai." - except Exception: - description += ( - " OpenRouter: Any model available on openrouter.ai " - "(e.g., 'gpt-4', 'claude-4-opus', 'mistral-large')." - ) - description += f" Defaults to '{DEFAULT_MODEL}' if not specified." - + description = ( + "Currently in auto model selection mode. IMPORTANT: Use the model specified by the user if provided, OR select the most suitable model by calling the " + "`listmodels` tool to obtain the full catalog with capabilities, aliases, and context window info." + ) return { "type": "string", "description": description, } + description = ( + f"The default model is '{DEFAULT_MODEL}'. Override by supplying another supported model name ONLY when requested by the user. " + "If needed, use the `listmodels` tool to obtain a full model catalog along with their capability details." + ) + + return { + "type": "string", + "description": description, + } + def get_default_temperature(self) -> float: """ Return the default temperature setting for this tool. diff --git a/tools/shared/schema_builders.py b/tools/shared/schema_builders.py index 6b319d8..dd0146c 100644 --- a/tools/shared/schema_builders.py +++ b/tools/shared/schema_builders.py @@ -58,6 +58,7 @@ class SchemaBuilder: required_fields: list[str] = None, model_field_schema: dict[str, Any] = None, auto_mode: bool = False, + require_model: bool = False, ) -> dict[str, Any]: """ Build complete schema for simple tools. @@ -88,8 +89,8 @@ class SchemaBuilder: properties.update(tool_specific_fields) # Build required fields list - required = required_fields or [] - if auto_mode and "model" not in required: + required = list(required_fields) if required_fields else [] + if (auto_mode or require_model) and "model" not in required: required.append("model") # Build the complete schema diff --git a/tools/simple/base.py b/tools/simple/base.py index ad369fa..b3dc611 100644 --- a/tools/simple/base.py +++ b/tools/simple/base.py @@ -148,9 +148,10 @@ class SimpleTool(BaseTool): Returns: Complete JSON schema for the tool """ + required_fields = list(self.get_required_fields()) return SchemaBuilder.build_schema( tool_specific_fields=self.get_tool_fields(), - required_fields=self.get_required_fields(), + required_fields=required_fields, model_field_schema=self.get_model_field_schema(), auto_mode=self.is_effective_auto_mode(), ) diff --git a/tools/workflow/schema_builders.py b/tools/workflow/schema_builders.py index 7858fc8..4ae1e27 100644 --- a/tools/workflow/schema_builders.py +++ b/tools/workflow/schema_builders.py @@ -93,6 +93,7 @@ class WorkflowSchemaBuilder: tool_name: str = None, excluded_workflow_fields: list[str] = None, excluded_common_fields: list[str] = None, + require_model: bool = False, ) -> dict[str, Any]: """ Build complete schema for workflow tools. @@ -142,7 +143,7 @@ class WorkflowSchemaBuilder: required = standard_required + (required_fields or []) - if auto_mode and "model" not in required: + if (auto_mode or require_model) and "model" not in required: required.append("model") # Build the complete schema