feat: DIAL provider implementation (#112)
## Description This PR implements a new [DIAL](https://dialx.ai/dial_api) (Data & AI Layer) provider for the Zen MCP Server, enabling unified access to multiple AI models through the DIAL API platform. DIAL provides enterprise-grade AI model access with deployment-specific routing similar to Azure OpenAI. ## Changes Made - [x] Added support of atexit: - Ensures automatic cleanup of provider resources (HTTP clients, connection pools) on server shutdown - Fixed bug using ModelProviderRegistry.get_available_providers() instead of accessing private _providers - Works with SIGTERM/Ctrl+C for graceful shutdown in both development and containerized environments - [x] Added new DIAL provider (`providers/dial.py`) inheriting from `OpenAICompatibleProvider` - [x] Updated server.py to register DIAL provider during initialization - [x] Updated provider registry to include DIAL provider type - [x] Implemented deployment-specific routing for DIAL's Azure OpenAI-style endpoints - [x] Implemented performance optimizations: - Connection pooling with httpx for better performance - Thread-safe client caching with double-check locking pattern - Proper resource cleanup with `close()` method - [x] Added comprehensive unit tests with 16 test cases (`tests/test_dial_provider.py`) - [x] Added DIAL configuration to `.env.example` with documentation - [x] Added support for configurable API version via `DIAL_API_VERSION` environment variable - [x] Added DIAL model restrictions support via `DIAL_ALLOWED_MODELS` environment variable ### Supported DIAL Models: - OpenAI models: o3, o4-mini (and their dated versions) - Google models: gemini-2.5-pro, gemini-2.5-flash (including search variant) - Anthropic models: Claude 4 Opus/Sonnet (with and without thinking mode) ### Environment Variables: - `DIAL_API_KEY`: Required API key for DIAL authentication - `DIAL_API_HOST`: Optional base URL (defaults to https://core.dialx.ai) - `DIAL_API_VERSION`: Optional API version header (defaults to 2025-01-01-preview) - `DIAL_ALLOWED_MODELS`: Optional comma-separated list of allowed models ### Breaking Changes: - None ### Dependencies: - No new dependencies added (uses existing OpenAI SDK with custom routing)
This commit is contained in:
committed by
GitHub
parent
4ae0344b14
commit
0623ce3546
34
server.py
34
server.py
@@ -19,6 +19,7 @@ as defined by the MCP protocol.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -271,6 +272,7 @@ def configure_providers():
|
||||
from providers import ModelProviderRegistry
|
||||
from providers.base import ProviderType
|
||||
from providers.custom import CustomProvider
|
||||
from providers.dial import DIALModelProvider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
@@ -303,6 +305,13 @@ def configure_providers():
|
||||
has_native_apis = True
|
||||
logger.info("X.AI API key found - GROK models available")
|
||||
|
||||
# Check for DIAL API key
|
||||
dial_key = os.getenv("DIAL_API_KEY")
|
||||
if dial_key and dial_key != "your_dial_api_key_here":
|
||||
valid_providers.append("DIAL")
|
||||
has_native_apis = True
|
||||
logger.info("DIAL API key found - DIAL models available")
|
||||
|
||||
# Check for OpenRouter API key
|
||||
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
|
||||
@@ -336,6 +345,8 @@ def configure_providers():
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
if xai_key and xai_key != "your_xai_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
if dial_key and dial_key != "your_dial_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider)
|
||||
|
||||
# 2. Custom provider second (for local/private models)
|
||||
if has_custom:
|
||||
@@ -358,6 +369,7 @@ def configure_providers():
|
||||
"- GEMINI_API_KEY for Gemini models\n"
|
||||
"- OPENAI_API_KEY for OpenAI o3 model\n"
|
||||
"- XAI_API_KEY for X.AI GROK models\n"
|
||||
"- DIAL_API_KEY for DIAL models\n"
|
||||
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
|
||||
"- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)"
|
||||
)
|
||||
@@ -376,6 +388,25 @@ def configure_providers():
|
||||
if len(priority_info) > 1:
|
||||
logger.info(f"Provider priority: {' → '.join(priority_info)}")
|
||||
|
||||
# Register cleanup function for providers
|
||||
def cleanup_providers():
|
||||
"""Clean up all registered providers on shutdown."""
|
||||
try:
|
||||
registry = ModelProviderRegistry()
|
||||
if hasattr(registry, "_initialized_providers"):
|
||||
for provider in list(registry._initialized_providers.items()):
|
||||
try:
|
||||
if provider and hasattr(provider, "close"):
|
||||
provider.close()
|
||||
except Exception:
|
||||
# Logger might be closed during shutdown
|
||||
pass
|
||||
except Exception:
|
||||
# Silently ignore any errors during cleanup
|
||||
pass
|
||||
|
||||
atexit.register(cleanup_providers)
|
||||
|
||||
# Check and log model restrictions
|
||||
restriction_service = get_restriction_service()
|
||||
restrictions = restriction_service.get_restriction_summary()
|
||||
@@ -390,7 +421,8 @@ def configure_providers():
|
||||
|
||||
# Validate restrictions against known models
|
||||
provider_instances = {}
|
||||
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI]:
|
||||
provider_types_to_validate = [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]
|
||||
for provider_type in provider_types_to_validate:
|
||||
provider = ModelProviderRegistry.get_provider(provider_type)
|
||||
if provider:
|
||||
provider_instances[provider_type] = provider
|
||||
|
||||
Reference in New Issue
Block a user