GPT-5, GPT-5-mini support
Improvements to model name resolution Improved instructions for multi-step workflows when continuation is available Improved instructions for chat tool Improved preferred model resolution, moved code from registry -> each provider Updated tests
This commit is contained in:
@@ -3,6 +3,7 @@ Test per-tool model default selection functionality
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -73,154 +74,194 @@ class TestToolModelCategories:
|
||||
class TestModelSelection:
|
||||
"""Test model selection based on tool categories."""
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test to prevent state pollution."""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# Unregister all providers
|
||||
for provider_type in list(ProviderType):
|
||||
ModelProviderRegistry.unregister_provider(provider_type)
|
||||
|
||||
def test_extended_reasoning_with_openai(self):
|
||||
"""Test EXTENDED_REASONING prefers o3 when OpenAI is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
|
||||
# Mock OpenAI models available
|
||||
mock_get_available.return_value = {
|
||||
"o3": ProviderType.OPENAI,
|
||||
"o3-mini": ProviderType.OPENAI,
|
||||
"o4-mini": ProviderType.OPENAI,
|
||||
}
|
||||
"""Test EXTENDED_REASONING with OpenAI provider."""
|
||||
# Setup with only OpenAI provider
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# First unregister all providers to ensure isolation
|
||||
for provider_type in list(ProviderType):
|
||||
ModelProviderRegistry.unregister_provider(provider_type)
|
||||
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
# OpenAI prefers o3 for extended reasoning
|
||||
assert model == "o3"
|
||||
|
||||
def test_extended_reasoning_with_gemini_only(self):
|
||||
"""Test EXTENDED_REASONING prefers pro when only Gemini is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
|
||||
# Mock only Gemini models available
|
||||
mock_get_available.return_value = {
|
||||
"gemini-2.5-pro": ProviderType.GOOGLE,
|
||||
"gemini-2.5-flash": ProviderType.GOOGLE,
|
||||
}
|
||||
# Clear cache and unregister all providers first
|
||||
ModelProviderRegistry.clear_cache()
|
||||
for provider_type in list(ProviderType):
|
||||
ModelProviderRegistry.unregister_provider(provider_type)
|
||||
|
||||
# Register only Gemini provider
|
||||
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}, clear=False):
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
# Should find the pro model for extended reasoning
|
||||
assert "pro" in model or model == "gemini-2.5-pro"
|
||||
# Gemini should return one of its models for extended reasoning
|
||||
# The default behavior may return flash when pro is not explicitly preferred
|
||||
assert model in ["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.0-flash"]
|
||||
|
||||
def test_fast_response_with_openai(self):
|
||||
"""Test FAST_RESPONSE prefers o4-mini when OpenAI is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
|
||||
# Mock OpenAI models available
|
||||
mock_get_available.return_value = {
|
||||
"o3": ProviderType.OPENAI,
|
||||
"o3-mini": ProviderType.OPENAI,
|
||||
"o4-mini": ProviderType.OPENAI,
|
||||
}
|
||||
"""Test FAST_RESPONSE with OpenAI provider."""
|
||||
# Setup with only OpenAI provider
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# First unregister all providers to ensure isolation
|
||||
for provider_type in list(ProviderType):
|
||||
ModelProviderRegistry.unregister_provider(provider_type)
|
||||
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
assert model == "o4-mini"
|
||||
# OpenAI now prefers gpt-5 for fast response (based on our new preference order)
|
||||
assert model == "gpt-5"
|
||||
|
||||
def test_fast_response_with_gemini_only(self):
|
||||
"""Test FAST_RESPONSE prefers flash when only Gemini is available."""
|
||||
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
|
||||
# Mock only Gemini models available
|
||||
mock_get_available.return_value = {
|
||||
"gemini-2.5-pro": ProviderType.GOOGLE,
|
||||
"gemini-2.5-flash": ProviderType.GOOGLE,
|
||||
}
|
||||
# Clear cache and unregister all providers first
|
||||
ModelProviderRegistry.clear_cache()
|
||||
for provider_type in list(ProviderType):
|
||||
ModelProviderRegistry.unregister_provider(provider_type)
|
||||
|
||||
# Register only Gemini provider
|
||||
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}, clear=False):
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
# Should find the flash model for fast response
|
||||
assert "flash" in model or model == "gemini-2.5-flash"
|
||||
# Gemini should return one of its models for fast response
|
||||
assert model in ["gemini-2.5-flash", "gemini-2.0-flash", "gemini-2.5-pro"]
|
||||
|
||||
def test_balanced_category_fallback(self):
|
||||
"""Test BALANCED category uses existing logic."""
|
||||
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
|
||||
# Mock OpenAI models available
|
||||
mock_get_available.return_value = {
|
||||
"o3": ProviderType.OPENAI,
|
||||
"o3-mini": ProviderType.OPENAI,
|
||||
"o4-mini": ProviderType.OPENAI,
|
||||
}
|
||||
# Setup with only OpenAI provider
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# First unregister all providers to ensure isolation
|
||||
for provider_type in list(ProviderType):
|
||||
ModelProviderRegistry.unregister_provider(provider_type)
|
||||
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
|
||||
assert model == "o4-mini" # Balanced prefers o4-mini when OpenAI available
|
||||
# OpenAI prefers gpt-5 for balanced (based on our new preference order)
|
||||
assert model == "gpt-5"
|
||||
|
||||
def test_no_category_uses_balanced_logic(self):
|
||||
"""Test that no category specified uses balanced logic."""
|
||||
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
|
||||
# Mock only Gemini models available
|
||||
mock_get_available.return_value = {
|
||||
"gemini-2.5-pro": ProviderType.GOOGLE,
|
||||
"gemini-2.5-flash": ProviderType.GOOGLE,
|
||||
}
|
||||
# Setup with only Gemini provider
|
||||
with patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"}, clear=False):
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
# Should pick a reasonable default, preferring flash for balanced use
|
||||
assert "flash" in model or model == "gemini-2.5-flash"
|
||||
# Should pick flash for balanced use
|
||||
assert model == "gemini-2.5-flash"
|
||||
|
||||
|
||||
class TestFlexibleModelSelection:
|
||||
"""Test that model selection handles various naming scenarios."""
|
||||
|
||||
def test_fallback_handles_mixed_model_names(self):
|
||||
"""Test that fallback selection works with mix of full names and shorthands."""
|
||||
# Test with mix of full names and shorthands
|
||||
"""Test that fallback selection works with different providers."""
|
||||
# Test with different provider configurations
|
||||
test_cases = [
|
||||
# Case 1: Mix of OpenAI shorthands and full names
|
||||
# Case 1: OpenAI provider for extended reasoning
|
||||
{
|
||||
"available": {"o3": ProviderType.OPENAI, "o4-mini": ProviderType.OPENAI},
|
||||
"env": {"OPENAI_API_KEY": "test-key"},
|
||||
"provider_type": ProviderType.OPENAI,
|
||||
"category": ToolModelCategory.EXTENDED_REASONING,
|
||||
"expected": "o3",
|
||||
},
|
||||
# Case 2: Mix of Gemini shorthands and full names
|
||||
# Case 2: Gemini provider for fast response
|
||||
{
|
||||
"available": {
|
||||
"gemini-2.5-flash": ProviderType.GOOGLE,
|
||||
"gemini-2.5-pro": ProviderType.GOOGLE,
|
||||
},
|
||||
"env": {"GEMINI_API_KEY": "test-key"},
|
||||
"provider_type": ProviderType.GOOGLE,
|
||||
"category": ToolModelCategory.FAST_RESPONSE,
|
||||
"expected_contains": "flash",
|
||||
"expected": "gemini-2.5-flash",
|
||||
},
|
||||
# Case 3: Only shorthands available
|
||||
# Case 3: OpenAI provider for fast response
|
||||
{
|
||||
"available": {"o4-mini": ProviderType.OPENAI, "o3-mini": ProviderType.OPENAI},
|
||||
"env": {"OPENAI_API_KEY": "test-key"},
|
||||
"provider_type": ProviderType.OPENAI,
|
||||
"category": ToolModelCategory.FAST_RESPONSE,
|
||||
"expected": "o4-mini",
|
||||
"expected": "gpt-5", # Based on new preference order
|
||||
},
|
||||
]
|
||||
|
||||
for case in test_cases:
|
||||
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
|
||||
mock_get_available.return_value = case["available"]
|
||||
# Clear registry for clean test
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# First unregister all providers to ensure isolation
|
||||
for provider_type in list(ProviderType):
|
||||
ModelProviderRegistry.unregister_provider(provider_type)
|
||||
|
||||
with patch.dict(os.environ, case["env"], clear=False):
|
||||
# Register the appropriate provider
|
||||
if case["provider_type"] == ProviderType.OPENAI:
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
elif case["provider_type"] == ProviderType.GOOGLE:
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(case["category"])
|
||||
|
||||
if "expected" in case:
|
||||
assert model == case["expected"], f"Failed for case: {case}"
|
||||
elif "expected_contains" in case:
|
||||
assert (
|
||||
case["expected_contains"] in model
|
||||
), f"Expected '{case['expected_contains']}' in '{model}' for case: {case}"
|
||||
assert model == case["expected"], f"Failed for case: {case}, got {model}"
|
||||
|
||||
|
||||
class TestCustomProviderFallback:
|
||||
"""Test fallback to custom/openrouter providers."""
|
||||
|
||||
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
|
||||
def test_extended_reasoning_custom_fallback(self, mock_find_thinking):
|
||||
"""Test EXTENDED_REASONING falls back to custom thinking model."""
|
||||
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
|
||||
# No native models available, but OpenRouter is available
|
||||
mock_get_available.return_value = {"openrouter-model": ProviderType.OPENROUTER}
|
||||
mock_find_thinking.return_value = "custom/thinking-model"
|
||||
def test_extended_reasoning_custom_fallback(self):
|
||||
"""Test EXTENDED_REASONING with custom provider."""
|
||||
# Setup with custom provider
|
||||
ModelProviderRegistry.clear_cache()
|
||||
with patch.dict(os.environ, {"CUSTOM_API_URL": "http://localhost:11434", "CUSTOM_API_KEY": ""}, clear=False):
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model == "custom/thinking-model"
|
||||
mock_find_thinking.assert_called_once()
|
||||
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
|
||||
|
||||
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
|
||||
def test_extended_reasoning_final_fallback(self, mock_find_thinking):
|
||||
"""Test EXTENDED_REASONING falls back to pro when no custom found."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# No providers available
|
||||
mock_get_provider.return_value = None
|
||||
mock_find_thinking.return_value = None
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
|
||||
if provider:
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
# Should get a model from custom provider
|
||||
assert model is not None
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model == "gemini-2.5-pro"
|
||||
def test_extended_reasoning_final_fallback(self):
|
||||
"""Test EXTENDED_REASONING falls back to default when no providers."""
|
||||
# Clear all providers
|
||||
ModelProviderRegistry.clear_cache()
|
||||
for provider_type in list(
|
||||
ModelProviderRegistry._instance._providers.keys() if ModelProviderRegistry._instance else []
|
||||
):
|
||||
ModelProviderRegistry.unregister_provider(provider_type)
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
# Should fall back to hardcoded default
|
||||
assert model == "gemini-2.5-flash"
|
||||
|
||||
|
||||
class TestAutoModeErrorMessages:
|
||||
@@ -266,42 +307,45 @@ class TestAutoModeErrorMessages:
|
||||
class TestProviderHelperMethods:
|
||||
"""Test the helper methods for finding models from custom/openrouter."""
|
||||
|
||||
def test_find_extended_thinking_model_custom(self):
|
||||
"""Test finding thinking model from custom provider."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
def test_extended_reasoning_with_custom_provider(self):
|
||||
"""Test extended reasoning model selection with custom provider."""
|
||||
# Setup with custom provider
|
||||
with patch.dict(os.environ, {"CUSTOM_API_URL": "http://localhost:11434", "CUSTOM_API_KEY": ""}, clear=False):
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
# Mock custom provider with thinking model
|
||||
mock_custom = MagicMock(spec=CustomProvider)
|
||||
mock_custom.model_registry = {
|
||||
"model1": {"supports_extended_thinking": False},
|
||||
"model2": {"supports_extended_thinking": True},
|
||||
"model3": {"supports_extended_thinking": False},
|
||||
}
|
||||
mock_get_provider.side_effect = lambda ptype: mock_custom if ptype == ProviderType.CUSTOM else None
|
||||
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
|
||||
|
||||
model = ModelProviderRegistry._find_extended_thinking_model()
|
||||
assert model == "model2"
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
|
||||
if provider:
|
||||
# Custom provider should return a model for extended reasoning
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model is not None
|
||||
|
||||
def test_find_extended_thinking_model_openrouter(self):
|
||||
"""Test finding thinking model from openrouter."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock openrouter provider
|
||||
mock_openrouter = MagicMock()
|
||||
mock_openrouter.validate_model_name.side_effect = lambda m: m == "anthropic/claude-sonnet-4"
|
||||
mock_get_provider.side_effect = lambda ptype: mock_openrouter if ptype == ProviderType.OPENROUTER else None
|
||||
def test_extended_reasoning_with_openrouter(self):
|
||||
"""Test extended reasoning model selection with OpenRouter."""
|
||||
# Setup with OpenRouter provider
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}, clear=False):
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
|
||||
model = ModelProviderRegistry._find_extended_thinking_model()
|
||||
assert model == "anthropic/claude-sonnet-4"
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
|
||||
def test_find_extended_thinking_model_none_found(self):
|
||||
"""Test when no thinking model is found."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# No providers available
|
||||
mock_get_provider.return_value = None
|
||||
# OpenRouter should provide a model for extended reasoning
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
# Should return first available OpenRouter model
|
||||
assert model is not None
|
||||
|
||||
model = ModelProviderRegistry._find_extended_thinking_model()
|
||||
assert model is None
|
||||
def test_fallback_when_no_providers_available(self):
|
||||
"""Test fallback when no providers are available."""
|
||||
# Clear all providers
|
||||
ModelProviderRegistry.clear_cache()
|
||||
for provider_type in list(
|
||||
ModelProviderRegistry._instance._providers.keys() if ModelProviderRegistry._instance else []
|
||||
):
|
||||
ModelProviderRegistry.unregister_provider(provider_type)
|
||||
|
||||
# Should return hardcoded fallback
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
assert model == "gemini-2.5-flash"
|
||||
|
||||
|
||||
class TestEffectiveAutoMode:
|
||||
|
||||
Reference in New Issue
Block a user