refactor: new base class for model registry / loading
This commit is contained in:
@@ -15,12 +15,13 @@ from utils.image_utils import validate_image
|
|||||||
|
|
||||||
from .base import ModelProvider
|
from .base import ModelProvider
|
||||||
from .gemini_registry import GeminiModelRegistry
|
from .gemini_registry import GeminiModelRegistry
|
||||||
|
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||||
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GeminiModelProvider(ModelProvider):
|
class GeminiModelProvider(RegistryBackedProviderMixin, ModelProvider):
|
||||||
"""First-party Gemini integration built on the official Google SDK.
|
"""First-party Gemini integration built on the official Google SDK.
|
||||||
|
|
||||||
The provider advertises detailed thinking-mode budgets, handles optional
|
The provider advertises detailed thinking-mode budgets, handles optional
|
||||||
@@ -28,8 +29,8 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
request to the Gemini APIs.
|
request to the Gemini APIs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
REGISTRY_CLASS = GeminiModelRegistry
|
||||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
||||||
_registry: Optional[GeminiModelRegistry] = None
|
|
||||||
|
|
||||||
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
||||||
# These percentages work across all models that support thinking
|
# These percentages work across all models that support thinking
|
||||||
@@ -59,43 +60,6 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
self._timeout_override = self._resolve_http_timeout()
|
self._timeout_override = self._resolve_http_timeout()
|
||||||
self._invalidate_capability_cache()
|
self._invalidate_capability_cache()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Registry access
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _ensure_registry(cls, *, force_reload: bool = False) -> None:
|
|
||||||
"""Load capability registry into MODEL_CAPABILITIES."""
|
|
||||||
|
|
||||||
if cls._registry is not None and not force_reload:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
registry = GeminiModelRegistry()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
logger.warning("Unable to load Gemini model registry: %s", exc)
|
|
||||||
cls._registry = None
|
|
||||||
cls.MODEL_CAPABILITIES = {}
|
|
||||||
return
|
|
||||||
|
|
||||||
cls._registry = registry
|
|
||||||
cls.MODEL_CAPABILITIES = dict(registry.model_map)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def reload_registry(cls) -> None:
|
|
||||||
"""Force registry reload (primarily for tests)."""
|
|
||||||
|
|
||||||
cls._ensure_registry(force_reload=True)
|
|
||||||
|
|
||||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
|
||||||
self._ensure_registry()
|
|
||||||
return super().get_all_model_capabilities()
|
|
||||||
|
|
||||||
def get_model_registry(self) -> Optional[dict[str, ModelCapabilities]]:
|
|
||||||
if self._registry is None:
|
|
||||||
return None
|
|
||||||
return dict(self._registry.model_map)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Capability surface
|
# Capability surface
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
from .openai_registry import OpenAIModelRegistry
|
from .openai_registry import OpenAIModelRegistry
|
||||||
|
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||||
from .shared import ModelCapabilities, ProviderType
|
from .shared import ModelCapabilities, ProviderType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelProvider(OpenAICompatibleProvider):
|
class OpenAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
|
||||||
"""Implementation that talks to api.openai.com using rich model metadata.
|
"""Implementation that talks to api.openai.com using rich model metadata.
|
||||||
|
|
||||||
In addition to the built-in catalogue, the provider can surface models
|
In addition to the built-in catalogue, the provider can surface models
|
||||||
@@ -21,8 +22,8 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
OpenAI-compatible gateways) while still respecting restriction policies.
|
OpenAI-compatible gateways) while still respecting restriction policies.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
REGISTRY_CLASS = OpenAIModelRegistry
|
||||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
||||||
_registry: Optional[OpenAIModelRegistry] = None
|
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize OpenAI provider with API key."""
|
"""Initialize OpenAI provider with API key."""
|
||||||
@@ -32,43 +33,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
self._invalidate_capability_cache()
|
self._invalidate_capability_cache()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Registry access
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _ensure_registry(cls, *, force_reload: bool = False) -> None:
|
|
||||||
"""Load capability registry into MODEL_CAPABILITIES."""
|
|
||||||
|
|
||||||
if cls._registry is not None and not force_reload:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
registry = OpenAIModelRegistry()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
logger.warning("Unable to load OpenAI model registry: %s", exc)
|
|
||||||
cls._registry = None
|
|
||||||
cls.MODEL_CAPABILITIES = {}
|
|
||||||
return
|
|
||||||
|
|
||||||
cls._registry = registry
|
|
||||||
cls.MODEL_CAPABILITIES = dict(registry.model_map)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def reload_registry(cls) -> None:
|
|
||||||
"""Force registry reload (primarily for tests)."""
|
|
||||||
|
|
||||||
cls._ensure_registry(force_reload=True)
|
|
||||||
|
|
||||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
|
||||||
self._ensure_registry()
|
|
||||||
return super().get_all_model_capabilities()
|
|
||||||
|
|
||||||
def get_model_registry(self) -> Optional[dict[str, ModelCapabilities]]:
|
|
||||||
if self._registry is None:
|
|
||||||
return None
|
|
||||||
return dict(self._registry.model_map)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Capability surface
|
# Capability surface
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
84
providers/registry_provider_mixin.py
Normal file
84
providers/registry_provider_mixin.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Mixin for providers backed by capability registries.
|
||||||
|
|
||||||
|
This mixin centralises the boilerplate for providers that expose their model
|
||||||
|
capabilities via JSON configuration files. Subclasses only need to set
|
||||||
|
``REGISTRY_CLASS`` to an appropriate :class:`CapabilityModelRegistry` and the
|
||||||
|
mix-in will take care of:
|
||||||
|
|
||||||
|
* Populating ``MODEL_CAPABILITIES`` exactly once per process (with optional
|
||||||
|
reload support for tests).
|
||||||
|
* Lazily exposing the registry contents through the standard provider hooks
|
||||||
|
(:meth:`get_all_model_capabilities` and :meth:`get_model_registry`).
|
||||||
|
* Providing defensive logging when a registry cannot be constructed so the
|
||||||
|
provider can degrade gracefully instead of raising during import.
|
||||||
|
|
||||||
|
Using this helper keeps individual provider implementations focused on their
|
||||||
|
SDK-specific behaviour while ensuring capability loading is consistent across
|
||||||
|
OpenAI, Gemini, X.AI, and other native backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
from .model_registry_base import CapabilityModelRegistry
|
||||||
|
from .shared import ModelCapabilities
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryBackedProviderMixin:
|
||||||
|
"""Shared helper for providers that load capabilities from JSON registries."""
|
||||||
|
|
||||||
|
REGISTRY_CLASS: ClassVar[type[CapabilityModelRegistry] | None] = None
|
||||||
|
_registry: ClassVar[CapabilityModelRegistry | None] = None
|
||||||
|
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _registry_logger(cls) -> logging.Logger:
|
||||||
|
"""Return the logger used for registry lifecycle messages."""
|
||||||
|
return logging.getLogger(cls.__module__)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _ensure_registry(cls, *, force_reload: bool = False) -> None:
|
||||||
|
"""Populate ``MODEL_CAPABILITIES`` from the configured registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_reload: When ``True`` the registry is re-created even if it
|
||||||
|
was previously loaded. This is primarily used by tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if cls.REGISTRY_CLASS is None: # pragma: no cover - defensive programming
|
||||||
|
raise RuntimeError(f"{cls.__name__} must define REGISTRY_CLASS.")
|
||||||
|
|
||||||
|
if cls._registry is not None and not force_reload:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
registry = cls.REGISTRY_CLASS()
|
||||||
|
except Exception as exc: # pragma: no cover - registry failures shouldn't break the provider
|
||||||
|
cls._registry_logger().warning("Unable to load %s registry: %s", cls.__name__, exc)
|
||||||
|
cls._registry = None
|
||||||
|
cls.MODEL_CAPABILITIES = {}
|
||||||
|
return
|
||||||
|
|
||||||
|
cls._registry = registry
|
||||||
|
cls.MODEL_CAPABILITIES = dict(registry.model_map)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reload_registry(cls) -> None:
|
||||||
|
"""Force a registry reload (used in tests)."""
|
||||||
|
|
||||||
|
cls._ensure_registry(force_reload=True)
|
||||||
|
|
||||||
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Return the registry-backed ``MODEL_CAPABILITIES`` map."""
|
||||||
|
|
||||||
|
self._ensure_registry()
|
||||||
|
return super().get_all_model_capabilities()
|
||||||
|
|
||||||
|
def get_model_registry(self) -> dict[str, ModelCapabilities] | None:
|
||||||
|
"""Return a copy of the underlying registry map when available."""
|
||||||
|
|
||||||
|
if self._registry is None:
|
||||||
|
return None
|
||||||
|
return dict(self._registry.model_map)
|
||||||
@@ -7,13 +7,14 @@ if TYPE_CHECKING:
|
|||||||
from tools.models import ToolModelCategory
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||||
from .shared import ModelCapabilities, ProviderType
|
from .shared import ModelCapabilities, ProviderType
|
||||||
from .xai_registry import XAIModelRegistry
|
from .xai_registry import XAIModelRegistry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class XAIModelProvider(OpenAICompatibleProvider):
|
class XAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
|
||||||
"""Integration for X.AI's GROK models exposed over an OpenAI-style API.
|
"""Integration for X.AI's GROK models exposed over an OpenAI-style API.
|
||||||
|
|
||||||
Publishes capability metadata for the officially supported deployments and
|
Publishes capability metadata for the officially supported deployments and
|
||||||
@@ -22,8 +23,8 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
FRIENDLY_NAME = "X.AI"
|
FRIENDLY_NAME = "X.AI"
|
||||||
|
|
||||||
|
REGISTRY_CLASS = XAIModelRegistry
|
||||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
||||||
_registry: Optional[XAIModelRegistry] = None
|
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize X.AI provider with API key."""
|
"""Initialize X.AI provider with API key."""
|
||||||
@@ -33,43 +34,6 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
self._invalidate_capability_cache()
|
self._invalidate_capability_cache()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Registry access
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _ensure_registry(cls, *, force_reload: bool = False) -> None:
|
|
||||||
"""Load capability registry into MODEL_CAPABILITIES."""
|
|
||||||
|
|
||||||
if cls._registry is not None and not force_reload:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
registry = XAIModelRegistry()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
logger.warning("Unable to load X.AI model registry: %s", exc)
|
|
||||||
cls._registry = None
|
|
||||||
cls.MODEL_CAPABILITIES = {}
|
|
||||||
return
|
|
||||||
|
|
||||||
cls._registry = registry
|
|
||||||
cls.MODEL_CAPABILITIES = dict(registry.model_map)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def reload_registry(cls) -> None:
|
|
||||||
"""Force registry reload (primarily for tests)."""
|
|
||||||
|
|
||||||
cls._ensure_registry(force_reload=True)
|
|
||||||
|
|
||||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
|
||||||
self._ensure_registry()
|
|
||||||
return super().get_all_model_capabilities()
|
|
||||||
|
|
||||||
def get_model_registry(self) -> Optional[dict[str, ModelCapabilities]]:
|
|
||||||
if self._registry is None:
|
|
||||||
return None
|
|
||||||
return dict(self._registry.model_map)
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.XAI
|
return ProviderType.XAI
|
||||||
|
|||||||
@@ -21,7 +21,14 @@ py-modules = ["server", "config"]
|
|||||||
"*" = ["conf/*.json"]
|
"*" = ["conf/*.json"]
|
||||||
|
|
||||||
[tool.setuptools.data-files]
|
[tool.setuptools.data-files]
|
||||||
"conf" = ["conf/custom_models.json", "conf/openrouter_models.json", "conf/azure_models.json"]
|
"conf" = [
|
||||||
|
"conf/custom_models.json",
|
||||||
|
"conf/openrouter_models.json",
|
||||||
|
"conf/azure_models.json",
|
||||||
|
"conf/openai_models.json",
|
||||||
|
"conf/gemini_models.json",
|
||||||
|
"conf/xai_models.json",
|
||||||
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
zen-mcp-server = "server:run"
|
zen-mcp-server = "server:run"
|
||||||
|
|||||||
Reference in New Issue
Block a user