|
|
|
|
@@ -47,17 +47,21 @@ class TestXAIProvider:
|
|
|
|
|
# Test valid models
|
|
|
|
|
assert provider.validate_model_name("grok-4") is True
|
|
|
|
|
assert provider.validate_model_name("grok4") is True
|
|
|
|
|
assert provider.validate_model_name("grok-3") is True
|
|
|
|
|
assert provider.validate_model_name("grok-3-fast") is True
|
|
|
|
|
assert provider.validate_model_name("grok") is True
|
|
|
|
|
assert provider.validate_model_name("grok3") is True
|
|
|
|
|
assert provider.validate_model_name("grokfast") is True
|
|
|
|
|
assert provider.validate_model_name("grok3fast") is True
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast") is True
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast-reasoning") is True
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast-reasoning-latest") is True
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast") is True
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast-reasoning") is True
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast-reasoning-latest") is True
|
|
|
|
|
|
|
|
|
|
# Test invalid model
|
|
|
|
|
assert provider.validate_model_name("invalid-model") is False
|
|
|
|
|
assert provider.validate_model_name("gpt-4") is False
|
|
|
|
|
assert provider.validate_model_name("gemini-pro") is False
|
|
|
|
|
assert provider.validate_model_name("grok-3") is False
|
|
|
|
|
assert provider.validate_model_name("grok-3-fast") is False
|
|
|
|
|
assert provider.validate_model_name("grokfast") is False
|
|
|
|
|
|
|
|
|
|
def test_resolve_model_name(self):
|
|
|
|
|
"""Test model name resolution."""
|
|
|
|
|
@@ -66,33 +70,12 @@ class TestXAIProvider:
|
|
|
|
|
# Test shorthand resolution
|
|
|
|
|
assert provider._resolve_model_name("grok") == "grok-4"
|
|
|
|
|
assert provider._resolve_model_name("grok4") == "grok-4"
|
|
|
|
|
assert provider._resolve_model_name("grok3") == "grok-3"
|
|
|
|
|
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
|
|
|
|
|
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
|
|
|
|
assert provider._resolve_model_name("grok-4.1-fast-reasoning") == "grok-4-1-fast-reasoning"
|
|
|
|
|
assert provider._resolve_model_name("grok-4.1-fast-reasoning-latest") == "grok-4-1-fast-reasoning"
|
|
|
|
|
|
|
|
|
|
# Test full name passthrough
|
|
|
|
|
assert provider._resolve_model_name("grok-4") == "grok-4"
|
|
|
|
|
assert provider._resolve_model_name("grok-3") == "grok-3"
|
|
|
|
|
assert provider._resolve_model_name("grok-3-fast") == "grok-3-fast"
|
|
|
|
|
|
|
|
|
|
def test_get_capabilities_grok3(self):
|
|
|
|
|
"""Test getting model capabilities for GROK-3."""
|
|
|
|
|
provider = XAIModelProvider("test-key")
|
|
|
|
|
|
|
|
|
|
capabilities = provider.get_capabilities("grok-3")
|
|
|
|
|
assert capabilities.model_name == "grok-3"
|
|
|
|
|
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
|
|
|
|
assert capabilities.context_window == 131_072
|
|
|
|
|
assert capabilities.provider == ProviderType.XAI
|
|
|
|
|
assert not capabilities.supports_extended_thinking
|
|
|
|
|
assert capabilities.supports_system_prompts is True
|
|
|
|
|
assert capabilities.supports_streaming is True
|
|
|
|
|
assert capabilities.supports_function_calling is True
|
|
|
|
|
|
|
|
|
|
# Test temperature range
|
|
|
|
|
assert capabilities.temperature_constraint.min_temp == 0.0
|
|
|
|
|
assert capabilities.temperature_constraint.max_temp == 2.0
|
|
|
|
|
assert capabilities.temperature_constraint.default_temp == 0.3
|
|
|
|
|
assert provider._resolve_model_name("grok-4.1-fast") == "grok-4-1-fast-reasoning"
|
|
|
|
|
|
|
|
|
|
def test_get_capabilities_grok4(self):
|
|
|
|
|
"""Test getting model capabilities for GROK-4."""
|
|
|
|
|
@@ -115,16 +98,19 @@ class TestXAIProvider:
|
|
|
|
|
assert capabilities.temperature_constraint.max_temp == 2.0
|
|
|
|
|
assert capabilities.temperature_constraint.default_temp == 0.3
|
|
|
|
|
|
|
|
|
|
def test_get_capabilities_grok3_fast(self):
|
|
|
|
|
"""Test getting model capabilities for GROK-3 Fast."""
|
|
|
|
|
def test_get_capabilities_grok4_1_fast(self):
|
|
|
|
|
"""Test getting model capabilities for GROK-4.1 Fast Reasoning."""
|
|
|
|
|
provider = XAIModelProvider("test-key")
|
|
|
|
|
|
|
|
|
|
capabilities = provider.get_capabilities("grok-3-fast")
|
|
|
|
|
assert capabilities.model_name == "grok-3-fast"
|
|
|
|
|
assert capabilities.friendly_name == "X.AI (Grok 3 Fast)"
|
|
|
|
|
assert capabilities.context_window == 131_072
|
|
|
|
|
capabilities = provider.get_capabilities("grok-4.1-fast")
|
|
|
|
|
assert capabilities.model_name == "grok-4-1-fast-reasoning"
|
|
|
|
|
assert capabilities.friendly_name == "X.AI (Grok 4.1 Fast Reasoning)"
|
|
|
|
|
assert capabilities.context_window == 2_000_000
|
|
|
|
|
assert capabilities.provider == ProviderType.XAI
|
|
|
|
|
assert not capabilities.supports_extended_thinking
|
|
|
|
|
assert capabilities.supports_extended_thinking is True
|
|
|
|
|
assert capabilities.supports_function_calling is True
|
|
|
|
|
assert capabilities.supports_json_mode is True
|
|
|
|
|
assert capabilities.supports_images is True
|
|
|
|
|
|
|
|
|
|
def test_get_capabilities_with_shorthand(self):
|
|
|
|
|
"""Test getting model capabilities with shorthand."""
|
|
|
|
|
@@ -134,8 +120,8 @@ class TestXAIProvider:
|
|
|
|
|
assert capabilities.model_name == "grok-4" # Should resolve to full name
|
|
|
|
|
assert capabilities.context_window == 256_000
|
|
|
|
|
|
|
|
|
|
capabilities_fast = provider.get_capabilities("grokfast")
|
|
|
|
|
assert capabilities_fast.model_name == "grok-3-fast" # Should resolve to full name
|
|
|
|
|
capabilities_fast = provider.get_capabilities("grok-4.1-fast-reasoning")
|
|
|
|
|
assert capabilities_fast.model_name == "grok-4-1-fast-reasoning" # Should resolve to full name
|
|
|
|
|
|
|
|
|
|
def test_unsupported_model_capabilities(self):
|
|
|
|
|
"""Test error handling for unsupported models."""
|
|
|
|
|
@@ -148,20 +134,23 @@ class TestXAIProvider:
|
|
|
|
|
"""X.AI capabilities should expose extended thinking support correctly."""
|
|
|
|
|
provider = XAIModelProvider("test-key")
|
|
|
|
|
|
|
|
|
|
thinking_aliases = ["grok-4", "grok", "grok4"]
|
|
|
|
|
thinking_aliases = [
|
|
|
|
|
"grok-4",
|
|
|
|
|
"grok",
|
|
|
|
|
"grok4",
|
|
|
|
|
"grok-4.1-fast",
|
|
|
|
|
"grok-4.1-fast-reasoning",
|
|
|
|
|
"grok-4.1-fast-reasoning-latest",
|
|
|
|
|
]
|
|
|
|
|
for alias in thinking_aliases:
|
|
|
|
|
assert provider.get_capabilities(alias).supports_extended_thinking is True
|
|
|
|
|
|
|
|
|
|
non_thinking_aliases = ["grok-3", "grok-3-fast", "grokfast"]
|
|
|
|
|
for alias in non_thinking_aliases:
|
|
|
|
|
assert provider.get_capabilities(alias).supports_extended_thinking is False
|
|
|
|
|
|
|
|
|
|
def test_provider_type(self):
|
|
|
|
|
"""Test provider type identification."""
|
|
|
|
|
provider = XAIModelProvider("test-key")
|
|
|
|
|
assert provider.get_provider_type() == ProviderType.XAI
|
|
|
|
|
|
|
|
|
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok-3"})
|
|
|
|
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok-4"})
|
|
|
|
|
def test_model_restrictions(self):
|
|
|
|
|
"""Test model restrictions functionality."""
|
|
|
|
|
# Clear cached restriction service
|
|
|
|
|
@@ -173,20 +162,17 @@ class TestXAIProvider:
|
|
|
|
|
|
|
|
|
|
provider = XAIModelProvider("test-key")
|
|
|
|
|
|
|
|
|
|
# grok-3 should be allowed
|
|
|
|
|
assert provider.validate_model_name("grok-3") is True
|
|
|
|
|
assert provider.validate_model_name("grok3") is True # Shorthand for grok-3
|
|
|
|
|
# grok-4 should be allowed (including alias)
|
|
|
|
|
assert provider.validate_model_name("grok-4") is True
|
|
|
|
|
assert provider.validate_model_name("grok") is True
|
|
|
|
|
|
|
|
|
|
# grok should be blocked (resolves to grok-4 which is not allowed)
|
|
|
|
|
assert provider.validate_model_name("grok") is False
|
|
|
|
|
# grok-4.1-fast should be blocked by restrictions
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast") is False
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast-reasoning") is False
|
|
|
|
|
|
|
|
|
|
# grok-3-fast should be blocked by restrictions
|
|
|
|
|
assert provider.validate_model_name("grok-3-fast") is False
|
|
|
|
|
assert provider.validate_model_name("grokfast") is False
|
|
|
|
|
|
|
|
|
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3-fast"})
|
|
|
|
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok-4.1-fast-reasoning"})
|
|
|
|
|
def test_multiple_model_restrictions(self):
|
|
|
|
|
"""Test multiple models in restrictions."""
|
|
|
|
|
"""Restrictions should allow aliases for Grok 4.1 Fast."""
|
|
|
|
|
# Clear cached restriction service
|
|
|
|
|
import utils.model_restrictions
|
|
|
|
|
from providers.registry import ModelProviderRegistry
|
|
|
|
|
@@ -196,24 +182,18 @@ class TestXAIProvider:
|
|
|
|
|
|
|
|
|
|
provider = XAIModelProvider("test-key")
|
|
|
|
|
|
|
|
|
|
# Shorthand "grok" should be allowed (resolves to grok-4)
|
|
|
|
|
assert provider.validate_model_name("grok") is True
|
|
|
|
|
# Alias should be allowed (resolves to grok-4.1-fast)
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast-reasoning") is True
|
|
|
|
|
|
|
|
|
|
# Full name "grok-4" should NOT be allowed (only shorthand "grok" is in restriction list)
|
|
|
|
|
# Canonical name is not allowed unless explicitly listed
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast") is False
|
|
|
|
|
|
|
|
|
|
# grok-4 should NOT be allowed
|
|
|
|
|
assert provider.validate_model_name("grok-4") is False
|
|
|
|
|
|
|
|
|
|
# "grok-3" should NOT be allowed (not in restriction list)
|
|
|
|
|
assert provider.validate_model_name("grok-3") is False
|
|
|
|
|
|
|
|
|
|
# "grok-3-fast" should be allowed (explicitly listed)
|
|
|
|
|
assert provider.validate_model_name("grok-3-fast") is True
|
|
|
|
|
|
|
|
|
|
# Shorthand "grokfast" should be allowed (resolves to grok-3-fast)
|
|
|
|
|
assert provider.validate_model_name("grokfast") is True
|
|
|
|
|
|
|
|
|
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3,grok-4"})
|
|
|
|
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-4.1-fast"})
|
|
|
|
|
def test_both_shorthand_and_full_name_allowed(self):
|
|
|
|
|
"""Test that both shorthand and full name can be allowed."""
|
|
|
|
|
"""Test that aliases and canonical names can be allowed together."""
|
|
|
|
|
# Clear cached restriction service
|
|
|
|
|
import utils.model_restrictions
|
|
|
|
|
|
|
|
|
|
@@ -223,12 +203,8 @@ class TestXAIProvider:
|
|
|
|
|
|
|
|
|
|
# Both shorthand and full name should be allowed
|
|
|
|
|
assert provider.validate_model_name("grok") is True # Resolves to grok-4
|
|
|
|
|
assert provider.validate_model_name("grok-3") is True
|
|
|
|
|
assert provider.validate_model_name("grok-4") is True
|
|
|
|
|
|
|
|
|
|
# Other models should not be allowed
|
|
|
|
|
assert provider.validate_model_name("grok-3-fast") is False
|
|
|
|
|
assert provider.validate_model_name("grokfast") is False
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast") is True
|
|
|
|
|
|
|
|
|
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": ""})
|
|
|
|
|
def test_empty_restrictions_allows_all(self):
|
|
|
|
|
@@ -241,10 +217,9 @@ class TestXAIProvider:
|
|
|
|
|
provider = XAIModelProvider("test-key")
|
|
|
|
|
|
|
|
|
|
assert provider.validate_model_name("grok-4") is True
|
|
|
|
|
assert provider.validate_model_name("grok-3") is True
|
|
|
|
|
assert provider.validate_model_name("grok-3-fast") is True
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast") is True
|
|
|
|
|
assert provider.validate_model_name("grok-4.1-fast-reasoning") is True
|
|
|
|
|
assert provider.validate_model_name("grok") is True
|
|
|
|
|
assert provider.validate_model_name("grokfast") is True
|
|
|
|
|
assert provider.validate_model_name("grok4") is True
|
|
|
|
|
|
|
|
|
|
def test_friendly_name(self):
|
|
|
|
|
@@ -252,8 +227,8 @@ class TestXAIProvider:
|
|
|
|
|
provider = XAIModelProvider("test-key")
|
|
|
|
|
assert provider.FRIENDLY_NAME == "X.AI"
|
|
|
|
|
|
|
|
|
|
capabilities = provider.get_capabilities("grok-3")
|
|
|
|
|
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
|
|
|
|
capabilities = provider.get_capabilities("grok-4")
|
|
|
|
|
assert capabilities.friendly_name == "X.AI (Grok 4)"
|
|
|
|
|
|
|
|
|
|
def test_supported_models_structure(self):
|
|
|
|
|
"""Test that MODEL_CAPABILITIES has the correct structure."""
|
|
|
|
|
@@ -261,8 +236,7 @@ class TestXAIProvider:
|
|
|
|
|
|
|
|
|
|
# Check that all expected base models are present
|
|
|
|
|
assert "grok-4" in provider.MODEL_CAPABILITIES
|
|
|
|
|
assert "grok-3" in provider.MODEL_CAPABILITIES
|
|
|
|
|
assert "grok-3-fast" in provider.MODEL_CAPABILITIES
|
|
|
|
|
assert "grok-4-1-fast-reasoning" in provider.MODEL_CAPABILITIES
|
|
|
|
|
|
|
|
|
|
# Check model configs have required fields
|
|
|
|
|
from providers.shared import ModelCapabilities
|
|
|
|
|
@@ -280,20 +254,11 @@ class TestXAIProvider:
|
|
|
|
|
assert "grok-4" in grok4_config.aliases
|
|
|
|
|
assert "grok4" in grok4_config.aliases
|
|
|
|
|
|
|
|
|
|
grok3_config = provider.MODEL_CAPABILITIES["grok-3"]
|
|
|
|
|
assert grok3_config.context_window == 131_072
|
|
|
|
|
assert grok3_config.supports_extended_thinking is False
|
|
|
|
|
# Check aliases are correctly structured
|
|
|
|
|
assert "grok3" in grok3_config.aliases # grok3 resolves to grok-3
|
|
|
|
|
|
|
|
|
|
# Check grok-4 aliases
|
|
|
|
|
grok4_config = provider.MODEL_CAPABILITIES["grok-4"]
|
|
|
|
|
assert "grok" in grok4_config.aliases # grok resolves to grok-4
|
|
|
|
|
assert "grok4" in grok4_config.aliases
|
|
|
|
|
|
|
|
|
|
grok3fast_config = provider.MODEL_CAPABILITIES["grok-3-fast"]
|
|
|
|
|
assert "grok3fast" in grok3fast_config.aliases
|
|
|
|
|
assert "grokfast" in grok3fast_config.aliases
|
|
|
|
|
grok41fast_config = provider.MODEL_CAPABILITIES["grok-4-1-fast-reasoning"]
|
|
|
|
|
assert grok41fast_config.context_window == 2_000_000
|
|
|
|
|
assert grok41fast_config.supports_extended_thinking is True
|
|
|
|
|
assert "grok-4.1-fast" in grok41fast_config.aliases
|
|
|
|
|
assert "grok-4.1-fast-reasoning" in grok41fast_config.aliases
|
|
|
|
|
|
|
|
|
|
@patch("providers.openai_compatible.OpenAI")
|
|
|
|
|
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
|
|
|
|
|
@@ -376,19 +341,13 @@ class TestXAIProvider:
|
|
|
|
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
|
|
|
|
assert call_kwargs["model"] == "grok-4"
|
|
|
|
|
|
|
|
|
|
# Test grok3 -> grok-3
|
|
|
|
|
mock_response.model = "grok-3"
|
|
|
|
|
provider.generate_content(prompt="Test", model_name="grok3", temperature=0.7)
|
|
|
|
|
# Test grok-4.1-fast-reasoning -> grok-4-1-fast-reasoning
|
|
|
|
|
mock_response.model = "grok-4-1-fast-reasoning"
|
|
|
|
|
provider.generate_content(prompt="Test", model_name="grok-4.1-fast-reasoning", temperature=0.7)
|
|
|
|
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
|
|
|
|
assert call_kwargs["model"] == "grok-3"
|
|
|
|
|
assert call_kwargs["model"] == "grok-4-1-fast-reasoning"
|
|
|
|
|
|
|
|
|
|
# Test grokfast -> grok-3-fast
|
|
|
|
|
mock_response.model = "grok-3-fast"
|
|
|
|
|
provider.generate_content(prompt="Test", model_name="grokfast", temperature=0.7)
|
|
|
|
|
# Test grok-4.1-fast -> grok-4-1-fast-reasoning
|
|
|
|
|
provider.generate_content(prompt="Test", model_name="grok-4.1-fast", temperature=0.7)
|
|
|
|
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
|
|
|
|
assert call_kwargs["model"] == "grok-3-fast"
|
|
|
|
|
|
|
|
|
|
# Test grok3fast -> grok-3-fast
|
|
|
|
|
provider.generate_content(prompt="Test", model_name="grok3fast", temperature=0.7)
|
|
|
|
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
|
|
|
|
assert call_kwargs["model"] == "grok-3-fast"
|
|
|
|
|
assert call_kwargs["model"] == "grok-4-1-fast-reasoning"
|
|
|
|
|
|