Update test_xai_provider.py
This commit is contained in:
committed by
GitHub
parent
8a884c57d6
commit
912cde42d1
@@ -45,7 +45,6 @@ class TestXAIProvider:
|
|||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Test valid models
|
# Test valid models
|
||||||
assert provider.validate_model_name("grok-4-0709") is True
|
|
||||||
assert provider.validate_model_name("grok-4") is True
|
assert provider.validate_model_name("grok-4") is True
|
||||||
assert provider.validate_model_name("grok-4-latest") is True
|
assert provider.validate_model_name("grok-4-latest") is True
|
||||||
assert provider.validate_model_name("grok4") is True
|
assert provider.validate_model_name("grok4") is True
|
||||||
@@ -73,7 +72,7 @@ class TestXAIProvider:
|
|||||||
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
||||||
|
|
||||||
# Test full name passthrough
|
# Test full name passthrough
|
||||||
assert provider._resolve_model_name("grok-4-0709") == "grok-4-0709"
|
assert provider._resolve_model_name("grok-4") == "grok-4"
|
||||||
assert provider._resolve_model_name("grok-3") == "grok-3"
|
assert provider._resolve_model_name("grok-3") == "grok-3"
|
||||||
assert provider._resolve_model_name("grok-3-fast") == "grok-3-fast"
|
assert provider._resolve_model_name("grok-3-fast") == "grok-3-fast"
|
||||||
|
|
||||||
@@ -100,8 +99,8 @@ class TestXAIProvider:
|
|||||||
"""Test getting model capabilities for GROK-4."""
|
"""Test getting model capabilities for GROK-4."""
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
capabilities = provider.get_capabilities("grok-4-0709")
|
capabilities = provider.get_capabilities("grok-4")
|
||||||
assert capabilities.model_name == "grok-4-0709"
|
assert capabilities.model_name == "grok-4"
|
||||||
assert capabilities.friendly_name == "X.AI (Grok 4)"
|
assert capabilities.friendly_name == "X.AI (Grok 4)"
|
||||||
assert capabilities.context_window == 256_000
|
assert capabilities.context_window == 256_000
|
||||||
assert capabilities.provider == ProviderType.XAI
|
assert capabilities.provider == ProviderType.XAI
|
||||||
@@ -151,7 +150,6 @@ class TestXAIProvider:
|
|||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Grok-4 supports thinking mode
|
# Grok-4 supports thinking mode
|
||||||
assert provider.supports_thinking_mode("grok-4-0709") is True
|
|
||||||
assert provider.supports_thinking_mode("grok-4") is True
|
assert provider.supports_thinking_mode("grok-4") is True
|
||||||
assert provider.supports_thinking_mode("grok") is True # Resolves to grok-4
|
assert provider.supports_thinking_mode("grok") is True # Resolves to grok-4
|
||||||
|
|
||||||
@@ -202,8 +200,8 @@ class TestXAIProvider:
|
|||||||
# Shorthand "grok" should be allowed (resolves to grok-4)
|
# Shorthand "grok" should be allowed (resolves to grok-4)
|
||||||
assert provider.validate_model_name("grok") is True
|
assert provider.validate_model_name("grok") is True
|
||||||
|
|
||||||
# Full name "grok-4-0709" should NOT be allowed (only shorthand "grok" is in restriction list)
|
# Full name "grok-4" should NOT be allowed (only shorthand "grok" is in restriction list)
|
||||||
assert provider.validate_model_name("grok-4-0709") is False
|
assert provider.validate_model_name("grok-4") is False
|
||||||
|
|
||||||
# "grok-3" should NOT be allowed (not in restriction list)
|
# "grok-3" should NOT be allowed (not in restriction list)
|
||||||
assert provider.validate_model_name("grok-3") is False
|
assert provider.validate_model_name("grok-3") is False
|
||||||
@@ -214,7 +212,7 @@ class TestXAIProvider:
|
|||||||
# Shorthand "grokfast" should be allowed (resolves to grok-3-fast)
|
# Shorthand "grokfast" should be allowed (resolves to grok-3-fast)
|
||||||
assert provider.validate_model_name("grokfast") is True
|
assert provider.validate_model_name("grokfast") is True
|
||||||
|
|
||||||
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3,grok-4-0709"})
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3,grok-4"})
|
||||||
def test_both_shorthand_and_full_name_allowed(self):
|
def test_both_shorthand_and_full_name_allowed(self):
|
||||||
"""Test that both shorthand and full name can be allowed."""
|
"""Test that both shorthand and full name can be allowed."""
|
||||||
# Clear cached restriction service
|
# Clear cached restriction service
|
||||||
@@ -225,9 +223,9 @@ class TestXAIProvider:
|
|||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Both shorthand and full name should be allowed
|
# Both shorthand and full name should be allowed
|
||||||
assert provider.validate_model_name("grok") is True # Resolves to grok-4-0709
|
assert provider.validate_model_name("grok") is True # Resolves to grok-4
|
||||||
assert provider.validate_model_name("grok-3") is True
|
assert provider.validate_model_name("grok-3") is True
|
||||||
assert provider.validate_model_name("grok-4-0709") is True
|
assert provider.validate_model_name("grok-4") is True
|
||||||
|
|
||||||
# Other models should not be allowed
|
# Other models should not be allowed
|
||||||
assert provider.validate_model_name("grok-3-fast") is False
|
assert provider.validate_model_name("grok-3-fast") is False
|
||||||
@@ -243,7 +241,7 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
assert provider.validate_model_name("grok-4-0709") is True
|
assert provider.validate_model_name("grok-4") is True
|
||||||
assert provider.validate_model_name("grok-3") is True
|
assert provider.validate_model_name("grok-3") is True
|
||||||
assert provider.validate_model_name("grok-3-fast") is True
|
assert provider.validate_model_name("grok-3-fast") is True
|
||||||
assert provider.validate_model_name("grok") is True
|
assert provider.validate_model_name("grok") is True
|
||||||
@@ -270,7 +268,7 @@ class TestXAIProvider:
|
|||||||
# Check model configs have required fields
|
# Check model configs have required fields
|
||||||
from providers.base import ModelCapabilities
|
from providers.base import ModelCapabilities
|
||||||
|
|
||||||
grok4_config = provider.SUPPORTED_MODELS["grok-4-0709"]
|
grok4_config = provider.SUPPORTED_MODELS["grok-4"]
|
||||||
assert isinstance(grok4_config, ModelCapabilities)
|
assert isinstance(grok4_config, ModelCapabilities)
|
||||||
assert hasattr(grok4_config, "context_window")
|
assert hasattr(grok4_config, "context_window")
|
||||||
assert hasattr(grok4_config, "supports_extended_thinking")
|
assert hasattr(grok4_config, "supports_extended_thinking")
|
||||||
@@ -315,7 +313,7 @@ class TestXAIProvider:
|
|||||||
mock_response.choices = [MagicMock()]
|
mock_response.choices = [MagicMock()]
|
||||||
mock_response.choices[0].message.content = "Test response"
|
mock_response.choices[0].message.content = "Test response"
|
||||||
mock_response.choices[0].finish_reason = "stop"
|
mock_response.choices[0].finish_reason = "stop"
|
||||||
mock_response.model = "grok-4-0709" # API returns the resolved model name
|
mock_response.model = "grok-4" # API returns the resolved model name
|
||||||
mock_response.id = "test-id"
|
mock_response.id = "test-id"
|
||||||
mock_response.created = 1234567890
|
mock_response.created = 1234567890
|
||||||
mock_response.usage = MagicMock()
|
mock_response.usage = MagicMock()
|
||||||
@@ -369,16 +367,16 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Test grok4 -> grok-4-0709
|
# Test grok4 -> grok-4
|
||||||
mock_response.model = "grok-4-0709"
|
mock_response.model = "grok-4"
|
||||||
provider.generate_content(prompt="Test", model_name="grok4", temperature=0.7)
|
provider.generate_content(prompt="Test", model_name="grok4", temperature=0.7)
|
||||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
assert call_kwargs["model"] == "grok-4-0709"
|
assert call_kwargs["model"] == "grok-4"
|
||||||
|
|
||||||
# Test grok-4 -> grok-4-0709
|
# Test grok-4 -> grok-4
|
||||||
provider.generate_content(prompt="Test", model_name="grok-4", temperature=0.7)
|
provider.generate_content(prompt="Test", model_name="grok-4", temperature=0.7)
|
||||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
assert call_kwargs["model"] == "grok-4-0709"
|
assert call_kwargs["model"] == "grok-4"
|
||||||
|
|
||||||
# Test grok3 -> grok-3
|
# Test grok3 -> grok-3
|
||||||
mock_response.model = "grok-3"
|
mock_response.model = "grok-3"
|
||||||
|
|||||||
Reference in New Issue
Block a user