Native support for xAI Grok3
Model shorthand mapping related fixes Comprehensive auto-mode related tests
This commit is contained in:
@@ -23,9 +23,11 @@ Inherit from `ModelProvider` when:
|
||||
### Option B: OpenAI-Compatible Provider (Simplified)
|
||||
Inherit from `OpenAICompatibleProvider` when:
|
||||
- Your API follows OpenAI's chat completion format
|
||||
- You want to reuse existing implementation for `generate_content` and `count_tokens`
|
||||
- You want to reuse existing implementation for most functionality
|
||||
- You only need to define model capabilities and validation
|
||||
|
||||
⚠️ **CRITICAL**: If your provider has model aliases (shorthands), you **MUST** override `generate_content()` to resolve aliases before API calls. See implementation example below.
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
### 1. Add Provider Type to Enum
|
||||
@@ -177,8 +179,11 @@ For providers with OpenAI-compatible APIs, the implementation is much simpler:
|
||||
"""Example provider using OpenAI-compatible interface."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
@@ -268,7 +273,31 @@ class ExampleProvider(OpenAICompatibleProvider):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
# Note: generate_content and count_tokens are inherited from OpenAICompatibleProvider
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using API with proper model name resolution."""
|
||||
# CRITICAL: Resolve model alias before making API call
|
||||
# This ensures aliases like "large" get sent as "example-model-large" to the API
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Call parent implementation with resolved model name
|
||||
return super().generate_content(
|
||||
prompt=prompt,
|
||||
model_name=resolved_model_name,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Note: count_tokens is inherited from OpenAICompatibleProvider
|
||||
```
|
||||
|
||||
### 3. Update Registry Configuration
|
||||
@@ -291,7 +320,32 @@ def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]
|
||||
# ... rest of the method
|
||||
```
|
||||
|
||||
### 4. Register Provider in server.py
|
||||
### 4. Configure Docker Environment Variables
|
||||
|
||||
**CRITICAL**: You must add your provider's environment variables to `docker-compose.yml` for them to be available in the Docker container.
|
||||
|
||||
Add your API key and restriction variables to the `environment` section:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
zen-mcp:
|
||||
# ... other configuration ...
|
||||
environment:
|
||||
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||
- EXAMPLE_API_KEY=${EXAMPLE_API_KEY:-} # Add this line
|
||||
# OpenRouter support
|
||||
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
|
||||
# ... other variables ...
|
||||
# Model usage restrictions
|
||||
- OPENAI_ALLOWED_MODELS=${OPENAI_ALLOWED_MODELS:-}
|
||||
- GOOGLE_ALLOWED_MODELS=${GOOGLE_ALLOWED_MODELS:-}
|
||||
- EXAMPLE_ALLOWED_MODELS=${EXAMPLE_ALLOWED_MODELS:-} # Add this line
|
||||
```
|
||||
|
||||
⚠️ **Without this step**, the Docker container won't have access to your environment variables, and your provider won't be registered even if the API key is set in your `.env` file.
|
||||
|
||||
### 5. Register Provider in server.py
|
||||
|
||||
The `configure_providers()` function in `server.py` handles provider registration. You need to:
|
||||
|
||||
@@ -355,7 +409,7 @@ def configure_providers():
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Add Model Capabilities for Auto Mode
|
||||
### 6. Add Model Capabilities for Auto Mode
|
||||
|
||||
Update `config.py` to add your models to `MODEL_CAPABILITIES_DESC`:
|
||||
|
||||
@@ -372,9 +426,9 @@ MODEL_CAPABILITIES_DESC = {
|
||||
}
|
||||
```
|
||||
|
||||
### 6. Update Documentation
|
||||
### 7. Update Documentation
|
||||
|
||||
#### 6.1. Update README.md
|
||||
#### 7.1. Update README.md
|
||||
|
||||
Add your provider to the quickstart section:
|
||||
|
||||
@@ -396,9 +450,9 @@ Also update the .env file example:
|
||||
# EXAMPLE_API_KEY=your-example-api-key-here # Add this
|
||||
```
|
||||
|
||||
### 7. Write Tests
|
||||
### 8. Write Tests
|
||||
|
||||
#### 7.1. Unit Tests
|
||||
#### 8.1. Unit Tests
|
||||
|
||||
Create `tests/test_example_provider.py`:
|
||||
|
||||
@@ -460,7 +514,7 @@ class TestExampleProvider:
|
||||
assert capabilities.temperature_constraint.max_temp == 2.0
|
||||
```
|
||||
|
||||
#### 7.2. Simulator Tests (Real-World Validation)
|
||||
#### 8.2. Simulator Tests (Real-World Validation)
|
||||
|
||||
Create a simulator test to validate that your provider works correctly in real-world scenarios. Create `simulator_tests/test_example_models.py`:
|
||||
|
||||
@@ -696,6 +750,36 @@ SUPPORTED_MODELS = {
|
||||
|
||||
The `_resolve_model_name()` method handles this mapping automatically.
|
||||
|
||||
## Critical Implementation Requirements
|
||||
|
||||
### Alias Resolution for OpenAI-Compatible Providers
|
||||
|
||||
If you inherit from `OpenAICompatibleProvider` and define model aliases, you **MUST** override `generate_content()` to resolve aliases before API calls. This is because:
|
||||
|
||||
1. **The base `OpenAICompatibleProvider.generate_content()`** sends the original model name directly to the API
|
||||
2. **Your API expects the full model name**, not the alias
|
||||
3. **Without resolution**, requests like `model="large"` will fail with 404/400 errors
|
||||
|
||||
**Examples of providers that need this:**
|
||||
- XAI provider: `"grok"` → `"grok-3"`
|
||||
- OpenAI provider: `"mini"` → `"o4-mini"`
|
||||
- Custom provider: `"fast"` → `"llama-3.1-8b-instruct"`
|
||||
|
||||
**Example implementation pattern:**
|
||||
```python
|
||||
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
||||
# CRITICAL: Resolve alias before API call
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Pass resolved name to parent
|
||||
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
|
||||
```
|
||||
|
||||
**Providers that DON'T need this:**
|
||||
- Gemini provider (has its own generate_content implementation)
|
||||
- OpenRouter provider (already implements this pattern)
|
||||
- Providers without aliases
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always validate model names** against supported models and restrictions
|
||||
@@ -715,6 +799,7 @@ Before submitting your PR:
|
||||
- [ ] Provider implementation complete with all required methods
|
||||
- [ ] API key mapping added to `_get_api_key_for_provider()` in `providers/registry.py`
|
||||
- [ ] Provider added to `PROVIDER_PRIORITY_ORDER` in `registry.py` (if native provider)
|
||||
- [ ] **Environment variables added to `docker-compose.yml`** (API key and restrictions)
|
||||
- [ ] Provider imported and registered in `server.py`'s `configure_providers()`
|
||||
- [ ] API key checking added to `configure_providers()` function
|
||||
- [ ] Error message updated to include new provider
|
||||
|
||||
Reference in New Issue
Block a user