GPT-5, GPT-5-mini support

Improvements to model name resolution
Improved instructions for multi-step workflows when continuation is available
Improved instructions for chat tool
Improved preferred model resolution, moved code from registry -> each provider
Updated tests
This commit is contained in:
Fahad
2025-08-08 08:51:34 +05:00
parent 9a4791cb06
commit 1a8ec2e12f
30 changed files with 792 additions and 483 deletions

View File

@@ -15,6 +15,17 @@ class ModelProviderRegistry:
_instance = None
# Provider priority order for model selection
# Native APIs first, then custom endpoints, then catch-all providers
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.XAI, # Direct X.AI GROK access
ProviderType.DIAL, # DIAL unified API access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
def __new__(cls):
"""Singleton pattern for registry."""
if cls._instance is None:
@@ -103,30 +114,19 @@ class ModelProviderRegistry:
3. OPENROUTER - Catch-all for cloud models via unified API
Args:
model_name: Name of the model (e.g., "gemini-2.5-flash", "o3-mini")
model_name: Name of the model (e.g., "gemini-2.5-flash", "gpt5")
Returns:
ModelProvider instance that supports this model
"""
logging.debug(f"get_provider_for_model called with model_name='{model_name}'")
# Define explicit provider priority order
# Native APIs first, then custom endpoints, then catch-all providers
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.XAI, # Direct X.AI GROK access
ProviderType.DIAL, # DIAL unified API access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
# Check providers in priority order
instance = cls()
logging.debug(f"Registry instance: {instance}")
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
for provider_type in PROVIDER_PRIORITY_ORDER:
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
if provider_type in instance._providers:
logging.debug(f"Found {provider_type} in registry")
# Get or create provider instance
@@ -244,14 +244,49 @@ class ModelProviderRegistry:
return os.getenv(env_var)
@classmethod
def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]:
"""Get a list of allowed canonical model names for a given provider.
Args:
provider: The provider instance to get models for
provider_type: The provider type for restriction checking
Returns:
List of model names that are both supported and allowed
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
allowed_models = []
# Get the provider's supported models
try:
# Use list_models to get all supported models (handles both regular and custom providers)
supported_models = provider.list_models(respect_restrictions=False)
except (NotImplementedError, AttributeError):
# Fallback to SUPPORTED_MODELS if list_models not implemented
try:
supported_models = list(provider.SUPPORTED_MODELS.keys())
except AttributeError:
supported_models = []
# Filter by restrictions
for model_name in supported_models:
if restriction_service.is_allowed(provider_type, model_name):
allowed_models.append(model_name)
return allowed_models
@classmethod
def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str:
"""Get the preferred fallback model based on available API keys and tool category.
"""Get the preferred fallback model based on provider priority and tool category.
This method checks which providers have valid API keys and returns
a sensible default model for auto mode fallback situations.
Takes into account model restrictions when selecting fallback models.
This method orchestrates model selection by:
1. Getting allowed models for each provider (respecting restrictions)
2. Asking providers for their preference from the allowed list
3. Falling back to first available model if no preference given
Args:
tool_category: Optional category to influence model selection
@@ -259,167 +294,42 @@ class ModelProviderRegistry:
Returns:
Model name string for fallback use
"""
# Import here to avoid circular import
from tools.models import ToolModelCategory
# Get available models respecting restrictions
available_models = cls.get_available_models(respect_restrictions=True)
effective_category = tool_category or ToolModelCategory.BALANCED
first_available_model = None
# Group by provider
openai_models = [m for m, p in available_models.items() if p == ProviderType.OPENAI]
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
xai_models = [m for m, p in available_models.items() if p == ProviderType.XAI]
openrouter_models = [m for m, p in available_models.items() if p == ProviderType.OPENROUTER]
custom_models = [m for m, p in available_models.items() if p == ProviderType.CUSTOM]
# Ask each provider for their preference in priority order
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
provider = cls.get_provider(provider_type)
if provider:
# 1. Registry filters the models first
allowed_models = cls._get_allowed_models_for_provider(provider, provider_type)
openai_available = bool(openai_models)
gemini_available = bool(gemini_models)
xai_available = bool(xai_models)
openrouter_available = bool(openrouter_models)
custom_available = bool(custom_models)
if tool_category == ToolModelCategory.EXTENDED_REASONING:
# Prefer thinking-capable models for deep reasoning tools
if openai_available and "o3" in openai_models:
return "o3" # O3 for deep reasoning
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
elif xai_available and "grok-3" in xai_models:
return "grok-3" # GROK-3 for deep reasoning
elif xai_available and xai_models:
# Fall back to any available XAI model
return xai_models[0]
elif gemini_available and any("pro" in m for m in gemini_models):
# Find the pro model (handles full names)
return next(m for m in gemini_models if "pro" in m)
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
elif openrouter_available:
# Try to find thinking-capable model from openrouter
thinking_model = cls._find_extended_thinking_model()
if thinking_model:
return thinking_model
# Fallback to first available OpenRouter model
return openrouter_models[0]
elif custom_available:
# Fallback to custom models when available
return custom_models[0]
else:
# Fallback to pro if nothing found
return "gemini-2.5-pro"
elif tool_category == ToolModelCategory.FAST_RESPONSE:
# Prefer fast, cost-efficient models
if openai_available and "o4-mini" in openai_models:
return "o4-mini" # Latest, fast and efficient
elif openai_available and "o3-mini" in openai_models:
return "o3-mini" # Second choice
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
elif xai_available and "grok-3-fast" in xai_models:
return "grok-3-fast" # GROK-3 Fast for speed
elif xai_available and xai_models:
# Fall back to any available XAI model
return xai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
# Find the flash model (handles full names)
# Prefer 2.5 over 2.0 for backward compatibility
flash_models = [m for m in gemini_models if "flash" in m]
# Sort to ensure 2.5 comes before 2.0
flash_models_sorted = sorted(flash_models, reverse=True)
return flash_models_sorted[0]
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
elif openrouter_available:
# Fallback to first available OpenRouter model
return openrouter_models[0]
elif custom_available:
# Fallback to custom models when available
return custom_models[0]
else:
# Default to flash
return "gemini-2.5-flash"
# BALANCED or no category specified - use existing balanced logic
if openai_available and "o4-mini" in openai_models:
return "o4-mini" # Latest balanced performance/cost
elif openai_available and "o3-mini" in openai_models:
return "o3-mini" # Second choice
elif openai_available and openai_models:
return openai_models[0]
elif xai_available and "grok-3" in xai_models:
return "grok-3" # GROK-3 as balanced choice
elif xai_available and xai_models:
return xai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
# Prefer 2.5 over 2.0 for backward compatibility
flash_models = [m for m in gemini_models if "flash" in m]
flash_models_sorted = sorted(flash_models, reverse=True)
return flash_models_sorted[0]
elif gemini_available and gemini_models:
return gemini_models[0]
elif openrouter_available:
return openrouter_models[0]
elif custom_available:
# Fallback to custom models when available
return custom_models[0]
else:
# No models available due to restrictions - check if any providers exist
if not available_models:
# This might happen if all models are restricted
logging.warning("No models available due to restrictions")
# Return a reasonable default for backward compatibility
return "gemini-2.5-flash"
@classmethod
def _find_extended_thinking_model(cls) -> Optional[str]:
"""Find a model suitable for extended reasoning from custom/openrouter providers.
Returns:
Model name if found, None otherwise
"""
# Check custom provider first
custom_provider = cls.get_provider(ProviderType.CUSTOM)
if custom_provider:
# Check if it's a CustomModelProvider and has thinking models
try:
from providers.custom import CustomProvider
if isinstance(custom_provider, CustomProvider) and hasattr(custom_provider, "model_registry"):
for model_name, config in custom_provider.model_registry.items():
if config.get("supports_extended_thinking", False):
return model_name
except ImportError:
pass
# Then check OpenRouter for high-context/powerful models
openrouter_provider = cls.get_provider(ProviderType.OPENROUTER)
if openrouter_provider:
# Prefer models known for deep reasoning
preferred_models = [
"anthropic/claude-sonnet-4",
"anthropic/claude-opus-4",
"google/gemini-2.5-pro",
"google/gemini-pro-1.5",
"meta-llama/llama-3.1-70b-instruct",
"mistralai/mixtral-8x7b-instruct",
]
for model in preferred_models:
try:
if openrouter_provider.validate_model_name(model):
return model
except Exception as e:
# Log the error for debugging purposes but continue searching
import logging
logging.warning(f"Model validation for '{model}' on OpenRouter failed: {e}")
if not allowed_models:
continue
return None
# 2. Keep track of the first available model as fallback
if not first_available_model:
first_available_model = sorted(allowed_models)[0]
# 3. Ask provider to pick from allowed list
preferred_model = provider.get_preferred_model(effective_category, allowed_models)
if preferred_model:
logging.debug(
f"Provider {provider_type.value} selected '{preferred_model}' for category '{effective_category.value}'"
)
return preferred_model
# If no provider returned a preference, use first available model
if first_available_model:
logging.debug(f"No provider preference, using first available: {first_available_model}")
return first_available_model
# Ultimate fallback if no providers have models
logging.warning("No models available from any provider, using default fallback")
return "gemini-2.5-flash"
@classmethod
def get_available_providers_with_keys(cls) -> list[ProviderType]: