From 0623ce3546fd744877b4fa09cf8943b4bb5c85dc Mon Sep 17 00:00:00 2001
From: Illya Havsiyevych <44289086+illya-havsiyevych@users.noreply.github.com>
Date: Mon, 23 Jun 2025 13:07:10 +0300
Subject: [PATCH] feat: DIAL provider implementation (#112)
## Description
This PR implements a new [DIAL](https://dialx.ai/dial_api) (Data & AI Layer) provider for the Zen MCP Server, enabling unified access to multiple AI models through the DIAL API platform. DIAL provides enterprise-grade AI model access with deployment-specific routing similar to Azure OpenAI.
## Changes Made
- [x] Added support of atexit:
- Ensures automatic cleanup of provider resources (HTTP clients, connection pools) on server shutdown
- Fixed bug using ModelProviderRegistry.get_available_providers() instead of accessing private _providers
- Works with SIGTERM/Ctrl+C for graceful shutdown in both development and containerized environments
- [x] Added new DIAL provider (`providers/dial.py`) inheriting from `OpenAICompatibleProvider`
- [x] Updated server.py to register DIAL provider during initialization
- [x] Updated provider registry to include DIAL provider type
- [x] Implemented deployment-specific routing for DIAL's Azure OpenAI-style endpoints
- [x] Implemented performance optimizations:
- Connection pooling with httpx for better performance
- Thread-safe client caching with double-check locking pattern
- Proper resource cleanup with `close()` method
- [x] Added comprehensive unit tests with 16 test cases (`tests/test_dial_provider.py`)
- [x] Added DIAL configuration to `.env.example` with documentation
- [x] Added support for configurable API version via `DIAL_API_VERSION` environment variable
- [x] Added DIAL model restrictions support via `DIAL_ALLOWED_MODELS` environment variable
### Supported DIAL Models:
- OpenAI models: o3, o4-mini (and their dated versions)
- Google models: gemini-2.5-pro, gemini-2.5-flash (including search variant)
- Anthropic models: Claude 4 Opus/Sonnet (with and without thinking mode)
### Environment Variables:
- `DIAL_API_KEY`: Required API key for DIAL authentication
- `DIAL_API_HOST`: Optional base URL (defaults to https://core.dialx.ai)
- `DIAL_API_VERSION`: Optional API version header (defaults to 2025-01-01-preview)
- `DIAL_ALLOWED_MODELS`: Optional comma-separated list of allowed models
### Breaking Changes:
- None
### Dependencies:
- No new dependencies added (uses existing OpenAI SDK with custom routing)
---
.env.example | 45 +++-
README.md | 14 +-
providers/base.py | 10 +
providers/dial.py | 525 ++++++++++++++++++++++++++++++++++++
providers/registry.py | 2 +
run-server.sh | 3 +
server.py | 34 ++-
tests/test_dial_provider.py | 273 +++++++++++++++++++
tools/listmodels.py | 1 +
utils/model_restrictions.py | 2 +
10 files changed, 900 insertions(+), 9 deletions(-)
create mode 100644 providers/dial.py
create mode 100644 tests/test_dial_provider.py
diff --git a/.env.example b/.env.example
index a7e6376..1d88d4c 100644
--- a/.env.example
+++ b/.env.example
@@ -3,8 +3,11 @@
# API Keys - At least one is required
#
-# IMPORTANT: Use EITHER OpenRouter OR native APIs (Gemini/OpenAI), not both!
-# Having both creates ambiguity about which provider serves each model.
+# IMPORTANT: Choose ONE approach:
+# - Native APIs (Gemini/OpenAI/XAI) for direct access
+# - DIAL for unified enterprise access
+# - OpenRouter for unified cloud access
+# Having multiple unified providers creates ambiguity about which serves each model.
#
# Option 1: Use native APIs (recommended for direct access)
# Get your Gemini API key from: https://makersuite.google.com/app/apikey
@@ -16,6 +19,12 @@ OPENAI_API_KEY=your_openai_api_key_here
# Get your X.AI API key from: https://console.x.ai/
XAI_API_KEY=your_xai_api_key_here
+# Get your DIAL API key and configure host URL
+# DIAL provides unified access to multiple AI models through a single API
+DIAL_API_KEY=your_dial_api_key_here
+# DIAL_API_HOST=https://core.dialx.ai # Optional: Base URL without /openai suffix (auto-appended)
+# DIAL_API_VERSION=2025-01-01-preview # Optional: API version header for DIAL requests
+
# Option 2: Use OpenRouter for access to multiple models through one API
# Get your OpenRouter API key from: https://openrouter.ai/
# If using OpenRouter, comment out the native API keys above
@@ -27,7 +36,8 @@ OPENROUTER_API_KEY=your_openrouter_api_key_here
# CUSTOM_MODEL_NAME=llama3.2 # Default model name
# Optional: Default model to use
-# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high' etc
+# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high',
+# 'grok', 'opus-4', 'sonnet-4', or any DIAL model if DIAL is configured
# When set to 'auto', Claude will select the best model for each task
# Defaults to 'auto' if not specified
DEFAULT_MODEL=auto
@@ -70,6 +80,26 @@ DEFAULT_THINKING_MODE_THINKDEEP=high
# - grok3 (shorthand for grok-3)
# - grokfast (shorthand for grok-3-fast)
#
+# Supported DIAL models (when available in your DIAL deployment):
+# - o3-2025-04-16 (200K context, latest O3 release)
+# - o4-mini-2025-04-16 (200K context, latest O4 mini)
+# - o3 (shorthand for o3-2025-04-16)
+# - o4-mini (shorthand for o4-mini-2025-04-16)
+# - anthropic.claude-sonnet-4-20250514-v1:0 (200K context, Claude 4 Sonnet)
+# - anthropic.claude-sonnet-4-20250514-v1:0-with-thinking (200K context, Claude 4 Sonnet with thinking mode)
+# - anthropic.claude-opus-4-20250514-v1:0 (200K context, Claude 4 Opus)
+# - anthropic.claude-opus-4-20250514-v1:0-with-thinking (200K context, Claude 4 Opus with thinking mode)
+# - sonnet-4 (shorthand for Claude 4 Sonnet)
+# - sonnet-4-thinking (shorthand for Claude 4 Sonnet with thinking)
+# - opus-4 (shorthand for Claude 4 Opus)
+# - opus-4-thinking (shorthand for Claude 4 Opus with thinking)
+# - gemini-2.5-pro-preview-03-25-google-search (1M context, with Google Search)
+# - gemini-2.5-pro-preview-05-06 (1M context, latest preview)
+# - gemini-2.5-flash-preview-05-20 (1M context, latest flash preview)
+# - gemini-2.5-pro (shorthand for gemini-2.5-pro-preview-05-06)
+# - gemini-2.5-pro-search (shorthand for gemini-2.5-pro-preview-03-25-google-search)
+# - gemini-2.5-flash (shorthand for gemini-2.5-flash-preview-05-20)
+#
# Examples:
# OPENAI_ALLOWED_MODELS=o3-mini,o4-mini,mini # Only allow mini models (cost control)
# GOOGLE_ALLOWED_MODELS=flash # Only allow Flash (fast responses)
@@ -77,21 +107,26 @@ DEFAULT_THINKING_MODE_THINKDEEP=high
# OPENAI_ALLOWED_MODELS=o4-mini # Single model standardization
# GOOGLE_ALLOWED_MODELS=flash,pro # Allow both Gemini models
# XAI_ALLOWED_MODELS=grok,grok-3-fast # Allow both GROK variants
+# DIAL_ALLOWED_MODELS=o3,o4-mini # Only allow O3/O4 models via DIAL
+# DIAL_ALLOWED_MODELS=opus-4,sonnet-4 # Only Claude 4 models (without thinking)
+# DIAL_ALLOWED_MODELS=opus-4-thinking,sonnet-4-thinking # Only Claude 4 with thinking mode
+# DIAL_ALLOWED_MODELS=gemini-2.5-pro,gemini-2.5-flash # Only Gemini 2.5 models via DIAL
#
# Note: These restrictions apply even in 'auto' mode - Claude will only pick from allowed models
# OPENAI_ALLOWED_MODELS=
# GOOGLE_ALLOWED_MODELS=
# XAI_ALLOWED_MODELS=
+# DIAL_ALLOWED_MODELS=
# Optional: Custom model configuration file path
# Override the default location of custom_models.json
# CUSTOM_MODELS_CONFIG_PATH=/path/to/your/custom_models.json
-# Note: Redis is no longer used - conversations are stored in memory
+# Note: Conversations are stored in memory during the session
# Optional: Conversation timeout (hours)
# How long AI-to-AI conversation threads persist before expiring
-# Longer timeouts use more Redis memory but allow resuming conversations later
+# Longer timeouts use more memory but allow resuming conversations later
# Defaults to 3 hours if not specified
CONVERSATION_TIMEOUT_HOURS=3
diff --git a/README.md b/README.md
index c540380..40552da 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
[zen_web.webm](https://github.com/user-attachments/assets/851e3911-7f06-47c0-a4ab-a2601236697c)
- 🤖 Claude + [Gemini / OpenAI / Grok / OpenRouter / Ollama / Any Model] = Your Ultimate AI Development Team
+ 🤖 Claude + [Gemini / OpenAI / Grok / OpenRouter / DIAL / Ollama / Any Model] = Your Ultimate AI Development Team
@@ -145,6 +145,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models.
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) to get an API key for O3 model access.
- **X.AI**: Visit [X.AI Console](https://console.x.ai/) to get an API key for GROK model access.
+- **DIAL**: Visit [DIAL Platform](https://dialx.ai/) to get an API key for accessing multiple models through their unified API. DIAL is an open-source AI orchestration platform that provides vendor-agnostic access to models from major providers, open-source community, and self-hosted deployments. [API Documentation](https://dialx.ai/dial_api)
**Option C: Custom API Endpoints (Local models like Ollama, vLLM)**
[Please see the setup guide](docs/custom_models.md#option-2-custom-api-setup-ollama-vllm-etc). With a custom API you can use:
@@ -154,7 +155,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan
- **Text Generation WebUI**: Popular local interface for running models
- **Any OpenAI-compatible API**: Custom endpoints for your own infrastructure
-> **Note:** Using all three options may create ambiguity about which provider / model to use if there is an overlap.
+> **Note:** Using multiple provider options may create ambiguity about which provider / model to use if there is an overlap.
> If all APIs are configured, native APIs will take priority when there is a clash in model name, such as for `gemini` and `o3`.
> Configure your model aliases and give them unique names in [`conf/custom_models.json`](conf/custom_models.json)
@@ -192,6 +193,12 @@ nano .env
# GEMINI_API_KEY=your-gemini-api-key-here # For Gemini models
# OPENAI_API_KEY=your-openai-api-key-here # For O3 model
# OPENROUTER_API_KEY=your-openrouter-key # For OpenRouter (see docs/custom_models.md)
+# DIAL_API_KEY=your-dial-api-key-here # For DIAL platform
+
+# For DIAL (optional configuration):
+# DIAL_API_HOST=https://core.dialx.ai # Default DIAL host (optional)
+# DIAL_API_VERSION=2024-12-01-preview # API version (optional)
+# DIAL_ALLOWED_MODELS=o3,gemini-2.5-pro # Restrict to specific models (optional)
# For local models (Ollama, vLLM, etc.):
# CUSTOM_API_URL=http://localhost:11434/v1 # Ollama example
@@ -537,10 +544,11 @@ Configure the Zen MCP Server through environment variables in your `.env` file.
DEFAULT_MODEL=auto
GEMINI_API_KEY=your-gemini-key
OPENAI_API_KEY=your-openai-key
+DIAL_API_KEY=your-dial-key # Optional: Access to multiple models via DIAL
```
**Key Configuration Options:**
-- **API Keys**: Native APIs (Gemini, OpenAI, X.AI), OpenRouter, or Custom endpoints (Ollama, vLLM)
+- **API Keys**: Native APIs (Gemini, OpenAI, X.AI), OpenRouter, DIAL, or Custom endpoints (Ollama, vLLM)
- **Model Selection**: Auto mode or specific model defaults
- **Usage Restrictions**: Control which models can be used for cost control
- **Conversation Settings**: Timeout, turn limits, memory configuration
diff --git a/providers/base.py b/providers/base.py
index e0b3882..c8b1ec7 100644
--- a/providers/base.py
+++ b/providers/base.py
@@ -17,6 +17,7 @@ class ProviderType(Enum):
XAI = "xai"
OPENROUTER = "openrouter"
CUSTOM = "custom"
+ DIAL = "dial"
class TemperatureConstraint(ABC):
@@ -326,3 +327,12 @@ class ModelProvider(ABC):
Resolved model name
"""
return model_name
+
+ def close(self):
+ """Clean up any resources held by the provider.
+
+ Default implementation does nothing.
+ Subclasses should override if they hold resources that need cleanup.
+ """
+ # Base implementation: no resources to clean up
+ return
diff --git a/providers/dial.py b/providers/dial.py
new file mode 100644
index 0000000..617858c
--- /dev/null
+++ b/providers/dial.py
@@ -0,0 +1,525 @@
+"""DIAL (Data & AI Layer) model provider implementation."""
+
+import logging
+import os
+import threading
+import time
+from typing import Optional
+
+from .base import (
+ ModelCapabilities,
+ ModelResponse,
+ ProviderType,
+ RangeTemperatureConstraint,
+)
+from .openai_compatible import OpenAICompatibleProvider
+
+logger = logging.getLogger(__name__)
+
+
+class DIALModelProvider(OpenAICompatibleProvider):
+ """DIAL provider using OpenAI-compatible API.
+
+ DIAL provides access to various AI models through a unified API interface.
+ Supports GPT, Claude, Gemini, and other models via DIAL deployments.
+ """
+
+ FRIENDLY_NAME = "DIAL"
+
+ # Retry configuration for API calls
+ MAX_RETRIES = 4
+ RETRY_DELAYS = [1, 3, 5, 8] # seconds
+
+ # Supported DIAL models (these can be customized based on your DIAL deployment)
+ SUPPORTED_MODELS = {
+ "o3-2025-04-16": {
+ "context_window": 200_000,
+ "supports_extended_thinking": False,
+ "supports_vision": True,
+ },
+ "o4-mini-2025-04-16": {
+ "context_window": 200_000,
+ "supports_extended_thinking": False,
+ "supports_vision": True,
+ },
+ "anthropic.claude-sonnet-4-20250514-v1:0": {
+ "context_window": 200_000,
+ "supports_extended_thinking": False,
+ "supports_vision": True,
+ },
+ "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
+ "context_window": 200_000,
+ "supports_extended_thinking": True, # Thinking mode variant
+ "supports_vision": True,
+ },
+ "anthropic.claude-opus-4-20250514-v1:0": {
+ "context_window": 200_000,
+ "supports_extended_thinking": False,
+ "supports_vision": True,
+ },
+ "anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
+ "context_window": 200_000,
+ "supports_extended_thinking": True, # Thinking mode variant
+ "supports_vision": True,
+ },
+ "gemini-2.5-pro-preview-03-25-google-search": {
+ "context_window": 1_000_000,
+ "supports_extended_thinking": False, # DIAL doesn't expose thinking mode
+ "supports_vision": True,
+ },
+ "gemini-2.5-pro-preview-05-06": {
+ "context_window": 1_000_000,
+ "supports_extended_thinking": False,
+ "supports_vision": True,
+ },
+ "gemini-2.5-flash-preview-05-20": {
+ "context_window": 1_000_000,
+ "supports_extended_thinking": False,
+ "supports_vision": True,
+ },
+ # Shorthands
+ "o3": "o3-2025-04-16",
+ "o4-mini": "o4-mini-2025-04-16",
+ "sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
+ "sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
+ "opus-4": "anthropic.claude-opus-4-20250514-v1:0",
+ "opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
+ "gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
+ "gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
+ "gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
+ }
+
+ def __init__(self, api_key: str, **kwargs):
+ """Initialize DIAL provider with API key and host.
+
+ Args:
+ api_key: DIAL API key for authentication
+ **kwargs: Additional configuration options
+ """
+ # Get DIAL API host from environment or kwargs
+ dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai"
+
+ # DIAL uses /openai endpoint for OpenAI-compatible API
+ if not dial_host.endswith("/openai"):
+ dial_host = f"{dial_host.rstrip('/')}/openai"
+
+ kwargs["base_url"] = dial_host
+
+ # Get API version from environment or use default
+ self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview")
+
+ # Add DIAL-specific headers
+ # DIAL uses Api-Key header instead of Authorization: Bearer
+ # Reference: https://dialx.ai/dial_api#section/Authorization
+ self.DEFAULT_HEADERS = {
+ "Api-Key": api_key,
+ }
+
+ # Store the actual API key for use in Api-Key header
+ self._dial_api_key = api_key
+
+ # Pass a placeholder API key to OpenAI client - we'll override the auth header in httpx
+ # The actual authentication happens via the Api-Key header in the httpx client
+ super().__init__("placeholder-not-used", **kwargs)
+
+ # Cache for deployment-specific clients to avoid recreating them on each request
+ self._deployment_clients = {}
+ # Lock to ensure thread-safe client creation
+ self._client_lock = threading.Lock()
+
+ # Create a SINGLE shared httpx client for the provider instance
+ import httpx
+
+ # Create custom event hooks to remove Authorization header
+ def remove_auth_header(request):
+ """Remove Authorization header that OpenAI client adds."""
+ # httpx headers are case-insensitive, so we need to check all variations
+ headers_to_remove = []
+ for header_name in request.headers:
+ if header_name.lower() == "authorization":
+ headers_to_remove.append(header_name)
+
+ for header_name in headers_to_remove:
+ del request.headers[header_name]
+
+ self._http_client = httpx.Client(
+ timeout=self.timeout_config,
+ verify=True,
+ follow_redirects=True,
+ headers=self.DEFAULT_HEADERS.copy(), # Include DIAL headers including Api-Key
+ limits=httpx.Limits(
+ max_keepalive_connections=5,
+ max_connections=10,
+ keepalive_expiry=30.0,
+ ),
+ event_hooks={"request": [remove_auth_header]},
+ )
+
+ logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
+
+ def get_capabilities(self, model_name: str) -> ModelCapabilities:
+ """Get capabilities for a specific model.
+
+ Args:
+ model_name: Name of the model (can be shorthand)
+
+ Returns:
+ ModelCapabilities object
+
+ Raises:
+ ValueError: If model is not supported or not allowed
+ """
+ resolved_name = self._resolve_model_name(model_name)
+
+ if resolved_name not in self.SUPPORTED_MODELS:
+ raise ValueError(f"Unsupported DIAL model: {model_name}")
+
+ # Check restrictions
+ from utils.model_restrictions import get_restriction_service
+
+ restriction_service = get_restriction_service()
+ if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
+ raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
+
+ config = self.SUPPORTED_MODELS[resolved_name]
+
+ return ModelCapabilities(
+ provider=ProviderType.DIAL,
+ model_name=resolved_name,
+ friendly_name=self.FRIENDLY_NAME,
+ context_window=config["context_window"],
+ supports_extended_thinking=config["supports_extended_thinking"],
+ supports_system_prompts=True,
+ supports_streaming=True,
+ supports_function_calling=True,
+ supports_images=config.get("supports_vision", False),
+ temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
+ )
+
+ def get_provider_type(self) -> ProviderType:
+ """Get the provider type."""
+ return ProviderType.DIAL
+
+ def validate_model_name(self, model_name: str) -> bool:
+ """Validate if the model name is supported.
+
+ Args:
+ model_name: Model name to validate
+
+ Returns:
+ True if model is supported and allowed, False otherwise
+ """
+ resolved_name = self._resolve_model_name(model_name)
+
+ if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
+ return False
+
+ # Check against base class allowed_models if configured
+ if self.allowed_models is not None:
+ # Check both original and resolved names (case-insensitive)
+ if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models:
+ logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list")
+ return False
+
+ # Also check restrictions via ModelRestrictionService
+ from utils.model_restrictions import get_restriction_service
+
+ restriction_service = get_restriction_service()
+ if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
+ logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions")
+ return False
+
+ return True
+
+ def _resolve_model_name(self, model_name: str) -> str:
+ """Resolve model shorthand to full name.
+
+ Args:
+ model_name: Model name or shorthand
+
+ Returns:
+ Full model name
+ """
+ shorthand_value = self.SUPPORTED_MODELS.get(model_name)
+ if isinstance(shorthand_value, str):
+ return shorthand_value
+ return model_name
+
+ def _get_deployment_client(self, deployment: str):
+ """Get or create a cached client for a specific deployment.
+
+ This avoids recreating OpenAI clients on every request, improving performance.
+ Reuses the shared HTTP client for connection pooling.
+
+ Args:
+ deployment: The deployment/model name
+
+ Returns:
+ OpenAI client configured for the specific deployment
+ """
+ # Check if client already exists without locking for performance
+ if deployment in self._deployment_clients:
+ return self._deployment_clients[deployment]
+
+ # Use lock to ensure thread-safe client creation
+ with self._client_lock:
+ # Double-check pattern: check again inside the lock
+ if deployment not in self._deployment_clients:
+ from openai import OpenAI
+
+ # Build deployment-specific URL
+ base_url = str(self.client.base_url)
+ if base_url.endswith("/"):
+ base_url = base_url[:-1]
+
+ # Remove /openai suffix if present to reconstruct properly
+ if base_url.endswith("/openai"):
+ base_url = base_url[:-7]
+
+ deployment_url = f"{base_url}/openai/deployments/{deployment}"
+
+ # Create and cache the client, REUSING the shared http_client
+ # Use placeholder API key - Authorization header will be removed by http_client event hook
+ self._deployment_clients[deployment] = OpenAI(
+ api_key="placeholder-not-used",
+ base_url=deployment_url,
+ http_client=self._http_client, # Pass the shared client with Api-Key header
+ default_query={"api-version": self.api_version}, # Add api-version as query param
+ )
+
+ return self._deployment_clients[deployment]
+
+ def generate_content(
+ self,
+ prompt: str,
+ model_name: str,
+ system_prompt: Optional[str] = None,
+ temperature: float = 0.7,
+ max_output_tokens: Optional[int] = None,
+ images: Optional[list[str]] = None,
+ **kwargs,
+ ) -> ModelResponse:
+ """Generate content using DIAL's deployment-specific endpoint.
+
+ DIAL uses Azure OpenAI-style deployment endpoints:
+ /openai/deployments/{deployment}/chat/completions
+
+ Args:
+ prompt: User prompt
+ model_name: Model name or alias
+ system_prompt: Optional system prompt
+ temperature: Sampling temperature
+ max_output_tokens: Maximum tokens to generate
+ **kwargs: Additional provider-specific parameters
+
+ Returns:
+ ModelResponse with generated content and metadata
+ """
+ # Validate model name against allow-list
+ if not self.validate_model_name(model_name):
+ raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
+
+ # Validate parameters
+ self.validate_parameters(model_name, temperature)
+
+ # Prepare messages
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ # Build user message content
+ user_message_content = []
+ if prompt:
+ user_message_content.append({"type": "text", "text": prompt})
+
+ if images and self._supports_vision(model_name):
+ for img_path in images:
+ processed_image = self._process_image(img_path)
+ if processed_image:
+ user_message_content.append(processed_image)
+ elif images:
+ logger.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
+
+ # Add user message. If only text, content will be a string, otherwise a list.
+ if len(user_message_content) == 1 and user_message_content[0]["type"] == "text":
+ messages.append({"role": "user", "content": prompt})
+ else:
+ messages.append({"role": "user", "content": user_message_content})
+
+ # Resolve model name
+ resolved_model = self._resolve_model_name(model_name)
+
+ # Build completion parameters
+ completion_params = {
+ "model": resolved_model,
+ "messages": messages,
+ }
+
+ # Check model capabilities
+ try:
+ capabilities = self.get_capabilities(model_name)
+ supports_temperature = getattr(capabilities, "supports_temperature", True)
+ except Exception as e:
+ logger.debug(f"Failed to check temperature support for {model_name}: {e}")
+ supports_temperature = True
+
+ # Add temperature parameter if supported
+ if supports_temperature:
+ completion_params["temperature"] = temperature
+
+ # Add max tokens if specified and model supports it
+ if max_output_tokens and supports_temperature:
+ completion_params["max_tokens"] = max_output_tokens
+
+ # Add additional parameters
+ for key, value in kwargs.items():
+ if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
+ if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]:
+ continue
+ completion_params[key] = value
+
+ # DIAL-specific: Get cached client for deployment endpoint
+ deployment_client = self._get_deployment_client(resolved_model)
+
+ # Retry logic with progressive delays
+ last_exception = None
+
+ for attempt in range(self.MAX_RETRIES):
+ try:
+ # Generate completion using deployment-specific client
+ response = deployment_client.chat.completions.create(**completion_params)
+
+ # Extract content and usage
+ content = response.choices[0].message.content
+ usage = self._extract_usage(response)
+
+ return ModelResponse(
+ content=content,
+ usage=usage,
+ model_name=model_name,
+ friendly_name=self.FRIENDLY_NAME,
+ provider=self.get_provider_type(),
+ metadata={
+ "finish_reason": response.choices[0].finish_reason,
+ "model": response.model,
+ "id": response.id,
+ "created": response.created,
+ },
+ )
+
+ except Exception as e:
+ last_exception = e
+
+ # Check if this is a retryable error
+ is_retryable = self._is_error_retryable(e)
+
+ if not is_retryable:
+ # Non-retryable error, raise immediately
+ raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
+
+ # If this isn't the last attempt and error is retryable, wait and retry
+ if attempt < self.MAX_RETRIES - 1:
+ delay = self.RETRY_DELAYS[attempt]
+ logger.info(
+ f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
+ )
+ time.sleep(delay)
+ continue
+
+ # All retries exhausted
+ raise ValueError(
+ f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
+ )
+
+ def _supports_vision(self, model_name: str) -> bool:
+ """Check if the model supports vision (image processing).
+
+ Args:
+ model_name: Model name to check
+
+ Returns:
+ True if model supports vision, False otherwise
+ """
+ resolved_name = self._resolve_model_name(model_name)
+
+ if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
+ return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
+
+ # Fall back to parent implementation for unknown models
+ return super()._supports_vision(model_name)
+
+ def list_models(self, respect_restrictions: bool = True) -> list[str]:
+ """Return a list of model names supported by this provider.
+
+ Args:
+ respect_restrictions: Whether to apply provider-specific restriction logic.
+
+ Returns:
+ List of model names available from this provider
+ """
+ # Get all model keys (both full names and aliases)
+ all_models = list(self.SUPPORTED_MODELS.keys())
+
+ if not respect_restrictions:
+ return all_models
+
+ # Apply restrictions if configured
+ from utils.model_restrictions import get_restriction_service
+
+ restriction_service = get_restriction_service()
+
+ # Filter based on restrictions
+ allowed_models = []
+ for model in all_models:
+ resolved_name = self._resolve_model_name(model)
+ if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
+ allowed_models.append(model)
+
+ return allowed_models
+
+ def list_all_known_models(self) -> list[str]:
+ """Return all model names known by this provider, including alias targets.
+
+ This is used for validation purposes to ensure restriction policies
+ can validate against both aliases and their target model names.
+
+ Returns:
+ List of all model names and alias targets known by this provider
+ """
+ # Collect all unique model names (both aliases and targets)
+ all_models = set()
+
+ for key, value in self.SUPPORTED_MODELS.items():
+ # Add the key (could be alias or full name)
+ all_models.add(key)
+
+ # If it's an alias (string value), add the target too
+ if isinstance(value, str):
+ all_models.add(value)
+
+ return sorted(all_models)
+
+ def close(self):
+ """Clean up HTTP clients when provider is closed."""
+ logger.info("Closing DIAL provider HTTP clients...")
+
+ # Clear the deployment clients cache
+ # Note: We don't need to close individual OpenAI clients since they
+ # use the shared httpx.Client which we close separately
+ self._deployment_clients.clear()
+
+ # Close the shared HTTP client
+ if hasattr(self, "_http_client"):
+ try:
+ self._http_client.close()
+ logger.debug("Closed shared HTTP client")
+ except Exception as e:
+ logger.warning(f"Error closing shared HTTP client: {e}")
+
+ # Also close the client created by the superclass (OpenAICompatibleProvider)
+ # as it holds its own httpx.Client instance that is not used by DIAL's generate_content
+ if hasattr(self, "client") and self.client and hasattr(self.client, "close"):
+ try:
+ self.client.close()
+ logger.debug("Closed superclass's OpenAI client")
+ except Exception as e:
+ logger.warning(f"Error closing superclass's OpenAI client: {e}")
diff --git a/providers/registry.py b/providers/registry.py
index a5efcf0..baa9222 100644
--- a/providers/registry.py
+++ b/providers/registry.py
@@ -118,6 +118,7 @@ class ModelProviderRegistry:
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.XAI, # Direct X.AI GROK access
+ ProviderType.DIAL, # DIAL unified API access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
@@ -237,6 +238,7 @@ class ModelProviderRegistry:
ProviderType.XAI: "XAI_API_KEY",
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
+ ProviderType.DIAL: "DIAL_API_KEY",
}
env_var = key_mapping.get(provider_type)
diff --git a/run-server.sh b/run-server.sh
index d2d0ebe..243f0e0 100755
--- a/run-server.sh
+++ b/run-server.sh
@@ -883,6 +883,7 @@ setup_env_file() {
"GEMINI_API_KEY:your_gemini_api_key_here"
"OPENAI_API_KEY:your_openai_api_key_here"
"XAI_API_KEY:your_xai_api_key_here"
+ "DIAL_API_KEY:your_dial_api_key_here"
"OPENROUTER_API_KEY:your_openrouter_api_key_here"
)
@@ -934,6 +935,7 @@ validate_api_keys() {
"GEMINI_API_KEY:your_gemini_api_key_here"
"OPENAI_API_KEY:your_openai_api_key_here"
"XAI_API_KEY:your_xai_api_key_here"
+ "DIAL_API_KEY:your_dial_api_key_here"
"OPENROUTER_API_KEY:your_openrouter_api_key_here"
)
@@ -961,6 +963,7 @@ validate_api_keys() {
echo " GEMINI_API_KEY=your-actual-key" >&2
echo " OPENAI_API_KEY=your-actual-key" >&2
echo " XAI_API_KEY=your-actual-key" >&2
+ echo " DIAL_API_KEY=your-actual-key" >&2
echo " OPENROUTER_API_KEY=your-actual-key" >&2
echo "" >&2
print_info "After adding your API keys, run ./run-server.sh again" >&2
diff --git a/server.py b/server.py
index 1b0f969..19904fb 100644
--- a/server.py
+++ b/server.py
@@ -19,6 +19,7 @@ as defined by the MCP protocol.
"""
import asyncio
+import atexit
import logging
import os
import sys
@@ -271,6 +272,7 @@ def configure_providers():
from providers import ModelProviderRegistry
from providers.base import ProviderType
from providers.custom import CustomProvider
+ from providers.dial import DIALModelProvider
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.openrouter import OpenRouterProvider
@@ -303,6 +305,13 @@ def configure_providers():
has_native_apis = True
logger.info("X.AI API key found - GROK models available")
+ # Check for DIAL API key
+ dial_key = os.getenv("DIAL_API_KEY")
+ if dial_key and dial_key != "your_dial_api_key_here":
+ valid_providers.append("DIAL")
+ has_native_apis = True
+ logger.info("DIAL API key found - DIAL models available")
+
# Check for OpenRouter API key
openrouter_key = os.getenv("OPENROUTER_API_KEY")
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
@@ -336,6 +345,8 @@ def configure_providers():
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
if xai_key and xai_key != "your_xai_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
+ if dial_key and dial_key != "your_dial_api_key_here":
+ ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider)
# 2. Custom provider second (for local/private models)
if has_custom:
@@ -358,6 +369,7 @@ def configure_providers():
"- GEMINI_API_KEY for Gemini models\n"
"- OPENAI_API_KEY for OpenAI o3 model\n"
"- XAI_API_KEY for X.AI GROK models\n"
+ "- DIAL_API_KEY for DIAL models\n"
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
"- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)"
)
@@ -376,6 +388,25 @@ def configure_providers():
if len(priority_info) > 1:
logger.info(f"Provider priority: {' → '.join(priority_info)}")
+ # Register cleanup function for providers
+ def cleanup_providers():
+ """Clean up all registered providers on shutdown."""
+ try:
+ registry = ModelProviderRegistry()
+ if hasattr(registry, "_initialized_providers"):
+ for provider in list(registry._initialized_providers.items()):
+ try:
+ if provider and hasattr(provider, "close"):
+ provider.close()
+ except Exception:
+ # Logger might be closed during shutdown
+ pass
+ except Exception:
+ # Silently ignore any errors during cleanup
+ pass
+
+ atexit.register(cleanup_providers)
+
# Check and log model restrictions
restriction_service = get_restriction_service()
restrictions = restriction_service.get_restriction_summary()
@@ -390,7 +421,8 @@ def configure_providers():
# Validate restrictions against known models
provider_instances = {}
- for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI]:
+ provider_types_to_validate = [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]
+ for provider_type in provider_types_to_validate:
provider = ModelProviderRegistry.get_provider(provider_type)
if provider:
provider_instances[provider_type] = provider
diff --git a/tests/test_dial_provider.py b/tests/test_dial_provider.py
new file mode 100644
index 0000000..4a22cb6
--- /dev/null
+++ b/tests/test_dial_provider.py
@@ -0,0 +1,273 @@
+"""Tests for DIAL provider implementation."""
+
+import os
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from providers.base import ProviderType
+from providers.dial import DIALModelProvider
+
+
+class TestDIALProvider:
+ """Test DIAL provider functionality."""
+
+ @patch.dict(os.environ, {"DIAL_API_KEY": "test-key", "DIAL_API_HOST": "https://test.dialx.ai"})
+ def test_initialization_with_host(self):
+ """Test provider initialization with custom host."""
+ provider = DIALModelProvider("test-key")
+ assert provider._dial_api_key == "test-key" # Check internal API key storage
+ assert provider.api_key == "placeholder-not-used" # OpenAI client uses placeholder, auth header removed by hook
+ assert provider.base_url == "https://test.dialx.ai/openai"
+ assert provider.get_provider_type() == ProviderType.DIAL
+
+ @patch.dict(os.environ, {"DIAL_API_KEY": "test-key", "DIAL_API_HOST": ""}, clear=True)
+ def test_initialization_default_host(self):
+ """Test provider initialization with default host."""
+ provider = DIALModelProvider("test-key")
+ assert provider._dial_api_key == "test-key" # Check internal API key storage
+ assert provider.api_key == "placeholder-not-used" # OpenAI client uses placeholder, auth header removed by hook
+ assert provider.base_url == "https://core.dialx.ai/openai"
+
+ def test_initialization_host_normalization(self):
+ """Test that host URL is normalized to include /openai suffix."""
+ # Test with host missing /openai
+ provider = DIALModelProvider("test-key", base_url="https://custom.dialx.ai")
+ assert provider.base_url == "https://custom.dialx.ai/openai"
+
+ # Test with host already having /openai
+ provider = DIALModelProvider("test-key", base_url="https://custom.dialx.ai/openai")
+ assert provider.base_url == "https://custom.dialx.ai/openai"
+
+ @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
+ @patch("utils.model_restrictions._restriction_service", None)
+ def test_model_validation(self):
+ """Test model name validation."""
+ provider = DIALModelProvider("test-key")
+
+ # Test valid models
+ assert provider.validate_model_name("o3-2025-04-16") is True
+ assert provider.validate_model_name("o3") is True # Shorthand
+ assert provider.validate_model_name("anthropic.claude-opus-4-20250514-v1:0") is True
+ assert provider.validate_model_name("opus-4") is True # Shorthand
+ assert provider.validate_model_name("gemini-2.5-pro-preview-05-06") is True
+ assert provider.validate_model_name("gemini-2.5-pro") is True # Shorthand
+
+ # Test invalid model
+ assert provider.validate_model_name("invalid-model") is False
+
+ def test_resolve_model_name(self):
+ """Test model name resolution for shorthands."""
+ provider = DIALModelProvider("test-key")
+
+ # Test shorthand resolution
+ assert provider._resolve_model_name("o3") == "o3-2025-04-16"
+ assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16"
+ assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0"
+ assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
+ assert provider._resolve_model_name("gemini-2.5-pro") == "gemini-2.5-pro-preview-05-06"
+ assert provider._resolve_model_name("gemini-2.5-flash") == "gemini-2.5-flash-preview-05-20"
+
+ # Test full name passthrough
+ assert provider._resolve_model_name("o3-2025-04-16") == "o3-2025-04-16"
+ assert (
+ provider._resolve_model_name("anthropic.claude-opus-4-20250514-v1:0")
+ == "anthropic.claude-opus-4-20250514-v1:0"
+ )
+
+ @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
+ @patch("utils.model_restrictions._restriction_service", None)
+ def test_get_capabilities(self):
+ """Test getting model capabilities."""
+ provider = DIALModelProvider("test-key")
+
+ # Test O3 capabilities
+ capabilities = provider.get_capabilities("o3")
+ assert capabilities.model_name == "o3-2025-04-16"
+ assert capabilities.friendly_name == "DIAL"
+ assert capabilities.context_window == 200_000
+ assert capabilities.provider == ProviderType.DIAL
+ assert capabilities.supports_images is True
+ assert capabilities.supports_extended_thinking is False
+
+ # Test Claude 4 capabilities
+ capabilities = provider.get_capabilities("opus-4")
+ assert capabilities.model_name == "anthropic.claude-opus-4-20250514-v1:0"
+ assert capabilities.context_window == 200_000
+ assert capabilities.supports_images is True
+ assert capabilities.supports_extended_thinking is False
+
+ # Test Claude 4 with thinking mode
+ capabilities = provider.get_capabilities("opus-4-thinking")
+ assert capabilities.model_name == "anthropic.claude-opus-4-20250514-v1:0-with-thinking"
+ assert capabilities.context_window == 200_000
+ assert capabilities.supports_images is True
+ assert capabilities.supports_extended_thinking is True
+
+ # Test Gemini capabilities
+ capabilities = provider.get_capabilities("gemini-2.5-pro")
+ assert capabilities.model_name == "gemini-2.5-pro-preview-05-06"
+ assert capabilities.context_window == 1_000_000
+ assert capabilities.supports_images is True
+
+ # Test temperature constraint
+ assert capabilities.temperature_constraint.min_temp == 0.0
+ assert capabilities.temperature_constraint.max_temp == 2.0
+ assert capabilities.temperature_constraint.default_temp == 0.7
+
+ @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
+ @patch("utils.model_restrictions._restriction_service", None)
+ def test_get_capabilities_invalid_model(self):
+ """Test that get_capabilities raises for invalid models."""
+ provider = DIALModelProvider("test-key")
+
+ with pytest.raises(ValueError, match="Unsupported DIAL model"):
+ provider.get_capabilities("invalid-model")
+
+ @patch("utils.model_restrictions.get_restriction_service")
+ def test_get_capabilities_restricted_model(self, mock_get_restriction):
+ """Test that get_capabilities respects model restrictions."""
+ provider = DIALModelProvider("test-key")
+
+ # Mock restriction service to block the model
+ mock_service = MagicMock()
+ mock_service.is_allowed.return_value = False
+ mock_get_restriction.return_value = mock_service
+
+ with pytest.raises(ValueError, match="not allowed by restriction policy"):
+ provider.get_capabilities("o3")
+
+ @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
+ @patch("utils.model_restrictions._restriction_service", None)
+ def test_supports_vision(self):
+ """Test vision support detection."""
+ provider = DIALModelProvider("test-key")
+
+ # Test models with vision support
+ assert provider._supports_vision("o3-2025-04-16") is True
+ assert provider._supports_vision("o3") is True # Via resolution
+ assert provider._supports_vision("anthropic.claude-opus-4-20250514-v1:0") is True
+ assert provider._supports_vision("gemini-2.5-pro-preview-05-06") is True
+
+ # Test unknown model (falls back to parent implementation)
+ assert provider._supports_vision("unknown-model") is False
+
+ @patch("openai.OpenAI") # Mock the OpenAI class directly from openai module
+ def test_generate_content_with_alias(self, mock_openai_class):
+ """Test that generate_content properly resolves aliases and uses deployment routing."""
+ # Create mock client
+ mock_client = MagicMock()
+ mock_response = MagicMock()
+ mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
+ mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
+ mock_response.model = "gpt-4"
+ mock_response.id = "test-id"
+ mock_response.created = 1234567890
+ mock_response.choices[0].finish_reason = "stop"
+
+ mock_client.chat.completions.create.return_value = mock_response
+ mock_openai_class.return_value = mock_client
+
+ provider = DIALModelProvider("test-key")
+
+ # Generate content with shorthand
+ response = provider.generate_content(prompt="Test prompt", model_name="o3", temperature=0.7) # Shorthand
+
+ # Verify OpenAI was instantiated with deployment-specific URL
+ mock_openai_class.assert_called_once()
+ call_args = mock_openai_class.call_args
+ assert "/deployments/o3-2025-04-16" in call_args[1]["base_url"]
+
+ # Verify the resolved model name was passed to the API
+ mock_client.chat.completions.create.assert_called_once()
+ create_call_args = mock_client.chat.completions.create.call_args
+ assert create_call_args[1]["model"] == "o3-2025-04-16" # Resolved name
+
+ # Verify response
+ assert response.content == "Test response"
+ assert response.model_name == "o3" # Original name preserved
+ assert response.metadata["model"] == "gpt-4" # API returned model name from mock
+
+ def test_provider_type(self):
+ """Test provider type identification."""
+ provider = DIALModelProvider("test-key")
+ assert provider.get_provider_type() == ProviderType.DIAL
+
+ def test_friendly_name(self):
+ """Test provider friendly name."""
+ provider = DIALModelProvider("test-key")
+ assert provider.FRIENDLY_NAME == "DIAL"
+
+ @patch.dict(os.environ, {"DIAL_API_VERSION": "2024-12-01"})
+ def test_configurable_api_version(self):
+ """Test that API version can be configured via environment variable."""
+ provider = DIALModelProvider("test-key")
+ # Check that the custom API version is stored
+ assert provider.api_version == "2024-12-01"
+
+ def test_default_api_version(self):
+ """Test that default API version is used when not configured."""
+ # Clear any existing DIAL_API_VERSION from environment
+ with patch.dict(os.environ, {}, clear=True):
+ # Keep other env vars but ensure DIAL_API_VERSION is not set
+ if "DIAL_API_VERSION" in os.environ:
+ del os.environ["DIAL_API_VERSION"]
+
+ provider = DIALModelProvider("test-key")
+ # Check that the default API version is used
+ assert provider.api_version == "2024-12-01-preview"
+ # Check that Api-Key header is set
+ assert provider.DEFAULT_HEADERS["Api-Key"] == "test-key"
+
+ @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": "o3-2025-04-16,anthropic.claude-opus-4-20250514-v1:0"})
+ @patch("utils.model_restrictions._restriction_service", None)
+ def test_allowed_models_restriction(self):
+ """Test model allow-list functionality."""
+ provider = DIALModelProvider("test-key")
+
+ # These should be allowed
+ assert provider.validate_model_name("o3-2025-04-16") is True
+ assert provider.validate_model_name("o3") is True # Alias for o3-2025-04-16
+ assert provider.validate_model_name("anthropic.claude-opus-4-20250514-v1:0") is True
+ assert provider.validate_model_name("opus-4") is True # Resolves to anthropic.claude-opus-4-20250514-v1:0
+
+ # These should be blocked
+ assert provider.validate_model_name("gemini-2.5-pro-preview-05-06") is False
+ assert provider.validate_model_name("o4-mini-2025-04-16") is False
+ assert provider.validate_model_name("sonnet-4") is False # sonnet-4 is not in allowed list
+
+ @patch("httpx.Client")
+ @patch("openai.OpenAI")
+ def test_close_method(self, mock_openai_class, mock_httpx_client_class):
+ """Test that the close method properly closes HTTP clients."""
+ # Mock the httpx.Client instance that DIALModelProvider will create
+ mock_shared_http_client = MagicMock()
+ mock_httpx_client_class.return_value = mock_shared_http_client
+
+ # Mock the OpenAI client instances
+ mock_openai_client_1 = MagicMock()
+ mock_openai_client_2 = MagicMock()
+ # Configure side_effect to return different mocks for subsequent calls
+ mock_openai_class.side_effect = [mock_openai_client_1, mock_openai_client_2]
+
+ provider = DIALModelProvider("test-key")
+
+ # Mock the superclass's _client attribute directly
+ mock_superclass_client = MagicMock()
+ provider._client = mock_superclass_client
+
+ # Simulate getting clients for two different deployments to populate _deployment_clients
+ provider._get_deployment_client("model_a")
+ provider._get_deployment_client("model_b")
+
+ # Now call close
+ provider.close()
+
+ # Assert that the shared httpx client's close method was called
+ mock_shared_http_client.close.assert_called_once()
+
+ # Assert that the superclass client's close method was called
+ mock_superclass_client.close.assert_called_once()
+
+ # Assert that the deployment clients cache is cleared
+ assert not provider._deployment_clients
diff --git a/tools/listmodels.py b/tools/listmodels.py
index 6a623b9..265fbcc 100644
--- a/tools/listmodels.py
+++ b/tools/listmodels.py
@@ -84,6 +84,7 @@ class ListModelsTool(BaseTool):
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
ProviderType.OPENAI: {"name": "OpenAI", "env_key": "OPENAI_API_KEY"},
ProviderType.XAI: {"name": "X.AI (Grok)", "env_key": "XAI_API_KEY"},
+ ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
}
# Check each native provider type
diff --git a/utils/model_restrictions.py b/utils/model_restrictions.py
index 0b7ff25..834c0a2 100644
--- a/utils/model_restrictions.py
+++ b/utils/model_restrictions.py
@@ -11,6 +11,7 @@ Environment Variables:
- GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models
- XAI_ALLOWED_MODELS: Comma-separated list of allowed X.AI GROK models
- OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models
+- DIAL_ALLOWED_MODELS: Comma-separated list of allowed DIAL models
Example:
OPENAI_ALLOWED_MODELS=o3-mini,o4-mini
@@ -44,6 +45,7 @@ class ModelRestrictionService:
ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS",
ProviderType.XAI: "XAI_ALLOWED_MODELS",
ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS",
+ ProviderType.DIAL: "DIAL_ALLOWED_MODELS",
}
def __init__(self):