Files
my-pal-mcp-server/docs/adding_providers.md
2025-06-23 19:44:01 +04:00

11 KiB

Adding a New Provider

This guide explains how to add support for a new AI model provider to the Zen MCP Server. The provider system is designed to be extensible and follows a simple pattern.

Overview

Each provider:

  • Inherits from ModelProvider (base class) or OpenAICompatibleProvider (for OpenAI-compatible APIs)
  • Defines supported models using ModelCapabilities objects
  • Implements a few core abstract methods
  • Gets registered automatically via environment variables

Choose Your Implementation Path

Option A: Full Provider (ModelProvider)

  • For APIs with unique features or custom authentication
  • Complete control over API calls and response handling
  • Required methods: generate_content(), count_tokens(), get_capabilities(), validate_model_name(), supports_thinking_mode(), get_provider_type()

Option B: OpenAI-Compatible (OpenAICompatibleProvider)

  • For APIs that follow OpenAI's chat completion format
  • Only need to define: model configurations, capabilities, and validation
  • Inherits all API handling automatically

⚠️ Important: If using aliases (like "gpt""gpt-4"), override generate_content() to resolve them before API calls.

Step-by-Step Guide

1. Add Provider Type

Add your provider to ProviderType enum in providers/base.py:

class ProviderType(Enum):
    GOOGLE = "google"
    OPENAI = "openai"
    EXAMPLE = "example"  # Add this

2. Create the Provider Implementation

Option A: Full Provider (Native Implementation)

Create providers/example.py:

"""Example model provider implementation."""

import logging
from typing import Optional
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint

logger = logging.getLogger(__name__)

class ExampleModelProvider(ModelProvider):
    """Example model provider implementation."""
    
    # Define models using ModelCapabilities objects (like Gemini provider)
    SUPPORTED_MODELS = {
        "example-large": ModelCapabilities(
            provider=ProviderType.EXAMPLE,
            model_name="example-large",
            friendly_name="Example Large",
            context_window=100_000,
            max_output_tokens=50_000,
            supports_extended_thinking=False,
            temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
            description="Large model for complex tasks",
            aliases=["large", "big"],
        ),
        "example-small": ModelCapabilities(
            provider=ProviderType.EXAMPLE,
            model_name="example-small",
            friendly_name="Example Small",
            context_window=32_000,
            max_output_tokens=16_000,
            temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
            description="Fast model for simple tasks",
            aliases=["small", "fast"],
        ),
    }
    
    def __init__(self, api_key: str, **kwargs):
        super().__init__(api_key, **kwargs)
        # Initialize your API client here
    
    def get_capabilities(self, model_name: str) -> ModelCapabilities:
        resolved_name = self._resolve_model_name(model_name)
        
        if resolved_name not in self.SUPPORTED_MODELS:
            raise ValueError(f"Unsupported model: {model_name}")
        
        # Apply restrictions if needed
        from utils.model_restrictions import get_restriction_service
        restriction_service = get_restriction_service()
        if not restriction_service.is_allowed(ProviderType.EXAMPLE, resolved_name, model_name):
            raise ValueError(f"Model '{model_name}' is not allowed.")
        
        return self.SUPPORTED_MODELS[resolved_name]
    
    def generate_content(self, prompt: str, model_name: str, system_prompt: Optional[str] = None, 
                        temperature: float = 0.7, max_output_tokens: Optional[int] = None, **kwargs) -> ModelResponse:
        resolved_name = self._resolve_model_name(model_name)
        
        # Your API call logic here
        # response = your_api_client.generate(...)
        
        return ModelResponse(
            content="Generated response",  # From your API
            usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
            model_name=resolved_name,
            friendly_name="Example",
            provider=ProviderType.EXAMPLE,
        )
    
    def count_tokens(self, text: str, model_name: str) -> int:
        return len(text) // 4  # Simple estimation
    
    def get_provider_type(self) -> ProviderType:
        return ProviderType.EXAMPLE
    
    def validate_model_name(self, model_name: str) -> bool:
        resolved_name = self._resolve_model_name(model_name)
        return resolved_name in self.SUPPORTED_MODELS
    
    def supports_thinking_mode(self, model_name: str) -> bool:
        capabilities = self.get_capabilities(model_name)
        return capabilities.supports_extended_thinking

Option B: OpenAI-Compatible Provider (Simplified)

For OpenAI-compatible APIs:

"""Example OpenAI-compatible provider."""

from typing import Optional
from .base import ModelCapabilities, ModelResponse, ProviderType, RangeTemperatureConstraint
from .openai_compatible import OpenAICompatibleProvider

