## Description This PR implements a new [DIAL](https://dialx.ai/dial_api) (Data & AI Layer) provider for the Zen MCP Server, enabling unified access to multiple AI models through the DIAL API platform. DIAL provides enterprise-grade AI model access with deployment-specific routing similar to Azure OpenAI. ## Changes Made - [x] Added support of atexit: - Ensures automatic cleanup of provider resources (HTTP clients, connection pools) on server shutdown - Fixed bug using ModelProviderRegistry.get_available_providers() instead of accessing private _providers - Works with SIGTERM/Ctrl+C for graceful shutdown in both development and containerized environments - [x] Added new DIAL provider (`providers/dial.py`) inheriting from `OpenAICompatibleProvider` - [x] Updated server.py to register DIAL provider during initialization - [x] Updated provider registry to include DIAL provider type - [x] Implemented deployment-specific routing for DIAL's Azure OpenAI-style endpoints - [x] Implemented performance optimizations: - Connection pooling with httpx for better performance - Thread-safe client caching with double-check locking pattern - Proper resource cleanup with `close()` method - [x] Added comprehensive unit tests with 16 test cases (`tests/test_dial_provider.py`) - [x] Added DIAL configuration to `.env.example` with documentation - [x] Added support for configurable API version via `DIAL_API_VERSION` environment variable - [x] Added DIAL model restrictions support via `DIAL_ALLOWED_MODELS` environment variable ### Supported DIAL Models: - OpenAI models: o3, o4-mini (and their dated versions) - Google models: gemini-2.5-pro, gemini-2.5-flash (including search variant) - Anthropic models: Claude 4 Opus/Sonnet (with and without thinking mode) ### Environment Variables: - `DIAL_API_KEY`: Required API key for DIAL authentication - `DIAL_API_HOST`: Optional base URL (defaults to https://core.dialx.ai) - `DIAL_API_VERSION`: Optional API version header (defaults to 2025-01-01-preview) - `DIAL_ALLOWED_MODELS`: Optional comma-separated list of allowed models ### Breaking Changes: - None ### Dependencies: - No new dependencies added (uses existing OpenAI SDK with custom routing)
526 lines
20 KiB
Python
526 lines
20 KiB
Python
"""DIAL (Data & AI Layer) model provider implementation."""
|
|
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
from typing import Optional
|
|
|
|
from .base import (
|
|
ModelCapabilities,
|
|
ModelResponse,
|
|
ProviderType,
|
|
RangeTemperatureConstraint,
|
|
)
|
|
from .openai_compatible import OpenAICompatibleProvider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DIALModelProvider(OpenAICompatibleProvider):
|
|
"""DIAL provider using OpenAI-compatible API.
|
|
|
|
DIAL provides access to various AI models through a unified API interface.
|
|
Supports GPT, Claude, Gemini, and other models via DIAL deployments.
|
|
"""
|
|
|
|
FRIENDLY_NAME = "DIAL"
|
|
|
|
# Retry configuration for API calls
|
|
MAX_RETRIES = 4
|
|
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
|
|
|
# Supported DIAL models (these can be customized based on your DIAL deployment)
|
|
SUPPORTED_MODELS = {
|
|
"o3-2025-04-16": {
|
|
"context_window": 200_000,
|
|
"supports_extended_thinking": False,
|
|
"supports_vision": True,
|
|
},
|
|
"o4-mini-2025-04-16": {
|
|
"context_window": 200_000,
|
|
"supports_extended_thinking": False,
|
|
"supports_vision": True,
|
|
},
|
|
"anthropic.claude-sonnet-4-20250514-v1:0": {
|
|
"context_window": 200_000,
|
|
"supports_extended_thinking": False,
|
|
"supports_vision": True,
|
|
},
|
|
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
|
|
"context_window": 200_000,
|
|
"supports_extended_thinking": True, # Thinking mode variant
|
|
"supports_vision": True,
|
|
},
|
|
"anthropic.claude-opus-4-20250514-v1:0": {
|
|
"context_window": 200_000,
|
|
"supports_extended_thinking": False,
|
|
"supports_vision": True,
|
|
},
|
|
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
|
|
"context_window": 200_000,
|
|
"supports_extended_thinking": True, # Thinking mode variant
|
|
"supports_vision": True,
|
|
},
|
|
"gemini-2.5-pro-preview-03-25-google-search": {
|
|
"context_window": 1_000_000,
|
|
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode
|
|
"supports_vision": True,
|
|
},
|
|
"gemini-2.5-pro-preview-05-06": {
|
|
"context_window": 1_000_000,
|
|
"supports_extended_thinking": False,
|
|
"supports_vision": True,
|
|
},
|
|
"gemini-2.5-flash-preview-05-20": {
|
|
"context_window": 1_000_000,
|
|
"supports_extended_thinking": False,
|
|
"supports_vision": True,
|
|
},
|
|
# Shorthands
|
|
"o3": "o3-2025-04-16",
|
|
"o4-mini": "o4-mini-2025-04-16",
|
|
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
|
|
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
|
"opus-4": "anthropic.claude-opus-4-20250514-v1:0",
|
|
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
|
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
|
|
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
|
|
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
|
|
}
|
|
|
|
def __init__(self, api_key: str, **kwargs):
|
|
"""Initialize DIAL provider with API key and host.
|
|
|
|
Args:
|
|
api_key: DIAL API key for authentication
|
|
**kwargs: Additional configuration options
|
|
"""
|
|
# Get DIAL API host from environment or kwargs
|
|
dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai"
|
|
|
|
# DIAL uses /openai endpoint for OpenAI-compatible API
|
|
if not dial_host.endswith("/openai"):
|
|
dial_host = f"{dial_host.rstrip('/')}/openai"
|
|
|
|
kwargs["base_url"] = dial_host
|
|
|
|
# Get API version from environment or use default
|
|
self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview")
|
|
|
|
# Add DIAL-specific headers
|
|
# DIAL uses Api-Key header instead of Authorization: Bearer
|
|
# Reference: https://dialx.ai/dial_api#section/Authorization
|
|
self.DEFAULT_HEADERS = {
|
|
"Api-Key": api_key,
|
|
}
|
|
|
|
# Store the actual API key for use in Api-Key header
|
|
self._dial_api_key = api_key
|
|
|
|
# Pass a placeholder API key to OpenAI client - we'll override the auth header in httpx
|
|
# The actual authentication happens via the Api-Key header in the httpx client
|
|
super().__init__("placeholder-not-used", **kwargs)
|
|
|
|
# Cache for deployment-specific clients to avoid recreating them on each request
|
|
self._deployment_clients = {}
|
|
# Lock to ensure thread-safe client creation
|
|
self._client_lock = threading.Lock()
|
|
|
|
# Create a SINGLE shared httpx client for the provider instance
|
|
import httpx
|
|
|
|
# Create custom event hooks to remove Authorization header
|
|
def remove_auth_header(request):
|
|
"""Remove Authorization header that OpenAI client adds."""
|
|
# httpx headers are case-insensitive, so we need to check all variations
|
|
headers_to_remove = []
|
|
for header_name in request.headers:
|
|
if header_name.lower() == "authorization":
|
|
headers_to_remove.append(header_name)
|
|
|
|
for header_name in headers_to_remove:
|
|
del request.headers[header_name]
|
|
|
|
self._http_client = httpx.Client(
|
|
timeout=self.timeout_config,
|
|
verify=True,
|
|
follow_redirects=True,
|
|
headers=self.DEFAULT_HEADERS.copy(), # Include DIAL headers including Api-Key
|
|
limits=httpx.Limits(
|
|
max_keepalive_connections=5,
|
|
max_connections=10,
|
|
keepalive_expiry=30.0,
|
|
),
|
|
event_hooks={"request": [remove_auth_header]},
|
|
)
|
|
|
|
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
|
|
|
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
"""Get capabilities for a specific model.
|
|
|
|
Args:
|
|
model_name: Name of the model (can be shorthand)
|
|
|
|
Returns:
|
|
ModelCapabilities object
|
|
|
|
Raises:
|
|
ValueError: If model is not supported or not allowed
|
|
"""
|
|
resolved_name = self._resolve_model_name(model_name)
|
|
|
|
if resolved_name not in self.SUPPORTED_MODELS:
|
|
raise ValueError(f"Unsupported DIAL model: {model_name}")
|
|
|
|
# Check restrictions
|
|
from utils.model_restrictions import get_restriction_service
|
|
|
|
restriction_service = get_restriction_service()
|
|
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
|
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
|
|
|
config = self.SUPPORTED_MODELS[resolved_name]
|
|
|
|
return ModelCapabilities(
|
|
provider=ProviderType.DIAL,
|
|
model_name=resolved_name,
|
|
friendly_name=self.FRIENDLY_NAME,
|
|
context_window=config["context_window"],
|
|
supports_extended_thinking=config["supports_extended_thinking"],
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=True,
|
|
supports_images=config.get("supports_vision", False),
|
|
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
|
)
|
|
|
|
def get_provider_type(self) -> ProviderType:
|
|
"""Get the provider type."""
|
|
return ProviderType.DIAL
|
|
|
|
def validate_model_name(self, model_name: str) -> bool:
|
|
"""Validate if the model name is supported.
|
|
|
|
Args:
|
|
model_name: Model name to validate
|
|
|
|
Returns:
|
|
True if model is supported and allowed, False otherwise
|
|
"""
|
|
resolved_name = self._resolve_model_name(model_name)
|
|
|
|
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
|
return False
|
|
|
|
# Check against base class allowed_models if configured
|
|
if self.allowed_models is not None:
|
|
# Check both original and resolved names (case-insensitive)
|
|
if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models:
|
|
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list")
|
|
return False
|
|
|
|
# Also check restrictions via ModelRestrictionService
|
|
from utils.model_restrictions import get_restriction_service
|
|
|
|
restriction_service = get_restriction_service()
|
|
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
|
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
|
return False
|
|
|
|
return True
|
|
|
|
def _resolve_model_name(self, model_name: str) -> str:
|
|
"""Resolve model shorthand to full name.
|
|
|
|
Args:
|
|
model_name: Model name or shorthand
|
|
|
|
Returns:
|
|
Full model name
|
|
"""
|
|
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
|
if isinstance(shorthand_value, str):
|
|
return shorthand_value
|
|
return model_name
|
|
|
|
def _get_deployment_client(self, deployment: str):
|
|
"""Get or create a cached client for a specific deployment.
|
|
|
|
This avoids recreating OpenAI clients on every request, improving performance.
|
|
Reuses the shared HTTP client for connection pooling.
|
|
|
|
Args:
|
|
deployment: The deployment/model name
|
|
|
|
Returns:
|
|
OpenAI client configured for the specific deployment
|
|
"""
|
|
# Check if client already exists without locking for performance
|
|
if deployment in self._deployment_clients:
|
|
return self._deployment_clients[deployment]
|
|
|
|
# Use lock to ensure thread-safe client creation
|
|
with self._client_lock:
|
|
# Double-check pattern: check again inside the lock
|
|
if deployment not in self._deployment_clients:
|
|
from openai import OpenAI
|
|
|
|
# Build deployment-specific URL
|
|
base_url = str(self.client.base_url)
|
|
if base_url.endswith("/"):
|
|
base_url = base_url[:-1]
|
|
|
|
# Remove /openai suffix if present to reconstruct properly
|
|
if base_url.endswith("/openai"):
|
|
base_url = base_url[:-7]
|
|
|
|
deployment_url = f"{base_url}/openai/deployments/{deployment}"
|
|
|
|
# Create and cache the client, REUSING the shared http_client
|
|
# Use placeholder API key - Authorization header will be removed by http_client event hook
|
|
self._deployment_clients[deployment] = OpenAI(
|
|
api_key="placeholder-not-used",
|
|
base_url=deployment_url,
|
|
http_client=self._http_client, # Pass the shared client with Api-Key header
|
|
default_query={"api-version": self.api_version}, # Add api-version as query param
|
|
)
|
|
|
|
return self._deployment_clients[deployment]
|
|
|
|
def generate_content(
|
|
self,
|
|
prompt: str,
|
|
model_name: str,
|
|
system_prompt: Optional[str] = None,
|
|
temperature: float = 0.7,
|
|
max_output_tokens: Optional[int] = None,
|
|
images: Optional[list[str]] = None,
|
|
**kwargs,
|
|
) -> ModelResponse:
|
|
"""Generate content using DIAL's deployment-specific endpoint.
|
|
|
|
DIAL uses Azure OpenAI-style deployment endpoints:
|
|
/openai/deployments/{deployment}/chat/completions
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
model_name: Model name or alias
|
|
system_prompt: Optional system prompt
|
|
temperature: Sampling temperature
|
|
max_output_tokens: Maximum tokens to generate
|
|
**kwargs: Additional provider-specific parameters
|
|
|
|
Returns:
|
|
ModelResponse with generated content and metadata
|
|
"""
|
|
# Validate model name against allow-list
|
|
if not self.validate_model_name(model_name):
|
|
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
|
|
|
|
# Validate parameters
|
|
self.validate_parameters(model_name, temperature)
|
|
|
|
# Prepare messages
|
|
messages = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
# Build user message content
|
|
user_message_content = []
|
|
if prompt:
|
|
user_message_content.append({"type": "text", "text": prompt})
|
|
|
|
if images and self._supports_vision(model_name):
|
|
for img_path in images:
|
|
processed_image = self._process_image(img_path)
|
|
if processed_image:
|
|
user_message_content.append(processed_image)
|
|
elif images:
|
|
logger.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
|
|
|
|
# Add user message. If only text, content will be a string, otherwise a list.
|
|
if len(user_message_content) == 1 and user_message_content[0]["type"] == "text":
|
|
messages.append({"role": "user", "content": prompt})
|
|
else:
|
|
messages.append({"role": "user", "content": user_message_content})
|
|
|
|
# Resolve model name
|
|
resolved_model = self._resolve_model_name(model_name)
|
|
|
|
# Build completion parameters
|
|
completion_params = {
|
|
"model": resolved_model,
|
|
"messages": messages,
|
|
}
|
|
|
|
# Check model capabilities
|
|
try:
|
|
capabilities = self.get_capabilities(model_name)
|
|
supports_temperature = getattr(capabilities, "supports_temperature", True)
|
|
except Exception as e:
|
|
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
|
supports_temperature = True
|
|
|
|
# Add temperature parameter if supported
|
|
if supports_temperature:
|
|
completion_params["temperature"] = temperature
|
|
|
|
# Add max tokens if specified and model supports it
|
|
if max_output_tokens and supports_temperature:
|
|
completion_params["max_tokens"] = max_output_tokens
|
|
|
|
# Add additional parameters
|
|
for key, value in kwargs.items():
|
|
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
|
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]:
|
|
continue
|
|
completion_params[key] = value
|
|
|
|
# DIAL-specific: Get cached client for deployment endpoint
|
|
deployment_client = self._get_deployment_client(resolved_model)
|
|
|
|
# Retry logic with progressive delays
|
|
last_exception = None
|
|
|
|
for attempt in range(self.MAX_RETRIES):
|
|
try:
|
|
# Generate completion using deployment-specific client
|
|
response = deployment_client.chat.completions.create(**completion_params)
|
|
|
|
# Extract content and usage
|
|
content = response.choices[0].message.content
|
|
usage = self._extract_usage(response)
|
|
|
|
return ModelResponse(
|
|
content=content,
|
|
usage=usage,
|
|
model_name=model_name,
|
|
friendly_name=self.FRIENDLY_NAME,
|
|
provider=self.get_provider_type(),
|
|
metadata={
|
|
"finish_reason": response.choices[0].finish_reason,
|
|
"model": response.model,
|
|
"id": response.id,
|
|
"created": response.created,
|
|
},
|
|
)
|
|
|
|
except Exception as e:
|
|
last_exception = e
|
|
|
|
# Check if this is a retryable error
|
|
is_retryable = self._is_error_retryable(e)
|
|
|
|
if not is_retryable:
|
|
# Non-retryable error, raise immediately
|
|
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
|
|
|
|
# If this isn't the last attempt and error is retryable, wait and retry
|
|
if attempt < self.MAX_RETRIES - 1:
|
|
delay = self.RETRY_DELAYS[attempt]
|
|
logger.info(
|
|
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
|
|
)
|
|
time.sleep(delay)
|
|
continue
|
|
|
|
# All retries exhausted
|
|
raise ValueError(
|
|
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
|
)
|
|
|
|
def _supports_vision(self, model_name: str) -> bool:
|
|
"""Check if the model supports vision (image processing).
|
|
|
|
Args:
|
|
model_name: Model name to check
|
|
|
|
Returns:
|
|
True if model supports vision, False otherwise
|
|
"""
|
|
resolved_name = self._resolve_model_name(model_name)
|
|
|
|
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
|
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
|
|
|
|
# Fall back to parent implementation for unknown models
|
|
return super()._supports_vision(model_name)
|
|
|
|
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
|
"""Return a list of model names supported by this provider.
|
|
|
|
Args:
|
|
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
|
|
Returns:
|
|
List of model names available from this provider
|
|
"""
|
|
# Get all model keys (both full names and aliases)
|
|
all_models = list(self.SUPPORTED_MODELS.keys())
|
|
|
|
if not respect_restrictions:
|
|
return all_models
|
|
|
|
# Apply restrictions if configured
|
|
from utils.model_restrictions import get_restriction_service
|
|
|
|
restriction_service = get_restriction_service()
|
|
|
|
# Filter based on restrictions
|
|
allowed_models = []
|
|
for model in all_models:
|
|
resolved_name = self._resolve_model_name(model)
|
|
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
|
|
allowed_models.append(model)
|
|
|
|
return allowed_models
|
|
|
|
def list_all_known_models(self) -> list[str]:
|
|
"""Return all model names known by this provider, including alias targets.
|
|
|
|
This is used for validation purposes to ensure restriction policies
|
|
can validate against both aliases and their target model names.
|
|
|
|
Returns:
|
|
List of all model names and alias targets known by this provider
|
|
"""
|
|
# Collect all unique model names (both aliases and targets)
|
|
all_models = set()
|
|
|
|
for key, value in self.SUPPORTED_MODELS.items():
|
|
# Add the key (could be alias or full name)
|
|
all_models.add(key)
|
|
|
|
# If it's an alias (string value), add the target too
|
|
if isinstance(value, str):
|
|
all_models.add(value)
|
|
|
|
return sorted(all_models)
|
|
|
|
def close(self):
|
|
"""Clean up HTTP clients when provider is closed."""
|
|
logger.info("Closing DIAL provider HTTP clients...")
|
|
|
|
# Clear the deployment clients cache
|
|
# Note: We don't need to close individual OpenAI clients since they
|
|
# use the shared httpx.Client which we close separately
|
|
self._deployment_clients.clear()
|
|
|
|
# Close the shared HTTP client
|
|
if hasattr(self, "_http_client"):
|
|
try:
|
|
self._http_client.close()
|
|
logger.debug("Closed shared HTTP client")
|
|
except Exception as e:
|
|
logger.warning(f"Error closing shared HTTP client: {e}")
|
|
|
|
# Also close the client created by the superclass (OpenAICompatibleProvider)
|
|
# as it holds its own httpx.Client instance that is not used by DIAL's generate_content
|
|
if hasattr(self, "client") and self.client and hasattr(self.client, "close"):
|
|
try:
|
|
self.client.close()
|
|
logger.debug("Closed superclass's OpenAI client")
|
|
except Exception as e:
|
|
logger.warning(f"Error closing superclass's OpenAI client: {e}")
|