Categorize tools into 'model capabilities categories' to help determine which type of model to pick when in auto mode

Encourage Claude to pick the best model for the job automatically in auto mode
Lots of new tests to ensure automatic model picking works reliably based on user preference or when a matching model is not found or ambiguous
Improved error reporting when bogus model is requested and is not configured or available
This commit is contained in:
Fahad
2025-06-14 02:17:06 +04:00
parent 7fc1186a7c
commit eb388ab2f2
13 changed files with 838 additions and 68 deletions

View File

@@ -2,10 +2,13 @@
import logging
import os
from typing import Optional
from typing import TYPE_CHECKING, Optional
from .base import ModelProvider, ProviderType
if TYPE_CHECKING:
from tools.models import ToolModelCategory
class ModelProviderRegistry:
"""Registry for managing model providers."""
@@ -189,27 +192,50 @@ class ModelProviderRegistry:
return os.getenv(env_var)
@classmethod
def get_preferred_fallback_model(cls) -> str:
"""Get the preferred fallback model based on available API keys.
def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str:
"""Get the preferred fallback model based on available API keys and tool category.
This method checks which providers have valid API keys and returns
a sensible default model for auto mode fallback situations.
Priority order:
1. OpenAI o3-mini (balanced performance/cost) if OpenAI API key available
2. Gemini 2.0 Flash (fast and efficient) if Gemini API key available
3. OpenAI o3 (high performance) if OpenAI API key available
4. Gemini 2.5 Pro (deep reasoning) if Gemini API key available
5. Fallback to gemini-2.5-flash-preview-05-20 (most common case)
Args:
tool_category: Optional category to influence model selection
Returns:
Model name string for fallback use
"""
# Import here to avoid circular import
from tools.models import ToolModelCategory
# Check provider availability by trying to get instances
openai_available = cls.get_provider(ProviderType.OPENAI) is not None
gemini_available = cls.get_provider(ProviderType.GOOGLE) is not None
# Priority order: prefer balanced models first, then high-performance
if tool_category == ToolModelCategory.EXTENDED_REASONING:
# Prefer thinking-capable models for deep reasoning tools
if openai_available:
return "o3" # O3 for deep reasoning
elif gemini_available:
return "pro" # Gemini Pro with thinking mode
else:
# Try to find thinking-capable model from custom/openrouter
thinking_model = cls._find_extended_thinking_model()
if thinking_model:
return thinking_model
# Fallback to pro if nothing found
return "gemini-2.5-pro-preview-06-05"
elif tool_category == ToolModelCategory.FAST_RESPONSE:
# Prefer fast, cost-efficient models
if openai_available:
return "o3-mini" # Fast and efficient
elif gemini_available:
return "flash" # Gemini Flash for speed
else:
# Default to flash
return "gemini-2.5-flash-preview-05-20"
# BALANCED or no category specified - use existing balanced logic
if openai_available:
return "o3-mini" # Balanced performance/cost
elif gemini_available:
@@ -219,6 +245,51 @@ class ModelProviderRegistry:
# This maintains backward compatibility for tests
return "gemini-2.5-flash-preview-05-20"
@classmethod
def _find_extended_thinking_model(cls) -> Optional[str]:
"""Find a model suitable for extended reasoning from custom/openrouter providers.
Returns:
Model name if found, None otherwise
"""
# Check custom provider first
custom_provider = cls.get_provider(ProviderType.CUSTOM)
if custom_provider:
# Check if it's a CustomModelProvider and has thinking models
try:
from providers.custom import CustomProvider
if isinstance(custom_provider, CustomProvider) and hasattr(custom_provider, "model_registry"):
for model_name, config in custom_provider.model_registry.items():
if config.get("supports_extended_thinking", False):
return model_name
except ImportError:
pass
# Then check OpenRouter for high-context/powerful models
openrouter_provider = cls.get_provider(ProviderType.OPENROUTER)
if openrouter_provider:
# Prefer models known for deep reasoning
preferred_models = [
"anthropic/claude-3.5-sonnet",
"anthropic/claude-3-opus-20240229",
"meta-llama/llama-3.1-70b-instruct",
"google/gemini-pro-1.5",
"mistralai/mixtral-8x7b-instruct",
]
for model in preferred_models:
try:
if openrouter_provider.validate_model_name(model):
return model
except Exception as e:
# Log the error for debugging purposes but continue searching
import logging
logging.warning(f"Model validation for '{model}' on OpenRouter failed: {e}")
continue
return None
@classmethod
def get_available_providers_with_keys(cls) -> list[ProviderType]:
"""Get list of provider types that have valid API keys.