Categorize tools into 'model capabilities categories' to help determine which type of model to pick when in auto mode
Encourage Claude to pick the best model for the job automatically in auto mode Lots of new tests to ensure automatic model picking works reliably based on user preference or when a matching model is not found or ambiguous Improved error reporting when bogus model is requested and is not configured or available
This commit is contained in:
@@ -2,10 +2,13 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .base import ModelProvider, ProviderType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
|
||||
class ModelProviderRegistry:
|
||||
"""Registry for managing model providers."""
|
||||
@@ -189,27 +192,50 @@ class ModelProviderRegistry:
|
||||
return os.getenv(env_var)
|
||||
|
||||
@classmethod
|
||||
def get_preferred_fallback_model(cls) -> str:
|
||||
"""Get the preferred fallback model based on available API keys.
|
||||
def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str:
|
||||
"""Get the preferred fallback model based on available API keys and tool category.
|
||||
|
||||
This method checks which providers have valid API keys and returns
|
||||
a sensible default model for auto mode fallback situations.
|
||||
|
||||
Priority order:
|
||||
1. OpenAI o3-mini (balanced performance/cost) if OpenAI API key available
|
||||
2. Gemini 2.0 Flash (fast and efficient) if Gemini API key available
|
||||
3. OpenAI o3 (high performance) if OpenAI API key available
|
||||
4. Gemini 2.5 Pro (deep reasoning) if Gemini API key available
|
||||
5. Fallback to gemini-2.5-flash-preview-05-20 (most common case)
|
||||
Args:
|
||||
tool_category: Optional category to influence model selection
|
||||
|
||||
Returns:
|
||||
Model name string for fallback use
|
||||
"""
|
||||
# Import here to avoid circular import
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
# Check provider availability by trying to get instances
|
||||
openai_available = cls.get_provider(ProviderType.OPENAI) is not None
|
||||
gemini_available = cls.get_provider(ProviderType.GOOGLE) is not None
|
||||
|
||||
# Priority order: prefer balanced models first, then high-performance
|
||||
if tool_category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# Prefer thinking-capable models for deep reasoning tools
|
||||
if openai_available:
|
||||
return "o3" # O3 for deep reasoning
|
||||
elif gemini_available:
|
||||
return "pro" # Gemini Pro with thinking mode
|
||||
else:
|
||||
# Try to find thinking-capable model from custom/openrouter
|
||||
thinking_model = cls._find_extended_thinking_model()
|
||||
if thinking_model:
|
||||
return thinking_model
|
||||
# Fallback to pro if nothing found
|
||||
return "gemini-2.5-pro-preview-06-05"
|
||||
|
||||
elif tool_category == ToolModelCategory.FAST_RESPONSE:
|
||||
# Prefer fast, cost-efficient models
|
||||
if openai_available:
|
||||
return "o3-mini" # Fast and efficient
|
||||
elif gemini_available:
|
||||
return "flash" # Gemini Flash for speed
|
||||
else:
|
||||
# Default to flash
|
||||
return "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
# BALANCED or no category specified - use existing balanced logic
|
||||
if openai_available:
|
||||
return "o3-mini" # Balanced performance/cost
|
||||
elif gemini_available:
|
||||
@@ -219,6 +245,51 @@ class ModelProviderRegistry:
|
||||
# This maintains backward compatibility for tests
|
||||
return "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
@classmethod
|
||||
def _find_extended_thinking_model(cls) -> Optional[str]:
|
||||
"""Find a model suitable for extended reasoning from custom/openrouter providers.
|
||||
|
||||
Returns:
|
||||
Model name if found, None otherwise
|
||||
"""
|
||||
# Check custom provider first
|
||||
custom_provider = cls.get_provider(ProviderType.CUSTOM)
|
||||
if custom_provider:
|
||||
# Check if it's a CustomModelProvider and has thinking models
|
||||
try:
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
if isinstance(custom_provider, CustomProvider) and hasattr(custom_provider, "model_registry"):
|
||||
for model_name, config in custom_provider.model_registry.items():
|
||||
if config.get("supports_extended_thinking", False):
|
||||
return model_name
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Then check OpenRouter for high-context/powerful models
|
||||
openrouter_provider = cls.get_provider(ProviderType.OPENROUTER)
|
||||
if openrouter_provider:
|
||||
# Prefer models known for deep reasoning
|
||||
preferred_models = [
|
||||
"anthropic/claude-3.5-sonnet",
|
||||
"anthropic/claude-3-opus-20240229",
|
||||
"meta-llama/llama-3.1-70b-instruct",
|
||||
"google/gemini-pro-1.5",
|
||||
"mistralai/mixtral-8x7b-instruct",
|
||||
]
|
||||
for model in preferred_models:
|
||||
try:
|
||||
if openrouter_provider.validate_model_name(model):
|
||||
return model
|
||||
except Exception as e:
|
||||
# Log the error for debugging purposes but continue searching
|
||||
import logging
|
||||
|
||||
logging.warning(f"Model validation for '{model}' on OpenRouter failed: {e}")
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_available_providers_with_keys(cls) -> list[ProviderType]:
|
||||
"""Get list of provider types that have valid API keys.
|
||||
|
||||
Reference in New Issue
Block a user