feat: DIAL provider implementation (#112)
## Description This PR implements a new [DIAL](https://dialx.ai/dial_api) (Data & AI Layer) provider for the Zen MCP Server, enabling unified access to multiple AI models through the DIAL API platform. DIAL provides enterprise-grade AI model access with deployment-specific routing similar to Azure OpenAI. ## Changes Made - [x] Added support of atexit: - Ensures automatic cleanup of provider resources (HTTP clients, connection pools) on server shutdown - Fixed bug using ModelProviderRegistry.get_available_providers() instead of accessing private _providers - Works with SIGTERM/Ctrl+C for graceful shutdown in both development and containerized environments - [x] Added new DIAL provider (`providers/dial.py`) inheriting from `OpenAICompatibleProvider` - [x] Updated server.py to register DIAL provider during initialization - [x] Updated provider registry to include DIAL provider type - [x] Implemented deployment-specific routing for DIAL's Azure OpenAI-style endpoints - [x] Implemented performance optimizations: - Connection pooling with httpx for better performance - Thread-safe client caching with double-check locking pattern - Proper resource cleanup with `close()` method - [x] Added comprehensive unit tests with 16 test cases (`tests/test_dial_provider.py`) - [x] Added DIAL configuration to `.env.example` with documentation - [x] Added support for configurable API version via `DIAL_API_VERSION` environment variable - [x] Added DIAL model restrictions support via `DIAL_ALLOWED_MODELS` environment variable ### Supported DIAL Models: - OpenAI models: o3, o4-mini (and their dated versions) - Google models: gemini-2.5-pro, gemini-2.5-flash (including search variant) - Anthropic models: Claude 4 Opus/Sonnet (with and without thinking mode) ### Environment Variables: - `DIAL_API_KEY`: Required API key for DIAL authentication - `DIAL_API_HOST`: Optional base URL (defaults to https://core.dialx.ai) - `DIAL_API_VERSION`: Optional API version header (defaults to 2025-01-01-preview) - `DIAL_ALLOWED_MODELS`: Optional comma-separated list of allowed models ### Breaking Changes: - None ### Dependencies: - No new dependencies added (uses existing OpenAI SDK with custom routing)
This commit is contained in:
committed by
GitHub
parent
4ae0344b14
commit
0623ce3546
45
.env.example
45
.env.example
@@ -3,8 +3,11 @@
|
||||
|
||||
# API Keys - At least one is required
|
||||
#
|
||||
# IMPORTANT: Use EITHER OpenRouter OR native APIs (Gemini/OpenAI), not both!
|
||||
# Having both creates ambiguity about which provider serves each model.
|
||||
# IMPORTANT: Choose ONE approach:
|
||||
# - Native APIs (Gemini/OpenAI/XAI) for direct access
|
||||
# - DIAL for unified enterprise access
|
||||
# - OpenRouter for unified cloud access
|
||||
# Having multiple unified providers creates ambiguity about which serves each model.
|
||||
#
|
||||
# Option 1: Use native APIs (recommended for direct access)
|
||||
# Get your Gemini API key from: https://makersuite.google.com/app/apikey
|
||||
@@ -16,6 +19,12 @@ OPENAI_API_KEY=your_openai_api_key_here
|
||||
# Get your X.AI API key from: https://console.x.ai/
|
||||
XAI_API_KEY=your_xai_api_key_here
|
||||
|
||||
# Get your DIAL API key and configure host URL
|
||||
# DIAL provides unified access to multiple AI models through a single API
|
||||
DIAL_API_KEY=your_dial_api_key_here
|
||||
# DIAL_API_HOST=https://core.dialx.ai # Optional: Base URL without /openai suffix (auto-appended)
|
||||
# DIAL_API_VERSION=2025-01-01-preview # Optional: API version header for DIAL requests
|
||||
|
||||
# Option 2: Use OpenRouter for access to multiple models through one API
|
||||
# Get your OpenRouter API key from: https://openrouter.ai/
|
||||
# If using OpenRouter, comment out the native API keys above
|
||||
@@ -27,7 +36,8 @@ OPENROUTER_API_KEY=your_openrouter_api_key_here
|
||||
# CUSTOM_MODEL_NAME=llama3.2 # Default model name
|
||||
|
||||
# Optional: Default model to use
|
||||
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high' etc
|
||||
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high',
|
||||
# 'grok', 'opus-4', 'sonnet-4', or any DIAL model if DIAL is configured
|
||||
# When set to 'auto', Claude will select the best model for each task
|
||||
# Defaults to 'auto' if not specified
|
||||
DEFAULT_MODEL=auto
|
||||
@@ -70,6 +80,26 @@ DEFAULT_THINKING_MODE_THINKDEEP=high
|
||||
# - grok3 (shorthand for grok-3)
|
||||
# - grokfast (shorthand for grok-3-fast)
|
||||
#
|
||||
# Supported DIAL models (when available in your DIAL deployment):
|
||||
# - o3-2025-04-16 (200K context, latest O3 release)
|
||||
# - o4-mini-2025-04-16 (200K context, latest O4 mini)
|
||||
# - o3 (shorthand for o3-2025-04-16)
|
||||
# - o4-mini (shorthand for o4-mini-2025-04-16)
|
||||
# - anthropic.claude-sonnet-4-20250514-v1:0 (200K context, Claude 4 Sonnet)
|
||||
# - anthropic.claude-sonnet-4-20250514-v1:0-with-thinking (200K context, Claude 4 Sonnet with thinking mode)
|
||||
# - anthropic.claude-opus-4-20250514-v1:0 (200K context, Claude 4 Opus)
|
||||
# - anthropic.claude-opus-4-20250514-v1:0-with-thinking (200K context, Claude 4 Opus with thinking mode)
|
||||
# - sonnet-4 (shorthand for Claude 4 Sonnet)
|
||||
# - sonnet-4-thinking (shorthand for Claude 4 Sonnet with thinking)
|
||||
# - opus-4 (shorthand for Claude 4 Opus)
|
||||
# - opus-4-thinking (shorthand for Claude 4 Opus with thinking)
|
||||
# - gemini-2.5-pro-preview-03-25-google-search (1M context, with Google Search)
|
||||
# - gemini-2.5-pro-preview-05-06 (1M context, latest preview)
|
||||
# - gemini-2.5-flash-preview-05-20 (1M context, latest flash preview)
|
||||
# - gemini-2.5-pro (shorthand for gemini-2.5-pro-preview-05-06)
|
||||
# - gemini-2.5-pro-search (shorthand for gemini-2.5-pro-preview-03-25-google-search)
|
||||
# - gemini-2.5-flash (shorthand for gemini-2.5-flash-preview-05-20)
|
||||
#
|
||||
# Examples:
|
||||
# OPENAI_ALLOWED_MODELS=o3-mini,o4-mini,mini # Only allow mini models (cost control)
|
||||
# GOOGLE_ALLOWED_MODELS=flash # Only allow Flash (fast responses)
|
||||
@@ -77,21 +107,26 @@ DEFAULT_THINKING_MODE_THINKDEEP=high
|
||||
# OPENAI_ALLOWED_MODELS=o4-mini # Single model standardization
|
||||
# GOOGLE_ALLOWED_MODELS=flash,pro # Allow both Gemini models
|
||||
# XAI_ALLOWED_MODELS=grok,grok-3-fast # Allow both GROK variants
|
||||
# DIAL_ALLOWED_MODELS=o3,o4-mini # Only allow O3/O4 models via DIAL
|
||||
# DIAL_ALLOWED_MODELS=opus-4,sonnet-4 # Only Claude 4 models (without thinking)
|
||||
# DIAL_ALLOWED_MODELS=opus-4-thinking,sonnet-4-thinking # Only Claude 4 with thinking mode
|
||||
# DIAL_ALLOWED_MODELS=gemini-2.5-pro,gemini-2.5-flash # Only Gemini 2.5 models via DIAL
|
||||
#
|
||||
# Note: These restrictions apply even in 'auto' mode - Claude will only pick from allowed models
|
||||
# OPENAI_ALLOWED_MODELS=
|
||||
# GOOGLE_ALLOWED_MODELS=
|
||||
# XAI_ALLOWED_MODELS=
|
||||
# DIAL_ALLOWED_MODELS=
|
||||
|
||||
# Optional: Custom model configuration file path
|
||||
# Override the default location of custom_models.json
|
||||
# CUSTOM_MODELS_CONFIG_PATH=/path/to/your/custom_models.json
|
||||
|
||||
# Note: Redis is no longer used - conversations are stored in memory
|
||||
# Note: Conversations are stored in memory during the session
|
||||
|
||||
# Optional: Conversation timeout (hours)
|
||||
# How long AI-to-AI conversation threads persist before expiring
|
||||
# Longer timeouts use more Redis memory but allow resuming conversations later
|
||||
# Longer timeouts use more memory but allow resuming conversations later
|
||||
# Defaults to 3 hours if not specified
|
||||
CONVERSATION_TIMEOUT_HOURS=3
|
||||
|
||||
|
||||
14
README.md
14
README.md
@@ -3,7 +3,7 @@
|
||||
[zen_web.webm](https://github.com/user-attachments/assets/851e3911-7f06-47c0-a4ab-a2601236697c)
|
||||
|
||||
<div align="center">
|
||||
<b>🤖 Claude + [Gemini / OpenAI / Grok / OpenRouter / Ollama / Any Model] = Your Ultimate AI Development Team</b>
|
||||
<b>🤖 Claude + [Gemini / OpenAI / Grok / OpenRouter / DIAL / Ollama / Any Model] = Your Ultimate AI Development Team</b>
|
||||
</div>
|
||||
|
||||
<br/>
|
||||
@@ -145,6 +145,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan
|
||||
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models.
|
||||
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) to get an API key for O3 model access.
|
||||
- **X.AI**: Visit [X.AI Console](https://console.x.ai/) to get an API key for GROK model access.
|
||||
- **DIAL**: Visit [DIAL Platform](https://dialx.ai/) to get an API key for accessing multiple models through their unified API. DIAL is an open-source AI orchestration platform that provides vendor-agnostic access to models from major providers, open-source community, and self-hosted deployments. [API Documentation](https://dialx.ai/dial_api)
|
||||
|
||||
**Option C: Custom API Endpoints (Local models like Ollama, vLLM)**
|
||||
[Please see the setup guide](docs/custom_models.md#option-2-custom-api-setup-ollama-vllm-etc). With a custom API you can use:
|
||||
@@ -154,7 +155,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan
|
||||
- **Text Generation WebUI**: Popular local interface for running models
|
||||
- **Any OpenAI-compatible API**: Custom endpoints for your own infrastructure
|
||||
|
||||
> **Note:** Using all three options may create ambiguity about which provider / model to use if there is an overlap.
|
||||
> **Note:** Using multiple provider options may create ambiguity about which provider / model to use if there is an overlap.
|
||||
> If all APIs are configured, native APIs will take priority when there is a clash in model name, such as for `gemini` and `o3`.
|
||||
> Configure your model aliases and give them unique names in [`conf/custom_models.json`](conf/custom_models.json)
|
||||
|
||||
@@ -192,6 +193,12 @@ nano .env
|
||||
# GEMINI_API_KEY=your-gemini-api-key-here # For Gemini models
|
||||
# OPENAI_API_KEY=your-openai-api-key-here # For O3 model
|
||||
# OPENROUTER_API_KEY=your-openrouter-key # For OpenRouter (see docs/custom_models.md)
|
||||
# DIAL_API_KEY=your-dial-api-key-here # For DIAL platform
|
||||
|
||||
# For DIAL (optional configuration):
|
||||
# DIAL_API_HOST=https://core.dialx.ai # Default DIAL host (optional)
|
||||
# DIAL_API_VERSION=2024-12-01-preview # API version (optional)
|
||||
# DIAL_ALLOWED_MODELS=o3,gemini-2.5-pro # Restrict to specific models (optional)
|
||||
|
||||
# For local models (Ollama, vLLM, etc.):
|
||||
# CUSTOM_API_URL=http://localhost:11434/v1 # Ollama example
|
||||
@@ -537,10 +544,11 @@ Configure the Zen MCP Server through environment variables in your `.env` file.
|
||||
DEFAULT_MODEL=auto
|
||||
GEMINI_API_KEY=your-gemini-key
|
||||
OPENAI_API_KEY=your-openai-key
|
||||
DIAL_API_KEY=your-dial-key # Optional: Access to multiple models via DIAL
|
||||
```
|
||||
|
||||
**Key Configuration Options:**
|
||||
- **API Keys**: Native APIs (Gemini, OpenAI, X.AI), OpenRouter, or Custom endpoints (Ollama, vLLM)
|
||||
- **API Keys**: Native APIs (Gemini, OpenAI, X.AI), OpenRouter, DIAL, or Custom endpoints (Ollama, vLLM)
|
||||
- **Model Selection**: Auto mode or specific model defaults
|
||||
- **Usage Restrictions**: Control which models can be used for cost control
|
||||
- **Conversation Settings**: Timeout, turn limits, memory configuration
|
||||
|
||||
@@ -17,6 +17,7 @@ class ProviderType(Enum):
|
||||
XAI = "xai"
|
||||
OPENROUTER = "openrouter"
|
||||
CUSTOM = "custom"
|
||||
DIAL = "dial"
|
||||
|
||||
|
||||
class TemperatureConstraint(ABC):
|
||||
@@ -326,3 +327,12 @@ class ModelProvider(ABC):
|
||||
Resolved model name
|
||||
"""
|
||||
return model_name
|
||||
|
||||
def close(self):
|
||||
"""Clean up any resources held by the provider.
|
||||
|
||||
Default implementation does nothing.
|
||||
Subclasses should override if they hold resources that need cleanup.
|
||||
"""
|
||||
# Base implementation: no resources to clean up
|
||||
return
|
||||
|
||||
525
providers/dial.py
Normal file
525
providers/dial.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""DIAL (Data & AI Layer) model provider implementation."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"""DIAL provider using OpenAI-compatible API.
|
||||
|
||||
DIAL provides access to various AI models through a unified API interface.
|
||||
Supports GPT, Claude, Gemini, and other models via DIAL deployments.
|
||||
"""
|
||||
|
||||
FRIENDLY_NAME = "DIAL"
|
||||
|
||||
# Retry configuration for API calls
|
||||
MAX_RETRIES = 4
|
||||
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||
|
||||
# Supported DIAL models (these can be customized based on your DIAL deployment)
|
||||
SUPPORTED_MODELS = {
|
||||
"o3-2025-04-16": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"o4-mini-2025-04-16": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": True, # Thinking mode variant
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-opus-4-20250514-v1:0": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": True, # Thinking mode variant
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-pro-preview-03-25-google-search": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-pro-preview-05-06": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-flash-preview-05-20": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Shorthands
|
||||
"o3": "o3-2025-04-16",
|
||||
"o4-mini": "o4-mini-2025-04-16",
|
||||
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
||||
"opus-4": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize DIAL provider with API key and host.
|
||||
|
||||
Args:
|
||||
api_key: DIAL API key for authentication
|
||||
**kwargs: Additional configuration options
|
||||
"""
|
||||
# Get DIAL API host from environment or kwargs
|
||||
dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai"
|
||||
|
||||
# DIAL uses /openai endpoint for OpenAI-compatible API
|
||||
if not dial_host.endswith("/openai"):
|
||||
dial_host = f"{dial_host.rstrip('/')}/openai"
|
||||
|
||||
kwargs["base_url"] = dial_host
|
||||
|
||||
# Get API version from environment or use default
|
||||
self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview")
|
||||
|
||||
# Add DIAL-specific headers
|
||||
# DIAL uses Api-Key header instead of Authorization: Bearer
|
||||
# Reference: https://dialx.ai/dial_api#section/Authorization
|
||||
self.DEFAULT_HEADERS = {
|
||||
"Api-Key": api_key,
|
||||
}
|
||||
|
||||
# Store the actual API key for use in Api-Key header
|
||||
self._dial_api_key = api_key
|
||||
|
||||
# Pass a placeholder API key to OpenAI client - we'll override the auth header in httpx
|
||||
# The actual authentication happens via the Api-Key header in the httpx client
|
||||
super().__init__("placeholder-not-used", **kwargs)
|
||||
|
||||
# Cache for deployment-specific clients to avoid recreating them on each request
|
||||
self._deployment_clients = {}
|
||||
# Lock to ensure thread-safe client creation
|
||||
self._client_lock = threading.Lock()
|
||||
|
||||
# Create a SINGLE shared httpx client for the provider instance
|
||||
import httpx
|
||||
|
||||
# Create custom event hooks to remove Authorization header
|
||||
def remove_auth_header(request):
|
||||
"""Remove Authorization header that OpenAI client adds."""
|
||||
# httpx headers are case-insensitive, so we need to check all variations
|
||||
headers_to_remove = []
|
||||
for header_name in request.headers:
|
||||
if header_name.lower() == "authorization":
|
||||
headers_to_remove.append(header_name)
|
||||
|
||||
for header_name in headers_to_remove:
|
||||
del request.headers[header_name]
|
||||
|
||||
self._http_client = httpx.Client(
|
||||
timeout=self.timeout_config,
|
||||
verify=True,
|
||||
follow_redirects=True,
|
||||
headers=self.DEFAULT_HEADERS.copy(), # Include DIAL headers including Api-Key
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=5,
|
||||
max_connections=10,
|
||||
keepalive_expiry=30.0,
|
||||
),
|
||||
event_hooks={"request": [remove_auth_header]},
|
||||
)
|
||||
|
||||
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (can be shorthand)
|
||||
|
||||
Returns:
|
||||
ModelCapabilities object
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not supported or not allowed
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported DIAL model: {model_name}")
|
||||
|
||||
# Check restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name=resolved_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_images=config.get("supports_vision", False),
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
)
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
return ProviderType.DIAL
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
True if model is supported and allowed, False otherwise
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return False
|
||||
|
||||
# Check against base class allowed_models if configured
|
||||
if self.allowed_models is not None:
|
||||
# Check both original and resolved names (case-insensitive)
|
||||
if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models:
|
||||
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list")
|
||||
return False
|
||||
|
||||
# Also check restrictions via ModelRestrictionService
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
|
||||
Args:
|
||||
model_name: Model name or shorthand
|
||||
|
||||
Returns:
|
||||
Full model name
|
||||
"""
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
def _get_deployment_client(self, deployment: str):
|
||||
"""Get or create a cached client for a specific deployment.
|
||||
|
||||
This avoids recreating OpenAI clients on every request, improving performance.
|
||||
Reuses the shared HTTP client for connection pooling.
|
||||
|
||||
Args:
|
||||
deployment: The deployment/model name
|
||||
|
||||
Returns:
|
||||
OpenAI client configured for the specific deployment
|
||||
"""
|
||||
# Check if client already exists without locking for performance
|
||||
if deployment in self._deployment_clients:
|
||||
return self._deployment_clients[deployment]
|
||||
|
||||
# Use lock to ensure thread-safe client creation
|
||||
with self._client_lock:
|
||||
# Double-check pattern: check again inside the lock
|
||||
if deployment not in self._deployment_clients:
|
||||
from openai import OpenAI
|
||||
|
||||
# Build deployment-specific URL
|
||||
base_url = str(self.client.base_url)
|
||||
if base_url.endswith("/"):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
# Remove /openai suffix if present to reconstruct properly
|
||||
if base_url.endswith("/openai"):
|
||||
base_url = base_url[:-7]
|
||||
|
||||
deployment_url = f"{base_url}/openai/deployments/{deployment}"
|
||||
|
||||
# Create and cache the client, REUSING the shared http_client
|
||||
# Use placeholder API key - Authorization header will be removed by http_client event hook
|
||||
self._deployment_clients[deployment] = OpenAI(
|
||||
api_key="placeholder-not-used",
|
||||
base_url=deployment_url,
|
||||
http_client=self._http_client, # Pass the shared client with Api-Key header
|
||||
default_query={"api-version": self.api_version}, # Add api-version as query param
|
||||
)
|
||||
|
||||
return self._deployment_clients[deployment]
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
images: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Generate content using DIAL's deployment-specific endpoint.
|
||||
|
||||
DIAL uses Azure OpenAI-style deployment endpoints:
|
||||
/openai/deployments/{deployment}/chat/completions
|
||||
|
||||
Args:
|
||||
prompt: User prompt
|
||||
model_name: Model name or alias
|
||||
system_prompt: Optional system prompt
|
||||
temperature: Sampling temperature
|
||||
max_output_tokens: Maximum tokens to generate
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated content and metadata
|
||||
"""
|
||||
# Validate model name against allow-list
|
||||
if not self.validate_model_name(model_name):
|
||||
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
|
||||
|
||||
# Validate parameters
|
||||
self.validate_parameters(model_name, temperature)
|
||||
|
||||
# Prepare messages
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
# Build user message content
|
||||
user_message_content = []
|
||||
if prompt:
|
||||
user_message_content.append({"type": "text", "text": prompt})
|
||||
|
||||
if images and self._supports_vision(model_name):
|
||||
for img_path in images:
|
||||
processed_image = self._process_image(img_path)
|
||||
if processed_image:
|
||||
user_message_content.append(processed_image)
|
||||
elif images:
|
||||
logger.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
|
||||
|
||||
# Add user message. If only text, content will be a string, otherwise a list.
|
||||
if len(user_message_content) == 1 and user_message_content[0]["type"] == "text":
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
else:
|
||||
messages.append({"role": "user", "content": user_message_content})
|
||||
|
||||
# Resolve model name
|
||||
resolved_model = self._resolve_model_name(model_name)
|
||||
|
||||
# Build completion parameters
|
||||
completion_params = {
|
||||
"model": resolved_model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
# Check model capabilities
|
||||
try:
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
supports_temperature = getattr(capabilities, "supports_temperature", True)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
||||
supports_temperature = True
|
||||
|
||||
# Add temperature parameter if supported
|
||||
if supports_temperature:
|
||||
completion_params["temperature"] = temperature
|
||||
|
||||
# Add max tokens if specified and model supports it
|
||||
if max_output_tokens and supports_temperature:
|
||||
completion_params["max_tokens"] = max_output_tokens
|
||||
|
||||
# Add additional parameters
|
||||
for key, value in kwargs.items():
|
||||
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||||
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]:
|
||||
continue
|
||||
completion_params[key] = value
|
||||
|
||||
# DIAL-specific: Get cached client for deployment endpoint
|
||||
deployment_client = self._get_deployment_client(resolved_model)
|
||||
|
||||
# Retry logic with progressive delays
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
# Generate completion using deployment-specific client
|
||||
response = deployment_client.chat.completions.create(**completion_params)
|
||||
|
||||
# Extract content and usage
|
||||
content = response.choices[0].message.content
|
||||
usage = self._extract_usage(response)
|
||||
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
usage=usage,
|
||||
model_name=model_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
provider=self.get_provider_type(),
|
||||
metadata={
|
||||
"finish_reason": response.choices[0].finish_reason,
|
||||
"model": response.model,
|
||||
"id": response.id,
|
||||
"created": response.created,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Check if this is a retryable error
|
||||
is_retryable = self._is_error_retryable(e)
|
||||
|
||||
if not is_retryable:
|
||||
# Non-retryable error, raise immediately
|
||||
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
|
||||
|
||||
# If this isn't the last attempt and error is retryable, wait and retry
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
delay = self.RETRY_DELAYS[attempt]
|
||||
logger.info(
|
||||
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
continue
|
||||
|
||||
# All retries exhausted
|
||||
raise ValueError(
|
||||
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
||||
)
|
||||
|
||||
def _supports_vision(self, model_name: str) -> bool:
|
||||
"""Check if the model supports vision (image processing).
|
||||
|
||||
Args:
|
||||
model_name: Model name to check
|
||||
|
||||
Returns:
|
||||
True if model supports vision, False otherwise
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
|
||||
|
||||
# Fall back to parent implementation for unknown models
|
||||
return super()._supports_vision(model_name)
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
# Get all model keys (both full names and aliases)
|
||||
all_models = list(self.SUPPORTED_MODELS.keys())
|
||||
|
||||
if not respect_restrictions:
|
||||
return all_models
|
||||
|
||||
# Apply restrictions if configured
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
|
||||
# Filter based on restrictions
|
||||
allowed_models = []
|
||||
for model in all_models:
|
||||
resolved_name = self._resolve_model_name(model)
|
||||
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
|
||||
allowed_models.append(model)
|
||||
|
||||
return allowed_models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
This is used for validation purposes to ensure restriction policies
|
||||
can validate against both aliases and their target model names.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
# Collect all unique model names (both aliases and targets)
|
||||
all_models = set()
|
||||
|
||||
for key, value in self.SUPPORTED_MODELS.items():
|
||||
# Add the key (could be alias or full name)
|
||||
all_models.add(key)
|
||||
|
||||
# If it's an alias (string value), add the target too
|
||||
if isinstance(value, str):
|
||||
all_models.add(value)
|
||||
|
||||
return sorted(all_models)
|
||||
|
||||
def close(self):
|
||||
"""Clean up HTTP clients when provider is closed."""
|
||||
logger.info("Closing DIAL provider HTTP clients...")
|
||||
|
||||
# Clear the deployment clients cache
|
||||
# Note: We don't need to close individual OpenAI clients since they
|
||||
# use the shared httpx.Client which we close separately
|
||||
self._deployment_clients.clear()
|
||||
|
||||
# Close the shared HTTP client
|
||||
if hasattr(self, "_http_client"):
|
||||
try:
|
||||
self._http_client.close()
|
||||
logger.debug("Closed shared HTTP client")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing shared HTTP client: {e}")
|
||||
|
||||
# Also close the client created by the superclass (OpenAICompatibleProvider)
|
||||
# as it holds its own httpx.Client instance that is not used by DIAL's generate_content
|
||||
if hasattr(self, "client") and self.client and hasattr(self.client, "close"):
|
||||
try:
|
||||
self.client.close()
|
||||
logger.debug("Closed superclass's OpenAI client")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing superclass's OpenAI client: {e}")
|
||||
@@ -118,6 +118,7 @@ class ModelProviderRegistry:
|
||||
ProviderType.GOOGLE, # Direct Gemini access
|
||||
ProviderType.OPENAI, # Direct OpenAI access
|
||||
ProviderType.XAI, # Direct X.AI GROK access
|
||||
ProviderType.DIAL, # DIAL unified API access
|
||||
ProviderType.CUSTOM, # Local/self-hosted models
|
||||
ProviderType.OPENROUTER, # Catch-all for cloud models
|
||||
]
|
||||
@@ -237,6 +238,7 @@ class ModelProviderRegistry:
|
||||
ProviderType.XAI: "XAI_API_KEY",
|
||||
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
||||
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
|
||||
ProviderType.DIAL: "DIAL_API_KEY",
|
||||
}
|
||||
|
||||
env_var = key_mapping.get(provider_type)
|
||||
|
||||
@@ -883,6 +883,7 @@ setup_env_file() {
|
||||
"GEMINI_API_KEY:your_gemini_api_key_here"
|
||||
"OPENAI_API_KEY:your_openai_api_key_here"
|
||||
"XAI_API_KEY:your_xai_api_key_here"
|
||||
"DIAL_API_KEY:your_dial_api_key_here"
|
||||
"OPENROUTER_API_KEY:your_openrouter_api_key_here"
|
||||
)
|
||||
|
||||
@@ -934,6 +935,7 @@ validate_api_keys() {
|
||||
"GEMINI_API_KEY:your_gemini_api_key_here"
|
||||
"OPENAI_API_KEY:your_openai_api_key_here"
|
||||
"XAI_API_KEY:your_xai_api_key_here"
|
||||
"DIAL_API_KEY:your_dial_api_key_here"
|
||||
"OPENROUTER_API_KEY:your_openrouter_api_key_here"
|
||||
)
|
||||
|
||||
@@ -961,6 +963,7 @@ validate_api_keys() {
|
||||
echo " GEMINI_API_KEY=your-actual-key" >&2
|
||||
echo " OPENAI_API_KEY=your-actual-key" >&2
|
||||
echo " XAI_API_KEY=your-actual-key" >&2
|
||||
echo " DIAL_API_KEY=your-actual-key" >&2
|
||||
echo " OPENROUTER_API_KEY=your-actual-key" >&2
|
||||
echo "" >&2
|
||||
print_info "After adding your API keys, run ./run-server.sh again" >&2
|
||||
|
||||
34
server.py
34
server.py
@@ -19,6 +19,7 @@ as defined by the MCP protocol.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -271,6 +272,7 @@ def configure_providers():
|
||||
from providers import ModelProviderRegistry
|
||||
from providers.base import ProviderType
|
||||
from providers.custom import CustomProvider
|
||||
from providers.dial import DIALModelProvider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
@@ -303,6 +305,13 @@ def configure_providers():
|
||||
has_native_apis = True
|
||||
logger.info("X.AI API key found - GROK models available")
|
||||
|
||||
# Check for DIAL API key
|
||||
dial_key = os.getenv("DIAL_API_KEY")
|
||||
if dial_key and dial_key != "your_dial_api_key_here":
|
||||
valid_providers.append("DIAL")
|
||||
has_native_apis = True
|
||||
logger.info("DIAL API key found - DIAL models available")
|
||||
|
||||
# Check for OpenRouter API key
|
||||
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
|
||||
@@ -336,6 +345,8 @@ def configure_providers():
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
if xai_key and xai_key != "your_xai_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
if dial_key and dial_key != "your_dial_api_key_here":
|
||||
ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider)
|
||||
|
||||
# 2. Custom provider second (for local/private models)
|
||||
if has_custom:
|
||||
@@ -358,6 +369,7 @@ def configure_providers():
|
||||
"- GEMINI_API_KEY for Gemini models\n"
|
||||
"- OPENAI_API_KEY for OpenAI o3 model\n"
|
||||
"- XAI_API_KEY for X.AI GROK models\n"
|
||||
"- DIAL_API_KEY for DIAL models\n"
|
||||
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
|
||||
"- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)"
|
||||
)
|
||||
@@ -376,6 +388,25 @@ def configure_providers():
|
||||
if len(priority_info) > 1:
|
||||
logger.info(f"Provider priority: {' → '.join(priority_info)}")
|
||||
|
||||
# Register cleanup function for providers
|
||||
def cleanup_providers():
|
||||
"""Clean up all registered providers on shutdown."""
|
||||
try:
|
||||
registry = ModelProviderRegistry()
|
||||
if hasattr(registry, "_initialized_providers"):
|
||||
for provider in list(registry._initialized_providers.items()):
|
||||
try:
|
||||
if provider and hasattr(provider, "close"):
|
||||
provider.close()
|
||||
except Exception:
|
||||
# Logger might be closed during shutdown
|
||||
pass
|
||||
except Exception:
|
||||
# Silently ignore any errors during cleanup
|
||||
pass
|
||||
|
||||
atexit.register(cleanup_providers)
|
||||
|
||||
# Check and log model restrictions
|
||||
restriction_service = get_restriction_service()
|
||||
restrictions = restriction_service.get_restriction_summary()
|
||||
@@ -390,7 +421,8 @@ def configure_providers():
|
||||
|
||||
# Validate restrictions against known models
|
||||
provider_instances = {}
|
||||
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI]:
|
||||
provider_types_to_validate = [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]
|
||||
for provider_type in provider_types_to_validate:
|
||||
provider = ModelProviderRegistry.get_provider(provider_type)
|
||||
if provider:
|
||||
provider_instances[provider_type] = provider
|
||||
|
||||
273
tests/test_dial_provider.py
Normal file
273
tests/test_dial_provider.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Tests for DIAL provider implementation."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.dial import DIALModelProvider
|
||||
|
||||
|
||||
class TestDIALProvider:
|
||||
"""Test DIAL provider functionality."""
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_API_KEY": "test-key", "DIAL_API_HOST": "https://test.dialx.ai"})
|
||||
def test_initialization_with_host(self):
|
||||
"""Test provider initialization with custom host."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
assert provider._dial_api_key == "test-key" # Check internal API key storage
|
||||
assert provider.api_key == "placeholder-not-used" # OpenAI client uses placeholder, auth header removed by hook
|
||||
assert provider.base_url == "https://test.dialx.ai/openai"
|
||||
assert provider.get_provider_type() == ProviderType.DIAL
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_API_KEY": "test-key", "DIAL_API_HOST": ""}, clear=True)
|
||||
def test_initialization_default_host(self):
|
||||
"""Test provider initialization with default host."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
assert provider._dial_api_key == "test-key" # Check internal API key storage
|
||||
assert provider.api_key == "placeholder-not-used" # OpenAI client uses placeholder, auth header removed by hook
|
||||
assert provider.base_url == "https://core.dialx.ai/openai"
|
||||
|
||||
def test_initialization_host_normalization(self):
|
||||
"""Test that host URL is normalized to include /openai suffix."""
|
||||
# Test with host missing /openai
|
||||
provider = DIALModelProvider("test-key", base_url="https://custom.dialx.ai")
|
||||
assert provider.base_url == "https://custom.dialx.ai/openai"
|
||||
|
||||
# Test with host already having /openai
|
||||
provider = DIALModelProvider("test-key", base_url="https://custom.dialx.ai/openai")
|
||||
assert provider.base_url == "https://custom.dialx.ai/openai"
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
|
||||
@patch("utils.model_restrictions._restriction_service", None)
|
||||
def test_model_validation(self):
|
||||
"""Test model name validation."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Test valid models
|
||||
assert provider.validate_model_name("o3-2025-04-16") is True
|
||||
assert provider.validate_model_name("o3") is True # Shorthand
|
||||
assert provider.validate_model_name("anthropic.claude-opus-4-20250514-v1:0") is True
|
||||
assert provider.validate_model_name("opus-4") is True # Shorthand
|
||||
assert provider.validate_model_name("gemini-2.5-pro-preview-05-06") is True
|
||||
assert provider.validate_model_name("gemini-2.5-pro") is True # Shorthand
|
||||
|
||||
# Test invalid model
|
||||
assert provider.validate_model_name("invalid-model") is False
|
||||
|
||||
def test_resolve_model_name(self):
|
||||
"""Test model name resolution for shorthands."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Test shorthand resolution
|
||||
assert provider._resolve_model_name("o3") == "o3-2025-04-16"
|
||||
assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16"
|
||||
assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0"
|
||||
assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
assert provider._resolve_model_name("gemini-2.5-pro") == "gemini-2.5-pro-preview-05-06"
|
||||
assert provider._resolve_model_name("gemini-2.5-flash") == "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
# Test full name passthrough
|
||||
assert provider._resolve_model_name("o3-2025-04-16") == "o3-2025-04-16"
|
||||
assert (
|
||||
provider._resolve_model_name("anthropic.claude-opus-4-20250514-v1:0")
|
||||
== "anthropic.claude-opus-4-20250514-v1:0"
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
|
||||
@patch("utils.model_restrictions._restriction_service", None)
|
||||
def test_get_capabilities(self):
|
||||
"""Test getting model capabilities."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Test O3 capabilities
|
||||
capabilities = provider.get_capabilities("o3")
|
||||
assert capabilities.model_name == "o3-2025-04-16"
|
||||
assert capabilities.friendly_name == "DIAL"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.provider == ProviderType.DIAL
|
||||
assert capabilities.supports_images is True
|
||||
assert capabilities.supports_extended_thinking is False
|
||||
|
||||
# Test Claude 4 capabilities
|
||||
capabilities = provider.get_capabilities("opus-4")
|
||||
assert capabilities.model_name == "anthropic.claude-opus-4-20250514-v1:0"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.supports_images is True
|
||||
assert capabilities.supports_extended_thinking is False
|
||||
|
||||
# Test Claude 4 with thinking mode
|
||||
capabilities = provider.get_capabilities("opus-4-thinking")
|
||||
assert capabilities.model_name == "anthropic.claude-opus-4-20250514-v1:0-with-thinking"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.supports_images is True
|
||||
assert capabilities.supports_extended_thinking is True
|
||||
|
||||
# Test Gemini capabilities
|
||||
capabilities = provider.get_capabilities("gemini-2.5-pro")
|
||||
assert capabilities.model_name == "gemini-2.5-pro-preview-05-06"
|
||||
assert capabilities.context_window == 1_000_000
|
||||
assert capabilities.supports_images is True
|
||||
|
||||
# Test temperature constraint
|
||||
assert capabilities.temperature_constraint.min_temp == 0.0
|
||||
assert capabilities.temperature_constraint.max_temp == 2.0
|
||||
assert capabilities.temperature_constraint.default_temp == 0.7
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
|
||||
@patch("utils.model_restrictions._restriction_service", None)
|
||||
def test_get_capabilities_invalid_model(self):
|
||||
"""Test that get_capabilities raises for invalid models."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported DIAL model"):
|
||||
provider.get_capabilities("invalid-model")
|
||||
|
||||
@patch("utils.model_restrictions.get_restriction_service")
|
||||
def test_get_capabilities_restricted_model(self, mock_get_restriction):
|
||||
"""Test that get_capabilities respects model restrictions."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Mock restriction service to block the model
|
||||
mock_service = MagicMock()
|
||||
mock_service.is_allowed.return_value = False
|
||||
mock_get_restriction.return_value = mock_service
|
||||
|
||||
with pytest.raises(ValueError, match="not allowed by restriction policy"):
|
||||
provider.get_capabilities("o3")
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
|
||||
@patch("utils.model_restrictions._restriction_service", None)
|
||||
def test_supports_vision(self):
|
||||
"""Test vision support detection."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Test models with vision support
|
||||
assert provider._supports_vision("o3-2025-04-16") is True
|
||||
assert provider._supports_vision("o3") is True # Via resolution
|
||||
assert provider._supports_vision("anthropic.claude-opus-4-20250514-v1:0") is True
|
||||
assert provider._supports_vision("gemini-2.5-pro-preview-05-06") is True
|
||||
|
||||
# Test unknown model (falls back to parent implementation)
|
||||
assert provider._supports_vision("unknown-model") is False
|
||||
|
||||
@patch("openai.OpenAI") # Mock the OpenAI class directly from openai module
|
||||
def test_generate_content_with_alias(self, mock_openai_class):
|
||||
"""Test that generate_content properly resolves aliases and uses deployment routing."""
|
||||
# Create mock client
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
|
||||
mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
mock_response.model = "gpt-4"
|
||||
mock_response.id = "test-id"
|
||||
mock_response.created = 1234567890
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Generate content with shorthand
|
||||
response = provider.generate_content(prompt="Test prompt", model_name="o3", temperature=0.7) # Shorthand
|
||||
|
||||
# Verify OpenAI was instantiated with deployment-specific URL
|
||||
mock_openai_class.assert_called_once()
|
||||
call_args = mock_openai_class.call_args
|
||||
assert "/deployments/o3-2025-04-16" in call_args[1]["base_url"]
|
||||
|
||||
# Verify the resolved model name was passed to the API
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
create_call_args = mock_client.chat.completions.create.call_args
|
||||
assert create_call_args[1]["model"] == "o3-2025-04-16" # Resolved name
|
||||
|
||||
# Verify response
|
||||
assert response.content == "Test response"
|
||||
assert response.model_name == "o3" # Original name preserved
|
||||
assert response.metadata["model"] == "gpt-4" # API returned model name from mock
|
||||
|
||||
def test_provider_type(self):
|
||||
"""Test provider type identification."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
assert provider.get_provider_type() == ProviderType.DIAL
|
||||
|
||||
def test_friendly_name(self):
|
||||
"""Test provider friendly name."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
assert provider.FRIENDLY_NAME == "DIAL"
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_API_VERSION": "2024-12-01"})
|
||||
def test_configurable_api_version(self):
|
||||
"""Test that API version can be configured via environment variable."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
# Check that the custom API version is stored
|
||||
assert provider.api_version == "2024-12-01"
|
||||
|
||||
def test_default_api_version(self):
|
||||
"""Test that default API version is used when not configured."""
|
||||
# Clear any existing DIAL_API_VERSION from environment
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Keep other env vars but ensure DIAL_API_VERSION is not set
|
||||
if "DIAL_API_VERSION" in os.environ:
|
||||
del os.environ["DIAL_API_VERSION"]
|
||||
|
||||
provider = DIALModelProvider("test-key")
|
||||
# Check that the default API version is used
|
||||
assert provider.api_version == "2024-12-01-preview"
|
||||
# Check that Api-Key header is set
|
||||
assert provider.DEFAULT_HEADERS["Api-Key"] == "test-key"
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": "o3-2025-04-16,anthropic.claude-opus-4-20250514-v1:0"})
|
||||
@patch("utils.model_restrictions._restriction_service", None)
|
||||
def test_allowed_models_restriction(self):
|
||||
"""Test model allow-list functionality."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# These should be allowed
|
||||
assert provider.validate_model_name("o3-2025-04-16") is True
|
||||
assert provider.validate_model_name("o3") is True # Alias for o3-2025-04-16
|
||||
assert provider.validate_model_name("anthropic.claude-opus-4-20250514-v1:0") is True
|
||||
assert provider.validate_model_name("opus-4") is True # Resolves to anthropic.claude-opus-4-20250514-v1:0
|
||||
|
||||
# These should be blocked
|
||||
assert provider.validate_model_name("gemini-2.5-pro-preview-05-06") is False
|
||||
assert provider.validate_model_name("o4-mini-2025-04-16") is False
|
||||
assert provider.validate_model_name("sonnet-4") is False # sonnet-4 is not in allowed list
|
||||
|
||||
@patch("httpx.Client")
|
||||
@patch("openai.OpenAI")
|
||||
def test_close_method(self, mock_openai_class, mock_httpx_client_class):
|
||||
"""Test that the close method properly closes HTTP clients."""
|
||||
# Mock the httpx.Client instance that DIALModelProvider will create
|
||||
mock_shared_http_client = MagicMock()
|
||||
mock_httpx_client_class.return_value = mock_shared_http_client
|
||||
|
||||
# Mock the OpenAI client instances
|
||||
mock_openai_client_1 = MagicMock()
|
||||
mock_openai_client_2 = MagicMock()
|
||||
# Configure side_effect to return different mocks for subsequent calls
|
||||
mock_openai_class.side_effect = [mock_openai_client_1, mock_openai_client_2]
|
||||
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Mock the superclass's _client attribute directly
|
||||
mock_superclass_client = MagicMock()
|
||||
provider._client = mock_superclass_client
|
||||
|
||||
# Simulate getting clients for two different deployments to populate _deployment_clients
|
||||
provider._get_deployment_client("model_a")
|
||||
provider._get_deployment_client("model_b")
|
||||
|
||||
# Now call close
|
||||
provider.close()
|
||||
|
||||
# Assert that the shared httpx client's close method was called
|
||||
mock_shared_http_client.close.assert_called_once()
|
||||
|
||||
# Assert that the superclass client's close method was called
|
||||
mock_superclass_client.close.assert_called_once()
|
||||
|
||||
# Assert that the deployment clients cache is cleared
|
||||
assert not provider._deployment_clients
|
||||
@@ -84,6 +84,7 @@ class ListModelsTool(BaseTool):
|
||||
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
|
||||
ProviderType.OPENAI: {"name": "OpenAI", "env_key": "OPENAI_API_KEY"},
|
||||
ProviderType.XAI: {"name": "X.AI (Grok)", "env_key": "XAI_API_KEY"},
|
||||
ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
|
||||
}
|
||||
|
||||
# Check each native provider type
|
||||
|
||||
@@ -11,6 +11,7 @@ Environment Variables:
|
||||
- GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models
|
||||
- XAI_ALLOWED_MODELS: Comma-separated list of allowed X.AI GROK models
|
||||
- OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models
|
||||
- DIAL_ALLOWED_MODELS: Comma-separated list of allowed DIAL models
|
||||
|
||||
Example:
|
||||
OPENAI_ALLOWED_MODELS=o3-mini,o4-mini
|
||||
@@ -44,6 +45,7 @@ class ModelRestrictionService:
|
||||
ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS",
|
||||
ProviderType.XAI: "XAI_ALLOWED_MODELS",
|
||||
ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS",
|
||||
ProviderType.DIAL: "DIAL_ALLOWED_MODELS",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
|
||||
Reference in New Issue
Block a user