Updated guides

This commit is contained in:
Fahad
2025-06-23 19:44:01 +04:00
parent a355b80afc
commit 4faa661c6d
2 changed files with 265 additions and 1273 deletions

View File

@@ -1,88 +1,83 @@
# Adding a New Provider
This guide explains how to add support for a new AI model provider to the Zen MCP Server. Follow these steps to integrate providers like Anthropic, Cohere, or any API that provides AI model access.
This guide explains how to add support for a new AI model provider to the Zen MCP Server. The provider system is designed to be extensible and follows a simple pattern.
## Overview
The provider system in Zen MCP Server is designed to be extensible. Each provider:
- Inherits from a base class (`ModelProvider` or `OpenAICompatibleProvider`)
- Implements required methods for model interaction
- Is registered in the provider registry by the server
- Has its API key configured via environment variables
Each provider:
- Inherits from `ModelProvider` (base class) or `OpenAICompatibleProvider` (for OpenAI-compatible APIs)
- Defines supported models using `ModelCapabilities` objects
- Implements a few core abstract methods
- Gets registered automatically via environment variables
## Implementation Paths
## Choose Your Implementation Path
You have two options when implementing a new provider:
**Option A: Full Provider (`ModelProvider`)**
- For APIs with unique features or custom authentication
- Complete control over API calls and response handling
- Required methods: `generate_content()`, `count_tokens()`, `get_capabilities()`, `validate_model_name()`, `supports_thinking_mode()`, `get_provider_type()`
### Option A: Native Provider (Full Implementation)
Inherit from `ModelProvider` when:
- Your API has unique features not compatible with OpenAI's format
- You need full control over the implementation
- You want to implement custom features like extended thinking
**Option B: OpenAI-Compatible (`OpenAICompatibleProvider`)**
- For APIs that follow OpenAI's chat completion format
- Only need to define: model configurations, capabilities, and validation
- Inherits all API handling automatically
### Option B: OpenAI-Compatible Provider (Simplified)
Inherit from `OpenAICompatibleProvider` when:
- Your API follows OpenAI's chat completion format
- You want to reuse existing implementation for most functionality
- You only need to define model capabilities and validation
⚠️ **CRITICAL**: If your provider has model aliases (shorthands), you **MUST** override `generate_content()` to resolve aliases before API calls. See implementation example below.
⚠️ **Important**: If using aliases (like `"gpt"``"gpt-4"`), override `generate_content()` to resolve them before API calls.
## Step-by-Step Guide
### 1. Add Provider Type to Enum
### 1. Add Provider Type
First, add your provider to the `ProviderType` enum in `providers/base.py`:
Add your provider to `ProviderType` enum in `providers/base.py`:
```python
class ProviderType(Enum):
"""Supported model provider types."""
GOOGLE = "google"
OPENAI = "openai"
OPENROUTER = "openrouter"
CUSTOM = "custom"
EXAMPLE = "example" # Add your provider here
EXAMPLE = "example" # Add this
```
### 2. Create the Provider Implementation
#### Option A: Native Provider Implementation
#### Option A: Full Provider (Native Implementation)
Create a new file in the `providers/` directory (e.g., `providers/example.py`):
Create `providers/example.py`:
```python
"""Example model provider implementation."""
import logging
from typing import Optional
from .base import (
ModelCapabilities,
ModelProvider,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
from utils.model_restrictions import get_restriction_service
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
logger = logging.getLogger(__name__)
class ExampleModelProvider(ModelProvider):
"""Example model provider implementation."""
# Define models using ModelCapabilities objects (like Gemini provider)
SUPPORTED_MODELS = {
"example-large-v1": {
"context_window": 100_000,
"supports_extended_thinking": False,
},
"example-small-v1": {
"context_window": 50_000,
"supports_extended_thinking": False,
},
# Shorthands
"large": "example-large-v1",
"small": "example-small-v1",
"example-large": ModelCapabilities(
provider=ProviderType.EXAMPLE,
model_name="example-large",
friendly_name="Example Large",
context_window=100_000,
max_output_tokens=50_000,
supports_extended_thinking=False,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
description="Large model for complex tasks",
aliases=["large", "big"],
),
"example-small": ModelCapabilities(
provider=ProviderType.EXAMPLE,
model_name="example-small",
friendly_name="Example Small",
context_window=32_000,
max_output_tokens=16_000,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
description="Fast model for simple tasks",
aliases=["small", "fast"],
),
}
def __init__(self, api_key: str, **kwargs):
@@ -95,708 +90,225 @@ class ExampleModelProvider(ModelProvider):
if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported model: {model_name}")
# Apply restrictions if needed
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
raise ValueError(f"Model '{model_name}' is not allowed.")
config = self.SUPPORTED_MODELS[resolved_name]
return ModelCapabilities(
provider=ProviderType.EXAMPLE,
model_name=resolved_name,
friendly_name="Example",
context_window=config["context_window"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
)
return self.SUPPORTED_MODELS[resolved_name]
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None,
temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
resolved_name = self._resolve_model_name(model_name)
self.validate_parameters(resolved_name, temperature)
# Call your API here
# response = your_api_call(...)
# Your API call logic here
# response = your_api_client.generate(...)
return ModelResponse(
content="", # From API response
usage={
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
},
content="Generated response", # From your API
usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
model_name=resolved_name,
friendly_name="Example",
provider=ProviderType.EXAMPLE,
)
def count_tokens(self, text: str, model_name: str) -> int:
# Implement your tokenization or use estimation
return len(text) // 4
return len(text) // 4 # Simple estimation
def get_provider_type(self) -> ProviderType:
return ProviderType.EXAMPLE
def validate_model_name(self, model_name: str) -> bool:
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
return False
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
logger.debug(f"Example model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
return resolved_name in self.SUPPORTED_MODELS
def supports_thinking_mode(self, model_name: str) -> bool:
capabilities = self.get_capabilities(model_name)
return capabilities.supports_extended_thinking
def _resolve_model_name(self, model_name: str) -> str:
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
```
#### Option B: OpenAI-Compatible Provider Implementation
#### Option B: OpenAI-Compatible Provider (Simplified)
For providers with OpenAI-compatible APIs, the implementation is much simpler:
For OpenAI-compatible APIs:
```python
"""Example provider using OpenAI-compatible interface."""
"""Example OpenAI-compatible provider."""
import logging
from typing import Optional
from .base import (
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
from .base import ModelCapabilities, ModelResponse, ProviderType, RangeTemperatureConstraint
from .openai_compatible import OpenAICompatibleProvider
logger = logging.getLogger(__name__)
class ExampleProvider(OpenAICompatibleProvider):
"""Example provider using OpenAI-compatible API."""
"""Example OpenAI-compatible provider."""
FRIENDLY_NAME = "Example"
# Define supported models
# Define models using ModelCapabilities (consistent with other providers)
SUPPORTED_MODELS = {
"example-model-large": {
"context_window": 128_000,
"supports_extended_thinking": False,
},
"example-model-small": {
"context_window": 32_000,
"supports_extended_thinking": False,
},
# Shorthands
"large": "example-model-large",
"small": "example-model-small",
"example-model-large": ModelCapabilities(
provider=ProviderType.EXAMPLE,
model_name="example-model-large",
friendly_name="Example Large",
context_window=128_000,
max_output_tokens=64_000,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
aliases=["large", "big"],
),
}
def __init__(self, api_key: str, **kwargs):
"""Initialize provider with API key."""
# Set your API base URL
kwargs.setdefault("base_url", "https://api.example.com/v1")
super().__init__(api_key, **kwargs)
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model."""
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported model: {model_name}")
# Check restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
config = self.SUPPORTED_MODELS[resolved_name]
return ModelCapabilities(
provider=ProviderType.EXAMPLE,
model_name=resolved_name,
friendly_name=self.FRIENDLY_NAME,
context_window=config["context_window"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
temperature_constraint=RangeTemperatureConstraint(0.0, 1.0, 0.7),
)
return self.SUPPORTED_MODELS[resolved_name]
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.EXAMPLE
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported."""
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
return False
# Check restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
return False
return True
return resolved_name in self.SUPPORTED_MODELS
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using API with proper model name resolution."""
# CRITICAL: Resolve model alias before making API call
# This ensures aliases like "large" get sent as "example-model-large" to the API
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
# IMPORTANT: Resolve aliases before API call
resolved_model_name = self._resolve_model_name(model_name)
# Call parent implementation with resolved model name
return super().generate_content(
prompt=prompt,
model_name=resolved_model_name,
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs,
)
# Note: count_tokens is inherited from OpenAICompatibleProvider
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
```
### 3. Update Registry Configuration
### 3. Register Your Provider
#### 3.1. Add Environment Variable Mapping
Update `providers/registry.py` to map your provider's API key:
Add environment variable mapping in `providers/registry.py`:
```python
@classmethod
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
"""Get API key for a provider from environment variables."""
key_mapping = {
ProviderType.GOOGLE: "GEMINI_API_KEY",
ProviderType.OPENAI: "OPENAI_API_KEY",
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
ProviderType.CUSTOM: "CUSTOM_API_KEY",
ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this line
}
# ... rest of the method
```
### 4. Register Provider in server.py
The `configure_providers()` function in `server.py` handles provider registration. You need to:
**Note**: The provider priority is hardcoded in `registry.py`. If you're adding a new native provider (like Example), you'll need to update the `PROVIDER_PRIORITY_ORDER` in `get_provider_for_model()`:
```python
# In providers/registry.py
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.EXAMPLE, # Add your native provider here
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all (must stay last)
]
```
Native providers should be placed BEFORE CUSTOM and OPENROUTER to ensure they get priority for their models.
1. **Import your provider class** at the top of `server.py`:
```python
from providers.example import ExampleModelProvider
```
2. **Add API key checking** in the `configure_providers()` function:
```python
def configure_providers():
"""Configure and validate AI providers based on available API keys."""
# ... existing code ...
# Check for Example API key
example_key = os.getenv("EXAMPLE_API_KEY")
if example_key and example_key != "your_example_api_key_here":
valid_providers.append("Example")
has_native_apis = True
logger.info("Example API key found - Example models available")
```
3. **Register the provider** in the appropriate section:
```python
# Register providers in priority order:
# 1. Native APIs first (most direct and efficient)
if has_native_apis:
if gemini_key and gemini_key != "your_gemini_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
if openai_key and openai_key != "your_openai_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
if example_key and example_key != "your_example_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.EXAMPLE, ExampleModelProvider)
```
4. **Update error message** to include your provider:
```python
if not valid_providers:
raise ValueError(
"At least one API configuration is required. Please set either:\n"
"- GEMINI_API_KEY for Gemini models\n"
"- OPENAI_API_KEY for OpenAI o3 model\n"
"- EXAMPLE_API_KEY for Example models\n" # Add this
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
"- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)"
)
```
### 6. Add Model Descriptions for Auto Mode
Add descriptions to your model configurations in the `SUPPORTED_MODELS` dictionary. These descriptions help Claude choose the best model for each task in auto mode:
```python
# In your provider implementation
SUPPORTED_MODELS = {
"example-large-v1": {
"context_window": 100_000,
"supports_extended_thinking": False,
"description": "Example Large (100K context) - High capacity model for complex tasks",
},
"example-small-v1": {
"context_window": 50_000,
"supports_extended_thinking": False,
"description": "Example Small (50K context) - Fast model for simple tasks",
},
# Aliases for convenience
"large": "example-large-v1",
"small": "example-small-v1",
# In _get_api_key_for_provider method:
key_mapping = {
ProviderType.GOOGLE: "GEMINI_API_KEY",
ProviderType.OPENAI: "OPENAI_API_KEY",
ProviderType.EXAMPLE: "EXAMPLE_API_KEY", # Add this
}
```
The descriptions should be detailed and help Claude understand when to use each model. Include context about performance, capabilities, cost, and ideal use cases.
### 7. Update Documentation
#### 7.1. Update README.md
Add your provider to the quickstart section:
```markdown
### 1. Get API Keys (at least one required)
**Option B: Native APIs**
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey)
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys)
- **Example**: Visit [Example API Console](https://example.com/api-keys) # Add this
```
Also update the .env file example:
```markdown
# Edit .env to add your API keys
# GEMINI_API_KEY=your-gemini-api-key-here
# OPENAI_API_KEY=your-openai-api-key-here
# EXAMPLE_API_KEY=your-example-api-key-here # Add this
```
### 8. Write Tests
#### 8.1. Unit Tests
Create `tests/test_example_provider.py`:
Add to `server.py`:
1. **Import your provider**:
```python
"""Tests for Example provider implementation."""
import os
from unittest.mock import patch
import pytest
from providers.example import ExampleModelProvider
from providers.base import ProviderType
class TestExampleProvider:
"""Test Example provider functionality."""
@patch.dict(os.environ, {"EXAMPLE_API_KEY": "test-key"})
def test_initialization(self):
"""Test provider initialization."""
provider = ExampleModelProvider("test-key")
assert provider.api_key == "test-key"
assert provider.get_provider_type() == ProviderType.EXAMPLE
def test_model_validation(self):
"""Test model name validation."""
provider = ExampleModelProvider("test-key")
# Test valid models
assert provider.validate_model_name("large") is True
assert provider.validate_model_name("example-large-v1") is True
# Test invalid model
assert provider.validate_model_name("invalid-model") is False
def test_resolve_model_name(self):
"""Test model name resolution."""
provider = ExampleModelProvider("test-key")
# Test shorthand resolution
assert provider._resolve_model_name("large") == "example-large-v1"
assert provider._resolve_model_name("small") == "example-small-v1"
# Test full name passthrough
assert provider._resolve_model_name("example-large-v1") == "example-large-v1"
def test_get_capabilities(self):
"""Test getting model capabilities."""
provider = ExampleModelProvider("test-key")
capabilities = provider.get_capabilities("large")
assert capabilities.model_name == "example-large-v1"
assert capabilities.friendly_name == "Example"
assert capabilities.context_window == 100_000
assert capabilities.provider == ProviderType.EXAMPLE
# Test temperature range
assert capabilities.temperature_constraint.min_temp == 0.0
assert capabilities.temperature_constraint.max_temp == 2.0
```
#### 8.2. Simulator Tests (Real-World Validation)
2. **Add to `configure_providers()` function**:
```python
# Check for Example API key
example_key = os.getenv("EXAMPLE_API_KEY")
if example_key:
ModelProviderRegistry.register_provider(ProviderType.EXAMPLE, ExampleModelProvider)
logger.info("Example API key found - Example models available")
```
Create a simulator test to validate that your provider works correctly in real-world scenarios. Create `simulator_tests/test_example_models.py`:
3. **Add to provider priority** (in `providers/registry.py`):
```python
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE,
ProviderType.OPENAI,
ProviderType.EXAMPLE, # Add your provider here
ProviderType.CUSTOM, # Local models
ProviderType.OPENROUTER, # Catch-all (keep last)
]
```
### 4. Environment Configuration
Add to your `.env` file:
```bash
# Your provider's API key
EXAMPLE_API_KEY=your_api_key_here
# Optional: Disable specific tools
DISABLED_TOOLS=debug,tracer
```
**Note**: The `description` field in `ModelCapabilities` helps Claude choose the best model in auto mode.
### 5. Test Your Provider
Create basic tests to verify your implementation:
```python
"""
Example Provider Model Tests
# Test model validation
provider = ExampleModelProvider("test-key")
assert provider.validate_model_name("large") == True
assert provider.validate_model_name("unknown") == False
Tests that verify Example provider functionality including:
- Model alias resolution
- API integration
- Conversation continuity
- Error handling
"""
from .base_test import BaseSimulatorTest
class TestExampleModels(BaseSimulatorTest):
"""Test Example provider functionality"""
@property
def test_name(self) -> str:
return "example_models"
@property
def test_description(self) -> str:
return "Example provider model functionality and integration"
def run_test(self) -> bool:
"""Test Example provider models"""
try:
self.logger.info("Test: Example provider functionality")
# Check if Example API key is configured
check_result = self.check_env_var("EXAMPLE_API_KEY")
if not check_result:
self.logger.info(" ⚠️ Example API key not configured - skipping test")
return True # Skip, not fail
# Test 1: Shorthand alias mapping
self.logger.info(" 1: Testing 'large' alias mapping")
response1, continuation_id = self.call_mcp_tool(
"chat",
{
"prompt": "Say 'Hello from Example Large model!' and nothing else.",
"model": "large", # Should map to example-large-v1
"temperature": 0.1,
}
)
if not response1:
self.logger.error(" ❌ Large alias test failed")
return False
self.logger.info(" ✅ Large alias call completed")
# Test 2: Direct model name
self.logger.info(" 2: Testing direct model name (example-small-v1)")
response2, _ = self.call_mcp_tool(
"chat",
{
"prompt": "Say 'Hello from Example Small model!' and nothing else.",
"model": "example-small-v1",
"temperature": 0.1,
}
)
if not response2:
self.logger.error(" ❌ Direct model name test failed")
return False
self.logger.info(" ✅ Direct model name call completed")
# Test 3: Conversation continuity
self.logger.info(" 3: Testing conversation continuity")
response3, new_continuation_id = self.call_mcp_tool(
"chat",
{
"prompt": "Remember this number: 99. What number did I just tell you?",
"model": "large",
"temperature": 0.1,
}
)
if not response3 or not new_continuation_id:
self.logger.error(" ❌ Failed to start conversation")
return False
# Continue conversation
response4, _ = self.call_mcp_tool(
"chat",
{
"prompt": "What was the number I told you earlier?",
"model": "large",
"continuation_id": new_continuation_id,
"temperature": 0.1,
}
)
if not response4:
self.logger.error(" ❌ Failed to continue conversation")
return False
if "99" in response4:
self.logger.info(" ✅ Conversation continuity working")
else:
self.logger.warning(" ⚠️ Model may not have remembered the number")
# Test 4: Check logs for proper provider usage
self.logger.info(" 4: Validating Example provider usage in logs")
logs = self.get_recent_server_logs()
# Look for evidence of Example provider usage
example_logs = [line for line in logs.split("\n") if "example" in line.lower()]
model_resolution_logs = [
line for line in logs.split("\n")
if "Resolved model" in line and "example" in line.lower()
]
self.logger.info(f" Example-related logs: {len(example_logs)}")
self.logger.info(f" Model resolution logs: {len(model_resolution_logs)}")
# Success criteria
api_used = len(example_logs) > 0
models_resolved = len(model_resolution_logs) > 0
if api_used and models_resolved:
self.logger.info(" ✅ Example provider tests passed")
return True
else:
self.logger.error(" ❌ Example provider tests failed")
return False
except Exception as e:
self.logger.error(f"Example provider test failed: {e}")
return False
def main():
"""Run the Example provider tests"""
import sys
verbose = "--verbose" in sys.argv or "-v" in sys.argv
test = TestExampleModels(verbose=verbose)
success = test.run_test()
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()
# Test capabilities
caps = provider.get_capabilities("large")
assert caps.context_window > 0
assert caps.provider == ProviderType.EXAMPLE
```
The simulator test is crucial because it:
- Validates your provider works in the actual server environment
- Tests real API integration, not just mocked behavior
- Verifies model name resolution works correctly
- Checks conversation continuity across requests
- Examines server logs to ensure proper provider selection
See `simulator_tests/test_openrouter_models.py` for a complete real-world example.
## Model Name Mapping and Provider Priority
## Key Concepts
### How Model Name Resolution Works
### Provider Priority
When a user requests a model, providers are checked in priority order:
1. **Native providers** (Gemini, OpenAI, Example) - handle their specific models
2. **Custom provider** - handles local/self-hosted models
3. **OpenRouter** - catch-all for everything else
When a user requests a model (e.g., "pro", "o3", "example-large-v1"), the system:
1. **Checks providers in priority order** (defined in `registry.py`):
```python
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Native Gemini API
ProviderType.OPENAI, # Native OpenAI API
ProviderType.CUSTOM, # Local/self-hosted
ProviderType.OPENROUTER, # Catch-all for everything else
]
```
2. **For each provider**, calls `validate_model_name()`:
- Native providers (Gemini, OpenAI) return `true` only for their specific models
- OpenRouter returns `true` for ANY model (it's the catch-all)
- First provider that validates the model handles the request
### Example: Model "gemini-2.5-pro"
1. **Gemini provider** checks: YES, it's in my SUPPORTED_MODELS → Gemini handles it
2. OpenAI skips (Gemini already handled it)
3. OpenRouter never sees it
### Example: Model "claude-4-opus"
1. **Gemini provider** checks: NO, not my model → skip
2. **OpenAI provider** checks: NO, not my model → skip
3. **Custom provider** checks: NO, not configured → skip
4. **OpenRouter provider** checks: YES, I accept all models → OpenRouter handles it
### Implementing Model Name Validation
Your provider's `validate_model_name()` should:
### Model Validation
Your `validate_model_name()` should **only** return `True` for models you explicitly support:
```python
def validate_model_name(self, model_name: str) -> bool:
resolved_name = self._resolve_model_name(model_name)
# Only accept models you explicitly support
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
return False
# Check restrictions
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
logger.debug(f"Example model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
return resolved_name in self.SUPPORTED_MODELS # Be specific!
```
**Important**: Native providers should ONLY return `true` for models they explicitly support. This ensures they get priority over proxy providers like OpenRouter.
### Model Aliases
The base class handles alias resolution automatically via the `aliases` field in `ModelCapabilities`.
### Model Shorthands
## Important Notes
Each provider can define shorthands in their SUPPORTED_MODELS:
### Alias Resolution in OpenAI-Compatible Providers
If using `OpenAICompatibleProvider` with aliases, **you must override `generate_content()`** to resolve aliases before API calls:
```python
SUPPORTED_MODELS = {
"example-large-v1": { ... }, # Full model name
"large": "example-large-v1", # Shorthand mapping
}
```
The `_resolve_model_name()` method handles this mapping automatically.
## Critical Implementation Requirements
### Alias Resolution for OpenAI-Compatible Providers
If you inherit from `OpenAICompatibleProvider` and define model aliases, you **MUST** override `generate_content()` to resolve aliases before API calls. This is because:
1. **The base `OpenAICompatibleProvider.generate_content()`** sends the original model name directly to the API
2. **Your API expects the full model name**, not the alias
3. **Without resolution**, requests like `model="large"` will fail with 404/400 errors
**Examples of providers that need this:**
- XAI provider: `"grok"` → `"grok-3"`
- OpenAI provider: `"mini"` → `"o4-mini"`
- Custom provider: `"fast"` → `"llama-3.1-8b-instruct"`
**Example implementation pattern:**
```python
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
# CRITICAL: Resolve alias before API call
# Resolve alias before API call
resolved_model_name = self._resolve_model_name(model_name)
# Pass resolved name to parent
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
```
**Providers that DON'T need this:**
- Gemini provider (has its own generate_content implementation)
- OpenRouter provider (already implements this pattern)
- Providers without aliases
Without this, API calls with aliases like `"large"` will fail because your API doesn't recognize the alias.
## Best Practices
1. **Always validate model names** against supported models and restrictions
2. **Be specific in validation** - only accept models you actually support
3. **Handle API errors gracefully** with proper error messages
4. **Include retry logic** for transient errors (see `gemini.py` for example)
5. **Log important events** for debugging (initialization, model resolution, errors)
6. **Support model shorthands** for better user experience
7. **Document supported models** clearly in your provider class
8. **Test thoroughly** including error cases and edge conditions
- **Be specific in model validation** - only accept models you actually support
- **Use ModelCapabilities objects** consistently (like Gemini provider)
- **Include descriptive aliases** for better user experience
- **Add error handling** and logging for debugging
- **Test with real API calls** to verify everything works
- **Follow the existing patterns** in `providers/gemini.py` and `providers/custom.py`
## Checklist
## Quick Checklist
Before submitting your PR:
- [ ] Added to `ProviderType` enum in `providers/base.py`
- [ ] Created provider class with all required methods
- [ ] Added API key mapping in `providers/registry.py`
- [ ] Added to provider priority order in `registry.py`
- [ ] Imported and registered in `server.py`
- [ ] Basic tests verify model validation and capabilities
- [ ] Tested with real API calls
- [ ] Provider type added to `ProviderType` enum in `providers/base.py`
- [ ] Provider implementation complete with all required methods
- [ ] API key mapping added to `_get_api_key_for_provider()` in `providers/registry.py`
- [ ] Provider added to `PROVIDER_PRIORITY_ORDER` in `registry.py` (if native provider)
- [ ] **Environment variables added to `.env` file** (API key and restrictions)
- [ ] Provider imported and registered in `server.py`'s `configure_providers()`
- [ ] API key checking added to `configure_providers()` function
- [ ] Error message updated to include new provider
- [ ] Model capabilities added to `config.py` for auto mode
- [ ] Documentation updated (README.md)
- [ ] Unit tests written and passing (`tests/test_<provider>.py`)
- [ ] Simulator tests written and passing (`simulator_tests/test_<provider>_models.py`)
- [ ] Integration tested with actual API calls
- [ ] Code follows project style (run linting)
- [ ] PR follows the template requirements
## Examples
## Need Help?
See existing implementations:
- **Full provider**: `providers/gemini.py`
- **OpenAI-compatible**: `providers/custom.py`
- **Base classes**: `providers/base.py`
- Look at existing providers (`gemini.py`, `openai.py`) for examples
- Check the base classes for method signatures and requirements
- Run tests frequently during development
- Ask questions in GitHub issues if stuck
The modern approach uses `ModelCapabilities` objects directly in `SUPPORTED_MODELS`, making the implementation much cleaner and more consistent.