Files
my-pal-mcp-server/providers/dial.py
Illya Havsiyevych 0623ce3546 feat: DIAL provider implementation (#112)
## Description

This PR implements a new [DIAL](https://dialx.ai/dial_api) (Data & AI Layer) provider for the Zen MCP Server, enabling unified access to multiple AI models through the DIAL API platform. DIAL provides enterprise-grade AI model access with deployment-specific routing similar to Azure OpenAI.

## Changes Made

- [x] Added support of atexit:
  - Ensures automatic cleanup of provider resources (HTTP clients, connection pools) on server shutdown
  - Fixed bug using ModelProviderRegistry.get_available_providers() instead of accessing private _providers
  - Works with SIGTERM/Ctrl+C for graceful shutdown in both development and containerized environments
- [x] Added new DIAL provider (`providers/dial.py`) inheriting from `OpenAICompatibleProvider`
- [x] Updated server.py to register DIAL provider during initialization
- [x] Updated provider registry to include DIAL provider type
- [x] Implemented deployment-specific routing for DIAL's Azure OpenAI-style endpoints
- [x] Implemented performance optimizations:
  - Connection pooling with httpx for better performance
  - Thread-safe client caching with double-check locking pattern
  - Proper resource cleanup with `close()` method
- [x] Added comprehensive unit tests with 16 test cases (`tests/test_dial_provider.py`)
- [x] Added DIAL configuration to `.env.example` with documentation
- [x] Added support for configurable API version via `DIAL_API_VERSION` environment variable
- [x] Added DIAL model restrictions support via `DIAL_ALLOWED_MODELS` environment variable

### Supported DIAL Models:
- OpenAI models: o3, o4-mini (and their dated versions)
- Google models: gemini-2.5-pro, gemini-2.5-flash (including search variant)
- Anthropic models: Claude 4 Opus/Sonnet (with and without thinking mode)

### Environment Variables:
- `DIAL_API_KEY`: Required API key for DIAL authentication
- `DIAL_API_HOST`: Optional base URL (defaults to https://core.dialx.ai)
- `DIAL_API_VERSION`: Optional API version header (defaults to 2025-01-01-preview)
- `DIAL_ALLOWED_MODELS`: Optional comma-separated list of allowed models

### Breaking Changes:
- None

  ### Dependencies:
  - No new dependencies added (uses existing OpenAI SDK with custom routing)
2025-06-23 14:07:10 +04:00

526 lines
20 KiB
Python

"""DIAL (Data & AI Layer) model provider implementation."""
import logging
import os
import threading
import time
from typing import Optional
from .base import (
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
from .openai_compatible import OpenAICompatibleProvider
logger = logging.getLogger(__name__)
class DIALModelProvider(OpenAICompatibleProvider):
"""DIAL provider using OpenAI-compatible API.
DIAL provides access to various AI models through a unified API interface.
Supports GPT, Claude, Gemini, and other models via DIAL deployments.
"""
FRIENDLY_NAME = "DIAL"
# Retry configuration for API calls
MAX_RETRIES = 4
RETRY_DELAYS = [1, 3, 5, 8] # seconds
# Supported DIAL models (these can be customized based on your DIAL deployment)
SUPPORTED_MODELS = {
"o3-2025-04-16": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"o4-mini-2025-04-16": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"anthropic.claude-sonnet-4-20250514-v1:0": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
"context_window": 200_000,
"supports_extended_thinking": True, # Thinking mode variant
"supports_vision": True,
},
"anthropic.claude-opus-4-20250514-v1:0": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
"context_window": 200_000,
"supports_extended_thinking": True, # Thinking mode variant
"supports_vision": True,
},
"gemini-2.5-pro-preview-03-25-google-search": {
"context_window": 1_000_000,
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode
"supports_vision": True,
},
"gemini-2.5-pro-preview-05-06": {
"context_window": 1_000_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"gemini-2.5-flash-preview-05-20": {
"context_window": 1_000_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
# Shorthands
"o3": "o3-2025-04-16",
"o4-mini": "o4-mini-2025-04-16",
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
"opus-4": "anthropic.claude-opus-4-20250514-v1:0",
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
}
def __init__(self, api_key: str, **kwargs):
"""Initialize DIAL provider with API key and host.
Args:
api_key: DIAL API key for authentication
**kwargs: Additional configuration options
"""
# Get DIAL API host from environment or kwargs
dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai"
# DIAL uses /openai endpoint for OpenAI-compatible API
if not dial_host.endswith("/openai"):
dial_host = f"{dial_host.rstrip('/')}/openai"
kwargs["base_url"] = dial_host
# Get API version from environment or use default
self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview")
# Add DIAL-specific headers
# DIAL uses Api-Key header instead of Authorization: Bearer
# Reference: https://dialx.ai/dial_api#section/Authorization
self.DEFAULT_HEADERS = {
"Api-Key": api_key,
}
# Store the actual API key for use in Api-Key header
self._dial_api_key = api_key
# Pass a placeholder API key to OpenAI client - we'll override the auth header in httpx
# The actual authentication happens via the Api-Key header in the httpx client
super().__init__("placeholder-not-used", **kwargs)
# Cache for deployment-specific clients to avoid recreating them on each request
self._deployment_clients = {}
# Lock to ensure thread-safe client creation
self._client_lock = threading.Lock()
# Create a SINGLE shared httpx client for the provider instance
import httpx
# Create custom event hooks to remove Authorization header
def remove_auth_header(request):
"""Remove Authorization header that OpenAI client adds."""
# httpx headers are case-insensitive, so we need to check all variations
headers_to_remove = []
for header_name in request.headers:
if header_name.lower() == "authorization":
headers_to_remove.append(header_name)
for header_name in headers_to_remove:
del request.headers[header_name]
self._http_client = httpx.Client(
timeout=self.timeout_config,
verify=True,
follow_redirects=True,
headers=self.DEFAULT_HEADERS.copy(), # Include DIAL headers including Api-Key
limits=httpx.Limits(
max_keepalive_connections=5,
max_connections=10,
keepalive_expiry=30.0,
),
event_hooks={"request": [remove_auth_header]},
)
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model.
Args:
model_name: Name of the model (can be shorthand)
Returns:
ModelCapabilities object
Raises:
ValueError: If model is not supported or not allowed
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported DIAL model: {model_name}")
# Check restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
config = self.SUPPORTED_MODELS[resolved_name]
return ModelCapabilities(
provider=ProviderType.DIAL,
model_name=resolved_name,
friendly_name=self.FRIENDLY_NAME,
context_window=config["context_window"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_images=config.get("supports_vision", False),
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
)
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.DIAL
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported.
Args:
model_name: Model name to validate
Returns:
True if model is supported and allowed, False otherwise
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
return False
# Check against base class allowed_models if configured
if self.allowed_models is not None:
# Check both original and resolved names (case-insensitive)
if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models:
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list")
return False
# Also check restrictions via ModelRestrictionService
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name.
Args:
model_name: Model name or shorthand
Returns:
Full model name
"""
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
def _get_deployment_client(self, deployment: str):
"""Get or create a cached client for a specific deployment.
This avoids recreating OpenAI clients on every request, improving performance.
Reuses the shared HTTP client for connection pooling.
Args:
deployment: The deployment/model name
Returns:
OpenAI client configured for the specific deployment
"""
# Check if client already exists without locking for performance
if deployment in self._deployment_clients:
return self._deployment_clients[deployment]
# Use lock to ensure thread-safe client creation
with self._client_lock:
# Double-check pattern: check again inside the lock
if deployment not in self._deployment_clients:
from openai import OpenAI
# Build deployment-specific URL
base_url = str(self.client.base_url)
if base_url.endswith("/"):
base_url = base_url[:-1]
# Remove /openai suffix if present to reconstruct properly
if base_url.endswith("/openai"):
base_url = base_url[:-7]
deployment_url = f"{base_url}/openai/deployments/{deployment}"
# Create and cache the client, REUSING the shared http_client
# Use placeholder API key - Authorization header will be removed by http_client event hook
self._deployment_clients[deployment] = OpenAI(
api_key="placeholder-not-used",
base_url=deployment_url,
http_client=self._http_client, # Pass the shared client with Api-Key header
default_query={"api-version": self.api_version}, # Add api-version as query param
)
return self._deployment_clients[deployment]
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
images: Optional[list[str]] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using DIAL's deployment-specific endpoint.
DIAL uses Azure OpenAI-style deployment endpoints:
/openai/deployments/{deployment}/chat/completions
Args:
prompt: User prompt
model_name: Model name or alias
system_prompt: Optional system prompt
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Validate model name against allow-list
if not self.validate_model_name(model_name):
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
# Validate parameters
self.validate_parameters(model_name, temperature)
# Prepare messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Build user message content
user_message_content = []
if prompt:
user_message_content.append({"type": "text", "text": prompt})
if images and self._supports_vision(model_name):
for img_path in images:
processed_image = self._process_image(img_path)
if processed_image:
user_message_content.append(processed_image)
elif images:
logger.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
# Add user message. If only text, content will be a string, otherwise a list.
if len(user_message_content) == 1 and user_message_content[0]["type"] == "text":
messages.append({"role": "user", "content": prompt})
else:
messages.append({"role": "user", "content": user_message_content})
# Resolve model name
resolved_model = self._resolve_model_name(model_name)
# Build completion parameters
completion_params = {
"model": resolved_model,
"messages": messages,
}
# Check model capabilities
try:
capabilities = self.get_capabilities(model_name)
supports_temperature = getattr(capabilities, "supports_temperature", True)
except Exception as e:
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
supports_temperature = True
# Add temperature parameter if supported
if supports_temperature:
completion_params["temperature"] = temperature
# Add max tokens if specified and model supports it
if max_output_tokens and supports_temperature:
completion_params["max_tokens"] = max_output_tokens
# Add additional parameters
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]:
continue
completion_params[key] = value
# DIAL-specific: Get cached client for deployment endpoint
deployment_client = self._get_deployment_client(resolved_model)
# Retry logic with progressive delays
last_exception = None
for attempt in range(self.MAX_RETRIES):
try:
# Generate completion using deployment-specific client
response = deployment_client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model,
"id": response.id,
"created": response.created,
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error
is_retryable = self._is_error_retryable(e)
if not is_retryable:
# Non-retryable error, raise immediately
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
# If this isn't the last attempt and error is retryable, wait and retry
if attempt < self.MAX_RETRIES - 1:
delay = self.RETRY_DELAYS[attempt]
logger.info(
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
)
time.sleep(delay)
continue
# All retries exhausted
raise ValueError(
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
)
def _supports_vision(self, model_name: str) -> bool:
"""Check if the model supports vision (image processing).
Args:
model_name: Model name to check
Returns:
True if model supports vision, False otherwise
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
# Fall back to parent implementation for unknown models
return super()._supports_vision(model_name)
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
# Get all model keys (both full names and aliases)
all_models = list(self.SUPPORTED_MODELS.keys())
if not respect_restrictions:
return all_models
# Apply restrictions if configured
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
# Filter based on restrictions
allowed_models = []
for model in all_models:
resolved_name = self._resolve_model_name(model)
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
allowed_models.append(model)
return allowed_models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
This is used for validation purposes to ensure restriction policies
can validate against both aliases and their target model names.
Returns:
List of all model names and alias targets known by this provider
"""
# Collect all unique model names (both aliases and targets)
all_models = set()
for key, value in self.SUPPORTED_MODELS.items():
# Add the key (could be alias or full name)
all_models.add(key)
# If it's an alias (string value), add the target too
if isinstance(value, str):
all_models.add(value)
return sorted(all_models)
def close(self):
"""Clean up HTTP clients when provider is closed."""
logger.info("Closing DIAL provider HTTP clients...")
# Clear the deployment clients cache
# Note: We don't need to close individual OpenAI clients since they
# use the shared httpx.Client which we close separately
self._deployment_clients.clear()
# Close the shared HTTP client
if hasattr(self, "_http_client"):
try:
self._http_client.close()
logger.debug("Closed shared HTTP client")
except Exception as e:
logger.warning(f"Error closing shared HTTP client: {e}")
# Also close the client created by the superclass (OpenAICompatibleProvider)
# as it holds its own httpx.Client instance that is not used by DIAL's generate_content
if hasattr(self, "client") and self.client and hasattr(self.client, "close"):
try:
self.client.close()
logger.debug("Closed superclass's OpenAI client")
except Exception as e:
logger.warning(f"Error closing superclass's OpenAI client: {e}")