From 498ea88293fea1b5f3f463ad9878aa1e6b516bc3 Mon Sep 17 00:00:00 2001 From: Fahad Date: Mon, 23 Jun 2025 16:58:59 +0400 Subject: [PATCH] Use ModelCapabilities consistently instead of dictionaries Moved aliases as part of SUPPORTED_MODELS instead of shorthand, more in line with how custom_models are declared Further refactoring to cleanup some code --- config.py | 2 +- providers/base.py | 134 +++++++++-- providers/custom.py | 74 +++--- providers/dial.py | 301 +++++++++++++------------ providers/gemini.py | 213 ++++++++--------- providers/openai_provider.py | 252 +++++++++------------ providers/openrouter.py | 36 +++ providers/xai.py | 134 ++++------- tests/test_auto_mode.py | 2 +- tests/test_auto_mode_comprehensive.py | 13 +- tests/test_dial_provider.py | 2 +- tests/test_openai_provider.py | 6 +- tests/test_supported_models_aliases.py | 206 +++++++++++++++++ tests/test_xai_provider.py | 37 +-- tools/listmodels.py | 25 +- tools/shared/base_tool.py | 18 +- 16 files changed, 850 insertions(+), 605 deletions(-) create mode 100644 tests/test_supported_models_aliases.py diff --git a/config.py b/config.py index d2005c1..5e8667a 100644 --- a/config.py +++ b/config.py @@ -14,7 +14,7 @@ import os # These values are used in server responses and for tracking releases # IMPORTANT: This is the single source of truth for version and author info # Semantic versioning: MAJOR.MINOR.PATCH -__version__ = "5.6.2" +__version__ = "5.7.0" # Last update date in ISO format __updated__ = "2025-06-23" # Primary maintainer diff --git a/providers/base.py b/providers/base.py index c8b1ec7..06f60fe 100644 --- a/providers/base.py +++ b/providers/base.py @@ -140,6 +140,19 @@ class ModelCapabilities: max_image_size_mb: float = 0.0 # Maximum total size for all images in MB supports_temperature: bool = True # Whether model accepts temperature parameter in API calls + # Additional fields for comprehensive model information + description: str = "" # Human-readable description of the model + aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model + + # JSON mode support (for providers that support structured output) + supports_json_mode: bool = False + + # Thinking mode support (for models with thinking capabilities) + max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models + + # Custom model flag (for models that only work with custom endpoints) + is_custom: bool = False # Whether this model requires custom API endpoints + # Temperature constraint object - preferred way to define temperature limits temperature_constraint: TemperatureConstraint = field( default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7) @@ -251,7 +264,7 @@ class ModelProvider(ABC): capabilities = self.get_capabilities(model_name) # Check if model supports temperature at all - if hasattr(capabilities, "supports_temperature") and not capabilities.supports_temperature: + if not capabilities.supports_temperature: return None # Get temperature range @@ -290,19 +303,109 @@ class ModelProvider(ABC): """Check if the model supports extended thinking mode.""" pass - @abstractmethod + def get_model_configurations(self) -> dict[str, ModelCapabilities]: + """Get model configurations for this provider. + + This is a hook method that subclasses can override to provide + their model configurations from different sources. + + Returns: + Dictionary mapping model names to their ModelCapabilities objects + """ + # Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects) + if hasattr(self, "SUPPORTED_MODELS"): + return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)} + return {} + + def get_all_model_aliases(self) -> dict[str, list[str]]: + """Get all model aliases for this provider. + + This is a hook method that subclasses can override to provide + aliases from different sources. + + Returns: + Dictionary mapping model names to their list of aliases + """ + # Default implementation extracts from ModelCapabilities objects + aliases = {} + for model_name, capabilities in self.get_model_configurations().items(): + if capabilities.aliases: + aliases[model_name] = capabilities.aliases + return aliases + + def _resolve_model_name(self, model_name: str) -> str: + """Resolve model shorthand to full name. + + This implementation uses the hook methods to support different + model configuration sources. + + Args: + model_name: Model name that may be an alias + + Returns: + Resolved model name + """ + # Get model configurations from the hook method + model_configs = self.get_model_configurations() + + # First check if it's already a base model name (case-sensitive exact match) + if model_name in model_configs: + return model_name + + # Check case-insensitively for both base models and aliases + model_name_lower = model_name.lower() + + # Check base model names case-insensitively + for base_model in model_configs: + if base_model.lower() == model_name_lower: + return base_model + + # Check aliases from the hook method + all_aliases = self.get_all_model_aliases() + for base_model, aliases in all_aliases.items(): + if any(alias.lower() == model_name_lower for alias in aliases): + return base_model + + # If not found, return as-is + return model_name + def list_models(self, respect_restrictions: bool = True) -> list[str]: """Return a list of model names supported by this provider. + This implementation uses the get_model_configurations() hook + to support different model configuration sources. + Args: respect_restrictions: Whether to apply provider-specific restriction logic. Returns: List of model names available from this provider """ - pass + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() if respect_restrictions else None + models = [] + + # Get model configurations from the hook method + model_configs = self.get_model_configurations() + + for model_name in model_configs: + # Check restrictions if enabled + if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): + continue + + # Add the base model + models.append(model_name) + + # Get aliases from the hook method + all_aliases = self.get_all_model_aliases() + for model_name, aliases in all_aliases.items(): + # Only add aliases for models that passed restriction check + if model_name in models: + models.extend(aliases) + + return models - @abstractmethod def list_all_known_models(self) -> list[str]: """Return all model names known by this provider, including alias targets. @@ -312,21 +415,22 @@ class ModelProvider(ABC): Returns: List of all model names and alias targets known by this provider """ - pass + all_models = set() - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model shorthand to full name. + # Get model configurations from the hook method + model_configs = self.get_model_configurations() - Base implementation returns the model name unchanged. - Subclasses should override to provide alias resolution. + # Add all base model names + for model_name in model_configs: + all_models.add(model_name.lower()) - Args: - model_name: Model name that may be an alias + # Get aliases from the hook method and add them + all_aliases = self.get_all_model_aliases() + for _model_name, aliases in all_aliases.items(): + for alias in aliases: + all_models.add(alias.lower()) - Returns: - Resolved model name - """ - return model_name + return list(all_models) def close(self): """Clean up any resources held by the provider. diff --git a/providers/custom.py b/providers/custom.py index bad1062..52d9b94 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -268,65 +268,55 @@ class CustomProvider(OpenAICompatibleProvider): def supports_thinking_mode(self, model_name: str) -> bool: """Check if the model supports extended thinking mode. - Most custom/local models don't support extended thinking. - Args: model_name: Model to check Returns: - False (custom models generally don't support thinking mode) + True if model supports thinking mode, False otherwise """ + # Check if model is in registry + config = self._registry.resolve(model_name) if self._registry else None + if config and config.is_custom: + # Trust the config from custom_models.json + return config.supports_extended_thinking + + # Default to False for unknown models return False - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. + def get_model_configurations(self) -> dict[str, ModelCapabilities]: + """Get model configurations from the registry. - Args: - respect_restrictions: Whether to apply provider-specific restriction logic. + For CustomProvider, we convert registry configurations to ModelCapabilities objects. Returns: - List of model names available from this provider + Dictionary mapping model names to their ModelCapabilities objects """ - from utils.model_restrictions import get_restriction_service + from .base import ProviderType - restriction_service = get_restriction_service() if respect_restrictions else None - models = [] + configs = {} if self._registry: - # Get all models from the registry - all_models = self._registry.list_models() - aliases = self._registry.list_aliases() - - # Add models that are validated by the custom provider - for model_name in all_models + aliases: - # Use the provider's validation logic to determine if this model - # is appropriate for the custom endpoint + # Get all models from registry + for model_name in self._registry.list_models(): + # Only include custom models that this provider validates if self.validate_model_name(model_name): - # Check restrictions if enabled - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): - continue + config = self._registry.resolve(model_name) + if config and config.is_custom: + # Convert OpenRouterModelConfig to ModelCapabilities + capabilities = config.to_capabilities() + # Override provider type to CUSTOM for local models + capabilities.provider = ProviderType.CUSTOM + capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})" + configs[model_name] = capabilities - models.append(model_name) + return configs - return models - - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. + def get_all_model_aliases(self) -> dict[str, list[str]]: + """Get all model aliases from the registry. Returns: - List of all model names and alias targets known by this provider + Dictionary mapping model names to their list of aliases """ - all_models = set() - - if self._registry: - # Get all models and aliases from the registry - all_models.update(model.lower() for model in self._registry.list_models()) - all_models.update(alias.lower() for alias in self._registry.list_aliases()) - - # For each alias, also add its target - for alias in self._registry.list_aliases(): - config = self._registry.resolve(alias) - if config: - all_models.add(config.model_name.lower()) - - return list(all_models) + # Since aliases are now included in the configurations, + # we can use the base class implementation + return super().get_all_model_aliases() diff --git a/providers/dial.py b/providers/dial.py index 617858c..f019415 100644 --- a/providers/dial.py +++ b/providers/dial.py @@ -10,7 +10,7 @@ from .base import ( ModelCapabilities, ModelResponse, ProviderType, - RangeTemperatureConstraint, + create_temperature_constraint, ) from .openai_compatible import OpenAICompatibleProvider @@ -30,63 +30,161 @@ class DIALModelProvider(OpenAICompatibleProvider): MAX_RETRIES = 4 RETRY_DELAYS = [1, 3, 5, 8] # seconds - # Supported DIAL models (these can be customized based on your DIAL deployment) + # Model configurations using ModelCapabilities objects SUPPORTED_MODELS = { - "o3-2025-04-16": { - "context_window": 200_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - "o4-mini-2025-04-16": { - "context_window": 200_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - "anthropic.claude-sonnet-4-20250514-v1:0": { - "context_window": 200_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": { - "context_window": 200_000, - "supports_extended_thinking": True, # Thinking mode variant - "supports_vision": True, - }, - "anthropic.claude-opus-4-20250514-v1:0": { - "context_window": 200_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - "anthropic.claude-opus-4-20250514-v1:0-with-thinking": { - "context_window": 200_000, - "supports_extended_thinking": True, # Thinking mode variant - "supports_vision": True, - }, - "gemini-2.5-pro-preview-03-25-google-search": { - "context_window": 1_000_000, - "supports_extended_thinking": False, # DIAL doesn't expose thinking mode - "supports_vision": True, - }, - "gemini-2.5-pro-preview-05-06": { - "context_window": 1_000_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - "gemini-2.5-flash-preview-05-20": { - "context_window": 1_000_000, - "supports_extended_thinking": False, - "supports_vision": True, - }, - # Shorthands - "o3": "o3-2025-04-16", - "o4-mini": "o4-mini-2025-04-16", - "sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0", - "sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking", - "opus-4": "anthropic.claude-opus-4-20250514-v1:0", - "opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking", - "gemini-2.5-pro": "gemini-2.5-pro-preview-05-06", - "gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search", - "gemini-2.5-flash": "gemini-2.5-flash-preview-05-20", + "o3-2025-04-16": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="o3-2025-04-16", + friendly_name="DIAL (O3)", + context_window=200_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=False, # O3 models don't accept temperature + temperature_constraint=create_temperature_constraint("fixed"), + description="OpenAI O3 via DIAL - Strong reasoning model", + aliases=["o3"], + ), + "o4-mini-2025-04-16": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="o4-mini-2025-04-16", + friendly_name="DIAL (O4-mini)", + context_window=200_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=False, # O4 models don't accept temperature + temperature_constraint=create_temperature_constraint("fixed"), + description="OpenAI O4-mini via DIAL - Fast reasoning model", + aliases=["o4-mini"], + ), + "anthropic.claude-sonnet-4-20250514-v1:0": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-sonnet-4-20250514-v1:0", + friendly_name="DIAL (Sonnet 4)", + context_window=200_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Claude Sonnet 4 via DIAL - Balanced performance", + aliases=["sonnet-4"], + ), + "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-sonnet-4-20250514-v1:0-with-thinking", + friendly_name="DIAL (Sonnet 4 Thinking)", + context_window=200_000, + supports_extended_thinking=True, # Thinking mode variant + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Claude Sonnet 4 with thinking mode via DIAL", + aliases=["sonnet-4-thinking"], + ), + "anthropic.claude-opus-4-20250514-v1:0": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-opus-4-20250514-v1:0", + friendly_name="DIAL (Opus 4)", + context_window=200_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Claude Opus 4 via DIAL - Most capable Claude model", + aliases=["opus-4"], + ), + "anthropic.claude-opus-4-20250514-v1:0-with-thinking": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-opus-4-20250514-v1:0-with-thinking", + friendly_name="DIAL (Opus 4 Thinking)", + context_window=200_000, + supports_extended_thinking=True, # Thinking mode variant + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Claude Opus 4 with thinking mode via DIAL", + aliases=["opus-4-thinking"], + ), + "gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="gemini-2.5-pro-preview-03-25-google-search", + friendly_name="DIAL (Gemini 2.5 Pro Search)", + context_window=1_000_000, + supports_extended_thinking=False, # DIAL doesn't expose thinking mode + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Gemini 2.5 Pro with Google Search via DIAL", + aliases=["gemini-2.5-pro-search"], + ), + "gemini-2.5-pro-preview-05-06": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="gemini-2.5-pro-preview-05-06", + friendly_name="DIAL (Gemini 2.5 Pro)", + context_window=1_000_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Gemini 2.5 Pro via DIAL - Deep reasoning", + aliases=["gemini-2.5-pro"], + ), + "gemini-2.5-flash-preview-05-20": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="gemini-2.5-flash-preview-05-20", + friendly_name="DIAL (Gemini Flash 2.5)", + context_window=1_000_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Gemini 2.5 Flash via DIAL - Ultra-fast", + aliases=["gemini-2.5-flash"], + ), } def __init__(self, api_key: str, **kwargs): @@ -181,20 +279,8 @@ class DIALModelProvider(OpenAICompatibleProvider): if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name): raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.") - config = self.SUPPORTED_MODELS[resolved_name] - - return ModelCapabilities( - provider=ProviderType.DIAL, - model_name=resolved_name, - friendly_name=self.FRIENDLY_NAME, - context_window=config["context_window"], - supports_extended_thinking=config["supports_extended_thinking"], - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=True, - supports_images=config.get("supports_vision", False), - temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7), - ) + # Return the ModelCapabilities object directly from SUPPORTED_MODELS + return self.SUPPORTED_MODELS[resolved_name] def get_provider_type(self) -> ProviderType: """Get the provider type.""" @@ -211,7 +297,7 @@ class DIALModelProvider(OpenAICompatibleProvider): """ resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + if resolved_name not in self.SUPPORTED_MODELS: return False # Check against base class allowed_models if configured @@ -231,20 +317,6 @@ class DIALModelProvider(OpenAICompatibleProvider): return True - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model shorthand to full name. - - Args: - model_name: Model name or shorthand - - Returns: - Full model name - """ - shorthand_value = self.SUPPORTED_MODELS.get(model_name) - if isinstance(shorthand_value, str): - return shorthand_value - return model_name - def _get_deployment_client(self, deployment: str): """Get or create a cached client for a specific deployment. @@ -357,7 +429,7 @@ class DIALModelProvider(OpenAICompatibleProvider): # Check model capabilities try: capabilities = self.get_capabilities(model_name) - supports_temperature = getattr(capabilities, "supports_temperature", True) + supports_temperature = capabilities.supports_temperature except Exception as e: logger.debug(f"Failed to check temperature support for {model_name}: {e}") supports_temperature = True @@ -441,63 +513,12 @@ class DIALModelProvider(OpenAICompatibleProvider): """ resolved_name = self._resolve_model_name(model_name) - if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict): - return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False) + if resolved_name in self.SUPPORTED_MODELS: + return self.SUPPORTED_MODELS[resolved_name].supports_images # Fall back to parent implementation for unknown models return super()._supports_vision(model_name) - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. - - Args: - respect_restrictions: Whether to apply provider-specific restriction logic. - - Returns: - List of model names available from this provider - """ - # Get all model keys (both full names and aliases) - all_models = list(self.SUPPORTED_MODELS.keys()) - - if not respect_restrictions: - return all_models - - # Apply restrictions if configured - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() - - # Filter based on restrictions - allowed_models = [] - for model in all_models: - resolved_name = self._resolve_model_name(model) - if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model): - allowed_models.append(model) - - return allowed_models - - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. - - This is used for validation purposes to ensure restriction policies - can validate against both aliases and their target model names. - - Returns: - List of all model names and alias targets known by this provider - """ - # Collect all unique model names (both aliases and targets) - all_models = set() - - for key, value in self.SUPPORTED_MODELS.items(): - # Add the key (could be alias or full name) - all_models.add(key) - - # If it's an alias (string value), add the target too - if isinstance(value, str): - all_models.add(value) - - return sorted(all_models) - def close(self): """Clean up HTTP clients when provider is closed.""" logger.info("Closing DIAL provider HTTP clients...") diff --git a/providers/gemini.py b/providers/gemini.py index 074232f..1118699 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -9,7 +9,7 @@ from typing import Optional from google import genai from google.genai import types -from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint +from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint logger = logging.getLogger(__name__) @@ -17,47 +17,79 @@ logger = logging.getLogger(__name__) class GeminiModelProvider(ModelProvider): """Google Gemini model provider implementation.""" - # Model configurations + # Model configurations using ModelCapabilities objects SUPPORTED_MODELS = { - "gemini-2.0-flash": { - "context_window": 1_048_576, # 1M tokens - "supports_extended_thinking": True, # Experimental thinking mode - "max_thinking_tokens": 24576, # Same as 2.5 flash for consistency - "supports_images": True, # Vision capability - "max_image_size_mb": 20.0, # Conservative 20MB limit for reliability - "description": "Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input", - }, - "gemini-2.0-flash-lite": { - "context_window": 1_048_576, # 1M tokens - "supports_extended_thinking": False, # Not supported per user request - "max_thinking_tokens": 0, # No thinking support - "supports_images": False, # Does not support images - "max_image_size_mb": 0.0, # No image support - "description": "Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only", - }, - "gemini-2.5-flash": { - "context_window": 1_048_576, # 1M tokens - "supports_extended_thinking": True, - "max_thinking_tokens": 24576, # Flash 2.5 thinking budget limit - "supports_images": True, # Vision capability - "max_image_size_mb": 20.0, # Conservative 20MB limit for reliability - "description": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations", - }, - "gemini-2.5-pro": { - "context_window": 1_048_576, # 1M tokens - "supports_extended_thinking": True, - "max_thinking_tokens": 32768, # Pro 2.5 thinking budget limit - "supports_images": True, # Vision capability - "max_image_size_mb": 32.0, # Higher limit for Pro model - "description": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis", - }, - # Shorthands - "flash": "gemini-2.5-flash", - "flash-2.0": "gemini-2.0-flash", - "flash2": "gemini-2.0-flash", - "flashlite": "gemini-2.0-flash-lite", - "flash-lite": "gemini-2.0-flash-lite", - "pro": "gemini-2.5-pro", + "gemini-2.0-flash": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.0-flash", + friendly_name="Gemini (Flash 2.0)", + context_window=1_048_576, # 1M tokens + supports_extended_thinking=True, # Experimental thinking mode + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # Vision capability + max_image_size_mb=20.0, # Conservative 20MB limit for reliability + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + max_thinking_tokens=24576, # Same as 2.5 flash for consistency + description="Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input", + aliases=["flash-2.0", "flash2"], + ), + "gemini-2.0-flash-lite": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.0-flash-lite", + friendly_name="Gemin (Flash Lite 2.0)", + context_window=1_048_576, # 1M tokens + supports_extended_thinking=False, # Not supported per user request + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=False, # Does not support images + max_image_size_mb=0.0, # No image support + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only", + aliases=["flashlite", "flash-lite"], + ), + "gemini-2.5-flash": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.5-flash", + friendly_name="Gemini (Flash 2.5)", + context_window=1_048_576, # 1M tokens + supports_extended_thinking=True, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # Vision capability + max_image_size_mb=20.0, # Conservative 20MB limit for reliability + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + max_thinking_tokens=24576, # Flash 2.5 thinking budget limit + description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations", + aliases=["flash", "flash2.5"], + ), + "gemini-2.5-pro": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.5-pro", + friendly_name="Gemini (Pro 2.5)", + context_window=1_048_576, # 1M tokens + supports_extended_thinking=True, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # Vision capability + max_image_size_mb=32.0, # Higher limit for Pro model + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + max_thinking_tokens=32768, # Max thinking tokens for Pro model + description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis", + aliases=["pro", "gemini pro", "gemini-pro"], + ), } # Thinking mode configurations - percentages of model's max_thinking_tokens @@ -70,6 +102,14 @@ class GeminiModelProvider(ModelProvider): "max": 1.0, # 100% of max - full thinking budget } + # Model-specific thinking token limits + MAX_THINKING_TOKENS = { + "gemini-2.0-flash": 24576, # Same as 2.5 flash for consistency + "gemini-2.0-flash-lite": 0, # No thinking support + "gemini-2.5-flash": 24576, # Flash 2.5 thinking budget limit + "gemini-2.5-pro": 32768, # Pro 2.5 thinking budget limit + } + def __init__(self, api_key: str, **kwargs): """Initialize Gemini provider with API key.""" super().__init__(api_key, **kwargs) @@ -100,25 +140,8 @@ class GeminiModelProvider(ModelProvider): if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name): raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.") - config = self.SUPPORTED_MODELS[resolved_name] - - # Gemini models support 0.0-2.0 temperature range - temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7) - - return ModelCapabilities( - provider=ProviderType.GOOGLE, - model_name=resolved_name, - friendly_name="Gemini", - context_window=config["context_window"], - supports_extended_thinking=config["supports_extended_thinking"], - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=True, - supports_images=config.get("supports_images", False), - max_image_size_mb=config.get("max_image_size_mb", 0.0), - supports_temperature=True, # Gemini models accept temperature parameter - temperature_constraint=temp_constraint, - ) + # Return the ModelCapabilities object directly from SUPPORTED_MODELS + return self.SUPPORTED_MODELS[resolved_name] def generate_content( self, @@ -179,8 +202,8 @@ class GeminiModelProvider(ModelProvider): if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS: # Get model's max thinking tokens and calculate actual budget model_config = self.SUPPORTED_MODELS.get(resolved_name) - if model_config and "max_thinking_tokens" in model_config: - max_thinking_tokens = model_config["max_thinking_tokens"] + if model_config and model_config.max_thinking_tokens > 0: + max_thinking_tokens = model_config.max_thinking_tokens actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode]) generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget) @@ -258,7 +281,7 @@ class GeminiModelProvider(ModelProvider): resolved_name = self._resolve_model_name(model_name) # First check if model is supported - if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + if resolved_name not in self.SUPPORTED_MODELS: return False # Then check if model is allowed by restrictions @@ -281,78 +304,20 @@ class GeminiModelProvider(ModelProvider): def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int: """Get actual thinking token budget for a model and thinking mode.""" resolved_name = self._resolve_model_name(model_name) - model_config = self.SUPPORTED_MODELS.get(resolved_name, {}) + model_config = self.SUPPORTED_MODELS.get(resolved_name) - if not model_config.get("supports_extended_thinking", False): + if not model_config or not model_config.supports_extended_thinking: return 0 if thinking_mode not in self.THINKING_BUDGETS: return 0 - max_thinking_tokens = model_config.get("max_thinking_tokens", 0) + max_thinking_tokens = model_config.max_thinking_tokens if max_thinking_tokens == 0: return 0 return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode]) - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. - - Args: - respect_restrictions: Whether to apply provider-specific restriction logic. - - Returns: - List of model names available from this provider - """ - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() if respect_restrictions else None - models = [] - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Handle both base models (dict configs) and aliases (string values) - if isinstance(config, str): - # This is an alias - check if the target model would be allowed - target_model = config - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model): - continue - # Allow the alias - models.append(model_name) - else: - # This is a base model with config dict - # Check restrictions if enabled - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): - continue - models.append(model_name) - - return models - - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. - - Returns: - List of all model names and alias targets known by this provider - """ - all_models = set() - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Add the model name itself - all_models.add(model_name.lower()) - - # If it's an alias (string value), add the target model too - if isinstance(config, str): - all_models.add(config.lower()) - - return list(all_models) - - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model shorthand to full name.""" - # Check if it's a shorthand - shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower()) - if isinstance(shorthand_value, str): - return shorthand_value - return model_name - def _extract_usage(self, response) -> dict[str, int]: """Extract token usage from Gemini response.""" usage = {} diff --git a/providers/openai_provider.py b/providers/openai_provider.py index 3553673..e065ee1 100644 --- a/providers/openai_provider.py +++ b/providers/openai_provider.py @@ -17,71 +17,110 @@ logger = logging.getLogger(__name__) class OpenAIModelProvider(OpenAICompatibleProvider): """Official OpenAI API provider (api.openai.com).""" - # Model configurations + # Model configurations using ModelCapabilities objects SUPPORTED_MODELS = { - "o3": { - "context_window": 200_000, # 200K tokens - "supports_extended_thinking": False, - "supports_images": True, # O3 models support vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": False, # O3 models don't accept temperature parameter - "temperature_constraint": "fixed", # Fixed at 1.0 - "description": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis", - }, - "o3-mini": { - "context_window": 200_000, # 200K tokens - "supports_extended_thinking": False, - "supports_images": True, # O3 models support vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": False, # O3 models don't accept temperature parameter - "temperature_constraint": "fixed", # Fixed at 1.0 - "description": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity", - }, - "o3-pro-2025-06-10": { - "context_window": 200_000, # 200K tokens - "supports_extended_thinking": False, - "supports_images": True, # O3 models support vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": False, # O3 models don't accept temperature parameter - "temperature_constraint": "fixed", # Fixed at 1.0 - "description": "Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.", - }, - # Aliases - "o3-pro": "o3-pro-2025-06-10", - "o4-mini": { - "context_window": 200_000, # 200K tokens - "supports_extended_thinking": False, - "supports_images": True, # O4 models support vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": False, # O4 models don't accept temperature parameter - "temperature_constraint": "fixed", # Fixed at 1.0 - "description": "Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning", - }, - "o4-mini-high": { - "context_window": 200_000, # 200K tokens - "supports_extended_thinking": False, - "supports_images": True, # O4 models support vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": False, # O4 models don't accept temperature parameter - "temperature_constraint": "fixed", # Fixed at 1.0 - "description": "Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks", - }, - "gpt-4.1-2025-04-14": { - "context_window": 1_000_000, # 1M tokens - "supports_extended_thinking": False, - "supports_images": True, # GPT-4.1 supports vision - "max_image_size_mb": 20.0, # 20MB per OpenAI docs - "supports_temperature": True, # Regular models accept temperature parameter - "temperature_constraint": "range", # 0.0-2.0 range - "description": "GPT-4.1 (1M context) - Advanced reasoning model with large context window", - }, - # Shorthands - "mini": "o4-mini", # Default 'mini' to latest mini model - "o3mini": "o3-mini", - "o4mini": "o4-mini", - "o4minihigh": "o4-mini-high", - "o4minihi": "o4-mini-high", - "gpt4.1": "gpt-4.1-2025-04-14", + "o3": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o3", + friendly_name="OpenAI (O3)", + context_window=200_000, # 200K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O3 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O3 models don't accept temperature parameter + temperature_constraint=create_temperature_constraint("fixed"), + description="Strong reasoning (200K context) - Logical problems, code generation, systematic analysis", + aliases=[], + ), + "o3-mini": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o3-mini", + friendly_name="OpenAI (O3-mini)", + context_window=200_000, # 200K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O3 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O3 models don't accept temperature parameter + temperature_constraint=create_temperature_constraint("fixed"), + description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity", + aliases=["o3mini", "o3-mini"], + ), + "o3-pro-2025-06-10": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o3-pro-2025-06-10", + friendly_name="OpenAI (O3-Pro)", + context_window=200_000, # 200K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O3 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O3 models don't accept temperature parameter + temperature_constraint=create_temperature_constraint("fixed"), + description="Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.", + aliases=["o3-pro"], + ), + "o4-mini": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o4-mini", + friendly_name="OpenAI (O4-mini)", + context_window=200_000, # 200K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O4 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O4 models don't accept temperature parameter + temperature_constraint=create_temperature_constraint("fixed"), + description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning", + aliases=["mini", "o4mini"], + ), + "o4-mini-high": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o4-mini-high", + friendly_name="OpenAI (O4-mini-high)", + context_window=200_000, # 200K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O4 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O4 models don't accept temperature parameter + temperature_constraint=create_temperature_constraint("fixed"), + description="Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks", + aliases=["o4minihigh", "o4minihi", "mini-high"], + ), + "gpt-4.1-2025-04-14": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="gpt-4.1-2025-04-14", + friendly_name="OpenAI (GPT 4.1)", + context_window=1_000_000, # 1M tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # GPT-4.1 supports vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=True, # Regular models accept temperature parameter + temperature_constraint=create_temperature_constraint("range"), + description="GPT-4.1 (1M context) - Advanced reasoning model with large context window", + aliases=["gpt4.1"], + ), } def __init__(self, api_key: str, **kwargs): @@ -95,7 +134,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): # Resolve shorthand resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str): + if resolved_name not in self.SUPPORTED_MODELS: raise ValueError(f"Unsupported OpenAI model: {model_name}") # Check if model is allowed by restrictions @@ -105,27 +144,8 @@ class OpenAIModelProvider(OpenAICompatibleProvider): if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name): raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.") - config = self.SUPPORTED_MODELS[resolved_name] - - # Get temperature constraints and support from configuration - supports_temperature = config.get("supports_temperature", True) # Default to True for backward compatibility - temp_constraint_type = config.get("temperature_constraint", "range") # Default to range - temp_constraint = create_temperature_constraint(temp_constraint_type) - - return ModelCapabilities( - provider=ProviderType.OPENAI, - model_name=model_name, - friendly_name="OpenAI", - context_window=config["context_window"], - supports_extended_thinking=config["supports_extended_thinking"], - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=True, - supports_images=config.get("supports_images", False), - max_image_size_mb=config.get("max_image_size_mb", 0.0), - supports_temperature=supports_temperature, - temperature_constraint=temp_constraint, - ) + # Return the ModelCapabilities object directly from SUPPORTED_MODELS + return self.SUPPORTED_MODELS[resolved_name] def get_provider_type(self) -> ProviderType: """Get the provider type.""" @@ -136,7 +156,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider): resolved_name = self._resolve_model_name(model_name) # First check if model is supported - if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + if resolved_name not in self.SUPPORTED_MODELS: return False # Then check if model is allowed by restrictions @@ -177,61 +197,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider): # Currently no OpenAI models support extended thinking # This may change with future O3 models return False - - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. - - Args: - respect_restrictions: Whether to apply provider-specific restriction logic. - - Returns: - List of model names available from this provider - """ - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() if respect_restrictions else None - models = [] - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Handle both base models (dict configs) and aliases (string values) - if isinstance(config, str): - # This is an alias - check if the target model would be allowed - target_model = config - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model): - continue - # Allow the alias - models.append(model_name) - else: - # This is a base model with config dict - # Check restrictions if enabled - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): - continue - models.append(model_name) - - return models - - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. - - Returns: - List of all model names and alias targets known by this provider - """ - all_models = set() - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Add the model name itself - all_models.add(model_name.lower()) - - # If it's an alias (string value), add the target model too - if isinstance(config, str): - all_models.add(config.lower()) - - return list(all_models) - - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model shorthand to full name.""" - # Check if it's a shorthand - shorthand_value = self.SUPPORTED_MODELS.get(model_name) - if isinstance(shorthand_value, str): - return shorthand_value - return model_name diff --git a/providers/openrouter.py b/providers/openrouter.py index 1e22b45..5d29514 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -270,3 +270,39 @@ class OpenRouterProvider(OpenAICompatibleProvider): all_models.add(config.model_name.lower()) return list(all_models) + + def get_model_configurations(self) -> dict[str, ModelCapabilities]: + """Get model configurations from the registry. + + For OpenRouter, we convert registry configurations to ModelCapabilities objects. + + Returns: + Dictionary mapping model names to their ModelCapabilities objects + """ + configs = {} + + if self._registry: + # Get all models from registry + for model_name in self._registry.list_models(): + # Only include models that this provider validates + if self.validate_model_name(model_name): + config = self._registry.resolve(model_name) + if config and not config.is_custom: # Only OpenRouter models, not custom ones + # Convert OpenRouterModelConfig to ModelCapabilities + capabilities = config.to_capabilities() + # Override provider type to OPENROUTER + capabilities.provider = ProviderType.OPENROUTER + capabilities.friendly_name = f"{self.FRIENDLY_NAME} ({config.model_name})" + configs[model_name] = capabilities + + return configs + + def get_all_model_aliases(self) -> dict[str, list[str]]: + """Get all model aliases from the registry. + + Returns: + Dictionary mapping model names to their list of aliases + """ + # Since aliases are now included in the configurations, + # we can use the base class implementation + return super().get_all_model_aliases() diff --git a/providers/xai.py b/providers/xai.py index 71d5c8a..2b6fd04 100644 --- a/providers/xai.py +++ b/providers/xai.py @@ -7,7 +7,7 @@ from .base import ( ModelCapabilities, ModelResponse, ProviderType, - RangeTemperatureConstraint, + create_temperature_constraint, ) from .openai_compatible import OpenAICompatibleProvider @@ -19,23 +19,42 @@ class XAIModelProvider(OpenAICompatibleProvider): FRIENDLY_NAME = "X.AI" - # Model configurations + # Model configurations using ModelCapabilities objects SUPPORTED_MODELS = { - "grok-3": { - "context_window": 131_072, # 131K tokens - "supports_extended_thinking": False, - "description": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis", - }, - "grok-3-fast": { - "context_window": 131_072, # 131K tokens - "supports_extended_thinking": False, - "description": "GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive", - }, - # Shorthands for convenience - "grok": "grok-3", # Default to grok-3 - "grok3": "grok-3", - "grok3fast": "grok-3-fast", - "grokfast": "grok-3-fast", + "grok-3": ModelCapabilities( + provider=ProviderType.XAI, + model_name="grok-3", + friendly_name="X.AI (Grok 3)", + context_window=131_072, # 131K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet + supports_images=False, # Assuming GROK is text-only for now + max_image_size_mb=0.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis", + aliases=["grok", "grok3"], + ), + "grok-3-fast": ModelCapabilities( + provider=ProviderType.XAI, + model_name="grok-3-fast", + friendly_name="X.AI (Grok 3 Fast)", + context_window=131_072, # 131K tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet + supports_images=False, # Assuming GROK is text-only for now + max_image_size_mb=0.0, + supports_temperature=True, + temperature_constraint=create_temperature_constraint("range"), + description="GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive", + aliases=["grok3fast", "grokfast", "grok3-fast"], + ), } def __init__(self, api_key: str, **kwargs): @@ -49,7 +68,7 @@ class XAIModelProvider(OpenAICompatibleProvider): # Resolve shorthand resolved_name = self._resolve_model_name(model_name) - if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str): + if resolved_name not in self.SUPPORTED_MODELS: raise ValueError(f"Unsupported X.AI model: {model_name}") # Check if model is allowed by restrictions @@ -59,23 +78,8 @@ class XAIModelProvider(OpenAICompatibleProvider): if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name): raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.") - config = self.SUPPORTED_MODELS[resolved_name] - - # Define temperature constraints for GROK models - # GROK supports the standard OpenAI temperature range - temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7) - - return ModelCapabilities( - provider=ProviderType.XAI, - model_name=resolved_name, - friendly_name=self.FRIENDLY_NAME, - context_window=config["context_window"], - supports_extended_thinking=config["supports_extended_thinking"], - supports_system_prompts=True, - supports_streaming=True, - supports_function_calling=True, - temperature_constraint=temp_constraint, - ) + # Return the ModelCapabilities object directly from SUPPORTED_MODELS + return self.SUPPORTED_MODELS[resolved_name] def get_provider_type(self) -> ProviderType: """Get the provider type.""" @@ -86,7 +90,7 @@ class XAIModelProvider(OpenAICompatibleProvider): resolved_name = self._resolve_model_name(model_name) # First check if model is supported - if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): + if resolved_name not in self.SUPPORTED_MODELS: return False # Then check if model is allowed by restrictions @@ -127,61 +131,3 @@ class XAIModelProvider(OpenAICompatibleProvider): # Currently GROK models do not support extended thinking # This may change with future GROK model releases return False - - def list_models(self, respect_restrictions: bool = True) -> list[str]: - """Return a list of model names supported by this provider. - - Args: - respect_restrictions: Whether to apply provider-specific restriction logic. - - Returns: - List of model names available from this provider - """ - from utils.model_restrictions import get_restriction_service - - restriction_service = get_restriction_service() if respect_restrictions else None - models = [] - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Handle both base models (dict configs) and aliases (string values) - if isinstance(config, str): - # This is an alias - check if the target model would be allowed - target_model = config - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model): - continue - # Allow the alias - models.append(model_name) - else: - # This is a base model with config dict - # Check restrictions if enabled - if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): - continue - models.append(model_name) - - return models - - def list_all_known_models(self) -> list[str]: - """Return all model names known by this provider, including alias targets. - - Returns: - List of all model names and alias targets known by this provider - """ - all_models = set() - - for model_name, config in self.SUPPORTED_MODELS.items(): - # Add the model name itself - all_models.add(model_name.lower()) - - # If it's an alias (string value), add the target model too - if isinstance(config, str): - all_models.add(config.lower()) - - return list(all_models) - - def _resolve_model_name(self, model_name: str) -> str: - """Resolve model shorthand to full name.""" - # Check if it's a shorthand - shorthand_value = self.SUPPORTED_MODELS.get(model_name) - if isinstance(shorthand_value, str): - return shorthand_value - return model_name diff --git a/tests/test_auto_mode.py b/tests/test_auto_mode.py index 1aa4376..74d8ae3 100644 --- a/tests/test_auto_mode.py +++ b/tests/test_auto_mode.py @@ -59,7 +59,7 @@ class TestAutoMode: continue # Check that model has description - description = config.get("description", "") + description = config.description if hasattr(config, "description") else "" if description: models_with_descriptions[model_name] = description diff --git a/tests/test_auto_mode_comprehensive.py b/tests/test_auto_mode_comprehensive.py index 8539fdf..4d699b0 100644 --- a/tests/test_auto_mode_comprehensive.py +++ b/tests/test_auto_mode_comprehensive.py @@ -319,7 +319,18 @@ class TestAutoModeComprehensive: m for m in available_models if not m.startswith("gemini") - and m not in ["flash", "pro", "flash-2.0", "flash2", "flashlite", "flash-lite"] + and m + not in [ + "flash", + "pro", + "flash-2.0", + "flash2", + "flashlite", + "flash-lite", + "flash2.5", + "gemini pro", + "gemini-pro", + ] ] assert ( len(non_gemini_models) == 0 diff --git a/tests/test_dial_provider.py b/tests/test_dial_provider.py index 4a22cb6..62af59c 100644 --- a/tests/test_dial_provider.py +++ b/tests/test_dial_provider.py @@ -84,7 +84,7 @@ class TestDIALProvider: # Test O3 capabilities capabilities = provider.get_capabilities("o3") assert capabilities.model_name == "o3-2025-04-16" - assert capabilities.friendly_name == "DIAL" + assert capabilities.friendly_name == "DIAL (O3)" assert capabilities.context_window == 200_000 assert capabilities.provider == ProviderType.DIAL assert capabilities.supports_images is True diff --git a/tests/test_openai_provider.py b/tests/test_openai_provider.py index e9e3ae8..baab182 100644 --- a/tests/test_openai_provider.py +++ b/tests/test_openai_provider.py @@ -85,7 +85,7 @@ class TestOpenAIProvider: capabilities = provider.get_capabilities("o3") assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities - assert capabilities.friendly_name == "OpenAI" + assert capabilities.friendly_name == "OpenAI (O3)" assert capabilities.context_window == 200_000 assert capabilities.provider == ProviderType.OPENAI assert not capabilities.supports_extended_thinking @@ -101,8 +101,8 @@ class TestOpenAIProvider: provider = OpenAIModelProvider("test-key") capabilities = provider.get_capabilities("mini") - assert capabilities.model_name == "mini" # Capabilities should show original request - assert capabilities.friendly_name == "OpenAI" + assert capabilities.model_name == "o4-mini" # Capabilities should show resolved model name + assert capabilities.friendly_name == "OpenAI (O4-mini)" assert capabilities.context_window == 200_000 assert capabilities.provider == ProviderType.OPENAI diff --git a/tests/test_supported_models_aliases.py b/tests/test_supported_models_aliases.py new file mode 100644 index 0000000..6ed899f --- /dev/null +++ b/tests/test_supported_models_aliases.py @@ -0,0 +1,206 @@ +"""Test the SUPPORTED_MODELS aliases structure across all providers.""" + +from providers.dial import DIALModelProvider +from providers.gemini import GeminiModelProvider +from providers.openai_provider import OpenAIModelProvider +from providers.xai import XAIModelProvider + + +class TestSupportedModelsAliases: + """Test that all providers have correctly structured SUPPORTED_MODELS with aliases.""" + + def test_gemini_provider_aliases(self): + """Test Gemini provider's alias structure.""" + provider = GeminiModelProvider("test-key") + + # Check that all models have ModelCapabilities with aliases + for model_name, config in provider.SUPPORTED_MODELS.items(): + assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" + assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" + + # Test specific aliases + assert "flash" in provider.SUPPORTED_MODELS["gemini-2.5-flash"].aliases + assert "pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro"].aliases + assert "flash-2.0" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases + assert "flash2" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases + assert "flashlite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases + assert "flash-lite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases + + # Test alias resolution + assert provider._resolve_model_name("flash") == "gemini-2.5-flash" + assert provider._resolve_model_name("pro") == "gemini-2.5-pro" + assert provider._resolve_model_name("flash-2.0") == "gemini-2.0-flash" + assert provider._resolve_model_name("flash2") == "gemini-2.0-flash" + assert provider._resolve_model_name("flashlite") == "gemini-2.0-flash-lite" + + # Test case insensitive resolution + assert provider._resolve_model_name("Flash") == "gemini-2.5-flash" + assert provider._resolve_model_name("PRO") == "gemini-2.5-pro" + + def test_openai_provider_aliases(self): + """Test OpenAI provider's alias structure.""" + provider = OpenAIModelProvider("test-key") + + # Check that all models have ModelCapabilities with aliases + for model_name, config in provider.SUPPORTED_MODELS.items(): + assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" + assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" + + # Test specific aliases + assert "mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases + assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases + assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases + assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro-2025-06-10"].aliases + assert "o4minihigh" in provider.SUPPORTED_MODELS["o4-mini-high"].aliases + assert "o4minihi" in provider.SUPPORTED_MODELS["o4-mini-high"].aliases + assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1-2025-04-14"].aliases + + # Test alias resolution + assert provider._resolve_model_name("mini") == "o4-mini" + assert provider._resolve_model_name("o3mini") == "o3-mini" + assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10" + assert provider._resolve_model_name("o4minihigh") == "o4-mini-high" + assert provider._resolve_model_name("gpt4.1") == "gpt-4.1-2025-04-14" + + # Test case insensitive resolution + assert provider._resolve_model_name("Mini") == "o4-mini" + assert provider._resolve_model_name("O3MINI") == "o3-mini" + + def test_xai_provider_aliases(self): + """Test XAI provider's alias structure.""" + provider = XAIModelProvider("test-key") + + # Check that all models have ModelCapabilities with aliases + for model_name, config in provider.SUPPORTED_MODELS.items(): + assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" + assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" + + # Test specific aliases + assert "grok" in provider.SUPPORTED_MODELS["grok-3"].aliases + assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases + assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases + assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases + + # Test alias resolution + assert provider._resolve_model_name("grok") == "grok-3" + assert provider._resolve_model_name("grok3") == "grok-3" + assert provider._resolve_model_name("grok3fast") == "grok-3-fast" + assert provider._resolve_model_name("grokfast") == "grok-3-fast" + + # Test case insensitive resolution + assert provider._resolve_model_name("Grok") == "grok-3" + assert provider._resolve_model_name("GROKFAST") == "grok-3-fast" + + def test_dial_provider_aliases(self): + """Test DIAL provider's alias structure.""" + provider = DIALModelProvider("test-key") + + # Check that all models have ModelCapabilities with aliases + for model_name, config in provider.SUPPORTED_MODELS.items(): + assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute" + assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" + + # Test specific aliases + assert "o3" in provider.SUPPORTED_MODELS["o3-2025-04-16"].aliases + assert "o4-mini" in provider.SUPPORTED_MODELS["o4-mini-2025-04-16"].aliases + assert "sonnet-4" in provider.SUPPORTED_MODELS["anthropic.claude-sonnet-4-20250514-v1:0"].aliases + assert "opus-4" in provider.SUPPORTED_MODELS["anthropic.claude-opus-4-20250514-v1:0"].aliases + assert "gemini-2.5-pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro-preview-05-06"].aliases + + # Test alias resolution + assert provider._resolve_model_name("o3") == "o3-2025-04-16" + assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16" + assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0" + assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0" + + # Test case insensitive resolution + assert provider._resolve_model_name("O3") == "o3-2025-04-16" + assert provider._resolve_model_name("SONNET-4") == "anthropic.claude-sonnet-4-20250514-v1:0" + + def test_list_models_includes_aliases(self): + """Test that list_models returns both base models and aliases.""" + # Test Gemini + gemini_provider = GeminiModelProvider("test-key") + gemini_models = gemini_provider.list_models(respect_restrictions=False) + assert "gemini-2.5-flash" in gemini_models + assert "flash" in gemini_models + assert "gemini-2.5-pro" in gemini_models + assert "pro" in gemini_models + + # Test OpenAI + openai_provider = OpenAIModelProvider("test-key") + openai_models = openai_provider.list_models(respect_restrictions=False) + assert "o4-mini" in openai_models + assert "mini" in openai_models + assert "o3-mini" in openai_models + assert "o3mini" in openai_models + + # Test XAI + xai_provider = XAIModelProvider("test-key") + xai_models = xai_provider.list_models(respect_restrictions=False) + assert "grok-3" in xai_models + assert "grok" in xai_models + assert "grok-3-fast" in xai_models + assert "grokfast" in xai_models + + # Test DIAL + dial_provider = DIALModelProvider("test-key") + dial_models = dial_provider.list_models(respect_restrictions=False) + assert "o3-2025-04-16" in dial_models + assert "o3" in dial_models + + def test_list_all_known_models_includes_aliases(self): + """Test that list_all_known_models returns all models and aliases in lowercase.""" + # Test Gemini + gemini_provider = GeminiModelProvider("test-key") + gemini_all = gemini_provider.list_all_known_models() + assert "gemini-2.5-flash" in gemini_all + assert "flash" in gemini_all + assert "gemini-2.5-pro" in gemini_all + assert "pro" in gemini_all + # All should be lowercase + assert all(model == model.lower() for model in gemini_all) + + # Test OpenAI + openai_provider = OpenAIModelProvider("test-key") + openai_all = openai_provider.list_all_known_models() + assert "o4-mini" in openai_all + assert "mini" in openai_all + assert "o3-mini" in openai_all + assert "o3mini" in openai_all + # All should be lowercase + assert all(model == model.lower() for model in openai_all) + + def test_no_string_shorthand_in_supported_models(self): + """Test that no provider has string-based shorthands anymore.""" + providers = [ + GeminiModelProvider("test-key"), + OpenAIModelProvider("test-key"), + XAIModelProvider("test-key"), + DIALModelProvider("test-key"), + ] + + for provider in providers: + for model_name, config in provider.SUPPORTED_MODELS.items(): + # All values must be ModelCapabilities objects, not strings or dicts + from providers.base import ModelCapabilities + + assert isinstance(config, ModelCapabilities), ( + f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] " + f"must be a ModelCapabilities object, not {type(config).__name__}" + ) + + def test_resolve_returns_original_if_not_found(self): + """Test that _resolve_model_name returns original name if alias not found.""" + providers = [ + GeminiModelProvider("test-key"), + OpenAIModelProvider("test-key"), + XAIModelProvider("test-key"), + DIALModelProvider("test-key"), + ] + + for provider in providers: + # Test with unknown model name + assert provider._resolve_model_name("unknown-model") == "unknown-model" + assert provider._resolve_model_name("gpt-4") == "gpt-4" + assert provider._resolve_model_name("claude-3") == "claude-3" diff --git a/tests/test_xai_provider.py b/tests/test_xai_provider.py index e002636..978d9c1 100644 --- a/tests/test_xai_provider.py +++ b/tests/test_xai_provider.py @@ -77,7 +77,7 @@ class TestXAIProvider: capabilities = provider.get_capabilities("grok-3") assert capabilities.model_name == "grok-3" - assert capabilities.friendly_name == "X.AI" + assert capabilities.friendly_name == "X.AI (Grok 3)" assert capabilities.context_window == 131_072 assert capabilities.provider == ProviderType.XAI assert not capabilities.supports_extended_thinking @@ -96,7 +96,7 @@ class TestXAIProvider: capabilities = provider.get_capabilities("grok-3-fast") assert capabilities.model_name == "grok-3-fast" - assert capabilities.friendly_name == "X.AI" + assert capabilities.friendly_name == "X.AI (Grok 3 Fast)" assert capabilities.context_window == 131_072 assert capabilities.provider == ProviderType.XAI assert not capabilities.supports_extended_thinking @@ -212,31 +212,34 @@ class TestXAIProvider: assert provider.FRIENDLY_NAME == "X.AI" capabilities = provider.get_capabilities("grok-3") - assert capabilities.friendly_name == "X.AI" + assert capabilities.friendly_name == "X.AI (Grok 3)" def test_supported_models_structure(self): """Test that SUPPORTED_MODELS has the correct structure.""" provider = XAIModelProvider("test-key") - # Check that all expected models are present + # Check that all expected base models are present assert "grok-3" in provider.SUPPORTED_MODELS assert "grok-3-fast" in provider.SUPPORTED_MODELS - assert "grok" in provider.SUPPORTED_MODELS - assert "grok3" in provider.SUPPORTED_MODELS - assert "grokfast" in provider.SUPPORTED_MODELS - assert "grok3fast" in provider.SUPPORTED_MODELS # Check model configs have required fields - grok3_config = provider.SUPPORTED_MODELS["grok-3"] - assert isinstance(grok3_config, dict) - assert "context_window" in grok3_config - assert "supports_extended_thinking" in grok3_config - assert grok3_config["context_window"] == 131_072 - assert grok3_config["supports_extended_thinking"] is False + from providers.base import ModelCapabilities - # Check shortcuts point to full names - assert provider.SUPPORTED_MODELS["grok"] == "grok-3" - assert provider.SUPPORTED_MODELS["grokfast"] == "grok-3-fast" + grok3_config = provider.SUPPORTED_MODELS["grok-3"] + assert isinstance(grok3_config, ModelCapabilities) + assert hasattr(grok3_config, "context_window") + assert hasattr(grok3_config, "supports_extended_thinking") + assert hasattr(grok3_config, "aliases") + assert grok3_config.context_window == 131_072 + assert grok3_config.supports_extended_thinking is False + + # Check aliases are correctly structured + assert "grok" in grok3_config.aliases + assert "grok3" in grok3_config.aliases + + grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"] + assert "grok3fast" in grok3fast_config.aliases + assert "grokfast" in grok3fast_config.aliases @patch("providers.openai_compatible.OpenAI") def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class): diff --git a/tools/listmodels.py b/tools/listmodels.py index 265fbcc..0813ee7 100644 --- a/tools/listmodels.py +++ b/tools/listmodels.py @@ -99,15 +99,11 @@ class ListModelsTool(BaseTool): output_lines.append("**Status**: Configured and available") output_lines.append("\n**Models**:") - # Get models from the provider's SUPPORTED_MODELS - for model_name, config in provider.SUPPORTED_MODELS.items(): - # Skip alias entries (string values) - if isinstance(config, str): - continue - - # Get description and context from the model config - description = config.get("description", "No description available") - context_window = config.get("context_window", 0) + # Get models from the provider's model configurations + for model_name, capabilities in provider.get_model_configurations().items(): + # Get description and context from the ModelCapabilities object + description = capabilities.description or "No description available" + context_window = capabilities.context_window # Format context window if context_window >= 1_000_000: @@ -133,13 +129,14 @@ class ListModelsTool(BaseTool): # Show aliases for this provider aliases = [] - for alias_name, target in provider.SUPPORTED_MODELS.items(): - if isinstance(target, str): # This is an alias - aliases.append(f"- `{alias_name}` → `{target}`") + for model_name, capabilities in provider.get_model_configurations().items(): + if capabilities.aliases: + for alias in capabilities.aliases: + aliases.append(f"- `{alias}` → `{model_name}`") if aliases: output_lines.append("\n**Aliases**:") - output_lines.extend(aliases) + output_lines.extend(sorted(aliases)) # Sort for consistent output else: output_lines.append(f"**Status**: Not configured (set {info['env_key']})") @@ -237,7 +234,7 @@ class ListModelsTool(BaseTool): for alias in registry.list_aliases(): config = registry.resolve(alias) - if config and hasattr(config, "is_custom") and config.is_custom: + if config and config.is_custom: custom_models.append((alias, config)) if custom_models: diff --git a/tools/shared/base_tool.py b/tools/shared/base_tool.py index a98baf8..10b223f 100644 --- a/tools/shared/base_tool.py +++ b/tools/shared/base_tool.py @@ -256,8 +256,8 @@ class BaseTool(ABC): # Find all custom models (is_custom=true) for alias in registry.list_aliases(): config = registry.resolve(alias) - # Use hasattr for defensive programming - is_custom is optional with default False - if config and hasattr(config, "is_custom") and config.is_custom: + # Check if this is a custom model that requires custom endpoints + if config and config.is_custom: if alias not in all_models: all_models.append(alias) except Exception as e: @@ -311,12 +311,16 @@ class BaseTool(ABC): ProviderType.GOOGLE: "Gemini models", ProviderType.OPENAI: "OpenAI models", ProviderType.XAI: "X.AI GROK models", + ProviderType.DIAL: "DIAL models", ProviderType.CUSTOM: "Custom models", ProviderType.OPENROUTER: "OpenRouter models", } # Check available providers and add their model descriptions - for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI]: + + # Start with native providers + for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]: + # Only if this is registered / available provider = ModelProviderRegistry.get_provider(provider_type) if provider: provider_section_added = False @@ -324,13 +328,13 @@ class BaseTool(ABC): try: # Get model config to extract description model_config = provider.SUPPORTED_MODELS.get(model_name) - if isinstance(model_config, dict) and "description" in model_config: + if model_config and model_config.description: if not provider_section_added: model_desc_parts.append( f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:" ) provider_section_added = True - model_desc_parts.append(f"- '{model_name}': {model_config['description']}") + model_desc_parts.append(f"- '{model_name}': {model_config.description}") except Exception: # Skip models without descriptions continue @@ -346,8 +350,8 @@ class BaseTool(ABC): # Find all custom models (is_custom=true) for alias in registry.list_aliases(): config = registry.resolve(alias) - # Use hasattr for defensive programming - is_custom is optional with default False - if config and hasattr(config, "is_custom") and config.is_custom: + # Check if this is a custom model that requires custom endpoints + if config and config.is_custom: # Format context window context_tokens = config.context_window if context_tokens >= 1_000_000: