Updated guides
This commit is contained in:
@@ -1,88 +1,83 @@
|
||||
# Adding a New Provider
|
||||
|
||||
This guide explains how to add support for a new AI model provider to the Zen MCP Server. Follow these steps to integrate providers like Anthropic, Cohere, or any API that provides AI model access.
|
||||
This guide explains how to add support for a new AI model provider to the Zen MCP Server. The provider system is designed to be extensible and follows a simple pattern.
|
||||
|
||||
## Overview
|
||||
|
||||
The provider system in Zen MCP Server is designed to be extensible. Each provider:
|
||||
- Inherits from a base class (`ModelProvider` or `OpenAICompatibleProvider`)
|
||||
- Implements required methods for model interaction
|
||||
- Is registered in the provider registry by the server
|
||||
- Has its API key configured via environment variables
|
||||
Each provider:
|
||||
- Inherits from `ModelProvider` (base class) or `OpenAICompatibleProvider` (for OpenAI-compatible APIs)
|
||||
- Defines supported models using `ModelCapabilities` objects
|
||||
- Implements a few core abstract methods
|
||||
- Gets registered automatically via environment variables
|
||||
|
||||
## Implementation Paths
|
||||
## Choose Your Implementation Path
|
||||
|
||||
You have two options when implementing a new provider:
|
||||
**Option A: Full Provider (`ModelProvider`)**
|
||||
- For APIs with unique features or custom authentication
|
||||
- Complete control over API calls and response handling
|
||||
- Required methods: `generate_content()`, `count_tokens()`, `get_capabilities()`, `validate_model_name()`, `supports_thinking_mode()`, `get_provider_type()`
|
||||
|
||||
### Option A: Native Provider (Full Implementation)
|
||||
Inherit from `ModelProvider` when:
|
||||
- Your API has unique features not compatible with OpenAI's format
|
||||
- You need full control over the implementation
|
||||
- You want to implement custom features like extended thinking
|
||||
**Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)**
|
||||
- For APIs that follow OpenAI's chat completion format
|
||||
- Only need to define: model configurations, capabilities, and validation
|
||||
- Inherits all API handling automatically
|
||||
|
||||
### Option B: OpenAI-Compatible Provider (Simplified)
|
||||
Inherit from `OpenAICompatibleProvider` when:
|
||||
- Your API follows OpenAI's chat completion format
|
||||
- You want to reuse existing implementation for most functionality
|
||||
- You only need to define model capabilities and validation
|
||||
|
||||
⚠️ **CRITICAL**: If your provider has model aliases (shorthands), you **MUST** override `generate_content()` to resolve aliases before API calls. See implementation example below.
|
||||
⚠️ **Important**: If using aliases (like `"gpt"` → `"gpt-4"`), override `generate_content()` to resolve them before API calls.
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
### 1. Add Provider Type to Enum
|
||||
### 1. Add Provider Type
|
||||
|
||||
First, add your provider to the `ProviderType` enum in `providers/base.py`:
|
||||
Add your provider to `ProviderType` enum in `providers/base.py`:
|
||||
|
||||
```python
|
||||
class ProviderType(Enum):
|
||||
"""Supported model provider types."""
|
||||
|
||||
GOOGLE = "google"
|
||||
OPENAI = "openai"
|
||||
OPENROUTER = "openrouter"
|
||||
CUSTOM = "custom"
|
||||
EXAMPLE = "example" # Add your provider here
|
||||
EXAMPLE = "example" # Add this
|
||||
```
|
||||
|
||||
### 2. Create the Provider Implementation
|
||||
|
||||
#### Option A: Native Provider Implementation
|
||||
#### Option A: Full Provider (Native Implementation)
|
||||
|
||||
Create a new file in the `providers/` directory (e.g., `providers/example.py`):
|
||||
Create `providers/example.py`:
|
||||
|
||||
```python
|
||||
"""Example model provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExampleModelProvider(ModelProvider):
|
||||
"""Example model provider implementation."""
|
||||
|
||||
# Define models using ModelCapabilities objects (like Gemini provider)
|
||||
SUPPORTED_MODELS = {
|
||||
"example-large-v1": {
|
||||
"context_window": 100_000,
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
"example-small-v1": {
|
||||
"context_window": 50_000,
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
# Shorthands
|
||||
"large": "example-large-v1",
|
||||
"small": "example-small-v1",
|
||||
"example-large": ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name="example-large",
|
||||
friendly_name="Example Large",
|
||||
context_window=100_000,
|
||||
max_output_tokens=50_000,
|
||||
supports_extended_thinking=False,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
description="Large model for complex tasks",
|
||||
aliases=["large", "big"],
|
||||
),
|
||||
"example-small": ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name="example-small",
|
||||
friendly_name="Example Small",
|
||||
context_window=32_000,
|
||||
max_output_tokens=16_000,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
description="Fast model for simple tasks",
|
||||
aliases=["small", "fast"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
@@ -95,708 +90,225 @@ class ExampleModelProvider(ModelProvider):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
|
||||
# Apply restrictions if needed
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
raise ValueError(f"Model '{model_name}' is not allowed.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name=resolved_name,
|
||||
friendly_name="Example",
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
)
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
self.validate_parameters(resolved_name, temperature)
|
||||
|
||||
# Call your API here
|
||||
# response = your_api_call(...)
|
||||
# Your API call logic here
|
||||
# response = your_api_client.generate(...)
|
||||
|
||||
return ModelResponse(
|
||||
content="", # From API response
|
||||
usage={
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
content="Generated response", # From your API
|
||||
usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
||||
model_name=resolved_name,
|
||||
friendly_name="Example",
|
||||
provider=ProviderType.EXAMPLE,
|
||||
)
|
||||
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
# Implement your tokenization or use estimation
|
||||
return len(text) // 4
|
||||
return len(text) // 4 # Simple estimation
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
return ProviderType.EXAMPLE
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return False
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
logger.debug(f"Example model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
return resolved_name in self.SUPPORTED_MODELS
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
return capabilities.supports_extended_thinking
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
```
|
||||
|
||||
#### Option B: OpenAI-Compatible Provider Implementation
|
||||
#### Option B: OpenAI-Compatible Provider (Simplified)
|
||||
|
||||
For providers with OpenAI-compatible APIs, the implementation is much simpler:
|
||||
For OpenAI-compatible APIs:
|
||||
|
||||
```python
|
||||
"""Example provider using OpenAI-compatible interface."""
|
||||
"""Example OpenAI-compatible provider."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
from .base import ModelCapabilities, ModelResponse, ProviderType, RangeTemperatureConstraint
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExampleProvider(OpenAICompatibleProvider):
|
||||
"""Example provider using OpenAI-compatible API."""
|
||||
"""Example OpenAI-compatible provider."""
|
||||
|
||||
FRIENDLY_NAME = "Example"
|
||||
|
||||
# Define supported models
|
||||
# Define models using ModelCapabilities (consistent with other providers)
|
||||
SUPPORTED_MODELS = {
|
||||
"example-model-large": {
|
||||
"context_window": 128_000,
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
"example-model-small": {
|
||||
"context_window": 32_000,
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
# Shorthands
|
||||
"large": "example-model-large",
|
||||
"small": "example-model-small",
|
||||
"example-model-large": ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name="example-model-large",
|
||||
friendly_name="Example Large",
|
||||
context_window=128_000,
|
||||
max_output_tokens=64_000,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
aliases=["large", "big"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize provider with API key."""
|
||||
# Set your API base URL
|
||||
kwargs.setdefault("base_url", "https://api.example.com/v1")
|
||||
super().__init__(api_key, **kwargs)
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
|
||||
# Check restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.EXAMPLE,
|
||||
model_name=resolved_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 1.0, 0.7),
|
||||
)
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.EXAMPLE
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return False
|
||||
|
||||
# Check restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
return False
|
||||
|
||||
return True
|
||||
return resolved_name in self.SUPPORTED_MODELS
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using API with proper model name resolution."""
|
||||
# CRITICAL: Resolve model alias before making API call
|
||||
# This ensures aliases like "large" get sent as "example-model-large" to the API
|
||||
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
||||
# IMPORTANT: Resolve aliases before API call
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Call parent implementation with resolved model name
|
||||
return super().generate_content(
|
||||
prompt=prompt,
|
||||
model_name=resolved_model_name,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Note: count_tokens is inherited from OpenAICompatibleProvider
|
||||
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
|
||||
```
|
||||
|
||||
### 3. Update Registry Configuration
|
||||
### 3. Register Your Provider
|
||||
|
||||
#### 3.1. Add Environment Variable Mapping
|
||||
|
||||
Update `providers/registry.py` to map your provider's API key:
|
||||
Add environment variable mapping in `providers/registry.py`:
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
|
||||
"""Get API key for a provider from environment variables."""
|
||||
key_mapping = {
|
||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
||||
ProviderType.CUSTOM: "CUSTOM_API_KEY",
|
||||
ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this line
|
||||
}
|
||||
# ... rest of the method
|
||||
```
|
||||
|
||||
### 4. Register Provider in server.py
|
||||
|
||||
The `configure_providers()` function in `server.py` handles provider registration. You need to:
|
||||
|
||||
**Note**: The provider priority is hardcoded in `registry.py`. If you're adding a new native provider (like Example), you'll need to update the `PROVIDER_PRIORITY_ORDER` in `get_provider_for_model()`:
|
||||
|
||||
```python
|
||||
# In providers/registry.py
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE, # Direct Gemini access
|
||||
ProviderType.OPENAI, # Direct OpenAI access
|
||||
ProviderType.EXAMPLE, # Add your native provider here
|
||||
ProviderType.CUSTOM, # Local/self-hosted models
|
||||
ProviderType.OPENROUTER, # Catch-all (must stay last)
|
||||
]
|
||||
```
|
||||
|
||||
Native providers should be placed BEFORE CUSTOM and OPENROUTER to ensure they get priority for their models.
|
||||
|
||||
1. **Import your provider class** at the top of `server.py`:
|
||||
```python
|
||||
from providers.example import ExampleModelProvider
|
||||
```
|
||||
|
||||
2. **Add API key checking** in the `configure_providers()` function:
|
||||
```python
|
||||
def configure_providers():
|
||||
"""Configure and validate AI providers based on available API keys."""
|
||||
# ... existing code ...
|
||||
|
||||
# Check for Example API key
|
||||
example_key = os.getenv("EXAMPLE_API_KEY")
|
||||
if example_key and example_key != "your_example_api_key_here":
|
||||
valid_providers.append("Example")
|
||||
has_native_apis = True
|
||||
logger.info("Example API key found - Example models available")
|
||||
```
|
||||
|
||||
3. **Register the provider** in the appropriate section:
|
||||
```python
|
||||
# Register providers in priority order:
|
||||
# 1. Native APIs first (most direct and efficient)
|
||||
if has_native_apis:
|
||||
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
if openai_key and openai_key != "your_openai_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
if example_key and example_key != "your_example_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.EXAMPLE, ExampleModelProvider)
|
||||
```
|
||||
|
||||
4. **Update error message** to include your provider:
|
||||
```python
|
||||
if not valid_providers:
|
||||
raise ValueError(
|
||||
"At least one API configuration is required. Please set either:\n"
|
||||
"- GEMINI_API_KEY for Gemini models\n"
|
||||
"- OPENAI_API_KEY for OpenAI o3 model\n"
|
||||
"- EXAMPLE_API_KEY for Example models\n" # Add this
|
||||
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
|
||||
"- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)"
|
||||
)
|
||||
```
|
||||
|
||||
### 6. Add Model Descriptions for Auto Mode
|
||||
|
||||
Add descriptions to your model configurations in the `SUPPORTED_MODELS` dictionary. These descriptions help Claude choose the best model for each task in auto mode:
|
||||
|
||||
```python
|
||||
# In your provider implementation
|
||||
SUPPORTED_MODELS = {
|
||||
"example-large-v1": {
|
||||
"context_window": 100_000,
|
||||
"supports_extended_thinking": False,
|
||||
"description": "Example Large (100K context) - High capacity model for complex tasks",
|
||||
},
|
||||
"example-small-v1": {
|
||||
"context_window": 50_000,
|
||||
"supports_extended_thinking": False,
|
||||
"description": "Example Small (50K context) - Fast model for simple tasks",
|
||||
},
|
||||
# Aliases for convenience
|
||||
"large": "example-large-v1",
|
||||
"small": "example-small-v1",
|
||||
# In _get_api_key_for_provider method:
|
||||
key_mapping = {
|
||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||
ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this
|
||||
}
|
||||
```
|
||||
|
||||
The descriptions should be detailed and help Claude understand when to use each model. Include context about performance, capabilities, cost, and ideal use cases.
|
||||
|
||||
### 7. Update Documentation
|
||||
|
||||
#### 7.1. Update README.md
|
||||
|
||||
Add your provider to the quickstart section:
|
||||
|
||||
```markdown
|
||||
### 1. Get API Keys (at least one required)
|
||||
|
||||
**Option B: Native APIs**
|
||||
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey)
|
||||
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys)
|
||||
- **Example**: Visit [Example API Console](https://example.com/api-keys) # Add this
|
||||
```
|
||||
|
||||
Also update the .env file example:
|
||||
|
||||
```markdown
|
||||
# Edit .env to add your API keys
|
||||
# GEMINI_API_KEY=your-gemini-api-key-here
|
||||
# OPENAI_API_KEY=your-openai-api-key-here
|
||||
# EXAMPLE_API_KEY=your-example-api-key-here # Add this
|
||||
```
|
||||
|
||||
### 8. Write Tests
|
||||
|
||||
#### 8.1. Unit Tests
|
||||
|
||||
Create `tests/test_example_provider.py`:
|
||||
Add to `server.py`:
|
||||
|
||||
1. **Import your provider**:
|
||||
```python
|
||||
"""Tests for Example provider implementation."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from providers.example import ExampleModelProvider
|
||||
from providers.base import ProviderType
|
||||
|
||||
|
||||
class TestExampleProvider:
|
||||
"""Test Example provider functionality."""
|
||||
|
||||
@patch.dict(os.environ, {"EXAMPLE_API_KEY": "test-key"})
|
||||
def test_initialization(self):
|
||||
"""Test provider initialization."""
|
||||
provider = ExampleModelProvider("test-key")
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.get_provider_type() == ProviderType.EXAMPLE
|
||||
|
||||
def test_model_validation(self):
|
||||
"""Test model name validation."""
|
||||
provider = ExampleModelProvider("test-key")
|
||||
|
||||
# Test valid models
|
||||
assert provider.validate_model_name("large") is True
|
||||
assert provider.validate_model_name("example-large-v1") is True
|
||||
|
||||
# Test invalid model
|
||||
assert provider.validate_model_name("invalid-model") is False
|
||||
|
||||
def test_resolve_model_name(self):
|
||||
"""Test model name resolution."""
|
||||
provider = ExampleModelProvider("test-key")
|
||||
|
||||
# Test shorthand resolution
|
||||
assert provider._resolve_model_name("large") == "example-large-v1"
|
||||
assert provider._resolve_model_name("small") == "example-small-v1"
|
||||
|
||||
# Test full name passthrough
|
||||
assert provider._resolve_model_name("example-large-v1") == "example-large-v1"
|
||||
|
||||
def test_get_capabilities(self):
|
||||
"""Test getting model capabilities."""
|
||||
provider = ExampleModelProvider("test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("large")
|
||||
assert capabilities.model_name == "example-large-v1"
|
||||
assert capabilities.friendly_name == "Example"
|
||||
assert capabilities.context_window == 100_000
|
||||
assert capabilities.provider == ProviderType.EXAMPLE
|
||||
|
||||
# Test temperature range
|
||||
assert capabilities.temperature_constraint.min_temp == 0.0
|
||||
assert capabilities.temperature_constraint.max_temp == 2.0
|
||||
```
|
||||
|
||||
#### 8.2. Simulator Tests (Real-World Validation)
|
||||
2. **Add to `configure_providers()` function**:
|
||||
```python
|
||||
# Check for Example API key
|
||||
example_key = os.getenv("EXAMPLE_API_KEY")
|
||||
if example_key:
|
||||
ModelProviderRegistry.register_provider(ProviderType.EXAMPLE, ExampleModelProvider)
|
||||
logger.info("Example API key found - Example models available")
|
||||
```
|
||||
|
||||
Create a simulator test to validate that your provider works correctly in real-world scenarios. Create `simulator_tests/test_example_models.py`:
|
||||
3. **Add to provider priority** (in `providers/registry.py`):
|
||||
```python
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE,
|
||||
ProviderType.OPENAI,
|
||||
ProviderType.EXAMPLE, # Add your provider here
|
||||
ProviderType.CUSTOM, # Local models
|
||||
ProviderType.OPENROUTER, # Catch-all (keep last)
|
||||
]
|
||||
```
|
||||
|
||||
### 4. Environment Configuration
|
||||
|
||||
Add to your `.env` file:
|
||||
```bash
|
||||
# Your provider's API key
|
||||
EXAMPLE_API_KEY=your_api_key_here
|
||||
|
||||
# Optional: Disable specific tools
|
||||
DISABLED_TOOLS=debug,tracer
|
||||
```
|
||||
|
||||
**Note**: The `description` field in `ModelCapabilities` helps Claude choose the best model in auto mode.
|
||||
|
||||
### 5. Test Your Provider
|
||||
|
||||
Create basic tests to verify your implementation:
|
||||
|
||||
```python
|
||||
"""
|
||||
Example Provider Model Tests
|
||||
# Test model validation
|
||||
provider = ExampleModelProvider("test-key")
|
||||
assert provider.validate_model_name("large") == True
|
||||
assert provider.validate_model_name("unknown") == False
|
||||
|
||||
Tests that verify Example provider functionality including:
|
||||
- Model alias resolution
|
||||
- API integration
|
||||
- Conversation continuity
|
||||
- Error handling
|
||||
"""
|
||||
|
||||
from .base_test import BaseSimulatorTest
|
||||
|
||||
|
||||
class TestExampleModels(BaseSimulatorTest):
|
||||
"""Test Example provider functionality"""
|
||||
|
||||
@property
|
||||
def test_name(self) -> str:
|
||||
return "example_models"
|
||||
|
||||
@property
|
||||
def test_description(self) -> str:
|
||||
return "Example provider model functionality and integration"
|
||||
|
||||
def run_test(self) -> bool:
|
||||
"""Test Example provider models"""
|
||||
try:
|
||||
self.logger.info("Test: Example provider functionality")
|
||||
|
||||
# Check if Example API key is configured
|
||||
check_result = self.check_env_var("EXAMPLE_API_KEY")
|
||||
if not check_result:
|
||||
self.logger.info(" ⚠️ Example API key not configured - skipping test")
|
||||
return True # Skip, not fail
|
||||
|
||||
# Test 1: Shorthand alias mapping
|
||||
self.logger.info(" 1: Testing 'large' alias mapping")
|
||||
|
||||
response1, continuation_id = self.call_mcp_tool(
|
||||
"chat",
|
||||
{
|
||||
"prompt": "Say 'Hello from Example Large model!' and nothing else.",
|
||||
"model": "large", # Should map to example-large-v1
|
||||
"temperature": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
if not response1:
|
||||
self.logger.error(" ❌ Large alias test failed")
|
||||
return False
|
||||
|
||||
self.logger.info(" ✅ Large alias call completed")
|
||||
|
||||
# Test 2: Direct model name
|
||||
self.logger.info(" 2: Testing direct model name (example-small-v1)")
|
||||
|
||||
response2, _ = self.call_mcp_tool(
|
||||
"chat",
|
||||
{
|
||||
"prompt": "Say 'Hello from Example Small model!' and nothing else.",
|
||||
"model": "example-small-v1",
|
||||
"temperature": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
if not response2:
|
||||
self.logger.error(" ❌ Direct model name test failed")
|
||||
return False
|
||||
|
||||
self.logger.info(" ✅ Direct model name call completed")
|
||||
|
||||
# Test 3: Conversation continuity
|
||||
self.logger.info(" 3: Testing conversation continuity")
|
||||
|
||||
response3, new_continuation_id = self.call_mcp_tool(
|
||||
"chat",
|
||||
{
|
||||
"prompt": "Remember this number: 99. What number did I just tell you?",
|
||||
"model": "large",
|
||||
"temperature": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
if not response3 or not new_continuation_id:
|
||||
self.logger.error(" ❌ Failed to start conversation")
|
||||
return False
|
||||
|
||||
# Continue conversation
|
||||
response4, _ = self.call_mcp_tool(
|
||||
"chat",
|
||||
{
|
||||
"prompt": "What was the number I told you earlier?",
|
||||
"model": "large",
|
||||
"continuation_id": new_continuation_id,
|
||||
"temperature": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
if not response4:
|
||||
self.logger.error(" ❌ Failed to continue conversation")
|
||||
return False
|
||||
|
||||
if "99" in response4:
|
||||
self.logger.info(" ✅ Conversation continuity working")
|
||||
else:
|
||||
self.logger.warning(" ⚠️ Model may not have remembered the number")
|
||||
|
||||
# Test 4: Check logs for proper provider usage
|
||||
self.logger.info(" 4: Validating Example provider usage in logs")
|
||||
logs = self.get_recent_server_logs()
|
||||
|
||||
# Look for evidence of Example provider usage
|
||||
example_logs = [line for line in logs.split("\n") if "example" in line.lower()]
|
||||
model_resolution_logs = [
|
||||
line for line in logs.split("\n")
|
||||
if "Resolved model" in line and "example" in line.lower()
|
||||
]
|
||||
|
||||
self.logger.info(f" Example-related logs: {len(example_logs)}")
|
||||
self.logger.info(f" Model resolution logs: {len(model_resolution_logs)}")
|
||||
|
||||
# Success criteria
|
||||
api_used = len(example_logs) > 0
|
||||
models_resolved = len(model_resolution_logs) > 0
|
||||
|
||||
if api_used and models_resolved:
|
||||
self.logger.info(" ✅ Example provider tests passed")
|
||||
return True
|
||||
else:
|
||||
self.logger.error(" ❌ Example provider tests failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Example provider test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the Example provider tests"""
|
||||
import sys
|
||||
|
||||
verbose = "--verbose" in sys.argv or "-v" in sys.argv
|
||||
test = TestExampleModels(verbose=verbose)
|
||||
|
||||
success = test.run_test()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Test capabilities
|
||||
caps = provider.get_capabilities("large")
|
||||
assert caps.context_window > 0
|
||||
assert caps.provider == ProviderType.EXAMPLE
|
||||
```
|
||||
|
||||
The simulator test is crucial because it:
|
||||
- Validates your provider works in the actual server environment
|
||||
- Tests real API integration, not just mocked behavior
|
||||
- Verifies model name resolution works correctly
|
||||
- Checks conversation continuity across requests
|
||||
- Examines server logs to ensure proper provider selection
|
||||
|
||||
See `simulator_tests/test_openrouter_models.py` for a complete real-world example.
|
||||
|
||||
## Model Name Mapping and Provider Priority
|
||||
## Key Concepts
|
||||
|
||||
### How Model Name Resolution Works
|
||||
### Provider Priority
|
||||
When a user requests a model, providers are checked in priority order:
|
||||
1. **Native providers** (Gemini, OpenAI, Example) - handle their specific models
|
||||
2. **Custom provider** - handles local/self-hosted models
|
||||
3. **OpenRouter** - catch-all for everything else
|
||||
|
||||
When a user requests a model (e.g., "pro", "o3", "example-large-v1"), the system:
|
||||
|
||||
1. **Checks providers in priority order** (defined in `registry.py`):
|
||||
```python
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE, # Native Gemini API
|
||||
ProviderType.OPENAI, # Native OpenAI API
|
||||
ProviderType.CUSTOM, # Local/self-hosted
|
||||
ProviderType.OPENROUTER, # Catch-all for everything else
|
||||
]
|
||||
```
|
||||
|
||||
2. **For each provider**, calls `validate_model_name()`:
|
||||
- Native providers (Gemini, OpenAI) return `true` only for their specific models
|
||||
- OpenRouter returns `true` for ANY model (it's the catch-all)
|
||||
- First provider that validates the model handles the request
|
||||
|
||||
### Example: Model "gemini-2.5-pro"
|
||||
|
||||
1. **Gemini provider** checks: YES, it's in my SUPPORTED_MODELS → Gemini handles it
|
||||
2. OpenAI skips (Gemini already handled it)
|
||||
3. OpenRouter never sees it
|
||||
|
||||
### Example: Model "claude-4-opus"
|
||||
|
||||
1. **Gemini provider** checks: NO, not my model → skip
|
||||
2. **OpenAI provider** checks: NO, not my model → skip
|
||||
3. **Custom provider** checks: NO, not configured → skip
|
||||
4. **OpenRouter provider** checks: YES, I accept all models → OpenRouter handles it
|
||||
|
||||
### Implementing Model Name Validation
|
||||
|
||||
Your provider's `validate_model_name()` should:
|
||||
### Model Validation
|
||||
Your `validate_model_name()` should **only** return `True` for models you explicitly support:
|
||||
|
||||
```python
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Only accept models you explicitly support
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return False
|
||||
|
||||
# Check restrictions
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
||||
logger.debug(f"Example model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
return resolved_name in self.SUPPORTED_MODELS # Be specific!
|
||||
```
|
||||
|
||||
**Important**: Native providers should ONLY return `true` for models they explicitly support. This ensures they get priority over proxy providers like OpenRouter.
|
||||
### Model Aliases
|
||||
The base class handles alias resolution automatically via the `aliases` field in `ModelCapabilities`.
|
||||
|
||||
### Model Shorthands
|
||||
## Important Notes
|
||||
|
||||
Each provider can define shorthands in their SUPPORTED_MODELS:
|
||||
### Alias Resolution in OpenAI-Compatible Providers
|
||||
If using `OpenAICompatibleProvider` with aliases, **you must override `generate_content()`** to resolve aliases before API calls:
|
||||
|
||||
```python
|
||||
SUPPORTED_MODELS = {
|
||||
"example-large-v1": { ... }, # Full model name
|
||||
"large": "example-large-v1", # Shorthand mapping
|
||||
}
|
||||
```
|
||||
|
||||
The `_resolve_model_name()` method handles this mapping automatically.
|
||||
|
||||
## Critical Implementation Requirements
|
||||
|
||||
### Alias Resolution for OpenAI-Compatible Providers
|
||||
|
||||
If you inherit from `OpenAICompatibleProvider` and define model aliases, you **MUST** override `generate_content()` to resolve aliases before API calls. This is because:
|
||||
|
||||
1. **The base `OpenAICompatibleProvider.generate_content()`** sends the original model name directly to the API
|
||||
2. **Your API expects the full model name**, not the alias
|
||||
3. **Without resolution**, requests like `model="large"` will fail with 404/400 errors
|
||||
|
||||
**Examples of providers that need this:**
|
||||
- XAI provider: `"grok"` → `"grok-3"`
|
||||
- OpenAI provider: `"mini"` → `"o4-mini"`
|
||||
- Custom provider: `"fast"` → `"llama-3.1-8b-instruct"`
|
||||
|
||||
**Example implementation pattern:**
|
||||
```python
|
||||
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
||||
# CRITICAL: Resolve alias before API call
|
||||
# Resolve alias before API call
|
||||
resolved_model_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Pass resolved name to parent
|
||||
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
|
||||
```
|
||||
|
||||
**Providers that DON'T need this:**
|
||||
- Gemini provider (has its own generate_content implementation)
|
||||
- OpenRouter provider (already implements this pattern)
|
||||
- Providers without aliases
|
||||
Without this, API calls with aliases like `"large"` will fail because your API doesn't recognize the alias.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always validate model names** against supported models and restrictions
|
||||
2. **Be specific in validation** - only accept models you actually support
|
||||
3. **Handle API errors gracefully** with proper error messages
|
||||
4. **Include retry logic** for transient errors (see `gemini.py` for example)
|
||||
5. **Log important events** for debugging (initialization, model resolution, errors)
|
||||
6. **Support model shorthands** for better user experience
|
||||
7. **Document supported models** clearly in your provider class
|
||||
8. **Test thoroughly** including error cases and edge conditions
|
||||
- **Be specific in model validation** - only accept models you actually support
|
||||
- **Use ModelCapabilities objects** consistently (like Gemini provider)
|
||||
- **Include descriptive aliases** for better user experience
|
||||
- **Add error handling** and logging for debugging
|
||||
- **Test with real API calls** to verify everything works
|
||||
- **Follow the existing patterns** in `providers/gemini.py` and `providers/custom.py`
|
||||
|
||||
## Checklist
|
||||
## Quick Checklist
|
||||
|
||||
Before submitting your PR:
|
||||
- [ ] Added to `ProviderType` enum in `providers/base.py`
|
||||
- [ ] Created provider class with all required methods
|
||||
- [ ] Added API key mapping in `providers/registry.py`
|
||||
- [ ] Added to provider priority order in `registry.py`
|
||||
- [ ] Imported and registered in `server.py`
|
||||
- [ ] Basic tests verify model validation and capabilities
|
||||
- [ ] Tested with real API calls
|
||||
|
||||
- [ ] Provider type added to `ProviderType` enum in `providers/base.py`
|
||||
- [ ] Provider implementation complete with all required methods
|
||||
- [ ] API key mapping added to `_get_api_key_for_provider()` in `providers/registry.py`
|
||||
- [ ] Provider added to `PROVIDER_PRIORITY_ORDER` in `registry.py` (if native provider)
|
||||
- [ ] **Environment variables added to `.env` file** (API key and restrictions)
|
||||
- [ ] Provider imported and registered in `server.py`'s `configure_providers()`
|
||||
- [ ] API key checking added to `configure_providers()` function
|
||||
- [ ] Error message updated to include new provider
|
||||
- [ ] Model capabilities added to `config.py` for auto mode
|
||||
- [ ] Documentation updated (README.md)
|
||||
- [ ] Unit tests written and passing (`tests/test_<provider>.py`)
|
||||
- [ ] Simulator tests written and passing (`simulator_tests/test_<provider>_models.py`)
|
||||
- [ ] Integration tested with actual API calls
|
||||
- [ ] Code follows project style (run linting)
|
||||
- [ ] PR follows the template requirements
|
||||
## Examples
|
||||
|
||||
## Need Help?
|
||||
See existing implementations:
|
||||
- **Full provider**: `providers/gemini.py`
|
||||
- **OpenAI-compatible**: `providers/custom.py`
|
||||
- **Base classes**: `providers/base.py`
|
||||
|
||||
- Look at existing providers (`gemini.py`, `openai.py`) for examples
|
||||
- Check the base classes for method signatures and requirements
|
||||
- Run tests frequently during development
|
||||
- Ask questions in GitHub issues if stuck
|
||||
The modern approach uses `ModelCapabilities` objects directly in `SUPPORTED_MODELS`, making the implementation much cleaner and more consistent.
|
||||
Reference in New Issue
Block a user