feat: DIAL provider implementation (#112)

## Description

This PR implements a new [DIAL](https://dialx.ai/dial_api) (Data & AI Layer) provider for the Zen MCP Server, enabling unified access to multiple AI models through the DIAL API platform. DIAL provides enterprise-grade AI model access with deployment-specific routing similar to Azure OpenAI.

## Changes Made

- [x] Added support of atexit:
  - Ensures automatic cleanup of provider resources (HTTP clients, connection pools) on server shutdown
  - Fixed bug using ModelProviderRegistry.get_available_providers() instead of accessing private _providers
  - Works with SIGTERM/Ctrl+C for graceful shutdown in both development and containerized environments
- [x] Added new DIAL provider (`providers/dial.py`) inheriting from `OpenAICompatibleProvider`
- [x] Updated server.py to register DIAL provider during initialization
- [x] Updated provider registry to include DIAL provider type
- [x] Implemented deployment-specific routing for DIAL's Azure OpenAI-style endpoints
- [x] Implemented performance optimizations:
  - Connection pooling with httpx for better performance
  - Thread-safe client caching with double-check locking pattern
  - Proper resource cleanup with `close()` method
- [x] Added comprehensive unit tests with 16 test cases (`tests/test_dial_provider.py`)
- [x] Added DIAL configuration to `.env.example` with documentation
- [x] Added support for configurable API version via `DIAL_API_VERSION` environment variable
- [x] Added DIAL model restrictions support via `DIAL_ALLOWED_MODELS` environment variable

### Supported DIAL Models:
- OpenAI models: o3, o4-mini (and their dated versions)
- Google models: gemini-2.5-pro, gemini-2.5-flash (including search variant)
- Anthropic models: Claude 4 Opus/Sonnet (with and without thinking mode)

### Environment Variables:
- `DIAL_API_KEY`: Required API key for DIAL authentication
- `DIAL_API_HOST`: Optional base URL (defaults to https://core.dialx.ai)
- `DIAL_API_VERSION`: Optional API version header (defaults to 2025-01-01-preview)
- `DIAL_ALLOWED_MODELS`: Optional comma-separated list of allowed models

### Breaking Changes:
- None

  ### Dependencies:
  - No new dependencies added (uses existing OpenAI SDK with custom routing)
This commit is contained in:
Illya Havsiyevych
2025-06-23 13:07:10 +03:00
committed by GitHub
parent 4ae0344b14
commit 0623ce3546
10 changed files with 900 additions and 9 deletions

View File

@@ -3,8 +3,11 @@
# API Keys - At least one is required
#
# IMPORTANT: Use EITHER OpenRouter OR native APIs (Gemini/OpenAI), not both!
# Having both creates ambiguity about which provider serves each model.
# IMPORTANT: Choose ONE approach:
# - Native APIs (Gemini/OpenAI/XAI) for direct access
# - DIAL for unified enterprise access
# - OpenRouter for unified cloud access
# Having multiple unified providers creates ambiguity about which serves each model.
#
# Option 1: Use native APIs (recommended for direct access)
# Get your Gemini API key from: https://makersuite.google.com/app/apikey
@@ -16,6 +19,12 @@ OPENAI_API_KEY=your_openai_api_key_here
# Get your X.AI API key from: https://console.x.ai/
XAI_API_KEY=your_xai_api_key_here
# Get your DIAL API key and configure host URL
# DIAL provides unified access to multiple AI models through a single API
DIAL_API_KEY=your_dial_api_key_here
# DIAL_API_HOST=https://core.dialx.ai # Optional: Base URL without /openai suffix (auto-appended)
# DIAL_API_VERSION=2025-01-01-preview # Optional: API version header for DIAL requests
# Option 2: Use OpenRouter for access to multiple models through one API
# Get your OpenRouter API key from: https://openrouter.ai/
# If using OpenRouter, comment out the native API keys above
@@ -27,7 +36,8 @@ OPENROUTER_API_KEY=your_openrouter_api_key_here
# CUSTOM_MODEL_NAME=llama3.2 # Default model name
# Optional: Default model to use
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high' etc
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high',
# 'grok', 'opus-4', 'sonnet-4', or any DIAL model if DIAL is configured
# When set to 'auto', Claude will select the best model for each task
# Defaults to 'auto' if not specified
DEFAULT_MODEL=auto
@@ -70,6 +80,26 @@ DEFAULT_THINKING_MODE_THINKDEEP=high
# - grok3 (shorthand for grok-3)
# - grokfast (shorthand for grok-3-fast)
#
# Supported DIAL models (when available in your DIAL deployment):
# - o3-2025-04-16 (200K context, latest O3 release)
# - o4-mini-2025-04-16 (200K context, latest O4 mini)
# - o3 (shorthand for o3-2025-04-16)
# - o4-mini (shorthand for o4-mini-2025-04-16)
# - anthropic.claude-sonnet-4-20250514-v1:0 (200K context, Claude 4 Sonnet)
# - anthropic.claude-sonnet-4-20250514-v1:0-with-thinking (200K context, Claude 4 Sonnet with thinking mode)
# - anthropic.claude-opus-4-20250514-v1:0 (200K context, Claude 4 Opus)
# - anthropic.claude-opus-4-20250514-v1:0-with-thinking (200K context, Claude 4 Opus with thinking mode)
# - sonnet-4 (shorthand for Claude 4 Sonnet)
# - sonnet-4-thinking (shorthand for Claude 4 Sonnet with thinking)
# - opus-4 (shorthand for Claude 4 Opus)
# - opus-4-thinking (shorthand for Claude 4 Opus with thinking)
# - gemini-2.5-pro-preview-03-25-google-search (1M context, with Google Search)
# - gemini-2.5-pro-preview-05-06 (1M context, latest preview)
# - gemini-2.5-flash-preview-05-20 (1M context, latest flash preview)
# - gemini-2.5-pro (shorthand for gemini-2.5-pro-preview-05-06)
# - gemini-2.5-pro-search (shorthand for gemini-2.5-pro-preview-03-25-google-search)
# - gemini-2.5-flash (shorthand for gemini-2.5-flash-preview-05-20)
#
# Examples:
# OPENAI_ALLOWED_MODELS=o3-mini,o4-mini,mini # Only allow mini models (cost control)
# GOOGLE_ALLOWED_MODELS=flash # Only allow Flash (fast responses)
@@ -77,21 +107,26 @@ DEFAULT_THINKING_MODE_THINKDEEP=high
# OPENAI_ALLOWED_MODELS=o4-mini # Single model standardization
# GOOGLE_ALLOWED_MODELS=flash,pro # Allow both Gemini models
# XAI_ALLOWED_MODELS=grok,grok-3-fast # Allow both GROK variants
# DIAL_ALLOWED_MODELS=o3,o4-mini # Only allow O3/O4 models via DIAL
# DIAL_ALLOWED_MODELS=opus-4,sonnet-4 # Only Claude 4 models (without thinking)
# DIAL_ALLOWED_MODELS=opus-4-thinking,sonnet-4-thinking # Only Claude 4 with thinking mode
# DIAL_ALLOWED_MODELS=gemini-2.5-pro,gemini-2.5-flash # Only Gemini 2.5 models via DIAL
#
# Note: These restrictions apply even in 'auto' mode - Claude will only pick from allowed models
# OPENAI_ALLOWED_MODELS=
# GOOGLE_ALLOWED_MODELS=
# XAI_ALLOWED_MODELS=
# DIAL_ALLOWED_MODELS=
# Optional: Custom model configuration file path
# Override the default location of custom_models.json
# CUSTOM_MODELS_CONFIG_PATH=/path/to/your/custom_models.json
# Note: Redis is no longer used - conversations are stored in memory
# Note: Conversations are stored in memory during the session
# Optional: Conversation timeout (hours)
# How long AI-to-AI conversation threads persist before expiring
# Longer timeouts use more Redis memory but allow resuming conversations later
# Longer timeouts use more memory but allow resuming conversations later
# Defaults to 3 hours if not specified
CONVERSATION_TIMEOUT_HOURS=3

View File

@@ -3,7 +3,7 @@
[zen_web.webm](https://github.com/user-attachments/assets/851e3911-7f06-47c0-a4ab-a2601236697c)
<div align="center">
<b>🤖 Claude + [Gemini / OpenAI / Grok / OpenRouter / Ollama / Any Model] = Your Ultimate AI Development Team</b>
<b>🤖 Claude + [Gemini / OpenAI / Grok / OpenRouter / DIAL / Ollama / Any Model] = Your Ultimate AI Development Team</b>
</div>
<br/>
@@ -145,6 +145,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models.
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) to get an API key for O3 model access.
- **X.AI**: Visit [X.AI Console](https://console.x.ai/) to get an API key for GROK model access.
- **DIAL**: Visit [DIAL Platform](https://dialx.ai/) to get an API key for accessing multiple models through their unified API. DIAL is an open-source AI orchestration platform that provides vendor-agnostic access to models from major providers, open-source community, and self-hosted deployments. [API Documentation](https://dialx.ai/dial_api)
**Option C: Custom API Endpoints (Local models like Ollama, vLLM)**
[Please see the setup guide](docs/custom_models.md#option-2-custom-api-setup-ollama-vllm-etc). With a custom API you can use:
@@ -154,7 +155,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan
- **Text Generation WebUI**: Popular local interface for running models
- **Any OpenAI-compatible API**: Custom endpoints for your own infrastructure
> **Note:** Using all three options may create ambiguity about which provider / model to use if there is an overlap.
> **Note:** Using multiple provider options may create ambiguity about which provider / model to use if there is an overlap.
> If all APIs are configured, native APIs will take priority when there is a clash in model name, such as for `gemini` and `o3`.
> Configure your model aliases and give them unique names in [`conf/custom_models.json`](conf/custom_models.json)
@@ -192,6 +193,12 @@ nano .env
# GEMINI_API_KEY=your-gemini-api-key-here # For Gemini models
# OPENAI_API_KEY=your-openai-api-key-here # For O3 model
# OPENROUTER_API_KEY=your-openrouter-key # For OpenRouter (see docs/custom_models.md)
# DIAL_API_KEY=your-dial-api-key-here # For DIAL platform
# For DIAL (optional configuration):
# DIAL_API_HOST=https://core.dialx.ai # Default DIAL host (optional)
# DIAL_API_VERSION=2024-12-01-preview # API version (optional)
# DIAL_ALLOWED_MODELS=o3,gemini-2.5-pro # Restrict to specific models (optional)
# For local models (Ollama, vLLM, etc.):
# CUSTOM_API_URL=http://localhost:11434/v1 # Ollama example
@@ -537,10 +544,11 @@ Configure the Zen MCP Server through environment variables in your `.env` file.
DEFAULT_MODEL=auto
GEMINI_API_KEY=your-gemini-key
OPENAI_API_KEY=your-openai-key
DIAL_API_KEY=your-dial-key # Optional: Access to multiple models via DIAL
```
**Key Configuration Options:**
- **API Keys**: Native APIs (Gemini, OpenAI, X.AI), OpenRouter, or Custom endpoints (Ollama, vLLM)
- **API Keys**: Native APIs (Gemini, OpenAI, X.AI), OpenRouter, DIAL, or Custom endpoints (Ollama, vLLM)
- **Model Selection**: Auto mode or specific model defaults
- **Usage Restrictions**: Control which models can be used for cost control
- **Conversation Settings**: Timeout, turn limits, memory configuration

View File

@@ -17,6 +17,7 @@ class ProviderType(Enum):
XAI = "xai"
OPENROUTER = "openrouter"
CUSTOM = "custom"
DIAL = "dial"
class TemperatureConstraint(ABC):
@@ -326,3 +327,12 @@ class ModelProvider(ABC):
Resolved model name
"""
return model_name
def close(self):
"""Clean up any resources held by the provider.
Default implementation does nothing.
Subclasses should override if they hold resources that need cleanup.
"""
# Base implementation: no resources to clean up
return

525
providers/dial.py Normal file
View File

@@ -0,0 +1,525 @@
"""DIAL (Data & AI Layer) model provider implementation."""
import logging
import os
import threading
import time
from typing import Optional
from .base import (
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
from .openai_compatible import OpenAICompatibleProvider
logger = logging.getLogger(__name__)
class DIALModelProvider(OpenAICompatibleProvider):
"""DIAL provider using OpenAI-compatible API.
DIAL provides access to various AI models through a unified API interface.
Supports GPT, Claude, Gemini, and other models via DIAL deployments.
"""
FRIENDLY_NAME = "DIAL"
# Retry configuration for API calls
MAX_RETRIES = 4
RETRY_DELAYS = [1, 3, 5, 8] # seconds
# Supported DIAL models (these can be customized based on your DIAL deployment)
SUPPORTED_MODELS = {
"o3-2025-04-16": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"o4-mini-2025-04-16": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"anthropic.claude-sonnet-4-20250514-v1:0": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
"context_window": 200_000,
"supports_extended_thinking": True, # Thinking mode variant
"supports_vision": True,
},
"anthropic.claude-opus-4-20250514-v1:0": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
"context_window": 200_000,
"supports_extended_thinking": True, # Thinking mode variant
"supports_vision": True,
},
"gemini-2.5-pro-preview-03-25-google-search": {
"context_window": 1_000_000,
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode
"supports_vision": True,
},
"gemini-2.5-pro-preview-05-06": {
"context_window": 1_000_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"gemini-2.5-flash-preview-05-20": {
"context_window": 1_000_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
# Shorthands
"o3": "o3-2025-04-16",
"o4-mini": "o4-mini-2025-04-16",
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
"opus-4": "anthropic.claude-opus-4-20250514-v1:0",
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
}
def __init__(self, api_key: str, **kwargs):
"""Initialize DIAL provider with API key and host.
Args:
api_key: DIAL API key for authentication
**kwargs: Additional configuration options
"""
# Get DIAL API host from environment or kwargs
dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai"
# DIAL uses /openai endpoint for OpenAI-compatible API
if not dial_host.endswith("/openai"):
dial_host = f"{dial_host.rstrip('/')}/openai"
kwargs["base_url"] = dial_host
# Get API version from environment or use default
self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview")
# Add DIAL-specific headers
# DIAL uses Api-Key header instead of Authorization: Bearer
# Reference: https://dialx.ai/dial_api#section/Authorization
self.DEFAULT_HEADERS = {
"Api-Key": api_key,
}
# Store the actual API key for use in Api-Key header
self._dial_api_key = api_key
# Pass a placeholder API key to OpenAI client - we'll override the auth header in httpx
# The actual authentication happens via the Api-Key header in the httpx client
super().__init__("placeholder-not-used", **kwargs)
# Cache for deployment-specific clients to avoid recreating them on each request
self._deployment_clients = {}
# Lock to ensure thread-safe client creation
self._client_lock = threading.Lock()
# Create a SINGLE shared httpx client for the provider instance
import httpx
# Create custom event hooks to remove Authorization header
def remove_auth_header(request):
"""Remove Authorization header that OpenAI client adds."""
# httpx headers are case-insensitive, so we need to check all variations
headers_to_remove = []
for header_name in request.headers:
if header_name.lower() == "authorization":
headers_to_remove.append(header_name)
for header_name in headers_to_remove:
del request.headers[header_name]
self._http_client = httpx.Client(
timeout=self.timeout_config,
verify=True,
follow_redirects=True,
headers=self.DEFAULT_HEADERS.copy(), # Include DIAL headers including Api-Key
limits=httpx.Limits(
max_keepalive_connections=5,
max_connections=10,
keepalive_expiry=30.0,
),
event_hooks={"request": [remove_auth_header]},
)
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific model.
Args:
model_name: Name of the model (can be shorthand)
Returns:
ModelCapabilities object
Raises:
ValueError: If model is not supported or not allowed
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported DIAL model: {model_name}")
# Check restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
config = self.SUPPORTED_MODELS[resolved_name]
return ModelCapabilities(
provider=ProviderType.DIAL,
model_name=resolved_name,
friendly_name=self.FRIENDLY_NAME,
context_window=config["context_window"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_images=config.get("supports_vision", False),
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
)
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.DIAL
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported.
Args:
model_name: Model name to validate
Returns:
True if model is supported and allowed, False otherwise
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
return False
# Check against base class allowed_models if configured
if self.allowed_models is not None:
# Check both original and resolved names (case-insensitive)
if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models:
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list")
return False
# Also check restrictions via ModelRestrictionService
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name.
Args:
model_name: Model name or shorthand
Returns:
Full model name
"""
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
def _get_deployment_client(self, deployment: str):
"""Get or create a cached client for a specific deployment.
This avoids recreating OpenAI clients on every request, improving performance.
Reuses the shared HTTP client for connection pooling.
Args:
deployment: The deployment/model name
Returns:
OpenAI client configured for the specific deployment
"""
# Check if client already exists without locking for performance
if deployment in self._deployment_clients:
return self._deployment_clients[deployment]
# Use lock to ensure thread-safe client creation
with self._client_lock:
# Double-check pattern: check again inside the lock
if deployment not in self._deployment_clients:
from openai import OpenAI
# Build deployment-specific URL
base_url = str(self.client.base_url)
if base_url.endswith("/"):
base_url = base_url[:-1]
# Remove /openai suffix if present to reconstruct properly
if base_url.endswith("/openai"):
base_url = base_url[:-7]
deployment_url = f"{base_url}/openai/deployments/{deployment}"
# Create and cache the client, REUSING the shared http_client
# Use placeholder API key - Authorization header will be removed by http_client event hook
self._deployment_clients[deployment] = OpenAI(
api_key="placeholder-not-used",
base_url=deployment_url,
http_client=self._http_client, # Pass the shared client with Api-Key header
default_query={"api-version": self.api_version}, # Add api-version as query param
)
return self._deployment_clients[deployment]
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_output_tokens: Optional[int] = None,
images: Optional[list[str]] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using DIAL's deployment-specific endpoint.
DIAL uses Azure OpenAI-style deployment endpoints:
/openai/deployments/{deployment}/chat/completions
Args:
prompt: User prompt
model_name: Model name or alias
system_prompt: Optional system prompt
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Validate model name against allow-list
if not self.validate_model_name(model_name):
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
# Validate parameters
self.validate_parameters(model_name, temperature)
# Prepare messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Build user message content
user_message_content = []
if prompt:
user_message_content.append({"type": "text", "text": prompt})
if images and self._supports_vision(model_name):
for img_path in images:
processed_image = self._process_image(img_path)
if processed_image:
user_message_content.append(processed_image)
elif images:
logger.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
# Add user message. If only text, content will be a string, otherwise a list.
if len(user_message_content) == 1 and user_message_content[0]["type"] == "text":
messages.append({"role": "user", "content": prompt})
else:
messages.append({"role": "user", "content": user_message_content})
# Resolve model name
resolved_model = self._resolve_model_name(model_name)
# Build completion parameters
completion_params = {
"model": resolved_model,
"messages": messages,
}
# Check model capabilities
try:
capabilities = self.get_capabilities(model_name)
supports_temperature = getattr(capabilities, "supports_temperature", True)
except Exception as e:
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
supports_temperature = True
# Add temperature parameter if supported
if supports_temperature:
completion_params["temperature"] = temperature
# Add max tokens if specified and model supports it
if max_output_tokens and supports_temperature:
completion_params["max_tokens"] = max_output_tokens
# Add additional parameters
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]:
continue
completion_params[key] = value
# DIAL-specific: Get cached client for deployment endpoint
deployment_client = self._get_deployment_client(resolved_model)
# Retry logic with progressive delays
last_exception = None
for attempt in range(self.MAX_RETRIES):
try:
# Generate completion using deployment-specific client
response = deployment_client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model,
"id": response.id,
"created": response.created,
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error
is_retryable = self._is_error_retryable(e)
if not is_retryable:
# Non-retryable error, raise immediately
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
# If this isn't the last attempt and error is retryable, wait and retry
if attempt < self.MAX_RETRIES - 1:
delay = self.RETRY_DELAYS[attempt]
logger.info(
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
)
time.sleep(delay)
continue
# All retries exhausted
raise ValueError(
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
)
def _supports_vision(self, model_name: str) -> bool:
"""Check if the model supports vision (image processing).
Args:
model_name: Model name to check
Returns:
True if model supports vision, False otherwise
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
# Fall back to parent implementation for unknown models
return super()._supports_vision(model_name)
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
# Get all model keys (both full names and aliases)
all_models = list(self.SUPPORTED_MODELS.keys())
if not respect_restrictions:
return all_models
# Apply restrictions if configured
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
# Filter based on restrictions
allowed_models = []
for model in all_models:
resolved_name = self._resolve_model_name(model)
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
allowed_models.append(model)
return allowed_models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
This is used for validation purposes to ensure restriction policies
can validate against both aliases and their target model names.
Returns:
List of all model names and alias targets known by this provider
"""
# Collect all unique model names (both aliases and targets)
all_models = set()
for key, value in self.SUPPORTED_MODELS.items():
# Add the key (could be alias or full name)
all_models.add(key)
# If it's an alias (string value), add the target too
if isinstance(value, str):
all_models.add(value)
return sorted(all_models)
def close(self):
"""Clean up HTTP clients when provider is closed."""
logger.info("Closing DIAL provider HTTP clients...")
# Clear the deployment clients cache
# Note: We don't need to close individual OpenAI clients since they
# use the shared httpx.Client which we close separately
self._deployment_clients.clear()
# Close the shared HTTP client
if hasattr(self, "_http_client"):
try:
self._http_client.close()
logger.debug("Closed shared HTTP client")
except Exception as e:
logger.warning(f"Error closing shared HTTP client: {e}")
# Also close the client created by the superclass (OpenAICompatibleProvider)
# as it holds its own httpx.Client instance that is not used by DIAL's generate_content
if hasattr(self, "client") and self.client and hasattr(self.client, "close"):
try:
self.client.close()
logger.debug("Closed superclass's OpenAI client")
except Exception as e:
logger.warning(f"Error closing superclass's OpenAI client: {e}")

View File

@@ -118,6 +118,7 @@ class ModelProviderRegistry:
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.XAI, # Direct X.AI GROK access
ProviderType.DIAL, # DIAL unified API access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
@@ -237,6 +238,7 @@ class ModelProviderRegistry:
ProviderType.XAI: "XAI_API_KEY",
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
ProviderType.DIAL: "DIAL_API_KEY",
}
env_var = key_mapping.get(provider_type)

View File

@@ -883,6 +883,7 @@ setup_env_file() {
"GEMINI_API_KEY:your_gemini_api_key_here"
"OPENAI_API_KEY:your_openai_api_key_here"
"XAI_API_KEY:your_xai_api_key_here"
"DIAL_API_KEY:your_dial_api_key_here"
"OPENROUTER_API_KEY:your_openrouter_api_key_here"
)
@@ -934,6 +935,7 @@ validate_api_keys() {
"GEMINI_API_KEY:your_gemini_api_key_here"
"OPENAI_API_KEY:your_openai_api_key_here"
"XAI_API_KEY:your_xai_api_key_here"
"DIAL_API_KEY:your_dial_api_key_here"
"OPENROUTER_API_KEY:your_openrouter_api_key_here"
)
@@ -961,6 +963,7 @@ validate_api_keys() {
echo " GEMINI_API_KEY=your-actual-key" >&2
echo " OPENAI_API_KEY=your-actual-key" >&2
echo " XAI_API_KEY=your-actual-key" >&2
echo " DIAL_API_KEY=your-actual-key" >&2
echo " OPENROUTER_API_KEY=your-actual-key" >&2
echo "" >&2
print_info "After adding your API keys, run ./run-server.sh again" >&2

View File

@@ -19,6 +19,7 @@ as defined by the MCP protocol.
"""
import asyncio
import atexit
import logging
import os
import sys
@@ -271,6 +272,7 @@ def configure_providers():
from providers import ModelProviderRegistry
from providers.base import ProviderType
from providers.custom import CustomProvider
from providers.dial import DIALModelProvider
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.openrouter import OpenRouterProvider
@@ -303,6 +305,13 @@ def configure_providers():
has_native_apis = True
logger.info("X.AI API key found - GROK models available")
# Check for DIAL API key
dial_key = os.getenv("DIAL_API_KEY")
if dial_key and dial_key != "your_dial_api_key_here":
valid_providers.append("DIAL")
has_native_apis = True
logger.info("DIAL API key found - DIAL models available")
# Check for OpenRouter API key
openrouter_key = os.getenv("OPENROUTER_API_KEY")
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
@@ -336,6 +345,8 @@ def configure_providers():
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
if xai_key and xai_key != "your_xai_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
if dial_key and dial_key != "your_dial_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider)
# 2. Custom provider second (for local/private models)
if has_custom:
@@ -358,6 +369,7 @@ def configure_providers():
"- GEMINI_API_KEY for Gemini models\n"
"- OPENAI_API_KEY for OpenAI o3 model\n"
"- XAI_API_KEY for X.AI GROK models\n"
"- DIAL_API_KEY for DIAL models\n"
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
"- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)"
)
@@ -376,6 +388,25 @@ def configure_providers():
if len(priority_info) > 1:
logger.info(f"Provider priority: {''.join(priority_info)}")
# Register cleanup function for providers
def cleanup_providers():
"""Clean up all registered providers on shutdown."""
try:
registry = ModelProviderRegistry()
if hasattr(registry, "_initialized_providers"):
for provider in list(registry._initialized_providers.items()):
try:
if provider and hasattr(provider, "close"):
provider.close()
except Exception:
# Logger might be closed during shutdown
pass
except Exception:
# Silently ignore any errors during cleanup
pass
atexit.register(cleanup_providers)
# Check and log model restrictions
restriction_service = get_restriction_service()
restrictions = restriction_service.get_restriction_summary()
@@ -390,7 +421,8 @@ def configure_providers():
# Validate restrictions against known models
provider_instances = {}
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI]:
provider_types_to_validate = [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]
for provider_type in provider_types_to_validate:
provider = ModelProviderRegistry.get_provider(provider_type)
if provider:
provider_instances[provider_type] = provider

273
tests/test_dial_provider.py Normal file
View File

@@ -0,0 +1,273 @@
"""Tests for DIAL provider implementation."""
import os
from unittest.mock import MagicMock, patch
import pytest
from providers.base import ProviderType
from providers.dial import DIALModelProvider
class TestDIALProvider:
"""Test DIAL provider functionality."""
@patch.dict(os.environ, {"DIAL_API_KEY": "test-key", "DIAL_API_HOST": "https://test.dialx.ai"})
def test_initialization_with_host(self):
"""Test provider initialization with custom host."""
provider = DIALModelProvider("test-key")
assert provider._dial_api_key == "test-key" # Check internal API key storage
assert provider.api_key == "placeholder-not-used" # OpenAI client uses placeholder, auth header removed by hook
assert provider.base_url == "https://test.dialx.ai/openai"
assert provider.get_provider_type() == ProviderType.DIAL
@patch.dict(os.environ, {"DIAL_API_KEY": "test-key", "DIAL_API_HOST": ""}, clear=True)
def test_initialization_default_host(self):
"""Test provider initialization with default host."""
provider = DIALModelProvider("test-key")
assert provider._dial_api_key == "test-key" # Check internal API key storage
assert provider.api_key == "placeholder-not-used" # OpenAI client uses placeholder, auth header removed by hook
assert provider.base_url == "https://core.dialx.ai/openai"
def test_initialization_host_normalization(self):
"""Test that host URL is normalized to include /openai suffix."""
# Test with host missing /openai
provider = DIALModelProvider("test-key", base_url="https://custom.dialx.ai")
assert provider.base_url == "https://custom.dialx.ai/openai"
# Test with host already having /openai
provider = DIALModelProvider("test-key", base_url="https://custom.dialx.ai/openai")
assert provider.base_url == "https://custom.dialx.ai/openai"
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
@patch("utils.model_restrictions._restriction_service", None)
def test_model_validation(self):
"""Test model name validation."""
provider = DIALModelProvider("test-key")
# Test valid models
assert provider.validate_model_name("o3-2025-04-16") is True
assert provider.validate_model_name("o3") is True # Shorthand
assert provider.validate_model_name("anthropic.claude-opus-4-20250514-v1:0") is True
assert provider.validate_model_name("opus-4") is True # Shorthand
assert provider.validate_model_name("gemini-2.5-pro-preview-05-06") is True
assert provider.validate_model_name("gemini-2.5-pro") is True # Shorthand
# Test invalid model
assert provider.validate_model_name("invalid-model") is False
def test_resolve_model_name(self):
"""Test model name resolution for shorthands."""
provider = DIALModelProvider("test-key")
# Test shorthand resolution
assert provider._resolve_model_name("o3") == "o3-2025-04-16"
assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16"
assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0"
assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
assert provider._resolve_model_name("gemini-2.5-pro") == "gemini-2.5-pro-preview-05-06"
assert provider._resolve_model_name("gemini-2.5-flash") == "gemini-2.5-flash-preview-05-20"
# Test full name passthrough
assert provider._resolve_model_name("o3-2025-04-16") == "o3-2025-04-16"
assert (
provider._resolve_model_name("anthropic.claude-opus-4-20250514-v1:0")
== "anthropic.claude-opus-4-20250514-v1:0"
)
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
@patch("utils.model_restrictions._restriction_service", None)
def test_get_capabilities(self):
"""Test getting model capabilities."""
provider = DIALModelProvider("test-key")
# Test O3 capabilities
capabilities = provider.get_capabilities("o3")
assert capabilities.model_name == "o3-2025-04-16"
assert capabilities.friendly_name == "DIAL"
assert capabilities.context_window == 200_000
assert capabilities.provider == ProviderType.DIAL
assert capabilities.supports_images is True
assert capabilities.supports_extended_thinking is False
# Test Claude 4 capabilities
capabilities = provider.get_capabilities("opus-4")
assert capabilities.model_name == "anthropic.claude-opus-4-20250514-v1:0"
assert capabilities.context_window == 200_000
assert capabilities.supports_images is True
assert capabilities.supports_extended_thinking is False
# Test Claude 4 with thinking mode
capabilities = provider.get_capabilities("opus-4-thinking")
assert capabilities.model_name == "anthropic.claude-opus-4-20250514-v1:0-with-thinking"
assert capabilities.context_window == 200_000
assert capabilities.supports_images is True
assert capabilities.supports_extended_thinking is True
# Test Gemini capabilities
capabilities = provider.get_capabilities("gemini-2.5-pro")
assert capabilities.model_name == "gemini-2.5-pro-preview-05-06"
assert capabilities.context_window == 1_000_000
assert capabilities.supports_images is True
# Test temperature constraint
assert capabilities.temperature_constraint.min_temp == 0.0
assert capabilities.temperature_constraint.max_temp == 2.0
assert capabilities.temperature_constraint.default_temp == 0.7
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
@patch("utils.model_restrictions._restriction_service", None)
def test_get_capabilities_invalid_model(self):
"""Test that get_capabilities raises for invalid models."""
provider = DIALModelProvider("test-key")
with pytest.raises(ValueError, match="Unsupported DIAL model"):
provider.get_capabilities("invalid-model")
@patch("utils.model_restrictions.get_restriction_service")
def test_get_capabilities_restricted_model(self, mock_get_restriction):
"""Test that get_capabilities respects model restrictions."""
provider = DIALModelProvider("test-key")
# Mock restriction service to block the model
mock_service = MagicMock()
mock_service.is_allowed.return_value = False
mock_get_restriction.return_value = mock_service
with pytest.raises(ValueError, match="not allowed by restriction policy"):
provider.get_capabilities("o3")
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
@patch("utils.model_restrictions._restriction_service", None)
def test_supports_vision(self):
"""Test vision support detection."""
provider = DIALModelProvider("test-key")
# Test models with vision support
assert provider._supports_vision("o3-2025-04-16") is True
assert provider._supports_vision("o3") is True # Via resolution
assert provider._supports_vision("anthropic.claude-opus-4-20250514-v1:0") is True
assert provider._supports_vision("gemini-2.5-pro-preview-05-06") is True
# Test unknown model (falls back to parent implementation)
assert provider._supports_vision("unknown-model") is False
@patch("openai.OpenAI") # Mock the OpenAI class directly from openai module
def test_generate_content_with_alias(self, mock_openai_class):
"""Test that generate_content properly resolves aliases and uses deployment routing."""
# Create mock client
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
mock_response.model = "gpt-4"
mock_response.id = "test-id"
mock_response.created = 1234567890
mock_response.choices[0].finish_reason = "stop"
mock_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_client
provider = DIALModelProvider("test-key")
# Generate content with shorthand
response = provider.generate_content(prompt="Test prompt", model_name="o3", temperature=0.7) # Shorthand
# Verify OpenAI was instantiated with deployment-specific URL
mock_openai_class.assert_called_once()
call_args = mock_openai_class.call_args
assert "/deployments/o3-2025-04-16" in call_args[1]["base_url"]
# Verify the resolved model name was passed to the API
mock_client.chat.completions.create.assert_called_once()
create_call_args = mock_client.chat.completions.create.call_args
assert create_call_args[1]["model"] == "o3-2025-04-16" # Resolved name
# Verify response
assert response.content == "Test response"
assert response.model_name == "o3" # Original name preserved
assert response.metadata["model"] == "gpt-4" # API returned model name from mock
def test_provider_type(self):
"""Test provider type identification."""
provider = DIALModelProvider("test-key")
assert provider.get_provider_type() == ProviderType.DIAL
def test_friendly_name(self):
"""Test provider friendly name."""
provider = DIALModelProvider("test-key")
assert provider.FRIENDLY_NAME == "DIAL"
@patch.dict(os.environ, {"DIAL_API_VERSION": "2024-12-01"})
def test_configurable_api_version(self):
"""Test that API version can be configured via environment variable."""
provider = DIALModelProvider("test-key")
# Check that the custom API version is stored
assert provider.api_version == "2024-12-01"
def test_default_api_version(self):
"""Test that default API version is used when not configured."""
# Clear any existing DIAL_API_VERSION from environment
with patch.dict(os.environ, {}, clear=True):
# Keep other env vars but ensure DIAL_API_VERSION is not set
if "DIAL_API_VERSION" in os.environ:
del os.environ["DIAL_API_VERSION"]
provider = DIALModelProvider("test-key")
# Check that the default API version is used
assert provider.api_version == "2024-12-01-preview"
# Check that Api-Key header is set
assert provider.DEFAULT_HEADERS["Api-Key"] == "test-key"
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": "o3-2025-04-16,anthropic.claude-opus-4-20250514-v1:0"})
@patch("utils.model_restrictions._restriction_service", None)
def test_allowed_models_restriction(self):
"""Test model allow-list functionality."""
provider = DIALModelProvider("test-key")
# These should be allowed
assert provider.validate_model_name("o3-2025-04-16") is True
assert provider.validate_model_name("o3") is True # Alias for o3-2025-04-16
assert provider.validate_model_name("anthropic.claude-opus-4-20250514-v1:0") is True
assert provider.validate_model_name("opus-4") is True # Resolves to anthropic.claude-opus-4-20250514-v1:0
# These should be blocked
assert provider.validate_model_name("gemini-2.5-pro-preview-05-06") is False
assert provider.validate_model_name("o4-mini-2025-04-16") is False
assert provider.validate_model_name("sonnet-4") is False # sonnet-4 is not in allowed list
@patch("httpx.Client")
@patch("openai.OpenAI")
def test_close_method(self, mock_openai_class, mock_httpx_client_class):
"""Test that the close method properly closes HTTP clients."""
# Mock the httpx.Client instance that DIALModelProvider will create
mock_shared_http_client = MagicMock()
mock_httpx_client_class.return_value = mock_shared_http_client
# Mock the OpenAI client instances
mock_openai_client_1 = MagicMock()
mock_openai_client_2 = MagicMock()
# Configure side_effect to return different mocks for subsequent calls
mock_openai_class.side_effect = [mock_openai_client_1, mock_openai_client_2]
provider = DIALModelProvider("test-key")
# Mock the superclass's _client attribute directly
mock_superclass_client = MagicMock()
provider._client = mock_superclass_client
# Simulate getting clients for two different deployments to populate _deployment_clients
provider._get_deployment_client("model_a")
provider._get_deployment_client("model_b")
# Now call close
provider.close()
# Assert that the shared httpx client's close method was called
mock_shared_http_client.close.assert_called_once()
# Assert that the superclass client's close method was called
mock_superclass_client.close.assert_called_once()
# Assert that the deployment clients cache is cleared
assert not provider._deployment_clients

View File

@@ -84,6 +84,7 @@ class ListModelsTool(BaseTool):
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
ProviderType.OPENAI: {"name": "OpenAI", "env_key": "OPENAI_API_KEY"},
ProviderType.XAI: {"name": "X.AI (Grok)", "env_key": "XAI_API_KEY"},
ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
}
# Check each native provider type

View File

@@ -11,6 +11,7 @@ Environment Variables:
- GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models
- XAI_ALLOWED_MODELS: Comma-separated list of allowed X.AI GROK models
- OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models
- DIAL_ALLOWED_MODELS: Comma-separated list of allowed DIAL models
Example:
OPENAI_ALLOWED_MODELS=o3-mini,o4-mini
@@ -44,6 +45,7 @@ class ModelRestrictionService:
ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS",
ProviderType.XAI: "XAI_ALLOWED_MODELS",
ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS",
ProviderType.DIAL: "DIAL_ALLOWED_MODELS",
}
def __init__(self):