Use ModelCapabilities consistently instead of dictionaries
Moved aliases as part of SUPPORTED_MODELS instead of shorthand, more in line with how custom_models are declared Further refactoring to cleanup some code
This commit is contained in:
@@ -10,7 +10,7 @@ from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
@@ -30,63 +30,161 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
MAX_RETRIES = 4
|
||||
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||
|
||||
# Supported DIAL models (these can be customized based on your DIAL deployment)
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"o3-2025-04-16": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"o4-mini-2025-04-16": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": True, # Thinking mode variant
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-opus-4-20250514-v1:0": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": True, # Thinking mode variant
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-pro-preview-03-25-google-search": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-pro-preview-05-06": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-flash-preview-05-20": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Shorthands
|
||||
"o3": "o3-2025-04-16",
|
||||
"o4-mini": "o4-mini-2025-04-16",
|
||||
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
||||
"opus-4": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
|
||||
"o3-2025-04-16": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="o3-2025-04-16",
|
||||
friendly_name="DIAL (O3)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=False, # O3 models don't accept temperature
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="OpenAI O3 via DIAL - Strong reasoning model",
|
||||
aliases=["o3"],
|
||||
),
|
||||
"o4-mini-2025-04-16": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="o4-mini-2025-04-16",
|
||||
friendly_name="DIAL (O4-mini)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=False, # O4 models don't accept temperature
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="OpenAI O4-mini via DIAL - Fast reasoning model",
|
||||
aliases=["o4-mini"],
|
||||
),
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
friendly_name="DIAL (Sonnet 4)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Sonnet 4 via DIAL - Balanced performance",
|
||||
aliases=["sonnet-4"],
|
||||
),
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
||||
friendly_name="DIAL (Sonnet 4 Thinking)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=True, # Thinking mode variant
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Sonnet 4 with thinking mode via DIAL",
|
||||
aliases=["sonnet-4-thinking"],
|
||||
),
|
||||
"anthropic.claude-opus-4-20250514-v1:0": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-opus-4-20250514-v1:0",
|
||||
friendly_name="DIAL (Opus 4)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Opus 4 via DIAL - Most capable Claude model",
|
||||
aliases=["opus-4"],
|
||||
),
|
||||
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
||||
friendly_name="DIAL (Opus 4 Thinking)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=True, # Thinking mode variant
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Opus 4 with thinking mode via DIAL",
|
||||
aliases=["opus-4-thinking"],
|
||||
),
|
||||
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-pro-preview-03-25-google-search",
|
||||
friendly_name="DIAL (Gemini 2.5 Pro Search)",
|
||||
context_window=1_000_000,
|
||||
supports_extended_thinking=False, # DIAL doesn't expose thinking mode
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Gemini 2.5 Pro with Google Search via DIAL",
|
||||
aliases=["gemini-2.5-pro-search"],
|
||||
),
|
||||
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-pro-preview-05-06",
|
||||
friendly_name="DIAL (Gemini 2.5 Pro)",
|
||||
context_window=1_000_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
|
||||
aliases=["gemini-2.5-pro"],
|
||||
),
|
||||
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-flash-preview-05-20",
|
||||
friendly_name="DIAL (Gemini Flash 2.5)",
|
||||
context_window=1_000_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
|
||||
aliases=["gemini-2.5-flash"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
@@ -181,20 +279,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name=resolved_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_images=config.get("supports_vision", False),
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
)
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -211,7 +297,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
return False
|
||||
|
||||
# Check against base class allowed_models if configured
|
||||
@@ -231,20 +317,6 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
|
||||
Args:
|
||||
model_name: Model name or shorthand
|
||||
|
||||
Returns:
|
||||
Full model name
|
||||
"""
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
def _get_deployment_client(self, deployment: str):
|
||||
"""Get or create a cached client for a specific deployment.
|
||||
|
||||
@@ -357,7 +429,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
# Check model capabilities
|
||||
try:
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
supports_temperature = getattr(capabilities, "supports_temperature", True)
|
||||
supports_temperature = capabilities.supports_temperature
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
||||
supports_temperature = True
|
||||
@@ -441,63 +513,12 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
|
||||
if resolved_name in self.SUPPORTED_MODELS:
|
||||
return self.SUPPORTED_MODELS[resolved_name].supports_images
|
||||
|
||||
# Fall back to parent implementation for unknown models
|
||||
return super()._supports_vision(model_name)
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
# Get all model keys (both full names and aliases)
|
||||
all_models = list(self.SUPPORTED_MODELS.keys())
|
||||
|
||||
if not respect_restrictions:
|
||||
return all_models
|
||||
|
||||
# Apply restrictions if configured
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
|
||||
# Filter based on restrictions
|
||||
allowed_models = []
|
||||
for model in all_models:
|
||||
resolved_name = self._resolve_model_name(model)
|
||||
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
|
||||
allowed_models.append(model)
|
||||
|
||||
return allowed_models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
This is used for validation purposes to ensure restriction policies
|
||||
can validate against both aliases and their target model names.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
# Collect all unique model names (both aliases and targets)
|
||||
all_models = set()
|
||||
|
||||
for key, value in self.SUPPORTED_MODELS.items():
|
||||
# Add the key (could be alias or full name)
|
||||
all_models.add(key)
|
||||
|
||||
# If it's an alias (string value), add the target too
|
||||
if isinstance(value, str):
|
||||
all_models.add(value)
|
||||
|
||||
return sorted(all_models)
|
||||
|
||||
def close(self):
|
||||
"""Clean up HTTP clients when provider is closed."""
|
||||
logger.info("Closing DIAL provider HTTP clients...")
|
||||
|
||||
Reference in New Issue
Block a user