feat: DIAL provider implementation (#112)
## Description This PR implements a new [DIAL](https://dialx.ai/dial_api) (Data & AI Layer) provider for the Zen MCP Server, enabling unified access to multiple AI models through the DIAL API platform. DIAL provides enterprise-grade AI model access with deployment-specific routing similar to Azure OpenAI. ## Changes Made - [x] Added support of atexit: - Ensures automatic cleanup of provider resources (HTTP clients, connection pools) on server shutdown - Fixed bug using ModelProviderRegistry.get_available_providers() instead of accessing private _providers - Works with SIGTERM/Ctrl+C for graceful shutdown in both development and containerized environments - [x] Added new DIAL provider (`providers/dial.py`) inheriting from `OpenAICompatibleProvider` - [x] Updated server.py to register DIAL provider during initialization - [x] Updated provider registry to include DIAL provider type - [x] Implemented deployment-specific routing for DIAL's Azure OpenAI-style endpoints - [x] Implemented performance optimizations: - Connection pooling with httpx for better performance - Thread-safe client caching with double-check locking pattern - Proper resource cleanup with `close()` method - [x] Added comprehensive unit tests with 16 test cases (`tests/test_dial_provider.py`) - [x] Added DIAL configuration to `.env.example` with documentation - [x] Added support for configurable API version via `DIAL_API_VERSION` environment variable - [x] Added DIAL model restrictions support via `DIAL_ALLOWED_MODELS` environment variable ### Supported DIAL Models: - OpenAI models: o3, o4-mini (and their dated versions) - Google models: gemini-2.5-pro, gemini-2.5-flash (including search variant) - Anthropic models: Claude 4 Opus/Sonnet (with and without thinking mode) ### Environment Variables: - `DIAL_API_KEY`: Required API key for DIAL authentication - `DIAL_API_HOST`: Optional base URL (defaults to https://core.dialx.ai) - `DIAL_API_VERSION`: Optional API version header (defaults to 2025-01-01-preview) - `DIAL_ALLOWED_MODELS`: Optional comma-separated list of allowed models ### Breaking Changes: - None ### Dependencies: - No new dependencies added (uses existing OpenAI SDK with custom routing)
This commit is contained in:
committed by
GitHub
parent
4ae0344b14
commit
0623ce3546
525
providers/dial.py
Normal file
525
providers/dial.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""DIAL (Data & AI Layer) model provider implementation."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"""DIAL provider using OpenAI-compatible API.
|
||||
|
||||
DIAL provides access to various AI models through a unified API interface.
|
||||
Supports GPT, Claude, Gemini, and other models via DIAL deployments.
|
||||
"""
|
||||
|
||||
FRIENDLY_NAME = "DIAL"
|
||||
|
||||
# Retry configuration for API calls
|
||||
MAX_RETRIES = 4
|
||||
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||
|
||||
# Supported DIAL models (these can be customized based on your DIAL deployment)
|
||||
SUPPORTED_MODELS = {
|
||||
"o3-2025-04-16": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"o4-mini-2025-04-16": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": True, # Thinking mode variant
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-opus-4-20250514-v1:0": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": True, # Thinking mode variant
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-pro-preview-03-25-google-search": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-pro-preview-05-06": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-flash-preview-05-20": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Shorthands
|
||||
"o3": "o3-2025-04-16",
|
||||
"o4-mini": "o4-mini-2025-04-16",
|
||||
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
||||
"opus-4": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize DIAL provider with API key and host.
|
||||
|
||||
Args:
|
||||
api_key: DIAL API key for authentication
|
||||
**kwargs: Additional configuration options
|
||||
"""
|
||||
# Get DIAL API host from environment or kwargs
|
||||
dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai"
|
||||
|
||||
# DIAL uses /openai endpoint for OpenAI-compatible API
|
||||
if not dial_host.endswith("/openai"):
|
||||
dial_host = f"{dial_host.rstrip('/')}/openai"
|
||||
|
||||
kwargs["base_url"] = dial_host
|
||||
|
||||
# Get API version from environment or use default
|
||||
self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview")
|
||||
|
||||
# Add DIAL-specific headers
|
||||
# DIAL uses Api-Key header instead of Authorization: Bearer
|
||||
# Reference: https://dialx.ai/dial_api#section/Authorization
|
||||
self.DEFAULT_HEADERS = {
|
||||
"Api-Key": api_key,
|
||||
}
|
||||
|
||||
# Store the actual API key for use in Api-Key header
|
||||
self._dial_api_key = api_key
|
||||
|
||||
# Pass a placeholder API key to OpenAI client - we'll override the auth header in httpx
|
||||
# The actual authentication happens via the Api-Key header in the httpx client
|
||||
super().__init__("placeholder-not-used", **kwargs)
|
||||
|
||||
# Cache for deployment-specific clients to avoid recreating them on each request
|
||||
self._deployment_clients = {}
|
||||
# Lock to ensure thread-safe client creation
|
||||
self._client_lock = threading.Lock()
|
||||
|
||||
# Create a SINGLE shared httpx client for the provider instance
|
||||
import httpx
|
||||
|
||||
# Create custom event hooks to remove Authorization header
|
||||
def remove_auth_header(request):
|
||||
"""Remove Authorization header that OpenAI client adds."""
|
||||
# httpx headers are case-insensitive, so we need to check all variations
|
||||
headers_to_remove = []
|
||||
for header_name in request.headers:
|
||||
if header_name.lower() == "authorization":
|
||||
headers_to_remove.append(header_name)
|
||||
|
||||
for header_name in headers_to_remove:
|
||||
del request.headers[header_name]
|
||||
|
||||
self._http_client = httpx.Client(
|
||||
timeout=self.timeout_config,
|
||||
verify=True,
|
||||
follow_redirects=True,
|
||||
headers=self.DEFAULT_HEADERS.copy(), # Include DIAL headers including Api-Key
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=5,
|
||||
max_connections=10,
|
||||
keepalive_expiry=30.0,
|
||||
),
|
||||
event_hooks={"request": [remove_auth_header]},
|
||||
)
|
||||
|
||||
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (can be shorthand)
|
||||
|
||||
Returns:
|
||||
ModelCapabilities object
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not supported or not allowed
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported DIAL model: {model_name}")
|
||||
|
||||
# Check restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name=resolved_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_images=config.get("supports_vision", False),
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
)
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.DIAL
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
True if model is supported and allowed, False otherwise
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return False
|
||||
|
||||
# Check against base class allowed_models if configured
|
||||
if self.allowed_models is not None:
|
||||
# Check both original and resolved names (case-insensitive)
|
||||
if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models:
|
||||
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list")
|
||||
return False
|
||||
|
||||
# Also check restrictions via ModelRestrictionService
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
|
||||
Args:
|
||||
model_name: Model name or shorthand
|
||||
|
||||
Returns:
|
||||
Full model name
|
||||
"""
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
def _get_deployment_client(self, deployment: str):
|
||||
"""Get or create a cached client for a specific deployment.
|
||||
|
||||
This avoids recreating OpenAI clients on every request, improving performance.
|
||||
Reuses the shared HTTP client for connection pooling.
|
||||
|
||||
Args:
|
||||
deployment: The deployment/model name
|
||||
|
||||
Returns:
|
||||
OpenAI client configured for the specific deployment
|
||||
"""
|
||||
# Check if client already exists without locking for performance
|
||||
if deployment in self._deployment_clients:
|
||||
return self._deployment_clients[deployment]
|
||||
|
||||
# Use lock to ensure thread-safe client creation
|
||||
with self._client_lock:
|
||||
# Double-check pattern: check again inside the lock
|
||||
if deployment not in self._deployment_clients:
|
||||
from openai import OpenAI
|
||||
|
||||
# Build deployment-specific URL
|
||||
base_url = str(self.client.base_url)
|
||||
if base_url.endswith("/"):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
# Remove /openai suffix if present to reconstruct properly
|
||||
if base_url.endswith("/openai"):
|
||||
base_url = base_url[:-7]
|
||||
|
||||
deployment_url = f"{base_url}/openai/deployments/{deployment}"
|
||||
|
||||
# Create and cache the client, REUSING the shared http_client
|
||||
# Use placeholder API key - Authorization header will be removed by http_client event hook
|
||||
self._deployment_clients[deployment] = OpenAI(
|
||||
api_key="placeholder-not-used",
|
||||
base_url=deployment_url,
|
||||
http_client=self._http_client, # Pass the shared client with Api-Key header
|
||||
default_query={"api-version": self.api_version}, # Add api-version as query param
|
||||
)
|
||||
|
||||
return self._deployment_clients[deployment]
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
images: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using DIAL's deployment-specific endpoint.
|
||||
|
||||
DIAL uses Azure OpenAI-style deployment endpoints:
|
||||
/openai/deployments/{deployment}/chat/completions
|
||||
|
||||
Args:
|
||||
prompt: User prompt
|
||||
model_name: Model name or alias
|
||||
system_prompt: Optional system prompt
|
||||
temperature: Sampling temperature
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
"""
|
||||
# Validate model name against allow-list
|
||||
if not self.validate_model_name(model_name):
|
||||
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
|
||||
|
||||
# Validate parameters
|
||||
self.validate_parameters(model_name, temperature)
|
||||
|
||||
# Prepare messages
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
# Build user message content
|
||||
user_message_content = []
|
||||
if prompt:
|
||||
user_message_content.append({"type": "text", "text": prompt})
|
||||
|
||||
if images and self._supports_vision(model_name):
|
||||
for img_path in images:
|
||||
processed_image = self._process_image(img_path)
|
||||
if processed_image:
|
||||
user_message_content.append(processed_image)
|
||||
elif images:
|
||||
logger.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
|
||||
|
||||
# Add user message. If only text, content will be a string, otherwise a list.
|
||||
if len(user_message_content) == 1 and user_message_content[0]["type"] == "text":
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
else:
|
||||
messages.append({"role": "user", "content": user_message_content})
|
||||
|
||||
# Resolve model name
|
||||
resolved_model = self._resolve_model_name(model_name)
|
||||
|
||||
# Build completion parameters
|
||||
completion_params = {
|
||||
"model": resolved_model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
# Check model capabilities
|
||||
try:
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
supports_temperature = getattr(capabilities, "supports_temperature", True)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
||||
supports_temperature = True
|
||||
|
||||
# Add temperature parameter if supported
|
||||
if supports_temperature:
|
||||
completion_params["temperature"] = temperature
|
||||
|
||||
# Add max tokens if specified and model supports it
|
||||
if max_output_tokens and supports_temperature:
|
||||
completion_params["max_tokens"] = max_output_tokens
|
||||
|
||||
# Add additional parameters
|
||||
for key, value in kwargs.items():
|
||||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||||
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]:
|
||||
continue
|
||||
completion_params[key] = value
|
||||
|
||||
# DIAL-specific: Get cached client for deployment endpoint
|
||||
deployment_client = self._get_deployment_client(resolved_model)
|
||||
|
||||
# Retry logic with progressive delays
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
# Generate completion using deployment-specific client
|
||||
response = deployment_client.chat.completions.create(**completion_params)
|
||||
|
||||
# Extract content and usage
|
||||
content = response.choices[0].message.content
|
||||
usage = self._extract_usage(response)
|
||||
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
usage=usage,
|
||||
model_name=model_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
provider=self.get_provider_type(),
|
||||
metadata={
|
||||
"finish_reason": response.choices[0].finish_reason,
|
||||
"model": response.model,
|
||||
"id": response.id,
|
||||
"created": response.created,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Check if this is a retryable error
|
||||
is_retryable = self._is_error_retryable(e)
|
||||
|
||||
if not is_retryable:
|
||||
# Non-retryable error, raise immediately
|
||||
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
|
||||
|
||||
# If this isn't the last attempt and error is retryable, wait and retry
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
delay = self.RETRY_DELAYS[attempt]
|
||||
logger.info(
|
||||
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
continue
|
||||
|
||||
# All retries exhausted
|
||||
raise ValueError(
|
||||
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
||||
)
|
||||
|
||||
def _supports_vision(self, model_name: str) -> bool:
|
||||
"""Check if the model supports vision (image processing).
|
||||
|
||||
Args:
|
||||
model_name: Model name to check
|
||||
|
||||
Returns:
|
||||
True if model supports vision, False otherwise
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
|
||||
|
||||
# Fall back to parent implementation for unknown models
|
||||
return super()._supports_vision(model_name)
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
# Get all model keys (both full names and aliases)
|
||||
all_models = list(self.SUPPORTED_MODELS.keys())
|
||||
|
||||
if not respect_restrictions:
|
||||
return all_models
|
||||
|
||||
# Apply restrictions if configured
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
|
||||
# Filter based on restrictions
|
||||
allowed_models = []
|
||||
for model in all_models:
|
||||
resolved_name = self._resolve_model_name(model)
|
||||
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
|
||||
allowed_models.append(model)
|
||||
|
||||
return allowed_models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
This is used for validation purposes to ensure restriction policies
|
||||
can validate against both aliases and their target model names.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
# Collect all unique model names (both aliases and targets)
|
||||
all_models = set()
|
||||
|
||||
for key, value in self.SUPPORTED_MODELS.items():
|
||||
# Add the key (could be alias or full name)
|
||||
all_models.add(key)
|
||||
|
||||
# If it's an alias (string value), add the target too
|
||||
if isinstance(value, str):
|
||||
all_models.add(value)
|
||||
|
||||
return sorted(all_models)
|
||||
|
||||
def close(self):
|
||||
"""Clean up HTTP clients when provider is closed."""
|
||||
logger.info("Closing DIAL provider HTTP clients...")
|
||||
|
||||
# Clear the deployment clients cache
|
||||
# Note: We don't need to close individual OpenAI clients since they
|
||||
# use the shared httpx.Client which we close separately
|
||||
self._deployment_clients.clear()
|
||||
|
||||
# Close the shared HTTP client
|
||||
if hasattr(self, "_http_client"):
|
||||
try:
|
||||
self._http_client.close()
|
||||
logger.debug("Closed shared HTTP client")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing shared HTTP client: {e}")
|
||||
|
||||
# Also close the client created by the superclass (OpenAICompatibleProvider)
|
||||
# as it holds its own httpx.Client instance that is not used by DIAL's generate_content
|
||||
if hasattr(self, "client") and self.client and hasattr(self.client, "close"):
|
||||
try:
|
||||
self.client.close()
|
||||
logger.debug("Closed superclass's OpenAI client")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing superclass's OpenAI client: {e}")
|
||||
Reference in New Issue
Block a user