Native support for xAI Grok3
Model shorthand mapping related fixes Comprehensive auto-mode related tests
This commit is contained in:
344
tests/test_auto_mode_provider_selection.py
Normal file
344
tests/test_auto_mode_provider_selection.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""Test auto mode provider selection logic specifically"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
class TestAutoModeProviderSelection:
|
||||
"""Test the core auto mode provider selection logic"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up clean state before each test."""
|
||||
# Clear restriction service cache
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Clear provider registry
|
||||
registry = ModelProviderRegistry()
|
||||
registry._providers.clear()
|
||||
registry._initialized_providers.clear()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test."""
|
||||
# Clear restriction service cache
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
def test_gemini_only_fallback_selection(self):
|
||||
"""Test auto mode fallback when only Gemini is available."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up environment - only Gemini available
|
||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||
for key in ["OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Register only Gemini provider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
# Test fallback selection for different categories
|
||||
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||
ToolModelCategory.EXTENDED_REASONING
|
||||
)
|
||||
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
|
||||
|
||||
# Should select appropriate Gemini models
|
||||
assert extended_reasoning in ["gemini-2.5-pro-preview-06-05", "pro"]
|
||||
assert fast_response in ["gemini-2.5-flash-preview-05-20", "flash"]
|
||||
assert balanced in ["gemini-2.5-flash-preview-05-20", "flash"]
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_openai_only_fallback_selection(self):
|
||||
"""Test auto mode fallback when only OpenAI is available."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up environment - only OpenAI available
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
for key in ["GEMINI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Register only OpenAI provider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
# Test fallback selection for different categories
|
||||
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||
ToolModelCategory.EXTENDED_REASONING
|
||||
)
|
||||
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
|
||||
|
||||
# Should select appropriate OpenAI models
|
||||
assert extended_reasoning in ["o3", "o3-mini", "o4-mini"] # Any available OpenAI model for reasoning
|
||||
assert fast_response in ["o4-mini", "o3-mini"] # Prefer faster models
|
||||
assert balanced in ["o4-mini", "o3-mini"] # Balanced selection
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_both_gemini_and_openai_priority(self):
|
||||
"""Test auto mode when both Gemini and OpenAI are available."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up environment - both Gemini and OpenAI available
|
||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
for key in ["XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Register both providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
# Test fallback selection for different categories
|
||||
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||
ToolModelCategory.EXTENDED_REASONING
|
||||
)
|
||||
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
|
||||
# Should prefer OpenAI for reasoning (based on fallback logic)
|
||||
assert extended_reasoning == "o3" # Should prefer O3 for extended reasoning
|
||||
|
||||
# Should prefer OpenAI for fast response
|
||||
assert fast_response == "o4-mini" # Should prefer O4-mini for fast response
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_xai_only_fallback_selection(self):
|
||||
"""Test auto mode fallback when only XAI is available."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up environment - only XAI available
|
||||
os.environ["XAI_API_KEY"] = "test-key"
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Register only XAI provider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
# Test fallback selection for different categories
|
||||
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||
ToolModelCategory.EXTENDED_REASONING
|
||||
)
|
||||
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
|
||||
# Should fallback to available models or default fallbacks
|
||||
# Since XAI models are not explicitly handled in fallback logic,
|
||||
# it should fall back to the hardcoded defaults
|
||||
assert extended_reasoning is not None
|
||||
assert fast_response is not None
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_available_models_respects_restrictions(self):
|
||||
"""Test that get_available_models respects model restrictions."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "OPENAI_ALLOWED_MODELS"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up environment with restrictions
|
||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_ALLOWED_MODELS"] = "o4-mini" # Only allow o4-mini
|
||||
|
||||
# Clear restriction service to pick up new restrictions
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Register both providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
# Get available models with restrictions
|
||||
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
|
||||
# Should include allowed OpenAI model
|
||||
assert "o4-mini" in available_models
|
||||
assert available_models["o4-mini"] == ProviderType.OPENAI
|
||||
|
||||
# Should NOT include restricted OpenAI models
|
||||
assert "o3" not in available_models
|
||||
assert "o3-mini" not in available_models
|
||||
|
||||
# Should include all Gemini models (no restrictions)
|
||||
assert "gemini-2.5-flash-preview-05-20" in available_models
|
||||
assert available_models["gemini-2.5-flash-preview-05-20"] == ProviderType.GOOGLE
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_model_validation_across_providers(self):
|
||||
"""Test that model validation works correctly across different providers."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up all providers
|
||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
os.environ["XAI_API_KEY"] = "test-key"
|
||||
|
||||
# Register all providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
# Test model validation - each provider should handle its own models
|
||||
# Gemini models
|
||||
gemini_provider = ModelProviderRegistry.get_provider_for_model("flash")
|
||||
assert gemini_provider is not None
|
||||
assert gemini_provider.get_provider_type() == ProviderType.GOOGLE
|
||||
|
||||
# OpenAI models
|
||||
openai_provider = ModelProviderRegistry.get_provider_for_model("o3")
|
||||
assert openai_provider is not None
|
||||
assert openai_provider.get_provider_type() == ProviderType.OPENAI
|
||||
|
||||
# XAI models
|
||||
xai_provider = ModelProviderRegistry.get_provider_for_model("grok")
|
||||
assert xai_provider is not None
|
||||
assert xai_provider.get_provider_type() == ProviderType.XAI
|
||||
|
||||
# Invalid model should return None
|
||||
invalid_provider = ModelProviderRegistry.get_provider_for_model("invalid-model-name")
|
||||
assert invalid_provider is None
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_alias_resolution_before_api_calls(self):
|
||||
"""Test that model aliases are resolved before being passed to providers."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up all providers
|
||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
os.environ["XAI_API_KEY"] = "test-key"
|
||||
|
||||
# Register all providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
# Test that providers resolve aliases correctly
|
||||
test_cases = [
|
||||
("flash", ProviderType.GOOGLE, "gemini-2.5-flash-preview-05-20"),
|
||||
("pro", ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05"),
|
||||
("mini", ProviderType.OPENAI, "o4-mini"),
|
||||
("o3mini", ProviderType.OPENAI, "o3-mini"),
|
||||
("grok", ProviderType.XAI, "grok-3"),
|
||||
("grokfast", ProviderType.XAI, "grok-3-fast"),
|
||||
]
|
||||
|
||||
for alias, expected_provider_type, expected_resolved_name in test_cases:
|
||||
provider = ModelProviderRegistry.get_provider_for_model(alias)
|
||||
assert provider is not None, f"No provider found for alias '{alias}'"
|
||||
assert provider.get_provider_type() == expected_provider_type, f"Wrong provider for '{alias}'"
|
||||
|
||||
# Test alias resolution
|
||||
resolved_name = provider._resolve_model_name(alias)
|
||||
assert (
|
||||
resolved_name == expected_resolved_name
|
||||
), f"Alias '{alias}' should resolve to '{expected_resolved_name}', got '{resolved_name}'"
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
Reference in New Issue
Block a user