feat: all native providers now read from catalog files like OpenRouter / Custom configs. Allows for greater control over the capabilities
This commit is contained in:
108
providers/xai.py
108
providers/xai.py
@@ -7,7 +7,8 @@ if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
from .xai_registry import XAIModelRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,72 +22,53 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
FRIENDLY_NAME = "X.AI"
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
MODEL_CAPABILITIES = {
|
||||
"grok-4": ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name="grok-4",
|
||||
friendly_name="X.AI (Grok 4)",
|
||||
intelligence_score=16,
|
||||
context_window=256_000, # 256K tokens
|
||||
max_output_tokens=256_000, # 256K tokens max output
|
||||
supports_extended_thinking=True, # Grok-4 supports reasoning mode
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True, # Function calling supported
|
||||
supports_json_mode=True, # Structured outputs supported
|
||||
supports_images=True, # Multimodal capabilities
|
||||
max_image_size_mb=20.0, # Standard image size limit
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="GROK-4 (256K context) - Frontier multimodal reasoning model with advanced capabilities",
|
||||
aliases=["grok", "grok4", "grok-4"],
|
||||
),
|
||||
"grok-3": ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name="grok-3",
|
||||
friendly_name="X.AI (Grok 3)",
|
||||
intelligence_score=13,
|
||||
context_window=131_072, # 131K tokens
|
||||
max_output_tokens=131072,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||
supports_images=False, # Assuming GROK is text-only for now
|
||||
max_image_size_mb=0.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||
aliases=["grok3"],
|
||||
),
|
||||
"grok-3-fast": ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name="grok-3-fast",
|
||||
friendly_name="X.AI (Grok 3 Fast)",
|
||||
intelligence_score=12,
|
||||
context_window=131_072, # 131K tokens
|
||||
max_output_tokens=131072,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||
supports_images=False, # Assuming GROK is text-only for now
|
||||
max_image_size_mb=0.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
||||
aliases=["grok3fast", "grokfast", "grok3-fast"],
|
||||
),
|
||||
}
|
||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
||||
_registry: Optional[XAIModelRegistry] = None
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize X.AI provider with API key."""
|
||||
# Set X.AI base URL
|
||||
kwargs.setdefault("base_url", "https://api.x.ai/v1")
|
||||
self._ensure_registry()
|
||||
super().__init__(api_key, **kwargs)
|
||||
self._invalidate_capability_cache()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registry access
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def _ensure_registry(cls, *, force_reload: bool = False) -> None:
|
||||
"""Load capability registry into MODEL_CAPABILITIES."""
|
||||
|
||||
if cls._registry is not None and not force_reload:
|
||||
return
|
||||
|
||||
try:
|
||||
registry = XAIModelRegistry()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.warning("Unable to load X.AI model registry: %s", exc)
|
||||
cls._registry = None
|
||||
cls.MODEL_CAPABILITIES = {}
|
||||
return
|
||||
|
||||
cls._registry = registry
|
||||
cls.MODEL_CAPABILITIES = dict(registry.model_map)
|
||||
|
||||
@classmethod
|
||||
def reload_registry(cls) -> None:
|
||||
"""Force registry reload (primarily for tests)."""
|
||||
|
||||
cls._ensure_registry(force_reload=True)
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
self._ensure_registry()
|
||||
return super().get_all_model_capabilities()
|
||||
|
||||
def get_model_registry(self) -> Optional[dict[str, ModelCapabilities]]:
|
||||
if self._registry is None:
|
||||
return None
|
||||
return dict(self._registry.model_map)
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -135,3 +117,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
return "grok-3-fast"
|
||||
# Fall back to any available model
|
||||
return allowed_models[0]
|
||||
|
||||
|
||||
# Load registry data at import time
|
||||
XAIModelProvider._ensure_registry()
|
||||
|
||||
Reference in New Issue
Block a user