469 lines
19 KiB
Python
469 lines
19 KiB
Python
"""DIAL (Data & AI Layer) model provider implementation."""
|
|
|
|
import logging
|
|
import os
|
|
import threading
|
|
from typing import Optional
|
|
|
|
from .openai_compatible import OpenAICompatibleProvider
|
|
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DIALModelProvider(OpenAICompatibleProvider):
|
|
"""Client for the DIAL (Data & AI Layer) aggregation service.
|
|
|
|
DIAL exposes several third-party models behind a single OpenAI-compatible
|
|
endpoint. This provider wraps the service, publishes capability metadata
|
|
for the known deployments, and centralises retry/backoff settings tailored
|
|
to DIAL's latency characteristics.
|
|
"""
|
|
|
|
FRIENDLY_NAME = "DIAL"
|
|
|
|
# Retry configuration for API calls
|
|
MAX_RETRIES = 4
|
|
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
|
|
|
# Model configurations using ModelCapabilities objects
|
|
MODEL_CAPABILITIES = {
|
|
"o3-2025-04-16": ModelCapabilities(
|
|
provider=ProviderType.DIAL,
|
|
model_name="o3-2025-04-16",
|
|
friendly_name="DIAL (O3)",
|
|
intelligence_score=14,
|
|
context_window=200_000,
|
|
max_output_tokens=100_000,
|
|
supports_extended_thinking=False,
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=False, # DIAL may not expose function calling
|
|
supports_json_mode=True,
|
|
supports_images=True,
|
|
max_image_size_mb=20.0,
|
|
supports_temperature=False, # O3 models don't accept temperature
|
|
temperature_constraint=TemperatureConstraint.create("fixed"),
|
|
description="OpenAI O3 via DIAL - Strong reasoning model",
|
|
aliases=["o3"],
|
|
),
|
|
"o4-mini-2025-04-16": ModelCapabilities(
|
|
provider=ProviderType.DIAL,
|
|
model_name="o4-mini-2025-04-16",
|
|
friendly_name="DIAL (O4-mini)",
|
|
intelligence_score=11,
|
|
context_window=200_000,
|
|
max_output_tokens=100_000,
|
|
supports_extended_thinking=False,
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=False, # DIAL may not expose function calling
|
|
supports_json_mode=True,
|
|
supports_images=True,
|
|
max_image_size_mb=20.0,
|
|
supports_temperature=False, # O4 models don't accept temperature
|
|
temperature_constraint=TemperatureConstraint.create("fixed"),
|
|
description="OpenAI O4-mini via DIAL - Fast reasoning model",
|
|
aliases=["o4-mini"],
|
|
),
|
|
"anthropic.claude-sonnet-4.1-20250805-v1:0": ModelCapabilities(
|
|
provider=ProviderType.DIAL,
|
|
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0",
|
|
friendly_name="DIAL (Sonnet 4.1)",
|
|
intelligence_score=10,
|
|
context_window=200_000,
|
|
max_output_tokens=64_000,
|
|
supports_extended_thinking=False,
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=False,
|
|
supports_json_mode=True,
|
|
supports_images=True,
|
|
max_image_size_mb=5.0,
|
|
supports_temperature=True,
|
|
temperature_constraint=TemperatureConstraint.create("range"),
|
|
description="Claude Sonnet 4.1 via DIAL - Balanced performance",
|
|
aliases=["sonnet-4.1", "sonnet-4"],
|
|
),
|
|
"anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
|
|
provider=ProviderType.DIAL,
|
|
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking",
|
|
friendly_name="DIAL (Sonnet 4.1 Thinking)",
|
|
intelligence_score=11,
|
|
context_window=200_000,
|
|
max_output_tokens=64_000,
|
|
supports_extended_thinking=True,
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=False,
|
|
supports_json_mode=True,
|
|
supports_images=True,
|
|
max_image_size_mb=5.0,
|
|
supports_temperature=True,
|
|
temperature_constraint=TemperatureConstraint.create("range"),
|
|
description="Claude Sonnet 4.1 with thinking mode via DIAL",
|
|
aliases=["sonnet-4.1-thinking", "sonnet-4-thinking"],
|
|
),
|
|
"anthropic.claude-opus-4.1-20250805-v1:0": ModelCapabilities(
|
|
provider=ProviderType.DIAL,
|
|
model_name="anthropic.claude-opus-4.1-20250805-v1:0",
|
|
friendly_name="DIAL (Opus 4.1)",
|
|
intelligence_score=14,
|
|
context_window=200_000,
|
|
max_output_tokens=64_000,
|
|
supports_extended_thinking=False,
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=False,
|
|
supports_json_mode=True,
|
|
supports_images=True,
|
|
max_image_size_mb=5.0,
|
|
supports_temperature=True,
|
|
temperature_constraint=TemperatureConstraint.create("range"),
|
|
description="Claude Opus 4.1 via DIAL - Most capable Claude model",
|
|
aliases=["opus-4.1", "opus-4"],
|
|
),
|
|
"anthropic.claude-opus-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
|
|
provider=ProviderType.DIAL,
|
|
model_name="anthropic.claude-opus-4.1-20250805-v1:0-with-thinking",
|
|
friendly_name="DIAL (Opus 4.1 Thinking)",
|
|
intelligence_score=15,
|
|
context_window=200_000,
|
|
max_output_tokens=64_000,
|
|
supports_extended_thinking=True,
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=False,
|
|
supports_json_mode=True,
|
|
supports_images=True,
|
|
max_image_size_mb=5.0,
|
|
supports_temperature=True,
|
|
temperature_constraint=TemperatureConstraint.create("range"),
|
|
description="Claude Opus 4.1 with thinking mode via DIAL",
|
|
aliases=["opus-4.1-thinking", "opus-4-thinking"],
|
|
),
|
|
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
|
|
provider=ProviderType.DIAL,
|
|
model_name="gemini-2.5-pro-preview-03-25-google-search",
|
|
friendly_name="DIAL (Gemini 2.5 Pro Search)",
|
|
intelligence_score=17,
|
|
context_window=1_000_000,
|
|
max_output_tokens=65_536,
|
|
supports_extended_thinking=False,
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=False,
|
|
supports_json_mode=True,
|
|
supports_images=True,
|
|
max_image_size_mb=20.0,
|
|
supports_temperature=True,
|
|
temperature_constraint=TemperatureConstraint.create("range"),
|
|
description="Gemini 2.5 Pro with Google Search via DIAL",
|
|
aliases=["gemini-2.5-pro-search"],
|
|
),
|
|
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
|
|
provider=ProviderType.DIAL,
|
|
model_name="gemini-2.5-pro-preview-05-06",
|
|
friendly_name="DIAL (Gemini 2.5 Pro)",
|
|
intelligence_score=18,
|
|
context_window=1_000_000,
|
|
max_output_tokens=65_536,
|
|
supports_extended_thinking=False,
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=False,
|
|
supports_json_mode=True,
|
|
supports_images=True,
|
|
max_image_size_mb=20.0,
|
|
supports_temperature=True,
|
|
temperature_constraint=TemperatureConstraint.create("range"),
|
|
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
|
|
aliases=["gemini-2.5-pro"],
|
|
),
|
|
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
|
|
provider=ProviderType.DIAL,
|
|
model_name="gemini-2.5-flash-preview-05-20",
|
|
friendly_name="DIAL (Gemini Flash 2.5)",
|
|
intelligence_score=10,
|
|
context_window=1_000_000,
|
|
max_output_tokens=65_536,
|
|
supports_extended_thinking=False,
|
|
supports_system_prompts=True,
|
|
supports_streaming=True,
|
|
supports_function_calling=False,
|
|
supports_json_mode=True,
|
|
supports_images=True,
|
|
max_image_size_mb=20.0,
|
|
supports_temperature=True,
|
|
temperature_constraint=TemperatureConstraint.create("range"),
|
|
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
|
|
aliases=["gemini-2.5-flash"],
|
|
),
|
|
}
|
|
|
|
def __init__(self, api_key: str, **kwargs):
|
|
"""Initialize DIAL provider with API key and host.
|
|
|
|
Args:
|
|
api_key: DIAL API key for authentication
|
|
**kwargs: Additional configuration options
|
|
"""
|
|
# Get DIAL API host from environment or kwargs
|
|
dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai"
|
|
|
|
# DIAL uses /openai endpoint for OpenAI-compatible API
|
|
if not dial_host.endswith("/openai"):
|
|
dial_host = f"{dial_host.rstrip('/')}/openai"
|
|
|
|
kwargs["base_url"] = dial_host
|
|
|
|
# Get API version from environment or use default
|
|
self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview")
|
|
|
|
# Add DIAL-specific headers
|
|
# DIAL uses Api-Key header instead of Authorization: Bearer
|
|
# Reference: https://dialx.ai/dial_api#section/Authorization
|
|
self.DEFAULT_HEADERS = {
|
|
"Api-Key": api_key,
|
|
}
|
|
|
|
# Store the actual API key for use in Api-Key header
|
|
self._dial_api_key = api_key
|
|
|
|
# Pass a placeholder API key to OpenAI client - we'll override the auth header in httpx
|
|
# The actual authentication happens via the Api-Key header in the httpx client
|
|
super().__init__("placeholder-not-used", **kwargs)
|
|
|
|
# Cache for deployment-specific clients to avoid recreating them on each request
|
|
self._deployment_clients = {}
|
|
# Lock to ensure thread-safe client creation
|
|
self._client_lock = threading.Lock()
|
|
|
|
# Create a SINGLE shared httpx client for the provider instance
|
|
import httpx
|
|
|
|
# Create custom event hooks to remove Authorization header
|
|
def remove_auth_header(request):
|
|
"""Remove Authorization header that OpenAI client adds."""
|
|
# httpx headers are case-insensitive, so we need to check all variations
|
|
headers_to_remove = []
|
|
for header_name in request.headers:
|
|
if header_name.lower() == "authorization":
|
|
headers_to_remove.append(header_name)
|
|
|
|
for header_name in headers_to_remove:
|
|
del request.headers[header_name]
|
|
|
|
self._http_client = httpx.Client(
|
|
timeout=self.timeout_config,
|
|
verify=True,
|
|
follow_redirects=True,
|
|
headers=self.DEFAULT_HEADERS.copy(), # Include DIAL headers including Api-Key
|
|
limits=httpx.Limits(
|
|
max_keepalive_connections=5,
|
|
max_connections=10,
|
|
keepalive_expiry=30.0,
|
|
),
|
|
event_hooks={"request": [remove_auth_header]},
|
|
)
|
|
|
|
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
|
|
|
|
def get_provider_type(self) -> ProviderType:
|
|
"""Get the provider type."""
|
|
return ProviderType.DIAL
|
|
|
|
def _get_deployment_client(self, deployment: str):
|
|
"""Get or create a cached client for a specific deployment.
|
|
|
|
This avoids recreating OpenAI clients on every request, improving performance.
|
|
Reuses the shared HTTP client for connection pooling.
|
|
|
|
Args:
|
|
deployment: The deployment/model name
|
|
|
|
Returns:
|
|
OpenAI client configured for the specific deployment
|
|
"""
|
|
# Check if client already exists without locking for performance
|
|
if deployment in self._deployment_clients:
|
|
return self._deployment_clients[deployment]
|
|
|
|
# Use lock to ensure thread-safe client creation
|
|
with self._client_lock:
|
|
# Double-check pattern: check again inside the lock
|
|
if deployment not in self._deployment_clients:
|
|
from openai import OpenAI
|
|
|
|
# Build deployment-specific URL
|
|
base_url = str(self.client.base_url)
|
|
if base_url.endswith("/"):
|
|
base_url = base_url[:-1]
|
|
|
|
# Remove /openai suffix if present to reconstruct properly
|
|
if base_url.endswith("/openai"):
|
|
base_url = base_url[:-7]
|
|
|
|
deployment_url = f"{base_url}/openai/deployments/{deployment}"
|
|
|
|
# Create and cache the client, REUSING the shared http_client
|
|
# Use placeholder API key - Authorization header will be removed by http_client event hook
|
|
self._deployment_clients[deployment] = OpenAI(
|
|
api_key="placeholder-not-used",
|
|
base_url=deployment_url,
|
|
http_client=self._http_client, # Pass the shared client with Api-Key header
|
|
default_query={"api-version": self.api_version}, # Add api-version as query param
|
|
)
|
|
|
|
return self._deployment_clients[deployment]
|
|
|
|
def generate_content(
|
|
self,
|
|
prompt: str,
|
|
model_name: str,
|
|
system_prompt: Optional[str] = None,
|
|
temperature: float = 0.3,
|
|
max_output_tokens: Optional[int] = None,
|
|
images: Optional[list[str]] = None,
|
|
**kwargs,
|
|
) -> ModelResponse:
|
|
"""Generate content using DIAL's deployment-specific endpoint.
|
|
|
|
DIAL uses Azure OpenAI-style deployment endpoints:
|
|
/openai/deployments/{deployment}/chat/completions
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
model_name: Model name or alias
|
|
system_prompt: Optional system prompt
|
|
temperature: Sampling temperature
|
|
max_output_tokens: Maximum tokens to generate
|
|
**kwargs: Additional provider-specific parameters
|
|
|
|
Returns:
|
|
ModelResponse with generated content and metadata
|
|
"""
|
|
# Validate model name against allow-list
|
|
if not self.validate_model_name(model_name):
|
|
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
|
|
|
|
# Validate parameters and fetch capabilities
|
|
self.validate_parameters(model_name, temperature)
|
|
capabilities = self.get_capabilities(model_name)
|
|
|
|
# Prepare messages
|
|
messages = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
# Build user message content
|
|
user_message_content = []
|
|
if prompt:
|
|
user_message_content.append({"type": "text", "text": prompt})
|
|
|
|
if images and capabilities.supports_images:
|
|
for img_path in images:
|
|
processed_image = self._process_image(img_path)
|
|
if processed_image:
|
|
user_message_content.append(processed_image)
|
|
elif images:
|
|
logger.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
|
|
|
|
# Add user message. If only text, content will be a string, otherwise a list.
|
|
if len(user_message_content) == 1 and user_message_content[0]["type"] == "text":
|
|
messages.append({"role": "user", "content": prompt})
|
|
else:
|
|
messages.append({"role": "user", "content": user_message_content})
|
|
|
|
# Resolve model name
|
|
resolved_model = self._resolve_model_name(model_name)
|
|
|
|
# Build completion parameters
|
|
completion_params = {
|
|
"model": resolved_model,
|
|
"messages": messages,
|
|
}
|
|
|
|
# Determine temperature support from capabilities
|
|
supports_temperature = capabilities.supports_temperature
|
|
|
|
# Add temperature parameter if supported
|
|
if supports_temperature:
|
|
completion_params["temperature"] = temperature
|
|
|
|
# Add max tokens if specified and model supports it
|
|
if max_output_tokens and supports_temperature:
|
|
completion_params["max_tokens"] = max_output_tokens
|
|
|
|
# Add additional parameters
|
|
for key, value in kwargs.items():
|
|
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
|
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]:
|
|
continue
|
|
completion_params[key] = value
|
|
|
|
# DIAL-specific: Get cached client for deployment endpoint
|
|
deployment_client = self._get_deployment_client(resolved_model)
|
|
|
|
attempt_counter = {"value": 0}
|
|
|
|
def _attempt() -> ModelResponse:
|
|
attempt_counter["value"] += 1
|
|
response = deployment_client.chat.completions.create(**completion_params)
|
|
|
|
content = response.choices[0].message.content
|
|
usage = self._extract_usage(response)
|
|
|
|
return ModelResponse(
|
|
content=content,
|
|
usage=usage,
|
|
model_name=model_name,
|
|
friendly_name=self.FRIENDLY_NAME,
|
|
provider=self.get_provider_type(),
|
|
metadata={
|
|
"finish_reason": response.choices[0].finish_reason,
|
|
"model": response.model,
|
|
"id": response.id,
|
|
"created": response.created,
|
|
},
|
|
)
|
|
|
|
try:
|
|
return self._run_with_retries(
|
|
operation=_attempt,
|
|
max_attempts=self.MAX_RETRIES,
|
|
delays=self.RETRY_DELAYS,
|
|
log_prefix=f"DIAL API ({resolved_model})",
|
|
)
|
|
except Exception as exc:
|
|
attempts = max(attempt_counter["value"], 1)
|
|
if attempts == 1:
|
|
raise ValueError(f"DIAL API error for model {model_name}: {exc}") from exc
|
|
|
|
raise ValueError(f"DIAL API error for model {model_name} after {attempts} attempts: {exc}") from exc
|
|
|
|
def close(self) -> None:
|
|
"""Clean up HTTP clients when provider is closed."""
|
|
logger.info("Closing DIAL provider HTTP clients...")
|
|
|
|
# Clear the deployment clients cache
|
|
# Note: We don't need to close individual OpenAI clients since they
|
|
# use the shared httpx.Client which we close separately
|
|
self._deployment_clients.clear()
|
|
|
|
# Close the shared HTTP client
|
|
if hasattr(self, "_http_client"):
|
|
try:
|
|
self._http_client.close()
|
|
logger.debug("Closed shared HTTP client")
|
|
except Exception as e:
|
|
logger.warning(f"Error closing shared HTTP client: {e}")
|
|
|
|
# Also close the client created by the superclass (OpenAICompatibleProvider)
|
|
# as it holds its own httpx.Client instance that is not used by DIAL's generate_content
|
|
if hasattr(self, "client") and self.client and hasattr(self.client, "close"):
|
|
try:
|
|
self.client.close()
|
|
logger.debug("Closed superclass's OpenAI client")
|
|
except Exception as e:
|
|
logger.warning(f"Error closing superclass's OpenAI client: {e}")
|