From 0623ce3546fd744877b4fa09cf8943b4bb5c85dc Mon Sep 17 00:00:00 2001 From: Illya Havsiyevych <44289086+illya-havsiyevych@users.noreply.github.com> Date: Mon, 23 Jun 2025 13:07:10 +0300 Subject: [PATCH] feat: DIAL provider implementation (#112) ## Description This PR implements a new [DIAL](https://dialx.ai/dial_api) (Data & AI Layer) provider for the Zen MCP Server, enabling unified access to multiple AI models through the DIAL API platform. DIAL provides enterprise-grade AI model access with deployment-specific routing similar to Azure OpenAI. ## Changes Made - [x] Added support of atexit: - Ensures automatic cleanup of provider resources (HTTP clients, connection pools) on server shutdown - Fixed bug using ModelProviderRegistry.get_available_providers() instead of accessing private _providers - Works with SIGTERM/Ctrl+C for graceful shutdown in both development and containerized environments - [x] Added new DIAL provider (`providers/dial.py`) inheriting from `OpenAICompatibleProvider` - [x] Updated server.py to register DIAL provider during initialization - [x] Updated provider registry to include DIAL provider type - [x] Implemented deployment-specific routing for DIAL's Azure OpenAI-style endpoints - [x] Implemented performance optimizations: - Connection pooling with httpx for better performance - Thread-safe client caching with double-check locking pattern - Proper resource cleanup with `close()` method - [x] Added comprehensive unit tests with 16 test cases (`tests/test_dial_provider.py`) - [x] Added DIAL configuration to `.env.example` with documentation - [x] Added support for configurable API version via `DIAL_API_VERSION` environment variable - [x] Added DIAL model restrictions support via `DIAL_ALLOWED_MODELS` environment variable ### Supported DIAL Models: - OpenAI models: o3, o4-mini (and their dated versions) - Google models: gemini-2.5-pro, gemini-2.5-flash (including search variant) - Anthropic models: Claude 4 Opus/Sonnet (with and without thinking mode) ### Environment Variables: - `DIAL_API_KEY`: Required API key for DIAL authentication - `DIAL_API_HOST`: Optional base URL (defaults to https://core.dialx.ai) - `DIAL_API_VERSION`: Optional API version header (defaults to 2025-01-01-preview) - `DIAL_ALLOWED_MODELS`: Optional comma-separated list of allowed models ### Breaking Changes: - None ### Dependencies: - No new dependencies added (uses existing OpenAI SDK with custom routing) --- .env.example | 45 +++- README.md | 14 +- providers/base.py | 10 + providers/dial.py | 525 ++++++++++++++++++++++++++++++++++++ providers/registry.py | 2 + run-server.sh | 3 + server.py | 34 ++- tests/test_dial_provider.py | 273 +++++++++++++++++++ tools/listmodels.py | 1 + utils/model_restrictions.py | 2 + 10 files changed, 900 insertions(+), 9 deletions(-) create mode 100644 providers/dial.py create mode 100644 tests/test_dial_provider.py diff --git a/.env.example b/.env.example index a7e6376..1d88d4c 100644 --- a/.env.example +++ b/.env.example @@ -3,8 +3,11 @@ # API Keys - At least one is required # -# IMPORTANT: Use EITHER OpenRouter OR native APIs (Gemini/OpenAI), not both! -# Having both creates ambiguity about which provider serves each model. +# IMPORTANT: Choose ONE approach: +# - Native APIs (Gemini/OpenAI/XAI) for direct access +# - DIAL for unified enterprise access +# - OpenRouter for unified cloud access +# Having multiple unified providers creates ambiguity about which serves each model. # # Option 1: Use native APIs (recommended for direct access) # Get your Gemini API key from: https://makersuite.google.com/app/apikey @@ -16,6 +19,12 @@ OPENAI_API_KEY=your_openai_api_key_here # Get your X.AI API key from: https://console.x.ai/ XAI_API_KEY=your_xai_api_key_here +# Get your DIAL API key and configure host URL +# DIAL provides unified access to multiple AI models through a single API +DIAL_API_KEY=your_dial_api_key_here +# DIAL_API_HOST=https://core.dialx.ai # Optional: Base URL without /openai suffix (auto-appended) +# DIAL_API_VERSION=2025-01-01-preview # Optional: API version header for DIAL requests + # Option 2: Use OpenRouter for access to multiple models through one API # Get your OpenRouter API key from: https://openrouter.ai/ # If using OpenRouter, comment out the native API keys above @@ -27,7 +36,8 @@ OPENROUTER_API_KEY=your_openrouter_api_key_here # CUSTOM_MODEL_NAME=llama3.2 # Default model name # Optional: Default model to use -# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high' etc +# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high', +# 'grok', 'opus-4', 'sonnet-4', or any DIAL model if DIAL is configured # When set to 'auto', Claude will select the best model for each task # Defaults to 'auto' if not specified DEFAULT_MODEL=auto @@ -70,6 +80,26 @@ DEFAULT_THINKING_MODE_THINKDEEP=high # - grok3 (shorthand for grok-3) # - grokfast (shorthand for grok-3-fast) # +# Supported DIAL models (when available in your DIAL deployment): +# - o3-2025-04-16 (200K context, latest O3 release) +# - o4-mini-2025-04-16 (200K context, latest O4 mini) +# - o3 (shorthand for o3-2025-04-16) +# - o4-mini (shorthand for o4-mini-2025-04-16) +# - anthropic.claude-sonnet-4-20250514-v1:0 (200K context, Claude 4 Sonnet) +# - anthropic.claude-sonnet-4-20250514-v1:0-with-thinking (200K context, Claude 4 Sonnet with thinking mode) +# - anthropic.claude-opus-4-20250514-v1:0 (200K context, Claude 4 Opus) +# - anthropic.claude-opus-4-20250514-v1:0-with-thinking (200K context, Claude 4 Opus with thinking mode) +# - sonnet-4 (shorthand for Claude 4 Sonnet) +# - sonnet-4-thinking (shorthand for Claude 4 Sonnet with thinking) +# - opus-4 (shorthand for Claude 4 Opus) +# - opus-4-thinking (shorthand for Claude 4 Opus with thinking) +# - gemini-2.5-pro-preview-03-25-google-search (1M context, with Google Search) +# - gemini-2.5-pro-preview-05-06 (1M context, latest preview) +# - gemini-2.5-flash-preview-05-20 (1M context, latest flash preview) +# - gemini-2.5-pro (shorthand for gemini-2.5-pro-preview-05-06) +# - gemini-2.5-pro-search (shorthand for gemini-2.5-pro-preview-03-25-google-search) +# - gemini-2.5-flash (shorthand for gemini-2.5-flash-preview-05-20) +# # Examples: # OPENAI_ALLOWED_MODELS=o3-mini,o4-mini,mini # Only allow mini models (cost control) # GOOGLE_ALLOWED_MODELS=flash # Only allow Flash (fast responses) @@ -77,21 +107,26 @@ DEFAULT_THINKING_MODE_THINKDEEP=high # OPENAI_ALLOWED_MODELS=o4-mini # Single model standardization # GOOGLE_ALLOWED_MODELS=flash,pro # Allow both Gemini models # XAI_ALLOWED_MODELS=grok,grok-3-fast # Allow both GROK variants +# DIAL_ALLOWED_MODELS=o3,o4-mini # Only allow O3/O4 models via DIAL +# DIAL_ALLOWED_MODELS=opus-4,sonnet-4 # Only Claude 4 models (without thinking) +# DIAL_ALLOWED_MODELS=opus-4-thinking,sonnet-4-thinking # Only Claude 4 with thinking mode +# DIAL_ALLOWED_MODELS=gemini-2.5-pro,gemini-2.5-flash # Only Gemini 2.5 models via DIAL # # Note: These restrictions apply even in 'auto' mode - Claude will only pick from allowed models # OPENAI_ALLOWED_MODELS= # GOOGLE_ALLOWED_MODELS= # XAI_ALLOWED_MODELS= +# DIAL_ALLOWED_MODELS= # Optional: Custom model configuration file path # Override the default location of custom_models.json # CUSTOM_MODELS_CONFIG_PATH=/path/to/your/custom_models.json -# Note: Redis is no longer used - conversations are stored in memory +# Note: Conversations are stored in memory during the session # Optional: Conversation timeout (hours) # How long AI-to-AI conversation threads persist before expiring -# Longer timeouts use more Redis memory but allow resuming conversations later +# Longer timeouts use more memory but allow resuming conversations later # Defaults to 3 hours if not specified CONVERSATION_TIMEOUT_HOURS=3 diff --git a/README.md b/README.md index c540380..40552da 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [zen_web.webm](https://github.com/user-attachments/assets/851e3911-7f06-47c0-a4ab-a2601236697c)
- 🤖 Claude + [Gemini / OpenAI / Grok / OpenRouter / Ollama / Any Model] = Your Ultimate AI Development Team + 🤖 Claude + [Gemini / OpenAI / Grok / OpenRouter / DIAL / Ollama / Any Model] = Your Ultimate AI Development Team

@@ -145,6 +145,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan - **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models. - **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) to get an API key for O3 model access. - **X.AI**: Visit [X.AI Console](https://console.x.ai/) to get an API key for GROK model access. +- **DIAL**: Visit [DIAL Platform](https://dialx.ai/) to get an API key for accessing multiple models through their unified API. DIAL is an open-source AI orchestration platform that provides vendor-agnostic access to models from major providers, open-source community, and self-hosted deployments. [API Documentation](https://dialx.ai/dial_api) **Option C: Custom API Endpoints (Local models like Ollama, vLLM)** [Please see the setup guide](docs/custom_models.md#option-2-custom-api-setup-ollama-vllm-etc). With a custom API you can use: @@ -154,7 +155,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan - **Text Generation WebUI**: Popular local interface for running models - **Any OpenAI-compatible API**: Custom endpoints for your own infrastructure -> **Note:** Using all three options may create ambiguity about which provider / model to use if there is an overlap. +> **Note:** Using multiple provider options may create ambiguity about which provider / model to use if there is an overlap. > If all APIs are configured, native APIs will take priority when there is a clash in model name, such as for `gemini` and `o3`. > Configure your model aliases and give them unique names in [`conf/custom_models.json`](conf/custom_models.json) @@ -192,6 +193,12 @@ nano .env # GEMINI_API_KEY=your-gemini-api-key-here # For Gemini models # OPENAI_API_KEY=your-openai-api-key-here # For O3 model # OPENROUTER_API_KEY=your-openrouter-key # For OpenRouter (see docs/custom_models.md) +# DIAL_API_KEY=your-dial-api-key-here # For DIAL platform + +# For DIAL (optional configuration): +# DIAL_API_HOST=https://core.dialx.ai # Default DIAL host (optional) +# DIAL_API_VERSION=2024-12-01-preview # API version (optional) +# DIAL_ALLOWED_MODELS=o3,gemini-2.5-pro # Restrict to specific models (optional) # For local models (Ollama, vLLM, etc.): # CUSTOM_API_URL=http://localhost:11434/v1 # Ollama example @@ -537,10 +544,11 @@ Configure the Zen MCP Server through environment variables in your `.env` file. DEFAULT_MODEL=auto GEMINI_API_KEY=your-gemini-key OPENAI_API_KEY=your-openai-key +DIAL_API_KEY=your-dial-key # Optional: Access to multiple models via DIAL ``` **Key Configuration Options:** -- **API Keys**: Native APIs (Gemini, OpenAI, X.AI), OpenRouter, or Custom endpoints (Ollama, vLLM) +- **API Keys**: Native APIs (Gemini, OpenAI, X.AI), OpenRouter, DIAL, or Custom endpoints (Ollama, vLLM) - **Model Selection**: Auto mode or specific model defaults - **Usage Restrictions**: Control which models can be used for cost control - **Conversation Settings**: Timeout, turn limits, memory configuration diff --git a/providers/base.py b/providers/base.py index e0b3882..c8b1ec7 100644 --- a/providers/base.py +++ b/providers/base.py @@ -17,6 +17,7 @@ class ProviderType(Enum): XAI = "xai" OPENROUTER = "openrouter" CUSTOM = "custom" + DIAL = "dial" class TemperatureConstraint(ABC): @@ -326,3 +327,12 @@ class ModelProvider(ABC): Resolved model name """ return model_name + + def close(self): + """Clean up any resources held by the provider. + + Default implementation does nothing. + Subclasses should override if they hold resources that need cleanup. + """ + # Base implementation: no resources to clean up + return diff --git a/providers/dial.py b/providers/dial.py new file mode 100644 index 0000000..617858c --- /dev/null +++ b/providers/dial.py @@ -0,0 +1,525 @@ +"""DIAL (Data & AI Layer) model provider implementation.""" + +import logging +import os +import threading +import time +from typing import Optional + +from .base import ( + ModelCapabilities, + ModelResponse, + ProviderType, + RangeTemperatureConstraint, +) +from .openai_compatible import OpenAICompatibleProvider + +logger = logging.getLogger(__name__) + + +class DIALModelProvider(OpenAICompatibleProvider): + """DIAL provider using OpenAI-compatible API. + + DIAL provides access to various AI models through a unified API interface. + Supports GPT, Claude, Gemini, and other models via DIAL deployments. + """ + + FRIENDLY_NAME = "DIAL" + + # Retry configuration for API calls + MAX_RETRIES = 4 + RETRY_DELAYS = [1, 3, 5, 8] # seconds + + # Supported DIAL models (these can be customized based on your DIAL deployment) + SUPPORTED_MODELS = { + "o3-2025-04-16": { + "context_window": 200_000, + "supports_extended_thinking": False, + "supports_vision": True, + }, + "o4-mini-2025-04-16": { + "context_window": 200_000, + "supports_extended_thinking": False, + "supports_vision": True, + }, + "anthropic.claude-sonnet-4-20250514-v1:0": { + "context_window": 200_000, + "supports_extended_thinking": False, + "supports_vision": True, + }, + "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": { + "context_window": 200_000, + "supports_extended_thinking": True, # Thinking mode variant + "supports_vision": True, + }, + "anthropic.claude-opus-4-20250514-v1:0": { + "context_window": 200_000, + "supports_extended_thinking": False, + "supports_vision": True, + }, + "anthropic.claude-opus-4-20250514-v1:0-with-thinking": { + "context_window": 200_000, + "supports_extended_thinking": True, # Thinking mode variant + "supports_vision": True, + }, + "gemini-2.5-pro-preview-03-25-google-search": { + "context_window": 1_000_000, + "supports_extended_thinking": False, # DIAL doesn't expose thinking mode + "supports_vision": True, + }, + "gemini-2.5-pro-preview-05-06": { + "context_window": 1_000_000, + "supports_extended_thinking": False, + "supports_vision": True, + }, + "gemini-2.5-flash-preview-05-20": { + "context_window": 1_000_000, + "supports_extended_thinking": False, + "supports_vision": True, + }, + # Shorthands + "o3": "o3-2025-04-16", + "o4-mini": "o4-mini-2025-04-16", + "sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0", + "sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking", + "opus-4": "anthropic.claude-opus-4-20250514-v1:0", + "opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking", + "gemini-2.5-pro": "gemini-2.5-pro-preview-05-06", + "gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search", + "gemini-2.5-flash": "gemini-2.5-flash-preview-05-20", + } + + def __init__(self, api_key: str, **kwargs): + """Initialize DIAL provider with API key and host. + + Args: + api_key: DIAL API key for authentication + **kwargs: Additional configuration options + """ + # Get DIAL API host from environment or kwargs + dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai" + + # DIAL uses /openai endpoint for OpenAI-compatible API + if not dial_host.endswith("/openai"): + dial_host = f"{dial_host.rstrip('/')}/openai" + + kwargs["base_url"] = dial_host + + # Get API version from environment or use default + self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview") + + # Add DIAL-specific headers + # DIAL uses Api-Key header instead of Authorization: Bearer + # Reference: https://dialx.ai/dial_api#section/Authorization + self.DEFAULT_HEADERS = { + "Api-Key": api_key, + } + + # Store the actual API key for use in Api-Key header + self._dial_api_key = api_key + + # Pass a placeholder API key to OpenAI client - we'll override the auth header in httpx + # The actual authentication happens via the Api-Key header in the httpx client + super().__init__("placeholder-not-used", **kwargs) + + # Cache for deployment-specific clients to avoid recreating them on each request + self._deployment_clients = {} + # Lock to ensure thread-safe client creation + self._client_lock = threading.Lock() + + # Create a SINGLE shared httpx client for the provider instance + import httpx + + # Create custom event hooks to remove Authorization header + def remove_auth_header(request): + """Remove Authorization header that OpenAI client adds.""" + # httpx headers are case-insensitive, so we need to check all variations + headers_to_remove = [] + for header_name in request.headers: + if header_name.lower() == "authorization": + headers_to_remove.append(header_name) + + for header_name in headers_to_remove: + del request.headers[header_name] + + self._http_client = httpx.Client( + timeout=self.timeout_config, + verify=True, + follow_redirects=True, + headers=self.DEFAULT_HEADERS.copy(), # Include DIAL headers including Api-Key + limits=httpx.Limits( + max_keepalive_connections=5, + max_connections=10, + keepalive_expiry=30.0, + ), + event_hooks={"request": [remove_auth_header]}, + ) + + logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}") + + def get_capabilities(self, model_name: str) -> ModelCapabilities: + """Get capabilities for a specific model. + + Args: + model_name: Name of the model (can be shorthand) + + Returns: + ModelCapabilities object + + Raises: + ValueError: If model is not supported or not allowed + """ + resolved_name = self._resolve_model_name(model_name) + + if resolved_name not in self.SUPPORTED_MODELS: + raise ValueError(f"Unsupported DIAL model: {model_name}") + + # Check restrictions + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name): + raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.") + + config = self.SUPPORTED_MODELS[resolved_name] + + return ModelCapabilities( + provider=ProviderType.DIAL, + model_name=resolved_name, + friendly_name=self.FRIENDLY_NAME, + context_window=config["context_window"], + supports_extended_thinking=config["supports_extended_thinking"], + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_images=config.get("supports_vision", False), + temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7), + ) + + def get_provider_type(self) -> ProviderType: + """Get the provider type.""" + return ProviderType.DIAL + + def validate_model_name(self, model_name: str) -> bool: + """Validate if the model name is supported. + + Args: + model_name: Model name to validate + + Returns: + True if model is supported and allowed, False otherwise + """ + resolved_name = self._resolve_model_name(model_name) + + if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + return False + + # Check against base class allowed_models if configured + if self.allowed_models is not None: + # Check both original and resolved names (case-insensitive) + if model_name.lower() not in self.allowed_models and resolved_name.lower() not in self.allowed_models: + logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' not in allowed_models list") + return False + + # Also check restrictions via ModelRestrictionService + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name): + logger.debug(f"DIAL model '{model_name}' -> '{resolved_name}' blocked by restrictions") + return False + + return True + + def _resolve_model_name(self, model_name: str) -> str: + """Resolve model shorthand to full name. + + Args: + model_name: Model name or shorthand + + Returns: + Full model name + """ + shorthand_value = self.SUPPORTED_MODELS.get(model_name) + if isinstance(shorthand_value, str): + return shorthand_value + return model_name + + def _get_deployment_client(self, deployment: str): + """Get or create a cached client for a specific deployment. + + This avoids recreating OpenAI clients on every request, improving performance. + Reuses the shared HTTP client for connection pooling. + + Args: + deployment: The deployment/model name + + Returns: + OpenAI client configured for the specific deployment + """ + # Check if client already exists without locking for performance + if deployment in self._deployment_clients: + return self._deployment_clients[deployment] + + # Use lock to ensure thread-safe client creation + with self._client_lock: + # Double-check pattern: check again inside the lock + if deployment not in self._deployment_clients: + from openai import OpenAI + + # Build deployment-specific URL + base_url = str(self.client.base_url) + if base_url.endswith("/"): + base_url = base_url[:-1] + + # Remove /openai suffix if present to reconstruct properly + if base_url.endswith("/openai"): + base_url = base_url[:-7] + + deployment_url = f"{base_url}/openai/deployments/{deployment}" + + # Create and cache the client, REUSING the shared http_client + # Use placeholder API key - Authorization header will be removed by http_client event hook + self._deployment_clients[deployment] = OpenAI( + api_key="placeholder-not-used", + base_url=deployment_url, + http_client=self._http_client, # Pass the shared client with Api-Key header + default_query={"api-version": self.api_version}, # Add api-version as query param + ) + + return self._deployment_clients[deployment] + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.7, + max_output_tokens: Optional[int] = None, + images: Optional[list[str]] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using DIAL's deployment-specific endpoint. + + DIAL uses Azure OpenAI-style deployment endpoints: + /openai/deployments/{deployment}/chat/completions + + Args: + prompt: User prompt + model_name: Model name or alias + system_prompt: Optional system prompt + temperature: Sampling temperature + max_output_tokens: Maximum tokens to generate + **kwargs: Additional provider-specific parameters + + Returns: + ModelResponse with generated content and metadata + """ + # Validate model name against allow-list + if not self.validate_model_name(model_name): + raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}") + + # Validate parameters + self.validate_parameters(model_name, temperature) + + # Prepare messages + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + # Build user message content + user_message_content = [] + if prompt: + user_message_content.append({"type": "text", "text": prompt}) + + if images and self._supports_vision(model_name): + for img_path in images: + processed_image = self._process_image(img_path) + if processed_image: + user_message_content.append(processed_image) + elif images: + logger.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)") + + # Add user message. If only text, content will be a string, otherwise a list. + if len(user_message_content) == 1 and user_message_content[0]["type"] == "text": + messages.append({"role": "user", "content": prompt}) + else: + messages.append({"role": "user", "content": user_message_content}) + + # Resolve model name + resolved_model = self._resolve_model_name(model_name) + + # Build completion parameters + completion_params = { + "model": resolved_model, + "messages": messages, + } + + # Check model capabilities + try: + capabilities = self.get_capabilities(model_name) + supports_temperature = getattr(capabilities, "supports_temperature", True) + except Exception as e: + logger.debug(f"Failed to check temperature support for {model_name}: {e}") + supports_temperature = True + + # Add temperature parameter if supported + if supports_temperature: + completion_params["temperature"] = temperature + + # Add max tokens if specified and model supports it + if max_output_tokens and supports_temperature: + completion_params["max_tokens"] = max_output_tokens + + # Add additional parameters + for key, value in kwargs.items(): + if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]: + if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]: + continue + completion_params[key] = value + + # DIAL-specific: Get cached client for deployment endpoint + deployment_client = self._get_deployment_client(resolved_model) + + # Retry logic with progressive delays + last_exception = None + + for attempt in range(self.MAX_RETRIES): + try: + # Generate completion using deployment-specific client + response = deployment_client.chat.completions.create(**completion_params) + + # Extract content and usage + content = response.choices[0].message.content + usage = self._extract_usage(response) + + return ModelResponse( + content=content, + usage=usage, + model_name=model_name, + friendly_name=self.FRIENDLY_NAME, + provider=self.get_provider_type(), + metadata={ + "finish_reason": response.choices[0].finish_reason, + "model": response.model, + "id": response.id, + "created": response.created, + }, + ) + + except Exception as e: + last_exception = e + + # Check if this is a retryable error + is_retryable = self._is_error_retryable(e) + + if not is_retryable: + # Non-retryable error, raise immediately + raise ValueError(f"DIAL API error for model {model_name}: {str(e)}") + + # If this isn't the last attempt and error is retryable, wait and retry + if attempt < self.MAX_RETRIES - 1: + delay = self.RETRY_DELAYS[attempt] + logger.info( + f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}" + ) + time.sleep(delay) + continue + + # All retries exhausted + raise ValueError( + f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}" + ) + + def _supports_vision(self, model_name: str) -> bool: + """Check if the model supports vision (image processing). + + Args: + model_name: Model name to check + + Returns: + True if model supports vision, False otherwise + """ + resolved_name = self._resolve_model_name(model_name) + + if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False) + + # Fall back to parent implementation for unknown models + return super()._supports_vision(model_name) + + def list_models(self, respect_restrictions: bool = True) -> list[str]: + """Return a list of model names supported by this provider. + + Args: + respect_restrictions: Whether to apply provider-specific restriction logic. + + Returns: + List of model names available from this provider + """ + # Get all model keys (both full names and aliases) + all_models = list(self.SUPPORTED_MODELS.keys()) + + if not respect_restrictions: + return all_models + + # Apply restrictions if configured + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + + # Filter based on restrictions + allowed_models = [] + for model in all_models: + resolved_name = self._resolve_model_name(model) + if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model): + allowed_models.append(model) + + return allowed_models + + def list_all_known_models(self) -> list[str]: + """Return all model names known by this provider, including alias targets. + + This is used for validation purposes to ensure restriction policies + can validate against both aliases and their target model names. + + Returns: + List of all model names and alias targets known by this provider + """ + # Collect all unique model names (both aliases and targets) + all_models = set() + + for key, value in self.SUPPORTED_MODELS.items(): + # Add the key (could be alias or full name) + all_models.add(key) + + # If it's an alias (string value), add the target too + if isinstance(value, str): + all_models.add(value) + + return sorted(all_models) + + def close(self): + """Clean up HTTP clients when provider is closed.""" + logger.info("Closing DIAL provider HTTP clients...") + + # Clear the deployment clients cache + # Note: We don't need to close individual OpenAI clients since they + # use the shared httpx.Client which we close separately + self._deployment_clients.clear() + + # Close the shared HTTP client + if hasattr(self, "_http_client"): + try: + self._http_client.close() + logger.debug("Closed shared HTTP client") + except Exception as e: + logger.warning(f"Error closing shared HTTP client: {e}") + + # Also close the client created by the superclass (OpenAICompatibleProvider) + # as it holds its own httpx.Client instance that is not used by DIAL's generate_content + if hasattr(self, "client") and self.client and hasattr(self.client, "close"): + try: + self.client.close() + logger.debug("Closed superclass's OpenAI client") + except Exception as e: + logger.warning(f"Error closing superclass's OpenAI client: {e}") diff --git a/providers/registry.py b/providers/registry.py index a5efcf0..baa9222 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -118,6 +118,7 @@ class ModelProviderRegistry: ProviderType.GOOGLE, # Direct Gemini access ProviderType.OPENAI, # Direct OpenAI access ProviderType.XAI, # Direct X.AI GROK access + ProviderType.DIAL, # DIAL unified API access ProviderType.CUSTOM, # Local/self-hosted models ProviderType.OPENROUTER, # Catch-all for cloud models ] @@ -237,6 +238,7 @@ class ModelProviderRegistry: ProviderType.XAI: "XAI_API_KEY", ProviderType.OPENROUTER: "OPENROUTER_API_KEY", ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth + ProviderType.DIAL: "DIAL_API_KEY", } env_var = key_mapping.get(provider_type) diff --git a/run-server.sh b/run-server.sh index d2d0ebe..243f0e0 100755 --- a/run-server.sh +++ b/run-server.sh @@ -883,6 +883,7 @@ setup_env_file() { "GEMINI_API_KEY:your_gemini_api_key_here" "OPENAI_API_KEY:your_openai_api_key_here" "XAI_API_KEY:your_xai_api_key_here" + "DIAL_API_KEY:your_dial_api_key_here" "OPENROUTER_API_KEY:your_openrouter_api_key_here" ) @@ -934,6 +935,7 @@ validate_api_keys() { "GEMINI_API_KEY:your_gemini_api_key_here" "OPENAI_API_KEY:your_openai_api_key_here" "XAI_API_KEY:your_xai_api_key_here" + "DIAL_API_KEY:your_dial_api_key_here" "OPENROUTER_API_KEY:your_openrouter_api_key_here" ) @@ -961,6 +963,7 @@ validate_api_keys() { echo " GEMINI_API_KEY=your-actual-key" >&2 echo " OPENAI_API_KEY=your-actual-key" >&2 echo " XAI_API_KEY=your-actual-key" >&2 + echo " DIAL_API_KEY=your-actual-key" >&2 echo " OPENROUTER_API_KEY=your-actual-key" >&2 echo "" >&2 print_info "After adding your API keys, run ./run-server.sh again" >&2 diff --git a/server.py b/server.py index 1b0f969..19904fb 100644 --- a/server.py +++ b/server.py @@ -19,6 +19,7 @@ as defined by the MCP protocol. """ import asyncio +import atexit import logging import os import sys @@ -271,6 +272,7 @@ def configure_providers(): from providers import ModelProviderRegistry from providers.base import ProviderType from providers.custom import CustomProvider + from providers.dial import DIALModelProvider from providers.gemini import GeminiModelProvider from providers.openai_provider import OpenAIModelProvider from providers.openrouter import OpenRouterProvider @@ -303,6 +305,13 @@ def configure_providers(): has_native_apis = True logger.info("X.AI API key found - GROK models available") + # Check for DIAL API key + dial_key = os.getenv("DIAL_API_KEY") + if dial_key and dial_key != "your_dial_api_key_here": + valid_providers.append("DIAL") + has_native_apis = True + logger.info("DIAL API key found - DIAL models available") + # Check for OpenRouter API key openrouter_key = os.getenv("OPENROUTER_API_KEY") if openrouter_key and openrouter_key != "your_openrouter_api_key_here": @@ -336,6 +345,8 @@ def configure_providers(): ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider) if xai_key and xai_key != "your_xai_api_key_here": ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider) + if dial_key and dial_key != "your_dial_api_key_here": + ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider) # 2. Custom provider second (for local/private models) if has_custom: @@ -358,6 +369,7 @@ def configure_providers(): "- GEMINI_API_KEY for Gemini models\n" "- OPENAI_API_KEY for OpenAI o3 model\n" "- XAI_API_KEY for X.AI GROK models\n" + "- DIAL_API_KEY for DIAL models\n" "- OPENROUTER_API_KEY for OpenRouter (multiple models)\n" "- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)" ) @@ -376,6 +388,25 @@ def configure_providers(): if len(priority_info) > 1: logger.info(f"Provider priority: {' → '.join(priority_info)}") + # Register cleanup function for providers + def cleanup_providers(): + """Clean up all registered providers on shutdown.""" + try: + registry = ModelProviderRegistry() + if hasattr(registry, "_initialized_providers"): + for provider in list(registry._initialized_providers.items()): + try: + if provider and hasattr(provider, "close"): + provider.close() + except Exception: + # Logger might be closed during shutdown + pass + except Exception: + # Silently ignore any errors during cleanup + pass + + atexit.register(cleanup_providers) + # Check and log model restrictions restriction_service = get_restriction_service() restrictions = restriction_service.get_restriction_summary() @@ -390,7 +421,8 @@ def configure_providers(): # Validate restrictions against known models provider_instances = {} - for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI]: + provider_types_to_validate = [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL] + for provider_type in provider_types_to_validate: provider = ModelProviderRegistry.get_provider(provider_type) if provider: provider_instances[provider_type] = provider diff --git a/tests/test_dial_provider.py b/tests/test_dial_provider.py new file mode 100644 index 0000000..4a22cb6 --- /dev/null +++ b/tests/test_dial_provider.py @@ -0,0 +1,273 @@ +"""Tests for DIAL provider implementation.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest + +from providers.base import ProviderType +from providers.dial import DIALModelProvider + + +class TestDIALProvider: + """Test DIAL provider functionality.""" + + @patch.dict(os.environ, {"DIAL_API_KEY": "test-key", "DIAL_API_HOST": "https://test.dialx.ai"}) + def test_initialization_with_host(self): + """Test provider initialization with custom host.""" + provider = DIALModelProvider("test-key") + assert provider._dial_api_key == "test-key" # Check internal API key storage + assert provider.api_key == "placeholder-not-used" # OpenAI client uses placeholder, auth header removed by hook + assert provider.base_url == "https://test.dialx.ai/openai" + assert provider.get_provider_type() == ProviderType.DIAL + + @patch.dict(os.environ, {"DIAL_API_KEY": "test-key", "DIAL_API_HOST": ""}, clear=True) + def test_initialization_default_host(self): + """Test provider initialization with default host.""" + provider = DIALModelProvider("test-key") + assert provider._dial_api_key == "test-key" # Check internal API key storage + assert provider.api_key == "placeholder-not-used" # OpenAI client uses placeholder, auth header removed by hook + assert provider.base_url == "https://core.dialx.ai/openai" + + def test_initialization_host_normalization(self): + """Test that host URL is normalized to include /openai suffix.""" + # Test with host missing /openai + provider = DIALModelProvider("test-key", base_url="https://custom.dialx.ai") + assert provider.base_url == "https://custom.dialx.ai/openai" + + # Test with host already having /openai + provider = DIALModelProvider("test-key", base_url="https://custom.dialx.ai/openai") + assert provider.base_url == "https://custom.dialx.ai/openai" + + @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False) + @patch("utils.model_restrictions._restriction_service", None) + def test_model_validation(self): + """Test model name validation.""" + provider = DIALModelProvider("test-key") + + # Test valid models + assert provider.validate_model_name("o3-2025-04-16") is True + assert provider.validate_model_name("o3") is True # Shorthand + assert provider.validate_model_name("anthropic.claude-opus-4-20250514-v1:0") is True + assert provider.validate_model_name("opus-4") is True # Shorthand + assert provider.validate_model_name("gemini-2.5-pro-preview-05-06") is True + assert provider.validate_model_name("gemini-2.5-pro") is True # Shorthand + + # Test invalid model + assert provider.validate_model_name("invalid-model") is False + + def test_resolve_model_name(self): + """Test model name resolution for shorthands.""" + provider = DIALModelProvider("test-key") + + # Test shorthand resolution + assert provider._resolve_model_name("o3") == "o3-2025-04-16" + assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16" + assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0" + assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0" + assert provider._resolve_model_name("gemini-2.5-pro") == "gemini-2.5-pro-preview-05-06" + assert provider._resolve_model_name("gemini-2.5-flash") == "gemini-2.5-flash-preview-05-20" + + # Test full name passthrough + assert provider._resolve_model_name("o3-2025-04-16") == "o3-2025-04-16" + assert ( + provider._resolve_model_name("anthropic.claude-opus-4-20250514-v1:0") + == "anthropic.claude-opus-4-20250514-v1:0" + ) + + @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False) + @patch("utils.model_restrictions._restriction_service", None) + def test_get_capabilities(self): + """Test getting model capabilities.""" + provider = DIALModelProvider("test-key") + + # Test O3 capabilities + capabilities = provider.get_capabilities("o3") + assert capabilities.model_name == "o3-2025-04-16" + assert capabilities.friendly_name == "DIAL" + assert capabilities.context_window == 200_000 + assert capabilities.provider == ProviderType.DIAL + assert capabilities.supports_images is True + assert capabilities.supports_extended_thinking is False + + # Test Claude 4 capabilities + capabilities = provider.get_capabilities("opus-4") + assert capabilities.model_name == "anthropic.claude-opus-4-20250514-v1:0" + assert capabilities.context_window == 200_000 + assert capabilities.supports_images is True + assert capabilities.supports_extended_thinking is False + + # Test Claude 4 with thinking mode + capabilities = provider.get_capabilities("opus-4-thinking") + assert capabilities.model_name == "anthropic.claude-opus-4-20250514-v1:0-with-thinking" + assert capabilities.context_window == 200_000 + assert capabilities.supports_images is True + assert capabilities.supports_extended_thinking is True + + # Test Gemini capabilities + capabilities = provider.get_capabilities("gemini-2.5-pro") + assert capabilities.model_name == "gemini-2.5-pro-preview-05-06" + assert capabilities.context_window == 1_000_000 + assert capabilities.supports_images is True + + # Test temperature constraint + assert capabilities.temperature_constraint.min_temp == 0.0 + assert capabilities.temperature_constraint.max_temp == 2.0 + assert capabilities.temperature_constraint.default_temp == 0.7 + + @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False) + @patch("utils.model_restrictions._restriction_service", None) + def test_get_capabilities_invalid_model(self): + """Test that get_capabilities raises for invalid models.""" + provider = DIALModelProvider("test-key") + + with pytest.raises(ValueError, match="Unsupported DIAL model"): + provider.get_capabilities("invalid-model") + + @patch("utils.model_restrictions.get_restriction_service") + def test_get_capabilities_restricted_model(self, mock_get_restriction): + """Test that get_capabilities respects model restrictions.""" + provider = DIALModelProvider("test-key") + + # Mock restriction service to block the model + mock_service = MagicMock() + mock_service.is_allowed.return_value = False + mock_get_restriction.return_value = mock_service + + with pytest.raises(ValueError, match="not allowed by restriction policy"): + provider.get_capabilities("o3") + + @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False) + @patch("utils.model_restrictions._restriction_service", None) + def test_supports_vision(self): + """Test vision support detection.""" + provider = DIALModelProvider("test-key") + + # Test models with vision support + assert provider._supports_vision("o3-2025-04-16") is True + assert provider._supports_vision("o3") is True # Via resolution + assert provider._supports_vision("anthropic.claude-opus-4-20250514-v1:0") is True + assert provider._supports_vision("gemini-2.5-pro-preview-05-06") is True + + # Test unknown model (falls back to parent implementation) + assert provider._supports_vision("unknown-model") is False + + @patch("openai.OpenAI") # Mock the OpenAI class directly from openai module + def test_generate_content_with_alias(self, mock_openai_class): + """Test that generate_content properly resolves aliases and uses deployment routing.""" + # Create mock client + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))] + mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + mock_response.model = "gpt-4" + mock_response.id = "test-id" + mock_response.created = 1234567890 + mock_response.choices[0].finish_reason = "stop" + + mock_client.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_client + + provider = DIALModelProvider("test-key") + + # Generate content with shorthand + response = provider.generate_content(prompt="Test prompt", model_name="o3", temperature=0.7) # Shorthand + + # Verify OpenAI was instantiated with deployment-specific URL + mock_openai_class.assert_called_once() + call_args = mock_openai_class.call_args + assert "/deployments/o3-2025-04-16" in call_args[1]["base_url"] + + # Verify the resolved model name was passed to the API + mock_client.chat.completions.create.assert_called_once() + create_call_args = mock_client.chat.completions.create.call_args + assert create_call_args[1]["model"] == "o3-2025-04-16" # Resolved name + + # Verify response + assert response.content == "Test response" + assert response.model_name == "o3" # Original name preserved + assert response.metadata["model"] == "gpt-4" # API returned model name from mock + + def test_provider_type(self): + """Test provider type identification.""" + provider = DIALModelProvider("test-key") + assert provider.get_provider_type() == ProviderType.DIAL + + def test_friendly_name(self): + """Test provider friendly name.""" + provider = DIALModelProvider("test-key") + assert provider.FRIENDLY_NAME == "DIAL" + + @patch.dict(os.environ, {"DIAL_API_VERSION": "2024-12-01"}) + def test_configurable_api_version(self): + """Test that API version can be configured via environment variable.""" + provider = DIALModelProvider("test-key") + # Check that the custom API version is stored + assert provider.api_version == "2024-12-01" + + def test_default_api_version(self): + """Test that default API version is used when not configured.""" + # Clear any existing DIAL_API_VERSION from environment + with patch.dict(os.environ, {}, clear=True): + # Keep other env vars but ensure DIAL_API_VERSION is not set + if "DIAL_API_VERSION" in os.environ: + del os.environ["DIAL_API_VERSION"] + + provider = DIALModelProvider("test-key") + # Check that the default API version is used + assert provider.api_version == "2024-12-01-preview" + # Check that Api-Key header is set + assert provider.DEFAULT_HEADERS["Api-Key"] == "test-key" + + @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": "o3-2025-04-16,anthropic.claude-opus-4-20250514-v1:0"}) + @patch("utils.model_restrictions._restriction_service", None) + def test_allowed_models_restriction(self): + """Test model allow-list functionality.""" + provider = DIALModelProvider("test-key") + + # These should be allowed + assert provider.validate_model_name("o3-2025-04-16") is True + assert provider.validate_model_name("o3") is True # Alias for o3-2025-04-16 + assert provider.validate_model_name("anthropic.claude-opus-4-20250514-v1:0") is True + assert provider.validate_model_name("opus-4") is True # Resolves to anthropic.claude-opus-4-20250514-v1:0 + + # These should be blocked + assert provider.validate_model_name("gemini-2.5-pro-preview-05-06") is False + assert provider.validate_model_name("o4-mini-2025-04-16") is False + assert provider.validate_model_name("sonnet-4") is False # sonnet-4 is not in allowed list + + @patch("httpx.Client") + @patch("openai.OpenAI") + def test_close_method(self, mock_openai_class, mock_httpx_client_class): + """Test that the close method properly closes HTTP clients.""" + # Mock the httpx.Client instance that DIALModelProvider will create + mock_shared_http_client = MagicMock() + mock_httpx_client_class.return_value = mock_shared_http_client + + # Mock the OpenAI client instances + mock_openai_client_1 = MagicMock() + mock_openai_client_2 = MagicMock() + # Configure side_effect to return different mocks for subsequent calls + mock_openai_class.side_effect = [mock_openai_client_1, mock_openai_client_2] + + provider = DIALModelProvider("test-key") + + # Mock the superclass's _client attribute directly + mock_superclass_client = MagicMock() + provider._client = mock_superclass_client + + # Simulate getting clients for two different deployments to populate _deployment_clients + provider._get_deployment_client("model_a") + provider._get_deployment_client("model_b") + + # Now call close + provider.close() + + # Assert that the shared httpx client's close method was called + mock_shared_http_client.close.assert_called_once() + + # Assert that the superclass client's close method was called + mock_superclass_client.close.assert_called_once() + + # Assert that the deployment clients cache is cleared + assert not provider._deployment_clients diff --git a/tools/listmodels.py b/tools/listmodels.py index 6a623b9..265fbcc 100644 --- a/tools/listmodels.py +++ b/tools/listmodels.py @@ -84,6 +84,7 @@ class ListModelsTool(BaseTool): ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"}, ProviderType.OPENAI: {"name": "OpenAI", "env_key": "OPENAI_API_KEY"}, ProviderType.XAI: {"name": "X.AI (Grok)", "env_key": "XAI_API_KEY"}, + ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"}, } # Check each native provider type diff --git a/utils/model_restrictions.py b/utils/model_restrictions.py index 0b7ff25..834c0a2 100644 --- a/utils/model_restrictions.py +++ b/utils/model_restrictions.py @@ -11,6 +11,7 @@ Environment Variables: - GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models - XAI_ALLOWED_MODELS: Comma-separated list of allowed X.AI GROK models - OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models +- DIAL_ALLOWED_MODELS: Comma-separated list of allowed DIAL models Example: OPENAI_ALLOWED_MODELS=o3-mini,o4-mini @@ -44,6 +45,7 @@ class ModelRestrictionService: ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS", ProviderType.XAI: "XAI_ALLOWED_MODELS", ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS", + ProviderType.DIAL: "DIAL_ALLOWED_MODELS", } def __init__(self):