Use ModelCapabilities consistently instead of dictionaries
Moved aliases as part of SUPPORTED_MODELS instead of shorthand, more in line with how custom_models are declared Further refactoring to cleanup some code
This commit is contained in:
@@ -14,7 +14,7 @@ import os
|
|||||||
# These values are used in server responses and for tracking releases
|
# These values are used in server responses and for tracking releases
|
||||||
# IMPORTANT: This is the single source of truth for version and author info
|
# IMPORTANT: This is the single source of truth for version and author info
|
||||||
# Semantic versioning: MAJOR.MINOR.PATCH
|
# Semantic versioning: MAJOR.MINOR.PATCH
|
||||||
__version__ = "5.6.2"
|
__version__ = "5.7.0"
|
||||||
# Last update date in ISO format
|
# Last update date in ISO format
|
||||||
__updated__ = "2025-06-23"
|
__updated__ = "2025-06-23"
|
||||||
# Primary maintainer
|
# Primary maintainer
|
||||||
|
|||||||
@@ -140,6 +140,19 @@ class ModelCapabilities:
|
|||||||
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
||||||
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
||||||
|
|
||||||
|
# Additional fields for comprehensive model information
|
||||||
|
description: str = "" # Human-readable description of the model
|
||||||
|
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
|
||||||
|
|
||||||
|
# JSON mode support (for providers that support structured output)
|
||||||
|
supports_json_mode: bool = False
|
||||||
|
|
||||||
|
# Thinking mode support (for models with thinking capabilities)
|
||||||
|
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
|
||||||
|
|
||||||
|
# Custom model flag (for models that only work with custom endpoints)
|
||||||
|
is_custom: bool = False # Whether this model requires custom API endpoints
|
||||||
|
|
||||||
# Temperature constraint object - preferred way to define temperature limits
|
# Temperature constraint object - preferred way to define temperature limits
|
||||||
temperature_constraint: TemperatureConstraint = field(
|
temperature_constraint: TemperatureConstraint = field(
|
||||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||||
@@ -251,7 +264,7 @@ class ModelProvider(ABC):
|
|||||||
capabilities = self.get_capabilities(model_name)
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
# Check if model supports temperature at all
|
# Check if model supports temperature at all
|
||||||
if hasattr(capabilities, "supports_temperature") and not capabilities.supports_temperature:
|
if not capabilities.supports_temperature:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get temperature range
|
# Get temperature range
|
||||||
@@ -290,19 +303,109 @@ class ModelProvider(ABC):
|
|||||||
"""Check if the model supports extended thinking mode."""
|
"""Check if the model supports extended thinking mode."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Get model configurations for this provider.
|
||||||
|
|
||||||
|
This is a hook method that subclasses can override to provide
|
||||||
|
their model configurations from different sources.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping model names to their ModelCapabilities objects
|
||||||
|
"""
|
||||||
|
# Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects)
|
||||||
|
if hasattr(self, "SUPPORTED_MODELS"):
|
||||||
|
return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||||
|
"""Get all model aliases for this provider.
|
||||||
|
|
||||||
|
This is a hook method that subclasses can override to provide
|
||||||
|
aliases from different sources.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping model names to their list of aliases
|
||||||
|
"""
|
||||||
|
# Default implementation extracts from ModelCapabilities objects
|
||||||
|
aliases = {}
|
||||||
|
for model_name, capabilities in self.get_model_configurations().items():
|
||||||
|
if capabilities.aliases:
|
||||||
|
aliases[model_name] = capabilities.aliases
|
||||||
|
return aliases
|
||||||
|
|
||||||
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
"""Resolve model shorthand to full name.
|
||||||
|
|
||||||
|
This implementation uses the hook methods to support different
|
||||||
|
model configuration sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model name that may be an alias
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved model name
|
||||||
|
"""
|
||||||
|
# Get model configurations from the hook method
|
||||||
|
model_configs = self.get_model_configurations()
|
||||||
|
|
||||||
|
# First check if it's already a base model name (case-sensitive exact match)
|
||||||
|
if model_name in model_configs:
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
# Check case-insensitively for both base models and aliases
|
||||||
|
model_name_lower = model_name.lower()
|
||||||
|
|
||||||
|
# Check base model names case-insensitively
|
||||||
|
for base_model in model_configs:
|
||||||
|
if base_model.lower() == model_name_lower:
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
# Check aliases from the hook method
|
||||||
|
all_aliases = self.get_all_model_aliases()
|
||||||
|
for base_model, aliases in all_aliases.items():
|
||||||
|
if any(alias.lower() == model_name_lower for alias in aliases):
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
# If not found, return as-is
|
||||||
|
return model_name
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||||
"""Return a list of model names supported by this provider.
|
"""Return a list of model names supported by this provider.
|
||||||
|
|
||||||
|
This implementation uses the get_model_configurations() hook
|
||||||
|
to support different model configuration sources.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of model names available from this provider
|
List of model names available from this provider
|
||||||
"""
|
"""
|
||||||
pass
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models = []
|
||||||
|
|
||||||
|
# Get model configurations from the hook method
|
||||||
|
model_configs = self.get_model_configurations()
|
||||||
|
|
||||||
|
for model_name in model_configs:
|
||||||
|
# Check restrictions if enabled
|
||||||
|
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add the base model
|
||||||
|
models.append(model_name)
|
||||||
|
|
||||||
|
# Get aliases from the hook method
|
||||||
|
all_aliases = self.get_all_model_aliases()
|
||||||
|
for model_name, aliases in all_aliases.items():
|
||||||
|
# Only add aliases for models that passed restriction check
|
||||||
|
if model_name in models:
|
||||||
|
models.extend(aliases)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
def list_all_known_models(self) -> list[str]:
|
||||||
"""Return all model names known by this provider, including alias targets.
|
"""Return all model names known by this provider, including alias targets.
|
||||||
|
|
||||||
@@ -312,21 +415,22 @@ class ModelProvider(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
List of all model names and alias targets known by this provider
|
List of all model names and alias targets known by this provider
|
||||||
"""
|
"""
|
||||||
pass
|
all_models = set()
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
# Get model configurations from the hook method
|
||||||
"""Resolve model shorthand to full name.
|
model_configs = self.get_model_configurations()
|
||||||
|
|
||||||
Base implementation returns the model name unchanged.
|
# Add all base model names
|
||||||
Subclasses should override to provide alias resolution.
|
for model_name in model_configs:
|
||||||
|
all_models.add(model_name.lower())
|
||||||
|
|
||||||
Args:
|
# Get aliases from the hook method and add them
|
||||||
model_name: Model name that may be an alias
|
all_aliases = self.get_all_model_aliases()
|
||||||
|
for _model_name, aliases in all_aliases.items():
|
||||||
|
for alias in aliases:
|
||||||
|
all_models.add(alias.lower())
|
||||||
|
|
||||||
Returns:
|
return list(all_models)
|
||||||
Resolved model name
|
|
||||||
"""
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Clean up any resources held by the provider.
|
"""Clean up any resources held by the provider.
|
||||||
|
|||||||
@@ -268,65 +268,55 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
"""Check if the model supports extended thinking mode.
|
"""Check if the model supports extended thinking mode.
|
||||||
|
|
||||||
Most custom/local models don't support extended thinking.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Model to check
|
model_name: Model to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
False (custom models generally don't support thinking mode)
|
True if model supports thinking mode, False otherwise
|
||||||
"""
|
"""
|
||||||
|
# Check if model is in registry
|
||||||
|
config = self._registry.resolve(model_name) if self._registry else None
|
||||||
|
if config and config.is_custom:
|
||||||
|
# Trust the config from custom_models.json
|
||||||
|
return config.supports_extended_thinking
|
||||||
|
|
||||||
|
# Default to False for unknown models
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||||
"""Return a list of model names supported by this provider.
|
"""Get model configurations from the registry.
|
||||||
|
|
||||||
Args:
|
For CustomProvider, we convert registry configurations to ModelCapabilities objects.
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of model names available from this provider
|
Dictionary mapping model names to their ModelCapabilities objects
|
||||||
"""
|
"""
|
||||||
from utils.model_restrictions import get_restriction_service
|
from .base import ProviderType
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
configs = {}
|
||||||
models = []
|
|
||||||
|
|
||||||
if self._registry:
|
if self._registry:
|
||||||
# Get all models from the registry
|
# Get all models from registry
|
||||||
all_models = self._registry.list_models()
|
for model_name in self._registry.list_models():
|
||||||
aliases = self._registry.list_aliases()
|
# Only include custom models that this provider validates
|
||||||
|
|
||||||
# Add models that are validated by the custom provider
|
|
||||||
for model_name in all_models + aliases:
|
|
||||||
# Use the provider's validation logic to determine if this model
|
|
||||||
# is appropriate for the custom endpoint
|
|
||||||
if self.validate_model_name(model_name):
|
if self.validate_model_name(model_name):
|
||||||
# Check restrictions if enabled
|
config = self._registry.resolve(model_name)
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
if config and config.is_custom:
|
||||||
continue
|
# Convert OpenRouterModelConfig to ModelCapabilities
|
||||||
|
capabilities = config.to_capabilities()
|
||||||
|
# Override provider type to CUSTOM for local models
|
||||||
|
capabilities.provider = ProviderType.CUSTOM
|
||||||
|
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
|
||||||
|
configs[model_name] = capabilities
|
||||||
|
|
||||||
models.append(model_name)
|
return configs
|
||||||
|
|
||||||
return models
|
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||||
|
"""Get all model aliases from the registry.
|
||||||
def list_all_known_models(self) -> list[str]:
|
|
||||||
"""Return all model names known by this provider, including alias targets.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of all model names and alias targets known by this provider
|
Dictionary mapping model names to their list of aliases
|
||||||
"""
|
"""
|
||||||
all_models = set()
|
# Since aliases are now included in the configurations,
|
||||||
|
# we can use the base class implementation
|
||||||
if self._registry:
|
return super().get_all_model_aliases()
|
||||||
# Get all models and aliases from the registry
|
|
||||||
all_models.update(model.lower() for model in self._registry.list_models())
|
|
||||||
all_models.update(alias.lower() for alias in self._registry.list_aliases())
|
|
||||||
|
|
||||||
# For each alias, also add its target
|
|
||||||
for alias in self._registry.list_aliases():
|
|
||||||
config = self._registry.resolve(alias)
|
|
||||||
if config:
|
|
||||||
all_models.add(config.model_name.lower())
|
|
||||||
|
|
||||||
return list(all_models)
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from .base import (
|
|||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
create_temperature_constraint,
|
||||||
)
|
)
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
|
||||||
@@ -30,63 +30,161 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
MAX_RETRIES = 4
|
MAX_RETRIES = 4
|
||||||
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||||
|
|
||||||
# Supported DIAL models (these can be customized based on your DIAL deployment)
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"o3-2025-04-16": {
|
"o3-2025-04-16": ModelCapabilities(
|
||||||
"context_window": 200_000,
|
provider=ProviderType.DIAL,
|
||||||
"supports_extended_thinking": False,
|
model_name="o3-2025-04-16",
|
||||||
"supports_vision": True,
|
friendly_name="DIAL (O3)",
|
||||||
},
|
context_window=200_000,
|
||||||
"o4-mini-2025-04-16": {
|
supports_extended_thinking=False,
|
||||||
"context_window": 200_000,
|
supports_system_prompts=True,
|
||||||
"supports_extended_thinking": False,
|
supports_streaming=True,
|
||||||
"supports_vision": True,
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
},
|
supports_json_mode=True,
|
||||||
"anthropic.claude-sonnet-4-20250514-v1:0": {
|
supports_images=True,
|
||||||
"context_window": 200_000,
|
max_image_size_mb=20.0,
|
||||||
"supports_extended_thinking": False,
|
supports_temperature=False, # O3 models don't accept temperature
|
||||||
"supports_vision": True,
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
},
|
description="OpenAI O3 via DIAL - Strong reasoning model",
|
||||||
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
|
aliases=["o3"],
|
||||||
"context_window": 200_000,
|
),
|
||||||
"supports_extended_thinking": True, # Thinking mode variant
|
"o4-mini-2025-04-16": ModelCapabilities(
|
||||||
"supports_vision": True,
|
provider=ProviderType.DIAL,
|
||||||
},
|
model_name="o4-mini-2025-04-16",
|
||||||
"anthropic.claude-opus-4-20250514-v1:0": {
|
friendly_name="DIAL (O4-mini)",
|
||||||
"context_window": 200_000,
|
context_window=200_000,
|
||||||
"supports_extended_thinking": False,
|
supports_extended_thinking=False,
|
||||||
"supports_vision": True,
|
supports_system_prompts=True,
|
||||||
},
|
supports_streaming=True,
|
||||||
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
"context_window": 200_000,
|
supports_json_mode=True,
|
||||||
"supports_extended_thinking": True, # Thinking mode variant
|
supports_images=True,
|
||||||
"supports_vision": True,
|
max_image_size_mb=20.0,
|
||||||
},
|
supports_temperature=False, # O4 models don't accept temperature
|
||||||
"gemini-2.5-pro-preview-03-25-google-search": {
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
"context_window": 1_000_000,
|
description="OpenAI O4-mini via DIAL - Fast reasoning model",
|
||||||
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode
|
aliases=["o4-mini"],
|
||||||
"supports_vision": True,
|
),
|
||||||
},
|
"anthropic.claude-sonnet-4-20250514-v1:0": ModelCapabilities(
|
||||||
"gemini-2.5-pro-preview-05-06": {
|
provider=ProviderType.DIAL,
|
||||||
"context_window": 1_000_000,
|
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||||
"supports_extended_thinking": False,
|
friendly_name="DIAL (Sonnet 4)",
|
||||||
"supports_vision": True,
|
context_window=200_000,
|
||||||
},
|
supports_extended_thinking=False,
|
||||||
"gemini-2.5-flash-preview-05-20": {
|
supports_system_prompts=True,
|
||||||
"context_window": 1_000_000,
|
supports_streaming=True,
|
||||||
"supports_extended_thinking": False,
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
"supports_vision": True,
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
},
|
supports_images=True,
|
||||||
# Shorthands
|
max_image_size_mb=5.0,
|
||||||
"o3": "o3-2025-04-16",
|
supports_temperature=True,
|
||||||
"o4-mini": "o4-mini-2025-04-16",
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
|
description="Claude Sonnet 4 via DIAL - Balanced performance",
|
||||||
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
aliases=["sonnet-4"],
|
||||||
"opus-4": "anthropic.claude-opus-4-20250514-v1:0",
|
),
|
||||||
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": ModelCapabilities(
|
||||||
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
|
provider=ProviderType.DIAL,
|
||||||
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
|
model_name="anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
||||||
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
|
friendly_name="DIAL (Sonnet 4 Thinking)",
|
||||||
|
context_window=200_000,
|
||||||
|
supports_extended_thinking=True, # Thinking mode variant
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=5.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Claude Sonnet 4 with thinking mode via DIAL",
|
||||||
|
aliases=["sonnet-4-thinking"],
|
||||||
|
),
|
||||||
|
"anthropic.claude-opus-4-20250514-v1:0": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="anthropic.claude-opus-4-20250514-v1:0",
|
||||||
|
friendly_name="DIAL (Opus 4)",
|
||||||
|
context_window=200_000,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=5.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Claude Opus 4 via DIAL - Most capable Claude model",
|
||||||
|
aliases=["opus-4"],
|
||||||
|
),
|
||||||
|
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
||||||
|
friendly_name="DIAL (Opus 4 Thinking)",
|
||||||
|
context_window=200_000,
|
||||||
|
supports_extended_thinking=True, # Thinking mode variant
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=5.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Claude Opus 4 with thinking mode via DIAL",
|
||||||
|
aliases=["opus-4-thinking"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="gemini-2.5-pro-preview-03-25-google-search",
|
||||||
|
friendly_name="DIAL (Gemini 2.5 Pro Search)",
|
||||||
|
context_window=1_000_000,
|
||||||
|
supports_extended_thinking=False, # DIAL doesn't expose thinking mode
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Gemini 2.5 Pro with Google Search via DIAL",
|
||||||
|
aliases=["gemini-2.5-pro-search"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="gemini-2.5-pro-preview-05-06",
|
||||||
|
friendly_name="DIAL (Gemini 2.5 Pro)",
|
||||||
|
context_window=1_000_000,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
|
||||||
|
aliases=["gemini-2.5-pro"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="gemini-2.5-flash-preview-05-20",
|
||||||
|
friendly_name="DIAL (Gemini Flash 2.5)",
|
||||||
|
context_window=1_000_000,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
|
||||||
|
aliases=["gemini-2.5-flash"],
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
@@ -181,20 +279,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
config = self.SUPPORTED_MODELS[resolved_name]
|
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||||
|
return self.SUPPORTED_MODELS[resolved_name]
|
||||||
return ModelCapabilities(
|
|
||||||
provider=ProviderType.DIAL,
|
|
||||||
model_name=resolved_name,
|
|
||||||
friendly_name=self.FRIENDLY_NAME,
|
|
||||||
context_window=config["context_window"],
|
|
||||||
supports_extended_thinking=config["supports_extended_thinking"],
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=True,
|
|
||||||
supports_images=config.get("supports_vision", False),
|
|
||||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
@@ -211,7 +297,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
"""
|
"""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check against base class allowed_models if configured
|
# Check against base class allowed_models if configured
|
||||||
@@ -231,20 +317,6 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
|
||||||
"""Resolve model shorthand to full name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Model name or shorthand
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Full model name
|
|
||||||
"""
|
|
||||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
|
||||||
if isinstance(shorthand_value, str):
|
|
||||||
return shorthand_value
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
def _get_deployment_client(self, deployment: str):
|
def _get_deployment_client(self, deployment: str):
|
||||||
"""Get or create a cached client for a specific deployment.
|
"""Get or create a cached client for a specific deployment.
|
||||||
|
|
||||||
@@ -357,7 +429,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
# Check model capabilities
|
# Check model capabilities
|
||||||
try:
|
try:
|
||||||
capabilities = self.get_capabilities(model_name)
|
capabilities = self.get_capabilities(model_name)
|
||||||
supports_temperature = getattr(capabilities, "supports_temperature", True)
|
supports_temperature = capabilities.supports_temperature
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
||||||
supports_temperature = True
|
supports_temperature = True
|
||||||
@@ -441,63 +513,12 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
"""
|
"""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
if resolved_name in self.SUPPORTED_MODELS:
|
||||||
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
|
return self.SUPPORTED_MODELS[resolved_name].supports_images
|
||||||
|
|
||||||
# Fall back to parent implementation for unknown models
|
# Fall back to parent implementation for unknown models
|
||||||
return super()._supports_vision(model_name)
|
return super()._supports_vision(model_name)
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
|
||||||
"""Return a list of model names supported by this provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model names available from this provider
|
|
||||||
"""
|
|
||||||
# Get all model keys (both full names and aliases)
|
|
||||||
all_models = list(self.SUPPORTED_MODELS.keys())
|
|
||||||
|
|
||||||
if not respect_restrictions:
|
|
||||||
return all_models
|
|
||||||
|
|
||||||
# Apply restrictions if configured
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
|
|
||||||
# Filter based on restrictions
|
|
||||||
allowed_models = []
|
|
||||||
for model in all_models:
|
|
||||||
resolved_name = self._resolve_model_name(model)
|
|
||||||
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
|
|
||||||
allowed_models.append(model)
|
|
||||||
|
|
||||||
return allowed_models
|
|
||||||
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
|
||||||
"""Return all model names known by this provider, including alias targets.
|
|
||||||
|
|
||||||
This is used for validation purposes to ensure restriction policies
|
|
||||||
can validate against both aliases and their target model names.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of all model names and alias targets known by this provider
|
|
||||||
"""
|
|
||||||
# Collect all unique model names (both aliases and targets)
|
|
||||||
all_models = set()
|
|
||||||
|
|
||||||
for key, value in self.SUPPORTED_MODELS.items():
|
|
||||||
# Add the key (could be alias or full name)
|
|
||||||
all_models.add(key)
|
|
||||||
|
|
||||||
# If it's an alias (string value), add the target too
|
|
||||||
if isinstance(value, str):
|
|
||||||
all_models.add(value)
|
|
||||||
|
|
||||||
return sorted(all_models)
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Clean up HTTP clients when provider is closed."""
|
"""Clean up HTTP clients when provider is closed."""
|
||||||
logger.info("Closing DIAL provider HTTP clients...")
|
logger.info("Closing DIAL provider HTTP clients...")
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import Optional
|
|||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
|
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -17,47 +17,79 @@ logger = logging.getLogger(__name__)
|
|||||||
class GeminiModelProvider(ModelProvider):
|
class GeminiModelProvider(ModelProvider):
|
||||||
"""Google Gemini model provider implementation."""
|
"""Google Gemini model provider implementation."""
|
||||||
|
|
||||||
# Model configurations
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"gemini-2.0-flash": {
|
"gemini-2.0-flash": ModelCapabilities(
|
||||||
"context_window": 1_048_576, # 1M tokens
|
provider=ProviderType.GOOGLE,
|
||||||
"supports_extended_thinking": True, # Experimental thinking mode
|
model_name="gemini-2.0-flash",
|
||||||
"max_thinking_tokens": 24576, # Same as 2.5 flash for consistency
|
friendly_name="Gemini (Flash 2.0)",
|
||||||
"supports_images": True, # Vision capability
|
context_window=1_048_576, # 1M tokens
|
||||||
"max_image_size_mb": 20.0, # Conservative 20MB limit for reliability
|
supports_extended_thinking=True, # Experimental thinking mode
|
||||||
"description": "Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input",
|
supports_system_prompts=True,
|
||||||
},
|
supports_streaming=True,
|
||||||
"gemini-2.0-flash-lite": {
|
supports_function_calling=True,
|
||||||
"context_window": 1_048_576, # 1M tokens
|
supports_json_mode=True,
|
||||||
"supports_extended_thinking": False, # Not supported per user request
|
supports_images=True, # Vision capability
|
||||||
"max_thinking_tokens": 0, # No thinking support
|
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
|
||||||
"supports_images": False, # Does not support images
|
supports_temperature=True,
|
||||||
"max_image_size_mb": 0.0, # No image support
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
"description": "Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only",
|
max_thinking_tokens=24576, # Same as 2.5 flash for consistency
|
||||||
},
|
description="Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input",
|
||||||
"gemini-2.5-flash": {
|
aliases=["flash-2.0", "flash2"],
|
||||||
"context_window": 1_048_576, # 1M tokens
|
),
|
||||||
"supports_extended_thinking": True,
|
"gemini-2.0-flash-lite": ModelCapabilities(
|
||||||
"max_thinking_tokens": 24576, # Flash 2.5 thinking budget limit
|
provider=ProviderType.GOOGLE,
|
||||||
"supports_images": True, # Vision capability
|
model_name="gemini-2.0-flash-lite",
|
||||||
"max_image_size_mb": 20.0, # Conservative 20MB limit for reliability
|
friendly_name="Gemin (Flash Lite 2.0)",
|
||||||
"description": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
context_window=1_048_576, # 1M tokens
|
||||||
},
|
supports_extended_thinking=False, # Not supported per user request
|
||||||
"gemini-2.5-pro": {
|
supports_system_prompts=True,
|
||||||
"context_window": 1_048_576, # 1M tokens
|
supports_streaming=True,
|
||||||
"supports_extended_thinking": True,
|
supports_function_calling=True,
|
||||||
"max_thinking_tokens": 32768, # Pro 2.5 thinking budget limit
|
supports_json_mode=True,
|
||||||
"supports_images": True, # Vision capability
|
supports_images=False, # Does not support images
|
||||||
"max_image_size_mb": 32.0, # Higher limit for Pro model
|
max_image_size_mb=0.0, # No image support
|
||||||
"description": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
supports_temperature=True,
|
||||||
},
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
# Shorthands
|
description="Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only",
|
||||||
"flash": "gemini-2.5-flash",
|
aliases=["flashlite", "flash-lite"],
|
||||||
"flash-2.0": "gemini-2.0-flash",
|
),
|
||||||
"flash2": "gemini-2.0-flash",
|
"gemini-2.5-flash": ModelCapabilities(
|
||||||
"flashlite": "gemini-2.0-flash-lite",
|
provider=ProviderType.GOOGLE,
|
||||||
"flash-lite": "gemini-2.0-flash-lite",
|
model_name="gemini-2.5-flash",
|
||||||
"pro": "gemini-2.5-pro",
|
friendly_name="Gemini (Flash 2.5)",
|
||||||
|
context_window=1_048_576, # 1M tokens
|
||||||
|
supports_extended_thinking=True,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # Vision capability
|
||||||
|
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
max_thinking_tokens=24576, # Flash 2.5 thinking budget limit
|
||||||
|
description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||||
|
aliases=["flash", "flash2.5"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-pro": ModelCapabilities(
|
||||||
|
provider=ProviderType.GOOGLE,
|
||||||
|
model_name="gemini-2.5-pro",
|
||||||
|
friendly_name="Gemini (Pro 2.5)",
|
||||||
|
context_window=1_048_576, # 1M tokens
|
||||||
|
supports_extended_thinking=True,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # Vision capability
|
||||||
|
max_image_size_mb=32.0, # Higher limit for Pro model
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
max_thinking_tokens=32768, # Max thinking tokens for Pro model
|
||||||
|
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||||
|
aliases=["pro", "gemini pro", "gemini-pro"],
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
||||||
@@ -70,6 +102,14 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
"max": 1.0, # 100% of max - full thinking budget
|
"max": 1.0, # 100% of max - full thinking budget
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Model-specific thinking token limits
|
||||||
|
MAX_THINKING_TOKENS = {
|
||||||
|
"gemini-2.0-flash": 24576, # Same as 2.5 flash for consistency
|
||||||
|
"gemini-2.0-flash-lite": 0, # No thinking support
|
||||||
|
"gemini-2.5-flash": 24576, # Flash 2.5 thinking budget limit
|
||||||
|
"gemini-2.5-pro": 32768, # Pro 2.5 thinking budget limit
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize Gemini provider with API key."""
|
"""Initialize Gemini provider with API key."""
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
@@ -100,25 +140,8 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
||||||
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
|
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
config = self.SUPPORTED_MODELS[resolved_name]
|
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||||
|
return self.SUPPORTED_MODELS[resolved_name]
|
||||||
# Gemini models support 0.0-2.0 temperature range
|
|
||||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
|
||||||
|
|
||||||
return ModelCapabilities(
|
|
||||||
provider=ProviderType.GOOGLE,
|
|
||||||
model_name=resolved_name,
|
|
||||||
friendly_name="Gemini",
|
|
||||||
context_window=config["context_window"],
|
|
||||||
supports_extended_thinking=config["supports_extended_thinking"],
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=True,
|
|
||||||
supports_images=config.get("supports_images", False),
|
|
||||||
max_image_size_mb=config.get("max_image_size_mb", 0.0),
|
|
||||||
supports_temperature=True, # Gemini models accept temperature parameter
|
|
||||||
temperature_constraint=temp_constraint,
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
@@ -179,8 +202,8 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||||
# Get model's max thinking tokens and calculate actual budget
|
# Get model's max thinking tokens and calculate actual budget
|
||||||
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
||||||
if model_config and "max_thinking_tokens" in model_config:
|
if model_config and model_config.max_thinking_tokens > 0:
|
||||||
max_thinking_tokens = model_config["max_thinking_tokens"]
|
max_thinking_tokens = model_config.max_thinking_tokens
|
||||||
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||||
generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget)
|
generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget)
|
||||||
|
|
||||||
@@ -258,7 +281,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# First check if model is supported
|
# First check if model is supported
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Then check if model is allowed by restrictions
|
# Then check if model is allowed by restrictions
|
||||||
@@ -281,78 +304,20 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
||||||
"""Get actual thinking token budget for a model and thinking mode."""
|
"""Get actual thinking token budget for a model and thinking mode."""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
model_config = self.SUPPORTED_MODELS.get(resolved_name, {})
|
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
||||||
|
|
||||||
if not model_config.get("supports_extended_thinking", False):
|
if not model_config or not model_config.supports_extended_thinking:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if thinking_mode not in self.THINKING_BUDGETS:
|
if thinking_mode not in self.THINKING_BUDGETS:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
max_thinking_tokens = model_config.get("max_thinking_tokens", 0)
|
max_thinking_tokens = model_config.max_thinking_tokens
|
||||||
if max_thinking_tokens == 0:
|
if max_thinking_tokens == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
|
||||||
"""Return a list of model names supported by this provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model names available from this provider
|
|
||||||
"""
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
|
||||||
models = []
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Handle both base models (dict configs) and aliases (string values)
|
|
||||||
if isinstance(config, str):
|
|
||||||
# This is an alias - check if the target model would be allowed
|
|
||||||
target_model = config
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
|
||||||
continue
|
|
||||||
# Allow the alias
|
|
||||||
models.append(model_name)
|
|
||||||
else:
|
|
||||||
# This is a base model with config dict
|
|
||||||
# Check restrictions if enabled
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
|
||||||
continue
|
|
||||||
models.append(model_name)
|
|
||||||
|
|
||||||
return models
|
|
||||||
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
|
||||||
"""Return all model names known by this provider, including alias targets.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of all model names and alias targets known by this provider
|
|
||||||
"""
|
|
||||||
all_models = set()
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Add the model name itself
|
|
||||||
all_models.add(model_name.lower())
|
|
||||||
|
|
||||||
# If it's an alias (string value), add the target model too
|
|
||||||
if isinstance(config, str):
|
|
||||||
all_models.add(config.lower())
|
|
||||||
|
|
||||||
return list(all_models)
|
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
|
||||||
"""Resolve model shorthand to full name."""
|
|
||||||
# Check if it's a shorthand
|
|
||||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower())
|
|
||||||
if isinstance(shorthand_value, str):
|
|
||||||
return shorthand_value
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
def _extract_usage(self, response) -> dict[str, int]:
|
def _extract_usage(self, response) -> dict[str, int]:
|
||||||
"""Extract token usage from Gemini response."""
|
"""Extract token usage from Gemini response."""
|
||||||
usage = {}
|
usage = {}
|
||||||
|
|||||||
@@ -17,71 +17,110 @@ logger = logging.getLogger(__name__)
|
|||||||
class OpenAIModelProvider(OpenAICompatibleProvider):
|
class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||||
"""Official OpenAI API provider (api.openai.com)."""
|
"""Official OpenAI API provider (api.openai.com)."""
|
||||||
|
|
||||||
# Model configurations
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"o3": {
|
"o3": ModelCapabilities(
|
||||||
"context_window": 200_000, # 200K tokens
|
provider=ProviderType.OPENAI,
|
||||||
"supports_extended_thinking": False,
|
model_name="o3",
|
||||||
"supports_images": True, # O3 models support vision
|
friendly_name="OpenAI (O3)",
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
context_window=200_000, # 200K tokens
|
||||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
supports_extended_thinking=False,
|
||||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
supports_system_prompts=True,
|
||||||
"description": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
|
supports_streaming=True,
|
||||||
},
|
supports_function_calling=True,
|
||||||
"o3-mini": {
|
supports_json_mode=True,
|
||||||
"context_window": 200_000, # 200K tokens
|
supports_images=True, # O3 models support vision
|
||||||
"supports_extended_thinking": False,
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
"supports_images": True, # O3 models support vision
|
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
description="Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
|
||||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
aliases=[],
|
||||||
"description": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
),
|
||||||
},
|
"o3-mini": ModelCapabilities(
|
||||||
"o3-pro-2025-06-10": {
|
provider=ProviderType.OPENAI,
|
||||||
"context_window": 200_000, # 200K tokens
|
model_name="o3-mini",
|
||||||
"supports_extended_thinking": False,
|
friendly_name="OpenAI (O3-mini)",
|
||||||
"supports_images": True, # O3 models support vision
|
context_window=200_000, # 200K tokens
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
supports_extended_thinking=False,
|
||||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
supports_system_prompts=True,
|
||||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
supports_streaming=True,
|
||||||
"description": "Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
supports_function_calling=True,
|
||||||
},
|
supports_json_mode=True,
|
||||||
# Aliases
|
supports_images=True, # O3 models support vision
|
||||||
"o3-pro": "o3-pro-2025-06-10",
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
"o4-mini": {
|
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||||
"context_window": 200_000, # 200K tokens
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
"supports_extended_thinking": False,
|
description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||||
"supports_images": True, # O4 models support vision
|
aliases=["o3mini", "o3-mini"],
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
),
|
||||||
"supports_temperature": False, # O4 models don't accept temperature parameter
|
"o3-pro-2025-06-10": ModelCapabilities(
|
||||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
provider=ProviderType.OPENAI,
|
||||||
"description": "Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
model_name="o3-pro-2025-06-10",
|
||||||
},
|
friendly_name="OpenAI (O3-Pro)",
|
||||||
"o4-mini-high": {
|
context_window=200_000, # 200K tokens
|
||||||
"context_window": 200_000, # 200K tokens
|
supports_extended_thinking=False,
|
||||||
"supports_extended_thinking": False,
|
supports_system_prompts=True,
|
||||||
"supports_images": True, # O4 models support vision
|
supports_streaming=True,
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
supports_function_calling=True,
|
||||||
"supports_temperature": False, # O4 models don't accept temperature parameter
|
supports_json_mode=True,
|
||||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
supports_images=True, # O3 models support vision
|
||||||
"description": "Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks",
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
},
|
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||||
"gpt-4.1-2025-04-14": {
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
"context_window": 1_000_000, # 1M tokens
|
description="Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
||||||
"supports_extended_thinking": False,
|
aliases=["o3-pro"],
|
||||||
"supports_images": True, # GPT-4.1 supports vision
|
),
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
"o4-mini": ModelCapabilities(
|
||||||
"supports_temperature": True, # Regular models accept temperature parameter
|
provider=ProviderType.OPENAI,
|
||||||
"temperature_constraint": "range", # 0.0-2.0 range
|
model_name="o4-mini",
|
||||||
"description": "GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
friendly_name="OpenAI (O4-mini)",
|
||||||
},
|
context_window=200_000, # 200K tokens
|
||||||
# Shorthands
|
supports_extended_thinking=False,
|
||||||
"mini": "o4-mini", # Default 'mini' to latest mini model
|
supports_system_prompts=True,
|
||||||
"o3mini": "o3-mini",
|
supports_streaming=True,
|
||||||
"o4mini": "o4-mini",
|
supports_function_calling=True,
|
||||||
"o4minihigh": "o4-mini-high",
|
supports_json_mode=True,
|
||||||
"o4minihi": "o4-mini-high",
|
supports_images=True, # O4 models support vision
|
||||||
"gpt4.1": "gpt-4.1-2025-04-14",
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=False, # O4 models don't accept temperature parameter
|
||||||
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
|
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
||||||
|
aliases=["mini", "o4mini"],
|
||||||
|
),
|
||||||
|
"o4-mini-high": ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name="o4-mini-high",
|
||||||
|
friendly_name="OpenAI (O4-mini-high)",
|
||||||
|
context_window=200_000, # 200K tokens
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # O4 models support vision
|
||||||
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=False, # O4 models don't accept temperature parameter
|
||||||
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
|
description="Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks",
|
||||||
|
aliases=["o4minihigh", "o4minihi", "mini-high"],
|
||||||
|
),
|
||||||
|
"gpt-4.1-2025-04-14": ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name="gpt-4.1-2025-04-14",
|
||||||
|
friendly_name="OpenAI (GPT 4.1)",
|
||||||
|
context_window=1_000_000, # 1M tokens
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # GPT-4.1 supports vision
|
||||||
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=True, # Regular models accept temperature parameter
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
||||||
|
aliases=["gpt4.1"],
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
@@ -95,7 +134,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# Resolve shorthand
|
# Resolve shorthand
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||||
|
|
||||||
# Check if model is allowed by restrictions
|
# Check if model is allowed by restrictions
|
||||||
@@ -105,27 +144,8 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
||||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
config = self.SUPPORTED_MODELS[resolved_name]
|
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||||
|
return self.SUPPORTED_MODELS[resolved_name]
|
||||||
# Get temperature constraints and support from configuration
|
|
||||||
supports_temperature = config.get("supports_temperature", True) # Default to True for backward compatibility
|
|
||||||
temp_constraint_type = config.get("temperature_constraint", "range") # Default to range
|
|
||||||
temp_constraint = create_temperature_constraint(temp_constraint_type)
|
|
||||||
|
|
||||||
return ModelCapabilities(
|
|
||||||
provider=ProviderType.OPENAI,
|
|
||||||
model_name=model_name,
|
|
||||||
friendly_name="OpenAI",
|
|
||||||
context_window=config["context_window"],
|
|
||||||
supports_extended_thinking=config["supports_extended_thinking"],
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=True,
|
|
||||||
supports_images=config.get("supports_images", False),
|
|
||||||
max_image_size_mb=config.get("max_image_size_mb", 0.0),
|
|
||||||
supports_temperature=supports_temperature,
|
|
||||||
temperature_constraint=temp_constraint,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
@@ -136,7 +156,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# First check if model is supported
|
# First check if model is supported
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Then check if model is allowed by restrictions
|
# Then check if model is allowed by restrictions
|
||||||
@@ -177,61 +197,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# Currently no OpenAI models support extended thinking
|
# Currently no OpenAI models support extended thinking
|
||||||
# This may change with future O3 models
|
# This may change with future O3 models
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
|
||||||
"""Return a list of model names supported by this provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model names available from this provider
|
|
||||||
"""
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
|
||||||
models = []
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Handle both base models (dict configs) and aliases (string values)
|
|
||||||
if isinstance(config, str):
|
|
||||||
# This is an alias - check if the target model would be allowed
|
|
||||||
target_model = config
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
|
||||||
continue
|
|
||||||
# Allow the alias
|
|
||||||
models.append(model_name)
|
|
||||||
else:
|
|
||||||
# This is a base model with config dict
|
|
||||||
# Check restrictions if enabled
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
|
||||||
continue
|
|
||||||
models.append(model_name)
|
|
||||||
|
|
||||||
return models
|
|
||||||
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
|
||||||
"""Return all model names known by this provider, including alias targets.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of all model names and alias targets known by this provider
|
|
||||||
"""
|
|
||||||
all_models = set()
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Add the model name itself
|
|
||||||
all_models.add(model_name.lower())
|
|
||||||
|
|
||||||
# If it's an alias (string value), add the target model too
|
|
||||||
if isinstance(config, str):
|
|
||||||
all_models.add(config.lower())
|
|
||||||
|
|
||||||
return list(all_models)
|
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
|
||||||
"""Resolve model shorthand to full name."""
|
|
||||||
# Check if it's a shorthand
|
|
||||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
|
||||||
if isinstance(shorthand_value, str):
|
|
||||||
return shorthand_value
|
|
||||||
return model_name
|
|
||||||
|
|||||||
@@ -270,3 +270,39 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
all_models.add(config.model_name.lower())
|
all_models.add(config.model_name.lower())
|
||||||
|
|
||||||
return list(all_models)
|
return list(all_models)
|
||||||
|
|
||||||
|
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Get model configurations from the registry.
|
||||||
|
|
||||||
|
For OpenRouter, we convert registry configurations to ModelCapabilities objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping model names to their ModelCapabilities objects
|
||||||
|
"""
|
||||||
|
configs = {}
|
||||||
|
|
||||||
|
if self._registry:
|
||||||
|
# Get all models from registry
|
||||||
|
for model_name in self._registry.list_models():
|
||||||
|
# Only include models that this provider validates
|
||||||
|
if self.validate_model_name(model_name):
|
||||||
|
config = self._registry.resolve(model_name)
|
||||||
|
if config and not config.is_custom: # Only OpenRouter models, not custom ones
|
||||||
|
# Convert OpenRouterModelConfig to ModelCapabilities
|
||||||
|
capabilities = config.to_capabilities()
|
||||||
|
# Override provider type to OPENROUTER
|
||||||
|
capabilities.provider = ProviderType.OPENROUTER
|
||||||
|
capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})"
|
||||||
|
configs[model_name] = capabilities
|
||||||
|
|
||||||
|
return configs
|
||||||
|
|
||||||
|
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||||
|
"""Get all model aliases from the registry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping model names to their list of aliases
|
||||||
|
"""
|
||||||
|
# Since aliases are now included in the configurations,
|
||||||
|
# we can use the base class implementation
|
||||||
|
return super().get_all_model_aliases()
|
||||||
|
|||||||
134
providers/xai.py
134
providers/xai.py
@@ -7,7 +7,7 @@ from .base import (
|
|||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
create_temperature_constraint,
|
||||||
)
|
)
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
|
||||||
@@ -19,23 +19,42 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
FRIENDLY_NAME = "X.AI"
|
FRIENDLY_NAME = "X.AI"
|
||||||
|
|
||||||
# Model configurations
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"grok-3": {
|
"grok-3": ModelCapabilities(
|
||||||
"context_window": 131_072, # 131K tokens
|
provider=ProviderType.XAI,
|
||||||
"supports_extended_thinking": False,
|
model_name="grok-3",
|
||||||
"description": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
friendly_name="X.AI (Grok 3)",
|
||||||
},
|
context_window=131_072, # 131K tokens
|
||||||
"grok-3-fast": {
|
supports_extended_thinking=False,
|
||||||
"context_window": 131_072, # 131K tokens
|
supports_system_prompts=True,
|
||||||
"supports_extended_thinking": False,
|
supports_streaming=True,
|
||||||
"description": "GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
supports_function_calling=True,
|
||||||
},
|
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||||
# Shorthands for convenience
|
supports_images=False, # Assuming GROK is text-only for now
|
||||||
"grok": "grok-3", # Default to grok-3
|
max_image_size_mb=0.0,
|
||||||
"grok3": "grok-3",
|
supports_temperature=True,
|
||||||
"grok3fast": "grok-3-fast",
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
"grokfast": "grok-3-fast",
|
description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||||
|
aliases=["grok", "grok3"],
|
||||||
|
),
|
||||||
|
"grok-3-fast": ModelCapabilities(
|
||||||
|
provider=ProviderType.XAI,
|
||||||
|
model_name="grok-3-fast",
|
||||||
|
friendly_name="X.AI (Grok 3 Fast)",
|
||||||
|
context_window=131_072, # 131K tokens
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||||
|
supports_images=False, # Assuming GROK is text-only for now
|
||||||
|
max_image_size_mb=0.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
||||||
|
aliases=["grok3fast", "grokfast", "grok3-fast"],
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
@@ -49,7 +68,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# Resolve shorthand
|
# Resolve shorthand
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
||||||
|
|
||||||
# Check if model is allowed by restrictions
|
# Check if model is allowed by restrictions
|
||||||
@@ -59,23 +78,8 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
||||||
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
config = self.SUPPORTED_MODELS[resolved_name]
|
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||||
|
return self.SUPPORTED_MODELS[resolved_name]
|
||||||
# Define temperature constraints for GROK models
|
|
||||||
# GROK supports the standard OpenAI temperature range
|
|
||||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
|
||||||
|
|
||||||
return ModelCapabilities(
|
|
||||||
provider=ProviderType.XAI,
|
|
||||||
model_name=resolved_name,
|
|
||||||
friendly_name=self.FRIENDLY_NAME,
|
|
||||||
context_window=config["context_window"],
|
|
||||||
supports_extended_thinking=config["supports_extended_thinking"],
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=True,
|
|
||||||
temperature_constraint=temp_constraint,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
@@ -86,7 +90,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# First check if model is supported
|
# First check if model is supported
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Then check if model is allowed by restrictions
|
# Then check if model is allowed by restrictions
|
||||||
@@ -127,61 +131,3 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# Currently GROK models do not support extended thinking
|
# Currently GROK models do not support extended thinking
|
||||||
# This may change with future GROK model releases
|
# This may change with future GROK model releases
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
|
||||||
"""Return a list of model names supported by this provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model names available from this provider
|
|
||||||
"""
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
|
||||||
models = []
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Handle both base models (dict configs) and aliases (string values)
|
|
||||||
if isinstance(config, str):
|
|
||||||
# This is an alias - check if the target model would be allowed
|
|
||||||
target_model = config
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
|
||||||
continue
|
|
||||||
# Allow the alias
|
|
||||||
models.append(model_name)
|
|
||||||
else:
|
|
||||||
# This is a base model with config dict
|
|
||||||
# Check restrictions if enabled
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
|
||||||
continue
|
|
||||||
models.append(model_name)
|
|
||||||
|
|
||||||
return models
|
|
||||||
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
|
||||||
"""Return all model names known by this provider, including alias targets.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of all model names and alias targets known by this provider
|
|
||||||
"""
|
|
||||||
all_models = set()
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Add the model name itself
|
|
||||||
all_models.add(model_name.lower())
|
|
||||||
|
|
||||||
# If it's an alias (string value), add the target model too
|
|
||||||
if isinstance(config, str):
|
|
||||||
all_models.add(config.lower())
|
|
||||||
|
|
||||||
return list(all_models)
|
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
|
||||||
"""Resolve model shorthand to full name."""
|
|
||||||
# Check if it's a shorthand
|
|
||||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
|
||||||
if isinstance(shorthand_value, str):
|
|
||||||
return shorthand_value
|
|
||||||
return model_name
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class TestAutoMode:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Check that model has description
|
# Check that model has description
|
||||||
description = config.get("description", "")
|
description = config.description if hasattr(config, "description") else ""
|
||||||
if description:
|
if description:
|
||||||
models_with_descriptions[model_name] = description
|
models_with_descriptions[model_name] = description
|
||||||
|
|
||||||
|
|||||||
@@ -319,7 +319,18 @@ class TestAutoModeComprehensive:
|
|||||||
m
|
m
|
||||||
for m in available_models
|
for m in available_models
|
||||||
if not m.startswith("gemini")
|
if not m.startswith("gemini")
|
||||||
and m not in ["flash", "pro", "flash-2.0", "flash2", "flashlite", "flash-lite"]
|
and m
|
||||||
|
not in [
|
||||||
|
"flash",
|
||||||
|
"pro",
|
||||||
|
"flash-2.0",
|
||||||
|
"flash2",
|
||||||
|
"flashlite",
|
||||||
|
"flash-lite",
|
||||||
|
"flash2.5",
|
||||||
|
"gemini pro",
|
||||||
|
"gemini-pro",
|
||||||
|
]
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
len(non_gemini_models) == 0
|
len(non_gemini_models) == 0
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class TestDIALProvider:
|
|||||||
# Test O3 capabilities
|
# Test O3 capabilities
|
||||||
capabilities = provider.get_capabilities("o3")
|
capabilities = provider.get_capabilities("o3")
|
||||||
assert capabilities.model_name == "o3-2025-04-16"
|
assert capabilities.model_name == "o3-2025-04-16"
|
||||||
assert capabilities.friendly_name == "DIAL"
|
assert capabilities.friendly_name == "DIAL (O3)"
|
||||||
assert capabilities.context_window == 200_000
|
assert capabilities.context_window == 200_000
|
||||||
assert capabilities.provider == ProviderType.DIAL
|
assert capabilities.provider == ProviderType.DIAL
|
||||||
assert capabilities.supports_images is True
|
assert capabilities.supports_images is True
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class TestOpenAIProvider:
|
|||||||
|
|
||||||
capabilities = provider.get_capabilities("o3")
|
capabilities = provider.get_capabilities("o3")
|
||||||
assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities
|
assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities
|
||||||
assert capabilities.friendly_name == "OpenAI"
|
assert capabilities.friendly_name == "OpenAI (O3)"
|
||||||
assert capabilities.context_window == 200_000
|
assert capabilities.context_window == 200_000
|
||||||
assert capabilities.provider == ProviderType.OPENAI
|
assert capabilities.provider == ProviderType.OPENAI
|
||||||
assert not capabilities.supports_extended_thinking
|
assert not capabilities.supports_extended_thinking
|
||||||
@@ -101,8 +101,8 @@ class TestOpenAIProvider:
|
|||||||
provider = OpenAIModelProvider("test-key")
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
capabilities = provider.get_capabilities("mini")
|
capabilities = provider.get_capabilities("mini")
|
||||||
assert capabilities.model_name == "mini" # Capabilities should show original request
|
assert capabilities.model_name == "o4-mini" # Capabilities should show resolved model name
|
||||||
assert capabilities.friendly_name == "OpenAI"
|
assert capabilities.friendly_name == "OpenAI (O4-mini)"
|
||||||
assert capabilities.context_window == 200_000
|
assert capabilities.context_window == 200_000
|
||||||
assert capabilities.provider == ProviderType.OPENAI
|
assert capabilities.provider == ProviderType.OPENAI
|
||||||
|
|
||||||
|
|||||||
206
tests/test_supported_models_aliases.py
Normal file
206
tests/test_supported_models_aliases.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""Test the SUPPORTED_MODELS aliases structure across all providers."""
|
||||||
|
|
||||||
|
from providers.dial import DIALModelProvider
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestSupportedModelsAliases:
|
||||||
|
"""Test that all providers have correctly structured SUPPORTED_MODELS with aliases."""
|
||||||
|
|
||||||
|
def test_gemini_provider_aliases(self):
|
||||||
|
"""Test Gemini provider's alias structure."""
|
||||||
|
provider = GeminiModelProvider("test-key")
|
||||||
|
|
||||||
|
# Check that all models have ModelCapabilities with aliases
|
||||||
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
|
# Test specific aliases
|
||||||
|
assert "flash" in provider.SUPPORTED_MODELS["gemini-2.5-flash"].aliases
|
||||||
|
assert "pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro"].aliases
|
||||||
|
assert "flash-2.0" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
|
||||||
|
assert "flash2" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
|
||||||
|
assert "flashlite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
|
||||||
|
assert "flash-lite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
|
||||||
|
|
||||||
|
# Test alias resolution
|
||||||
|
assert provider._resolve_model_name("flash") == "gemini-2.5-flash"
|
||||||
|
assert provider._resolve_model_name("pro") == "gemini-2.5-pro"
|
||||||
|
assert provider._resolve_model_name("flash-2.0") == "gemini-2.0-flash"
|
||||||
|
assert provider._resolve_model_name("flash2") == "gemini-2.0-flash"
|
||||||
|
assert provider._resolve_model_name("flashlite") == "gemini-2.0-flash-lite"
|
||||||
|
|
||||||
|
# Test case insensitive resolution
|
||||||
|
assert provider._resolve_model_name("Flash") == "gemini-2.5-flash"
|
||||||
|
assert provider._resolve_model_name("PRO") == "gemini-2.5-pro"
|
||||||
|
|
||||||
|
def test_openai_provider_aliases(self):
|
||||||
|
"""Test OpenAI provider's alias structure."""
|
||||||
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Check that all models have ModelCapabilities with aliases
|
||||||
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
|
# Test specific aliases
|
||||||
|
assert "mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
||||||
|
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
||||||
|
assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases
|
||||||
|
assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro-2025-06-10"].aliases
|
||||||
|
assert "o4minihigh" in provider.SUPPORTED_MODELS["o4-mini-high"].aliases
|
||||||
|
assert "o4minihi" in provider.SUPPORTED_MODELS["o4-mini-high"].aliases
|
||||||
|
assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1-2025-04-14"].aliases
|
||||||
|
|
||||||
|
# Test alias resolution
|
||||||
|
assert provider._resolve_model_name("mini") == "o4-mini"
|
||||||
|
assert provider._resolve_model_name("o3mini") == "o3-mini"
|
||||||
|
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
|
||||||
|
assert provider._resolve_model_name("o4minihigh") == "o4-mini-high"
|
||||||
|
assert provider._resolve_model_name("gpt4.1") == "gpt-4.1-2025-04-14"
|
||||||
|
|
||||||
|
# Test case insensitive resolution
|
||||||
|
assert provider._resolve_model_name("Mini") == "o4-mini"
|
||||||
|
assert provider._resolve_model_name("O3MINI") == "o3-mini"
|
||||||
|
|
||||||
|
def test_xai_provider_aliases(self):
|
||||||
|
"""Test XAI provider's alias structure."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Check that all models have ModelCapabilities with aliases
|
||||||
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
|
# Test specific aliases
|
||||||
|
assert "grok" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
||||||
|
assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
||||||
|
assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
||||||
|
assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
||||||
|
|
||||||
|
# Test alias resolution
|
||||||
|
assert provider._resolve_model_name("grok") == "grok-3"
|
||||||
|
assert provider._resolve_model_name("grok3") == "grok-3"
|
||||||
|
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
||||||
|
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
|
||||||
|
|
||||||
|
# Test case insensitive resolution
|
||||||
|
assert provider._resolve_model_name("Grok") == "grok-3"
|
||||||
|
assert provider._resolve_model_name("GROKFAST") == "grok-3-fast"
|
||||||
|
|
||||||
|
def test_dial_provider_aliases(self):
|
||||||
|
"""Test DIAL provider's alias structure."""
|
||||||
|
provider = DIALModelProvider("test-key")
|
||||||
|
|
||||||
|
# Check that all models have ModelCapabilities with aliases
|
||||||
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
|
# Test specific aliases
|
||||||
|
assert "o3" in provider.SUPPORTED_MODELS["o3-2025-04-16"].aliases
|
||||||
|
assert "o4-mini" in provider.SUPPORTED_MODELS["o4-mini-2025-04-16"].aliases
|
||||||
|
assert "sonnet-4" in provider.SUPPORTED_MODELS["anthropic.claude-sonnet-4-20250514-v1:0"].aliases
|
||||||
|
assert "opus-4" in provider.SUPPORTED_MODELS["anthropic.claude-opus-4-20250514-v1:0"].aliases
|
||||||
|
assert "gemini-2.5-pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro-preview-05-06"].aliases
|
||||||
|
|
||||||
|
# Test alias resolution
|
||||||
|
assert provider._resolve_model_name("o3") == "o3-2025-04-16"
|
||||||
|
assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16"
|
||||||
|
assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||||
|
assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0"
|
||||||
|
|
||||||
|
# Test case insensitive resolution
|
||||||
|
assert provider._resolve_model_name("O3") == "o3-2025-04-16"
|
||||||
|
assert provider._resolve_model_name("SONNET-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||||
|
|
||||||
|
def test_list_models_includes_aliases(self):
|
||||||
|
"""Test that list_models returns both base models and aliases."""
|
||||||
|
# Test Gemini
|
||||||
|
gemini_provider = GeminiModelProvider("test-key")
|
||||||
|
gemini_models = gemini_provider.list_models(respect_restrictions=False)
|
||||||
|
assert "gemini-2.5-flash" in gemini_models
|
||||||
|
assert "flash" in gemini_models
|
||||||
|
assert "gemini-2.5-pro" in gemini_models
|
||||||
|
assert "pro" in gemini_models
|
||||||
|
|
||||||
|
# Test OpenAI
|
||||||
|
openai_provider = OpenAIModelProvider("test-key")
|
||||||
|
openai_models = openai_provider.list_models(respect_restrictions=False)
|
||||||
|
assert "o4-mini" in openai_models
|
||||||
|
assert "mini" in openai_models
|
||||||
|
assert "o3-mini" in openai_models
|
||||||
|
assert "o3mini" in openai_models
|
||||||
|
|
||||||
|
# Test XAI
|
||||||
|
xai_provider = XAIModelProvider("test-key")
|
||||||
|
xai_models = xai_provider.list_models(respect_restrictions=False)
|
||||||
|
assert "grok-3" in xai_models
|
||||||
|
assert "grok" in xai_models
|
||||||
|
assert "grok-3-fast" in xai_models
|
||||||
|
assert "grokfast" in xai_models
|
||||||
|
|
||||||
|
# Test DIAL
|
||||||
|
dial_provider = DIALModelProvider("test-key")
|
||||||
|
dial_models = dial_provider.list_models(respect_restrictions=False)
|
||||||
|
assert "o3-2025-04-16" in dial_models
|
||||||
|
assert "o3" in dial_models
|
||||||
|
|
||||||
|
def test_list_all_known_models_includes_aliases(self):
|
||||||
|
"""Test that list_all_known_models returns all models and aliases in lowercase."""
|
||||||
|
# Test Gemini
|
||||||
|
gemini_provider = GeminiModelProvider("test-key")
|
||||||
|
gemini_all = gemini_provider.list_all_known_models()
|
||||||
|
assert "gemini-2.5-flash" in gemini_all
|
||||||
|
assert "flash" in gemini_all
|
||||||
|
assert "gemini-2.5-pro" in gemini_all
|
||||||
|
assert "pro" in gemini_all
|
||||||
|
# All should be lowercase
|
||||||
|
assert all(model == model.lower() for model in gemini_all)
|
||||||
|
|
||||||
|
# Test OpenAI
|
||||||
|
openai_provider = OpenAIModelProvider("test-key")
|
||||||
|
openai_all = openai_provider.list_all_known_models()
|
||||||
|
assert "o4-mini" in openai_all
|
||||||
|
assert "mini" in openai_all
|
||||||
|
assert "o3-mini" in openai_all
|
||||||
|
assert "o3mini" in openai_all
|
||||||
|
# All should be lowercase
|
||||||
|
assert all(model == model.lower() for model in openai_all)
|
||||||
|
|
||||||
|
def test_no_string_shorthand_in_supported_models(self):
|
||||||
|
"""Test that no provider has string-based shorthands anymore."""
|
||||||
|
providers = [
|
||||||
|
GeminiModelProvider("test-key"),
|
||||||
|
OpenAIModelProvider("test-key"),
|
||||||
|
XAIModelProvider("test-key"),
|
||||||
|
DIALModelProvider("test-key"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for provider in providers:
|
||||||
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
|
# All values must be ModelCapabilities objects, not strings or dicts
|
||||||
|
from providers.base import ModelCapabilities
|
||||||
|
|
||||||
|
assert isinstance(config, ModelCapabilities), (
|
||||||
|
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "
|
||||||
|
f"must be a ModelCapabilities object, not {type(config).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_resolve_returns_original_if_not_found(self):
|
||||||
|
"""Test that _resolve_model_name returns original name if alias not found."""
|
||||||
|
providers = [
|
||||||
|
GeminiModelProvider("test-key"),
|
||||||
|
OpenAIModelProvider("test-key"),
|
||||||
|
XAIModelProvider("test-key"),
|
||||||
|
DIALModelProvider("test-key"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for provider in providers:
|
||||||
|
# Test with unknown model name
|
||||||
|
assert provider._resolve_model_name("unknown-model") == "unknown-model"
|
||||||
|
assert provider._resolve_model_name("gpt-4") == "gpt-4"
|
||||||
|
assert provider._resolve_model_name("claude-3") == "claude-3"
|
||||||
@@ -77,7 +77,7 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
capabilities = provider.get_capabilities("grok-3")
|
capabilities = provider.get_capabilities("grok-3")
|
||||||
assert capabilities.model_name == "grok-3"
|
assert capabilities.model_name == "grok-3"
|
||||||
assert capabilities.friendly_name == "X.AI"
|
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
||||||
assert capabilities.context_window == 131_072
|
assert capabilities.context_window == 131_072
|
||||||
assert capabilities.provider == ProviderType.XAI
|
assert capabilities.provider == ProviderType.XAI
|
||||||
assert not capabilities.supports_extended_thinking
|
assert not capabilities.supports_extended_thinking
|
||||||
@@ -96,7 +96,7 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
capabilities = provider.get_capabilities("grok-3-fast")
|
capabilities = provider.get_capabilities("grok-3-fast")
|
||||||
assert capabilities.model_name == "grok-3-fast"
|
assert capabilities.model_name == "grok-3-fast"
|
||||||
assert capabilities.friendly_name == "X.AI"
|
assert capabilities.friendly_name == "X.AI (Grok 3 Fast)"
|
||||||
assert capabilities.context_window == 131_072
|
assert capabilities.context_window == 131_072
|
||||||
assert capabilities.provider == ProviderType.XAI
|
assert capabilities.provider == ProviderType.XAI
|
||||||
assert not capabilities.supports_extended_thinking
|
assert not capabilities.supports_extended_thinking
|
||||||
@@ -212,31 +212,34 @@ class TestXAIProvider:
|
|||||||
assert provider.FRIENDLY_NAME == "X.AI"
|
assert provider.FRIENDLY_NAME == "X.AI"
|
||||||
|
|
||||||
capabilities = provider.get_capabilities("grok-3")
|
capabilities = provider.get_capabilities("grok-3")
|
||||||
assert capabilities.friendly_name == "X.AI"
|
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
||||||
|
|
||||||
def test_supported_models_structure(self):
|
def test_supported_models_structure(self):
|
||||||
"""Test that SUPPORTED_MODELS has the correct structure."""
|
"""Test that SUPPORTED_MODELS has the correct structure."""
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Check that all expected models are present
|
# Check that all expected base models are present
|
||||||
assert "grok-3" in provider.SUPPORTED_MODELS
|
assert "grok-3" in provider.SUPPORTED_MODELS
|
||||||
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
||||||
assert "grok" in provider.SUPPORTED_MODELS
|
|
||||||
assert "grok3" in provider.SUPPORTED_MODELS
|
|
||||||
assert "grokfast" in provider.SUPPORTED_MODELS
|
|
||||||
assert "grok3fast" in provider.SUPPORTED_MODELS
|
|
||||||
|
|
||||||
# Check model configs have required fields
|
# Check model configs have required fields
|
||||||
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
from providers.base import ModelCapabilities
|
||||||
assert isinstance(grok3_config, dict)
|
|
||||||
assert "context_window" in grok3_config
|
|
||||||
assert "supports_extended_thinking" in grok3_config
|
|
||||||
assert grok3_config["context_window"] == 131_072
|
|
||||||
assert grok3_config["supports_extended_thinking"] is False
|
|
||||||
|
|
||||||
# Check shortcuts point to full names
|
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
||||||
assert provider.SUPPORTED_MODELS["grok"] == "grok-3"
|
assert isinstance(grok3_config, ModelCapabilities)
|
||||||
assert provider.SUPPORTED_MODELS["grokfast"] == "grok-3-fast"
|
assert hasattr(grok3_config, "context_window")
|
||||||
|
assert hasattr(grok3_config, "supports_extended_thinking")
|
||||||
|
assert hasattr(grok3_config, "aliases")
|
||||||
|
assert grok3_config.context_window == 131_072
|
||||||
|
assert grok3_config.supports_extended_thinking is False
|
||||||
|
|
||||||
|
# Check aliases are correctly structured
|
||||||
|
assert "grok" in grok3_config.aliases
|
||||||
|
assert "grok3" in grok3_config.aliases
|
||||||
|
|
||||||
|
grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"]
|
||||||
|
assert "grok3fast" in grok3fast_config.aliases
|
||||||
|
assert "grokfast" in grok3fast_config.aliases
|
||||||
|
|
||||||
@patch("providers.openai_compatible.OpenAI")
|
@patch("providers.openai_compatible.OpenAI")
|
||||||
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
|
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
|
||||||
|
|||||||
@@ -99,15 +99,11 @@ class ListModelsTool(BaseTool):
|
|||||||
output_lines.append("**Status**: Configured and available")
|
output_lines.append("**Status**: Configured and available")
|
||||||
output_lines.append("\n**Models**:")
|
output_lines.append("\n**Models**:")
|
||||||
|
|
||||||
# Get models from the provider's SUPPORTED_MODELS
|
# Get models from the provider's model configurations
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
for model_name, capabilities in provider.get_model_configurations().items():
|
||||||
# Skip alias entries (string values)
|
# Get description and context from the ModelCapabilities object
|
||||||
if isinstance(config, str):
|
description = capabilities.description or "No description available"
|
||||||
continue
|
context_window = capabilities.context_window
|
||||||
|
|
||||||
# Get description and context from the model config
|
|
||||||
description = config.get("description", "No description available")
|
|
||||||
context_window = config.get("context_window", 0)
|
|
||||||
|
|
||||||
# Format context window
|
# Format context window
|
||||||
if context_window >= 1_000_000:
|
if context_window >= 1_000_000:
|
||||||
@@ -133,13 +129,14 @@ class ListModelsTool(BaseTool):
|
|||||||
|
|
||||||
# Show aliases for this provider
|
# Show aliases for this provider
|
||||||
aliases = []
|
aliases = []
|
||||||
for alias_name, target in provider.SUPPORTED_MODELS.items():
|
for model_name, capabilities in provider.get_model_configurations().items():
|
||||||
if isinstance(target, str): # This is an alias
|
if capabilities.aliases:
|
||||||
aliases.append(f"- `{alias_name}` → `{target}`")
|
for alias in capabilities.aliases:
|
||||||
|
aliases.append(f"- `{alias}` → `{model_name}`")
|
||||||
|
|
||||||
if aliases:
|
if aliases:
|
||||||
output_lines.append("\n**Aliases**:")
|
output_lines.append("\n**Aliases**:")
|
||||||
output_lines.extend(aliases)
|
output_lines.extend(sorted(aliases)) # Sort for consistent output
|
||||||
else:
|
else:
|
||||||
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
||||||
|
|
||||||
@@ -237,7 +234,7 @@ class ListModelsTool(BaseTool):
|
|||||||
|
|
||||||
for alias in registry.list_aliases():
|
for alias in registry.list_aliases():
|
||||||
config = registry.resolve(alias)
|
config = registry.resolve(alias)
|
||||||
if config and hasattr(config, "is_custom") and config.is_custom:
|
if config and config.is_custom:
|
||||||
custom_models.append((alias, config))
|
custom_models.append((alias, config))
|
||||||
|
|
||||||
if custom_models:
|
if custom_models:
|
||||||
|
|||||||
@@ -256,8 +256,8 @@ class BaseTool(ABC):
|
|||||||
# Find all custom models (is_custom=true)
|
# Find all custom models (is_custom=true)
|
||||||
for alias in registry.list_aliases():
|
for alias in registry.list_aliases():
|
||||||
config = registry.resolve(alias)
|
config = registry.resolve(alias)
|
||||||
# Use hasattr for defensive programming - is_custom is optional with default False
|
# Check if this is a custom model that requires custom endpoints
|
||||||
if config and hasattr(config, "is_custom") and config.is_custom:
|
if config and config.is_custom:
|
||||||
if alias not in all_models:
|
if alias not in all_models:
|
||||||
all_models.append(alias)
|
all_models.append(alias)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -311,12 +311,16 @@ class BaseTool(ABC):
|
|||||||
ProviderType.GOOGLE: "Gemini models",
|
ProviderType.GOOGLE: "Gemini models",
|
||||||
ProviderType.OPENAI: "OpenAI models",
|
ProviderType.OPENAI: "OpenAI models",
|
||||||
ProviderType.XAI: "X.AI GROK models",
|
ProviderType.XAI: "X.AI GROK models",
|
||||||
|
ProviderType.DIAL: "DIAL models",
|
||||||
ProviderType.CUSTOM: "Custom models",
|
ProviderType.CUSTOM: "Custom models",
|
||||||
ProviderType.OPENROUTER: "OpenRouter models",
|
ProviderType.OPENROUTER: "OpenRouter models",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Check available providers and add their model descriptions
|
# Check available providers and add their model descriptions
|
||||||
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI]:
|
|
||||||
|
# Start with native providers
|
||||||
|
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]:
|
||||||
|
# Only if this is registered / available
|
||||||
provider = ModelProviderRegistry.get_provider(provider_type)
|
provider = ModelProviderRegistry.get_provider(provider_type)
|
||||||
if provider:
|
if provider:
|
||||||
provider_section_added = False
|
provider_section_added = False
|
||||||
@@ -324,13 +328,13 @@ class BaseTool(ABC):
|
|||||||
try:
|
try:
|
||||||
# Get model config to extract description
|
# Get model config to extract description
|
||||||
model_config = provider.SUPPORTED_MODELS.get(model_name)
|
model_config = provider.SUPPORTED_MODELS.get(model_name)
|
||||||
if isinstance(model_config, dict) and "description" in model_config:
|
if model_config and model_config.description:
|
||||||
if not provider_section_added:
|
if not provider_section_added:
|
||||||
model_desc_parts.append(
|
model_desc_parts.append(
|
||||||
f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:"
|
f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:"
|
||||||
)
|
)
|
||||||
provider_section_added = True
|
provider_section_added = True
|
||||||
model_desc_parts.append(f"- '{model_name}': {model_config['description']}")
|
model_desc_parts.append(f"- '{model_name}': {model_config.description}")
|
||||||
except Exception:
|
except Exception:
|
||||||
# Skip models without descriptions
|
# Skip models without descriptions
|
||||||
continue
|
continue
|
||||||
@@ -346,8 +350,8 @@ class BaseTool(ABC):
|
|||||||
# Find all custom models (is_custom=true)
|
# Find all custom models (is_custom=true)
|
||||||
for alias in registry.list_aliases():
|
for alias in registry.list_aliases():
|
||||||
config = registry.resolve(alias)
|
config = registry.resolve(alias)
|
||||||
# Use hasattr for defensive programming - is_custom is optional with default False
|
# Check if this is a custom model that requires custom endpoints
|
||||||
if config and hasattr(config, "is_custom") and config.is_custom:
|
if config and config.is_custom:
|
||||||
# Format context window
|
# Format context window
|
||||||
context_tokens = config.context_window
|
context_tokens = config.context_window
|
||||||
if context_tokens >= 1_000_000:
|
if context_tokens >= 1_000_000:
|
||||||
|
|||||||
Reference in New Issue
Block a user