refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract
docs: document provider base class refactor: cleanup custom provider, it should only deal with `is_custom` model configurations fix: make sure openrouter provider does not load `is_custom` models fix: listmodels tool cleanup
This commit is contained in:
@@ -7,7 +7,7 @@ This guide explains how to add support for a new AI model provider to the Zen MC
|
||||
Each provider:
|
||||
- Inherits from `ModelProvider` (base class) or `OpenAICompatibleProvider` (for OpenAI-compatible APIs)
|
||||
- Defines supported models using `ModelCapabilities` objects
|
||||
- Implements a few core abstract methods
|
||||
- Implements the minimal abstract hooks (`get_provider_type()` and `generate_content()`)
|
||||
- Gets registered automatically via environment variables
|
||||
|
||||
## Choose Your Implementation Path
|
||||
@@ -15,11 +15,11 @@ Each provider:
|
||||
**Option A: Full Provider (`ModelProvider`)**
|
||||
- For APIs with unique features or custom authentication
|
||||
- Complete control over API calls and response handling
|
||||
- Required methods: `generate_content()`, `get_capabilities()`, `validate_model_name()`, `get_provider_type()` (override `count_tokens()` only when you have a provider-accurate tokenizer)
|
||||
- Implement `generate_content()` and `get_provider_type()`; override `get_all_model_capabilities()` to expose your catalogue and extend `_lookup_capabilities()` / `_ensure_model_allowed()` only when you need registry lookups or custom restriction rules (override `count_tokens()` only when you have a provider-accurate tokenizer)
|
||||
|
||||
**Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)**
|
||||
- For APIs that follow OpenAI's chat completion format
|
||||
- Only need to define: model configurations, capabilities, and validation
|
||||
- Supply `MODEL_CAPABILITIES`, override `get_provider_type()`, and optionally adjust configuration (the base class handles alias resolution, validation, and request wiring)
|
||||
- Inherits all API handling automatically
|
||||
|
||||
⚠️ **Important**: If using aliases (like `"gpt"` → `"gpt-4"`), override `generate_content()` to resolve them before API calls.
|
||||
@@ -62,8 +62,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ExampleModelProvider(ModelProvider):
|
||||
"""Example model provider implementation."""
|
||||
|
||||
# Define models using ModelCapabilities objects (like Gemini provider)
|
||||
|
||||
MODEL_CAPABILITIES = {
|
||||
"example-large": ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
@@ -87,51 +86,47 @@ class ExampleModelProvider(ModelProvider):
|
||||
aliases=["small", "fast"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
super().__init__(api_key, **kwargs)
|
||||
# Initialize your API client here
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
return dict(self.MODEL_CAPABILITIES)
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
return ProviderType.EXAMPLE
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
|
||||
# Apply restrictions if needed
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed.")
|
||||
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
|
||||
def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
|
||||
# Your API call logic here
|
||||
# response = your_api_client.generate(...)
|
||||
|
||||
|
||||
return ModelResponse(
|
||||
content="Generated response", # From your API
|
||||
content="Generated response",
|
||||
usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
||||
model_name=resolved_name,
|
||||
friendly_name="Example",
|
||||
provider=ProviderType.EXAMPLE,
|
||||
)
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
return ProviderType.EXAMPLE
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
return resolved_name in self.MODEL_CAPABILITIES
|
||||
```
|
||||
|
||||
`ModelProvider.count_tokens()` uses a simple 4-characters-per-token estimate so
|
||||
providers work out of the box. Override the method only when you can call into
|
||||
the provider's real tokenizer (for example, the OpenAI-compatible base class
|
||||
already integrates `tiktoken`).
|
||||
`ModelProvider.get_capabilities()` automatically resolves aliases, enforces the
|
||||
shared restriction service, and returns the correct `ModelCapabilities`
|
||||
instance. Override `_lookup_capabilities()` only when you source capabilities
|
||||
from a registry or remote API. `ModelProvider.count_tokens()` uses a simple
|
||||
4-characters-per-token estimate so providers work out of the box—override it
|
||||
only when you can call the provider's real tokenizer (for example, the
|
||||
OpenAI-compatible base class integrates `tiktoken`).
|
||||
|
||||
#### Option B: OpenAI-Compatible Provider (Simplified)
|
||||
|
||||
@@ -172,26 +167,16 @@ class ExampleProvider(OpenAICompatibleProvider):
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
kwargs.setdefault("base_url", "https://api.example.com/v1")
|
||||
super().__init__(api_key, **kwargs)
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
return ProviderType.EXAMPLE
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
return resolved_name in self.MODEL_CAPABILITIES
|
||||
|
||||
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
||||
# IMPORTANT: Resolve aliases before API call
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
|
||||
```
|
||||
|
||||
`OpenAICompatibleProvider` already exposes the declared models via
|
||||
`MODEL_CAPABILITIES`, resolves aliases through the shared base pipeline, and
|
||||
enforces restrictions. Most subclasses only need to provide the class metadata
|
||||
shown above.
|
||||
|
||||
### 3. Register Your Provider
|
||||
|
||||
Add environment variable mapping in `providers/registry.py`:
|
||||
@@ -237,15 +222,11 @@ DISABLED_TOOLS=debug,tracer
|
||||
Create basic tests to verify your implementation:
|
||||
|
||||
```python
|
||||
# Test model validation
|
||||
provider = ExampleModelProvider("test-key")
|
||||
assert provider.validate_model_name("large") == True
|
||||
assert provider.validate_model_name("unknown") == False
|
||||
|
||||
# Test capabilities
|
||||
caps = provider.get_capabilities("large")
|
||||
assert caps.context_window > 0
|
||||
assert caps.provider == ProviderType.EXAMPLE
|
||||
provider = ExampleModelProvider("test-key")
|
||||
capabilities = provider.get_capabilities("large")
|
||||
assert capabilities.context_window > 0
|
||||
assert capabilities.provider == ProviderType.EXAMPLE
|
||||
```
|
||||
|
||||
|
||||
@@ -259,31 +240,19 @@ When a user requests a model, providers are checked in priority order:
|
||||
3. **OpenRouter** - catch-all for everything else
|
||||
|
||||
### Model Validation
|
||||
Your `validate_model_name()` should **only** return `True` for models you explicitly support:
|
||||
|
||||
```python
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
return resolved_name in self.MODEL_CAPABILITIES # Be specific!
|
||||
```
|
||||
`ModelProvider.validate_model_name()` delegates to `get_capabilities()` so most
|
||||
providers can rely on the shared implementation. Override it only when you need
|
||||
to opt out of that pipeline—for example, `CustomProvider` declines OpenRouter
|
||||
models so they fall through to the dedicated OpenRouter provider.
|
||||
|
||||
### Model Aliases
|
||||
The base class handles alias resolution automatically via the `aliases` field in `ModelCapabilities`.
|
||||
Aliases declared on `ModelCapabilities` are applied automatically via
|
||||
`_resolve_model_name()`, and both the validation and request flows call it
|
||||
before touching your SDK. Override `generate_content()` only when your provider
|
||||
needs additional alias handling beyond the shared behaviour.
|
||||
|
||||
## Important Notes
|
||||
|
||||
### Alias Resolution in OpenAI-Compatible Providers
|
||||
If using `OpenAICompatibleProvider` with aliases, **you must override `generate_content()`** to resolve aliases before API calls:
|
||||
|
||||
```python
|
||||
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
||||
# Resolve alias before API call
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
|
||||
```
|
||||
|
||||
Without this, API calls with aliases like `"large"` will fail because your API doesn't recognize the alias.
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Be specific in model validation** - only accept models you actually support
|
||||
|
||||
Reference in New Issue
Block a user