Support for allowed model restrictions per provider

Tool escalation added to `analyze` to a graceful switch over to codereview is made when absolutely necessary
This commit is contained in:
Fahad
2025-06-14 10:56:53 +04:00
parent ac9c58ce61
commit 23353734cd
14 changed files with 1037 additions and 79 deletions

View File

@@ -1,5 +1,6 @@
"""Gemini model provider implementation."""
import logging
import time
from typing import Optional
@@ -8,6 +9,8 @@ from google.genai import types
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
logger = logging.getLogger(__name__)
class GeminiModelProvider(ModelProvider):
"""Google Gemini model provider implementation."""
@@ -60,6 +63,13 @@ class GeminiModelProvider(ModelProvider):
if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported Gemini model: {model_name}")
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
raise ValueError(f"Gemini model '{model_name}' is not allowed by restriction policy.")
config = self.SUPPORTED_MODELS[resolved_name]
# Gemini models support 0.0-2.0 temperature range
@@ -201,9 +211,22 @@ class GeminiModelProvider(ModelProvider):
return ProviderType.GOOGLE
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported."""
"""Validate if the model name is supported and allowed."""
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
# First check if model is supported
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
return False
# Then check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
logger.debug(f"Gemini model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""

View File

@@ -1,5 +1,7 @@
"""OpenAI model provider implementation."""
import logging
from .base import (
FixedTemperatureConstraint,
ModelCapabilities,
@@ -8,6 +10,8 @@ from .base import (
)
from .openai_compatible import OpenAICompatibleProvider
logger = logging.getLogger(__name__)
class OpenAIModelProvider(OpenAICompatibleProvider):
"""Official OpenAI API provider (api.openai.com)."""
@@ -31,6 +35,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
"supports_extended_thinking": False,
},
# Shorthands
"mini": "o4-mini", # Default 'mini' to latest mini model
"o3mini": "o3-mini",
"o4mini": "o4-mini",
"o4minihigh": "o4-mini-high",
@@ -51,6 +56,13 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
raise ValueError(f"Unsupported OpenAI model: {model_name}")
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
config = self.SUPPORTED_MODELS[resolved_name]
# Define temperature constraints per model
@@ -78,9 +90,22 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
return ProviderType.OPENAI
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is supported."""
"""Validate if the model name is supported and allowed."""
resolved_name = self._resolve_model_name(model_name)
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
# First check if model is supported
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
return False
# Then check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
logger.debug(f"OpenAI model '{model_name}' -> '{resolved_name}' blocked by restrictions")
return False
return True
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""

View File

@@ -150,24 +150,63 @@ class ModelProviderRegistry:
return list(instance._providers.keys())
@classmethod
def get_available_models(cls) -> dict[str, ProviderType]:
def get_available_models(cls, respect_restrictions: bool = True) -> dict[str, ProviderType]:
"""Get mapping of all available models to their providers.
Args:
respect_restrictions: If True, filter out models not allowed by restrictions
Returns:
Dict mapping model names to provider types
"""
models = {}
instance = cls()
# Import here to avoid circular imports
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
for provider_type in instance._providers:
provider = cls.get_provider(provider_type)
if provider:
# This assumes providers have a method to list supported models
# We'll need to add this to the interface
pass
# Get supported models based on provider type
if hasattr(provider, "SUPPORTED_MODELS"):
for model_name, config in provider.SUPPORTED_MODELS.items():
# Skip aliases (string values)
if isinstance(config, str):
continue
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
logging.debug(f"Model {model_name} filtered by restrictions")
continue
models[model_name] = provider_type
return models
@classmethod
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
"""Get list of available model names, optionally filtered by provider.
This respects model restrictions automatically.
Args:
provider_type: Optional provider to filter by
Returns:
List of available model names
"""
available_models = cls.get_available_models(respect_restrictions=True)
if provider_type:
# Filter by specific provider
return [name for name, ptype in available_models.items() if ptype == provider_type]
else:
# Return all available models
return list(available_models.keys())
@classmethod
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
"""Get API key for a provider from environment variables.
@@ -198,6 +237,8 @@ class ModelProviderRegistry:
This method checks which providers have valid API keys and returns
a sensible default model for auto mode fallback situations.
Takes into account model restrictions when selecting fallback models.
Args:
tool_category: Optional category to influence model selection
@@ -207,16 +248,29 @@ class ModelProviderRegistry:
# Import here to avoid circular import
from tools.models import ToolModelCategory
# Check provider availability by trying to get instances
openai_available = cls.get_provider(ProviderType.OPENAI) is not None
gemini_available = cls.get_provider(ProviderType.GOOGLE) is not None
# Get available models respecting restrictions
available_models = cls.get_available_models(respect_restrictions=True)
# Group by provider
openai_models = [m for m, p in available_models.items() if p == ProviderType.OPENAI]
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
openai_available = bool(openai_models)
gemini_available = bool(gemini_models)
if tool_category == ToolModelCategory.EXTENDED_REASONING:
# Prefer thinking-capable models for deep reasoning tools
if openai_available:
if openai_available and "o3" in openai_models:
return "o3" # O3 for deep reasoning
elif gemini_available:
return "pro" # Gemini Pro with thinking mode
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
elif gemini_available and any("pro" in m for m in gemini_models):
# Find the pro model (handles full names)
return next(m for m in gemini_models if "pro" in m)
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
else:
# Try to find thinking-capable model from custom/openrouter
thinking_model = cls._find_extended_thinking_model()
@@ -227,22 +281,40 @@ class ModelProviderRegistry:
elif tool_category == ToolModelCategory.FAST_RESPONSE:
# Prefer fast, cost-efficient models
if openai_available:
return "o3-mini" # Fast and efficient
elif gemini_available:
return "flash" # Gemini Flash for speed
if openai_available and "o4-mini" in openai_models:
return "o4-mini" # Latest, fast and efficient
elif openai_available and "o3-mini" in openai_models:
return "o3-mini" # Second choice
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
# Find the flash model (handles full names)
return next(m for m in gemini_models if "flash" in m)
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
else:
# Default to flash
return "gemini-2.5-flash-preview-05-20"
# BALANCED or no category specified - use existing balanced logic
if openai_available:
return "o3-mini" # Balanced performance/cost
elif gemini_available:
return "gemini-2.5-flash-preview-05-20" # Fast and efficient
if openai_available and "o4-mini" in openai_models:
return "o4-mini" # Latest balanced performance/cost
elif openai_available and "o3-mini" in openai_models:
return "o3-mini" # Second choice
elif openai_available and openai_models:
return openai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
return next(m for m in gemini_models if "flash" in m)
elif gemini_available and gemini_models:
return gemini_models[0]
else:
# No API keys available - return a reasonable default
# This maintains backward compatibility for tests
# No models available due to restrictions - check if any providers exist
if not available_models:
# This might happen if all models are restricted
logging.warning("No models available due to restrictions")
# Return a reasonable default for backward compatibility
return "gemini-2.5-flash-preview-05-20"
@classmethod