Native support for xAI Grok3

Model shorthand mapping related fixes
Comprehensive auto-mode related tests
This commit is contained in:
Fahad
2025-06-15 12:21:44 +04:00
parent 4becd70a82
commit 6304b7af6b
24 changed files with 2278 additions and 58 deletions

View File

@@ -21,6 +21,8 @@ if "GEMINI_API_KEY" not in os.environ:
os.environ["GEMINI_API_KEY"] = "dummy-key-for-tests"
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = "dummy-key-for-tests"
if "XAI_API_KEY" not in os.environ:
os.environ["XAI_API_KEY"] = "dummy-key-for-tests"
# Set default model to a specific value for tests to avoid auto mode
# This prevents all tests from failing due to missing model parameter
@@ -46,10 +48,12 @@ from providers import ModelProviderRegistry # noqa: E402
from providers.base import ProviderType # noqa: E402
from providers.gemini import GeminiModelProvider # noqa: E402
from providers.openai import OpenAIModelProvider # noqa: E402
from providers.xai import XAIModelProvider # noqa: E402
# Register providers at test startup
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
@pytest.fixture
@@ -90,6 +94,18 @@ def mock_provider_availability(request, monkeypatch):
if marker:
return
# Ensure providers are registered (in case other tests cleared the registry)
from providers.base import ProviderType
registry = ModelProviderRegistry()
if ProviderType.GOOGLE not in registry._providers:
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
if ProviderType.OPENAI not in registry._providers:
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
if ProviderType.XAI not in registry._providers:
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
from unittest.mock import MagicMock
original_get_provider = ModelProviderRegistry.get_provider_for_model
@@ -119,3 +135,31 @@ def mock_provider_availability(request, monkeypatch):
return original_get_provider(model_name)
monkeypatch.setattr(ModelProviderRegistry, "get_provider_for_model", mock_get_provider_for_model)
# Also mock is_effective_auto_mode for all BaseTool instances to return False
# unless we're specifically testing auto mode behavior
from tools.base import BaseTool
def mock_is_effective_auto_mode(self):
# If this is an auto mode test file or specific auto mode test, use the real logic
test_file = request.node.fspath.basename if hasattr(request, "node") and hasattr(request.node, "fspath") else ""
test_name = request.node.name if hasattr(request, "node") else ""
# Allow auto mode for tests in auto mode files or with auto in the name
if (
"auto_mode" in test_file.lower()
or "auto" in test_name.lower()
or "intelligent_fallback" in test_file.lower()
or "per_tool_model_defaults" in test_file.lower()
):
# Call original method logic
from config import DEFAULT_MODEL
if DEFAULT_MODEL.lower() == "auto":
return True
provider = ModelProviderRegistry.get_provider_for_model(DEFAULT_MODEL)
return provider is None
# For all other tests, return False to disable auto mode
return False
monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode)

View File

