refactor: new base class for model registry / loading

This commit is contained in:
Fahad
2025-10-07 12:31:34 +04:00
parent 4b988029fa
commit 02d13da897
5 changed files with 101 additions and 118 deletions

View File

@@ -7,13 +7,14 @@ if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .openai_compatible import OpenAICompatibleProvider
from .registry_provider_mixin import RegistryBackedProviderMixin
from .shared import ModelCapabilities, ProviderType
from .xai_registry import XAIModelRegistry
logger = logging.getLogger(__name__)
class XAIModelProvider(OpenAICompatibleProvider):
class XAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
"""Integration for X.AI's GROK models exposed over an OpenAI-style API.
Publishes capability metadata for the officially supported deployments and
@@ -22,8 +23,8 @@ class XAIModelProvider(OpenAICompatibleProvider):
FRIENDLY_NAME = "X.AI"
REGISTRY_CLASS = XAIModelRegistry
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
_registry: Optional[XAIModelRegistry] = None
def __init__(self, api_key: str, **kwargs):
"""Initialize X.AI provider with API key."""
@@ -33,43 +34,6 @@ class XAIModelProvider(OpenAICompatibleProvider):
super().__init__(api_key, **kwargs)
self._invalidate_capability_cache()
# ------------------------------------------------------------------
# Registry access
# ------------------------------------------------------------------
@classmethod
def _ensure_registry(cls, *, force_reload: bool = False) -> None:
"""Load capability registry into MODEL_CAPABILITIES."""
if cls._registry is not None and not force_reload:
return
try:
registry = XAIModelRegistry()
except Exception as exc: # pragma: no cover - defensive logging
logger.warning("Unable to load X.AI model registry: %s", exc)
cls._registry = None
cls.MODEL_CAPABILITIES = {}
return
cls._registry = registry
cls.MODEL_CAPABILITIES = dict(registry.model_map)
@classmethod
def reload_registry(cls) -> None:
"""Force registry reload (primarily for tests)."""
cls._ensure_registry(force_reload=True)
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
self._ensure_registry()
return super().get_all_model_capabilities()
def get_model_registry(self) -> Optional[dict[str, ModelCapabilities]]:
if self._registry is None:
return None
return dict(self._registry.model_map)
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.XAI