diff --git a/docs/adding_providers.md b/docs/adding_providers.md index 21abd53..c29404c 100644 --- a/docs/adding_providers.md +++ b/docs/adding_providers.md @@ -15,7 +15,7 @@ Each provider: **Option A: Full Provider (`ModelProvider`)** - For APIs with unique features or custom authentication - Complete control over API calls and response handling -- Implement `generate_content()` and `get_provider_type()`; override `get_all_model_capabilities()` to expose your catalogue and extend `_lookup_capabilities()` / `_ensure_model_allowed()` only when you need registry lookups or custom restriction rules (override `count_tokens()` only when you have a provider-accurate tokenizer) +- Populate `MODEL_CAPABILITIES`, implement `generate_content()` and `get_provider_type()`, and only override `get_all_model_capabilities()` / `_lookup_capabilities()` when your catalogue comes from a registry or remote source (override `count_tokens()` only when you have a provider-accurate tokenizer) **Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)** - For APIs that follow OpenAI's chat completion format diff --git a/providers/base.py b/providers/base.py index 05b688a..fd316c1 100644 --- a/providers/base.py +++ b/providers/base.py @@ -70,8 +70,11 @@ class ModelProvider(ABC): return self._finalise_capabilities(capabilities, resolved_name, model_name) def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: - """Return the provider's statically declared model capabilities.""" + """Return statically declared capabilities when available.""" + model_map = getattr(self, "MODEL_CAPABILITIES", None) + if isinstance(model_map, dict) and model_map: + return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)} return {} def list_models( diff --git a/providers/gemini.py b/providers/gemini.py index f333cf2..de3fa4d 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -135,11 +135,6 @@ class GeminiModelProvider(ModelProvider): # Capability surface # ------------------------------------------------------------------ - def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: - """Return statically defined Gemini capabilities.""" - - return dict(self.MODEL_CAPABILITIES) - # ------------------------------------------------------------------ # Client access # ------------------------------------------------------------------ diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py index 8714186..2da361d 100644 --- a/providers/openai_compatible.py +++ b/providers/openai_compatible.py @@ -79,14 +79,6 @@ class OpenAICompatibleProvider(ModelProvider): f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}" ) - def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: - """Return statically declared capabilities for OpenAI-compatible providers.""" - - model_map = getattr(self, "MODEL_CAPABILITIES", None) - if isinstance(model_map, dict): - return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)} - return {} - def _parse_allowed_models(self) -> Optional[set[str]]: """Parse allowed models from environment variable.