@@ -0,0 +1,582 @@
"""Comprehensive tests for auto mode functionality across all provider combinations"""
import importlib
import os
from unittest.mock import MagicMock, patch
import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
from tools.analyze import AnalyzeTool
from tools.chat import ChatTool
from tools.debug import DebugIssueTool
from tools.models import ToolModelCategory
from tools.thinkdeep import ThinkDeepTool
@pytest.mark.no_mock_provider
class TestAutoModeComprehensive:
"""Test auto mode model selection across all provider combinations"""
def setup_method(self):
"""Set up clean state before each test."""
# Save original environment state for restoration
import os
self._original_default_model = os.environ.get("DEFAULT_MODEL", "")
# Clear restriction service cache
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
# Clear provider registry by resetting singleton instance
ModelProviderRegistry._instance = None
def teardown_method(self):
"""Clean up after each test."""
# Restore original DEFAULT_MODEL
import os
if self._original_default_model:
os.environ["DEFAULT_MODEL"] = self._original_default_model
elif "DEFAULT_MODEL" in os.environ:
del os.environ["DEFAULT_MODEL"]
# Reload config to pick up the restored DEFAULT_MODEL
import importlib
import config
importlib.reload(config)
# Clear restriction service cache
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
# Clear provider registry by resetting singleton instance
ModelProviderRegistry._instance = None
# Re-register providers for subsequent tests (like conftest.py does)
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
from providers.xai import XAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
@pytest.mark.parametrize(
"provider_config,expected_models",
[
# Only Gemini API available
(
{
"GEMINI_API_KEY": "real-key",
"OPENAI_API_KEY": None,
"XAI_API_KEY": None,
"OPENROUTER_API_KEY": None,
},
{
"EXTENDED_REASONING": "gemini-2.5-pro-preview-06-05", # Pro for deep thinking
"FAST_RESPONSE": "gemini-2.5-flash-preview-05-20", # Flash for speed
"BALANCED": "gemini-2.5-flash-preview-05-20", # Flash as balanced
},
),
# Only OpenAI API available
(
{
"GEMINI_API_KEY": None,
"OPENAI_API_KEY": "real-key",
"XAI_API_KEY": None,
"OPENROUTER_API_KEY": None,
},
{
"EXTENDED_REASONING": "o3", # O3 for deep reasoning
"FAST_RESPONSE": "o4-mini", # O4-mini for speed
"BALANCED": "o4-mini", # O4-mini as balanced
},
),
# Only X.AI API available
(
{
"GEMINI_API_KEY": None,
"OPENAI_API_KEY": None,
"XAI_API_KEY": "real-key",
"OPENROUTER_API_KEY": None,
},
{
"EXTENDED_REASONING": "grok-3", # GROK-3 for reasoning
"FAST_RESPONSE": "grok-3-fast", # GROK-3-fast for speed
"BALANCED": "grok-3", # GROK-3 as balanced
},
),
# Both Gemini and OpenAI available - should prefer based on tool category
(
{
"GEMINI_API_KEY": "real-key",
"OPENAI_API_KEY": "real-key",
"XAI_API_KEY": None,
"OPENROUTER_API_KEY": None,
},
{
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
"BALANCED": "o4-mini", # Prefer OpenAI for balanced
},
),
# All native APIs available - should prefer based on tool category
(
{
"GEMINI_API_KEY": "real-key",
"OPENAI_API_KEY": "real-key",
"XAI_API_KEY": "real-key",
"OPENROUTER_API_KEY": None,
},
{
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
"BALANCED": "o4-mini", # Prefer OpenAI for balanced
},
),
# Only OpenRouter available - should fall back to proxy models
(
{
"GEMINI_API_KEY": None,
"OPENAI_API_KEY": None,
"XAI_API_KEY": None,
"OPENROUTER_API_KEY": "real-key",
},
{
"EXTENDED_REASONING": "anthropic/claude-3.5-sonnet", # First preferred thinking model from OpenRouter
"FAST_RESPONSE": "anthropic/claude-3-opus", # First available OpenRouter model
"BALANCED": "anthropic/claude-3-opus", # First available OpenRouter model
},
),
],
)
def test_auto_mode_model_selection_by_provider(self, provider_config, expected_models):
"""Test that auto mode selects correct models based on available providers."""
# Set up environment with specific provider configuration
# Filter out None values and handle them separately
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
env_to_clear = [k for k, v in provider_config.items() if v is None]
with patch.dict(os.environ, env_to_set, clear=False):
# Clear the None-valued environment variables
for key in env_to_clear:
if key in os.environ:
del os.environ[key]
# Reload config to pick up auto mode
os.environ["DEFAULT_MODEL"] = "auto"
import config
importlib.reload(config)
# Register providers based on configuration
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
from providers.openrouter import OpenRouterProvider
from providers.xai import XAIModelProvider
if provider_config.get("GEMINI_API_KEY"):
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
if provider_config.get("OPENAI_API_KEY"):
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
if provider_config.get("XAI_API_KEY"):
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
if provider_config.get("OPENROUTER_API_KEY"):
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
# Test each tool category
for category_name, expected_model in expected_models.items():
category = ToolModelCategory(category_name.lower())
# Get preferred fallback model for this category
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(category)
assert fallback_model == expected_model, (
f"Provider config {provider_config}: "
f"Expected {expected_model} for {category_name}, got {fallback_model}"
)
@pytest.mark.parametrize(
"tool_class,expected_category",
[
(ChatTool, ToolModelCategory.FAST_RESPONSE),
(AnalyzeTool, ToolModelCategory.EXTENDED_REASONING), # AnalyzeTool uses EXTENDED_REASONING
(DebugIssueTool, ToolModelCategory.EXTENDED_REASONING),
(ThinkDeepTool, ToolModelCategory.EXTENDED_REASONING),
],
)
def test_tool_model_categories(self, tool_class, expected_category):
"""Test that tools have the correct model categories."""
tool = tool_class()
assert tool.get_model_category() == expected_category
@pytest.mark.asyncio
async def test_auto_mode_with_gemini_only_uses_correct_models(self):
"""Test that auto mode with only Gemini uses flash for fast tools and pro for reasoning tools."""
provider_config = {
"GEMINI_API_KEY": "real-key",
"OPENAI_API_KEY": None,
"XAI_API_KEY": None,
"OPENROUTER_API_KEY": None,
"DEFAULT_MODEL": "auto",
}
# Filter out None values to avoid patch.dict errors
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
env_to_clear = [k for k, v in provider_config.items() if v is None]
with patch.dict(os.environ, env_to_set, clear=False):
# Clear the None-valued environment variables
for key in env_to_clear:
if key in os.environ:
del os.environ[key]
import config
importlib.reload(config)
# Register only Gemini provider
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
# Mock provider to capture what model is requested
mock_provider = MagicMock()
mock_provider.generate_content.return_value = MagicMock(
content="test response", model_name="test-model", usage={"input_tokens": 10, "output_tokens": 5}
)
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
# Test ChatTool (FAST_RESPONSE) - should prefer flash
chat_tool = ChatTool()
await chat_tool.execute({"prompt": "test", "model": "auto"}) # This should trigger auto selection
# In auto mode, the tool should get an error requiring model selection
# but the suggested model should be flash
# Reset mock for next test
ModelProviderRegistry.get_provider_for_model.reset_mock()
# Test DebugIssueTool (EXTENDED_REASONING) - should prefer pro
debug_tool = DebugIssueTool()
await debug_tool.execute({"prompt": "test error", "model": "auto"})
def test_auto_mode_schema_includes_all_available_models(self):
"""Test that auto mode schema includes all available models for user convenience."""
# Test with only Gemini available
provider_config = {
"GEMINI_API_KEY": "real-key",
"OPENAI_API_KEY": None,
"XAI_API_KEY": None,
"OPENROUTER_API_KEY": None,
"DEFAULT_MODEL": "auto",
}
# Filter out None values to avoid patch.dict errors
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
env_to_clear = [k for k, v in provider_config.items() if v is None]
with patch.dict(os.environ, env_to_set, clear=False):
# Clear the None-valued environment variables
for key in env_to_clear:
if key in os.environ:
del os.environ[key]
import config
importlib.reload(config)
# Register only Gemini provider
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
tool = AnalyzeTool()
schema = tool.get_input_schema()
# Should have model as required field
assert "model" in schema["required"]
# Should include all model options from global config
model_schema = schema["properties"]["model"]
assert "enum" in model_schema
available_models = model_schema["enum"]
# Should include Gemini models
assert "flash" in available_models
assert "pro" in available_models
assert "gemini-2.5-flash-preview-05-20" in available_models
assert "gemini-2.5-pro-preview-06-05" in available_models
# Should also include other models (users might have OpenRouter configured)
# The schema should show all options; validation happens at runtime
assert "o3" in available_models
assert "o4-mini" in available_models
assert "grok" in available_models
assert "grok-3" in available_models
def test_auto_mode_schema_with_all_providers(self):
"""Test that auto mode schema includes models from all available providers."""
provider_config = {
"GEMINI_API_KEY": "real-key",
"OPENAI_API_KEY": "real-key",
"XAI_API_KEY": "real-key",
"OPENROUTER_API_KEY": None, # Don't include OpenRouter to avoid infinite models
"DEFAULT_MODEL": "auto",
}
# Filter out None values to avoid patch.dict errors
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
env_to_clear = [k for k, v in provider_config.items() if v is None]
with patch.dict(os.environ, env_to_set, clear=False):
# Clear the None-valued environment variables
for key in env_to_clear:
if key in os.environ:
del os.environ[key]
import config
importlib.reload(config)
# Register all native providers
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
from providers.xai import XAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
tool = AnalyzeTool()
schema = tool.get_input_schema()
model_schema = schema["properties"]["model"]
available_models = model_schema["enum"]
# Should include models from all providers
# Gemini models
assert "flash" in available_models
assert "pro" in available_models
# OpenAI models
assert "o3" in available_models
assert "o4-mini" in available_models
# XAI models
assert "grok" in available_models
assert "grok-3" in available_models
@pytest.mark.asyncio
async def test_auto_mode_model_parameter_required_error(self):
"""Test that auto mode properly requires model parameter and suggests correct model."""
provider_config = {
"GEMINI_API_KEY": "real-key",
"OPENAI_API_KEY": None,
"XAI_API_KEY": None,
"OPENROUTER_API_KEY": None,
"DEFAULT_MODEL": "auto",
}
# Filter out None values to avoid patch.dict errors
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
env_to_clear = [k for k, v in provider_config.items() if v is None]
with patch.dict(os.environ, env_to_set, clear=False):
# Clear the None-valued environment variables
for key in env_to_clear:
if key in os.environ:
del os.environ[key]
import config
importlib.reload(config)
# Register only Gemini provider
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
# Test with ChatTool (FAST_RESPONSE category)
chat_tool = ChatTool()
result = await chat_tool.execute(
{
"prompt": "test"
# Note: no model parameter provided in auto mode
}
)
# Should get error requiring model selection
assert len(result) == 1
response_text = result[0].text
# Parse JSON response to check error
import json
response_data = json.loads(response_text)
assert response_data["status"] == "error"
assert "Model parameter is required" in response_data["content"]
assert "flash" in response_data["content"] # Should suggest flash for FAST_RESPONSE
assert "category: fast_response" in response_data["content"]
def test_model_availability_with_restrictions(self):
"""Test that auto mode respects model restrictions when selecting fallback models."""
provider_config = {
"GEMINI_API_KEY": "real-key",
"OPENAI_API_KEY": "real-key",
"XAI_API_KEY": None,
"OPENROUTER_API_KEY": None,
"DEFAULT_MODEL": "auto",
"OPENAI_ALLOWED_MODELS": "o4-mini", # Restrict OpenAI to only o4-mini
}
# Filter out None values to avoid patch.dict errors
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
env_to_clear = [k for k, v in provider_config.items() if v is None]
with patch.dict(os.environ, env_to_set, clear=False):
# Clear the None-valued environment variables
for key in env_to_clear:
if key in os.environ:
del os.environ[key]
import config
importlib.reload(config)
# Clear restriction service to pick up new env vars
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
# Register providers
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
# Get available models - should respect restrictions
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
# Should include restricted OpenAI model
assert "o4-mini" in available_models
# Should NOT include non-restricted OpenAI models
assert "o3" not in available_models
assert "o3-mini" not in available_models
# Should still include all Gemini models (no restrictions)
assert "gemini-2.5-flash-preview-05-20" in available_models
assert "gemini-2.5-pro-preview-06-05" in available_models
def test_openrouter_fallback_when_no_native_apis(self):
"""Test that OpenRouter provides fallback models when no native APIs are available."""
provider_config = {
"GEMINI_API_KEY": None,
"OPENAI_API_KEY": None,
"XAI_API_KEY": None,
"OPENROUTER_API_KEY": "real-key",
"DEFAULT_MODEL": "auto",
}
# Filter out None values to avoid patch.dict errors
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
env_to_clear = [k for k, v in provider_config.items() if v is None]
with patch.dict(os.environ, env_to_set, clear=False):
# Clear the None-valued environment variables
for key in env_to_clear:
if key in os.environ:
del os.environ[key]
import config
importlib.reload(config)
# Register only OpenRouter provider
from providers.openrouter import OpenRouterProvider
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
# Mock OpenRouter registry to return known models
mock_registry = MagicMock()
mock_registry.list_models.return_value = [
"google/gemini-2.5-flash-preview-05-20",
"google/gemini-2.5-pro-preview-06-05",
"openai/o3",
"openai/o4-mini",
"anthropic/claude-3-opus",
]
with patch.object(OpenRouterProvider, "_registry", mock_registry):
# Get preferred models for different categories
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
ToolModelCategory.EXTENDED_REASONING
)
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# Should fallback to known good models even via OpenRouter
# The exact model depends on _find_extended_thinking_model implementation
assert extended_reasoning is not None
assert fast_response is not None
@pytest.mark.asyncio
async def test_actual_model_name_resolution_in_auto_mode(self):
"""Test that when a model is selected in auto mode, the tool executes successfully."""
provider_config = {
"GEMINI_API_KEY": "real-key",
"OPENAI_API_KEY": None,
"XAI_API_KEY": None,
"OPENROUTER_API_KEY": None,
"DEFAULT_MODEL": "auto",
}
# Filter out None values to avoid patch.dict errors
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
env_to_clear = [k for k, v in provider_config.items() if v is None]
with patch.dict(os.environ, env_to_set, clear=False):
# Clear the None-valued environment variables
for key in env_to_clear:
if key in os.environ:
del os.environ[key]
import config
importlib.reload(config)
# Register Gemini provider
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
# Mock the actual provider to simulate successful execution
mock_provider = MagicMock()
mock_response = MagicMock()
mock_response.content = "test response"
mock_response.model_name = "gemini-2.5-flash-preview-05-20" # The resolved name
mock_response.usage = {"input_tokens": 10, "output_tokens": 5}
# Mock _resolve_model_name to simulate alias resolution
mock_provider._resolve_model_name = lambda alias: (
"gemini-2.5-flash-preview-05-20" if alias == "flash" else alias
)
mock_provider.generate_content.return_value = mock_response
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
chat_tool = ChatTool()
result = await chat_tool.execute({"prompt": "test", "model": "flash"}) # Use alias in auto mode
# Should succeed with proper model resolution
assert len(result) == 1
# Just verify that the tool executed successfully and didn't return an error
assert "error" not in result[0].text.lower()

