Support for allowed model restrictions per provider
Tool escalation added to `analyze` to a graceful switch over to codereview is made when absolutely necessary
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Gemini model provider implementation."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
@@ -8,6 +9,8 @@ from google.genai import types
|
||||
|
||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GeminiModelProvider(ModelProvider):
|
||||
"""Google Gemini model provider implementation."""
|
||||
@@ -60,6 +63,13 @@ class GeminiModelProvider(ModelProvider):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported Gemini model: {model_name}")
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
||||
raise ValueError(f"Gemini model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
# Gemini models support 0.0-2.0 temperature range
|
||||
@@ -201,9 +211,22 @@ class GeminiModelProvider(ModelProvider):
|
||||
return ProviderType.GOOGLE
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported."""
|
||||
"""Validate if the model name is supported and allowed."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
||||
logger.debug(f"Gemini model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""OpenAI model provider implementation."""
|
||||
|
||||
import logging
|
||||
|
||||
from .base import (
|
||||
FixedTemperatureConstraint,
|
||||
ModelCapabilities,
|
||||
@@ -8,6 +10,8 @@ from .base import (
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
"""Official OpenAI API provider (api.openai.com)."""
|
||||
@@ -31,6 +35,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
# Shorthands
|
||||
"mini": "o4-mini", # Default 'mini' to latest mini model
|
||||
"o3mini": "o3-mini",
|
||||
"o4mini": "o4-mini",
|
||||
"o4minihigh": "o4-mini-high",
|
||||
@@ -51,6 +56,13 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
|
||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
# Define temperature constraints per model
|
||||
@@ -78,9 +90,22 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
return ProviderType.OPENAI
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool:
|
||||
"""Validate if the model name is supported."""
|
||||
"""Validate if the model name is supported and allowed."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
return resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
||||
logger.debug(f"OpenAI model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
|
||||
@@ -150,24 +150,63 @@ class ModelProviderRegistry:
|
||||
return list(instance._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> dict[str, ProviderType]:
|
||||
def get_available_models(cls, respect_restrictions: bool = True) -> dict[str, ProviderType]:
|
||||
"""Get mapping of all available models to their providers.
|
||||
|
||||
Args:
|
||||
respect_restrictions: If True, filter out models not allowed by restrictions
|
||||
|
||||
Returns:
|
||||
Dict mapping model names to provider types
|
||||
"""
|
||||
models = {}
|
||||
instance = cls()
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
|
||||
for provider_type in instance._providers:
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider:
|
||||
# This assumes providers have a method to list supported models
|
||||
# We'll need to add this to the interface
|
||||
pass
|
||||
# Get supported models based on provider type
|
||||
if hasattr(provider, "SUPPORTED_MODELS"):
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
# Skip aliases (string values)
|
||||
if isinstance(config, str):
|
||||
continue
|
||||
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
||||
continue
|
||||
|
||||
models[model_name] = provider_type
|
||||
|
||||
return models
|
||||
|
||||
@classmethod
|
||||
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
|
||||
"""Get list of available model names, optionally filtered by provider.
|
||||
|
||||
This respects model restrictions automatically.
|
||||
|
||||
Args:
|
||||
provider_type: Optional provider to filter by
|
||||
|
||||
Returns:
|
||||
List of available model names
|
||||
"""
|
||||
available_models = cls.get_available_models(respect_restrictions=True)
|
||||
|
||||
if provider_type:
|
||||
# Filter by specific provider
|
||||
return [name for name, ptype in available_models.items() if ptype == provider_type]
|
||||
else:
|
||||
# Return all available models
|
||||
return list(available_models.keys())
|
||||
|
||||
@classmethod
|
||||
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
|
||||
"""Get API key for a provider from environment variables.
|
||||
@@ -198,6 +237,8 @@ class ModelProviderRegistry:
|
||||
This method checks which providers have valid API keys and returns
|
||||
a sensible default model for auto mode fallback situations.
|
||||
|
||||
Takes into account model restrictions when selecting fallback models.
|
||||
|
||||
Args:
|
||||
tool_category: Optional category to influence model selection
|
||||
|
||||
@@ -207,16 +248,29 @@ class ModelProviderRegistry:
|
||||
# Import here to avoid circular import
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
# Check provider availability by trying to get instances
|
||||
openai_available = cls.get_provider(ProviderType.OPENAI) is not None
|
||||
gemini_available = cls.get_provider(ProviderType.GOOGLE) is not None
|
||||
# Get available models respecting restrictions
|
||||
available_models = cls.get_available_models(respect_restrictions=True)
|
||||
|
||||
# Group by provider
|
||||
openai_models = [m for m, p in available_models.items() if p == ProviderType.OPENAI]
|
||||
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
|
||||
|
||||
openai_available = bool(openai_models)
|
||||
gemini_available = bool(gemini_models)
|
||||
|
||||
if tool_category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# Prefer thinking-capable models for deep reasoning tools
|
||||
if openai_available:
|
||||
if openai_available and "o3" in openai_models:
|
||||
return "o3" # O3 for deep reasoning
|
||||
elif gemini_available:
|
||||
return "pro" # Gemini Pro with thinking mode
|
||||
elif openai_available and openai_models:
|
||||
# Fall back to any available OpenAI model
|
||||
return openai_models[0]
|
||||
elif gemini_available and any("pro" in m for m in gemini_models):
|
||||
# Find the pro model (handles full names)
|
||||
return next(m for m in gemini_models if "pro" in m)
|
||||
elif gemini_available and gemini_models:
|
||||
# Fall back to any available Gemini model
|
||||
return gemini_models[0]
|
||||
else:
|
||||
# Try to find thinking-capable model from custom/openrouter
|
||||
thinking_model = cls._find_extended_thinking_model()
|
||||
@@ -227,22 +281,40 @@ class ModelProviderRegistry:
|
||||
|
||||
elif tool_category == ToolModelCategory.FAST_RESPONSE:
|
||||
# Prefer fast, cost-efficient models
|
||||
if openai_available:
|
||||
return "o3-mini" # Fast and efficient
|
||||
elif gemini_available:
|
||||
return "flash" # Gemini Flash for speed
|
||||
if openai_available and "o4-mini" in openai_models:
|
||||
return "o4-mini" # Latest, fast and efficient
|
||||
elif openai_available and "o3-mini" in openai_models:
|
||||
return "o3-mini" # Second choice
|
||||
elif openai_available and openai_models:
|
||||
# Fall back to any available OpenAI model
|
||||
return openai_models[0]
|
||||
elif gemini_available and any("flash" in m for m in gemini_models):
|
||||
# Find the flash model (handles full names)
|
||||
return next(m for m in gemini_models if "flash" in m)
|
||||
elif gemini_available and gemini_models:
|
||||
# Fall back to any available Gemini model
|
||||
return gemini_models[0]
|
||||
else:
|
||||
# Default to flash
|
||||
return "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
# BALANCED or no category specified - use existing balanced logic
|
||||
if openai_available:
|
||||
return "o3-mini" # Balanced performance/cost
|
||||
elif gemini_available:
|
||||
return "gemini-2.5-flash-preview-05-20" # Fast and efficient
|
||||
if openai_available and "o4-mini" in openai_models:
|
||||
return "o4-mini" # Latest balanced performance/cost
|
||||
elif openai_available and "o3-mini" in openai_models:
|
||||
return "o3-mini" # Second choice
|
||||
elif openai_available and openai_models:
|
||||
return openai_models[0]
|
||||
elif gemini_available and any("flash" in m for m in gemini_models):
|
||||
return next(m for m in gemini_models if "flash" in m)
|
||||
elif gemini_available and gemini_models:
|
||||
return gemini_models[0]
|
||||
else:
|
||||
# No API keys available - return a reasonable default
|
||||
# This maintains backward compatibility for tests
|
||||
# No models available due to restrictions - check if any providers exist
|
||||
if not available_models:
|
||||
# This might happen if all models are restricted
|
||||
logging.warning("No models available due to restrictions")
|
||||
# Return a reasonable default for backward compatibility
|
||||
return "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user