diff --git a/providers/gemini.py b/providers/gemini.py index c51f96b..162fee7 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -15,12 +15,13 @@ from utils.image_utils import validate_image from .base import ModelProvider from .gemini_registry import GeminiModelRegistry +from .registry_provider_mixin import RegistryBackedProviderMixin from .shared import ModelCapabilities, ModelResponse, ProviderType logger = logging.getLogger(__name__) -class GeminiModelProvider(ModelProvider): +class GeminiModelProvider(RegistryBackedProviderMixin, ModelProvider): """First-party Gemini integration built on the official Google SDK. The provider advertises detailed thinking-mode budgets, handles optional @@ -28,8 +29,8 @@ class GeminiModelProvider(ModelProvider): request to the Gemini APIs. """ + REGISTRY_CLASS = GeminiModelRegistry MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {} - _registry: Optional[GeminiModelRegistry] = None # Thinking mode configurations - percentages of model's max_thinking_tokens # These percentages work across all models that support thinking @@ -59,43 +60,6 @@ class GeminiModelProvider(ModelProvider): self._timeout_override = self._resolve_http_timeout() self._invalidate_capability_cache() - # ------------------------------------------------------------------ - # Registry access - # ------------------------------------------------------------------ - - @classmethod - def _ensure_registry(cls, *, force_reload: bool = False) -> None: - """Load capability registry into MODEL_CAPABILITIES.""" - - if cls._registry is not None and not force_reload: - return - - try: - registry = GeminiModelRegistry() - except Exception as exc: # pragma: no cover - defensive logging - logger.warning("Unable to load Gemini model registry: %s", exc) - cls._registry = None - cls.MODEL_CAPABILITIES = {} - return - - cls._registry = registry - cls.MODEL_CAPABILITIES = dict(registry.model_map) - - @classmethod - def reload_registry(cls) -> None: - """Force registry reload (primarily for tests).""" - - cls._ensure_registry(force_reload=True) - - def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: - self._ensure_registry() - return super().get_all_model_capabilities() - - def get_model_registry(self) -> Optional[dict[str, ModelCapabilities]]: - if self._registry is None: - return None - return dict(self._registry.model_map) - # ------------------------------------------------------------------ # Capability surface # ------------------------------------------------------------------ diff --git a/providers/openai_provider.py b/providers/openai_provider.py index 9d72ec9..d40263a 100644 --- a/providers/openai_provider.py +++ b/providers/openai_provider.py @@ -8,12 +8,13 @@ if TYPE_CHECKING: from .openai_compatible import OpenAICompatibleProvider from .openai_registry import OpenAIModelRegistry +from .registry_provider_mixin import RegistryBackedProviderMixin from .shared import ModelCapabilities, ProviderType logger = logging.getLogger(__name__) -class OpenAIModelProvider(OpenAICompatibleProvider): +class OpenAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider): """Implementation that talks to api.openai.com using rich model metadata. In addition to the built-in catalogue, the provider can surface models @@ -21,8 +22,8 @@ class OpenAIModelProvider(OpenAICompatibleProvider): OpenAI-compatible gateways) while still respecting restriction policies. """ + REGISTRY_CLASS = OpenAIModelRegistry MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {} - _registry: Optional[OpenAIModelRegistry] = None def __init__(self, api_key: str, **kwargs): """Initialize OpenAI provider with API key.""" @@ -32,43 +33,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider): super().__init__(api_key, **kwargs) self._invalidate_capability_cache() - # ------------------------------------------------------------------ - # Registry access - # ------------------------------------------------------------------ - - @classmethod - def _ensure_registry(cls, *, force_reload: bool = False) -> None: - """Load capability registry into MODEL_CAPABILITIES.""" - - if cls._registry is not None and not force_reload: - return - - try: - registry = OpenAIModelRegistry() - except Exception as exc: # pragma: no cover - defensive logging - logger.warning("Unable to load OpenAI model registry: %s", exc) - cls._registry = None - cls.MODEL_CAPABILITIES = {} - return - - cls._registry = registry - cls.MODEL_CAPABILITIES = dict(registry.model_map) - - @classmethod - def reload_registry(cls) -> None: - """Force registry reload (primarily for tests).""" - - cls._ensure_registry(force_reload=True) - - def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: - self._ensure_registry() - return super().get_all_model_capabilities() - - def get_model_registry(self) -> Optional[dict[str, ModelCapabilities]]: - if self._registry is None: - return None - return dict(self._registry.model_map) - # ------------------------------------------------------------------ # Capability surface # ------------------------------------------------------------------ diff --git a/providers/registry_provider_mixin.py b/providers/registry_provider_mixin.py new file mode 100644 index 0000000..afc85a9 --- /dev/null +++ b/providers/registry_provider_mixin.py @@ -0,0 +1,84 @@ +"""Mixin for providers backed by capability registries. + +This mixin centralises the boilerplate for providers that expose their model +capabilities via JSON configuration files. Subclasses only need to set +``REGISTRY_CLASS`` to an appropriate :class:`CapabilityModelRegistry` and the +mix-in will take care of: + +* Populating ``MODEL_CAPABILITIES`` exactly once per process (with optional + reload support for tests). +* Lazily exposing the registry contents through the standard provider hooks + (:meth:`get_all_model_capabilities` and :meth:`get_model_registry`). +* Providing defensive logging when a registry cannot be constructed so the + provider can degrade gracefully instead of raising during import. + +Using this helper keeps individual provider implementations focused on their +SDK-specific behaviour while ensuring capability loading is consistent across +OpenAI, Gemini, X.AI, and other native backends. +""" + +from __future__ import annotations + +import logging +from typing import ClassVar + +from .model_registry_base import CapabilityModelRegistry +from .shared import ModelCapabilities + + +class RegistryBackedProviderMixin: + """Shared helper for providers that load capabilities from JSON registries.""" + + REGISTRY_CLASS: ClassVar[type[CapabilityModelRegistry] | None] = None + _registry: ClassVar[CapabilityModelRegistry | None] = None + MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {} + + @classmethod + def _registry_logger(cls) -> logging.Logger: + """Return the logger used for registry lifecycle messages.""" + return logging.getLogger(cls.__module__) + + @classmethod + def _ensure_registry(cls, *, force_reload: bool = False) -> None: + """Populate ``MODEL_CAPABILITIES`` from the configured registry. + + Args: + force_reload: When ``True`` the registry is re-created even if it + was previously loaded. This is primarily used by tests. + """ + + if cls.REGISTRY_CLASS is None: # pragma: no cover - defensive programming + raise RuntimeError(f"{cls.__name__} must define REGISTRY_CLASS.") + + if cls._registry is not None and not force_reload: + return + + try: + registry = cls.REGISTRY_CLASS() + except Exception as exc: # pragma: no cover - registry failures shouldn't break the provider + cls._registry_logger().warning("Unable to load %s registry: %s", cls.__name__, exc) + cls._registry = None + cls.MODEL_CAPABILITIES = {} + return + + cls._registry = registry + cls.MODEL_CAPABILITIES = dict(registry.model_map) + + @classmethod + def reload_registry(cls) -> None: + """Force a registry reload (used in tests).""" + + cls._ensure_registry(force_reload=True) + + def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: + """Return the registry-backed ``MODEL_CAPABILITIES`` map.""" + + self._ensure_registry() + return super().get_all_model_capabilities() + + def get_model_registry(self) -> dict[str, ModelCapabilities] | None: + """Return a copy of the underlying registry map when available.""" + + if self._registry is None: + return None + return dict(self._registry.model_map) diff --git a/providers/xai.py b/providers/xai.py index c51e2bb..0842f8e 100644 --- a/providers/xai.py +++ b/providers/xai.py @@ -7,13 +7,14 @@ if TYPE_CHECKING: from tools.models import ToolModelCategory from .openai_compatible import OpenAICompatibleProvider +from .registry_provider_mixin import RegistryBackedProviderMixin from .shared import ModelCapabilities, ProviderType from .xai_registry import XAIModelRegistry logger = logging.getLogger(__name__) -class XAIModelProvider(OpenAICompatibleProvider): +class XAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider): """Integration for X.AI's GROK models exposed over an OpenAI-style API. Publishes capability metadata for the officially supported deployments and @@ -22,8 +23,8 @@ class XAIModelProvider(OpenAICompatibleProvider): FRIENDLY_NAME = "X.AI" + REGISTRY_CLASS = XAIModelRegistry MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {} - _registry: Optional[XAIModelRegistry] = None def __init__(self, api_key: str, **kwargs): """Initialize X.AI provider with API key.""" @@ -33,43 +34,6 @@ class XAIModelProvider(OpenAICompatibleProvider): super().__init__(api_key, **kwargs) self._invalidate_capability_cache() - # ------------------------------------------------------------------ - # Registry access - # ------------------------------------------------------------------ - - @classmethod - def _ensure_registry(cls, *, force_reload: bool = False) -> None: - """Load capability registry into MODEL_CAPABILITIES.""" - - if cls._registry is not None and not force_reload: - return - - try: - registry = XAIModelRegistry() - except Exception as exc: # pragma: no cover - defensive logging - logger.warning("Unable to load X.AI model registry: %s", exc) - cls._registry = None - cls.MODEL_CAPABILITIES = {} - return - - cls._registry = registry - cls.MODEL_CAPABILITIES = dict(registry.model_map) - - @classmethod - def reload_registry(cls) -> None: - """Force registry reload (primarily for tests).""" - - cls._ensure_registry(force_reload=True) - - def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: - self._ensure_registry() - return super().get_all_model_capabilities() - - def get_model_registry(self) -> Optional[dict[str, ModelCapabilities]]: - if self._registry is None: - return None - return dict(self._registry.model_map) - def get_provider_type(self) -> ProviderType: """Get the provider type.""" return ProviderType.XAI diff --git a/pyproject.toml b/pyproject.toml index f9f91cd..b11d75c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,14 @@ py-modules = ["server", "config"] "*" = ["conf/*.json"] [tool.setuptools.data-files] -"conf" = ["conf/custom_models.json", "conf/openrouter_models.json", "conf/azure_models.json"] +"conf" = [ + "conf/custom_models.json", + "conf/openrouter_models.json", + "conf/azure_models.json", + "conf/openai_models.json", + "conf/gemini_models.json", + "conf/xai_models.json", +] [project.scripts] zen-mcp-server = "server:run"