Docs added to show how a new tool is created All tools should add numbers to code for models to be able to reference if needed Enabled line numbering for code for all tools to use Additional tests to validate line numbering is not added to git diffs
734 lines
26 KiB
Markdown
734 lines
26 KiB
Markdown
# Adding a New Provider
|
|
|
|
This guide explains how to add support for a new AI model provider to the Zen MCP Server. Follow these steps to integrate providers like Anthropic, Cohere, or any API that provides AI model access.
|
|
|
|
## Overview
|
|
|
|
The provider system in Zen MCP Server is designed to be extensible. Each provider:
|
|
- Inherits from a base class (`ModelProvider` or `OpenAICompatibleProvider`)
|
|
- Implements required methods for model interaction
|
|
- Is registered in the provider registry by the server
|
|
- Has its API key configured via environment variables
|
|
|
|
## Implementation Paths
|
|
|
|
You have two options when implementing a new provider:
|
|
|
|
### Option A: Native Provider (Full Implementation)
|
|
Inherit from `ModelProvider` when:
|
|
- Your API has unique features not compatible with OpenAI's format
|
|
- You need full control over the implementation
|
|
- You want to implement custom features like extended thinking
|
|
|
|
### Option B: OpenAI-Compatible Provider (Simplified)
|
|
Inherit from `OpenAICompatibleProvider` when:
|
|
- Your API follows OpenAI's chat completion format
|
|
- You want to reuse existing implementation for `generate_content` and `count_tokens`
|
|
- You only need to define model capabilities and validation
|
|
|
|
## Step-by-Step Guide
|
|
|
|
### 1. Add Provider Type to Enum
|
|
|
|
First, add your provider to the `ProviderType` enum in `providers/base.py`:
|
|
|
|
```python
|
|
class ProviderType(Enum):
|
|
"""Supported model provider types."""
|
|
|
|
GOOGLE = "google"
|
|
OPENAI = "openai"
|
|
OPENROUTER = "openrouter"
|
|
CUSTOM = "custom"
|
|
EXAMPLE = "example" # Add your provider here
|
|
```
|
|
|
|
### 2. Create the Provider Implementation
|
|
|
|
#### Option A: Native Provider Implementation
|
|
|
|
Create a new file in the `providers/` directory (e.g., `providers/example.py`):
|
|
|
|
```python
|
|
"""Example model provider implementation."""
|
|
|
|
import logging
|
|
from typing import Optional
|
|
from .base import (
|
|
ModelCapabilities,
|
|
ModelProvider,
|
|
ModelResponse,
|
|
ProviderType,
|
|
RangeTemperatureConstraint,
|
|
)
|
|
from utils.model_restrictions import get_restriction_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ExampleModelProvider(ModelProvider):
|
|
"""Example model provider implementation."""
|
|
|
|
SUPPORTED_MODELS = {
|
|
"example-large-v1": {
|
|
"context_window": 100_000,
|
|
"supports_extended_thinking": False,
|
|
},
|
|
"example-small-v1": {
|
|
"context_window": 50_000,
|
|
"supports_extended_thinking": False,
|
|
},
|
|
# Shorthands
|
|
"large": "example-large-v1",
|
|
"small": "example-small-v1",
|
|
}
|
|
|
|
def __init__(self, api_key: str, **kwargs):
|
|
super().__init__(api_key, **kwargs)
|
|
# Initialize your API client here
|
|
|
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
resolved_name = self._resolve_model_name(model_name)
|
|
|
|
if resolved_name not in self.SUPPORTED_MODELS:
|
|
raise ValueError(f"Unsupported model: {model_name}")
|
|
|
|
restriction_service = get_restriction_service()
|
|
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
|
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
|
|
|
config = self.SUPPORTED_MODELS[resolved_name]
|
|
|
|
return ModelCapabilities(
|
|
provider=ProviderType.EXAMPLE,
|
|
model_name=resolved_name,
|
|
friendly_name="Example",
|
|
context_window=config["context_window"],
|
|
supports_extended_thinking=config["supports_extended_thinking"],
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=True,
|
|
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
|
)
|
|
|
|
def generate_content(
|
|
self,
|
|
prompt: str,
|
|
model_name: str,
|
|
system_prompt: Optional[str] = None,
|
|
temperature: float = 0.7,
|
|
max_output_tokens: Optional[int] = None,
|
|
**kwargs,
|
|
) -> ModelResponse:
|
|
resolved_name = self._resolve_model_name(model_name)
|
|
self.validate_parameters(resolved_name, temperature)
|
|
|
|
# Call your API here
|
|
# response = your_api_call(...)
|
|
|
|
return ModelResponse(
|
|
content="", # From API response
|
|
usage={
|
|
"input_tokens": 0,
|
|
"output_tokens": 0,
|
|
"total_tokens": 0,
|
|
},
|
|
model_name=resolved_name,
|
|
friendly_name="Example",
|
|
provider=ProviderType.EXAMPLE,
|
|
)
|
|
|
|
def count_tokens(self, text: str, model_name: str) -> int:
|
|
# Implement your tokenization or use estimation
|
|
return len(text) // 4
|
|
|
|
def get_provider_type(self) -> ProviderType:
|
|
return ProviderType.EXAMPLE
|
|
|
|
def validate_model_name(self, model_name: str) -> bool:
|
|
resolved_name = self._resolve_model_name(model_name)
|
|
|
|
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
|
return False
|
|
|
|
restriction_service = get_restriction_service()
|
|
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
|
logger.debug(f"Example model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
|
return False
|
|
|
|
return True
|
|
|
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
|
capabilities = self.get_capabilities(model_name)
|
|
return capabilities.supports_extended_thinking
|
|
|
|
def _resolve_model_name(self, model_name: str) -> str:
|
|
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
|
if isinstance(shorthand_value, str):
|
|
return shorthand_value
|
|
return model_name
|
|
```
|
|
|
|
#### Option B: OpenAI-Compatible Provider Implementation
|
|
|
|
For providers with OpenAI-compatible APIs, the implementation is much simpler:
|
|
|
|
```python
|
|
"""Example provider using OpenAI-compatible interface."""
|
|
|
|
import logging
|
|
from .base import (
|
|
ModelCapabilities,
|
|
ProviderType,
|
|
RangeTemperatureConstraint,
|
|
)
|
|
from .openai_compatible import OpenAICompatibleProvider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ExampleProvider(OpenAICompatibleProvider):
|
|
"""Example provider using OpenAI-compatible API."""
|
|
|
|
FRIENDLY_NAME = "Example"
|
|
|
|
# Define supported models
|
|
SUPPORTED_MODELS = {
|
|
"example-model-large": {
|
|
"context_window": 128_000,
|
|
"supports_extended_thinking": False,
|
|
},
|
|
"example-model-small": {
|
|
"context_window": 32_000,
|
|
"supports_extended_thinking": False,
|
|
},
|
|
# Shorthands
|
|
"large": "example-model-large",
|
|
"small": "example-model-small",
|
|
}
|
|
|
|
def __init__(self, api_key: str, **kwargs):
|
|
"""Initialize provider with API key."""
|
|
# Set your API base URL
|
|
kwargs.setdefault("base_url", "https://api.example.com/v1")
|
|
super().__init__(api_key, **kwargs)
|
|
|
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
"""Get capabilities for a specific model."""
|
|
resolved_name = self._resolve_model_name(model_name)
|
|
|
|
if resolved_name not in self.SUPPORTED_MODELS:
|
|
raise ValueError(f"Unsupported model: {model_name}")
|
|
|
|
# Check restrictions
|
|
from utils.model_restrictions import get_restriction_service
|
|
|
|
restriction_service = get_restriction_service()
|
|
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
|
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
|
|
|
config = self.SUPPORTED_MODELS[resolved_name]
|
|
|
|
return ModelCapabilities(
|
|
provider=ProviderType.EXAMPLE,
|
|
model_name=resolved_name,
|
|
friendly_name=self.FRIENDLY_NAME,
|
|
context_window=config["context_window"],
|
|
supports_extended_thinking=config["supports_extended_thinking"],
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=True,
|
|
temperature_constraint=RangeTemperatureConstraint(0.0, 1.0, 0.7),
|
|
)
|
|
|
|
def get_provider_type(self) -> ProviderType:
|
|
"""Get the provider type."""
|
|
return ProviderType.EXAMPLE
|
|
|
|
def validate_model_name(self, model_name: str) -> bool:
|
|
"""Validate if the model name is supported."""
|
|
resolved_name = self._resolve_model_name(model_name)
|
|
|
|
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
|
return False
|
|
|
|
# Check restrictions
|
|
from utils.model_restrictions import get_restriction_service
|
|
|
|
restriction_service = get_restriction_service()
|
|
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
|
return False
|
|
|
|
return True
|
|
|
|
def _resolve_model_name(self, model_name: str) -> str:
|
|
"""Resolve model shorthand to full name."""
|
|
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
|
if isinstance(shorthand_value, str):
|
|
return shorthand_value
|
|
return model_name
|
|
|
|
# Note: generate_content and count_tokens are inherited from OpenAICompatibleProvider
|
|
```
|
|
|
|
### 3. Update Registry Configuration
|
|
|
|
#### 3.1. Add Environment Variable Mapping
|
|
|
|
Update `providers/registry.py` to map your provider's API key:
|
|
|
|
```python
|
|
@classmethod
|
|
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
|
|
"""Get API key for a provider from environment variables."""
|
|
key_mapping = {
|
|
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
|
ProviderType.OPENAI: "OPENAI_API_KEY",
|
|
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
|
ProviderType.CUSTOM: "CUSTOM_API_KEY",
|
|
ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this line
|
|
}
|
|
# ... rest of the method
|
|
```
|
|
|
|
### 4. Register Provider in server.py
|
|
|
|
The `configure_providers()` function in `server.py` handles provider registration. You need to:
|
|
|
|
**Note**: The provider priority is hardcoded in `registry.py`. If you're adding a new native provider (like Example), you'll need to update the `PROVIDER_PRIORITY_ORDER` in `get_provider_for_model()`:
|
|
|
|
```python
|
|
# In providers/registry.py
|
|
PROVIDER_PRIORITY_ORDER = [
|
|
ProviderType.GOOGLE, # Direct Gemini access
|
|
ProviderType.OPENAI, # Direct OpenAI access
|
|
ProviderType.EXAMPLE, # Add your native provider here
|
|
ProviderType.CUSTOM, # Local/self-hosted models
|
|
ProviderType.OPENROUTER, # Catch-all (must stay last)
|
|
]
|
|
```
|
|
|
|
Native providers should be placed BEFORE CUSTOM and OPENROUTER to ensure they get priority for their models.
|
|
|
|
1. **Import your provider class** at the top of `server.py`:
|
|
```python
|
|
from providers.example import ExampleModelProvider
|
|
```
|
|
|
|
2. **Add API key checking** in the `configure_providers()` function:
|
|
```python
|
|
def configure_providers():
|
|
"""Configure and validate AI providers based on available API keys."""
|
|
# ... existing code ...
|
|
|
|
# Check for Example API key
|
|
example_key = os.getenv("EXAMPLE_API_KEY")
|
|
if example_key and example_key != "your_example_api_key_here":
|
|
valid_providers.append("Example")
|
|
has_native_apis = True
|
|
logger.info("Example API key found - Example models available")
|
|
```
|
|
|
|
3. **Register the provider** in the appropriate section:
|
|
```python
|
|
# Register providers in priority order:
|
|
# 1. Native APIs first (most direct and efficient)
|
|
if has_native_apis:
|
|
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
|
if openai_key and openai_key != "your_openai_api_key_here":
|
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
|
if example_key and example_key != "your_example_api_key_here":
|
|
ModelProviderRegistry.register_provider(ProviderType.EXAMPLE, ExampleModelProvider)
|
|
```
|
|
|
|
4. **Update error message** to include your provider:
|
|
```python
|
|
if not valid_providers:
|
|
raise ValueError(
|
|
"At least one API configuration is required. Please set either:\n"
|
|
"- GEMINI_API_KEY for Gemini models\n"
|
|
"- OPENAI_API_KEY for OpenAI o3 model\n"
|
|
"- EXAMPLE_API_KEY for Example models\n" # Add this
|
|
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
|
|
"- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)"
|
|
)
|
|
```
|
|
|
|
### 5. Add Model Capabilities for Auto Mode
|
|
|
|
Update `config.py` to add your models to `MODEL_CAPABILITIES_DESC`:
|
|
|
|
```python
|
|
MODEL_CAPABILITIES_DESC = {
|
|
# ... existing models ...
|
|
|
|
# Example models - Available when EXAMPLE_API_KEY is configured
|
|
"large": "Example Large (100K context) - High capacity model for complex tasks",
|
|
"small": "Example Small (50K context) - Fast model for simple tasks",
|
|
# Full model names
|
|
"example-large-v1": "Example Large (100K context) - High capacity model",
|
|
"example-small-v1": "Example Small (50K context) - Fast lightweight model",
|
|
}
|
|
```
|
|
|
|
### 6. Update Documentation
|
|
|
|
#### 6.1. Update README.md
|
|
|
|
Add your provider to the quickstart section:
|
|
|
|
```markdown
|
|
### 1. Get API Keys (at least one required)
|
|
|
|
**Option B: Native APIs**
|
|
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey)
|
|
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys)
|
|
- **Example**: Visit [Example API Console](https://example.com/api-keys) # Add this
|
|
```
|
|
|
|
Also update the .env file example:
|
|
|
|
```markdown
|
|
# Edit .env to add your API keys
|
|
# GEMINI_API_KEY=your-gemini-api-key-here
|
|
# OPENAI_API_KEY=your-openai-api-key-here
|
|
# EXAMPLE_API_KEY=your-example-api-key-here # Add this
|
|
```
|
|
|
|
### 7. Write Tests
|
|
|
|
#### 7.1. Unit Tests
|
|
|
|
Create `tests/test_example_provider.py`:
|
|
|
|
```python
|
|
"""Tests for Example provider implementation."""
|
|
|
|
import os
|
|
from unittest.mock import patch
|
|
import pytest
|
|
|
|
from providers.example import ExampleModelProvider
|
|
from providers.base import ProviderType
|
|
|
|
|
|
class TestExampleProvider:
|
|
"""Test Example provider functionality."""
|
|
|
|
@patch.dict(os.environ, {"EXAMPLE_API_KEY": "test-key"})
|
|
def test_initialization(self):
|
|
"""Test provider initialization."""
|
|
provider = ExampleModelProvider("test-key")
|
|
assert provider.api_key == "test-key"
|
|
assert provider.get_provider_type() == ProviderType.EXAMPLE
|
|
|
|
def test_model_validation(self):
|
|
"""Test model name validation."""
|
|
provider = ExampleModelProvider("test-key")
|
|
|
|
# Test valid models
|
|
assert provider.validate_model_name("large") is True
|
|
assert provider.validate_model_name("example-large-v1") is True
|
|
|
|
# Test invalid model
|
|
assert provider.validate_model_name("invalid-model") is False
|
|
|
|
def test_resolve_model_name(self):
|
|
"""Test model name resolution."""
|
|
provider = ExampleModelProvider("test-key")
|
|
|
|
# Test shorthand resolution
|
|
assert provider._resolve_model_name("large") == "example-large-v1"
|
|
assert provider._resolve_model_name("small") == "example-small-v1"
|
|
|
|
# Test full name passthrough
|
|
assert provider._resolve_model_name("example-large-v1") == "example-large-v1"
|
|
|
|
def test_get_capabilities(self):
|
|
"""Test getting model capabilities."""
|
|
provider = ExampleModelProvider("test-key")
|
|
|
|
capabilities = provider.get_capabilities("large")
|
|
assert capabilities.model_name == "example-large-v1"
|
|
assert capabilities.friendly_name == "Example"
|
|
assert capabilities.context_window == 100_000
|
|
assert capabilities.provider == ProviderType.EXAMPLE
|
|
|
|
# Test temperature range
|
|
assert capabilities.temperature_constraint.min_temp == 0.0
|
|
assert capabilities.temperature_constraint.max_temp == 2.0
|
|
```
|
|
|
|
#### 7.2. Simulator Tests (Real-World Validation)
|
|
|
|
Create a simulator test to validate that your provider works correctly in real-world scenarios. Create `simulator_tests/test_example_models.py`:
|
|
|
|
```python
|
|
"""
|
|
Example Provider Model Tests
|
|
|
|
Tests that verify Example provider functionality including:
|
|
- Model alias resolution
|
|
- API integration
|
|
- Conversation continuity
|
|
- Error handling
|
|
"""
|
|
|
|
from .base_test import BaseSimulatorTest
|
|
|
|
|
|
class TestExampleModels(BaseSimulatorTest):
|
|
"""Test Example provider functionality"""
|
|
|
|
@property
|
|
def test_name(self) -> str:
|
|
return "example_models"
|
|
|
|
@property
|
|
def test_description(self) -> str:
|
|
return "Example provider model functionality and integration"
|
|
|
|
def run_test(self) -> bool:
|
|
"""Test Example provider models"""
|
|
try:
|
|
self.logger.info("Test: Example provider functionality")
|
|
|
|
# Check if Example API key is configured
|
|
check_result = self.check_env_var("EXAMPLE_API_KEY")
|
|
if not check_result:
|
|
self.logger.info(" ⚠️ Example API key not configured - skipping test")
|
|
return True # Skip, not fail
|
|
|
|
# Test 1: Shorthand alias mapping
|
|
self.logger.info(" 1: Testing 'large' alias mapping")
|
|
|
|
response1, continuation_id = self.call_mcp_tool(
|
|
"chat",
|
|
{
|
|
"prompt": "Say 'Hello from Example Large model!' and nothing else.",
|
|
"model": "large", # Should map to example-large-v1
|
|
"temperature": 0.1,
|
|
}
|
|
)
|
|
|
|
if not response1:
|
|
self.logger.error(" ❌ Large alias test failed")
|
|
return False
|
|
|
|
self.logger.info(" ✅ Large alias call completed")
|
|
|
|
# Test 2: Direct model name
|
|
self.logger.info(" 2: Testing direct model name (example-small-v1)")
|
|
|
|
response2, _ = self.call_mcp_tool(
|
|
"chat",
|
|
{
|
|
"prompt": "Say 'Hello from Example Small model!' and nothing else.",
|
|
"model": "example-small-v1",
|
|
"temperature": 0.1,
|
|
}
|
|
)
|
|
|
|
if not response2:
|
|
self.logger.error(" ❌ Direct model name test failed")
|
|
return False
|
|
|
|
self.logger.info(" ✅ Direct model name call completed")
|
|
|
|
# Test 3: Conversation continuity
|
|
self.logger.info(" 3: Testing conversation continuity")
|
|
|
|
response3, new_continuation_id = self.call_mcp_tool(
|
|
"chat",
|
|
{
|
|
"prompt": "Remember this number: 99. What number did I just tell you?",
|
|
"model": "large",
|
|
"temperature": 0.1,
|
|
}
|
|
)
|
|
|
|
if not response3 or not new_continuation_id:
|
|
self.logger.error(" ❌ Failed to start conversation")
|
|
return False
|
|
|
|
# Continue conversation
|
|
response4, _ = self.call_mcp_tool(
|
|
"chat",
|
|
{
|
|
"prompt": "What was the number I told you earlier?",
|
|
"model": "large",
|
|
"continuation_id": new_continuation_id,
|
|
"temperature": 0.1,
|
|
}
|
|
)
|
|
|
|
if not response4:
|
|
self.logger.error(" ❌ Failed to continue conversation")
|
|
return False
|
|
|
|
if "99" in response4:
|
|
self.logger.info(" ✅ Conversation continuity working")
|
|
else:
|
|
self.logger.warning(" ⚠️ Model may not have remembered the number")
|
|
|
|
# Test 4: Check logs for proper provider usage
|
|
self.logger.info(" 4: Validating Example provider usage in logs")
|
|
logs = self.get_recent_server_logs()
|
|
|
|
# Look for evidence of Example provider usage
|
|
example_logs = [line for line in logs.split("\n") if "example" in line.lower()]
|
|
model_resolution_logs = [
|
|
line for line in logs.split("\n")
|
|
if "Resolved model" in line and "example" in line.lower()
|
|
]
|
|
|
|
self.logger.info(f" Example-related logs: {len(example_logs)}")
|
|
self.logger.info(f" Model resolution logs: {len(model_resolution_logs)}")
|
|
|
|
# Success criteria
|
|
api_used = len(example_logs) > 0
|
|
models_resolved = len(model_resolution_logs) > 0
|
|
|
|
if api_used and models_resolved:
|
|
self.logger.info(" ✅ Example provider tests passed")
|
|
return True
|
|
else:
|
|
self.logger.error(" ❌ Example provider tests failed")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Example provider test failed: {e}")
|
|
return False
|
|
|
|
|
|
def main():
|
|
"""Run the Example provider tests"""
|
|
import sys
|
|
|
|
verbose = "--verbose" in sys.argv or "-v" in sys.argv
|
|
test = TestExampleModels(verbose=verbose)
|
|
|
|
success = test.run_test()
|
|
sys.exit(0 if success else 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
```
|
|
|
|
The simulator test is crucial because it:
|
|
- Validates your provider works in the actual Docker environment
|
|
- Tests real API integration, not just mocked behavior
|
|
- Verifies model name resolution works correctly
|
|
- Checks conversation continuity across requests
|
|
- Examines server logs to ensure proper provider selection
|
|
|
|
See `simulator_tests/test_openrouter_models.py` for a complete real-world example.
|
|
|
|
## Model Name Mapping and Provider Priority
|
|
|
|
### How Model Name Resolution Works
|
|
|
|
When a user requests a model (e.g., "pro", "o3", "example-large-v1"), the system:
|
|
|
|
1. **Checks providers in priority order** (defined in `registry.py`):
|
|
```python
|
|
PROVIDER_PRIORITY_ORDER = [
|
|
ProviderType.GOOGLE, # Native Gemini API
|
|
ProviderType.OPENAI, # Native OpenAI API
|
|
ProviderType.CUSTOM, # Local/self-hosted
|
|
ProviderType.OPENROUTER, # Catch-all for everything else
|
|
]
|
|
```
|
|
|
|
2. **For each provider**, calls `validate_model_name()`:
|
|
- Native providers (Gemini, OpenAI) return `true` only for their specific models
|
|
- OpenRouter returns `true` for ANY model (it's the catch-all)
|
|
- First provider that validates the model handles the request
|
|
|
|
### Example: Model "gemini-2.5-pro"
|
|
|
|
1. **Gemini provider** checks: YES, it's in my SUPPORTED_MODELS → Gemini handles it
|
|
2. OpenAI skips (Gemini already handled it)
|
|
3. OpenRouter never sees it
|
|
|
|
### Example: Model "claude-3-opus"
|
|
|
|
1. **Gemini provider** checks: NO, not my model → skip
|
|
2. **OpenAI provider** checks: NO, not my model → skip
|
|
3. **Custom provider** checks: NO, not configured → skip
|
|
4. **OpenRouter provider** checks: YES, I accept all models → OpenRouter handles it
|
|
|
|
### Implementing Model Name Validation
|
|
|
|
Your provider's `validate_model_name()` should:
|
|
|
|
```python
|
|
def validate_model_name(self, model_name: str) -> bool:
|
|
resolved_name = self._resolve_model_name(model_name)
|
|
|
|
# Only accept models you explicitly support
|
|
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
|
return False
|
|
|
|
# Check restrictions
|
|
restriction_service = get_restriction_service()
|
|
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
|
|
logger.debug(f"Example model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
|
return False
|
|
|
|
return True
|
|
```
|
|
|
|
**Important**: Native providers should ONLY return `true` for models they explicitly support. This ensures they get priority over proxy providers like OpenRouter.
|
|
|
|
### Model Shorthands
|
|
|
|
Each provider can define shorthands in their SUPPORTED_MODELS:
|
|
|
|
```python
|
|
SUPPORTED_MODELS = {
|
|
"example-large-v1": { ... }, # Full model name
|
|
"large": "example-large-v1", # Shorthand mapping
|
|
}
|
|
```
|
|
|
|
The `_resolve_model_name()` method handles this mapping automatically.
|
|
|
|
## Best Practices
|
|
|
|
1. **Always validate model names** against supported models and restrictions
|
|
2. **Be specific in validation** - only accept models you actually support
|
|
3. **Handle API errors gracefully** with proper error messages
|
|
4. **Include retry logic** for transient errors (see `gemini.py` for example)
|
|
5. **Log important events** for debugging (initialization, model resolution, errors)
|
|
6. **Support model shorthands** for better user experience
|
|
7. **Document supported models** clearly in your provider class
|
|
8. **Test thoroughly** including error cases and edge conditions
|
|
|
|
## Checklist
|
|
|
|
Before submitting your PR:
|
|
|
|
- [ ] Provider type added to `ProviderType` enum in `providers/base.py`
|
|
- [ ] Provider implementation complete with all required methods
|
|
- [ ] API key mapping added to `_get_api_key_for_provider()` in `providers/registry.py`
|
|
- [ ] Provider added to `PROVIDER_PRIORITY_ORDER` in `registry.py` (if native provider)
|
|
- [ ] Provider imported and registered in `server.py`'s `configure_providers()`
|
|
- [ ] API key checking added to `configure_providers()` function
|
|
- [ ] Error message updated to include new provider
|
|
- [ ] Model capabilities added to `config.py` for auto mode
|
|
- [ ] Documentation updated (README.md)
|
|
- [ ] Unit tests written and passing (`tests/test_<provider>.py`)
|
|
- [ ] Simulator tests written and passing (`simulator_tests/test_<provider>_models.py`)
|
|
- [ ] Integration tested with actual API calls
|
|
- [ ] Code follows project style (run linting)
|
|
- [ ] PR follows the template requirements
|
|
|
|
## Need Help?
|
|
|
|
- Look at existing providers (`gemini.py`, `openai.py`) for examples
|
|
- Check the base classes for method signatures and requirements
|
|
- Run tests frequently during development
|
|
- Ask questions in GitHub issues if stuck |