Use ModelCapabilities consistently instead of dictionaries

Moved aliases as part of SUPPORTED_MODELS instead of shorthand, more in line with how custom_models are declared
Further refactoring to cleanup some code
This commit is contained in:
Fahad
2025-06-23 16:58:59 +04:00
parent e94c028a3f
commit 498ea88293
16 changed files with 850 additions and 605 deletions

View File

@@ -10,7 +10,7 @@ from .base import (
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
create_temperature_constraint,
)
from .openai_compatible import OpenAICompatibleProvider
@@ -30,63 +30,161 @@ class DIALModelProvider(OpenAICompatibleProvider):
MAX_RETRIES = 4
RETRY_DELAYS = [1, 3, 5, 8] # seconds
# Supported DIAL models (these can be customized based on your DIAL deployment)
# Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = {
"o3-2025-04-16": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"o4-mini-2025-04-16": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"anthropic.claude-sonnet-4-20250514-v1:0": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
"context_window": 200_000,
"supports_extended_thinking": True, # Thinking mode variant
"supports_vision": True,
},
"anthropic.claude-opus-4-20250514-v1:0": {
"context_window": 200_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
"context_window": 200_000,
"supports_extended_thinking": True, # Thinking mode variant
"supports_vision": True,
},
"gemini-2.5-pro-preview-03-25-google-search": {
"context_window": 1_000_000,
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode
"supports_vision": True,
},
"gemini-2.5-pro-preview-05-06": {
"context_window": 1_000_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
"gemini-2.5-flash-preview-05-20": {
"context_window": 1_000_000,
"supports_extended_thinking": False,
"supports_vision": True,
},
# Shorthands
"o3": "o3-2025-04-16",
"o4-mini": "o4-mini-2025-04-16",
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
"opus-4": "anthropic.claude-opus-4-20250514-v1:0",
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
"o3-2025-04-16": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="o3-2025-04-16",
friendly_name="DIAL (O3)",
context_window=200_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=False, # O3 models don't accept temperature
temperature_constraint=create_temperature_constraint("fixed"),
description="OpenAI O3 via DIAL - Strong reasoning model",
aliases=["o3"],
),
"o4-mini-2025-04-16": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="o4-mini-2025-04-16",
friendly_name="DIAL (O4-mini)",
context_window=200_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=False, # O4 models don't accept temperature
temperature_constraint=create_temperature_constraint("fixed"),
description="OpenAI O4-mini via DIAL - Fast reasoning model",
aliases=["o4-mini"],
),
"anthropic.claude-sonnet-4-20250514-v1:0": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
friendly_name="DIAL (Sonnet 4)",
context_window=200_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Claude doesn't have function calling
supports_json_mode=False, # Claude doesn't have JSON mode
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Claude Sonnet 4 via DIAL - Balanced performance",
aliases=["sonnet-4"],
),
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
friendly_name="DIAL (Sonnet 4 Thinking)",
context_window=200_000,
supports_extended_thinking=True, # Thinking mode variant
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Claude doesn't have function calling
supports_json_mode=False, # Claude doesn't have JSON mode
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Claude Sonnet 4 with thinking mode via DIAL",
aliases=["sonnet-4-thinking"],
),
"anthropic.claude-opus-4-20250514-v1:0": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-opus-4-20250514-v1:0",
friendly_name="DIAL (Opus 4)",
context_window=200_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Claude doesn't have function calling
supports_json_mode=False, # Claude doesn't have JSON mode
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Claude Opus 4 via DIAL - Most capable Claude model",
aliases=["opus-4"],
),
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-opus-4-20250514-v1:0-with-thinking",
friendly_name="DIAL (Opus 4 Thinking)",
context_window=200_000,
supports_extended_thinking=True, # Thinking mode variant
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Claude doesn't have function calling
supports_json_mode=False, # Claude doesn't have JSON mode
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Claude Opus 4 with thinking mode via DIAL",
aliases=["opus-4-thinking"],
),
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-pro-preview-03-25-google-search",
friendly_name="DIAL (Gemini 2.5 Pro Search)",
context_window=1_000_000,
supports_extended_thinking=False, # DIAL doesn't expose thinking mode
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Gemini 2.5 Pro with Google Search via DIAL",
aliases=["gemini-2.5-pro-search"],
),
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-pro-preview-05-06",
friendly_name="DIAL (Gemini 2.5 Pro)",
context_window=1_000_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
aliases=["gemini-2.5-pro"],
),
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-flash-preview-05-20",
friendly_name="DIAL (Gemini Flash 2.5)",
context_window=1_000_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
aliases=["gemini-2.5-flash"],
),
}
def __init__(self, api_key: str, **kwargs):
@@ -181,20 +279,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
config = self.SUPPORTED_MODELS[resolved_name]
return ModelCapabilities(
provider=ProviderType.DIAL,
model_name=resolved_name,
friendly_name=self.FRIENDLY_NAME,
context_window=config["context_window"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_images=config.get("supports_vision", False),
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
)
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
return self.SUPPORTED_MODELS[resolved_name]
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
@@ -211,7 +297,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
if resolved_name not in self.SUPPORTED_MODELS:
return False
# Check against base class allowed_models if configured
@@ -231,20 +317,6 @@ class DIALModelProvider(OpenAICompatibleProvider):
return True
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name.
Args:
model_name: Model name or shorthand
Returns:
Full model name
"""
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
def _get_deployment_client(self, deployment: str):
"""Get or create a cached client for a specific deployment.
@@ -357,7 +429,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
# Check model capabilities
try:
capabilities = self.get_capabilities(model_name)
supports_temperature = getattr(capabilities, "supports_temperature", True)
supports_temperature = capabilities.supports_temperature
except Exception as e:
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
supports_temperature = True
@@ -441,63 +513,12 @@ class DIALModelProvider(OpenAICompatibleProvider):
"""
resolved_name = self._resolve_model_name(model_name)
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
if resolved_name in self.SUPPORTED_MODELS:
return self.SUPPORTED_MODELS[resolved_name].supports_images
# Fall back to parent implementation for unknown models
return super()._supports_vision(model_name)
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
# Get all model keys (both full names and aliases)
all_models = list(self.SUPPORTED_MODELS.keys())
if not respect_restrictions:
return all_models
# Apply restrictions if configured
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
# Filter based on restrictions
allowed_models = []
for model in all_models:
resolved_name = self._resolve_model_name(model)
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
allowed_models.append(model)
return allowed_models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
This is used for validation purposes to ensure restriction policies
can validate against both aliases and their target model names.
Returns:
List of all model names and alias targets known by this provider
"""
# Collect all unique model names (both aliases and targets)
all_models = set()
for key, value in self.SUPPORTED_MODELS.items():
# Add the key (could be alias or full name)
all_models.add(key)
# If it's an alias (string value), add the target too
if isinstance(value, str):
all_models.add(value)
return sorted(all_models)
def close(self):
"""Clean up HTTP clients when provider is closed."""
logger.info("Closing DIAL provider HTTP clients...")