View File

@@ -0,0 +1,344 @@
"""Test auto mode provider selection logic specifically"""
import os
import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
from tools.models import ToolModelCategory
@pytest.mark.no_mock_provider
class TestAutoModeProviderSelection:
"""Test the core auto mode provider selection logic"""
def setup_method(self):
"""Set up clean state before each test."""
# Clear restriction service cache
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
# Clear provider registry
registry = ModelProviderRegistry()
registry._providers.clear()
registry._initialized_providers.clear()
def teardown_method(self):
"""Clean up after each test."""
# Clear restriction service cache
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
def test_gemini_only_fallback_selection(self):
"""Test auto mode fallback when only Gemini is available."""
# Save original environment
original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
original_env[key] = os.environ.get(key)
try:
# Set up environment - only Gemini available
os.environ["GEMINI_API_KEY"] = "test-key"
for key in ["OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
os.environ.pop(key, None)
# Register only Gemini provider
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
# Test fallback selection for different categories
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
ToolModelCategory.EXTENDED_REASONING
)
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
# Should select appropriate Gemini models
assert extended_reasoning in ["gemini-2.5-pro-preview-06-05", "pro"]
assert fast_response in ["gemini-2.5-flash-preview-05-20", "flash"]
assert balanced in ["gemini-2.5-flash-preview-05-20", "flash"]
finally:
# Restore original environment
for key, value in original_env.items():
if value is not None:
os.environ[key] = value
else:
os.environ.pop(key, None)
def test_openai_only_fallback_selection(self):
"""Test auto mode fallback when only OpenAI is available."""
# Save original environment
original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
original_env[key] = os.environ.get(key)
try:
# Set up environment - only OpenAI available
os.environ["OPENAI_API_KEY"] = "test-key"
for key in ["GEMINI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
os.environ.pop(key, None)
# Register only OpenAI provider
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
# Test fallback selection for different categories
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
ToolModelCategory.EXTENDED_REASONING
)
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
# Should select appropriate OpenAI models
assert extended_reasoning in ["o3", "o3-mini", "o4-mini"] # Any available OpenAI model for reasoning
assert fast_response in ["o4-mini", "o3-mini"] # Prefer faster models
assert balanced in ["o4-mini", "o3-mini"] # Balanced selection
finally:
# Restore original environment
for key, value in original_env.items():
if value is not None:
os.environ[key] = value
else:
os.environ.pop(key, None)
def test_both_gemini_and_openai_priority(self):
"""Test auto mode when both Gemini and OpenAI are available."""
# Save original environment
original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
original_env[key] = os.environ.get(key)
try:
# Set up environment - both Gemini and OpenAI available
os.environ["GEMINI_API_KEY"] = "test-key"
os.environ["OPENAI_API_KEY"] = "test-key"
for key in ["XAI_API_KEY", "OPENROUTER_API_KEY"]:
os.environ.pop(key, None)
# Register both providers
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
# Test fallback selection for different categories
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
ToolModelCategory.EXTENDED_REASONING
)
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# Should prefer OpenAI for reasoning (based on fallback logic)
assert extended_reasoning == "o3" # Should prefer O3 for extended reasoning
# Should prefer OpenAI for fast response
assert fast_response == "o4-mini" # Should prefer O4-mini for fast response
finally:
# Restore original environment
for key, value in original_env.items():
if value is not None:
os.environ[key] = value
else:
os.environ.pop(key, None)
def test_xai_only_fallback_selection(self):
"""Test auto mode fallback when only XAI is available."""
# Save original environment
original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
original_env[key] = os.environ.get(key)
try:
# Set up environment - only XAI available
os.environ["XAI_API_KEY"] = "test-key"
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "OPENROUTER_API_KEY"]:
os.environ.pop(key, None)
# Register only XAI provider
from providers.xai import XAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
# Test fallback selection for different categories
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
ToolModelCategory.EXTENDED_REASONING
)
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# Should fallback to available models or default fallbacks
# Since XAI models are not explicitly handled in fallback logic,
# it should fall back to the hardcoded defaults
assert extended_reasoning is not None
assert fast_response is not None
finally:
# Restore original environment
for key, value in original_env.items():
if value is not None:
os.environ[key] = value
else:
os.environ.pop(key, None)
def test_available_models_respects_restrictions(self):
"""Test that get_available_models respects model restrictions."""
# Save original environment
original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "OPENAI_ALLOWED_MODELS"]:
original_env[key] = os.environ.get(key)
try:
# Set up environment with restrictions
os.environ["GEMINI_API_KEY"] = "test-key"
os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["OPENAI_ALLOWED_MODELS"] = "o4-mini" # Only allow o4-mini
# Clear restriction service to pick up new restrictions
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
# Register both providers
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
# Get available models with restrictions
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
# Should include allowed OpenAI model
assert "o4-mini" in available_models
assert available_models["o4-mini"] == ProviderType.OPENAI
# Should NOT include restricted OpenAI models
assert "o3" not in available_models
assert "o3-mini" not in available_models
# Should include all Gemini models (no restrictions)
assert "gemini-2.5-flash-preview-05-20" in available_models
assert available_models["gemini-2.5-flash-preview-05-20"] == ProviderType.GOOGLE
finally:
# Restore original environment
for key, value in original_env.items():
if value is not None:
os.environ[key] = value
else:
os.environ.pop(key, None)
def test_model_validation_across_providers(self):
"""Test that model validation works correctly across different providers."""
# Save original environment
original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"]:
original_env[key] = os.environ.get(key)
try:
# Set up all providers
os.environ["GEMINI_API_KEY"] = "test-key"
os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["XAI_API_KEY"] = "test-key"
# Register all providers
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
from providers.xai import XAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
# Test model validation - each provider should handle its own models
# Gemini models
gemini_provider = ModelProviderRegistry.get_provider_for_model("flash")
assert gemini_provider is not None
assert gemini_provider.get_provider_type() == ProviderType.GOOGLE
# OpenAI models
openai_provider = ModelProviderRegistry.get_provider_for_model("o3")
assert openai_provider is not None
assert openai_provider.get_provider_type() == ProviderType.OPENAI
# XAI models
xai_provider = ModelProviderRegistry.get_provider_for_model("grok")
assert xai_provider is not None
assert xai_provider.get_provider_type() == ProviderType.XAI
# Invalid model should return None
invalid_provider = ModelProviderRegistry.get_provider_for_model("invalid-model-name")
assert invalid_provider is None
finally:
# Restore original environment
for key, value in original_env.items():
if value is not None:
os.environ[key] = value
else:
os.environ.pop(key, None)
def test_alias_resolution_before_api_calls(self):
"""Test that model aliases are resolved before being passed to providers."""
# Save original environment
original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"]:
original_env[key] = os.environ.get(key)
try:
# Set up all providers
os.environ["GEMINI_API_KEY"] = "test-key"
os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["XAI_API_KEY"] = "test-key"
# Register all providers
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
from providers.xai import XAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
# Test that providers resolve aliases correctly
test_cases = [
("flash", ProviderType.GOOGLE, "gemini-2.5-flash-preview-05-20"),
("pro", ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05"),
("mini", ProviderType.OPENAI, "o4-mini"),
("o3mini", ProviderType.OPENAI, "o3-mini"),
("grok", ProviderType.XAI, "grok-3"),
("grokfast", ProviderType.XAI, "grok-3-fast"),
]
for alias, expected_provider_type, expected_resolved_name in test_cases:
provider = ModelProviderRegistry.get_provider_for_model(alias)
assert provider is not None, f"No provider found for alias '{alias}'"
assert provider.get_provider_type() == expected_provider_type, f"Wrong provider for '{alias}'"
# Test alias resolution
resolved_name = provider._resolve_model_name(alias)
assert (
resolved_name == expected_resolved_name
), f"Alias '{alias}' should resolve to '{expected_resolved_name}', got '{resolved_name}'"
finally:
# Restore original environment
for key, value in original_env.items():
if value is not None:
os.environ[key] = value
else:
os.environ.pop(key, None)

View File

@@ -55,6 +55,8 @@ class TestClaudeContinuationOffers:
"""Test Claude continuation offer functionality"""
def setup_method(self):
# Note: Tool creation and schema generation happens here
# If providers are not registered yet, tool might detect auto mode
self.tool = ClaudeContinuationTool()
# Set default model to avoid effective auto mode
self.tool.default_model = "gemini-2.5-flash-preview-05-20"
@@ -63,11 +65,15 @@ class TestClaudeContinuationOffers:
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_new_conversation_offers_continuation(self, mock_redis):
"""Test that new conversations offer Claude continuation opportunity"""
# Create tool AFTER providers are registered (in conftest.py fixture)
tool = ClaudeContinuationTool()
tool.default_model = "gemini-2.5-flash-preview-05-20"
mock_client = Mock()
mock_redis.return_value = mock_client
# Mock the model
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
@@ -81,7 +87,7 @@ class TestClaudeContinuationOffers:
# Execute tool without continuation_id (new conversation)
arguments = {"prompt": "Analyze this code"}
response = await self.tool.execute(arguments)
response = await tool.execute(arguments)
# Parse response
response_data = json.loads(response[0].text)
@@ -177,10 +183,6 @@ class TestClaudeContinuationOffers:
assert len(response) == 1
response_data = json.loads(response[0].text)
# Debug output
if response_data.get("status") == "error":
print(f"Error content: {response_data.get('content')}")
assert response_data["status"] == "continuation_available"
assert response_data["content"] == "Analysis complete. The code looks good."
assert "continuation_offer" in response_data

View File

@@ -17,51 +17,93 @@ class TestIntelligentFallback:
"""Test intelligent model fallback logic"""
def setup_method(self):
"""Setup for each test - clear registry cache"""
ModelProviderRegistry.clear_cache()
"""Setup for each test - clear registry and reset providers"""
# Store original providers for restoration
registry = ModelProviderRegistry()
self._original_providers = registry._providers.copy()
self._original_initialized = registry._initialized_providers.copy()
# Clear registry completely
ModelProviderRegistry._instance = None
def teardown_method(self):
"""Cleanup after each test"""
ModelProviderRegistry.clear_cache()
"""Cleanup after each test - restore original providers"""
# Restore original registry state
registry = ModelProviderRegistry()
registry._providers.clear()
registry._initialized_providers.clear()
registry._providers.update(self._original_providers)
registry._initialized_providers.update(self._original_initialized)
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False)
def test_prefers_openai_o3_mini_when_available(self):
"""Test that o4-mini is preferred when OpenAI API key is available"""
ModelProviderRegistry.clear_cache()
# Register only OpenAI provider for this test
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o4-mini"
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
def test_prefers_gemini_flash_when_openai_unavailable(self):
"""Test that gemini-2.5-flash-preview-05-20 is used when only Gemini API key is available"""
ModelProviderRegistry.clear_cache()
# Register only Gemini provider for this test
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "gemini-2.5-flash-preview-05-20"
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
def test_prefers_openai_when_both_available(self):
"""Test that OpenAI is preferred when both API keys are available"""
ModelProviderRegistry.clear_cache()
# Register both OpenAI and Gemini providers
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o4-mini" # OpenAI has priority
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False)
def test_fallback_when_no_keys_available(self):
"""Test fallback behavior when no API keys are available"""
ModelProviderRegistry.clear_cache()
# Register providers but with no API keys available
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "gemini-2.5-flash-preview-05-20" # Default fallback
def test_available_providers_with_keys(self):
"""Test the get_available_providers_with_keys method"""
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False):
ModelProviderRegistry.clear_cache()
# Clear and register providers
ModelProviderRegistry._instance = None
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
available = ModelProviderRegistry.get_available_providers_with_keys()
assert ProviderType.OPENAI in available
assert ProviderType.GOOGLE not in available
with patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False):
ModelProviderRegistry.clear_cache()
# Clear and register providers
ModelProviderRegistry._instance = None
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
available = ModelProviderRegistry.get_available_providers_with_keys()
assert ProviderType.GOOGLE in available
assert ProviderType.OPENAI not in available
@@ -76,7 +118,10 @@ class TestIntelligentFallback:
patch("config.DEFAULT_MODEL", "auto"),
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False),
):
ModelProviderRegistry.clear_cache()
# Register only OpenAI provider for this test
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
# Create a context with at least one turn so it doesn't exit early
from utils.conversation_memory import ConversationTurn
@@ -114,7 +159,10 @@ class TestIntelligentFallback:
patch("config.DEFAULT_MODEL", "auto"),
patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False),
):
ModelProviderRegistry.clear_cache()
# Register only Gemini provider for this test
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
from utils.conversation_memory import ConversationTurn

View File

@@ -243,10 +243,23 @@ class TestLargePromptHandling:
tool = ChatTool()
exact_prompt = "x" * MCP_PROMPT_SIZE_LIMIT
# With the fix, this should now pass because we check at MCP transport boundary before adding internal content
result = await tool.execute({"prompt": exact_prompt})
output = json.loads(result[0].text)
assert output["status"] == "success"
# Mock the model provider to avoid real API calls
with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = MagicMock()
mock_provider.get_provider_type.return_value = MagicMock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = MagicMock(
content="Response to the large prompt",
usage={"input_tokens": 12000, "output_tokens": 10, "total_tokens": 12010},
model_name="gemini-2.5-flash-preview-05-20",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
# With the fix, this should now pass because we check at MCP transport boundary before adding internal content
result = await tool.execute({"prompt": exact_prompt})
output = json.loads(result[0].text)
assert output["status"] == "success"
@pytest.mark.asyncio
async def test_boundary_case_just_over_limit(self):

View File

@@ -535,18 +535,38 @@ class TestAutoModeWithRestrictions:
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GEMINI_API_KEY": "", "OPENAI_API_KEY": "test-key"})
def test_fallback_with_shorthand_restrictions(self):
"""Test fallback model selection with shorthand restrictions."""
# Clear caches
# Clear caches and reset registry
import utils.model_restrictions
from providers.registry import ModelProviderRegistry
from tools.models import ToolModelCategory
utils.model_restrictions._restriction_service = None
ModelProviderRegistry.clear_cache()
# Even with "mini" restriction, fallback should work if provider handles it correctly
# This tests the real-world scenario
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# Store original providers for restoration
registry = ModelProviderRegistry()
original_providers = registry._providers.copy()
original_initialized = registry._initialized_providers.copy()
# The fallback will depend on how get_available_models handles aliases
# For now, we accept either behavior and document it
assert model in ["o4-mini", "gemini-2.5-flash-preview-05-20"]
try:
# Clear registry and register only OpenAI and Gemini providers
ModelProviderRegistry._instance = None
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
# Even with "mini" restriction, fallback should work if provider handles it correctly
# This tests the real-world scenario
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# The fallback will depend on how get_available_models handles aliases
# For now, we accept either behavior and document it
assert model in ["o4-mini", "gemini-2.5-flash-preview-05-20"]
finally:
# Restore original registry state
registry = ModelProviderRegistry()
registry._providers.clear()
registry._initialized_providers.clear()
registry._providers.update(original_providers)
registry._initialized_providers.update(original_initialized)

View File

@@ -0,0 +1,221 @@
"""Tests for OpenAI provider implementation."""
import os
from unittest.mock import MagicMock, patch
from providers.base import ProviderType
from providers.openai import OpenAIModelProvider
class TestOpenAIProvider:
"""Test OpenAI provider functionality."""
def setup_method(self):
"""Set up clean state before each test."""
# Clear restriction service cache before each test
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
def teardown_method(self):
"""Clean up after each test to avoid singleton issues."""
# Clear restriction service cache after each test
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
@patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"})
def test_initialization(self):
"""Test provider initialization."""
provider = OpenAIModelProvider("test-key")
assert provider.api_key == "test-key"
assert provider.get_provider_type() == ProviderType.OPENAI
assert provider.base_url == "https://api.openai.com/v1"
def test_initialization_with_custom_url(self):
"""Test provider initialization with custom base URL."""
provider = OpenAIModelProvider("test-key", base_url="https://custom.openai.com/v1")
assert provider.api_key == "test-key"
assert provider.base_url == "https://custom.openai.com/v1"
def test_model_validation(self):
"""Test model name validation."""
provider = OpenAIModelProvider("test-key")
# Test valid models
assert provider.validate_model_name("o3") is True
assert provider.validate_model_name("o3-mini") is True
assert provider.validate_model_name("o3-pro") is True
assert provider.validate_model_name("o4-mini") is True
assert provider.validate_model_name("o4-mini-high") is True
# Test valid aliases
assert provider.validate_model_name("mini") is True
assert provider.validate_model_name("o3mini") is True
assert provider.validate_model_name("o4mini") is True
assert provider.validate_model_name("o4minihigh") is True
assert provider.validate_model_name("o4minihi") is True
# Test invalid model
assert provider.validate_model_name("invalid-model") is False
assert provider.validate_model_name("gpt-4") is False
assert provider.validate_model_name("gemini-pro") is False
def test_resolve_model_name(self):
"""Test model name resolution."""
provider = OpenAIModelProvider("test-key")
# Test shorthand resolution
assert provider._resolve_model_name("mini") == "o4-mini"
assert provider._resolve_model_name("o3mini") == "o3-mini"
assert provider._resolve_model_name("o4mini") == "o4-mini"
assert provider._resolve_model_name("o4minihigh") == "o4-mini-high"
assert provider._resolve_model_name("o4minihi") == "o4-mini-high"
# Test full name passthrough
assert provider._resolve_model_name("o3") == "o3"
assert provider._resolve_model_name("o3-mini") == "o3-mini"
assert provider._resolve_model_name("o3-pro") == "o3-pro"
assert provider._resolve_model_name("o4-mini") == "o4-mini"
assert provider._resolve_model_name("o4-mini-high") == "o4-mini-high"
def test_get_capabilities_o3(self):
"""Test getting model capabilities for O3."""
provider = OpenAIModelProvider("test-key")
capabilities = provider.get_capabilities("o3")
assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities
assert capabilities.friendly_name == "OpenAI"
assert capabilities.context_window == 200_000
assert capabilities.provider == ProviderType.OPENAI
assert not capabilities.supports_extended_thinking
assert capabilities.supports_system_prompts is True
assert capabilities.supports_streaming is True
assert capabilities.supports_function_calling is True
# Test temperature constraint (O3 has fixed temperature)
assert capabilities.temperature_constraint.value == 1.0
def test_get_capabilities_with_alias(self):
"""Test getting model capabilities with alias resolves correctly."""
provider = OpenAIModelProvider("test-key")
capabilities = provider.get_capabilities("mini")
assert capabilities.model_name == "mini" # Capabilities should show original request
assert capabilities.friendly_name == "OpenAI"
assert capabilities.context_window == 200_000
assert capabilities.provider == ProviderType.OPENAI
@patch("providers.openai_compatible.OpenAI")
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
"""Test that generate_content resolves aliases before making API calls.
This is the CRITICAL test that was missing - verifying that aliases
like 'mini' get resolved to 'o4-mini' before being sent to OpenAI API.
"""
# Set up mock OpenAI client
mock_client = MagicMock()
mock_openai_class.return_value = mock_client
# Mock the completion response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "o4-mini" # API returns the resolved model name
mock_response.id = "test-id"
mock_response.created = 1234567890
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_response.usage.total_tokens = 15
mock_client.chat.completions.create.return_value = mock_response
provider = OpenAIModelProvider("test-key")
# Call generate_content with alias 'mini'
result = provider.generate_content(
prompt="Test prompt", model_name="mini", temperature=1.0 # This should be resolved to "o4-mini"
)
# Verify the API was called with the RESOLVED model name
mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args[1]
# CRITICAL ASSERTION: The API should receive "o4-mini", not "mini"
assert call_kwargs["model"] == "o4-mini", f"Expected 'o4-mini' but API received '{call_kwargs['model']}'"
# Verify other parameters
assert call_kwargs["temperature"] == 1.0
assert len(call_kwargs["messages"]) == 1
assert call_kwargs["messages"][0]["role"] == "user"
assert call_kwargs["messages"][0]["content"] == "Test prompt"
# Verify response
assert result.content == "Test response"
assert result.model_name == "o4-mini" # Should be the resolved name
@patch("providers.openai_compatible.OpenAI")
def test_generate_content_other_aliases(self, mock_openai_class):
"""Test other alias resolutions in generate_content."""
# Set up mock
mock_client = MagicMock()
mock_openai_class.return_value = mock_client
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response"
mock_response.choices[0].finish_reason = "stop"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_response.usage.total_tokens = 15
mock_client.chat.completions.create.return_value = mock_response
provider = OpenAIModelProvider("test-key")
# Test o3mini -> o3-mini
mock_response.model = "o3-mini"
provider.generate_content(prompt="Test", model_name="o3mini", temperature=1.0)
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "o3-mini"
# Test o4minihigh -> o4-mini-high
mock_response.model = "o4-mini-high"
provider.generate_content(prompt="Test", model_name="o4minihigh", temperature=1.0)
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "o4-mini-high"
@patch("providers.openai_compatible.OpenAI")
def test_generate_content_no_alias_passthrough(self, mock_openai_class):
"""Test that full model names pass through unchanged."""
# Set up mock
mock_client = MagicMock()
mock_openai_class.return_value = mock_client
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "o3-pro"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_response.usage.total_tokens = 15
mock_client.chat.completions.create.return_value = mock_response
provider = OpenAIModelProvider("test-key")
# Test full model name passes through unchanged
provider.generate_content(prompt="Test", model_name="o3-pro", temperature=1.0)
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "o3-pro" # Should be unchanged
def test_supports_thinking_mode(self):
"""Test thinking mode support (currently False for all OpenAI models)."""
provider = OpenAIModelProvider("test-key")
# All OpenAI models currently don't support thinking mode
assert provider.supports_thinking_mode("o3") is False
assert provider.supports_thinking_mode("o3-mini") is False
assert provider.supports_thinking_mode("o4-mini") is False
assert provider.supports_thinking_mode("mini") is False # Test with alias too

View File

@@ -202,9 +202,9 @@ class TestCustomProviderFallback:
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
def test_extended_reasoning_custom_fallback(self, mock_find_thinking):
"""Test EXTENDED_REASONING falls back to custom thinking model."""
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
# No native providers available
mock_get_provider.return_value = None
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
# No native models available, but OpenRouter is available
mock_get_available.return_value = {"openrouter-model": ProviderType.OPENROUTER}
mock_find_thinking.return_value = "custom/thinking-model"
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)

