Merge branch 'feat-local_support_with_UTF-8_encoding-update' of https://github.com/GiGiDKR/zen-mcp-server into feat-local_support_with_UTF-8_encoding-update
This commit is contained in:
@@ -15,6 +15,7 @@ def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576
|
||||
model_name=model_name,
|
||||
friendly_name="Gemini",
|
||||
context_window=context_window,
|
||||
max_output_tokens=8192,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
|
||||
@@ -211,7 +211,7 @@ class TestAliasTargetRestrictions:
|
||||
# Verify the polymorphic method was called
|
||||
mock_provider.list_all_known_models.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high"}) # Restrict to specific model
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}) # Restrict to specific model
|
||||
def test_complex_alias_chains_handled_correctly(self):
|
||||
"""Test that complex alias chains are handled correctly in restrictions."""
|
||||
# Clear cached restriction service
|
||||
@@ -221,12 +221,11 @@ class TestAliasTargetRestrictions:
|
||||
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
# Only o4-mini-high should be allowed
|
||||
assert provider.validate_model_name("o4-mini-high")
|
||||
# Only o4-mini should be allowed
|
||||
assert provider.validate_model_name("o4-mini")
|
||||
|
||||
# Other models should be blocked
|
||||
assert not provider.validate_model_name("o4-mini")
|
||||
assert not provider.validate_model_name("mini") # This resolves to o4-mini
|
||||
assert not provider.validate_model_name("o3")
|
||||
assert not provider.validate_model_name("o3-mini")
|
||||
|
||||
def test_critical_regression_validation_sees_alias_targets(self):
|
||||
@@ -307,7 +306,7 @@ class TestAliasTargetRestrictions:
|
||||
it appear that target-based restrictions don't work.
|
||||
"""
|
||||
# Test with a made-up restriction scenario
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high,o3-mini"}):
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,o3-mini"}):
|
||||
# Clear cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
@@ -318,7 +317,7 @@ class TestAliasTargetRestrictions:
|
||||
|
||||
# These specific target models should be recognized as valid
|
||||
all_known = provider.list_all_known_models()
|
||||
assert "o4-mini-high" in all_known, "Target model o4-mini-high should be known"
|
||||
assert "o4-mini" in all_known, "Target model o4-mini should be known"
|
||||
assert "o3-mini" in all_known, "Target model o3-mini should be known"
|
||||
|
||||
# Validation should not warn about these being unrecognized
|
||||
@@ -329,11 +328,11 @@ class TestAliasTargetRestrictions:
|
||||
# Should not warn about our allowed models being unrecognized
|
||||
all_warnings = [str(call) for call in mock_logger.warning.call_args_list]
|
||||
for warning in all_warnings:
|
||||
assert "o4-mini-high" not in warning or "not a recognized" not in warning
|
||||
assert "o4-mini" not in warning or "not a recognized" not in warning
|
||||
assert "o3-mini" not in warning or "not a recognized" not in warning
|
||||
|
||||
# The restriction should actually work
|
||||
assert provider.validate_model_name("o4-mini-high")
|
||||
assert provider.validate_model_name("o4-mini")
|
||||
assert provider.validate_model_name("o3-mini")
|
||||
assert not provider.validate_model_name("o4-mini") # not in allowed list
|
||||
assert not provider.validate_model_name("o3-pro") # not in allowed list
|
||||
assert not provider.validate_model_name("o3") # not in allowed list
|
||||
|
||||
@@ -59,12 +59,12 @@ class TestAutoMode:
|
||||
continue
|
||||
|
||||
# Check that model has description
|
||||
description = config.get("description", "")
|
||||
description = config.description if hasattr(config, "description") else ""
|
||||
if description:
|
||||
models_with_descriptions[model_name] = description
|
||||
|
||||
# Check all expected models are present with meaningful descriptions
|
||||
expected_models = ["flash", "pro", "o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high"]
|
||||
expected_models = ["flash", "pro", "o3", "o3-mini", "o3-pro", "o4-mini"]
|
||||
for model in expected_models:
|
||||
# Model should exist somewhere in the providers
|
||||
# Note: Some models might not be available if API keys aren't configured
|
||||
|
||||
@@ -319,7 +319,18 @@ class TestAutoModeComprehensive:
|
||||
m
|
||||
for m in available_models
|
||||
if not m.startswith("gemini")
|
||||
and m not in ["flash", "pro", "flash-2.0", "flash2", "flashlite", "flash-lite"]
|
||||
and m
|
||||
not in [
|
||||
"flash",
|
||||
"pro",
|
||||
"flash-2.0",
|
||||
"flash2",
|
||||
"flashlite",
|
||||
"flash-lite",
|
||||
"flash2.5",
|
||||
"gemini pro",
|
||||
"gemini-pro",
|
||||
]
|
||||
]
|
||||
assert (
|
||||
len(non_gemini_models) == 0
|
||||
|
||||
@@ -70,7 +70,7 @@ class TestAutoModeCustomProviderOnly:
|
||||
}
|
||||
|
||||
# Clear all other provider keys
|
||||
clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]
|
||||
clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]
|
||||
|
||||
with patch.dict(os.environ, test_env, clear=False):
|
||||
# Ensure other provider keys are not set
|
||||
@@ -109,7 +109,7 @@ class TestAutoModeCustomProviderOnly:
|
||||
|
||||
with patch.dict(os.environ, test_env, clear=False):
|
||||
# Clear other provider keys
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
@@ -177,7 +177,7 @@ class TestAutoModeCustomProviderOnly:
|
||||
|
||||
with patch.dict(os.environ, test_env, clear=False):
|
||||
# Clear other provider keys
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ class TestBuggyBehaviorPrevention:
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
# Simulate a scenario where admin wants to restrict specific targets
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini-high"}):
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
|
||||
# Clear cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
@@ -126,19 +126,21 @@ class TestBuggyBehaviorPrevention:
|
||||
|
||||
# These should work because they're explicitly allowed
|
||||
assert provider.validate_model_name("o3-mini")
|
||||
assert provider.validate_model_name("o4-mini-high")
|
||||
assert provider.validate_model_name("o4-mini")
|
||||
|
||||
# These should be blocked
|
||||
assert not provider.validate_model_name("o4-mini") # Not in allowed list
|
||||
assert not provider.validate_model_name("o3-pro") # Not in allowed list
|
||||
assert not provider.validate_model_name("o3") # Not in allowed list
|
||||
assert not provider.validate_model_name("mini") # Resolves to o4-mini, not allowed
|
||||
|
||||
# This should be ALLOWED because it resolves to o4-mini which is in the allowed list
|
||||
assert provider.validate_model_name("mini") # Resolves to o4-mini, which IS allowed
|
||||
|
||||
# Verify our list_all_known_models includes the restricted models
|
||||
all_known = provider.list_all_known_models()
|
||||
assert "o3-mini" in all_known # Should be known (and allowed)
|
||||
assert "o4-mini-high" in all_known # Should be known (and allowed)
|
||||
assert "o4-mini" in all_known # Should be known (but blocked)
|
||||
assert "mini" in all_known # Should be known (but blocked)
|
||||
assert "o4-mini" in all_known # Should be known (and allowed)
|
||||
assert "o3-pro" in all_known # Should be known (but blocked)
|
||||
assert "mini" in all_known # Should be known (and allowed since it resolves to o4-mini)
|
||||
|
||||
def test_demonstration_of_old_vs_new_interface(self):
|
||||
"""
|
||||
|
||||
@@ -506,17 +506,17 @@ class TestConversationFlow:
|
||||
mock_client = Mock()
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Start conversation with files
|
||||
thread_id = create_thread("analyze", {"prompt": "Analyze this codebase", "relevant_files": ["/project/src/"]})
|
||||
# Start conversation with files using a simple tool
|
||||
thread_id = create_thread("chat", {"prompt": "Analyze this codebase", "files": ["/project/src/"]})
|
||||
|
||||
# Turn 1: Claude provides context with multiple files
|
||||
initial_context = ThreadContext(
|
||||
thread_id=thread_id,
|
||||
created_at="2023-01-01T00:00:00Z",
|
||||
last_updated_at="2023-01-01T00:00:00Z",
|
||||
tool_name="analyze",
|
||||
tool_name="chat",
|
||||
turns=[],
|
||||
initial_context={"prompt": "Analyze this codebase", "relevant_files": ["/project/src/"]},
|
||||
initial_context={"prompt": "Analyze this codebase", "files": ["/project/src/"]},
|
||||
)
|
||||
mock_client.get.return_value = initial_context.model_dump_json()
|
||||
|
||||
|
||||
@@ -45,18 +45,32 @@ class TestCustomProvider:
|
||||
|
||||
def test_get_capabilities_from_registry(self):
|
||||
"""Test get_capabilities returns registry capabilities when available."""
|
||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||
# Save original environment
|
||||
original_env = os.environ.get("OPENROUTER_ALLOWED_MODELS")
|
||||
|
||||
# Test with a model that should be in the registry (OpenRouter model) and is allowed by restrictions
|
||||
capabilities = provider.get_capabilities("o3") # o3 is in OPENROUTER_ALLOWED_MODELS
|
||||
try:
|
||||
# Clear any restrictions
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||
|
||||
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
|
||||
assert capabilities.context_window > 0
|
||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||
|
||||
# Test with a custom model (is_custom=true)
|
||||
capabilities = provider.get_capabilities("local-llama")
|
||||
assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true
|
||||
assert capabilities.context_window > 0
|
||||
# Test with a model that should be in the registry (OpenRouter model)
|
||||
capabilities = provider.get_capabilities("o3") # o3 is an OpenRouter model
|
||||
|
||||
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
|
||||
assert capabilities.context_window > 0
|
||||
|
||||
# Test with a custom model (is_custom=true)
|
||||
capabilities = provider.get_capabilities("local-llama")
|
||||
assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true
|
||||
assert capabilities.context_window > 0
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
if original_env is None:
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||
else:
|
||||
os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env
|
||||
|
||||
def test_get_capabilities_generic_fallback(self):
|
||||
"""Test get_capabilities returns generic capabilities for unknown models."""
|
||||
|
||||
@@ -84,7 +84,7 @@ class TestDIALProvider:
|
||||
# Test O3 capabilities
|
||||
capabilities = provider.get_capabilities("o3")
|
||||
assert capabilities.model_name == "o3-2025-04-16"
|
||||
assert capabilities.friendly_name == "DIAL"
|
||||
assert capabilities.friendly_name == "DIAL (O3)"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.provider == ProviderType.DIAL
|
||||
assert capabilities.supports_images is True
|
||||
|
||||
140
tests/test_disabled_tools.py
Normal file
140
tests/test_disabled_tools.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Tests for DISABLED_TOOLS environment variable functionality."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from server import (
|
||||
apply_tool_filter,
|
||||
parse_disabled_tools_env,
|
||||
validate_disabled_tools,
|
||||
)
|
||||
|
||||
|
||||
# Mock the tool classes since we're testing the filtering logic
|
||||
class MockTool:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class TestDisabledTools:
|
||||
"""Test suite for DISABLED_TOOLS functionality."""
|
||||
|
||||
def test_parse_disabled_tools_empty(self):
|
||||
"""Empty string returns empty set (no tools disabled)."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": ""}):
|
||||
assert parse_disabled_tools_env() == set()
|
||||
|
||||
def test_parse_disabled_tools_not_set(self):
|
||||
"""Unset variable returns empty set."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Ensure DISABLED_TOOLS is not in environment
|
||||
if "DISABLED_TOOLS" in os.environ:
|
||||
del os.environ["DISABLED_TOOLS"]
|
||||
assert parse_disabled_tools_env() == set()
|
||||
|
||||
def test_parse_disabled_tools_single(self):
|
||||
"""Single tool name parsed correctly."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug"}):
|
||||
assert parse_disabled_tools_env() == {"debug"}
|
||||
|
||||
def test_parse_disabled_tools_multiple(self):
|
||||
"""Multiple tools with spaces parsed correctly."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug, analyze, refactor"}):
|
||||
assert parse_disabled_tools_env() == {"debug", "analyze", "refactor"}
|
||||
|
||||
def test_parse_disabled_tools_extra_spaces(self):
|
||||
"""Extra spaces and empty items handled correctly."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": " debug , , analyze , "}):
|
||||
assert parse_disabled_tools_env() == {"debug", "analyze"}
|
||||
|
||||
def test_parse_disabled_tools_duplicates(self):
|
||||
"""Duplicate entries handled correctly (set removes duplicates)."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug,analyze,debug"}):
|
||||
assert parse_disabled_tools_env() == {"debug", "analyze"}
|
||||
|
||||
def test_tool_filtering_logic(self):
|
||||
"""Test the complete filtering logic using the actual server functions."""
|
||||
# Simulate ALL_TOOLS
|
||||
ALL_TOOLS = {
|
||||
"chat": MockTool("chat"),
|
||||
"debug": MockTool("debug"),
|
||||
"analyze": MockTool("analyze"),
|
||||
"version": MockTool("version"),
|
||||
"listmodels": MockTool("listmodels"),
|
||||
}
|
||||
|
||||
# Test case 1: No tools disabled
|
||||
disabled_tools = set()
|
||||
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||
|
||||
assert len(enabled_tools) == 5 # All tools included
|
||||
assert set(enabled_tools.keys()) == set(ALL_TOOLS.keys())
|
||||
|
||||
# Test case 2: Disable some regular tools
|
||||
disabled_tools = {"debug", "analyze"}
|
||||
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||
|
||||
assert len(enabled_tools) == 3 # chat, version, listmodels
|
||||
assert "debug" not in enabled_tools
|
||||
assert "analyze" not in enabled_tools
|
||||
assert "chat" in enabled_tools
|
||||
assert "version" in enabled_tools
|
||||
assert "listmodels" in enabled_tools
|
||||
|
||||
# Test case 3: Attempt to disable essential tools
|
||||
disabled_tools = {"version", "chat"}
|
||||
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||
|
||||
assert "version" in enabled_tools # Essential tool not disabled
|
||||
assert "chat" not in enabled_tools # Regular tool disabled
|
||||
assert "listmodels" in enabled_tools # Essential tool included
|
||||
|
||||
def test_unknown_tools_warning(self, caplog):
|
||||
"""Test that unknown tool names generate appropriate warnings."""
|
||||
ALL_TOOLS = {
|
||||
"chat": MockTool("chat"),
|
||||
"debug": MockTool("debug"),
|
||||
"analyze": MockTool("analyze"),
|
||||
"version": MockTool("version"),
|
||||
"listmodels": MockTool("listmodels"),
|
||||
}
|
||||
disabled_tools = {"chat", "unknown_tool", "another_unknown"}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
validate_disabled_tools(disabled_tools, ALL_TOOLS)
|
||||
assert "Unknown tools in DISABLED_TOOLS: ['another_unknown', 'unknown_tool']" in caplog.text
|
||||
|
||||
def test_essential_tools_warning(self, caplog):
|
||||
"""Test warning when trying to disable essential tools."""
|
||||
ALL_TOOLS = {
|
||||
"chat": MockTool("chat"),
|
||||
"debug": MockTool("debug"),
|
||||
"analyze": MockTool("analyze"),
|
||||
"version": MockTool("version"),
|
||||
"listmodels": MockTool("listmodels"),
|
||||
}
|
||||
disabled_tools = {"version", "chat", "debug"}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
validate_disabled_tools(disabled_tools, ALL_TOOLS)
|
||||
assert "Cannot disable essential tools: ['version']" in caplog.text
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_value,expected",
|
||||
[
|
||||
("", set()), # Empty string
|
||||
(" ", set()), # Only spaces
|
||||
(",,,", set()), # Only commas
|
||||
("chat", {"chat"}), # Single tool
|
||||
("chat,debug", {"chat", "debug"}), # Multiple tools
|
||||
("chat, debug, analyze", {"chat", "debug", "analyze"}), # With spaces
|
||||
("chat,debug,chat", {"chat", "debug"}), # Duplicates
|
||||
],
|
||||
)
|
||||
def test_parse_disabled_tools_parametrized(self, env_value, expected):
|
||||
"""Parametrized tests for various input formats."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": env_value}):
|
||||
assert parse_disabled_tools_env() == expected
|
||||
@@ -483,14 +483,14 @@ class TestImageSupportIntegration:
|
||||
tool_name="chat",
|
||||
)
|
||||
|
||||
# Create child thread linked to parent
|
||||
child_thread_id = create_thread("debug", {"child": "context"}, parent_thread_id=parent_thread_id)
|
||||
# Create child thread linked to parent using a simple tool
|
||||
child_thread_id = create_thread("chat", {"prompt": "child context"}, parent_thread_id=parent_thread_id)
|
||||
add_turn(
|
||||
thread_id=child_thread_id,
|
||||
role="user",
|
||||
content="Child thread with more images",
|
||||
images=["child1.png", "shared.png"], # shared.png appears again (should prioritize newer)
|
||||
tool_name="debug",
|
||||
tool_name="chat",
|
||||
)
|
||||
|
||||
# Mock child thread context for get_thread call
|
||||
|
||||
@@ -149,7 +149,7 @@ class TestModelEnumeration:
|
||||
("o3", False), # OpenAI - not available without API key
|
||||
("grok", False), # X.AI - not available without API key
|
||||
("gemini-2.5-flash", False), # Full Gemini name - not available without API key
|
||||
("o4-mini-high", False), # OpenAI variant - not available without API key
|
||||
("o4-mini", False), # OpenAI variant - not available without API key
|
||||
("grok-3-fast", False), # X.AI variant - not available without API key
|
||||
],
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ class TestModelMetadataContinuation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_turns_uses_last_assistant_model(self):
|
||||
"""Test that with multiple turns, the last assistant turn's model is used."""
|
||||
thread_id = create_thread("analyze", {"prompt": "analyze this"})
|
||||
thread_id = create_thread("chat", {"prompt": "analyze this"})
|
||||
|
||||
# Add multiple turns with different models
|
||||
add_turn(thread_id, "assistant", "First response", model_name="gemini-2.5-flash", model_provider="google")
|
||||
@@ -185,11 +185,11 @@ class TestModelMetadataContinuation:
|
||||
async def test_thread_chain_model_preservation(self):
|
||||
"""Test model preservation across thread chains (parent-child relationships)."""
|
||||
# Create parent thread
|
||||
parent_id = create_thread("analyze", {"prompt": "analyze"})
|
||||
parent_id = create_thread("chat", {"prompt": "analyze"})
|
||||
add_turn(parent_id, "assistant", "Analysis", model_name="gemini-2.5-pro", model_provider="google")
|
||||
|
||||
# Create child thread
|
||||
child_id = create_thread("codereview", {"prompt": "review"}, parent_thread_id=parent_id)
|
||||
# Create child thread using a simple tool instead of workflow tool
|
||||
child_id = create_thread("chat", {"prompt": "review"}, parent_thread_id=parent_id)
|
||||
|
||||
# Child thread should be able to access parent's model through chain traversal
|
||||
# NOTE: Current implementation only checks current thread (not parent threads)
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestModelRestrictionService:
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
|
||||
service = ModelRestrictionService()
|
||||
|
||||
models = ["o3", "o3-mini", "o4-mini", "o4-mini-high"]
|
||||
models = ["o3", "o3-mini", "o4-mini", "o3-pro"]
|
||||
filtered = service.filter_models(ProviderType.OPENAI, models)
|
||||
|
||||
assert filtered == ["o3-mini", "o4-mini"]
|
||||
@@ -573,7 +573,7 @@ class TestShorthandRestrictions:
|
||||
|
||||
# Other models should not work
|
||||
assert not openai_provider.validate_model_name("o3")
|
||||
assert not openai_provider.validate_model_name("o4-mini-high")
|
||||
assert not openai_provider.validate_model_name("o3-pro")
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
|
||||
@@ -185,7 +185,7 @@ class TestO3TemperatureParameterFixSimple:
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
# Test O3/O4 models that should NOT support temperature parameter
|
||||
o3_o4_models = ["o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high"]
|
||||
o3_o4_models = ["o3", "o3-mini", "o3-pro", "o4-mini"]
|
||||
|
||||
for model in o3_o4_models:
|
||||
capabilities = provider.get_capabilities(model)
|
||||
|
||||
@@ -47,14 +47,13 @@ class TestOpenAIProvider:
|
||||
assert provider.validate_model_name("o3-mini") is True
|
||||
assert provider.validate_model_name("o3-pro") is True
|
||||
assert provider.validate_model_name("o4-mini") is True
|
||||
assert provider.validate_model_name("o4-mini-high") is True
|
||||
assert provider.validate_model_name("o4-mini") is True
|
||||
|
||||
# Test valid aliases
|
||||
assert provider.validate_model_name("mini") is True
|
||||
assert provider.validate_model_name("o3mini") is True
|
||||
assert provider.validate_model_name("o4mini") is True
|
||||
assert provider.validate_model_name("o4minihigh") is True
|
||||
assert provider.validate_model_name("o4minihi") is True
|
||||
assert provider.validate_model_name("o4mini") is True
|
||||
|
||||
# Test invalid model
|
||||
assert provider.validate_model_name("invalid-model") is False
|
||||
@@ -69,15 +68,14 @@ class TestOpenAIProvider:
|
||||
assert provider._resolve_model_name("mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("o3mini") == "o3-mini"
|
||||
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("o4minihigh") == "o4-mini-high"
|
||||
assert provider._resolve_model_name("o4minihi") == "o4-mini-high"
|
||||
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
||||
|
||||
# Test full name passthrough
|
||||
assert provider._resolve_model_name("o3") == "o3"
|
||||
assert provider._resolve_model_name("o3-mini") == "o3-mini"
|
||||
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
|
||||
assert provider._resolve_model_name("o4-mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("o4-mini-high") == "o4-mini-high"
|
||||
assert provider._resolve_model_name("o4-mini") == "o4-mini"
|
||||
|
||||
def test_get_capabilities_o3(self):
|
||||
"""Test getting model capabilities for O3."""
|
||||
@@ -85,7 +83,7 @@ class TestOpenAIProvider:
|
||||
|
||||
capabilities = provider.get_capabilities("o3")
|
||||
assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities
|
||||
assert capabilities.friendly_name == "OpenAI"
|
||||
assert capabilities.friendly_name == "OpenAI (O3)"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.provider == ProviderType.OPENAI
|
||||
assert not capabilities.supports_extended_thinking
|
||||
@@ -101,8 +99,8 @@ class TestOpenAIProvider:
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("mini")
|
||||
assert capabilities.model_name == "mini" # Capabilities should show original request
|
||||
assert capabilities.friendly_name == "OpenAI"
|
||||
assert capabilities.model_name == "o4-mini" # Capabilities should show resolved model name
|
||||
assert capabilities.friendly_name == "OpenAI (O4-mini)"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.provider == ProviderType.OPENAI
|
||||
|
||||
@@ -184,11 +182,11 @@ class TestOpenAIProvider:
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["model"] == "o3-mini"
|
||||
|
||||
# Test o4minihigh -> o4-mini-high
|
||||
mock_response.model = "o4-mini-high"
|
||||
provider.generate_content(prompt="Test", model_name="o4minihigh", temperature=1.0)
|
||||
# Test o4mini -> o4-mini
|
||||
mock_response.model = "o4-mini"
|
||||
provider.generate_content(prompt="Test", model_name="o4mini", temperature=1.0)
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["model"] == "o4-mini-high"
|
||||
assert call_kwargs["model"] == "o4-mini"
|
||||
|
||||
@patch("providers.openai_compatible.OpenAI")
|
||||
def test_generate_content_no_alias_passthrough(self, mock_openai_class):
|
||||
|
||||
@@ -57,7 +57,7 @@ class TestOpenRouterProvider:
|
||||
caps = provider.get_capabilities("o3")
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "openai/o3" # Resolved name
|
||||
assert caps.friendly_name == "OpenRouter"
|
||||
assert caps.friendly_name == "OpenRouter (openai/o3)"
|
||||
|
||||
# Test with a model not in registry - should get generic capabilities
|
||||
caps = provider.get_capabilities("unknown-model")
|
||||
@@ -77,7 +77,7 @@ class TestOpenRouterProvider:
|
||||
assert provider._resolve_model_name("o3-mini") == "openai/o3-mini"
|
||||
assert provider._resolve_model_name("o3mini") == "openai/o3-mini"
|
||||
assert provider._resolve_model_name("o4-mini") == "openai/o4-mini"
|
||||
assert provider._resolve_model_name("o4-mini-high") == "openai/o4-mini-high"
|
||||
assert provider._resolve_model_name("o4-mini") == "openai/o4-mini"
|
||||
assert provider._resolve_model_name("claude") == "anthropic/claude-sonnet-4"
|
||||
assert provider._resolve_model_name("mistral") == "mistralai/mistral-large-2411"
|
||||
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-r1-0528"
|
||||
|
||||
@@ -6,8 +6,8 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry
|
||||
from providers.base import ModelCapabilities, ProviderType
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
|
||||
class TestOpenRouterModelRegistry:
|
||||
@@ -24,7 +24,16 @@ class TestOpenRouterModelRegistry:
|
||||
def test_custom_config_path(self):
|
||||
"""Test registry with custom config path."""
|
||||
# Create temporary config
|
||||
config_data = {"models": [{"model_name": "test/model-1", "aliases": ["test1", "t1"], "context_window": 4096}]}
|
||||
config_data = {
|
||||
"models": [
|
||||
{
|
||||
"model_name": "test/model-1",
|
||||
"aliases": ["test1", "t1"],
|
||||
"context_window": 4096,
|
||||
"max_output_tokens": 2048,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
@@ -42,7 +51,11 @@ class TestOpenRouterModelRegistry:
|
||||
def test_environment_variable_override(self):
|
||||
"""Test OPENROUTER_MODELS_PATH environment variable."""
|
||||
# Create custom config
|
||||
config_data = {"models": [{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192}]}
|
||||
config_data = {
|
||||
"models": [
|
||||
{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192, "max_output_tokens": 4096}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
@@ -110,28 +123,29 @@ class TestOpenRouterModelRegistry:
|
||||
assert registry.resolve("non-existent") is None
|
||||
|
||||
def test_model_capabilities_conversion(self):
|
||||
"""Test conversion to ModelCapabilities."""
|
||||
"""Test that registry returns ModelCapabilities directly."""
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
config = registry.resolve("opus")
|
||||
assert config is not None
|
||||
|
||||
caps = config.to_capabilities()
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "anthropic/claude-opus-4"
|
||||
assert caps.friendly_name == "OpenRouter"
|
||||
assert caps.context_window == 200000
|
||||
assert not caps.supports_extended_thinking
|
||||
# Registry now returns ModelCapabilities objects directly
|
||||
assert config.provider == ProviderType.OPENROUTER
|
||||
assert config.model_name == "anthropic/claude-opus-4"
|
||||
assert config.friendly_name == "OpenRouter (anthropic/claude-opus-4)"
|
||||
assert config.context_window == 200000
|
||||
assert not config.supports_extended_thinking
|
||||
|
||||
def test_duplicate_alias_detection(self):
|
||||
"""Test that duplicate aliases are detected."""
|
||||
config_data = {
|
||||
"models": [
|
||||
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096},
|
||||
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096, "max_output_tokens": 2048},
|
||||
{
|
||||
"model_name": "test/model-2",
|
||||
"aliases": ["DUPE"], # Same alias, different case
|
||||
"context_window": 8192,
|
||||
"max_output_tokens": 2048,
|
||||
},
|
||||
]
|
||||
}
|
||||
@@ -199,19 +213,23 @@ class TestOpenRouterModelRegistry:
|
||||
|
||||
def test_model_with_all_capabilities(self):
|
||||
"""Test model with all capability flags."""
|
||||
config = OpenRouterModelConfig(
|
||||
from providers.base import create_temperature_constraint
|
||||
|
||||
caps = ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
model_name="test/full-featured",
|
||||
friendly_name="OpenRouter (test/full-featured)",
|
||||
aliases=["full"],
|
||||
context_window=128000,
|
||||
max_output_tokens=8192,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
description="Fully featured test model",
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
)
|
||||
|
||||
caps = config.to_capabilities()
|
||||
assert caps.context_window == 128000
|
||||
assert caps.supports_extended_thinking
|
||||
assert caps.supports_system_prompts
|
||||
|
||||
79
tests/test_parse_model_option.py
Normal file
79
tests/test_parse_model_option.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for parse_model_option function."""
|
||||
|
||||
from server import parse_model_option
|
||||
|
||||
|
||||
class TestParseModelOption:
|
||||
"""Test cases for model option parsing."""
|
||||
|
||||
def test_openrouter_free_suffix_preserved(self):
|
||||
"""Test that OpenRouter :free suffix is preserved as part of model name."""
|
||||
model, option = parse_model_option("openai/gpt-3.5-turbo:free")
|
||||
assert model == "openai/gpt-3.5-turbo:free"
|
||||
assert option is None
|
||||
|
||||
def test_openrouter_beta_suffix_preserved(self):
|
||||
"""Test that OpenRouter :beta suffix is preserved as part of model name."""
|
||||
model, option = parse_model_option("anthropic/claude-3-opus:beta")
|
||||
assert model == "anthropic/claude-3-opus:beta"
|
||||
assert option is None
|
||||
|
||||
def test_openrouter_preview_suffix_preserved(self):
|
||||
"""Test that OpenRouter :preview suffix is preserved as part of model name."""
|
||||
model, option = parse_model_option("google/gemini-pro:preview")
|
||||
assert model == "google/gemini-pro:preview"
|
||||
assert option is None
|
||||
|
||||
def test_ollama_tag_parsed_as_option(self):
|
||||
"""Test that Ollama tags are parsed as options."""
|
||||
model, option = parse_model_option("llama3.2:latest")
|
||||
assert model == "llama3.2"
|
||||
assert option == "latest"
|
||||
|
||||
def test_consensus_stance_parsed_as_option(self):
|
||||
"""Test that consensus stances are parsed as options."""
|
||||
model, option = parse_model_option("o3:for")
|
||||
assert model == "o3"
|
||||
assert option == "for"
|
||||
|
||||
model, option = parse_model_option("gemini-2.5-pro:against")
|
||||
assert model == "gemini-2.5-pro"
|
||||
assert option == "against"
|
||||
|
||||
def test_openrouter_unknown_suffix_parsed_as_option(self):
|
||||
"""Test that unknown suffixes on OpenRouter models are parsed as options."""
|
||||
model, option = parse_model_option("openai/gpt-4:custom-tag")
|
||||
assert model == "openai/gpt-4"
|
||||
assert option == "custom-tag"
|
||||
|
||||
def test_plain_model_name(self):
|
||||
"""Test plain model names without colons."""
|
||||
model, option = parse_model_option("gpt-4")
|
||||
assert model == "gpt-4"
|
||||
assert option is None
|
||||
|
||||
def test_url_not_parsed(self):
|
||||
"""Test that URLs are not parsed for options."""
|
||||
model, option = parse_model_option("http://localhost:8080")
|
||||
assert model == "http://localhost:8080"
|
||||
assert option is None
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test that whitespace is properly stripped."""
|
||||
model, option = parse_model_option(" openai/gpt-3.5-turbo:free ")
|
||||
assert model == "openai/gpt-3.5-turbo:free"
|
||||
assert option is None
|
||||
|
||||
model, option = parse_model_option(" llama3.2 : latest ")
|
||||
assert model == "llama3.2"
|
||||
assert option == "latest"
|
||||
|
||||
def test_case_insensitive_suffix_matching(self):
|
||||
"""Test that OpenRouter suffix matching is case-insensitive."""
|
||||
model, option = parse_model_option("openai/gpt-3.5-turbo:FREE")
|
||||
assert model == "openai/gpt-3.5-turbo:FREE" # Original case preserved
|
||||
assert option is None
|
||||
|
||||
model, option = parse_model_option("openai/gpt-3.5-turbo:Free")
|
||||
assert model == "openai/gpt-3.5-turbo:Free" # Original case preserved
|
||||
assert option is None
|
||||
@@ -58,7 +58,13 @@ class TestProviderRoutingBugs:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
@@ -66,6 +72,7 @@ class TestProviderRoutingBugs:
|
||||
os.environ.pop("GEMINI_API_KEY", None) # No Google API key
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("XAI_API_KEY", None)
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||
|
||||
# Register only OpenRouter provider (like in server.py:configure_providers)
|
||||
@@ -113,12 +120,24 @@ class TestProviderRoutingBugs:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up scenario: NO API keys at all
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Create tool to test fallback logic
|
||||
@@ -151,7 +170,13 @@ class TestProviderRoutingBugs:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
@@ -160,6 +185,7 @@ class TestProviderRoutingBugs:
|
||||
os.environ["OPENAI_API_KEY"] = "test-openai-key"
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||
os.environ.pop("XAI_API_KEY", None)
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||
|
||||
# Register providers in priority order (like server.py)
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
@@ -215,9 +215,7 @@ class TestOpenAIProvider:
|
||||
assert provider.validate_model_name("o3-mini") # Backwards compatibility
|
||||
assert provider.validate_model_name("o4-mini")
|
||||
assert provider.validate_model_name("o4mini")
|
||||
assert provider.validate_model_name("o4-mini-high")
|
||||
assert provider.validate_model_name("o4minihigh")
|
||||
assert provider.validate_model_name("o4minihi")
|
||||
assert provider.validate_model_name("o4-mini")
|
||||
assert not provider.validate_model_name("gpt-4o")
|
||||
assert not provider.validate_model_name("invalid-model")
|
||||
|
||||
@@ -229,4 +227,4 @@ class TestOpenAIProvider:
|
||||
assert not provider.supports_thinking_mode("o3mini")
|
||||
assert not provider.supports_thinking_mode("o3-mini")
|
||||
assert not provider.supports_thinking_mode("o4-mini")
|
||||
assert not provider.supports_thinking_mode("o4-mini-high")
|
||||
assert not provider.supports_thinking_mode("o4-mini")
|
||||
|
||||
205
tests/test_supported_models_aliases.py
Normal file
205
tests/test_supported_models_aliases.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Test the SUPPORTED_MODELS aliases structure across all providers."""
|
||||
|
||||
from providers.dial import DIALModelProvider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
|
||||
class TestSupportedModelsAliases:
|
||||
"""Test that all providers have correctly structured SUPPORTED_MODELS with aliases."""
|
||||
|
||||
def test_gemini_provider_aliases(self):
|
||||
"""Test Gemini provider's alias structure."""
|
||||
provider = GeminiModelProvider("test-key")
|
||||
|
||||
# Check that all models have ModelCapabilities with aliases
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||
|
||||
# Test specific aliases
|
||||
assert "flash" in provider.SUPPORTED_MODELS["gemini-2.5-flash"].aliases
|
||||
assert "pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro"].aliases
|
||||
assert "flash-2.0" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
|
||||
assert "flash2" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
|
||||
assert "flashlite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
|
||||
assert "flash-lite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
|
||||
|
||||
# Test alias resolution
|
||||
assert provider._resolve_model_name("flash") == "gemini-2.5-flash"
|
||||
assert provider._resolve_model_name("pro") == "gemini-2.5-pro"
|
||||
assert provider._resolve_model_name("flash-2.0") == "gemini-2.0-flash"
|
||||
assert provider._resolve_model_name("flash2") == "gemini-2.0-flash"
|
||||
assert provider._resolve_model_name("flashlite") == "gemini-2.0-flash-lite"
|
||||
|
||||
# Test case insensitive resolution
|
||||
assert provider._resolve_model_name("Flash") == "gemini-2.5-flash"
|
||||
assert provider._resolve_model_name("PRO") == "gemini-2.5-pro"
|
||||
|
||||
def test_openai_provider_aliases(self):
|
||||
"""Test OpenAI provider's alias structure."""
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
# Check that all models have ModelCapabilities with aliases
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||
|
||||
# Test specific aliases
|
||||
assert "mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
||||
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
||||
assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases
|
||||
assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro-2025-06-10"].aliases
|
||||
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
||||
assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1-2025-04-14"].aliases
|
||||
|
||||
# Test alias resolution
|
||||
assert provider._resolve_model_name("mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("o3mini") == "o3-mini"
|
||||
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
|
||||
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("gpt4.1") == "gpt-4.1-2025-04-14"
|
||||
|
||||
# Test case insensitive resolution
|
||||
assert provider._resolve_model_name("Mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("O3MINI") == "o3-mini"
|
||||
|
||||
def test_xai_provider_aliases(self):
|
||||
"""Test XAI provider's alias structure."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# Check that all models have ModelCapabilities with aliases
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||
|
||||
# Test specific aliases
|
||||
assert "grok" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
||||
assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
||||
assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
||||
assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
||||
|
||||
# Test alias resolution
|
||||
assert provider._resolve_model_name("grok") == "grok-3"
|
||||
assert provider._resolve_model_name("grok3") == "grok-3"
|
||||
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
||||
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
|
||||
|
||||
# Test case insensitive resolution
|
||||
assert provider._resolve_model_name("Grok") == "grok-3"
|
||||
assert provider._resolve_model_name("GROKFAST") == "grok-3-fast"
|
||||
|
||||
def test_dial_provider_aliases(self):
|
||||
"""Test DIAL provider's alias structure."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Check that all models have ModelCapabilities with aliases
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||
|
||||
# Test specific aliases
|
||||
assert "o3" in provider.SUPPORTED_MODELS["o3-2025-04-16"].aliases
|
||||
assert "o4-mini" in provider.SUPPORTED_MODELS["o4-mini-2025-04-16"].aliases
|
||||
assert "sonnet-4" in provider.SUPPORTED_MODELS["anthropic.claude-sonnet-4-20250514-v1:0"].aliases
|
||||
assert "opus-4" in provider.SUPPORTED_MODELS["anthropic.claude-opus-4-20250514-v1:0"].aliases
|
||||
assert "gemini-2.5-pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro-preview-05-06"].aliases
|
||||
|
||||
# Test alias resolution
|
||||
assert provider._resolve_model_name("o3") == "o3-2025-04-16"
|
||||
assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16"
|
||||
assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0"
|
||||
|
||||
# Test case insensitive resolution
|
||||
assert provider._resolve_model_name("O3") == "o3-2025-04-16"
|
||||
assert provider._resolve_model_name("SONNET-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
|
||||
def test_list_models_includes_aliases(self):
|
||||
"""Test that list_models returns both base models and aliases."""
|
||||
# Test Gemini
|
||||
gemini_provider = GeminiModelProvider("test-key")
|
||||
gemini_models = gemini_provider.list_models(respect_restrictions=False)
|
||||
assert "gemini-2.5-flash" in gemini_models
|
||||
assert "flash" in gemini_models
|
||||
assert "gemini-2.5-pro" in gemini_models
|
||||
assert "pro" in gemini_models
|
||||
|
||||
# Test OpenAI
|
||||
openai_provider = OpenAIModelProvider("test-key")
|
||||
openai_models = openai_provider.list_models(respect_restrictions=False)
|
||||
assert "o4-mini" in openai_models
|
||||
assert "mini" in openai_models
|
||||
assert "o3-mini" in openai_models
|
||||
assert "o3mini" in openai_models
|
||||
|
||||
# Test XAI
|
||||
xai_provider = XAIModelProvider("test-key")
|
||||
xai_models = xai_provider.list_models(respect_restrictions=False)
|
||||
assert "grok-3" in xai_models
|
||||
assert "grok" in xai_models
|
||||
assert "grok-3-fast" in xai_models
|
||||
assert "grokfast" in xai_models
|
||||
|
||||
# Test DIAL
|
||||
dial_provider = DIALModelProvider("test-key")
|
||||
dial_models = dial_provider.list_models(respect_restrictions=False)
|
||||
assert "o3-2025-04-16" in dial_models
|
||||
assert "o3" in dial_models
|
||||
|
||||
def test_list_all_known_models_includes_aliases(self):
|
||||
"""Test that list_all_known_models returns all models and aliases in lowercase."""
|
||||
# Test Gemini
|
||||
gemini_provider = GeminiModelProvider("test-key")
|
||||
gemini_all = gemini_provider.list_all_known_models()
|
||||
assert "gemini-2.5-flash" in gemini_all
|
||||
assert "flash" in gemini_all
|
||||
assert "gemini-2.5-pro" in gemini_all
|
||||
assert "pro" in gemini_all
|
||||
# All should be lowercase
|
||||
assert all(model == model.lower() for model in gemini_all)
|
||||
|
||||
# Test OpenAI
|
||||
openai_provider = OpenAIModelProvider("test-key")
|
||||
openai_all = openai_provider.list_all_known_models()
|
||||
assert "o4-mini" in openai_all
|
||||
assert "mini" in openai_all
|
||||
assert "o3-mini" in openai_all
|
||||
assert "o3mini" in openai_all
|
||||
# All should be lowercase
|
||||
assert all(model == model.lower() for model in openai_all)
|
||||
|
||||
def test_no_string_shorthand_in_supported_models(self):
|
||||
"""Test that no provider has string-based shorthands anymore."""
|
||||
providers = [
|
||||
GeminiModelProvider("test-key"),
|
||||
OpenAIModelProvider("test-key"),
|
||||
XAIModelProvider("test-key"),
|
||||
DIALModelProvider("test-key"),
|
||||
]
|
||||
|
||||
for provider in providers:
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
# All values must be ModelCapabilities objects, not strings or dicts
|
||||
from providers.base import ModelCapabilities
|
||||
|
||||
assert isinstance(config, ModelCapabilities), (
|
||||
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "
|
||||
f"must be a ModelCapabilities object, not {type(config).__name__}"
|
||||
)
|
||||
|
||||
def test_resolve_returns_original_if_not_found(self):
|
||||
"""Test that _resolve_model_name returns original name if alias not found."""
|
||||
providers = [
|
||||
GeminiModelProvider("test-key"),
|
||||
OpenAIModelProvider("test-key"),
|
||||
XAIModelProvider("test-key"),
|
||||
DIALModelProvider("test-key"),
|
||||
]
|
||||
|
||||
for provider in providers:
|
||||
# Test with unknown model name
|
||||
assert provider._resolve_model_name("unknown-model") == "unknown-model"
|
||||
assert provider._resolve_model_name("gpt-4") == "gpt-4"
|
||||
assert provider._resolve_model_name("claude-3") == "claude-3"
|
||||
@@ -48,7 +48,13 @@ class TestWorkflowMetadata:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
@@ -56,6 +62,7 @@ class TestWorkflowMetadata:
|
||||
os.environ.pop("GEMINI_API_KEY", None)
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("XAI_API_KEY", None)
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||
|
||||
# Register OpenRouter provider
|
||||
@@ -124,7 +131,13 @@ class TestWorkflowMetadata:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
@@ -132,6 +145,7 @@ class TestWorkflowMetadata:
|
||||
os.environ.pop("GEMINI_API_KEY", None)
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("XAI_API_KEY", None)
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||
|
||||
# Register OpenRouter provider
|
||||
@@ -182,43 +196,60 @@ class TestWorkflowMetadata:
|
||||
"""
|
||||
Test that workflow tools handle metadata gracefully when model context is missing.
|
||||
"""
|
||||
# Create debug tool
|
||||
debug_tool = DebugIssueTool()
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["OPENROUTER_ALLOWED_MODELS"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
# Create arguments without model context (fallback scenario)
|
||||
arguments = {
|
||||
"step": "Test step without model context",
|
||||
"step_number": 1,
|
||||
"total_steps": 1,
|
||||
"next_step_required": False,
|
||||
"findings": "Test findings",
|
||||
"model": "flash",
|
||||
"confidence": "low",
|
||||
# No _model_context or _resolved_model_name
|
||||
}
|
||||
try:
|
||||
# Clear any restrictions
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||
|
||||
# Execute the workflow tool
|
||||
import asyncio
|
||||
# Create debug tool
|
||||
debug_tool = DebugIssueTool()
|
||||
|
||||
result = asyncio.run(debug_tool.execute_workflow(arguments))
|
||||
# Create arguments without model context (fallback scenario)
|
||||
arguments = {
|
||||
"step": "Test step without model context",
|
||||
"step_number": 1,
|
||||
"total_steps": 1,
|
||||
"next_step_required": False,
|
||||
"findings": "Test findings",
|
||||
"model": "flash",
|
||||
"confidence": "low",
|
||||
# No _model_context or _resolved_model_name
|
||||
}
|
||||
|
||||
# Parse the JSON response
|
||||
assert len(result) == 1
|
||||
response_text = result[0].text
|
||||
response_data = json.loads(response_text)
|
||||
# Execute the workflow tool
|
||||
import asyncio
|
||||
|
||||
# Verify metadata is still present with fallback values
|
||||
assert "metadata" in response_data, "Workflow response should include metadata even in fallback"
|
||||
metadata = response_data["metadata"]
|
||||
result = asyncio.run(debug_tool.execute_workflow(arguments))
|
||||
|
||||
# Verify fallback metadata
|
||||
assert "tool_name" in metadata, "Fallback metadata should include tool_name"
|
||||
assert "model_used" in metadata, "Fallback metadata should include model_used"
|
||||
assert "provider_used" in metadata, "Fallback metadata should include provider_used"
|
||||
# Parse the JSON response
|
||||
assert len(result) == 1
|
||||
response_text = result[0].text
|
||||
response_data = json.loads(response_text)
|
||||
|
||||
assert metadata["tool_name"] == "debug", "tool_name should be 'debug'"
|
||||
assert metadata["model_used"] == "flash", "model_used should be from request"
|
||||
assert metadata["provider_used"] == "unknown", "provider_used should be 'unknown' in fallback"
|
||||
# Verify metadata is still present with fallback values
|
||||
assert "metadata" in response_data, "Workflow response should include metadata even in fallback"
|
||||
metadata = response_data["metadata"]
|
||||
|
||||
# Verify fallback metadata
|
||||
assert "tool_name" in metadata, "Fallback metadata should include tool_name"
|
||||
assert "model_used" in metadata, "Fallback metadata should include model_used"
|
||||
assert "provider_used" in metadata, "Fallback metadata should include provider_used"
|
||||
|
||||
assert metadata["tool_name"] == "debug", "tool_name should be 'debug'"
|
||||
assert metadata["model_used"] == "flash", "model_used should be from request"
|
||||
assert metadata["provider_used"] == "unknown", "provider_used should be 'unknown' in fallback"
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
def test_workflow_metadata_preserves_existing_response_fields(self):
|
||||
@@ -227,7 +258,13 @@ class TestWorkflowMetadata:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
@@ -235,6 +272,7 @@ class TestWorkflowMetadata:
|
||||
os.environ.pop("GEMINI_API_KEY", None)
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("XAI_API_KEY", None)
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||
|
||||
# Register OpenRouter provider
|
||||
|
||||
@@ -77,7 +77,7 @@ class TestXAIProvider:
|
||||
|
||||
capabilities = provider.get_capabilities("grok-3")
|
||||
assert capabilities.model_name == "grok-3"
|
||||
assert capabilities.friendly_name == "X.AI"
|
||||
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
||||
assert capabilities.context_window == 131_072
|
||||
assert capabilities.provider == ProviderType.XAI
|
||||
assert not capabilities.supports_extended_thinking
|
||||
@@ -96,7 +96,7 @@ class TestXAIProvider:
|
||||
|
||||
capabilities = provider.get_capabilities("grok-3-fast")
|
||||
assert capabilities.model_name == "grok-3-fast"
|
||||
assert capabilities.friendly_name == "X.AI"
|
||||
assert capabilities.friendly_name == "X.AI (Grok 3 Fast)"
|
||||
assert capabilities.context_window == 131_072
|
||||
assert capabilities.provider == ProviderType.XAI
|
||||
assert not capabilities.supports_extended_thinking
|
||||
@@ -212,31 +212,34 @@ class TestXAIProvider:
|
||||
assert provider.FRIENDLY_NAME == "X.AI"
|
||||
|
||||
capabilities = provider.get_capabilities("grok-3")
|
||||
assert capabilities.friendly_name == "X.AI"
|
||||
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
||||
|
||||
def test_supported_models_structure(self):
|
||||
"""Test that SUPPORTED_MODELS has the correct structure."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# Check that all expected models are present
|
||||
# Check that all expected base models are present
|
||||
assert "grok-3" in provider.SUPPORTED_MODELS
|
||||
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
||||
assert "grok" in provider.SUPPORTED_MODELS
|
||||
assert "grok3" in provider.SUPPORTED_MODELS
|
||||
assert "grokfast" in provider.SUPPORTED_MODELS
|
||||
assert "grok3fast" in provider.SUPPORTED_MODELS
|
||||
|
||||
# Check model configs have required fields
|
||||
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
||||
assert isinstance(grok3_config, dict)
|
||||
assert "context_window" in grok3_config
|
||||
assert "supports_extended_thinking" in grok3_config
|
||||
assert grok3_config["context_window"] == 131_072
|
||||
assert grok3_config["supports_extended_thinking"] is False
|
||||
from providers.base import ModelCapabilities
|
||||
|
||||
# Check shortcuts point to full names
|
||||
assert provider.SUPPORTED_MODELS["grok"] == "grok-3"
|
||||
assert provider.SUPPORTED_MODELS["grokfast"] == "grok-3-fast"
|
||||
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
||||
assert isinstance(grok3_config, ModelCapabilities)
|
||||
assert hasattr(grok3_config, "context_window")
|
||||
assert hasattr(grok3_config, "supports_extended_thinking")
|
||||
assert hasattr(grok3_config, "aliases")
|
||||
assert grok3_config.context_window == 131_072
|
||||
assert grok3_config.supports_extended_thinking is False
|
||||
|
||||
# Check aliases are correctly structured
|
||||
assert "grok" in grok3_config.aliases
|
||||
assert "grok3" in grok3_config.aliases
|
||||
|
||||
grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"]
|
||||
assert "grok3fast" in grok3fast_config.aliases
|
||||
assert "grokfast" in grok3fast_config.aliases
|
||||
|
||||
@patch("providers.openai_compatible.OpenAI")
|
||||
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
|
||||
|
||||
Reference in New Issue
Block a user