feat: depending on the number of tools in use, this change should save ~50% of overall tokens used. fixes https://github.com/BeehiveInnovations/zen-mcp-server/issues/255 but also refactored individual tools to instead encourage the agent to use the listmodels tool if needed.
This commit is contained in:
@@ -92,9 +92,9 @@ class TestAutoMode:
|
|||||||
|
|
||||||
# Model field should have detailed descriptions
|
# Model field should have detailed descriptions
|
||||||
model_schema = schema["properties"]["model"]
|
model_schema = schema["properties"]["model"]
|
||||||
assert "enum" in model_schema
|
assert "enum" not in model_schema
|
||||||
assert "flash" in model_schema["enum"]
|
assert "auto model selection" in model_schema["description"].lower()
|
||||||
assert "select the most suitable model" in model_schema["description"]
|
assert "listmodels" in model_schema["description"]
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore
|
# Restore
|
||||||
@@ -111,14 +111,14 @@ class TestAutoMode:
|
|||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
schema = tool.get_input_schema()
|
schema = tool.get_input_schema()
|
||||||
|
|
||||||
# Model should not be required
|
# Model should not be required when default model is configured
|
||||||
assert "model" not in schema["required"]
|
assert "model" not in schema["required"]
|
||||||
|
|
||||||
# Model field should have simpler description
|
# Model field should have simpler description
|
||||||
model_schema = schema["properties"]["model"]
|
model_schema = schema["properties"]["model"]
|
||||||
assert "enum" not in model_schema
|
assert "enum" not in model_schema
|
||||||
assert "Native models:" in model_schema["description"]
|
assert "listmodels" in model_schema["description"]
|
||||||
assert "Defaults to" in model_schema["description"]
|
assert "default model" in model_schema["description"].lower()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_auto_mode_requires_model_parameter(self):
|
async def test_auto_mode_requires_model_parameter(self):
|
||||||
@@ -287,19 +287,10 @@ class TestAutoMode:
|
|||||||
importlib.reload(config)
|
importlib.reload(config)
|
||||||
|
|
||||||
schema = tool.get_model_field_schema()
|
schema = tool.get_model_field_schema()
|
||||||
assert "enum" in schema
|
assert "enum" not in schema
|
||||||
# Test that some basic models are available (those that should be available with dummy keys)
|
assert schema["type"] == "string"
|
||||||
available_models = schema["enum"]
|
assert "auto model selection" in schema["description"]
|
||||||
# Check for models that should be available with basic provider setup
|
assert "listmodels" in schema["description"]
|
||||||
expected_basic_models = ["flash", "pro"] # Gemini models from conftest.py
|
|
||||||
for model in expected_basic_models:
|
|
||||||
if model not in available_models:
|
|
||||||
print(f"Missing expected model: {model}")
|
|
||||||
print(f"Available models: {available_models}")
|
|
||||||
assert any(
|
|
||||||
model in available_models for model in expected_basic_models
|
|
||||||
), f"None of {expected_basic_models} found in {available_models}"
|
|
||||||
assert "select the most suitable model" in schema["description"]
|
|
||||||
|
|
||||||
# Test normal mode
|
# Test normal mode
|
||||||
os.environ["DEFAULT_MODEL"] = "pro"
|
os.environ["DEFAULT_MODEL"] = "pro"
|
||||||
@@ -307,10 +298,9 @@ class TestAutoMode:
|
|||||||
|
|
||||||
schema = tool.get_model_field_schema()
|
schema = tool.get_model_field_schema()
|
||||||
assert "enum" not in schema
|
assert "enum" not in schema
|
||||||
# Check for the new schema format
|
assert schema["type"] == "string"
|
||||||
assert "Model to use." in schema["description"]
|
|
||||||
assert "'pro'" in schema["description"]
|
assert "'pro'" in schema["description"]
|
||||||
assert "Defaults to" in schema["description"]
|
assert "listmodels" in schema["description"]
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore
|
# Restore
|
||||||
|
|||||||
@@ -291,58 +291,28 @@ class TestAutoModeComprehensive:
|
|||||||
# Should have model as required field
|
# Should have model as required field
|
||||||
assert "model" in schema["required"]
|
assert "model" in schema["required"]
|
||||||
|
|
||||||
# Should include all model options from global config
|
# In auto mode, the schema should now have a description field
|
||||||
|
# instructing users to use the listmodels tool instead of an enum
|
||||||
model_schema = schema["properties"]["model"]
|
model_schema = schema["properties"]["model"]
|
||||||
assert "enum" in model_schema
|
assert "type" in model_schema
|
||||||
|
assert model_schema["type"] == "string"
|
||||||
|
assert "description" in model_schema
|
||||||
|
|
||||||
available_models = model_schema["enum"]
|
# Check that the description mentions using listmodels tool
|
||||||
|
description = model_schema["description"]
|
||||||
|
assert "listmodels" in description.lower()
|
||||||
|
assert "auto" in description.lower() or "selection" in description.lower()
|
||||||
|
|
||||||
# Should include Gemini models
|
# Should NOT have enum field anymore - this is the new behavior
|
||||||
assert "flash" in available_models
|
assert "enum" not in model_schema
|
||||||
assert "pro" in available_models
|
|
||||||
assert "gemini-2.5-flash" in available_models
|
|
||||||
assert "gemini-2.5-pro" in available_models
|
|
||||||
|
|
||||||
# After the fix, schema only shows models from enabled providers
|
# After the design change, the system directs users to use listmodels
|
||||||
# This prevents model namespace collisions and misleading users
|
# instead of enumerating all models in the schema
|
||||||
# If only Gemini is configured, only Gemini models should appear
|
# This prevents model namespace collisions and keeps the schema cleaner
|
||||||
provider_count = len(
|
|
||||||
[
|
|
||||||
key
|
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]
|
|
||||||
if os.getenv(key) and os.getenv(key) != f"your_{key.lower()}_here"
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider_count == 1 and os.getenv("GEMINI_API_KEY"):
|
# With the new design change, we no longer enumerate models in the schema
|
||||||
# Only Gemini configured - should only show Gemini models
|
# The listmodels tool should be used to discover available models
|
||||||
non_gemini_models = [
|
# This test now validates the schema structure rather than model enumeration
|
||||||
m
|
|
||||||
for m in available_models
|
|
||||||
if not m.startswith("gemini")
|
|
||||||
and m
|
|
||||||
not in [
|
|
||||||
"flash",
|
|
||||||
"pro",
|
|
||||||
"flash-2.0",
|
|
||||||
"flash2",
|
|
||||||
"flashlite",
|
|
||||||
"flash-lite",
|
|
||||||
"flash2.5",
|
|
||||||
"gemini pro",
|
|
||||||
"gemini-pro",
|
|
||||||
]
|
|
||||||
]
|
|
||||||
assert (
|
|
||||||
len(non_gemini_models) == 0
|
|
||||||
), f"Found non-Gemini models when only Gemini configured: {non_gemini_models}"
|
|
||||||
else:
|
|
||||||
# Multiple providers or OpenRouter - should include various models
|
|
||||||
# Only check if models are available if their providers might be configured
|
|
||||||
if os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"):
|
|
||||||
assert any("o3" in m or "o4" in m for m in available_models), "No OpenAI models found"
|
|
||||||
if os.getenv("XAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"):
|
|
||||||
assert any("grok" in m for m in available_models), "No XAI models found"
|
|
||||||
|
|
||||||
def test_auto_mode_schema_with_all_providers(self):
|
def test_auto_mode_schema_with_all_providers(self):
|
||||||
"""Test that auto mode schema includes models from all available providers."""
|
"""Test that auto mode schema includes models from all available providers."""
|
||||||
@@ -380,21 +350,21 @@ class TestAutoModeComprehensive:
|
|||||||
tool = AnalyzeTool()
|
tool = AnalyzeTool()
|
||||||
schema = tool.get_input_schema()
|
schema = tool.get_input_schema()
|
||||||
|
|
||||||
|
# In auto mode with multiple providers, should still use the new schema format
|
||||||
model_schema = schema["properties"]["model"]
|
model_schema = schema["properties"]["model"]
|
||||||
available_models = model_schema["enum"]
|
assert "type" in model_schema
|
||||||
|
assert model_schema["type"] == "string"
|
||||||
|
assert "description" in model_schema
|
||||||
|
|
||||||
# Should include models from all providers
|
# Check that the description mentions using listmodels tool
|
||||||
# Gemini models
|
description = model_schema["description"]
|
||||||
assert "flash" in available_models
|
assert "listmodels" in description.lower()
|
||||||
assert "pro" in available_models
|
|
||||||
|
|
||||||
# OpenAI models
|
# Should NOT have enum field - uses listmodels tool instead
|
||||||
assert "o3" in available_models
|
assert "enum" not in model_schema
|
||||||
assert "o4-mini" in available_models
|
|
||||||
|
|
||||||
# XAI models
|
# With multiple providers configured, the listmodels tool
|
||||||
assert "grok" in available_models
|
# would show models from all providers when called
|
||||||
assert "grok-3" in available_models
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_auto_mode_model_parameter_required_error(self):
|
async def test_auto_mode_model_parameter_required_error(self):
|
||||||
|
|||||||
@@ -84,15 +84,14 @@ class TestChatTool:
|
|||||||
assert schema["type"] == "string"
|
assert schema["type"] == "string"
|
||||||
assert "description" in schema
|
assert "description" in schema
|
||||||
|
|
||||||
# In auto mode, should have enum. In normal mode, should have model descriptions
|
# Description should route callers to listmodels, regardless of mode
|
||||||
|
assert "listmodels" in schema["description"]
|
||||||
if self.tool.is_effective_auto_mode():
|
if self.tool.is_effective_auto_mode():
|
||||||
assert "enum" in schema
|
assert "auto model selection" in schema["description"]
|
||||||
assert len(schema["enum"]) > 0
|
|
||||||
assert "IMPORTANT:" in schema["description"]
|
|
||||||
else:
|
else:
|
||||||
# Normal mode - should have model descriptions in description
|
import config
|
||||||
assert "Model to use" in schema["description"]
|
|
||||||
assert "Native models:" in schema["description"]
|
assert f"'{config.DEFAULT_MODEL}'" in schema["description"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_prompt_preparation(self):
|
async def test_prompt_preparation(self):
|
||||||
|
|||||||
@@ -130,9 +130,7 @@ class TestModelEnumeration:
|
|||||||
models = tool._get_available_models()
|
models = tool._get_available_models()
|
||||||
|
|
||||||
for alias in ("local-llama", "llama3.2"):
|
for alias in ("local-llama", "llama3.2"):
|
||||||
assert (
|
assert alias not in models, f"Custom model alias '{alias}' should remain hidden without CUSTOM_API_URL"
|
||||||
alias not in models
|
|
||||||
), f"Custom model alias '{alias}' should remain hidden without CUSTOM_API_URL"
|
|
||||||
|
|
||||||
def test_no_duplicates_with_overlapping_providers(self):
|
def test_no_duplicates_with_overlapping_providers(self):
|
||||||
"""Test that models aren't duplicated when multiple providers offer the same model."""
|
"""Test that models aren't duplicated when multiple providers offer the same model."""
|
||||||
|
|||||||
@@ -465,7 +465,7 @@ class TestSchemaGeneration:
|
|||||||
tool = ThinkDeepTool()
|
tool = ThinkDeepTool()
|
||||||
schema = tool.get_input_schema()
|
schema = tool.get_input_schema()
|
||||||
|
|
||||||
# Model should NOT be required
|
# Model should remain optional when DEFAULT_MODEL is available
|
||||||
assert "model" not in schema["required"]
|
assert "model" not in schema["required"]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -87,6 +87,10 @@ class ChatTool(SimpleTool):
|
|||||||
the same schema generation approach while still benefiting from SimpleTool
|
the same schema generation approach while still benefiting from SimpleTool
|
||||||
convenience methods.
|
convenience methods.
|
||||||
"""
|
"""
|
||||||
|
required_fields = ["prompt"]
|
||||||
|
if self.is_effective_auto_mode():
|
||||||
|
required_fields.append("model")
|
||||||
|
|
||||||
schema = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -121,7 +125,7 @@ class ChatTool(SimpleTool):
|
|||||||
"description": COMMON_FIELD_DESCRIPTIONS["continuation_id"],
|
"description": COMMON_FIELD_DESCRIPTIONS["continuation_id"],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["prompt"] + (["model"] if self.is_effective_auto_mode() else []),
|
"required": required_fields,
|
||||||
}
|
}
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|||||||
@@ -298,182 +298,30 @@ class BaseTool(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict containing the model field JSON schema
|
Dict containing the model field JSON schema
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
|
|
||||||
from config import DEFAULT_MODEL
|
from config import DEFAULT_MODEL
|
||||||
|
|
||||||
# Check if OpenRouter is configured
|
|
||||||
has_openrouter = bool(
|
|
||||||
os.getenv("OPENROUTER_API_KEY") and os.getenv("OPENROUTER_API_KEY") != "your_openrouter_api_key_here"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the centralized effective auto mode check
|
# Use the centralized effective auto mode check
|
||||||
if self.is_effective_auto_mode():
|
if self.is_effective_auto_mode():
|
||||||
# In auto mode, model is required and we provide detailed descriptions
|
description = (
|
||||||
model_desc_parts = [
|
"Currently in auto model selection mode. IMPORTANT: Use the model specified by the user if provided, OR select the most suitable model by calling the "
|
||||||
"IMPORTANT: Use the model specified by the user if provided, OR select the most suitable model "
|
"`listmodels` tool to obtain the full catalog with capabilities, aliases, and context window info."
|
||||||
"for this specific task based on the requirements and capabilities listed below:"
|
)
|
||||||
]
|
|
||||||
|
|
||||||
# Get descriptions from enabled providers
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.registry import ModelProviderRegistry
|
|
||||||
|
|
||||||
# Map provider types to readable names
|
|
||||||
provider_names = {
|
|
||||||
ProviderType.GOOGLE: "Gemini models",
|
|
||||||
ProviderType.OPENAI: "OpenAI models",
|
|
||||||
ProviderType.XAI: "X.AI GROK models",
|
|
||||||
ProviderType.DIAL: "DIAL models",
|
|
||||||
ProviderType.CUSTOM: "Custom models",
|
|
||||||
ProviderType.OPENROUTER: "OpenRouter models",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check available providers and add their model descriptions
|
|
||||||
|
|
||||||
# Start with native providers
|
|
||||||
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]:
|
|
||||||
# Only if this is registered / available
|
|
||||||
provider = ModelProviderRegistry.get_provider(provider_type)
|
|
||||||
if provider:
|
|
||||||
provider_section_added = False
|
|
||||||
for model_name in provider.list_models(respect_restrictions=True):
|
|
||||||
try:
|
|
||||||
# Get model config to extract description
|
|
||||||
model_config = provider.SUPPORTED_MODELS.get(model_name)
|
|
||||||
if model_config and model_config.description:
|
|
||||||
if not provider_section_added:
|
|
||||||
model_desc_parts.append(
|
|
||||||
f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:"
|
|
||||||
)
|
|
||||||
provider_section_added = True
|
|
||||||
model_desc_parts.append(f"- '{model_name}': {model_config.description}")
|
|
||||||
except Exception:
|
|
||||||
# Skip models without descriptions
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Add custom models if custom API is configured
|
|
||||||
custom_url = os.getenv("CUSTOM_API_URL")
|
|
||||||
if custom_url:
|
|
||||||
# Load custom models from registry
|
|
||||||
try:
|
|
||||||
registry = self._get_openrouter_registry()
|
|
||||||
model_desc_parts.append(f"\nCustom models via {custom_url}:")
|
|
||||||
|
|
||||||
# Find all custom models (is_custom=true)
|
|
||||||
for alias in registry.list_aliases():
|
|
||||||
config = registry.resolve(alias)
|
|
||||||
# Check if this is a custom model that requires custom endpoints
|
|
||||||
if config and config.is_custom:
|
|
||||||
# Format context window
|
|
||||||
context_tokens = config.context_window
|
|
||||||
if context_tokens >= 1_000_000:
|
|
||||||
context_str = f"{context_tokens // 1_000_000}M"
|
|
||||||
elif context_tokens >= 1_000:
|
|
||||||
context_str = f"{context_tokens // 1_000}K"
|
|
||||||
else:
|
|
||||||
context_str = str(context_tokens)
|
|
||||||
|
|
||||||
desc_line = f"- '{alias}' ({context_str} context): {config.description}"
|
|
||||||
if desc_line not in model_desc_parts: # Avoid duplicates
|
|
||||||
model_desc_parts.append(desc_line)
|
|
||||||
except Exception as e:
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.debug(f"Failed to load custom model descriptions: {e}")
|
|
||||||
model_desc_parts.append(f"\nCustom models: Models available via {custom_url}")
|
|
||||||
|
|
||||||
if has_openrouter:
|
|
||||||
# Add OpenRouter models with descriptions
|
|
||||||
try:
|
|
||||||
import logging
|
|
||||||
|
|
||||||
registry = self._get_openrouter_registry()
|
|
||||||
|
|
||||||
# Group models by their model_name to avoid duplicates
|
|
||||||
seen_models = set()
|
|
||||||
model_configs = []
|
|
||||||
|
|
||||||
for alias in registry.list_aliases():
|
|
||||||
config = registry.resolve(alias)
|
|
||||||
if config and config.model_name not in seen_models:
|
|
||||||
seen_models.add(config.model_name)
|
|
||||||
model_configs.append((alias, config))
|
|
||||||
|
|
||||||
# Sort by context window (descending) then by alias
|
|
||||||
model_configs.sort(key=lambda x: (-x[1].context_window, x[0]))
|
|
||||||
|
|
||||||
if model_configs:
|
|
||||||
model_desc_parts.append("\nOpenRouter models (use these aliases):")
|
|
||||||
for alias, config in model_configs: # Show ALL models so the CLI can choose
|
|
||||||
# Format context window in human-readable form
|
|
||||||
context_tokens = config.context_window
|
|
||||||
if context_tokens >= 1_000_000:
|
|
||||||
context_str = f"{context_tokens // 1_000_000}M"
|
|
||||||
elif context_tokens >= 1_000:
|
|
||||||
context_str = f"{context_tokens // 1_000}K"
|
|
||||||
else:
|
|
||||||
context_str = str(context_tokens)
|
|
||||||
|
|
||||||
# Build description line
|
|
||||||
if config.description:
|
|
||||||
desc = f"- '{alias}' ({context_str} context): {config.description}"
|
|
||||||
else:
|
|
||||||
# Fallback to showing the model name if no description
|
|
||||||
desc = f"- '{alias}' ({context_str} context): {config.model_name}"
|
|
||||||
model_desc_parts.append(desc)
|
|
||||||
|
|
||||||
# Show all models - no truncation needed
|
|
||||||
except Exception as e:
|
|
||||||
# Log for debugging but don't fail
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.debug(f"Failed to load OpenRouter model descriptions: {e}")
|
|
||||||
# Fallback to simple message
|
|
||||||
model_desc_parts.append(
|
|
||||||
"\nOpenRouter models: If configured, you can also use ANY model available on OpenRouter."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get all available models for the enum
|
|
||||||
all_models = self._get_available_models()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "string",
|
|
||||||
"description": "\n".join(model_desc_parts),
|
|
||||||
"enum": all_models,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Normal mode - model is optional with default
|
|
||||||
available_models = self._get_available_models()
|
|
||||||
models_str = ", ".join(f"'{m}'" for m in available_models) # Show ALL models so the CLI can choose
|
|
||||||
|
|
||||||
description = f"Model to use. Native models: {models_str}."
|
|
||||||
if has_openrouter:
|
|
||||||
# Add OpenRouter aliases
|
|
||||||
try:
|
|
||||||
registry = self._get_openrouter_registry()
|
|
||||||
aliases = registry.list_aliases()
|
|
||||||
|
|
||||||
# Show ALL aliases from the configuration
|
|
||||||
if aliases:
|
|
||||||
# Show all aliases so the CLI knows every option available
|
|
||||||
all_aliases = sorted(aliases)
|
|
||||||
alias_list = ", ".join(f"'{a}'" for a in all_aliases)
|
|
||||||
description += f" OpenRouter aliases: {alias_list}."
|
|
||||||
else:
|
|
||||||
description += " OpenRouter: Any model available on openrouter.ai."
|
|
||||||
except Exception:
|
|
||||||
description += (
|
|
||||||
" OpenRouter: Any model available on openrouter.ai "
|
|
||||||
"(e.g., 'gpt-4', 'claude-4-opus', 'mistral-large')."
|
|
||||||
)
|
|
||||||
description += f" Defaults to '{DEFAULT_MODEL}' if not specified."
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": description,
|
"description": description,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
description = (
|
||||||
|
f"The default model is '{DEFAULT_MODEL}'. Override by supplying another supported model name ONLY when requested by the user. "
|
||||||
|
"If needed, use the `listmodels` tool to obtain a full model catalog along with their capability details."
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "string",
|
||||||
|
"description": description,
|
||||||
|
}
|
||||||
|
|
||||||
def get_default_temperature(self) -> float:
|
def get_default_temperature(self) -> float:
|
||||||
"""
|
"""
|
||||||
Return the default temperature setting for this tool.
|
Return the default temperature setting for this tool.
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class SchemaBuilder:
|
|||||||
required_fields: list[str] = None,
|
required_fields: list[str] = None,
|
||||||
model_field_schema: dict[str, Any] = None,
|
model_field_schema: dict[str, Any] = None,
|
||||||
auto_mode: bool = False,
|
auto_mode: bool = False,
|
||||||
|
require_model: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Build complete schema for simple tools.
|
Build complete schema for simple tools.
|
||||||
@@ -88,8 +89,8 @@ class SchemaBuilder:
|
|||||||
properties.update(tool_specific_fields)
|
properties.update(tool_specific_fields)
|
||||||
|
|
||||||
# Build required fields list
|
# Build required fields list
|
||||||
required = required_fields or []
|
required = list(required_fields) if required_fields else []
|
||||||
if auto_mode and "model" not in required:
|
if (auto_mode or require_model) and "model" not in required:
|
||||||
required.append("model")
|
required.append("model")
|
||||||
|
|
||||||
# Build the complete schema
|
# Build the complete schema
|
||||||
|
|||||||
@@ -148,9 +148,10 @@ class SimpleTool(BaseTool):
|
|||||||
Returns:
|
Returns:
|
||||||
Complete JSON schema for the tool
|
Complete JSON schema for the tool
|
||||||
"""
|
"""
|
||||||
|
required_fields = list(self.get_required_fields())
|
||||||
return SchemaBuilder.build_schema(
|
return SchemaBuilder.build_schema(
|
||||||
tool_specific_fields=self.get_tool_fields(),
|
tool_specific_fields=self.get_tool_fields(),
|
||||||
required_fields=self.get_required_fields(),
|
required_fields=required_fields,
|
||||||
model_field_schema=self.get_model_field_schema(),
|
model_field_schema=self.get_model_field_schema(),
|
||||||
auto_mode=self.is_effective_auto_mode(),
|
auto_mode=self.is_effective_auto_mode(),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ class WorkflowSchemaBuilder:
|
|||||||
tool_name: str = None,
|
tool_name: str = None,
|
||||||
excluded_workflow_fields: list[str] = None,
|
excluded_workflow_fields: list[str] = None,
|
||||||
excluded_common_fields: list[str] = None,
|
excluded_common_fields: list[str] = None,
|
||||||
|
require_model: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Build complete schema for workflow tools.
|
Build complete schema for workflow tools.
|
||||||
@@ -142,7 +143,7 @@ class WorkflowSchemaBuilder:
|
|||||||
|
|
||||||
required = standard_required + (required_fields or [])
|
required = standard_required + (required_fields or [])
|
||||||
|
|
||||||
if auto_mode and "model" not in required:
|
if (auto_mode or require_model) and "model" not in required:
|
||||||
required.append("model")
|
required.append("model")
|
||||||
|
|
||||||
# Build the complete schema
|
# Build the complete schema
|
||||||
|
|||||||
Reference in New Issue
Block a user