326
tests/test_xai_provider.py Normal file
View File

@@ -0,0 +1,326 @@
"""Tests for X.AI provider implementation."""
import os
from unittest.mock import MagicMock, patch
import pytest
from providers.base import ProviderType
from providers.xai import XAIModelProvider
class TestXAIProvider:
"""Test X.AI provider functionality."""
def setup_method(self):
"""Set up clean state before each test."""
# Clear restriction service cache before each test
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
def teardown_method(self):
"""Clean up after each test to avoid singleton issues."""
# Clear restriction service cache after each test
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
@patch.dict(os.environ, {"XAI_API_KEY": "test-key"})
def test_initialization(self):
"""Test provider initialization."""
provider = XAIModelProvider("test-key")
assert provider.api_key == "test-key"
assert provider.get_provider_type() == ProviderType.XAI
assert provider.base_url == "https://api.x.ai/v1"
def test_initialization_with_custom_url(self):
"""Test provider initialization with custom base URL."""
provider = XAIModelProvider("test-key", base_url="https://custom.x.ai/v1")
assert provider.api_key == "test-key"
assert provider.base_url == "https://custom.x.ai/v1"
def test_model_validation(self):
"""Test model name validation."""
provider = XAIModelProvider("test-key")
# Test valid models
assert provider.validate_model_name("grok-3") is True
assert provider.validate_model_name("grok-3-fast") is True
assert provider.validate_model_name("grok") is True
assert provider.validate_model_name("grok3") is True
assert provider.validate_model_name("grokfast") is True
assert provider.validate_model_name("grok3fast") is True
# Test invalid model
assert provider.validate_model_name("invalid-model") is False
assert provider.validate_model_name("gpt-4") is False
assert provider.validate_model_name("gemini-pro") is False
def test_resolve_model_name(self):
"""Test model name resolution."""
provider = XAIModelProvider("test-key")
# Test shorthand resolution
assert provider._resolve_model_name("grok") == "grok-3"
assert provider._resolve_model_name("grok3") == "grok-3"
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
# Test full name passthrough
assert provider._resolve_model_name("grok-3") == "grok-3"
assert provider._resolve_model_name("grok-3-fast") == "grok-3-fast"
def test_get_capabilities_grok3(self):
"""Test getting model capabilities for GROK-3."""
provider = XAIModelProvider("test-key")
capabilities = provider.get_capabilities("grok-3")
assert capabilities.model_name == "grok-3"
assert capabilities.friendly_name == "X.AI"
assert capabilities.context_window == 131_072
assert capabilities.provider == ProviderType.XAI
assert not capabilities.supports_extended_thinking
assert capabilities.supports_system_prompts is True
assert capabilities.supports_streaming is True
assert capabilities.supports_function_calling is True
# Test temperature range
assert capabilities.temperature_constraint.min_temp == 0.0
assert capabilities.temperature_constraint.max_temp == 2.0
assert capabilities.temperature_constraint.default_temp == 0.7
def test_get_capabilities_grok3_fast(self):
"""Test getting model capabilities for GROK-3 Fast."""
provider = XAIModelProvider("test-key")
capabilities = provider.get_capabilities("grok-3-fast")
assert capabilities.model_name == "grok-3-fast"
assert capabilities.friendly_name == "X.AI"
assert capabilities.context_window == 131_072
assert capabilities.provider == ProviderType.XAI
assert not capabilities.supports_extended_thinking
def test_get_capabilities_with_shorthand(self):
"""Test getting model capabilities with shorthand."""
provider = XAIModelProvider("test-key")
capabilities = provider.get_capabilities("grok")
assert capabilities.model_name == "grok-3" # Should resolve to full name
assert capabilities.context_window == 131_072
capabilities_fast = provider.get_capabilities("grokfast")
assert capabilities_fast.model_name == "grok-3-fast" # Should resolve to full name
def test_unsupported_model_capabilities(self):
"""Test error handling for unsupported models."""
provider = XAIModelProvider("test-key")
with pytest.raises(ValueError, match="Unsupported X.AI model"):
provider.get_capabilities("invalid-model")
def test_no_thinking_mode_support(self):
"""Test that X.AI models don't support thinking mode."""
provider = XAIModelProvider("test-key")
assert not provider.supports_thinking_mode("grok-3")
assert not provider.supports_thinking_mode("grok-3-fast")
assert not provider.supports_thinking_mode("grok")
assert not provider.supports_thinking_mode("grokfast")
def test_provider_type(self):
"""Test provider type identification."""
provider = XAIModelProvider("test-key")
assert provider.get_provider_type() == ProviderType.XAI
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok-3"})
def test_model_restrictions(self):
"""Test model restrictions functionality."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = XAIModelProvider("test-key")
# grok-3 should be allowed
assert provider.validate_model_name("grok-3") is True
assert provider.validate_model_name("grok") is True # Shorthand for grok-3
# grok-3-fast should be blocked by restrictions
assert provider.validate_model_name("grok-3-fast") is False
assert provider.validate_model_name("grokfast") is False
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3-fast"})
def test_multiple_model_restrictions(self):
"""Test multiple models in restrictions."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = XAIModelProvider("test-key")
# Shorthand "grok" should be allowed (resolves to grok-3)
assert provider.validate_model_name("grok") is True
# Full name "grok-3" should NOT be allowed (only shorthand "grok" is in restriction list)
assert provider.validate_model_name("grok-3") is False
# "grok-3-fast" should be allowed (explicitly listed)
assert provider.validate_model_name("grok-3-fast") is True
# Shorthand "grokfast" should be allowed (resolves to grok-3-fast)
assert provider.validate_model_name("grokfast") is True
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3"})
def test_both_shorthand_and_full_name_allowed(self):
"""Test that both shorthand and full name can be allowed."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = XAIModelProvider("test-key")
# Both shorthand and full name should be allowed
assert provider.validate_model_name("grok") is True
assert provider.validate_model_name("grok-3") is True
# Other models should not be allowed
assert provider.validate_model_name("grok-3-fast") is False
assert provider.validate_model_name("grokfast") is False
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": ""})
def test_empty_restrictions_allows_all(self):
"""Test that empty restrictions allow all models."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = XAIModelProvider("test-key")
assert provider.validate_model_name("grok-3") is True
assert provider.validate_model_name("grok-3-fast") is True
assert provider.validate_model_name("grok") is True
assert provider.validate_model_name("grokfast") is True
def test_friendly_name(self):
"""Test friendly name constant."""
provider = XAIModelProvider("test-key")
assert provider.FRIENDLY_NAME == "X.AI"
capabilities = provider.get_capabilities("grok-3")
assert capabilities.friendly_name == "X.AI"
def test_supported_models_structure(self):
"""Test that SUPPORTED_MODELS has the correct structure."""
provider = XAIModelProvider("test-key")
# Check that all expected models are present
assert "grok-3" in provider.SUPPORTED_MODELS
assert "grok-3-fast" in provider.SUPPORTED_MODELS
assert "grok" in provider.SUPPORTED_MODELS
assert "grok3" in provider.SUPPORTED_MODELS
assert "grokfast" in provider.SUPPORTED_MODELS
assert "grok3fast" in provider.SUPPORTED_MODELS
# Check model configs have required fields
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
assert isinstance(grok3_config, dict)
assert "context_window" in grok3_config
assert "supports_extended_thinking" in grok3_config
assert grok3_config["context_window"] == 131_072
assert grok3_config["supports_extended_thinking"] is False
# Check shortcuts point to full names
assert provider.SUPPORTED_MODELS["grok"] == "grok-3"
assert provider.SUPPORTED_MODELS["grokfast"] == "grok-3-fast"
@patch("providers.openai_compatible.OpenAI")
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
"""Test that generate_content resolves aliases before making API calls.
This is the CRITICAL test that ensures aliases like 'grok' get resolved
to 'grok-3' before being sent to X.AI API.
"""
# Set up mock OpenAI client
mock_client = MagicMock()
mock_openai_class.return_value = mock_client
# Mock the completion response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "grok-3" # API returns the resolved model name
mock_response.id = "test-id"
mock_response.created = 1234567890
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_response.usage.total_tokens = 15
mock_client.chat.completions.create.return_value = mock_response
provider = XAIModelProvider("test-key")
# Call generate_content with alias 'grok'
result = provider.generate_content(
prompt="Test prompt", model_name="grok", temperature=0.7 # This should be resolved to "grok-3"
)
# Verify the API was called with the RESOLVED model name
mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args[1]
# CRITICAL ASSERTION: The API should receive "grok-3", not "grok"
assert call_kwargs["model"] == "grok-3", f"Expected 'grok-3' but API received '{call_kwargs['model']}'"
# Verify other parameters
assert call_kwargs["temperature"] == 0.7
assert len(call_kwargs["messages"]) == 1
assert call_kwargs["messages"][0]["role"] == "user"
assert call_kwargs["messages"][0]["content"] == "Test prompt"
# Verify response
assert result.content == "Test response"
assert result.model_name == "grok-3" # Should be the resolved name
@patch("providers.openai_compatible.OpenAI")
def test_generate_content_other_aliases(self, mock_openai_class):
"""Test other alias resolutions in generate_content."""
from unittest.mock import MagicMock
# Set up mock
mock_client = MagicMock()
mock_openai_class.return_value = mock_client
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response"
mock_response.choices[0].finish_reason = "stop"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_response.usage.total_tokens = 15
mock_client.chat.completions.create.return_value = mock_response
provider = XAIModelProvider("test-key")
# Test grok3 -> grok-3
mock_response.model = "grok-3"
provider.generate_content(prompt="Test", model_name="grok3", temperature=0.7)
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "grok-3"
# Test grokfast -> grok-3-fast
mock_response.model = "grok-3-fast"
provider.generate_content(prompt="Test", model_name="grokfast", temperature=0.7)
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "grok-3-fast"
# Test grok3fast -> grok-3-fast
provider.generate_content(prompt="Test", model_name="grok3fast", temperature=0.7)
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "grok-3-fast"