Grok-4 support

This commit is contained in:
Fahad
2025-08-08 09:39:07 +05:00
parent 6e7f07c49d
commit 7f37efcbfe
13 changed files with 78 additions and 39 deletions

View File

@@ -202,7 +202,7 @@ class ModelProvider(ABC):
prompt: str, prompt: str,
model_name: str, model_name: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
temperature: float = 0.7, temperature: float = 0.3,
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> ModelResponse: ) -> ModelResponse:

View File

@@ -236,7 +236,7 @@ class CustomProvider(OpenAICompatibleProvider):
prompt: str, prompt: str,
model_name: str, model_name: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
temperature: float = 0.7, temperature: float = 0.3,
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> ModelResponse: ) -> ModelResponse:

View File

@@ -375,7 +375,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
prompt: str, prompt: str,
model_name: str, model_name: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
temperature: float = 0.7, temperature: float = 0.3,
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
images: Optional[list[str]] = None, images: Optional[list[str]] = None,
**kwargs, **kwargs,

View File

@@ -155,7 +155,7 @@ class GeminiModelProvider(ModelProvider):
prompt: str, prompt: str,
model_name: str, model_name: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
temperature: float = 0.7, temperature: float = 0.3,
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
thinking_mode: str = "medium", thinking_mode: str = "medium",
images: Optional[list[str]] = None, images: Optional[list[str]] = None,

View File

@@ -389,7 +389,7 @@ class OpenAICompatibleProvider(ModelProvider):
prompt: str, prompt: str,
model_name: str, model_name: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
temperature: float = 0.7, temperature: float = 0.3,
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
images: Optional[list[str]] = None, images: Optional[list[str]] = None,
**kwargs, **kwargs,

View File

@@ -221,7 +221,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
prompt: str, prompt: str,
model_name: str, model_name: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
temperature: float = 0.7, temperature: float = 0.3,
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> ModelResponse: ) -> ModelResponse:

View File

@@ -158,7 +158,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
prompt: str, prompt: str,
model_name: str, model_name: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
temperature: float = 0.7, temperature: float = 0.3,
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> ModelResponse: ) -> ModelResponse:

View File

