Native support for xAI Grok3
Model shorthand mapping related fixes Comprehensive auto-mode related tests
This commit is contained in:
@@ -21,6 +21,8 @@ if "GEMINI_API_KEY" not in os.environ:
|
||||
os.environ["GEMINI_API_KEY"] = "dummy-key-for-tests"
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
os.environ["OPENAI_API_KEY"] = "dummy-key-for-tests"
|
||||
if "XAI_API_KEY" not in os.environ:
|
||||
os.environ["XAI_API_KEY"] = "dummy-key-for-tests"
|
||||
|
||||
# Set default model to a specific value for tests to avoid auto mode
|
||||
# This prevents all tests from failing due to missing model parameter
|
||||
@@ -46,10 +48,12 @@ from providers import ModelProviderRegistry # noqa: E402
|
||||
from providers.base import ProviderType # noqa: E402
|
||||
from providers.gemini import GeminiModelProvider # noqa: E402
|
||||
from providers.openai import OpenAIModelProvider # noqa: E402
|
||||
from providers.xai import XAIModelProvider # noqa: E402
|
||||
|
||||
# Register providers at test startup
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -90,6 +94,18 @@ def mock_provider_availability(request, monkeypatch):
|
||||
if marker:
|
||||
return
|
||||
|
||||
# Ensure providers are registered (in case other tests cleared the registry)
|
||||
from providers.base import ProviderType
|
||||
|
||||
registry = ModelProviderRegistry()
|
||||
|
||||
if ProviderType.GOOGLE not in registry._providers:
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
if ProviderType.OPENAI not in registry._providers:
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
if ProviderType.XAI not in registry._providers:
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
original_get_provider = ModelProviderRegistry.get_provider_for_model
|
||||
@@ -119,3 +135,31 @@ def mock_provider_availability(request, monkeypatch):
|
||||
return original_get_provider(model_name)
|
||||
|
||||
monkeypatch.setattr(ModelProviderRegistry, "get_provider_for_model", mock_get_provider_for_model)
|
||||
|
||||
# Also mock is_effective_auto_mode for all BaseTool instances to return False
|
||||
# unless we're specifically testing auto mode behavior
|
||||
from tools.base import BaseTool
|
||||
|
||||
def mock_is_effective_auto_mode(self):
|
||||
# If this is an auto mode test file or specific auto mode test, use the real logic
|
||||
test_file = request.node.fspath.basename if hasattr(request, "node") and hasattr(request.node, "fspath") else ""
|
||||
test_name = request.node.name if hasattr(request, "node") else ""
|
||||
|
||||
# Allow auto mode for tests in auto mode files or with auto in the name
|
||||
if (
|
||||
"auto_mode" in test_file.lower()
|
||||
or "auto" in test_name.lower()
|
||||
or "intelligent_fallback" in test_file.lower()
|
||||
or "per_tool_model_defaults" in test_file.lower()
|
||||
):
|
||||
# Call original method logic
|
||||
from config import DEFAULT_MODEL
|
||||
|
||||
if DEFAULT_MODEL.lower() == "auto":
|
||||
return True
|
||||
provider = ModelProviderRegistry.get_provider_for_model(DEFAULT_MODEL)
|
||||
return provider is None
|
||||
# For all other tests, return False to disable auto mode
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode)
|
||||
|
||||
582
tests/test_auto_mode_comprehensive.py
Normal file
582
tests/test_auto_mode_comprehensive.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""Comprehensive tests for auto mode functionality across all provider combinations"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from tools.analyze import AnalyzeTool
|
||||
from tools.chat import ChatTool
|
||||
from tools.debug import DebugIssueTool
|
||||
from tools.models import ToolModelCategory
|
||||
from tools.thinkdeep import ThinkDeepTool
|
||||
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
class TestAutoModeComprehensive:
|
||||
"""Test auto mode model selection across all provider combinations"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up clean state before each test."""
|
||||
# Save original environment state for restoration
|
||||
import os
|
||||
|
||||
self._original_default_model = os.environ.get("DEFAULT_MODEL", "")
|
||||
|
||||
# Clear restriction service cache
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Clear provider registry by resetting singleton instance
|
||||
ModelProviderRegistry._instance = None
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test."""
|
||||
# Restore original DEFAULT_MODEL
|
||||
import os
|
||||
|
||||
if self._original_default_model:
|
||||
os.environ["DEFAULT_MODEL"] = self._original_default_model
|
||||
elif "DEFAULT_MODEL" in os.environ:
|
||||
del os.environ["DEFAULT_MODEL"]
|
||||
|
||||
# Reload config to pick up the restored DEFAULT_MODEL
|
||||
import importlib
|
||||
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Clear restriction service cache
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Clear provider registry by resetting singleton instance
|
||||
ModelProviderRegistry._instance = None
|
||||
|
||||
# Re-register providers for subsequent tests (like conftest.py does)
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_config,expected_models",
|
||||
[
|
||||
# Only Gemini API available
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "gemini-2.5-pro-preview-06-05", # Pro for deep thinking
|
||||
"FAST_RESPONSE": "gemini-2.5-flash-preview-05-20", # Flash for speed
|
||||
"BALANCED": "gemini-2.5-flash-preview-05-20", # Flash as balanced
|
||||
},
|
||||
),
|
||||
# Only OpenAI API available
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": None,
|
||||
"OPENAI_API_KEY": "real-key",
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "o3", # O3 for deep reasoning
|
||||
"FAST_RESPONSE": "o4-mini", # O4-mini for speed
|
||||
"BALANCED": "o4-mini", # O4-mini as balanced
|
||||
},
|
||||
),
|
||||
# Only X.AI API available
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": None,
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": "real-key",
|
||||
"OPENROUTER_API_KEY": None,
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "grok-3", # GROK-3 for reasoning
|
||||
"FAST_RESPONSE": "grok-3-fast", # GROK-3-fast for speed
|
||||
"BALANCED": "grok-3", # GROK-3 as balanced
|
||||
},
|
||||
),
|
||||
# Both Gemini and OpenAI available - should prefer based on tool category
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": "real-key",
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
|
||||
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
|
||||
"BALANCED": "o4-mini", # Prefer OpenAI for balanced
|
||||
},
|
||||
),
|
||||
# All native APIs available - should prefer based on tool category
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": "real-key",
|
||||
"XAI_API_KEY": "real-key",
|
||||
"OPENROUTER_API_KEY": None,
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
|
||||
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
|
||||
"BALANCED": "o4-mini", # Prefer OpenAI for balanced
|
||||
},
|
||||
),
|
||||
# Only OpenRouter available - should fall back to proxy models
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": None,
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": "real-key",
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "anthropic/claude-3.5-sonnet", # First preferred thinking model from OpenRouter
|
||||
"FAST_RESPONSE": "anthropic/claude-3-opus", # First available OpenRouter model
|
||||
"BALANCED": "anthropic/claude-3-opus", # First available OpenRouter model
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_auto_mode_model_selection_by_provider(self, provider_config, expected_models):
|
||||
"""Test that auto mode selects correct models based on available providers."""
|
||||
|
||||
# Set up environment with specific provider configuration
|
||||
# Filter out None values and handle them separately
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
# Reload config to pick up auto mode
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register providers based on configuration
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
if provider_config.get("GEMINI_API_KEY"):
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
if provider_config.get("OPENAI_API_KEY"):
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
if provider_config.get("XAI_API_KEY"):
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
if provider_config.get("OPENROUTER_API_KEY"):
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
|
||||
# Test each tool category
|
||||
for category_name, expected_model in expected_models.items():
|
||||
category = ToolModelCategory(category_name.lower())
|
||||
|
||||
# Get preferred fallback model for this category
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(category)
|
||||
|
||||
assert fallback_model == expected_model, (
|
||||
f"Provider config {provider_config}: "
|
||||
f"Expected {expected_model} for {category_name}, got {fallback_model}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_class,expected_category",
|
||||
[
|
||||
(ChatTool, ToolModelCategory.FAST_RESPONSE),
|
||||
(AnalyzeTool, ToolModelCategory.EXTENDED_REASONING), # AnalyzeTool uses EXTENDED_REASONING
|
||||
(DebugIssueTool, ToolModelCategory.EXTENDED_REASONING),
|
||||
(ThinkDeepTool, ToolModelCategory.EXTENDED_REASONING),
|
||||
],
|
||||
)
|
||||
def test_tool_model_categories(self, tool_class, expected_category):
|
||||
"""Test that tools have the correct model categories."""
|
||||
tool = tool_class()
|
||||
assert tool.get_model_category() == expected_category
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_gemini_only_uses_correct_models(self):
|
||||
"""Test that auto mode with only Gemini uses flash for fast tools and pro for reasoning tools."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register only Gemini provider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
# Mock provider to capture what model is requested
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="test response", model_name="test-model", usage={"input_tokens": 10, "output_tokens": 5}
|
||||
)
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
|
||||
# Test ChatTool (FAST_RESPONSE) - should prefer flash
|
||||
chat_tool = ChatTool()
|
||||
await chat_tool.execute({"prompt": "test", "model": "auto"}) # This should trigger auto selection
|
||||
|
||||
# In auto mode, the tool should get an error requiring model selection
|
||||
# but the suggested model should be flash
|
||||
|
||||
# Reset mock for next test
|
||||
ModelProviderRegistry.get_provider_for_model.reset_mock()
|
||||
|
||||
# Test DebugIssueTool (EXTENDED_REASONING) - should prefer pro
|
||||
debug_tool = DebugIssueTool()
|
||||
await debug_tool.execute({"prompt": "test error", "model": "auto"})
|
||||
|
||||
def test_auto_mode_schema_includes_all_available_models(self):
|
||||
"""Test that auto mode schema includes all available models for user convenience."""
|
||||
|
||||
# Test with only Gemini available
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register only Gemini provider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
tool = AnalyzeTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Should have model as required field
|
||||
assert "model" in schema["required"]
|
||||
|
||||
# Should include all model options from global config
|
||||
model_schema = schema["properties"]["model"]
|
||||
assert "enum" in model_schema
|
||||
|
||||
available_models = model_schema["enum"]
|
||||
|
||||
# Should include Gemini models
|
||||
assert "flash" in available_models
|
||||
assert "pro" in available_models
|
||||
assert "gemini-2.5-flash-preview-05-20" in available_models
|
||||
assert "gemini-2.5-pro-preview-06-05" in available_models
|
||||
|
||||
# Should also include other models (users might have OpenRouter configured)
|
||||
# The schema should show all options; validation happens at runtime
|
||||
assert "o3" in available_models
|
||||
assert "o4-mini" in available_models
|
||||
assert "grok" in available_models
|
||||
assert "grok-3" in available_models
|
||||
|
||||
def test_auto_mode_schema_with_all_providers(self):
|
||||
"""Test that auto mode schema includes models from all available providers."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": "real-key",
|
||||
"XAI_API_KEY": "real-key",
|
||||
"OPENROUTER_API_KEY": None, # Don't include OpenRouter to avoid infinite models
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register all native providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
tool = AnalyzeTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
model_schema = schema["properties"]["model"]
|
||||
available_models = model_schema["enum"]
|
||||
|
||||
# Should include models from all providers
|
||||
# Gemini models
|
||||
assert "flash" in available_models
|
||||
assert "pro" in available_models
|
||||
|
||||
# OpenAI models
|
||||
assert "o3" in available_models
|
||||
assert "o4-mini" in available_models
|
||||
|
||||
# XAI models
|
||||
assert "grok" in available_models
|
||||
assert "grok-3" in available_models
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_model_parameter_required_error(self):
|
||||
"""Test that auto mode properly requires model parameter and suggests correct model."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register only Gemini provider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
# Test with ChatTool (FAST_RESPONSE category)
|
||||
chat_tool = ChatTool()
|
||||
result = await chat_tool.execute(
|
||||
{
|
||||
"prompt": "test"
|
||||
# Note: no model parameter provided in auto mode
|
||||
}
|
||||
)
|
||||
|
||||
# Should get error requiring model selection
|
||||
assert len(result) == 1
|
||||
response_text = result[0].text
|
||||
|
||||
# Parse JSON response to check error
|
||||
import json
|
||||
|
||||
response_data = json.loads(response_text)
|
||||
|
||||
assert response_data["status"] == "error"
|
||||
assert "Model parameter is required" in response_data["content"]
|
||||
assert "flash" in response_data["content"] # Should suggest flash for FAST_RESPONSE
|
||||
assert "category: fast_response" in response_data["content"]
|
||||
|
||||
def test_model_availability_with_restrictions(self):
|
||||
"""Test that auto mode respects model restrictions when selecting fallback models."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": "real-key",
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
"DEFAULT_MODEL": "auto",
|
||||
"OPENAI_ALLOWED_MODELS": "o4-mini", # Restrict OpenAI to only o4-mini
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Clear restriction service to pick up new env vars
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Register providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
# Get available models - should respect restrictions
|
||||
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
|
||||
# Should include restricted OpenAI model
|
||||
assert "o4-mini" in available_models
|
||||
|
||||
# Should NOT include non-restricted OpenAI models
|
||||
assert "o3" not in available_models
|
||||
assert "o3-mini" not in available_models
|
||||
|
||||
# Should still include all Gemini models (no restrictions)
|
||||
assert "gemini-2.5-flash-preview-05-20" in available_models
|
||||
assert "gemini-2.5-pro-preview-06-05" in available_models
|
||||
|
||||
def test_openrouter_fallback_when_no_native_apis(self):
|
||||
"""Test that OpenRouter provides fallback models when no native APIs are available."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": None,
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": "real-key",
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register only OpenRouter provider
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
|
||||
# Mock OpenRouter registry to return known models
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_models.return_value = [
|
||||
"google/gemini-2.5-flash-preview-05-20",
|
||||
"google/gemini-2.5-pro-preview-06-05",
|
||||
"openai/o3",
|
||||
"openai/o4-mini",
|
||||
"anthropic/claude-3-opus",
|
||||
]
|
||||
|
||||
with patch.object(OpenRouterProvider, "_registry", mock_registry):
|
||||
# Get preferred models for different categories
|
||||
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||
ToolModelCategory.EXTENDED_REASONING
|
||||
)
|
||||
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
|
||||
# Should fallback to known good models even via OpenRouter
|
||||
# The exact model depends on _find_extended_thinking_model implementation
|
||||
assert extended_reasoning is not None
|
||||
assert fast_response is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_actual_model_name_resolution_in_auto_mode(self):
|
||||
"""Test that when a model is selected in auto mode, the tool executes successfully."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register Gemini provider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
# Mock the actual provider to simulate successful execution
|
||||
mock_provider = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "test response"
|
||||
mock_response.model_name = "gemini-2.5-flash-preview-05-20" # The resolved name
|
||||
mock_response.usage = {"input_tokens": 10, "output_tokens": 5}
|
||||
# Mock _resolve_model_name to simulate alias resolution
|
||||
mock_provider._resolve_model_name = lambda alias: (
|
||||
"gemini-2.5-flash-preview-05-20" if alias == "flash" else alias
|
||||
)
|
||||
mock_provider.generate_content.return_value = mock_response
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
|
||||
chat_tool = ChatTool()
|
||||
result = await chat_tool.execute({"prompt": "test", "model": "flash"}) # Use alias in auto mode
|
||||
|
||||
# Should succeed with proper model resolution
|
||||
assert len(result) == 1
|
||||
# Just verify that the tool executed successfully and didn't return an error
|
||||
assert "error" not in result[0].text.lower()
|
||||
344
tests/test_auto_mode_provider_selection.py
Normal file
344
tests/test_auto_mode_provider_selection.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""Test auto mode provider selection logic specifically"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
class TestAutoModeProviderSelection:
|
||||
"""Test the core auto mode provider selection logic"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up clean state before each test."""
|
||||
# Clear restriction service cache
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Clear provider registry
|
||||
registry = ModelProviderRegistry()
|
||||
registry._providers.clear()
|
||||
registry._initialized_providers.clear()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test."""
|
||||
# Clear restriction service cache
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
def test_gemini_only_fallback_selection(self):
|
||||
"""Test auto mode fallback when only Gemini is available."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up environment - only Gemini available
|
||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||
for key in ["OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Register only Gemini provider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
# Test fallback selection for different categories
|
||||
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||
ToolModelCategory.EXTENDED_REASONING
|
||||
)
|
||||
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
|
||||
|
||||
# Should select appropriate Gemini models
|
||||
assert extended_reasoning in ["gemini-2.5-pro-preview-06-05", "pro"]
|
||||
assert fast_response in ["gemini-2.5-flash-preview-05-20", "flash"]
|
||||
assert balanced in ["gemini-2.5-flash-preview-05-20", "flash"]
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_openai_only_fallback_selection(self):
|
||||
"""Test auto mode fallback when only OpenAI is available."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up environment - only OpenAI available
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
for key in ["GEMINI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Register only OpenAI provider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
# Test fallback selection for different categories
|
||||
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||
ToolModelCategory.EXTENDED_REASONING
|
||||
)
|
||||
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
|
||||
|
||||
# Should select appropriate OpenAI models
|
||||
assert extended_reasoning in ["o3", "o3-mini", "o4-mini"] # Any available OpenAI model for reasoning
|
||||
assert fast_response in ["o4-mini", "o3-mini"] # Prefer faster models
|
||||
assert balanced in ["o4-mini", "o3-mini"] # Balanced selection
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_both_gemini_and_openai_priority(self):
|
||||
"""Test auto mode when both Gemini and OpenAI are available."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up environment - both Gemini and OpenAI available
|
||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
for key in ["XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Register both providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
# Test fallback selection for different categories
|
||||
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||
ToolModelCategory.EXTENDED_REASONING
|
||||
)
|
||||
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
|
||||
# Should prefer OpenAI for reasoning (based on fallback logic)
|
||||
assert extended_reasoning == "o3" # Should prefer O3 for extended reasoning
|
||||
|
||||
# Should prefer OpenAI for fast response
|
||||
assert fast_response == "o4-mini" # Should prefer O4-mini for fast response
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_xai_only_fallback_selection(self):
|
||||
"""Test auto mode fallback when only XAI is available."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up environment - only XAI available
|
||||
os.environ["XAI_API_KEY"] = "test-key"
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Register only XAI provider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
# Test fallback selection for different categories
|
||||
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||
ToolModelCategory.EXTENDED_REASONING
|
||||
)
|
||||
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
|
||||
# Should fallback to available models or default fallbacks
|
||||
# Since XAI models are not explicitly handled in fallback logic,
|
||||
# it should fall back to the hardcoded defaults
|
||||
assert extended_reasoning is not None
|
||||
assert fast_response is not None
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_available_models_respects_restrictions(self):
|
||||
"""Test that get_available_models respects model restrictions."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "OPENAI_ALLOWED_MODELS"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up environment with restrictions
|
||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_ALLOWED_MODELS"] = "o4-mini" # Only allow o4-mini
|
||||
|
||||
# Clear restriction service to pick up new restrictions
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Register both providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
# Get available models with restrictions
|
||||
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
|
||||
# Should include allowed OpenAI model
|
||||
assert "o4-mini" in available_models
|
||||
assert available_models["o4-mini"] == ProviderType.OPENAI
|
||||
|
||||
# Should NOT include restricted OpenAI models
|
||||
assert "o3" not in available_models
|
||||
assert "o3-mini" not in available_models
|
||||
|
||||
# Should include all Gemini models (no restrictions)
|
||||
assert "gemini-2.5-flash-preview-05-20" in available_models
|
||||
assert available_models["gemini-2.5-flash-preview-05-20"] == ProviderType.GOOGLE
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_model_validation_across_providers(self):
|
||||
"""Test that model validation works correctly across different providers."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up all providers
|
||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
os.environ["XAI_API_KEY"] = "test-key"
|
||||
|
||||
# Register all providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
# Test model validation - each provider should handle its own models
|
||||
# Gemini models
|
||||
gemini_provider = ModelProviderRegistry.get_provider_for_model("flash")
|
||||
assert gemini_provider is not None
|
||||
assert gemini_provider.get_provider_type() == ProviderType.GOOGLE
|
||||
|
||||
# OpenAI models
|
||||
openai_provider = ModelProviderRegistry.get_provider_for_model("o3")
|
||||
assert openai_provider is not None
|
||||
assert openai_provider.get_provider_type() == ProviderType.OPENAI
|
||||
|
||||
# XAI models
|
||||
xai_provider = ModelProviderRegistry.get_provider_for_model("grok")
|
||||
assert xai_provider is not None
|
||||
assert xai_provider.get_provider_type() == ProviderType.XAI
|
||||
|
||||
# Invalid model should return None
|
||||
invalid_provider = ModelProviderRegistry.get_provider_for_model("invalid-model-name")
|
||||
assert invalid_provider is None
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_alias_resolution_before_api_calls(self):
|
||||
"""Test that model aliases are resolved before being passed to providers."""
|
||||
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up all providers
|
||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
os.environ["XAI_API_KEY"] = "test-key"
|
||||
|
||||
# Register all providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
# Test that providers resolve aliases correctly
|
||||
test_cases = [
|
||||
("flash", ProviderType.GOOGLE, "gemini-2.5-flash-preview-05-20"),
|
||||
("pro", ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05"),
|
||||
("mini", ProviderType.OPENAI, "o4-mini"),
|
||||
("o3mini", ProviderType.OPENAI, "o3-mini"),
|
||||
("grok", ProviderType.XAI, "grok-3"),
|
||||
("grokfast", ProviderType.XAI, "grok-3-fast"),
|
||||
]
|
||||
|
||||
for alias, expected_provider_type, expected_resolved_name in test_cases:
|
||||
provider = ModelProviderRegistry.get_provider_for_model(alias)
|
||||
assert provider is not None, f"No provider found for alias '{alias}'"
|
||||
assert provider.get_provider_type() == expected_provider_type, f"Wrong provider for '{alias}'"
|
||||
|
||||
# Test alias resolution
|
||||
resolved_name = provider._resolve_model_name(alias)
|
||||
assert (
|
||||
resolved_name == expected_resolved_name
|
||||
), f"Alias '{alias}' should resolve to '{expected_resolved_name}', got '{resolved_name}'"
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
@@ -55,6 +55,8 @@ class TestClaudeContinuationOffers:
|
||||
"""Test Claude continuation offer functionality"""
|
||||
|
||||
def setup_method(self):
|
||||
# Note: Tool creation and schema generation happens here
|
||||
# If providers are not registered yet, tool might detect auto mode
|
||||
self.tool = ClaudeContinuationTool()
|
||||
# Set default model to avoid effective auto mode
|
||||
self.tool.default_model = "gemini-2.5-flash-preview-05-20"
|
||||
@@ -63,11 +65,15 @@ class TestClaudeContinuationOffers:
|
||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||
async def test_new_conversation_offers_continuation(self, mock_redis):
|
||||
"""Test that new conversations offer Claude continuation opportunity"""
|
||||
# Create tool AFTER providers are registered (in conftest.py fixture)
|
||||
tool = ClaudeContinuationTool()
|
||||
tool.default_model = "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
mock_client = Mock()
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Mock the model
|
||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
@@ -81,7 +87,7 @@ class TestClaudeContinuationOffers:
|
||||
|
||||
# Execute tool without continuation_id (new conversation)
|
||||
arguments = {"prompt": "Analyze this code"}
|
||||
response = await self.tool.execute(arguments)
|
||||
response = await tool.execute(arguments)
|
||||
|
||||
# Parse response
|
||||
response_data = json.loads(response[0].text)
|
||||
@@ -177,10 +183,6 @@ class TestClaudeContinuationOffers:
|
||||
assert len(response) == 1
|
||||
response_data = json.loads(response[0].text)
|
||||
|
||||
# Debug output
|
||||
if response_data.get("status") == "error":
|
||||
print(f"Error content: {response_data.get('content')}")
|
||||
|
||||
assert response_data["status"] == "continuation_available"
|
||||
assert response_data["content"] == "Analysis complete. The code looks good."
|
||||
assert "continuation_offer" in response_data
|
||||
|
||||
@@ -17,51 +17,93 @@ class TestIntelligentFallback:
|
||||
"""Test intelligent model fallback logic"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup for each test - clear registry cache"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
"""Setup for each test - clear registry and reset providers"""
|
||||
# Store original providers for restoration
|
||||
registry = ModelProviderRegistry()
|
||||
self._original_providers = registry._providers.copy()
|
||||
self._original_initialized = registry._initialized_providers.copy()
|
||||
|
||||
# Clear registry completely
|
||||
ModelProviderRegistry._instance = None
|
||||
|
||||
def teardown_method(self):
|
||||
"""Cleanup after each test"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
"""Cleanup after each test - restore original providers"""
|
||||
# Restore original registry state
|
||||
registry = ModelProviderRegistry()
|
||||
registry._providers.clear()
|
||||
registry._initialized_providers.clear()
|
||||
registry._providers.update(self._original_providers)
|
||||
registry._initialized_providers.update(self._original_initialized)
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False)
|
||||
def test_prefers_openai_o3_mini_when_available(self):
|
||||
"""Test that o4-mini is preferred when OpenAI API key is available"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# Register only OpenAI provider for this test
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
assert fallback_model == "o4-mini"
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
|
||||
def test_prefers_gemini_flash_when_openai_unavailable(self):
|
||||
"""Test that gemini-2.5-flash-preview-05-20 is used when only Gemini API key is available"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# Register only Gemini provider for this test
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
assert fallback_model == "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
|
||||
def test_prefers_openai_when_both_available(self):
|
||||
"""Test that OpenAI is preferred when both API keys are available"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# Register both OpenAI and Gemini providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
assert fallback_model == "o4-mini" # OpenAI has priority
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False)
|
||||
def test_fallback_when_no_keys_available(self):
|
||||
"""Test fallback behavior when no API keys are available"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# Register providers but with no API keys available
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
assert fallback_model == "gemini-2.5-flash-preview-05-20" # Default fallback
|
||||
|
||||
def test_available_providers_with_keys(self):
|
||||
"""Test the get_available_providers_with_keys method"""
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False):
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# Clear and register providers
|
||||
ModelProviderRegistry._instance = None
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
available = ModelProviderRegistry.get_available_providers_with_keys()
|
||||
assert ProviderType.OPENAI in available
|
||||
assert ProviderType.GOOGLE not in available
|
||||
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False):
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# Clear and register providers
|
||||
ModelProviderRegistry._instance = None
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
available = ModelProviderRegistry.get_available_providers_with_keys()
|
||||
assert ProviderType.GOOGLE in available
|
||||
assert ProviderType.OPENAI not in available
|
||||
@@ -76,7 +118,10 @@ class TestIntelligentFallback:
|
||||
patch("config.DEFAULT_MODEL", "auto"),
|
||||
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False),
|
||||
):
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# Register only OpenAI provider for this test
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
# Create a context with at least one turn so it doesn't exit early
|
||||
from utils.conversation_memory import ConversationTurn
|
||||
@@ -114,7 +159,10 @@ class TestIntelligentFallback:
|
||||
patch("config.DEFAULT_MODEL", "auto"),
|
||||
patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False),
|
||||
):
|
||||
ModelProviderRegistry.clear_cache()
|
||||
# Register only Gemini provider for this test
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
from utils.conversation_memory import ConversationTurn
|
||||
|
||||
|
||||
@@ -243,10 +243,23 @@ class TestLargePromptHandling:
|
||||
tool = ChatTool()
|
||||
exact_prompt = "x" * MCP_PROMPT_SIZE_LIMIT
|
||||
|
||||
# With the fix, this should now pass because we check at MCP transport boundary before adding internal content
|
||||
result = await tool.execute({"prompt": exact_prompt})
|
||||
output = json.loads(result[0].text)
|
||||
assert output["status"] == "success"
|
||||
# Mock the model provider to avoid real API calls
|
||||
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="Response to the large prompt",
|
||||
usage={"input_tokens": 12000, "output_tokens": 10, "total_tokens": 12010},
|
||||
model_name="gemini-2.5-flash-preview-05-20",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
# With the fix, this should now pass because we check at MCP transport boundary before adding internal content
|
||||
result = await tool.execute({"prompt": exact_prompt})
|
||||
output = json.loads(result[0].text)
|
||||
assert output["status"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_boundary_case_just_over_limit(self):
|
||||
|
||||
@@ -535,18 +535,38 @@ class TestAutoModeWithRestrictions:
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GEMINI_API_KEY": "", "OPENAI_API_KEY": "test-key"})
|
||||
def test_fallback_with_shorthand_restrictions(self):
|
||||
"""Test fallback model selection with shorthand restrictions."""
|
||||
# Clear caches
|
||||
# Clear caches and reset registry
|
||||
import utils.model_restrictions
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
ModelProviderRegistry.clear_cache()
|
||||
|
||||
# Even with "mini" restriction, fallback should work if provider handles it correctly
|
||||
# This tests the real-world scenario
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
# Store original providers for restoration
|
||||
registry = ModelProviderRegistry()
|
||||
original_providers = registry._providers.copy()
|
||||
original_initialized = registry._initialized_providers.copy()
|
||||
|
||||
# The fallback will depend on how get_available_models handles aliases
|
||||
# For now, we accept either behavior and document it
|
||||
assert model in ["o4-mini", "gemini-2.5-flash-preview-05-20"]
|
||||
try:
|
||||
# Clear registry and register only OpenAI and Gemini providers
|
||||
ModelProviderRegistry._instance = None
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
# Even with "mini" restriction, fallback should work if provider handles it correctly
|
||||
# This tests the real-world scenario
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
|
||||
# The fallback will depend on how get_available_models handles aliases
|
||||
# For now, we accept either behavior and document it
|
||||
assert model in ["o4-mini", "gemini-2.5-flash-preview-05-20"]
|
||||
finally:
|
||||
# Restore original registry state
|
||||
registry = ModelProviderRegistry()
|
||||
registry._providers.clear()
|
||||
registry._initialized_providers.clear()
|
||||
registry._providers.update(original_providers)
|
||||
registry._initialized_providers.update(original_initialized)
|
||||
|
||||
221
tests/test_openai_provider.py
Normal file
221
tests/test_openai_provider.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Tests for OpenAI provider implementation."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
|
||||
class TestOpenAIProvider:
|
||||
"""Test OpenAI provider functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up clean state before each test."""
|
||||
# Clear restriction service cache before each test
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test to avoid singleton issues."""
|
||||
# Clear restriction service cache after each test
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"})
|
||||
def test_initialization(self):
|
||||
"""Test provider initialization."""
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.get_provider_type() == ProviderType.OPENAI
|
||||
assert provider.base_url == "https://api.openai.com/v1"
|
||||
|
||||
def test_initialization_with_custom_url(self):
|
||||
"""Test provider initialization with custom base URL."""
|
||||
provider = OpenAIModelProvider("test-key", base_url="https://custom.openai.com/v1")
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.base_url == "https://custom.openai.com/v1"
|
||||
|
||||
def test_model_validation(self):
|
||||
"""Test model name validation."""
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
# Test valid models
|
||||
assert provider.validate_model_name("o3") is True
|
||||
assert provider.validate_model_name("o3-mini") is True
|
||||
assert provider.validate_model_name("o3-pro") is True
|
||||
assert provider.validate_model_name("o4-mini") is True
|
||||
assert provider.validate_model_name("o4-mini-high") is True
|
||||
|
||||
# Test valid aliases
|
||||
assert provider.validate_model_name("mini") is True
|
||||
assert provider.validate_model_name("o3mini") is True
|
||||
assert provider.validate_model_name("o4mini") is True
|
||||
assert provider.validate_model_name("o4minihigh") is True
|
||||
assert provider.validate_model_name("o4minihi") is True
|
||||
|
||||
# Test invalid model
|
||||
assert provider.validate_model_name("invalid-model") is False
|
||||
assert provider.validate_model_name("gpt-4") is False
|
||||
assert provider.validate_model_name("gemini-pro") is False
|
||||
|
||||
def test_resolve_model_name(self):
|
||||
"""Test model name resolution."""
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
# Test shorthand resolution
|
||||
assert provider._resolve_model_name("mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("o3mini") == "o3-mini"
|
||||
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("o4minihigh") == "o4-mini-high"
|
||||
assert provider._resolve_model_name("o4minihi") == "o4-mini-high"
|
||||
|
||||
# Test full name passthrough
|
||||
assert provider._resolve_model_name("o3") == "o3"
|
||||
assert provider._resolve_model_name("o3-mini") == "o3-mini"
|
||||
assert provider._resolve_model_name("o3-pro") == "o3-pro"
|
||||
assert provider._resolve_model_name("o4-mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("o4-mini-high") == "o4-mini-high"
|
||||
|
||||
def test_get_capabilities_o3(self):
|
||||
"""Test getting model capabilities for O3."""
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("o3")
|
||||
assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities
|
||||
assert capabilities.friendly_name == "OpenAI"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.provider == ProviderType.OPENAI
|
||||
assert not capabilities.supports_extended_thinking
|
||||
assert capabilities.supports_system_prompts is True
|
||||
assert capabilities.supports_streaming is True
|
||||
assert capabilities.supports_function_calling is True
|
||||
|
||||
# Test temperature constraint (O3 has fixed temperature)
|
||||
assert capabilities.temperature_constraint.value == 1.0
|
||||
|
||||
def test_get_capabilities_with_alias(self):
|
||||
"""Test getting model capabilities with alias resolves correctly."""
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("mini")
|
||||
assert capabilities.model_name == "mini" # Capabilities should show original request
|
||||
assert capabilities.friendly_name == "OpenAI"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.provider == ProviderType.OPENAI
|
||||
|
||||
@patch("providers.openai_compatible.OpenAI")
|
||||
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
|
||||
"""Test that generate_content resolves aliases before making API calls.
|
||||
|
||||
This is the CRITICAL test that was missing - verifying that aliases
|
||||
like 'mini' get resolved to 'o4-mini' before being sent to OpenAI API.
|
||||
"""
|
||||
# Set up mock OpenAI client
|
||||
mock_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Mock the completion response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "o4-mini" # API returns the resolved model name
|
||||
mock_response.id = "test-id"
|
||||
mock_response.created = 1234567890
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_response.usage.total_tokens = 15
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
# Call generate_content with alias 'mini'
|
||||
result = provider.generate_content(
|
||||
prompt="Test prompt", model_name="mini", temperature=1.0 # This should be resolved to "o4-mini"
|
||||
)
|
||||
|
||||
# Verify the API was called with the RESOLVED model name
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
|
||||
# CRITICAL ASSERTION: The API should receive "o4-mini", not "mini"
|
||||
assert call_kwargs["model"] == "o4-mini", f"Expected 'o4-mini' but API received '{call_kwargs['model']}'"
|
||||
|
||||
# Verify other parameters
|
||||
assert call_kwargs["temperature"] == 1.0
|
||||
assert len(call_kwargs["messages"]) == 1
|
||||
assert call_kwargs["messages"][0]["role"] == "user"
|
||||
assert call_kwargs["messages"][0]["content"] == "Test prompt"
|
||||
|
||||
# Verify response
|
||||
assert result.content == "Test response"
|
||||
assert result.model_name == "o4-mini" # Should be the resolved name
|
||||
|
||||
@patch("providers.openai_compatible.OpenAI")
|
||||
def test_generate_content_other_aliases(self, mock_openai_class):
|
||||
"""Test other alias resolutions in generate_content."""
|
||||
# Set up mock
|
||||
mock_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_client
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_response.usage.total_tokens = 15
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
# Test o3mini -> o3-mini
|
||||
mock_response.model = "o3-mini"
|
||||
provider.generate_content(prompt="Test", model_name="o3mini", temperature=1.0)
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["model"] == "o3-mini"
|
||||
|
||||
# Test o4minihigh -> o4-mini-high
|
||||
mock_response.model = "o4-mini-high"
|
||||
provider.generate_content(prompt="Test", model_name="o4minihigh", temperature=1.0)
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["model"] == "o4-mini-high"
|
||||
|
||||
@patch("providers.openai_compatible.OpenAI")
|
||||
def test_generate_content_no_alias_passthrough(self, mock_openai_class):
|
||||
"""Test that full model names pass through unchanged."""
|
||||
# Set up mock
|
||||
mock_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_client
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "o3-pro"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_response.usage.total_tokens = 15
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
# Test full model name passes through unchanged
|
||||
provider.generate_content(prompt="Test", model_name="o3-pro", temperature=1.0)
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["model"] == "o3-pro" # Should be unchanged
|
||||
|
||||
def test_supports_thinking_mode(self):
|
||||
"""Test thinking mode support (currently False for all OpenAI models)."""
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
# All OpenAI models currently don't support thinking mode
|
||||
assert provider.supports_thinking_mode("o3") is False
|
||||
assert provider.supports_thinking_mode("o3-mini") is False
|
||||
assert provider.supports_thinking_mode("o4-mini") is False
|
||||
assert provider.supports_thinking_mode("mini") is False # Test with alias too
|
||||
@@ -202,9 +202,9 @@ class TestCustomProviderFallback:
|
||||
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
|
||||
def test_extended_reasoning_custom_fallback(self, mock_find_thinking):
|
||||
"""Test EXTENDED_REASONING falls back to custom thinking model."""
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# No native providers available
|
||||
mock_get_provider.return_value = None
|
||||
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
|
||||
# No native models available, but OpenRouter is available
|
||||
mock_get_available.return_value = {"openrouter-model": ProviderType.OPENROUTER}
|
||||
mock_find_thinking.return_value = "custom/thinking-model"
|
||||
|
||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||
|
||||
326
tests/test_xai_provider.py
Normal file
326
tests/test_xai_provider.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""Tests for X.AI provider implementation."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
|
||||
class TestXAIProvider:
|
||||
"""Test X.AI provider functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up clean state before each test."""
|
||||
# Clear restriction service cache before each test
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test to avoid singleton issues."""
|
||||
# Clear restriction service cache after each test
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
@patch.dict(os.environ, {"XAI_API_KEY": "test-key"})
|
||||
def test_initialization(self):
|
||||
"""Test provider initialization."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.get_provider_type() == ProviderType.XAI
|
||||
assert provider.base_url == "https://api.x.ai/v1"
|
||||
|
||||
def test_initialization_with_custom_url(self):
|
||||
"""Test provider initialization with custom base URL."""
|
||||
provider = XAIModelProvider("test-key", base_url="https://custom.x.ai/v1")
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.base_url == "https://custom.x.ai/v1"
|
||||
|
||||
def test_model_validation(self):
|
||||
"""Test model name validation."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# Test valid models
|
||||
assert provider.validate_model_name("grok-3") is True
|
||||
assert provider.validate_model_name("grok-3-fast") is True
|
||||
assert provider.validate_model_name("grok") is True
|
||||
assert provider.validate_model_name("grok3") is True
|
||||
assert provider.validate_model_name("grokfast") is True
|
||||
assert provider.validate_model_name("grok3fast") is True
|
||||
|
||||
# Test invalid model
|
||||
assert provider.validate_model_name("invalid-model") is False
|
||||
assert provider.validate_model_name("gpt-4") is False
|
||||
assert provider.validate_model_name("gemini-pro") is False
|
||||
|
||||
def test_resolve_model_name(self):
|
||||
"""Test model name resolution."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# Test shorthand resolution
|
||||
assert provider._resolve_model_name("grok") == "grok-3"
|
||||
assert provider._resolve_model_name("grok3") == "grok-3"
|
||||
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
|
||||
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
||||
|
||||
# Test full name passthrough
|
||||
assert provider._resolve_model_name("grok-3") == "grok-3"
|
||||
assert provider._resolve_model_name("grok-3-fast") == "grok-3-fast"
|
||||
|
||||
def test_get_capabilities_grok3(self):
|
||||
"""Test getting model capabilities for GROK-3."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("grok-3")
|
||||
assert capabilities.model_name == "grok-3"
|
||||
assert capabilities.friendly_name == "X.AI"
|
||||
assert capabilities.context_window == 131_072
|
||||
assert capabilities.provider == ProviderType.XAI
|
||||
assert not capabilities.supports_extended_thinking
|
||||
assert capabilities.supports_system_prompts is True
|
||||
assert capabilities.supports_streaming is True
|
||||
assert capabilities.supports_function_calling is True
|
||||
|
||||
# Test temperature range
|
||||
assert capabilities.temperature_constraint.min_temp == 0.0
|
||||
assert capabilities.temperature_constraint.max_temp == 2.0
|
||||
assert capabilities.temperature_constraint.default_temp == 0.7
|
||||
|
||||
def test_get_capabilities_grok3_fast(self):
|
||||
"""Test getting model capabilities for GROK-3 Fast."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("grok-3-fast")
|
||||
assert capabilities.model_name == "grok-3-fast"
|
||||
assert capabilities.friendly_name == "X.AI"
|
||||
assert capabilities.context_window == 131_072
|
||||
assert capabilities.provider == ProviderType.XAI
|
||||
assert not capabilities.supports_extended_thinking
|
||||
|
||||
def test_get_capabilities_with_shorthand(self):
|
||||
"""Test getting model capabilities with shorthand."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("grok")
|
||||
assert capabilities.model_name == "grok-3" # Should resolve to full name
|
||||
assert capabilities.context_window == 131_072
|
||||
|
||||
capabilities_fast = provider.get_capabilities("grokfast")
|
||||
assert capabilities_fast.model_name == "grok-3-fast" # Should resolve to full name
|
||||
|
||||
def test_unsupported_model_capabilities(self):
|
||||
"""Test error handling for unsupported models."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported X.AI model"):
|
||||
provider.get_capabilities("invalid-model")
|
||||
|
||||
def test_no_thinking_mode_support(self):
|
||||
"""Test that X.AI models don't support thinking mode."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
assert not provider.supports_thinking_mode("grok-3")
|
||||
assert not provider.supports_thinking_mode("grok-3-fast")
|
||||
assert not provider.supports_thinking_mode("grok")
|
||||
assert not provider.supports_thinking_mode("grokfast")
|
||||
|
||||
def test_provider_type(self):
|
||||
"""Test provider type identification."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
assert provider.get_provider_type() == ProviderType.XAI
|
||||
|
||||
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok-3"})
|
||||
def test_model_restrictions(self):
|
||||
"""Test model restrictions functionality."""
|
||||
# Clear cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# grok-3 should be allowed
|
||||
assert provider.validate_model_name("grok-3") is True
|
||||
assert provider.validate_model_name("grok") is True # Shorthand for grok-3
|
||||
|
||||
# grok-3-fast should be blocked by restrictions
|
||||
assert provider.validate_model_name("grok-3-fast") is False
|
||||
assert provider.validate_model_name("grokfast") is False
|
||||
|
||||
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3-fast"})
|
||||
def test_multiple_model_restrictions(self):
|
||||
"""Test multiple models in restrictions."""
|
||||
# Clear cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# Shorthand "grok" should be allowed (resolves to grok-3)
|
||||
assert provider.validate_model_name("grok") is True
|
||||
|
||||
# Full name "grok-3" should NOT be allowed (only shorthand "grok" is in restriction list)
|
||||
assert provider.validate_model_name("grok-3") is False
|
||||
|
||||
# "grok-3-fast" should be allowed (explicitly listed)
|
||||
assert provider.validate_model_name("grok-3-fast") is True
|
||||
|
||||
# Shorthand "grokfast" should be allowed (resolves to grok-3-fast)
|
||||
assert provider.validate_model_name("grokfast") is True
|
||||
|
||||
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3"})
|
||||
def test_both_shorthand_and_full_name_allowed(self):
|
||||
"""Test that both shorthand and full name can be allowed."""
|
||||
# Clear cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# Both shorthand and full name should be allowed
|
||||
assert provider.validate_model_name("grok") is True
|
||||
assert provider.validate_model_name("grok-3") is True
|
||||
|
||||
# Other models should not be allowed
|
||||
assert provider.validate_model_name("grok-3-fast") is False
|
||||
assert provider.validate_model_name("grokfast") is False
|
||||
|
||||
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": ""})
|
||||
def test_empty_restrictions_allows_all(self):
|
||||
"""Test that empty restrictions allow all models."""
|
||||
# Clear cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
assert provider.validate_model_name("grok-3") is True
|
||||
assert provider.validate_model_name("grok-3-fast") is True
|
||||
assert provider.validate_model_name("grok") is True
|
||||
assert provider.validate_model_name("grokfast") is True
|
||||
|
||||
def test_friendly_name(self):
|
||||
"""Test friendly name constant."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
assert provider.FRIENDLY_NAME == "X.AI"
|
||||
|
||||
capabilities = provider.get_capabilities("grok-3")
|
||||
assert capabilities.friendly_name == "X.AI"
|
||||
|
||||
def test_supported_models_structure(self):
|
||||
"""Test that SUPPORTED_MODELS has the correct structure."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# Check that all expected models are present
|
||||
assert "grok-3" in provider.SUPPORTED_MODELS
|
||||
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
||||
assert "grok" in provider.SUPPORTED_MODELS
|
||||
assert "grok3" in provider.SUPPORTED_MODELS
|
||||
assert "grokfast" in provider.SUPPORTED_MODELS
|
||||
assert "grok3fast" in provider.SUPPORTED_MODELS
|
||||
|
||||
# Check model configs have required fields
|
||||
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
||||
assert isinstance(grok3_config, dict)
|
||||
assert "context_window" in grok3_config
|
||||
assert "supports_extended_thinking" in grok3_config
|
||||
assert grok3_config["context_window"] == 131_072
|
||||
assert grok3_config["supports_extended_thinking"] is False
|
||||
|
||||
# Check shortcuts point to full names
|
||||
assert provider.SUPPORTED_MODELS["grok"] == "grok-3"
|
||||
assert provider.SUPPORTED_MODELS["grokfast"] == "grok-3-fast"
|
||||
|
||||
@patch("providers.openai_compatible.OpenAI")
|
||||
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
|
||||
"""Test that generate_content resolves aliases before making API calls.
|
||||
|
||||
This is the CRITICAL test that ensures aliases like 'grok' get resolved
|
||||
to 'grok-3' before being sent to X.AI API.
|
||||
"""
|
||||
# Set up mock OpenAI client
|
||||
mock_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Mock the completion response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "grok-3" # API returns the resolved model name
|
||||
mock_response.id = "test-id"
|
||||
mock_response.created = 1234567890
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_response.usage.total_tokens = 15
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# Call generate_content with alias 'grok'
|
||||
result = provider.generate_content(
|
||||
prompt="Test prompt", model_name="grok", temperature=0.7 # This should be resolved to "grok-3"
|
||||
)
|
||||
|
||||
# Verify the API was called with the RESOLVED model name
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
|
||||
# CRITICAL ASSERTION: The API should receive "grok-3", not "grok"
|
||||
assert call_kwargs["model"] == "grok-3", f"Expected 'grok-3' but API received '{call_kwargs['model']}'"
|
||||
|
||||
# Verify other parameters
|
||||
assert call_kwargs["temperature"] == 0.7
|
||||
assert len(call_kwargs["messages"]) == 1
|
||||
assert call_kwargs["messages"][0]["role"] == "user"
|
||||
assert call_kwargs["messages"][0]["content"] == "Test prompt"
|
||||
|
||||
# Verify response
|
||||
assert result.content == "Test response"
|
||||
assert result.model_name == "grok-3" # Should be the resolved name
|
||||
|
||||
@patch("providers.openai_compatible.OpenAI")
|
||||
def test_generate_content_other_aliases(self, mock_openai_class):
|
||||
"""Test other alias resolutions in generate_content."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Set up mock
|
||||
mock_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_client
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_response.usage.total_tokens = 15
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# Test grok3 -> grok-3
|
||||
mock_response.model = "grok-3"
|
||||
provider.generate_content(prompt="Test", model_name="grok3", temperature=0.7)
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["model"] == "grok-3"
|
||||
|
||||
# Test grokfast -> grok-3-fast
|
||||
mock_response.model = "grok-3-fast"
|
||||
provider.generate_content(prompt="Test", model_name="grokfast", temperature=0.7)
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["model"] == "grok-3-fast"
|
||||
|
||||
# Test grok3fast -> grok-3-fast
|
||||
provider.generate_content(prompt="Test", model_name="grok3fast", temperature=0.7)
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["model"] == "grok-3-fast"
|
||||
Reference in New Issue
Block a user