Use ModelCapabilities consistently instead of dictionaries
Moved aliases as part of SUPPORTED_MODELS instead of shorthand, more in line with how custom_models are declared Further refactoring to cleanup some code
This commit is contained in:
@@ -140,6 +140,19 @@ class ModelCapabilities:
|
||||
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
||||
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
||||
|
||||
# Additional fields for comprehensive model information
|
||||
description: str = "" # Human-readable description of the model
|
||||
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
|
||||
|
||||
# JSON mode support (for providers that support structured output)
|
||||
supports_json_mode: bool = False
|
||||
|
||||
# Thinking mode support (for models with thinking capabilities)
|
||||
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
|
||||
|
||||
# Custom model flag (for models that only work with custom endpoints)
|
||||
is_custom: bool = False # Whether this model requires custom API endpoints
|
||||
|
||||
# Temperature constraint object - preferred way to define temperature limits
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
@@ -251,7 +264,7 @@ class ModelProvider(ABC):
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
# Check if model supports temperature at all
|
||||
if hasattr(capabilities, "supports_temperature") and not capabilities.supports_temperature:
|
||||
if not capabilities.supports_temperature:
|
||||
return None
|
||||
|
||||
# Get temperature range
|
||||
@@ -290,19 +303,109 @@ class ModelProvider(ABC):
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations for this provider.
|
||||
|
||||
This is a hook method that subclasses can override to provide
|
||||
their model configurations from different sources.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
# Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects)
|
||||
if hasattr(self, "SUPPORTED_MODELS"):
|
||||
return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)}
|
||||
return {}
|
||||
|
||||
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||
"""Get all model aliases for this provider.
|
||||
|
||||
This is a hook method that subclasses can override to provide
|
||||
aliases from different sources.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their list of aliases
|
||||
"""
|
||||
# Default implementation extracts from ModelCapabilities objects
|
||||
aliases = {}
|
||||
for model_name, capabilities in self.get_model_configurations().items():
|
||||
if capabilities.aliases:
|
||||
aliases[model_name] = capabilities.aliases
|
||||
return aliases
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
|
||||
This implementation uses the hook methods to support different
|
||||
model configuration sources.
|
||||
|
||||
Args:
|
||||
model_name: Model name that may be an alias
|
||||
|
||||
Returns:
|
||||
Resolved model name
|
||||
"""
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
|
||||
# First check if it's already a base model name (case-sensitive exact match)
|
||||
if model_name in model_configs:
|
||||
return model_name
|
||||
|
||||
# Check case-insensitively for both base models and aliases
|
||||
model_name_lower = model_name.lower()
|
||||
|
||||
# Check base model names case-insensitively
|
||||
for base_model in model_configs:
|
||||
if base_model.lower() == model_name_lower:
|
||||
return base_model
|
||||
|
||||
# Check aliases from the hook method
|
||||
all_aliases = self.get_all_model_aliases()
|
||||
for base_model, aliases in all_aliases.items():
|
||||
if any(alias.lower() == model_name_lower for alias in aliases):
|
||||
return base_model
|
||||
|
||||
# If not found, return as-is
|
||||
return model_name
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
This implementation uses the get_model_configurations() hook
|
||||
to support different model configuration sources.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
pass
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
|
||||
for model_name in model_configs:
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
|
||||
# Add the base model
|
||||
models.append(model_name)
|
||||
|
||||
# Get aliases from the hook method
|
||||
all_aliases = self.get_all_model_aliases()
|
||||
for model_name, aliases in all_aliases.items():
|
||||
# Only add aliases for models that passed restriction check
|
||||
if model_name in models:
|
||||
models.extend(aliases)
|
||||
|
||||
return models
|
||||
|
||||
@abstractmethod
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
@@ -312,21 +415,22 @@ class ModelProvider(ABC):
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
pass
|
||||
all_models = set()
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
|
||||
Base implementation returns the model name unchanged.
|
||||
Subclasses should override to provide alias resolution.
|
||||
# Add all base model names
|
||||
for model_name in model_configs:
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
Args:
|
||||
model_name: Model name that may be an alias
|
||||
# Get aliases from the hook method and add them
|
||||
all_aliases = self.get_all_model_aliases()
|
||||
for _model_name, aliases in all_aliases.items():
|
||||
for alias in aliases:
|
||||
all_models.add(alias.lower())
|
||||
|
||||
Returns:
|
||||
Resolved model name
|
||||
"""
|
||||
return model_name
|
||||
return list(all_models)
|
||||
|
||||
def close(self):
|
||||
"""Clean up any resources held by the provider.
|
||||
|
||||
@@ -268,65 +268,55 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode.
|
||||
|
||||
Most custom/local models don't support extended thinking.
|
||||
|
||||
Args:
|
||||
model_name: Model to check
|
||||
|
||||
Returns:
|
||||
False (custom models generally don't support thinking mode)
|
||||
True if model supports thinking mode, False otherwise
|
||||
"""
|
||||
# Check if model is in registry
|
||||
config = self._registry.resolve(model_name) if self._registry else None
|
||||
if config and config.is_custom:
|
||||
# Trust the config from custom_models.json
|
||||
return config.supports_extended_thinking
|
||||
|
||||
# Default to False for unknown models
|
||||
return False
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations from the registry.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
For CustomProvider, we convert registry configurations to ModelCapabilities objects.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
from .base import ProviderType
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
configs = {}
|
||||
|
||||
if self._registry:
|
||||
# Get all models from the registry
|
||||
all_models = self._registry.list_models()
|
||||
aliases = self._registry.list_aliases()
|
||||
|
||||
# Add models that are validated by the custom provider
|
||||
for model_name in all_models + aliases:
|
||||
# Use the provider's validation logic to determine if this model
|
||||
# is appropriate for the custom endpoint
|
||||
# Get all models from registry
|
||||
for model_name in self._registry.list_models():
|
||||
# Only include custom models that this provider validates
|
||||
if self.validate_model_name(model_name):
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and config.is_custom:
|
||||
# Convert OpenRouterModelConfig to ModelCapabilities
|
||||
capabilities = config.to_capabilities()
|
||||
# Override provider type to CUSTOM for local models
|
||||
capabilities.provider = ProviderType.CUSTOM
|
||||
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
|
||||
configs[model_name] = capabilities
|
||||
|
||||
models.append(model_name)
|
||||
return configs
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||
"""Get all model aliases from the registry.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
Dictionary mapping model names to their list of aliases
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
if self._registry:
|
||||
# Get all models and aliases from the registry
|
||||
all_models.update(model.lower() for model in self._registry.list_models())
|
||||
all_models.update(alias.lower() for alias in self._registry.list_aliases())
|
||||
|
||||
# For each alias, also add its target
|
||||
for alias in self._registry.list_aliases():
|
||||
config = self._registry.resolve(alias)
|
||||
if config:
|
||||
all_models.add(config.model_name.lower())
|
||||
|
||||
return list(all_models)
|
||||
# Since aliases are now included in the configurations,
|
||||
# we can use the base class implementation
|
||||
return super().get_all_model_aliases()
|
||||
|
||||
@@ -10,7 +10,7 @@ from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
@@ -30,63 +30,161 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
MAX_RETRIES = 4
|
||||
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||
|
||||
# Supported DIAL models (these can be customized based on your DIAL deployment)
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"o3-2025-04-16": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"o4-mini-2025-04-16": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": True, # Thinking mode variant
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-opus-4-20250514-v1:0": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": True, # Thinking mode variant
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-pro-preview-03-25-google-search": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-pro-preview-05-06": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-flash-preview-05-20": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Shorthands
|
||||
"o3": "o3-2025-04-16",
|
||||
"o4-mini": "o4-mini-2025-04-16",
|
||||
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
||||
"opus-4": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
|
||||
"o3-2025-04-16": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="o3-2025-04-16",
|
||||
friendly_name="DIAL (O3)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=False, # O3 models don't accept temperature
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="OpenAI O3 via DIAL - Strong reasoning model",
|
||||
aliases=["o3"],
|
||||
),
|
||||
"o4-mini-2025-04-16": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="o4-mini-2025-04-16",
|
||||
friendly_name="DIAL (O4-mini)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=False, # O4 models don't accept temperature
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="OpenAI O4-mini via DIAL - Fast reasoning model",
|
||||
aliases=["o4-mini"],
|
||||
),
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
friendly_name="DIAL (Sonnet 4)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Sonnet 4 via DIAL - Balanced performance",
|
||||
aliases=["sonnet-4"],
|
||||
),
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
||||
friendly_name="DIAL (Sonnet 4 Thinking)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=True, # Thinking mode variant
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Sonnet 4 with thinking mode via DIAL",
|
||||
aliases=["sonnet-4-thinking"],
|
||||
),
|
||||
"anthropic.claude-opus-4-20250514-v1:0": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-opus-4-20250514-v1:0",
|
||||
friendly_name="DIAL (Opus 4)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Opus 4 via DIAL - Most capable Claude model",
|
||||
aliases=["opus-4"],
|
||||
),
|
||||
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
||||
friendly_name="DIAL (Opus 4 Thinking)",
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=True, # Thinking mode variant
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Opus 4 with thinking mode via DIAL",
|
||||
aliases=["opus-4-thinking"],
|
||||
),
|
||||
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-pro-preview-03-25-google-search",
|
||||
friendly_name="DIAL (Gemini 2.5 Pro Search)",
|
||||
context_window=1_000_000,
|
||||
supports_extended_thinking=False, # DIAL doesn't expose thinking mode
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Gemini 2.5 Pro with Google Search via DIAL",
|
||||
aliases=["gemini-2.5-pro-search"],
|
||||
),
|
||||
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-pro-preview-05-06",
|
||||
friendly_name="DIAL (Gemini 2.5 Pro)",
|
||||
context_window=1_000_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
|
||||
aliases=["gemini-2.5-pro"],
|
||||
),
|
||||
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-flash-preview-05-20",
|
||||
friendly_name="DIAL (Gemini Flash 2.5)",
|
||||
context_window=1_000_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
|
||||
aliases=["gemini-2.5-flash"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
@@ -181,20 +279,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name=resolved_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_images=config.get("supports_vision", False),
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
)
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -211,7 +297,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
return False
|
||||
|
||||
# Check against base class allowed_models if configured
|
||||
@@ -231,20 +317,6 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
|
||||
Args:
|
||||
model_name: Model name or shorthand
|
||||
|
||||
Returns:
|
||||
Full model name
|
||||
"""
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
def _get_deployment_client(self, deployment: str):
|
||||
"""Get or create a cached client for a specific deployment.
|
||||
|
||||
@@ -357,7 +429,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
# Check model capabilities
|
||||
try:
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
supports_temperature = getattr(capabilities, "supports_temperature", True)
|
||||
supports_temperature = capabilities.supports_temperature
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
||||
supports_temperature = True
|
||||
@@ -441,63 +513,12 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
|
||||
if resolved_name in self.SUPPORTED_MODELS:
|
||||
return self.SUPPORTED_MODELS[resolved_name].supports_images
|
||||
|
||||
# Fall back to parent implementation for unknown models
|
||||
return super()._supports_vision(model_name)
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
# Get all model keys (both full names and aliases)
|
||||
all_models = list(self.SUPPORTED_MODELS.keys())
|
||||
|
||||
if not respect_restrictions:
|
||||
return all_models
|
||||
|
||||
# Apply restrictions if configured
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
|
||||
# Filter based on restrictions
|
||||
allowed_models = []
|
||||
for model in all_models:
|
||||
resolved_name = self._resolve_model_name(model)
|
||||
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
|
||||
allowed_models.append(model)
|
||||
|
||||
return allowed_models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
This is used for validation purposes to ensure restriction policies
|
||||
can validate against both aliases and their target model names.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
# Collect all unique model names (both aliases and targets)
|
||||
all_models = set()
|
||||
|
||||
for key, value in self.SUPPORTED_MODELS.items():
|
||||
# Add the key (could be alias or full name)
|
||||
all_models.add(key)
|
||||
|
||||
# If it's an alias (string value), add the target too
|
||||
if isinstance(value, str):
|
||||
all_models.add(value)
|
||||
|
||||
return sorted(all_models)
|
||||
|
||||
def close(self):
|
||||
"""Clean up HTTP clients when provider is closed."""
|
||||
logger.info("Closing DIAL provider HTTP clients...")
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Optional
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
|
||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,47 +17,79 @@ logger = logging.getLogger(__name__)
|
||||
class GeminiModelProvider(ModelProvider):
|
||||
"""Google Gemini model provider implementation."""
|
||||
|
||||
# Model configurations
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"gemini-2.0-flash": {
|
||||
"context_window": 1_048_576, # 1M tokens
|
||||
"supports_extended_thinking": True, # Experimental thinking mode
|
||||
"max_thinking_tokens": 24576, # Same as 2.5 flash for consistency
|
||||
"supports_images": True, # Vision capability
|
||||
"max_image_size_mb": 20.0, # Conservative 20MB limit for reliability
|
||||
"description": "Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input",
|
||||
},
|
||||
"gemini-2.0-flash-lite": {
|
||||
"context_window": 1_048_576, # 1M tokens
|
||||
"supports_extended_thinking": False, # Not supported per user request
|
||||
"max_thinking_tokens": 0, # No thinking support
|
||||
"supports_images": False, # Does not support images
|
||||
"max_image_size_mb": 0.0, # No image support
|
||||
"description": "Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only",
|
||||
},
|
||||
"gemini-2.5-flash": {
|
||||
"context_window": 1_048_576, # 1M tokens
|
||||
"supports_extended_thinking": True,
|
||||
"max_thinking_tokens": 24576, # Flash 2.5 thinking budget limit
|
||||
"supports_images": True, # Vision capability
|
||||
"max_image_size_mb": 20.0, # Conservative 20MB limit for reliability
|
||||
"description": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
"context_window": 1_048_576, # 1M tokens
|
||||
"supports_extended_thinking": True,
|
||||
"max_thinking_tokens": 32768, # Pro 2.5 thinking budget limit
|
||||
"supports_images": True, # Vision capability
|
||||
"max_image_size_mb": 32.0, # Higher limit for Pro model
|
||||
"description": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||
},
|
||||
# Shorthands
|
||||
"flash": "gemini-2.5-flash",
|
||||
"flash-2.0": "gemini-2.0-flash",
|
||||
"flash2": "gemini-2.0-flash",
|
||||
"flashlite": "gemini-2.0-flash-lite",
|
||||
"flash-lite": "gemini-2.0-flash-lite",
|
||||
"pro": "gemini-2.5-pro",
|
||||
"gemini-2.0-flash": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.0-flash",
|
||||
friendly_name="Gemini (Flash 2.0)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
supports_extended_thinking=True, # Experimental thinking mode
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # Vision capability
|
||||
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
max_thinking_tokens=24576, # Same as 2.5 flash for consistency
|
||||
description="Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input",
|
||||
aliases=["flash-2.0", "flash2"],
|
||||
),
|
||||
"gemini-2.0-flash-lite": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.0-flash-lite",
|
||||
friendly_name="Gemin (Flash Lite 2.0)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
supports_extended_thinking=False, # Not supported per user request
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=False, # Does not support images
|
||||
max_image_size_mb=0.0, # No image support
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only",
|
||||
aliases=["flashlite", "flash-lite"],
|
||||
),
|
||||
"gemini-2.5-flash": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.5-flash",
|
||||
friendly_name="Gemini (Flash 2.5)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # Vision capability
|
||||
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
max_thinking_tokens=24576, # Flash 2.5 thinking budget limit
|
||||
description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||
aliases=["flash", "flash2.5"],
|
||||
),
|
||||
"gemini-2.5-pro": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.5-pro",
|
||||
friendly_name="Gemini (Pro 2.5)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # Vision capability
|
||||
max_image_size_mb=32.0, # Higher limit for Pro model
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
max_thinking_tokens=32768, # Max thinking tokens for Pro model
|
||||
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||
aliases=["pro", "gemini pro", "gemini-pro"],
|
||||
),
|
||||
}
|
||||
|
||||
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
||||
@@ -70,6 +102,14 @@ class GeminiModelProvider(ModelProvider):
|
||||
"max": 1.0, # 100% of max - full thinking budget
|
||||
}
|
||||
|
||||
# Model-specific thinking token limits
|
||||
MAX_THINKING_TOKENS = {
|
||||
"gemini-2.0-flash": 24576, # Same as 2.5 flash for consistency
|
||||
"gemini-2.0-flash-lite": 0, # No thinking support
|
||||
"gemini-2.5-flash": 24576, # Flash 2.5 thinking budget limit
|
||||
"gemini-2.5-pro": 32768, # Pro 2.5 thinking budget limit
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize Gemini provider with API key."""
|
||||
super().__init__(api_key, **kwargs)
|
||||
@@ -100,25 +140,8 @@ class GeminiModelProvider(ModelProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
||||
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
# Gemini models support 0.0-2.0 temperature range
|
||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name=resolved_name,
|
||||
friendly_name="Gemini",
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_images=config.get("supports_images", False),
|
||||
max_image_size_mb=config.get("max_image_size_mb", 0.0),
|
||||
supports_temperature=True, # Gemini models accept temperature parameter
|
||||
temperature_constraint=temp_constraint,
|
||||
)
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
@@ -179,8 +202,8 @@ class GeminiModelProvider(ModelProvider):
|
||||
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||
# Get model's max thinking tokens and calculate actual budget
|
||||
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
||||
if model_config and "max_thinking_tokens" in model_config:
|
||||
max_thinking_tokens = model_config["max_thinking_tokens"]
|
||||
if model_config and model_config.max_thinking_tokens > 0:
|
||||
max_thinking_tokens = model_config.max_thinking_tokens
|
||||
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||
generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget)
|
||||
|
||||
@@ -258,7 +281,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
@@ -281,78 +304,20 @@ class GeminiModelProvider(ModelProvider):
|
||||
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
||||
"""Get actual thinking token budget for a model and thinking mode."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
model_config = self.SUPPORTED_MODELS.get(resolved_name, {})
|
||||
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
||||
|
||||
if not model_config.get("supports_extended_thinking", False):
|
||||
if not model_config or not model_config.supports_extended_thinking:
|
||||
return 0
|
||||
|
||||
if thinking_mode not in self.THINKING_BUDGETS:
|
||||
return 0
|
||||
|
||||
max_thinking_tokens = model_config.get("max_thinking_tokens", 0)
|
||||
max_thinking_tokens = model_config.max_thinking_tokens
|
||||
if max_thinking_tokens == 0:
|
||||
return 0
|
||||
|
||||
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Handle both base models (dict configs) and aliases (string values)
|
||||
if isinstance(config, str):
|
||||
# This is an alias - check if the target model would be allowed
|
||||
target_model = config
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||
continue
|
||||
# Allow the alias
|
||||
models.append(model_name)
|
||||
else:
|
||||
# This is a base model with config dict
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
models.append(model_name)
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Add the model name itself
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
# If it's an alias (string value), add the target model too
|
||||
if isinstance(config, str):
|
||||
all_models.add(config.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
# Check if it's a shorthand
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower())
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
def _extract_usage(self, response) -> dict[str, int]:
|
||||
"""Extract token usage from Gemini response."""
|
||||
usage = {}
|
||||
|
||||
@@ -17,71 +17,110 @@ logger = logging.getLogger(__name__)
|
||||
class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
"""Official OpenAI API provider (api.openai.com)."""
|
||||
|
||||
# Model configurations
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"o3": {
|
||||
"context_window": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # O3 models support vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
||||
"description": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
|
||||
},
|
||||
"o3-mini": {
|
||||
"context_window": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # O3 models support vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
||||
"description": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||
},
|
||||
"o3-pro-2025-06-10": {
|
||||
"context_window": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # O3 models support vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
||||
"description": "Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
||||
},
|
||||
# Aliases
|
||||
"o3-pro": "o3-pro-2025-06-10",
|
||||
"o4-mini": {
|
||||
"context_window": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # O4 models support vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": False, # O4 models don't accept temperature parameter
|
||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
||||
"description": "Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
||||
},
|
||||
"o4-mini-high": {
|
||||
"context_window": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # O4 models support vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": False, # O4 models don't accept temperature parameter
|
||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
||||
"description": "Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks",
|
||||
},
|
||||
"gpt-4.1-2025-04-14": {
|
||||
"context_window": 1_000_000, # 1M tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # GPT-4.1 supports vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": True, # Regular models accept temperature parameter
|
||||
"temperature_constraint": "range", # 0.0-2.0 range
|
||||
"description": "GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
||||
},
|
||||
# Shorthands
|
||||
"mini": "o4-mini", # Default 'mini' to latest mini model
|
||||
"o3mini": "o3-mini",
|
||||
"o4mini": "o4-mini",
|
||||
"o4minihigh": "o4-mini-high",
|
||||
"o4minihi": "o4-mini-high",
|
||||
"gpt4.1": "gpt-4.1-2025-04-14",
|
||||
"o3": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o3",
|
||||
friendly_name="OpenAI (O3)",
|
||||
context_window=200_000, # 200K tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O3 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
|
||||
aliases=[],
|
||||
),
|
||||
"o3-mini": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o3-mini",
|
||||
friendly_name="OpenAI (O3-mini)",
|
||||
context_window=200_000, # 200K tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O3 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||
aliases=["o3mini", "o3-mini"],
|
||||
),
|
||||
"o3-pro-2025-06-10": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o3-pro-2025-06-10",
|
||||
friendly_name="OpenAI (O3-Pro)",
|
||||
context_window=200_000, # 200K tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O3 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
||||
aliases=["o3-pro"],
|
||||
),
|
||||
"o4-mini": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o4-mini",
|
||||
friendly_name="OpenAI (O4-mini)",
|
||||
context_window=200_000, # 200K tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O4 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O4 models don't accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
||||
aliases=["mini", "o4mini"],
|
||||
),
|
||||
"o4-mini-high": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o4-mini-high",
|
||||
friendly_name="OpenAI (O4-mini-high)",
|
||||
context_window=200_000, # 200K tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O4 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O4 models don't accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks",
|
||||
aliases=["o4minihigh", "o4minihi", "mini-high"],
|
||||
),
|
||||
"gpt-4.1-2025-04-14": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="gpt-4.1-2025-04-14",
|
||||
friendly_name="OpenAI (GPT 4.1)",
|
||||
context_window=1_000_000, # 1M tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # GPT-4.1 supports vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=True, # Regular models accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
||||
aliases=["gpt4.1"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
@@ -95,7 +134,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
# Resolve shorthand
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
@@ -105,27 +144,8 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
# Get temperature constraints and support from configuration
|
||||
supports_temperature = config.get("supports_temperature", True) # Default to True for backward compatibility
|
||||
temp_constraint_type = config.get("temperature_constraint", "range") # Default to range
|
||||
temp_constraint = create_temperature_constraint(temp_constraint_type)
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name=model_name,
|
||||
friendly_name="OpenAI",
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_images=config.get("supports_images", False),
|
||||
max_image_size_mb=config.get("max_image_size_mb", 0.0),
|
||||
supports_temperature=supports_temperature,
|
||||
temperature_constraint=temp_constraint,
|
||||
)
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -136,7 +156,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
@@ -177,61 +197,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
# Currently no OpenAI models support extended thinking
|
||||
# This may change with future O3 models
|
||||
return False
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Handle both base models (dict configs) and aliases (string values)
|
||||
if isinstance(config, str):
|
||||
# This is an alias - check if the target model would be allowed
|
||||
target_model = config
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||
continue
|
||||
# Allow the alias
|
||||
models.append(model_name)
|
||||
else:
|
||||
# This is a base model with config dict
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
models.append(model_name)
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Add the model name itself
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
# If it's an alias (string value), add the target model too
|
||||
if isinstance(config, str):
|
||||
all_models.add(config.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
# Check if it's a shorthand
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
@@ -270,3 +270,39 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
all_models.add(config.model_name.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations from the registry.
|
||||
|
||||
For OpenRouter, we convert registry configurations to ModelCapabilities objects.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
configs = {}
|
||||
|
||||
if self._registry:
|
||||
# Get all models from registry
|
||||
for model_name in self._registry.list_models():
|
||||
# Only include models that this provider validates
|
||||
if self.validate_model_name(model_name):
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and not config.is_custom: # Only OpenRouter models, not custom ones
|
||||
# Convert OpenRouterModelConfig to ModelCapabilities
|
||||
capabilities = config.to_capabilities()
|
||||
# Override provider type to OPENROUTER
|
||||
capabilities.provider = ProviderType.OPENROUTER
|
||||
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
|
||||
configs[model_name] = capabilities
|
||||
|
||||
return configs
|
||||
|
||||
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||
"""Get all model aliases from the registry.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their list of aliases
|
||||
"""
|
||||
# Since aliases are now included in the configurations,
|
||||
# we can use the base class implementation
|
||||
return super().get_all_model_aliases()
|
||||
|
||||
134
providers/xai.py
134
providers/xai.py
@@ -7,7 +7,7 @@ from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
@@ -19,23 +19,42 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
FRIENDLY_NAME = "X.AI"
|
||||
|
||||
# Model configurations
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"grok-3": {
|
||||
"context_window": 131_072, # 131K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"description": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||
},
|
||||
"grok-3-fast": {
|
||||
"context_window": 131_072, # 131K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"description": "GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
||||
},
|
||||
# Shorthands for convenience
|
||||
"grok": "grok-3", # Default to grok-3
|
||||
"grok3": "grok-3",
|
||||
"grok3fast": "grok-3-fast",
|
||||
"grokfast": "grok-3-fast",
|
||||
"grok-3": ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name="grok-3",
|
||||
friendly_name="X.AI (Grok 3)",
|
||||
context_window=131_072, # 131K tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||
supports_images=False, # Assuming GROK is text-only for now
|
||||
max_image_size_mb=0.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||
aliases=["grok", "grok3"],
|
||||
),
|
||||
"grok-3-fast": ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name="grok-3-fast",
|
||||
friendly_name="X.AI (Grok 3 Fast)",
|
||||
context_window=131_072, # 131K tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||
supports_images=False, # Assuming GROK is text-only for now
|
||||
max_image_size_mb=0.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
||||
aliases=["grok3fast", "grokfast", "grok3-fast"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
@@ -49,7 +68,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
# Resolve shorthand
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
@@ -59,23 +78,8 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
||||
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
# Define temperature constraints for GROK models
|
||||
# GROK supports the standard OpenAI temperature range
|
||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name=resolved_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
temperature_constraint=temp_constraint,
|
||||
)
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -86,7 +90,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
@@ -127,61 +131,3 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
# Currently GROK models do not support extended thinking
|
||||
# This may change with future GROK model releases
|
||||
return False
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Handle both base models (dict configs) and aliases (string values)
|
||||
if isinstance(config, str):
|
||||
# This is an alias - check if the target model would be allowed
|
||||
target_model = config
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||
continue
|
||||
# Allow the alias
|
||||
models.append(model_name)
|
||||
else:
|
||||
# This is a base model with config dict
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
models.append(model_name)
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Add the model name itself
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
# If it's an alias (string value), add the target model too
|
||||
if isinstance(config, str):
|
||||
all_models.add(config.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
# Check if it's a shorthand
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
Reference in New Issue
Block a user