@@ -24,6 +24,24 @@ class XAIModelProvider(OpenAICompatibleProvider):
# Model configurations using ModelCapabilities objects # Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"grok-4": ModelCapabilities(
provider=ProviderType.XAI,
model_name="grok-4",
friendly_name="X.AI (Grok 4)",
context_window=256_000, # 256K tokens
max_output_tokens=256_000, # 256K tokens max output
supports_extended_thinking=True, # Grok-4 supports reasoning mode
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True, # Function calling supported
supports_json_mode=True, # Structured outputs supported
supports_images=True, # Multimodal capabilities
max_image_size_mb=20.0, # Standard image size limit
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="GROK-4 (256K context) - Frontier multimodal reasoning model with advanced capabilities",
aliases=["grok", "grok4", "grok-4"],
),
"grok-3": ModelCapabilities( "grok-3": ModelCapabilities(
provider=ProviderType.XAI, provider=ProviderType.XAI,
model_name="grok-3", model_name="grok-3",
@@ -40,7 +58,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
supports_temperature=True, supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"), temperature_constraint=create_temperature_constraint("range"),
description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis", description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
aliases=["grok", "grok3"], aliases=["grok3"],
), ),
"grok-3-fast": ModelCapabilities( "grok-3-fast": ModelCapabilities(
provider=ProviderType.XAI, provider=ProviderType.XAI,
@@ -113,7 +131,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
prompt: str, prompt: str,
model_name: str, model_name: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
temperature: float = 0.7, temperature: float = 0.3,
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> ModelResponse: ) -> ModelResponse:
@@ -133,8 +151,10 @@ class XAIModelProvider(OpenAICompatibleProvider):
def supports_thinking_mode(self, model_name: str) -> bool: def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode.""" """Check if the model supports extended thinking mode."""
# Currently GROK models do not support extended thinking resolved_name = self._resolve_model_name(model_name)
# This may change with future GROK model releases capabilities = self.SUPPORTED_MODELS.get(resolved_name)
if capabilities:
return capabilities.supports_extended_thinking
return False return False
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]: def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
@@ -153,22 +173,28 @@ class XAIModelProvider(OpenAICompatibleProvider):
return None return None
if category == ToolModelCategory.EXTENDED_REASONING: if category == ToolModelCategory.EXTENDED_REASONING:
# Prefer GROK-3 for reasoning # Prefer GROK-4 for advanced reasoning with thinking mode
if "grok-3" in allowed_models: if "grok-4" in allowed_models:
return "grok-4"
elif "grok-3" in allowed_models:
return "grok-3" return "grok-3"
# Fall back to any available model # Fall back to any available model
return allowed_models[0] return allowed_models[0]
elif category == ToolModelCategory.FAST_RESPONSE: elif category == ToolModelCategory.FAST_RESPONSE:
# Prefer GROK-3-Fast for speed # Prefer GROK-3-Fast for speed, then GROK-4
if "grok-3-fast" in allowed_models: if "grok-3-fast" in allowed_models:
return "grok-3-fast" return "grok-3-fast"
elif "grok-4" in allowed_models:
return "grok-4"
# Fall back to any available model # Fall back to any available model
return allowed_models[0] return allowed_models[0]
else: # BALANCED or default else: # BALANCED or default
# Prefer standard GROK-3 for balanced use # Prefer GROK-4 for balanced use (best overall capabilities)
if "grok-3" in allowed_models: if "grok-4" in allowed_models:
return "grok-4"
elif "grok-3" in allowed_models:
return "grok-3" return "grok-3"
elif "grok-3-fast" in allowed_models: elif "grok-3-fast" in allowed_models:
return "grok-3-fast" return "grok-3-fast"

View File

@@ -43,8 +43,8 @@ class XAIModelsTest(BaseSimulatorTest):
# Setup test files for later use # Setup test files for later use
self.setup_test_files() self.setup_test_files()
# Test 1: 'grok' alias (should map to grok-3) # Test 1: 'grok' alias (should map to grok-4)
self.logger.info(" 1: Testing 'grok' alias (should map to grok-3)") self.logger.info(" 1: Testing 'grok' alias (should map to grok-4)")
response1, continuation_id = self.call_mcp_tool( response1, continuation_id = self.call_mcp_tool(
"chat", "chat",

View File

@@ -108,9 +108,9 @@ class TestAutoModeComprehensive:
"OPENROUTER_API_KEY": None, "OPENROUTER_API_KEY": None,
}, },
{ {
"EXTENDED_REASONING": "grok-3", # GROK-3 for reasoning "EXTENDED_REASONING": "grok-4", # GROK-4 for reasoning (now preferred)
"FAST_RESPONSE": "grok-3-fast", # GROK-3-fast for speed "FAST_RESPONSE": "grok-3-fast", # GROK-3-fast for speed
"BALANCED": "grok-3", # GROK-3 as balanced "BALANCED": "grok-4", # GROK-4 as balanced (now preferred)
}, },
), ),
# Both Gemini and OpenAI available - Google comes first in priority # Both Gemini and OpenAI available - Google comes first in priority

View File

@@ -320,7 +320,7 @@ class TestAutoModeProviderSelection:
("pro", ProviderType.GOOGLE, "gemini-2.5-pro"), ("pro", ProviderType.GOOGLE, "gemini-2.5-pro"),
("mini", ProviderType.OPENAI, "gpt-5-mini"), # "mini" now resolves to gpt-5-mini ("mini", ProviderType.OPENAI, "gpt-5-mini"), # "mini" now resolves to gpt-5-mini
("o3mini", ProviderType.OPENAI, "o3-mini"), ("o3mini", ProviderType.OPENAI, "o3-mini"),
("grok", ProviderType.XAI, "grok-3"), ("grok", ProviderType.XAI, "grok-4"),
("grokfast", ProviderType.XAI, "grok-3-fast"), ("grokfast", ProviderType.XAI, "grok-3-fast"),
] ]

View File

@@ -76,19 +76,21 @@ class TestSupportedModelsAliases:
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
# Test specific aliases # Test specific aliases
assert "grok" in provider.SUPPORTED_MODELS["grok-3"].aliases assert "grok" in provider.SUPPORTED_MODELS["grok-4"].aliases
assert "grok4" in provider.SUPPORTED_MODELS["grok-4"].aliases
assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases
assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
# Test alias resolution # Test alias resolution
assert provider._resolve_model_name("grok") == "grok-3" assert provider._resolve_model_name("grok") == "grok-4"
assert provider._resolve_model_name("grok4") == "grok-4"
assert provider._resolve_model_name("grok3") == "grok-3" assert provider._resolve_model_name("grok3") == "grok-3"
assert provider._resolve_model_name("grok3fast") == "grok-3-fast" assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
assert provider._resolve_model_name("grokfast") == "grok-3-fast" assert provider._resolve_model_name("grokfast") == "grok-3-fast"
# Test case insensitive resolution # Test case insensitive resolution
assert provider._resolve_model_name("Grok") == "grok-3" assert provider._resolve_model_name("Grok") == "grok-4"
assert provider._resolve_model_name("GROKFAST") == "grok-3-fast" assert provider._resolve_model_name("GROKFAST") == "grok-3-fast"
def test_dial_provider_aliases(self): def test_dial_provider_aliases(self):

View File

@@ -62,7 +62,8 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key") provider = XAIModelProvider("test-key")
# Test shorthand resolution # Test shorthand resolution
assert provider._resolve_model_name("grok") == "grok-3" assert provider._resolve_model_name("grok") == "grok-4"
assert provider._resolve_model_name("grok4") == "grok-4"
assert provider._resolve_model_name("grok3") == "grok-3" assert provider._resolve_model_name("grok3") == "grok-3"
assert provider._resolve_model_name("grokfast") == "grok-3-fast" assert provider._resolve_model_name("grokfast") == "grok-3-fast"
assert provider._resolve_model_name("grok3fast") == "grok-3-fast" assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
@@ -106,8 +107,8 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key") provider = XAIModelProvider("test-key")
capabilities = provider.get_capabilities("grok") capabilities = provider.get_capabilities("grok")
assert capabilities.model_name == "grok-3" # Should resolve to full name assert capabilities.model_name == "grok-4" # Should resolve to full name
assert capabilities.context_window == 131_072 assert capabilities.context_window == 256_000
capabilities_fast = provider.get_capabilities("grokfast") capabilities_fast = provider.get_capabilities("grokfast")
assert capabilities_fast.model_name == "grok-3-fast" # Should resolve to full name assert capabilities_fast.model_name == "grok-3-fast" # Should resolve to full name
@@ -119,13 +120,15 @@ class TestXAIProvider:
with pytest.raises(ValueError, match="Unsupported X.AI model"): with pytest.raises(ValueError, match="Unsupported X.AI model"):
provider.get_capabilities("invalid-model") provider.get_capabilities("invalid-model")
def test_no_thinking_mode_support(self): def test_thinking_mode_support(self):
"""Test that X.AI models don't support thinking mode.""" """Test X.AI model thinking mode support - grok-4 supports it, earlier models don't."""
provider = XAIModelProvider("test-key") provider = XAIModelProvider("test-key")
assert not provider.supports_thinking_mode("grok-3") assert not provider.supports_thinking_mode("grok-3")
assert not provider.supports_thinking_mode("grok-3-fast") assert not provider.supports_thinking_mode("grok-3-fast")
assert not provider.supports_thinking_mode("grok") assert provider.supports_thinking_mode("grok-4") # grok-4 supports thinking mode
assert provider.supports_thinking_mode("grok") # resolves to grok-4
assert provider.supports_thinking_mode("grok4") # resolves to grok-4
assert not provider.supports_thinking_mode("grokfast") assert not provider.supports_thinking_mode("grokfast")
def test_provider_type(self): def test_provider_type(self):
@@ -145,7 +148,10 @@ class TestXAIProvider:
# grok-3 should be allowed # grok-3 should be allowed
assert provider.validate_model_name("grok-3") is True assert provider.validate_model_name("grok-3") is True
assert provider.validate_model_name("grok") is True # Shorthand for grok-3 assert provider.validate_model_name("grok3") is True # Shorthand for grok-3
# grok should be blocked (resolves to grok-4 which is not allowed)
assert provider.validate_model_name("grok") is False
# grok-3-fast should be blocked by restrictions # grok-3-fast should be blocked by restrictions
assert provider.validate_model_name("grok-3-fast") is False assert provider.validate_model_name("grok-3-fast") is False
@@ -161,7 +167,7 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key") provider = XAIModelProvider("test-key")
# Shorthand "grok" should be allowed (resolves to grok-3) # Shorthand "grok" should be allowed (resolves to grok-4)
assert provider.validate_model_name("grok") is True assert provider.validate_model_name("grok") is True
# Full name "grok-3" should NOT be allowed (only shorthand "grok" is in restriction list) # Full name "grok-3" should NOT be allowed (only shorthand "grok" is in restriction list)
@@ -219,6 +225,7 @@ class TestXAIProvider:
provider = XAIModelProvider("test-key") provider = XAIModelProvider("test-key")
# Check that all expected base models are present # Check that all expected base models are present
assert "grok-4" in provider.SUPPORTED_MODELS
assert "grok-3" in provider.SUPPORTED_MODELS assert "grok-3" in provider.SUPPORTED_MODELS
assert "grok-3-fast" in provider.SUPPORTED_MODELS assert "grok-3-fast" in provider.SUPPORTED_MODELS
@@ -234,8 +241,12 @@ class TestXAIProvider:
assert grok3_config.supports_extended_thinking is False assert grok3_config.supports_extended_thinking is False
# Check aliases are correctly structured # Check aliases are correctly structured
assert "grok" in grok3_config.aliases assert "grok3" in grok3_config.aliases # grok3 resolves to grok-3
assert "grok3" in grok3_config.aliases
# Check grok-4 aliases
grok4_config = provider.SUPPORTED_MODELS["grok-4"]
assert "grok" in grok4_config.aliases # grok resolves to grok-4
assert "grok4" in grok4_config.aliases
grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"] grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"]
assert "grok3fast" in grok3fast_config.aliases assert "grok3fast" in grok3fast_config.aliases
@@ -246,7 +257,7 @@ class TestXAIProvider:
"""Test that generate_content resolves aliases before making API calls. """Test that generate_content resolves aliases before making API calls.
This is the CRITICAL test that ensures aliases like 'grok' get resolved This is the CRITICAL test that ensures aliases like 'grok' get resolved
to 'grok-3' before being sent to X.AI API. to 'grok-4' before being sent to X.AI API.
""" """
# Set up mock OpenAI client # Set up mock OpenAI client
mock_client = MagicMock() mock_client = MagicMock()
@@ -271,15 +282,15 @@ class TestXAIProvider:
# Call generate_content with alias 'grok' # Call generate_content with alias 'grok'
result = provider.generate_content( result = provider.generate_content(
prompt="Test prompt", model_name="grok", temperature=0.7 # This should be resolved to "grok-3" prompt="Test prompt", model_name="grok", temperature=0.7 # This should be resolved to "grok-4"
) )
# Verify the API was called with the RESOLVED model name # Verify the API was called with the RESOLVED model name
mock_client.chat.completions.create.assert_called_once() mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args[1] call_kwargs = mock_client.chat.completions.create.call_args[1]
# CRITICAL ASSERTION: The API should receive "grok-3", not "grok" # CRITICAL ASSERTION: The API should receive "grok-4", not "grok"
assert call_kwargs["model"] == "grok-3", f"Expected 'grok-3' but API received '{call_kwargs['model']}'" assert call_kwargs["model"] == "grok-4", f"Expected 'grok-4' but API received '{call_kwargs['model']}'"
# Verify other parameters # Verify other parameters
assert call_kwargs["temperature"] == 0.7 assert call_kwargs["temperature"] == 0.7
@@ -289,7 +300,7 @@ class TestXAIProvider:
# Verify response # Verify response
assert result.content == "Test response" assert result.content == "Test response"
assert result.model_name == "grok-3" # Should be the resolved name assert result.model_name == "grok-4" # Should be the resolved name
@patch("providers.openai_compatible.OpenAI") @patch("providers.openai_compatible.OpenAI")
def test_generate_content_other_aliases(self, mock_openai_class): def test_generate_content_other_aliases(self, mock_openai_class):