refactor: renaming to reflect underlying type
docs: updated to reflect new modules
This commit is contained in:
@@ -28,7 +28,7 @@ class ModelProvider(ABC):
|
||||
"""
|
||||
|
||||
# All concrete providers must define their supported models
|
||||
SUPPORTED_MODELS: dict[str, Any] = {}
|
||||
MODEL_CAPABILITIES: dict[str, Any] = {}
|
||||
|
||||
# Default maximum image size in MB
|
||||
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
|
||||
@@ -147,9 +147,9 @@ class ModelProvider(ABC):
|
||||
Returns:
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
# Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects)
|
||||
if hasattr(self, "SUPPORTED_MODELS"):
|
||||
return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)}
|
||||
model_map = getattr(self, "MODEL_CAPABILITIES", None)
|
||||
if isinstance(model_map, dict) and model_map:
|
||||
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
|
||||
return {}
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
|
||||
@@ -33,7 +33,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
MODEL_CAPABILITIES = {
|
||||
"o3-2025-04-16": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="o3-2025-04-16",
|
||||
@@ -280,7 +280,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
raise ValueError(f"Unsupported DIAL model: {model_name}")
|
||||
|
||||
# Check restrictions
|
||||
@@ -290,8 +290,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -308,7 +308,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
return False
|
||||
|
||||
# Check against base class allowed_models if configured
|
||||
|
||||
@@ -31,7 +31,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
"""
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
MODEL_CAPABILITIES = {
|
||||
"gemini-2.5-pro": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.5-pro",
|
||||
@@ -154,7 +154,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
# Resolve shorthand
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
raise ValueError(f"Unsupported Gemini model: {model_name}")
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
@@ -166,8 +166,8 @@ class GeminiModelProvider(ModelProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
||||
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
|
||||
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
@@ -227,7 +227,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
# Add thinking configuration for models that support it
|
||||
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||
# Get model's max thinking tokens and calculate actual budget
|
||||
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
||||
model_config = self.MODEL_CAPABILITIES.get(resolved_name)
|
||||
if model_config and model_config.max_thinking_tokens > 0:
|
||||
max_thinking_tokens = model_config.max_thinking_tokens
|
||||
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||
@@ -382,7 +382,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
@@ -405,7 +405,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
||||
"""Get actual thinking token budget for a model and thinking mode."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
||||
model_config = self.MODEL_CAPABILITIES.get(resolved_name)
|
||||
|
||||
if not model_config or not model_config.supports_extended_thinking:
|
||||
return 0
|
||||
@@ -584,7 +584,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
pro_thinking = [
|
||||
m
|
||||
for m in allowed_models
|
||||
if "pro" in m and m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
|
||||
if "pro" in m and m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking
|
||||
]
|
||||
if pro_thinking:
|
||||
return find_best(pro_thinking)
|
||||
@@ -593,7 +593,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
any_thinking = [
|
||||
m
|
||||
for m in allowed_models
|
||||
if m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
|
||||
if m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking
|
||||
]
|
||||
if any_thinking:
|
||||
return find_best(any_thinking)
|
||||
|
||||
@@ -26,7 +26,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
"""
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
MODEL_CAPABILITIES = {
|
||||
"gpt-5": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="gpt-5",
|
||||
@@ -181,21 +181,21 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific OpenAI model."""
|
||||
# First check if it's a key in SUPPORTED_MODELS
|
||||
if model_name in self.SUPPORTED_MODELS:
|
||||
# First check if it's a key in MODEL_CAPABILITIES
|
||||
if model_name in self.MODEL_CAPABILITIES:
|
||||
self._check_model_restrictions(model_name, model_name)
|
||||
return self.SUPPORTED_MODELS[model_name]
|
||||
return self.MODEL_CAPABILITIES[model_name]
|
||||
|
||||
# Try resolving as alias
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# Check if resolved name is a key
|
||||
if resolved_name in self.SUPPORTED_MODELS:
|
||||
if resolved_name in self.MODEL_CAPABILITIES:
|
||||
self._check_model_restrictions(resolved_name, model_name)
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
|
||||
# Finally check if resolved name matches any API model name
|
||||
for key, capabilities in self.SUPPORTED_MODELS.items():
|
||||
for key, capabilities in self.MODEL_CAPABILITIES.items():
|
||||
if resolved_name == capabilities.model_name:
|
||||
self._check_model_restrictions(key, model_name)
|
||||
return capabilities
|
||||
@@ -248,7 +248,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
model_to_check = None
|
||||
is_custom_model = False
|
||||
|
||||
if resolved_name in self.SUPPORTED_MODELS:
|
||||
if resolved_name in self.MODEL_CAPABILITIES:
|
||||
model_to_check = resolved_name
|
||||
else:
|
||||
# If not a built-in model, check the custom models registry.
|
||||
|
||||
@@ -282,11 +282,9 @@ class ModelProviderRegistry:
|
||||
# Use list_models to get all supported models (handles both regular and custom providers)
|
||||
supported_models = provider.list_models(respect_restrictions=False)
|
||||
except (NotImplementedError, AttributeError):
|
||||
# Fallback to SUPPORTED_MODELS if list_models not implemented
|
||||
try:
|
||||
supported_models = list(provider.SUPPORTED_MODELS.keys())
|
||||
except AttributeError:
|
||||
supported_models = []
|
||||
# Fallback to provider-declared capability maps if list_models not implemented
|
||||
model_map = getattr(provider, "MODEL_CAPABILITIES", None)
|
||||
supported_models = list(model_map.keys()) if isinstance(model_map, dict) else []
|
||||
|
||||
# Filter by restrictions
|
||||
for model_name in supported_models:
|
||||
|
||||
@@ -27,7 +27,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
FRIENDLY_NAME = "X.AI"
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
MODEL_CAPABILITIES = {
|
||||
"grok-4": ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name="grok-4",
|
||||
@@ -95,7 +95,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
# Resolve shorthand
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
@@ -105,8 +105,8 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
||||
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
# Return the ModelCapabilities object directly from MODEL_CAPABILITIES
|
||||
return self.MODEL_CAPABILITIES[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -117,7 +117,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
if resolved_name not in self.MODEL_CAPABILITIES:
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
@@ -156,7 +156,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
capabilities = self.SUPPORTED_MODELS.get(resolved_name)
|
||||
capabilities = self.MODEL_CAPABILITIES.get(resolved_name)
|
||||
if capabilities:
|
||||
return capabilities.supports_extended_thinking
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user