This commit is contained in:
Fahad
2025-12-11 20:08:17 +00:00
parent 8b16405f06
commit 514c9c58fc
12 changed files with 157 additions and 232 deletions

View File

@@ -26,6 +26,10 @@ class XAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
REGISTRY_CLASS = XAIModelRegistry
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
# Canonical model identifiers used for category routing.
PRIMARY_MODEL = "grok-4-1-fast-reasoning"
FALLBACK_MODEL = "grok-4"
def __init__(self, api_key: str, **kwargs):
"""Initialize X.AI provider with API key."""
# Set X.AI base URL
@@ -54,32 +58,27 @@ class XAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
return None
if category == ToolModelCategory.EXTENDED_REASONING:
# Prefer GROK-4 for advanced reasoning with thinking mode
if "grok-4" in allowed_models:
return "grok-4"
elif "grok-3" in allowed_models:
return "grok-3"
# Fall back to any available model
# Prefer Grok 4.1 Fast Reasoning for advanced tasks
if self.PRIMARY_MODEL in allowed_models:
return self.PRIMARY_MODEL
if self.FALLBACK_MODEL in allowed_models:
return self.FALLBACK_MODEL
return allowed_models[0]
elif category == ToolModelCategory.FAST_RESPONSE:
# Prefer GROK-3-Fast for speed, then GROK-4
if "grok-3-fast" in allowed_models:
return "grok-3-fast"
elif "grok-4" in allowed_models:
return "grok-4"
# Fall back to any available model
# Prefer Grok 4.1 Fast Reasoning for speed as well (latest fast SKU).
if self.PRIMARY_MODEL in allowed_models:
return self.PRIMARY_MODEL
if self.FALLBACK_MODEL in allowed_models:
return self.FALLBACK_MODEL
return allowed_models[0]
else: # BALANCED or default
# Prefer GROK-4 for balanced use (best overall capabilities)
if "grok-4" in allowed_models:
return "grok-4"
elif "grok-3" in allowed_models:
return "grok-3"
elif "grok-3-fast" in allowed_models:
return "grok-3-fast"
# Fall back to any available model
# Prefer Grok 4.1 Fast Reasoning for balanced use.
if self.PRIMARY_MODEL in allowed_models:
return self.PRIMARY_MODEL
if self.FALLBACK_MODEL in allowed_models:
return self.FALLBACK_MODEL
return allowed_models[0]