Tests for Grok 4.
This commit is contained in:
@@ -320,7 +320,8 @@ class TestAutoModeProviderSelection:
|
|||||||
("pro", ProviderType.GOOGLE, "gemini-2.5-pro"),
|
("pro", ProviderType.GOOGLE, "gemini-2.5-pro"),
|
||||||
("mini", ProviderType.OPENAI, "o4-mini"),
|
("mini", ProviderType.OPENAI, "o4-mini"),
|
||||||
("o3mini", ProviderType.OPENAI, "o3-mini"),
|
("o3mini", ProviderType.OPENAI, "o3-mini"),
|
||||||
("grok", ProviderType.XAI, "grok-3"),
|
("grok", ProviderType.XAI, "grok-4-0709"),
|
||||||
|
("grok3", ProviderType.XAI, "grok-3"),
|
||||||
("grokfast", ProviderType.XAI, "grok-3-fast"),
|
("grokfast", ProviderType.XAI, "grok-3-fast"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -75,19 +75,25 @@ class TestSupportedModelsAliases:
|
|||||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
# Test specific aliases
|
# Test specific aliases
|
||||||
assert "grok" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
assert "grok" in provider.SUPPORTED_MODELS["grok-4-0709"].aliases
|
||||||
|
assert "grok-4" in provider.SUPPORTED_MODELS["grok-4-0709"].aliases
|
||||||
|
assert "grok-4-latest" in provider.SUPPORTED_MODELS["grok-4-0709"].aliases
|
||||||
|
assert "grok4" in provider.SUPPORTED_MODELS["grok-4-0709"].aliases
|
||||||
assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
||||||
assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
||||||
assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
||||||
|
|
||||||
# Test alias resolution
|
# Test alias resolution
|
||||||
assert provider._resolve_model_name("grok") == "grok-3"
|
assert provider._resolve_model_name("grok") == "grok-4-0709"
|
||||||
|
assert provider._resolve_model_name("grok4") == "grok-4-0709"
|
||||||
|
assert provider._resolve_model_name("grok-4") == "grok-4-0709"
|
||||||
assert provider._resolve_model_name("grok3") == "grok-3"
|
assert provider._resolve_model_name("grok3") == "grok-3"
|
||||||
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
||||||
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
|
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
|
||||||
|
|
||||||
# Test case insensitive resolution
|
# Test case insensitive resolution
|
||||||
assert provider._resolve_model_name("Grok") == "grok-3"
|
assert provider._resolve_model_name("Grok") == "grok-4-0709"
|
||||||
|
assert provider._resolve_model_name("GROK4") == "grok-4-0709"
|
||||||
assert provider._resolve_model_name("GROKFAST") == "grok-3-fast"
|
assert provider._resolve_model_name("GROKFAST") == "grok-3-fast"
|
||||||
|
|
||||||
def test_dial_provider_aliases(self):
|
def test_dial_provider_aliases(self):
|
||||||
|
|||||||
@@ -45,6 +45,10 @@ class TestXAIProvider:
|
|||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Test valid models
|
# Test valid models
|
||||||
|
assert provider.validate_model_name("grok-4-0709") is True
|
||||||
|
assert provider.validate_model_name("grok-4") is True
|
||||||
|
assert provider.validate_model_name("grok-4-latest") is True
|
||||||
|
assert provider.validate_model_name("grok4") is True
|
||||||
assert provider.validate_model_name("grok-3") is True
|
assert provider.validate_model_name("grok-3") is True
|
||||||
assert provider.validate_model_name("grok-3-fast") is True
|
assert provider.validate_model_name("grok-3-fast") is True
|
||||||
assert provider.validate_model_name("grok") is True
|
assert provider.validate_model_name("grok") is True
|
||||||
@@ -62,12 +66,16 @@ class TestXAIProvider:
|
|||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Test shorthand resolution
|
# Test shorthand resolution
|
||||||
assert provider._resolve_model_name("grok") == "grok-3"
|
assert provider._resolve_model_name("grok") == "grok-4-0709"
|
||||||
|
assert provider._resolve_model_name("grok4") == "grok-4-0709"
|
||||||
|
assert provider._resolve_model_name("grok-4") == "grok-4-0709"
|
||||||
|
assert provider._resolve_model_name("grok-4-latest") == "grok-4-0709"
|
||||||
assert provider._resolve_model_name("grok3") == "grok-3"
|
assert provider._resolve_model_name("grok3") == "grok-3"
|
||||||
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
|
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
|
||||||
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
||||||
|
|
||||||
# Test full name passthrough
|
# Test full name passthrough
|
||||||
|
assert provider._resolve_model_name("grok-4-0709") == "grok-4-0709"
|
||||||
assert provider._resolve_model_name("grok-3") == "grok-3"
|
assert provider._resolve_model_name("grok-3") == "grok-3"
|
||||||
assert provider._resolve_model_name("grok-3-fast") == "grok-3-fast"
|
assert provider._resolve_model_name("grok-3-fast") == "grok-3-fast"
|
||||||
|
|
||||||
@@ -90,6 +98,27 @@ class TestXAIProvider:
|
|||||||
assert capabilities.temperature_constraint.max_temp == 2.0
|
assert capabilities.temperature_constraint.max_temp == 2.0
|
||||||
assert capabilities.temperature_constraint.default_temp == 0.7
|
assert capabilities.temperature_constraint.default_temp == 0.7
|
||||||
|
|
||||||
|
def test_get_capabilities_grok4(self):
|
||||||
|
"""Test getting model capabilities for GROK-4."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
capabilities = provider.get_capabilities("grok-4-0709")
|
||||||
|
assert capabilities.model_name == "grok-4-0709"
|
||||||
|
assert capabilities.friendly_name == "X.AI (Grok 4)"
|
||||||
|
assert capabilities.context_window == 256_000
|
||||||
|
assert capabilities.provider == ProviderType.XAI
|
||||||
|
assert capabilities.supports_extended_thinking is True
|
||||||
|
assert capabilities.supports_system_prompts is True
|
||||||
|
assert capabilities.supports_streaming is True
|
||||||
|
assert capabilities.supports_function_calling is True
|
||||||
|
assert capabilities.supports_json_mode is True
|
||||||
|
assert capabilities.supports_images is True
|
||||||
|
|
||||||
|
# Test temperature range
|
||||||
|
assert capabilities.temperature_constraint.min_temp == 0.0
|
||||||
|
assert capabilities.temperature_constraint.max_temp == 2.0
|
||||||
|
assert capabilities.temperature_constraint.default_temp == 0.7
|
||||||
|
|
||||||
def test_get_capabilities_grok3_fast(self):
|
def test_get_capabilities_grok3_fast(self):
|
||||||
"""Test getting model capabilities for GROK-3 Fast."""
|
"""Test getting model capabilities for GROK-3 Fast."""
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
@@ -106,8 +135,12 @@ class TestXAIProvider:
|
|||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
capabilities = provider.get_capabilities("grok")
|
capabilities = provider.get_capabilities("grok")
|
||||||
assert capabilities.model_name == "grok-3" # Should resolve to full name
|
assert capabilities.model_name == "grok-4-0709" # Should resolve to full name
|
||||||
assert capabilities.context_window == 131_072
|
assert capabilities.context_window == 256_000
|
||||||
|
|
||||||
|
capabilities_3 = provider.get_capabilities("grok3")
|
||||||
|
assert capabilities_3.model_name == "grok-3" # Should resolve to full name
|
||||||
|
assert capabilities_3.context_window == 131_072
|
||||||
|
|
||||||
capabilities_fast = provider.get_capabilities("grokfast")
|
capabilities_fast = provider.get_capabilities("grokfast")
|
||||||
assert capabilities_fast.model_name == "grok-3-fast" # Should resolve to full name
|
assert capabilities_fast.model_name == "grok-3-fast" # Should resolve to full name
|
||||||
@@ -119,13 +152,19 @@ class TestXAIProvider:
|
|||||||
with pytest.raises(ValueError, match="Unsupported X.AI model"):
|
with pytest.raises(ValueError, match="Unsupported X.AI model"):
|
||||||
provider.get_capabilities("invalid-model")
|
provider.get_capabilities("invalid-model")
|
||||||
|
|
||||||
def test_no_thinking_mode_support(self):
|
def test_thinking_mode_support(self):
|
||||||
"""Test that X.AI models don't support thinking mode."""
|
"""Test thinking mode support for X.AI models."""
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Grok-4 supports thinking mode
|
||||||
|
assert provider.supports_thinking_mode("grok-4-0709") is True
|
||||||
|
assert provider.supports_thinking_mode("grok-4") is True
|
||||||
|
assert provider.supports_thinking_mode("grok") is True # Resolves to grok-4
|
||||||
|
|
||||||
|
# Grok-3 models don't support thinking mode
|
||||||
assert not provider.supports_thinking_mode("grok-3")
|
assert not provider.supports_thinking_mode("grok-3")
|
||||||
assert not provider.supports_thinking_mode("grok-3-fast")
|
assert not provider.supports_thinking_mode("grok-3-fast")
|
||||||
assert not provider.supports_thinking_mode("grok")
|
assert not provider.supports_thinking_mode("grok3")
|
||||||
assert not provider.supports_thinking_mode("grokfast")
|
assert not provider.supports_thinking_mode("grokfast")
|
||||||
|
|
||||||
def test_provider_type(self):
|
def test_provider_type(self):
|
||||||
@@ -145,7 +184,11 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
# grok-3 should be allowed
|
# grok-3 should be allowed
|
||||||
assert provider.validate_model_name("grok-3") is True
|
assert provider.validate_model_name("grok-3") is True
|
||||||
assert provider.validate_model_name("grok") is True # Shorthand for grok-3
|
assert provider.validate_model_name("grok3") is True # Shorthand for grok-3
|
||||||
|
|
||||||
|
# grok-4 and its aliases should be blocked
|
||||||
|
assert provider.validate_model_name("grok-4-0709") is False
|
||||||
|
assert provider.validate_model_name("grok") is False # Now resolves to grok-4
|
||||||
|
|
||||||
# grok-3-fast should be blocked by restrictions
|
# grok-3-fast should be blocked by restrictions
|
||||||
assert provider.validate_model_name("grok-3-fast") is False
|
assert provider.validate_model_name("grok-3-fast") is False
|
||||||
@@ -161,10 +204,13 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Shorthand "grok" should be allowed (resolves to grok-3)
|
# Shorthand "grok" should be allowed (resolves to grok-4-0709)
|
||||||
assert provider.validate_model_name("grok") is True
|
assert provider.validate_model_name("grok") is True
|
||||||
|
|
||||||
# Full name "grok-3" should NOT be allowed (only shorthand "grok" is in restriction list)
|
# Full name "grok-4-0709" should NOT be allowed (only shorthand "grok" is in restriction list)
|
||||||
|
assert provider.validate_model_name("grok-4-0709") is False
|
||||||
|
|
||||||
|
# "grok-3" should NOT be allowed (not in restriction list)
|
||||||
assert provider.validate_model_name("grok-3") is False
|
assert provider.validate_model_name("grok-3") is False
|
||||||
|
|
||||||
# "grok-3-fast" should be allowed (explicitly listed)
|
# "grok-3-fast" should be allowed (explicitly listed)
|
||||||
@@ -173,7 +219,7 @@ class TestXAIProvider:
|
|||||||
# Shorthand "grokfast" should be allowed (resolves to grok-3-fast)
|
# Shorthand "grokfast" should be allowed (resolves to grok-3-fast)
|
||||||
assert provider.validate_model_name("grokfast") is True
|
assert provider.validate_model_name("grokfast") is True
|
||||||
|
|
||||||
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3"})
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3,grok-4-0709"})
|
||||||
def test_both_shorthand_and_full_name_allowed(self):
|
def test_both_shorthand_and_full_name_allowed(self):
|
||||||
"""Test that both shorthand and full name can be allowed."""
|
"""Test that both shorthand and full name can be allowed."""
|
||||||
# Clear cached restriction service
|
# Clear cached restriction service
|
||||||
@@ -184,8 +230,9 @@ class TestXAIProvider:
|
|||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Both shorthand and full name should be allowed
|
# Both shorthand and full name should be allowed
|
||||||
assert provider.validate_model_name("grok") is True
|
assert provider.validate_model_name("grok") is True # Resolves to grok-4-0709
|
||||||
assert provider.validate_model_name("grok-3") is True
|
assert provider.validate_model_name("grok-3") is True
|
||||||
|
assert provider.validate_model_name("grok-4-0709") is True
|
||||||
|
|
||||||
# Other models should not be allowed
|
# Other models should not be allowed
|
||||||
assert provider.validate_model_name("grok-3-fast") is False
|
assert provider.validate_model_name("grok-3-fast") is False
|
||||||
@@ -201,10 +248,12 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
assert provider.validate_model_name("grok-4-0709") is True
|
||||||
assert provider.validate_model_name("grok-3") is True
|
assert provider.validate_model_name("grok-3") is True
|
||||||
assert provider.validate_model_name("grok-3-fast") is True
|
assert provider.validate_model_name("grok-3-fast") is True
|
||||||
assert provider.validate_model_name("grok") is True
|
assert provider.validate_model_name("grok") is True
|
||||||
assert provider.validate_model_name("grokfast") is True
|
assert provider.validate_model_name("grokfast") is True
|
||||||
|
assert provider.validate_model_name("grok4") is True
|
||||||
|
|
||||||
def test_friendly_name(self):
|
def test_friendly_name(self):
|
||||||
"""Test friendly name constant."""
|
"""Test friendly name constant."""
|
||||||
@@ -219,22 +268,30 @@ class TestXAIProvider:
|
|||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Check that all expected base models are present
|
# Check that all expected base models are present
|
||||||
|
assert "grok-4-0709" in provider.SUPPORTED_MODELS
|
||||||
assert "grok-3" in provider.SUPPORTED_MODELS
|
assert "grok-3" in provider.SUPPORTED_MODELS
|
||||||
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
||||||
|
|
||||||
# Check model configs have required fields
|
# Check model configs have required fields
|
||||||
from providers.base import ModelCapabilities
|
from providers.base import ModelCapabilities
|
||||||
|
|
||||||
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
grok4_config = provider.SUPPORTED_MODELS["grok-4-0709"]
|
||||||
assert isinstance(grok3_config, ModelCapabilities)
|
assert isinstance(grok4_config, ModelCapabilities)
|
||||||
assert hasattr(grok3_config, "context_window")
|
assert hasattr(grok4_config, "context_window")
|
||||||
assert hasattr(grok3_config, "supports_extended_thinking")
|
assert hasattr(grok4_config, "supports_extended_thinking")
|
||||||
assert hasattr(grok3_config, "aliases")
|
assert hasattr(grok4_config, "aliases")
|
||||||
assert grok3_config.context_window == 131_072
|
assert grok4_config.context_window == 256_000
|
||||||
assert grok3_config.supports_extended_thinking is False
|
assert grok4_config.supports_extended_thinking is True
|
||||||
|
|
||||||
# Check aliases are correctly structured
|
# Check aliases are correctly structured
|
||||||
assert "grok" in grok3_config.aliases
|
assert "grok" in grok4_config.aliases
|
||||||
|
assert "grok-4" in grok4_config.aliases
|
||||||
|
assert "grok-4-latest" in grok4_config.aliases
|
||||||
|
assert "grok4" in grok4_config.aliases
|
||||||
|
|
||||||
|
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
||||||
|
assert grok3_config.context_window == 131_072
|
||||||
|
assert grok3_config.supports_extended_thinking is False
|
||||||
assert "grok3" in grok3_config.aliases
|
assert "grok3" in grok3_config.aliases
|
||||||
|
|
||||||
grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"]
|
grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"]
|
||||||
@@ -257,7 +314,7 @@ class TestXAIProvider:
|
|||||||
mock_response.choices = [MagicMock()]
|
mock_response.choices = [MagicMock()]
|
||||||
mock_response.choices[0].message.content = "Test response"
|
mock_response.choices[0].message.content = "Test response"
|
||||||
mock_response.choices[0].finish_reason = "stop"
|
mock_response.choices[0].finish_reason = "stop"
|
||||||
mock_response.model = "grok-3" # API returns the resolved model name
|
mock_response.model = "grok-4-0709" # API returns the resolved model name
|
||||||
mock_response.id = "test-id"
|
mock_response.id = "test-id"
|
||||||
mock_response.created = 1234567890
|
mock_response.created = 1234567890
|
||||||
mock_response.usage = MagicMock()
|
mock_response.usage = MagicMock()
|
||||||
@@ -271,15 +328,17 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
# Call generate_content with alias 'grok'
|
# Call generate_content with alias 'grok'
|
||||||
result = provider.generate_content(
|
result = provider.generate_content(
|
||||||
prompt="Test prompt", model_name="grok", temperature=0.7 # This should be resolved to "grok-3"
|
prompt="Test prompt", model_name="grok", temperature=0.7 # This should be resolved to "grok-4-0709"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the API was called with the RESOLVED model name
|
# Verify the API was called with the RESOLVED model name
|
||||||
mock_client.chat.completions.create.assert_called_once()
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
|
|
||||||
# CRITICAL ASSERTION: The API should receive "grok-3", not "grok"
|
# CRITICAL ASSERTION: The API should receive "grok-4-0709", not "grok"
|
||||||
assert call_kwargs["model"] == "grok-3", f"Expected 'grok-3' but API received '{call_kwargs['model']}'"
|
assert (
|
||||||
|
call_kwargs["model"] == "grok-4-0709"
|
||||||
|
), f"Expected 'grok-4-0709' but API received '{call_kwargs['model']}'"
|
||||||
|
|
||||||
# Verify other parameters
|
# Verify other parameters
|
||||||
assert call_kwargs["temperature"] == 0.7
|
assert call_kwargs["temperature"] == 0.7
|
||||||
@@ -289,7 +348,7 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
# Verify response
|
# Verify response
|
||||||
assert result.content == "Test response"
|
assert result.content == "Test response"
|
||||||
assert result.model_name == "grok-3" # Should be the resolved name
|
assert result.model_name == "grok-4-0709" # Should be the resolved name
|
||||||
|
|
||||||
@patch("providers.openai_compatible.OpenAI")
|
@patch("providers.openai_compatible.OpenAI")
|
||||||
def test_generate_content_other_aliases(self, mock_openai_class):
|
def test_generate_content_other_aliases(self, mock_openai_class):
|
||||||
@@ -311,6 +370,17 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Test grok4 -> grok-4-0709
|
||||||
|
mock_response.model = "grok-4-0709"
|
||||||
|
provider.generate_content(prompt="Test", model_name="grok4", temperature=0.7)
|
||||||
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
|
assert call_kwargs["model"] == "grok-4-0709"
|
||||||
|
|
||||||
|
# Test grok-4 -> grok-4-0709
|
||||||
|
provider.generate_content(prompt="Test", model_name="grok-4", temperature=0.7)
|
||||||
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
|
assert call_kwargs["model"] == "grok-4-0709"
|
||||||
|
|
||||||
# Test grok3 -> grok-3
|
# Test grok3 -> grok-3
|
||||||
mock_response.model = "grok-3"
|
mock_response.model = "grok-3"
|
||||||
provider.generate_content(prompt="Test", model_name="grok3", temperature=0.7)
|
provider.generate_content(prompt="Test", model_name="grok3", temperature=0.7)
|
||||||
|
|||||||
Reference in New Issue
Block a user