Native support for xAI Grok3

Model shorthand mapping related fixes
Comprehensive auto-mode related tests
This commit is contained in:
Fahad
2025-06-15 12:21:44 +04:00
parent 4becd70a82
commit 6304b7af6b
24 changed files with 2278 additions and 58 deletions

View File

@@ -23,9 +23,11 @@ Inherit from `ModelProvider` when:
### Option B: OpenAI-Compatible Provider (Simplified)
Inherit from `OpenAICompatibleProvider` when:
- Your API follows OpenAI's chat completion format
- You want to reuse existing implementation for `generate_content` and `count_tokens`
- You want to reuse existing implementation for most functionality
- You only need to define model capabilities and validation
⚠️ **CRITICAL**: If your provider has model aliases (shorthands), you **MUST** override `generate_content()` to resolve aliases before API calls. See implementation example below.
## Step-by-Step Guide
### 1. Add Provider Type to Enum
@@ -177,8 +179,11 @@ For providers with OpenAI-compatible APIs, the implementation is much simpler:
"""Example provider using OpenAI-compatible interface."""
import logging
from typing import Optional
from .base import (
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
@@ -268,7 +273,31 @@ class ExampleProvider(OpenAICompatibleProvider):
return shorthand_value
return model_name
# Note: generate_content and count_tokens are inherited from OpenAICompatibleProvider
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using API with proper model name resolution."""
# CRITICAL: Resolve model alias before making API call
# This ensures aliases like "large" get sent as "example-model-large" to the API
resolved_model_name = self._resolve_model_name(model_name)
# Call parent implementation with resolved model name
return super().generate_content(
prompt=prompt,
model_name=resolved_model_name,
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs,
)
# Note: count_tokens is inherited from OpenAICompatibleProvider
```
### 3. Update Registry Configuration
@@ -291,7 +320,32 @@ def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]
# ... rest of the method
```
### 4. Register Provider in server.py
### 4. Configure Docker Environment Variables
**CRITICAL**: You must add your provider's environment variables to `docker-compose.yml` for them to be available in the Docker container.
Add your API key and restriction variables to the `environment` section:
```yaml
services:
zen-mcp:
# ... other configuration ...
environment:
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
- EXAMPLE_API_KEY=${EXAMPLE_API_KEY:-} # Add this line
# OpenRouter support
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
# ... other variables ...
# Model usage restrictions
- OPENAI_ALLOWED_MODELS=${OPENAI_ALLOWED_MODELS:-}
- GOOGLE_ALLOWED_MODELS=${GOOGLE_ALLOWED_MODELS:-}
- EXAMPLE_ALLOWED_MODELS=${EXAMPLE_ALLOWED_MODELS:-} # Add this line
```
⚠️ **Without this step**, the Docker container won't have access to your environment variables, and your provider won't be registered even if the API key is set in your `.env` file.
### 5. Register Provider in server.py
The `configure_providers()` function in `server.py` handles provider registration. You need to:
@@ -355,7 +409,7 @@ def configure_providers():
)
```
### 5. Add Model Capabilities for Auto Mode
### 6. Add Model Capabilities for Auto Mode
Update `config.py` to add your models to `MODEL_CAPABILITIES_DESC`:
@@ -372,9 +426,9 @@ MODEL_CAPABILITIES_DESC = {
}
```
### 6. Update Documentation
### 7. Update Documentation
#### 6.1. Update README.md
#### 7.1. Update README.md
Add your provider to the quickstart section:
@@ -396,9 +450,9 @@ Also update the .env file example:
# EXAMPLE_API_KEY=your-example-api-key-here # Add this
```
### 7. Write Tests
### 8. Write Tests
#### 7.1. Unit Tests
#### 8.1. Unit Tests
Create `tests/test_example_provider.py`:
@@ -460,7 +514,7 @@ class TestExampleProvider:
assert capabilities.temperature_constraint.max_temp == 2.0
```
#### 7.2. Simulator Tests (Real-World Validation)
#### 8.2. Simulator Tests (Real-World Validation)
Create a simulator test to validate that your provider works correctly in real-world scenarios. Create `simulator_tests/test_example_models.py`:
@@ -696,6 +750,36 @@ SUPPORTED_MODELS = {
The `_resolve_model_name()` method handles this mapping automatically.
## Critical Implementation Requirements
### Alias Resolution for OpenAI-Compatible Providers
If you inherit from `OpenAICompatibleProvider` and define model aliases, you **MUST** override `generate_content()` to resolve aliases before API calls. This is because:
1. **The base `OpenAICompatibleProvider.generate_content()`** sends the original model name directly to the API
2. **Your API expects the full model name**, not the alias
3. **Without resolution**, requests like `model="large"` will fail with 404/400 errors
**Examples of providers that need this:**
- XAI provider: `"grok"` → `"grok-3"`
- OpenAI provider: `"mini"` → `"o4-mini"`
- Custom provider: `"fast"` → `"llama-3.1-8b-instruct"`
**Example implementation pattern:**
```python
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
# CRITICAL: Resolve alias before API call
resolved_model_name = self._resolve_model_name(model_name)
# Pass resolved name to parent
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
```
**Providers that DON'T need this:**
- Gemini provider (has its own generate_content implementation)
- OpenRouter provider (already implements this pattern)
- Providers without aliases
## Best Practices
1. **Always validate model names** against supported models and restrictions
@@ -715,6 +799,7 @@ Before submitting your PR:
- [ ] Provider implementation complete with all required methods
- [ ] API key mapping added to `_get_api_key_for_provider()` in `providers/registry.py`
- [ ] Provider added to `PROVIDER_PRIORITY_ORDER` in `registry.py` (if native provider)
- [ ] **Environment variables added to `docker-compose.yml`** (API key and restrictions)
- [ ] Provider imported and registered in `server.py`'s `configure_providers()`
- [ ] API key checking added to `configure_providers()` function
- [ ] Error message updated to include new provider