refactor: new base class for model registry / loading
This commit is contained in:
@@ -8,12 +8,13 @@ if TYPE_CHECKING:
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openai_registry import OpenAIModelRegistry
|
||||
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
class OpenAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
|
||||
"""Implementation that talks to api.openai.com using rich model metadata.
|
||||
|
||||
In addition to the built-in catalogue, the provider can surface models
|
||||
@@ -21,8 +22,8 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
OpenAI-compatible gateways) while still respecting restriction policies.
|
||||
"""
|
||||
|
||||
REGISTRY_CLASS = OpenAIModelRegistry
|
||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
||||
_registry: Optional[OpenAIModelRegistry] = None
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize OpenAI provider with API key."""
|
||||
@@ -32,43 +33,6 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
super().__init__(api_key, **kwargs)
|
||||
self._invalidate_capability_cache()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registry access
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def _ensure_registry(cls, *, force_reload: bool = False) -> None:
|
||||
"""Load capability registry into MODEL_CAPABILITIES."""
|
||||
|
||||
if cls._registry is not None and not force_reload:
|
||||
return
|
||||
|
||||
try:
|
||||
registry = OpenAIModelRegistry()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.warning("Unable to load OpenAI model registry: %s", exc)
|
||||
cls._registry = None
|
||||
cls.MODEL_CAPABILITIES = {}
|
||||
return
|
||||
|
||||
cls._registry = registry
|
||||
cls.MODEL_CAPABILITIES = dict(registry.model_map)
|
||||
|
||||
@classmethod
|
||||
def reload_registry(cls) -> None:
|
||||
"""Force registry reload (primarily for tests)."""
|
||||
|
||||
cls._ensure_registry(force_reload=True)
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
self._ensure_registry()
|
||||
return super().get_all_model_capabilities()
|
||||
|
||||
def get_model_registry(self) -> Optional[dict[str, ModelCapabilities]]:
|
||||
if self._registry is None:
|
||||
return None
|
||||
return dict(self._registry.model_map)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Capability surface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user