class ExampleProvider(OpenAICompatibleProvider):
    """Example OpenAI-compatible provider."""
    
    FRIENDLY_NAME = "Example"
    
    # Define models using ModelCapabilities (consistent with other providers)
    SUPPORTED_MODELS = {
        "example-model-large": ModelCapabilities(
            provider=ProviderType.EXAMPLE,
            model_name="example-model-large",
            friendly_name="Example Large",
            context_window=128_000,
            max_output_tokens=64_000,
            temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
            aliases=["large", "big"],
        ),
    }
    
    def __init__(self, api_key: str, **kwargs):
        kwargs.setdefault("base_url", "https://api.example.com/v1")
        super().__init__(api_key, **kwargs)
    
    def get_capabilities(self, model_name: str) -> ModelCapabilities:
        resolved_name = self._resolve_model_name(model_name)
        if resolved_name not in self.SUPPORTED_MODELS:
            raise ValueError(f"Unsupported model: {model_name}")
        return self.SUPPORTED_MODELS[resolved_name]
    
    def get_provider_type(self) -> ProviderType:
        return ProviderType.EXAMPLE
    
    def validate_model_name(self, model_name: str) -> bool:
        resolved_name = self._resolve_model_name(model_name)
        return resolved_name in self.SUPPORTED_MODELS
    
    def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
        # IMPORTANT: Resolve aliases before API call
        resolved_model_name = self._resolve_model_name(model_name)
        return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)

3. Register Your Provider

Add environment variable mapping in providers/registry.py:

# In _get_api_key_for_provider method:
key_mapping = {
    ProviderType.GOOGLE: "GEMINI_API_KEY",
    ProviderType.OPENAI: "OPENAI_API_KEY", 
    ProviderType.EXAMPLE: "EXAMPLE_API_KEY",  # Add this
}

Add to server.py:

  1. Import your provider:
from providers.example import ExampleModelProvider
  1. Add to configure_providers() function:
# Check for Example API key
example_key = os.getenv("EXAMPLE_API_KEY")
if example_key:
    ModelProviderRegistry.register_provider(ProviderType.EXAMPLE, ExampleModelProvider)
    logger.info("Example API key found - Example models available")
  1. Add to provider priority (in providers/registry.py):
PROVIDER_PRIORITY_ORDER = [
    ProviderType.GOOGLE,
    ProviderType.OPENAI, 
    ProviderType.EXAMPLE,     # Add your provider here
    ProviderType.CUSTOM,      # Local models
    ProviderType.OPENROUTER,  # Catch-all (keep last)
]

4. Environment Configuration

Add to your .env file:

# Your provider's API key
EXAMPLE_API_KEY=your_api_key_here

# Optional: Disable specific tools
DISABLED_TOOLS=debug,tracer

Note: The description field in ModelCapabilities helps Claude choose the best model in auto mode.

5. Test Your Provider

Create basic tests to verify your implementation:

# Test model validation
provider = ExampleModelProvider("test-key")
assert provider.validate_model_name("large") == True
assert provider.validate_model_name("unknown") == False

# Test capabilities
caps = provider.get_capabilities("large")
assert caps.context_window > 0
assert caps.provider == ProviderType.EXAMPLE

Key Concepts

Provider Priority

When a user requests a model, providers are checked in priority order:

  1. Native providers (Gemini, OpenAI, Example) - handle their specific models
  2. Custom provider - handles local/self-hosted models
  3. OpenRouter - catch-all for everything else

Model Validation

Your validate_model_name() should only return True for models you explicitly support:

def validate_model_name(self, model_name: str) -> bool:
    resolved_name = self._resolve_model_name(model_name)
    return resolved_name in self.SUPPORTED_MODELS  # Be specific!

Model Aliases

The base class handles alias resolution automatically via the aliases field in ModelCapabilities.

Important Notes

Alias Resolution in OpenAI-Compatible Providers

If using OpenAICompatibleProvider with aliases, you must override generate_content() to resolve aliases before API calls:

def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
    # Resolve alias before API call
    resolved_model_name = self._resolve_model_name(model_name)
    return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)

Without this, API calls with aliases like "large" will fail because your API doesn't recognize the alias.

Best Practices

  • Be specific in model validation - only accept models you actually support
  • Use ModelCapabilities objects consistently (like Gemini provider)
  • Include descriptive aliases for better user experience
  • Add error handling and logging for debugging
  • Test with real API calls to verify everything works
  • Follow the existing patterns in providers/gemini.py and providers/custom.py

Quick Checklist

  • Added to ProviderType enum in providers/base.py
  • Created provider class with all required methods
  • Added API key mapping in providers/registry.py
  • Added to provider priority order in registry.py
  • Imported and registered in server.py
  • Basic tests verify model validation and capabilities
  • Tested with real API calls

Examples

See existing implementations:

  • Full provider: providers/gemini.py
  • OpenAI-compatible: providers/custom.py
  • Base classes: providers/base.py

The modern approach uses ModelCapabilities objects directly in SUPPORTED_MODELS, making the implementation much cleaner and more consistent.