GPT-5, GPT-5-mini support
Improvements to model name resolution Improved instructions for multi-step workflows when continuation is available Improved instructions for chat tool Improved preferred model resolution, moved code from registry -> each provider Updated tests
This commit is contained in:
@@ -4,7 +4,10 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -118,10 +121,10 @@ def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint
|
||||
return FixedTemperatureConstraint(1.0)
|
||||
elif constraint_type == "discrete":
|
||||
# For models with specific allowed values - using common OpenAI values as default
|
||||
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.7)
|
||||
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
|
||||
else:
|
||||
# Default range constraint (for "range" or None)
|
||||
return RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -154,24 +157,11 @@ class ModelCapabilities:
|
||||
# Custom model flag (for models that only work with custom endpoints)
|
||||
is_custom: bool = False # Whether this model requires custom API endpoints
|
||||
|
||||
# Temperature constraint object - preferred way to define temperature limits
|
||||
# Temperature constraint object - defines temperature limits and behavior
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
)
|
||||
|
||||
# Backward compatibility property for existing code
|
||||
@property
|
||||
def temperature_range(self) -> tuple[float, float]:
|
||||
"""Backward compatibility for existing code that uses temperature_range."""
|
||||
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
|
||||
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
|
||||
elif isinstance(self.temperature_constraint, FixedTemperatureConstraint):
|
||||
return (self.temperature_constraint.value, self.temperature_constraint.value)
|
||||
elif isinstance(self.temperature_constraint, DiscreteTemperatureConstraint):
|
||||
values = self.temperature_constraint.allowed_values
|
||||
return (min(values), max(values))
|
||||
return (0.0, 2.0) # Fallback
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
@@ -268,18 +258,15 @@ class ModelProvider(ABC):
|
||||
if not capabilities.supports_temperature:
|
||||
return None
|
||||
|
||||
# Get temperature range
|
||||
min_temp, max_temp = capabilities.temperature_range
|
||||
# Use temperature constraint to get corrected value
|
||||
corrected_temp = capabilities.temperature_constraint.get_corrected_value(requested_temperature)
|
||||
|
||||
# Clamp to valid range
|
||||
if requested_temperature < min_temp:
|
||||
logger.debug(f"Clamping temperature from {requested_temperature} to {min_temp} for model {model_name}")
|
||||
return min_temp
|
||||
elif requested_temperature > max_temp:
|
||||
logger.debug(f"Clamping temperature from {requested_temperature} to {max_temp} for model {model_name}")
|
||||
return max_temp
|
||||
else:
|
||||
return requested_temperature
|
||||
if corrected_temp != requested_temperature:
|
||||
logger.debug(
|
||||
f"Adjusting temperature from {requested_temperature} to {corrected_temp} for model {model_name}"
|
||||
)
|
||||
|
||||
return corrected_temp
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not determine effective temperature for {model_name}: {e}")
|
||||
@@ -294,10 +281,10 @@ class ModelProvider(ABC):
|
||||
"""
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
# Validate temperature
|
||||
min_temp, max_temp = capabilities.temperature_range
|
||||
if not min_temp <= temperature <= max_temp:
|
||||
raise ValueError(f"Temperature {temperature} out of range [{min_temp}, {max_temp}] for model {model_name}")
|
||||
# Validate temperature using constraint
|
||||
if not capabilities.temperature_constraint.validate(temperature):
|
||||
constraint_desc = capabilities.temperature_constraint.get_description()
|
||||
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
|
||||
|
||||
@abstractmethod
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
@@ -441,3 +428,28 @@ class ModelProvider(ABC):
|
||||
"""
|
||||
# Base implementation: no resources to clean up
|
||||
return
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get the preferred model from this provider for a given category.
|
||||
|
||||
Args:
|
||||
category: The tool category requiring a model
|
||||
allowed_models: Pre-filtered list of model names that are allowed by restrictions
|
||||
|
||||
Returns:
|
||||
Model name if this provider has a preference, None otherwise
|
||||
"""
|
||||
# Default implementation - providers can override with specific logic
|
||||
return None
|
||||
|
||||
def get_model_registry(self) -> Optional[dict[str, Any]]:
|
||||
"""Get the model registry for providers that maintain one.
|
||||
|
||||
This is a hook method for providers like CustomProvider that maintain
|
||||
a dynamic model registry.
|
||||
|
||||
Returns:
|
||||
Model registry dict or None if not applicable
|
||||
"""
|
||||
# Default implementation - most providers don't have a registry
|
||||
return None
|
||||
|
||||
@@ -4,7 +4,10 @@ import base64
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
@@ -19,6 +22,25 @@ class GeminiModelProvider(ModelProvider):
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"gemini-2.5-pro": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.5-pro",
|
||||
friendly_name="Gemini (Pro 2.5)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # Vision capability
|
||||
max_image_size_mb=32.0, # Higher limit for Pro model
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
max_thinking_tokens=32768, # Max thinking tokens for Pro model
|
||||
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||
aliases=["pro", "gemini pro", "gemini-pro"],
|
||||
),
|
||||
"gemini-2.0-flash": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.0-flash",
|
||||
@@ -75,25 +97,6 @@ class GeminiModelProvider(ModelProvider):
|
||||
description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||
aliases=["flash", "flash2.5"],
|
||||
),
|
||||
"gemini-2.5-pro": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.5-pro",
|
||||
friendly_name="Gemini (Pro 2.5)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # Vision capability
|
||||
max_image_size_mb=32.0, # Higher limit for Pro model
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
max_thinking_tokens=32768, # Max thinking tokens for Pro model
|
||||
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||
aliases=["pro", "gemini pro", "gemini-pro"],
|
||||
),
|
||||
}
|
||||
|
||||
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
||||
@@ -465,3 +468,67 @@ class GeminiModelProvider(ModelProvider):
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {image_path}: {e}")
|
||||
return None
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get Gemini's preferred model for a given category from allowed models.
|
||||
|
||||
Args:
|
||||
category: The tool category requiring a model
|
||||
allowed_models: Pre-filtered list of models allowed by restrictions
|
||||
|
||||
Returns:
|
||||
Preferred model name or None
|
||||
"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
if not allowed_models:
|
||||
return None
|
||||
|
||||
# Helper to find best model from candidates
|
||||
def find_best(candidates: list[str]) -> Optional[str]:
|
||||
"""Return best model from candidates (sorted for consistency)."""
|
||||
return sorted(candidates, reverse=True)[0] if candidates else None
|
||||
|
||||
if category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# For extended reasoning, prefer models with thinking support
|
||||
# First try Pro models that support thinking
|
||||
pro_thinking = [
|
||||
m
|
||||
for m in allowed_models
|
||||
if "pro" in m and m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
|
||||
]
|
||||
if pro_thinking:
|
||||
return find_best(pro_thinking)
|
||||
|
||||
# Then any model that supports thinking
|
||||
any_thinking = [
|
||||
m
|
||||
for m in allowed_models
|
||||
if m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
|
||||
]
|
||||
if any_thinking:
|
||||
return find_best(any_thinking)
|
||||
|
||||
# Finally, just prefer Pro models even without thinking
|
||||
pro_models = [m for m in allowed_models if "pro" in m]
|
||||
if pro_models:
|
||||
return find_best(pro_models)
|
||||
|
||||
elif category == ToolModelCategory.FAST_RESPONSE:
|
||||
# Prefer Flash models for speed
|
||||
flash_models = [m for m in allowed_models if "flash" in m]
|
||||
if flash_models:
|
||||
return find_best(flash_models)
|
||||
|
||||
# Default for BALANCED or as fallback
|
||||
# Prefer Flash for balanced use, then Pro, then anything
|
||||
flash_models = [m for m in allowed_models if "flash" in m]
|
||||
if flash_models:
|
||||
return find_best(flash_models)
|
||||
|
||||
pro_models = [m for m in allowed_models if "pro" in m]
|
||||
if pro_models:
|
||||
return find_best(pro_models)
|
||||
|
||||
# Ultimate fallback to best available model
|
||||
return find_best(allowed_models)
|
||||
|
||||
@@ -309,8 +309,10 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
max_retries = 4
|
||||
retry_delays = [1, 3, 5, 8]
|
||||
last_exception = None
|
||||
actual_attempts = 0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||
try: # Log the exact payload being sent for debugging
|
||||
import json
|
||||
|
||||
@@ -371,14 +373,13 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
if is_retryable and attempt < max_retries - 1:
|
||||
delay = retry_delays[attempt]
|
||||
logging.warning(
|
||||
f"Retryable error for o3-pro responses endpoint, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break
|
||||
|
||||
# If we get here, all retries failed
|
||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from last_exception
|
||||
@@ -481,7 +482,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
completion_params[key] = value
|
||||
|
||||
# Check if this is o3-pro and needs the responses endpoint
|
||||
if resolved_model == "o3-pro-2025-06-10":
|
||||
if resolved_model == "o3-pro":
|
||||
# This model requires the /v1/responses endpoint
|
||||
# If it fails, we should not fall back to chat/completions
|
||||
return self._generate_with_responses_endpoint(
|
||||
@@ -497,8 +498,10 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
||||
|
||||
last_exception = None
|
||||
actual_attempts = 0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||
try:
|
||||
# Generate completion
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
@@ -536,12 +539,11 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
# Log retry attempt
|
||||
logging.warning(
|
||||
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
# If we get here, all retries failed
|
||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from last_exception
|
||||
@@ -576,11 +578,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
# Try common encodings based on model patterns
|
||||
if "gpt-4" in model_name or "gpt-3.5" in model_name:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
else:
|
||||
encoding = tiktoken.get_encoding("cl100k_base") # Default
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
return len(encoding.encode(text))
|
||||
|
||||
@@ -679,11 +677,13 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
"""
|
||||
# Common vision-capable models - only include models that actually support images
|
||||
vision_models = {
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4.1-2025-04-14", # GPT-4.1 supports vision
|
||||
"gpt-4.1-2025-04-14",
|
||||
"o3",
|
||||
"o3-mini",
|
||||
"o3-pro",
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""OpenAI model provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
@@ -19,6 +22,42 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"gpt-5": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="gpt-5",
|
||||
friendly_name="OpenAI (GPT-5)",
|
||||
context_window=400_000, # 400K tokens
|
||||
max_output_tokens=128_000, # 128K max output tokens
|
||||
supports_extended_thinking=True, # Supports reasoning tokens
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # GPT-5 supports vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=True, # Regular models accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="GPT-5 (400K context, 128K output) - Advanced model with reasoning support",
|
||||
aliases=["gpt5", "gpt-5"],
|
||||
),
|
||||
"gpt-5-mini": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="gpt-5-mini",
|
||||
friendly_name="OpenAI (GPT-5-mini)",
|
||||
context_window=400_000, # 400K tokens
|
||||
max_output_tokens=128_000, # 128K max output tokens
|
||||
supports_extended_thinking=True, # Supports reasoning tokens
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # GPT-5-mini supports vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=True, # Regular models accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="GPT-5-mini (400K context, 128K output) - Efficient variant with reasoning support",
|
||||
aliases=["gpt5-mini", "gpt5mini", "mini"],
|
||||
),
|
||||
"o3": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o3",
|
||||
@@ -55,9 +94,9 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||
aliases=["o3mini", "o3-mini"],
|
||||
),
|
||||
"o3-pro-2025-06-10": ModelCapabilities(
|
||||
"o3-pro": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o3-pro-2025-06-10",
|
||||
model_name="o3-pro",
|
||||
friendly_name="OpenAI (O3-Pro)",
|
||||
context_window=200_000, # 200K tokens
|
||||
max_output_tokens=65536, # 64K max output tokens
|
||||
@@ -89,11 +128,11 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
supports_temperature=False, # O4 models don't accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
||||
aliases=["mini", "o4mini", "o4-mini"],
|
||||
aliases=["o4mini", "o4-mini"],
|
||||
),
|
||||
"gpt-4.1-2025-04-14": ModelCapabilities(
|
||||
"gpt-4.1": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="gpt-4.1-2025-04-14",
|
||||
model_name="gpt-4.1",
|
||||
friendly_name="OpenAI (GPT 4.1)",
|
||||
context_window=1_000_000, # 1M tokens
|
||||
max_output_tokens=32_768,
|
||||
@@ -107,7 +146,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
supports_temperature=True, # Regular models accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
||||
aliases=["gpt4.1"],
|
||||
aliases=["gpt4.1", "gpt-4.1"],
|
||||
),
|
||||
}
|
||||
|
||||
@@ -119,21 +158,41 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific OpenAI model."""
|
||||
# Resolve shorthand
|
||||
# First check if it's a key in SUPPORTED_MODELS
|
||||
if model_name in self.SUPPORTED_MODELS:
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, model_name, model_name):
|
||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||
return self.SUPPORTED_MODELS[model_name]
|
||||
|
||||
# Try resolving as alias
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||
# Check if resolved name is a key
|
||||
if resolved_name in self.SUPPORTED_MODELS:
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||
# Finally check if resolved name matches any API model name
|
||||
for key, capabilities in self.SUPPORTED_MODELS.items():
|
||||
if resolved_name == capabilities.model_name:
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, key, model_name):
|
||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||
return capabilities
|
||||
|
||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -182,6 +241,47 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
# Currently no OpenAI models support extended thinking
|
||||
# This may change with future O3 models
|
||||
# GPT-5 models support reasoning tokens (extended thinking)
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
if resolved_name in ["gpt-5", "gpt-5-mini"]:
|
||||
return True
|
||||
# O3 models don't support extended thinking yet
|
||||
return False
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get OpenAI's preferred model for a given category from allowed models.
|
||||
|
||||
Args:
|
||||
category: The tool category requiring a model
|
||||
allowed_models: Pre-filtered list of models allowed by restrictions
|
||||
|
||||
Returns:
|
||||
Preferred model name or None
|
||||
"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
if not allowed_models:
|
||||
return None
|
||||
|
||||
# Helper to find first available from preference list
|
||||
def find_first(preferences: list[str]) -> Optional[str]:
|
||||
"""Return first available model from preference list."""
|
||||
for model in preferences:
|
||||
if model in allowed_models:
|
||||
return model
|
||||
return None
|
||||
|
||||
if category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# Prefer models with extended thinking support
|
||||
preferred = find_first(["o3", "o3-pro", "gpt-5"])
|
||||
return preferred if preferred else allowed_models[0]
|
||||
|
||||
elif category == ToolModelCategory.FAST_RESPONSE:
|
||||
# Prefer fast, cost-efficient models
|
||||
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
|
||||
return preferred if preferred else allowed_models[0]
|
||||
|
||||
else: # BALANCED or default
|
||||
# Prefer balanced performance/cost models
|
||||
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
|
||||
return preferred if preferred else allowed_models[0]
|
||||
|
||||
@@ -15,6 +15,17 @@ class ModelProviderRegistry:
|
||||
|
||||
_instance = None
|
||||
|
||||
# Provider priority order for model selection
|
||||
# Native APIs first, then custom endpoints, then catch-all providers
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE, # Direct Gemini access
|
||||
ProviderType.OPENAI, # Direct OpenAI access
|
||||
ProviderType.XAI, # Direct X.AI GROK access
|
||||
ProviderType.DIAL, # DIAL unified API access
|
||||
ProviderType.CUSTOM, # Local/self-hosted models
|
||||
ProviderType.OPENROUTER, # Catch-all for cloud models
|
||||
]
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern for registry."""
|
||||
if cls._instance is None:
|
||||
@@ -103,30 +114,19 @@ class ModelProviderRegistry:
|
||||
3. OPENROUTER - Catch-all for cloud models via unified API
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (e.g., "gemini-2.5-flash", "o3-mini")
|
||||
model_name: Name of the model (e.g., "gemini-2.5-flash", "gpt5")
|
||||
|
||||
Returns:
|
||||
ModelProvider instance that supports this model
|
||||
"""
|
||||
logging.debug(f"get_provider_for_model called with model_name='{model_name}'")
|
||||
|
||||
# Define explicit provider priority order
|
||||
# Native APIs first, then custom endpoints, then catch-all providers
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE, # Direct Gemini access
|
||||
ProviderType.OPENAI, # Direct OpenAI access
|
||||
ProviderType.XAI, # Direct X.AI GROK access
|
||||
ProviderType.DIAL, # DIAL unified API access
|
||||
ProviderType.CUSTOM, # Local/self-hosted models
|
||||
ProviderType.OPENROUTER, # Catch-all for cloud models
|
||||
]
|
||||
|
||||
# Check providers in priority order
|
||||
instance = cls()
|
||||
logging.debug(f"Registry instance: {instance}")
|
||||
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
|
||||
|
||||
for provider_type in PROVIDER_PRIORITY_ORDER:
|
||||
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
|
||||
if provider_type in instance._providers:
|
||||
logging.debug(f"Found {provider_type} in registry")
|
||||
# Get or create provider instance
|
||||
@@ -244,14 +244,49 @@ class ModelProviderRegistry:
|
||||
|
||||
return os.getenv(env_var)
|
||||
|
||||
@classmethod
|
||||
def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]:
|
||||
"""Get a list of allowed canonical model names for a given provider.
|
||||
|
||||
Args:
|
||||
provider: The provider instance to get models for
|
||||
provider_type: The provider type for restriction checking
|
||||
|
||||
Returns:
|
||||
List of model names that are both supported and allowed
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
|
||||
allowed_models = []
|
||||
|
||||
# Get the provider's supported models
|
||||
try:
|
||||
# Use list_models to get all supported models (handles both regular and custom providers)
|
||||
supported_models = provider.list_models(respect_restrictions=False)
|
||||
except (NotImplementedError, AttributeError):
|
||||
# Fallback to SUPPORTED_MODELS if list_models not implemented
|
||||
try:
|
||||
supported_models = list(provider.SUPPORTED_MODELS.keys())
|
||||
except AttributeError:
|
||||
supported_models = []
|
||||
|
||||
# Filter by restrictions
|
||||
for model_name in supported_models:
|
||||
if restriction_service.is_allowed(provider_type, model_name):
|
||||
allowed_models.append(model_name)
|
||||
|
||||
return allowed_models
|
||||
|
||||
@classmethod
|
||||
def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str:
|
||||
"""Get the preferred fallback model based on available API keys and tool category.
|
||||
"""Get the preferred fallback model based on provider priority and tool category.
|
||||
|
||||
This method checks which providers have valid API keys and returns
|
||||
a sensible default model for auto mode fallback situations.
|
||||
|
||||
Takes into account model restrictions when selecting fallback models.
|
||||
This method orchestrates model selection by:
|
||||
1. Getting allowed models for each provider (respecting restrictions)
|
||||
2. Asking providers for their preference from the allowed list
|
||||
3. Falling back to first available model if no preference given
|
||||
|
||||
Args:
|
||||
tool_category: Optional category to influence model selection
|
||||
@@ -259,167 +294,42 @@ class ModelProviderRegistry:
|
||||
Returns:
|
||||
Model name string for fallback use
|
||||
"""
|
||||
# Import here to avoid circular import
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
# Get available models respecting restrictions
|
||||
available_models = cls.get_available_models(respect_restrictions=True)
|
||||
effective_category = tool_category or ToolModelCategory.BALANCED
|
||||
first_available_model = None
|
||||
|
||||
# Group by provider
|
||||
openai_models = [m for m, p in available_models.items() if p == ProviderType.OPENAI]
|
||||
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
|
||||
xai_models = [m for m, p in available_models.items() if p == ProviderType.XAI]
|
||||
openrouter_models = [m for m, p in available_models.items() if p == ProviderType.OPENROUTER]
|
||||
custom_models = [m for m, p in available_models.items() if p == ProviderType.CUSTOM]
|
||||
# Ask each provider for their preference in priority order
|
||||
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider:
|
||||
# 1. Registry filters the models first
|
||||
allowed_models = cls._get_allowed_models_for_provider(provider, provider_type)
|
||||
|
||||
openai_available = bool(openai_models)
|
||||
gemini_available = bool(gemini_models)
|
||||
xai_available = bool(xai_models)
|
||||
openrouter_available = bool(openrouter_models)
|
||||
custom_available = bool(custom_models)
|
||||
|
||||
if tool_category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# Prefer thinking-capable models for deep reasoning tools
|
||||
if openai_available and "o3" in openai_models:
|
||||
return "o3" # O3 for deep reasoning
|
||||
elif openai_available and openai_models:
|
||||
# Fall back to any available OpenAI model
|
||||
return openai_models[0]
|
||||
elif xai_available and "grok-3" in xai_models:
|
||||
return "grok-3" # GROK-3 for deep reasoning
|
||||
elif xai_available and xai_models:
|
||||
# Fall back to any available XAI model
|
||||
return xai_models[0]
|
||||
elif gemini_available and any("pro" in m for m in gemini_models):
|
||||
# Find the pro model (handles full names)
|
||||
return next(m for m in gemini_models if "pro" in m)
|
||||
elif gemini_available and gemini_models:
|
||||
# Fall back to any available Gemini model
|
||||
return gemini_models[0]
|
||||
elif openrouter_available:
|
||||
# Try to find thinking-capable model from openrouter
|
||||
thinking_model = cls._find_extended_thinking_model()
|
||||
if thinking_model:
|
||||
return thinking_model
|
||||
# Fallback to first available OpenRouter model
|
||||
return openrouter_models[0]
|
||||
elif custom_available:
|
||||
# Fallback to custom models when available
|
||||
return custom_models[0]
|
||||
else:
|
||||
# Fallback to pro if nothing found
|
||||
return "gemini-2.5-pro"
|
||||
|
||||
elif tool_category == ToolModelCategory.FAST_RESPONSE:
|
||||
# Prefer fast, cost-efficient models
|
||||
if openai_available and "o4-mini" in openai_models:
|
||||
return "o4-mini" # Latest, fast and efficient
|
||||
elif openai_available and "o3-mini" in openai_models:
|
||||
return "o3-mini" # Second choice
|
||||
elif openai_available and openai_models:
|
||||
# Fall back to any available OpenAI model
|
||||
return openai_models[0]
|
||||
elif xai_available and "grok-3-fast" in xai_models:
|
||||
return "grok-3-fast" # GROK-3 Fast for speed
|
||||
elif xai_available and xai_models:
|
||||
# Fall back to any available XAI model
|
||||
return xai_models[0]
|
||||
elif gemini_available and any("flash" in m for m in gemini_models):
|
||||
# Find the flash model (handles full names)
|
||||
# Prefer 2.5 over 2.0 for backward compatibility
|
||||
flash_models = [m for m in gemini_models if "flash" in m]
|
||||
# Sort to ensure 2.5 comes before 2.0
|
||||
flash_models_sorted = sorted(flash_models, reverse=True)
|
||||
return flash_models_sorted[0]
|
||||
elif gemini_available and gemini_models:
|
||||
# Fall back to any available Gemini model
|
||||
return gemini_models[0]
|
||||
elif openrouter_available:
|
||||
# Fallback to first available OpenRouter model
|
||||
return openrouter_models[0]
|
||||
elif custom_available:
|
||||
# Fallback to custom models when available
|
||||
return custom_models[0]
|
||||
else:
|
||||
# Default to flash
|
||||
return "gemini-2.5-flash"
|
||||
|
||||
# BALANCED or no category specified - use existing balanced logic
|
||||
if openai_available and "o4-mini" in openai_models:
|
||||
return "o4-mini" # Latest balanced performance/cost
|
||||
elif openai_available and "o3-mini" in openai_models:
|
||||
return "o3-mini" # Second choice
|
||||
elif openai_available and openai_models:
|
||||
return openai_models[0]
|
||||
elif xai_available and "grok-3" in xai_models:
|
||||
return "grok-3" # GROK-3 as balanced choice
|
||||
elif xai_available and xai_models:
|
||||
return xai_models[0]
|
||||
elif gemini_available and any("flash" in m for m in gemini_models):
|
||||
# Prefer 2.5 over 2.0 for backward compatibility
|
||||
flash_models = [m for m in gemini_models if "flash" in m]
|
||||
flash_models_sorted = sorted(flash_models, reverse=True)
|
||||
return flash_models_sorted[0]
|
||||
elif gemini_available and gemini_models:
|
||||
return gemini_models[0]
|
||||
elif openrouter_available:
|
||||
return openrouter_models[0]
|
||||
elif custom_available:
|
||||
# Fallback to custom models when available
|
||||
return custom_models[0]
|
||||
else:
|
||||
# No models available due to restrictions - check if any providers exist
|
||||
if not available_models:
|
||||
# This might happen if all models are restricted
|
||||
logging.warning("No models available due to restrictions")
|
||||
# Return a reasonable default for backward compatibility
|
||||
return "gemini-2.5-flash"
|
||||
|
||||
@classmethod
|
||||
def _find_extended_thinking_model(cls) -> Optional[str]:
|
||||
"""Find a model suitable for extended reasoning from custom/openrouter providers.
|
||||
|
||||
Returns:
|
||||
Model name if found, None otherwise
|
||||
"""
|
||||
# Check custom provider first
|
||||
custom_provider = cls.get_provider(ProviderType.CUSTOM)
|
||||
if custom_provider:
|
||||
# Check if it's a CustomModelProvider and has thinking models
|
||||
try:
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
if isinstance(custom_provider, CustomProvider) and hasattr(custom_provider, "model_registry"):
|
||||
for model_name, config in custom_provider.model_registry.items():
|
||||
if config.get("supports_extended_thinking", False):
|
||||
return model_name
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Then check OpenRouter for high-context/powerful models
|
||||
openrouter_provider = cls.get_provider(ProviderType.OPENROUTER)
|
||||
if openrouter_provider:
|
||||
# Prefer models known for deep reasoning
|
||||
preferred_models = [
|
||||
"anthropic/claude-sonnet-4",
|
||||
"anthropic/claude-opus-4",
|
||||
"google/gemini-2.5-pro",
|
||||
"google/gemini-pro-1.5",
|
||||
"meta-llama/llama-3.1-70b-instruct",
|
||||
"mistralai/mixtral-8x7b-instruct",
|
||||
]
|
||||
for model in preferred_models:
|
||||
try:
|
||||
if openrouter_provider.validate_model_name(model):
|
||||
return model
|
||||
except Exception as e:
|
||||
# Log the error for debugging purposes but continue searching
|
||||
import logging
|
||||
|
||||
logging.warning(f"Model validation for '{model}' on OpenRouter failed: {e}")
|
||||
if not allowed_models:
|
||||
continue
|
||||
|
||||
return None
|
||||
# 2. Keep track of the first available model as fallback
|
||||
if not first_available_model:
|
||||
first_available_model = sorted(allowed_models)[0]
|
||||
|
||||
# 3. Ask provider to pick from allowed list
|
||||
preferred_model = provider.get_preferred_model(effective_category, allowed_models)
|
||||
|
||||
if preferred_model:
|
||||
logging.debug(
|
||||
f"Provider {provider_type.value} selected '{preferred_model}' for category '{effective_category.value}'"
|
||||
)
|
||||
return preferred_model
|
||||
|
||||
# If no provider returned a preference, use first available model
|
||||
if first_available_model:
|
||||
logging.debug(f"No provider preference, using first available: {first_available_model}")
|
||||
return first_available_model
|
||||
|
||||
# Ultimate fallback if no providers have models
|
||||
logging.warning("No models available from any provider, using default fallback")
|
||||
return "gemini-2.5-flash"
|
||||
|
||||
@classmethod
|
||||
def get_available_providers_with_keys(cls) -> list[ProviderType]:
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""X.AI (GROK) model provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
@@ -133,3 +136,41 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
# Currently GROK models do not support extended thinking
|
||||
# This may change with future GROK model releases
|
||||
return False
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get XAI's preferred model for a given category from allowed models.
|
||||
|
||||
Args:
|
||||
category: The tool category requiring a model
|
||||
allowed_models: Pre-filtered list of models allowed by restrictions
|
||||
|
||||
Returns:
|
||||
Preferred model name or None
|
||||
"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
if not allowed_models:
|
||||
return None
|
||||
|
||||
if category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# Prefer GROK-3 for reasoning
|
||||
if "grok-3" in allowed_models:
|
||||
return "grok-3"
|
||||
# Fall back to any available model
|
||||
return allowed_models[0]
|
||||
|
||||
elif category == ToolModelCategory.FAST_RESPONSE:
|
||||
# Prefer GROK-3-Fast for speed
|
||||
if "grok-3-fast" in allowed_models:
|
||||
return "grok-3-fast"
|
||||
# Fall back to any available model
|
||||
return allowed_models[0]
|
||||
|
||||
else: # BALANCED or default
|
||||
# Prefer standard GROK-3 for balanced use
|
||||
if "grok-3" in allowed_models:
|
||||
return "grok-3"
|
||||
elif "grok-3-fast" in allowed_models:
|
||||
return "grok-3-fast"
|
||||
# Fall back to any available model
|
||||
return allowed_models[0]
|
||||
|
||||
Reference in New Issue
Block a user