Support for allowed model restrictions per provider

Tool escalation added to `analyze` to a graceful switch over to codereview is made when absolutely necessary
This commit is contained in:
Fahad
2025-06-14 10:56:53 +04:00
parent ac9c58ce61
commit 23353734cd
14 changed files with 1037 additions and 79 deletions

View File

@@ -150,24 +150,63 @@ class ModelProviderRegistry:
return list(instance._providers.keys())
@classmethod
def get_available_models(cls) -> dict[str, ProviderType]:
def get_available_models(cls, respect_restrictions: bool = True) -> dict[str, ProviderType]:
"""Get mapping of all available models to their providers.
Args:
respect_restrictions: If True, filter out models not allowed by restrictions
Returns:
Dict mapping model names to provider types
"""
models = {}
instance = cls()
# Import here to avoid circular imports
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
for provider_type in instance._providers:
provider = cls.get_provider(provider_type)
if provider:
# This assumes providers have a method to list supported models
# We'll need to add this to the interface
pass
# Get supported models based on provider type
if hasattr(provider, "SUPPORTED_MODELS"):
for model_name, config in provider.SUPPORTED_MODELS.items():
# Skip aliases (string values)
if isinstance(config, str):
continue
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
logging.debug(f"Model {model_name} filtered by restrictions")
continue
models[model_name] = provider_type
return models
@classmethod
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
"""Get list of available model names, optionally filtered by provider.
This respects model restrictions automatically.
Args:
provider_type: Optional provider to filter by
Returns:
List of available model names
"""
available_models = cls.get_available_models(respect_restrictions=True)
if provider_type:
# Filter by specific provider
return [name for name, ptype in available_models.items() if ptype == provider_type]
else:
# Return all available models
return list(available_models.keys())
@classmethod
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
"""Get API key for a provider from environment variables.
@@ -198,6 +237,8 @@ class ModelProviderRegistry:
This method checks which providers have valid API keys and returns
a sensible default model for auto mode fallback situations.
Takes into account model restrictions when selecting fallback models.
Args:
tool_category: Optional category to influence model selection
@@ -207,16 +248,29 @@ class ModelProviderRegistry:
# Import here to avoid circular import
from tools.models import ToolModelCategory
# Check provider availability by trying to get instances
openai_available = cls.get_provider(ProviderType.OPENAI) is not None
gemini_available = cls.get_provider(ProviderType.GOOGLE) is not None
# Get available models respecting restrictions
available_models = cls.get_available_models(respect_restrictions=True)
# Group by provider
openai_models = [m for m, p in available_models.items() if p == ProviderType.OPENAI]
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
openai_available = bool(openai_models)
gemini_available = bool(gemini_models)
if tool_category == ToolModelCategory.EXTENDED_REASONING:
# Prefer thinking-capable models for deep reasoning tools
if openai_available:
if openai_available and "o3" in openai_models:
return "o3" # O3 for deep reasoning
elif gemini_available:
return "pro" # Gemini Pro with thinking mode
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
elif gemini_available and any("pro" in m for m in gemini_models):
# Find the pro model (handles full names)
return next(m for m in gemini_models if "pro" in m)
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
else:
# Try to find thinking-capable model from custom/openrouter
thinking_model = cls._find_extended_thinking_model()
@@ -227,22 +281,40 @@ class ModelProviderRegistry:
elif tool_category == ToolModelCategory.FAST_RESPONSE:
# Prefer fast, cost-efficient models
if openai_available:
return "o3-mini" # Fast and efficient
elif gemini_available:
return "flash" # Gemini Flash for speed
if openai_available and "o4-mini" in openai_models:
return "o4-mini" # Latest, fast and efficient
elif openai_available and "o3-mini" in openai_models:
return "o3-mini" # Second choice
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
# Find the flash model (handles full names)
return next(m for m in gemini_models if "flash" in m)
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
else:
# Default to flash
return "gemini-2.5-flash-preview-05-20"
# BALANCED or no category specified - use existing balanced logic
if openai_available:
return "o3-mini" # Balanced performance/cost
elif gemini_available:
return "gemini-2.5-flash-preview-05-20" # Fast and efficient
if openai_available and "o4-mini" in openai_models:
return "o4-mini" # Latest balanced performance/cost
elif openai_available and "o3-mini" in openai_models:
return "o3-mini" # Second choice
elif openai_available and openai_models:
return openai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
return next(m for m in gemini_models if "flash" in m)
elif gemini_available and gemini_models:
return gemini_models[0]
else:
# No API keys available - return a reasonable default
# This maintains backward compatibility for tests
# No models available due to restrictions - check if any providers exist
if not available_models:
# This might happen if all models are restricted
logging.warning("No models available due to restrictions")
# Return a reasonable default for backward compatibility
return "gemini-2.5-flash-preview-05-20"
@classmethod