Merge branch 'main' into refactor-image-validation
This commit is contained in:
@@ -7,7 +7,10 @@ import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from utils.file_types import IMAGES, get_image_mime_type
|
||||
|
||||
@@ -123,10 +126,10 @@ def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint
|
||||
return FixedTemperatureConstraint(1.0)
|
||||
elif constraint_type == "discrete":
|
||||
# For models with specific allowed values - using common OpenAI values as default
|
||||
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.7)
|
||||
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
|
||||
else:
|
||||
# Default range constraint (for "range" or None)
|
||||
return RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -159,24 +162,11 @@ class ModelCapabilities:
|
||||
# Custom model flag (for models that only work with custom endpoints)
|
||||
is_custom: bool = False # Whether this model requires custom API endpoints
|
||||
|
||||
# Temperature constraint object - preferred way to define temperature limits
|
||||
# Temperature constraint object - defines temperature limits and behavior
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
)
|
||||
|
||||
# Backward compatibility property for existing code
|
||||
@property
|
||||
def temperature_range(self) -> tuple[float, float]:
|
||||
"""Backward compatibility for existing code that uses temperature_range."""
|
||||
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
|
||||
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
|
||||
elif isinstance(self.temperature_constraint, FixedTemperatureConstraint):
|
||||
return (self.temperature_constraint.value, self.temperature_constraint.value)
|
||||
elif isinstance(self.temperature_constraint, DiscreteTemperatureConstraint):
|
||||
values = self.temperature_constraint.allowed_values
|
||||
return (min(values), max(values))
|
||||
return (0.0, 2.0) # Fallback
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
@@ -220,7 +210,7 @@ class ModelProvider(ABC):
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
@@ -276,18 +266,15 @@ class ModelProvider(ABC):
|
||||
if not capabilities.supports_temperature:
|
||||
return None
|
||||
|
||||
# Get temperature range
|
||||
min_temp, max_temp = capabilities.temperature_range
|
||||
# Use temperature constraint to get corrected value
|
||||
corrected_temp = capabilities.temperature_constraint.get_corrected_value(requested_temperature)
|
||||
|
||||
# Clamp to valid range
|
||||
if requested_temperature < min_temp:
|
||||
logger.debug(f"Clamping temperature from {requested_temperature} to {min_temp} for model {model_name}")
|
||||
return min_temp
|
||||
elif requested_temperature > max_temp:
|
||||
logger.debug(f"Clamping temperature from {requested_temperature} to {max_temp} for model {model_name}")
|
||||
return max_temp
|
||||
else:
|
||||
return requested_temperature
|
||||
if corrected_temp != requested_temperature:
|
||||
logger.debug(
|
||||
f"Adjusting temperature from {requested_temperature} to {corrected_temp} for model {model_name}"
|
||||
)
|
||||
|
||||
return corrected_temp
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not determine effective temperature for {model_name}: {e}")
|
||||
@@ -302,10 +289,10 @@ class ModelProvider(ABC):
|
||||
"""
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
# Validate temperature
|
||||
min_temp, max_temp = capabilities.temperature_range
|
||||
if not min_temp <= temperature <= max_temp:
|
||||
raise ValueError(f"Temperature {temperature} out of range [{min_temp}, {max_temp}] for model {model_name}")
|
||||
# Validate temperature using constraint
|
||||
if not capabilities.temperature_constraint.validate(temperature):
|
||||
constraint_desc = capabilities.temperature_constraint.get_description()
|
||||
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
|
||||
|
||||
@abstractmethod
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
@@ -520,3 +507,28 @@ class ModelProvider(ABC):
|
||||
"""
|
||||
# Base implementation: no resources to clean up
|
||||
return
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get the preferred model from this provider for a given category.
|
||||
|
||||
Args:
|
||||
category: The tool category requiring a model
|
||||
allowed_models: Pre-filtered list of model names that are allowed by restrictions
|
||||
|
||||
Returns:
|
||||
Model name if this provider has a preference, None otherwise
|
||||
"""
|
||||
# Default implementation - providers can override with specific logic
|
||||
return None
|
||||
|
||||
def get_model_registry(self) -> Optional[dict[str, Any]]:
|
||||
"""Get the model registry for providers that maintain one.
|
||||
|
||||
This is a hook method for providers like CustomProvider that maintain
|
||||
a dynamic model registry.
|
||||
|
||||
Returns:
|
||||
Model registry dict or None if not applicable
|
||||
"""
|
||||
# Default implementation - most providers don't have a registry
|
||||
return None
|
||||
|
||||
@@ -236,7 +236,7 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
|
||||
@@ -375,7 +375,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
images: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
@@ -18,6 +21,25 @@ class GeminiModelProvider(ModelProvider):
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"gemini-2.5-pro": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.5-pro",
|
||||
friendly_name="Gemini (Pro 2.5)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # Vision capability
|
||||
max_image_size_mb=32.0, # Higher limit for Pro model
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
max_thinking_tokens=32768, # Max thinking tokens for Pro model
|
||||
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||
aliases=["pro", "gemini pro", "gemini-pro"],
|
||||
),
|
||||
"gemini-2.0-flash": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.0-flash",
|
||||
@@ -74,25 +96,6 @@ class GeminiModelProvider(ModelProvider):
|
||||
description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||
aliases=["flash", "flash2.5"],
|
||||
),
|
||||
"gemini-2.5-pro": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.5-pro",
|
||||
friendly_name="Gemini (Pro 2.5)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # Vision capability
|
||||
max_image_size_mb=32.0, # Higher limit for Pro model
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
max_thinking_tokens=32768, # Max thinking tokens for Pro model
|
||||
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||
aliases=["pro", "gemini pro", "gemini-pro"],
|
||||
),
|
||||
}
|
||||
|
||||
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
||||
@@ -151,7 +154,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
thinking_mode: str = "medium",
|
||||
images: Optional[list[str]] = None,
|
||||
@@ -458,3 +461,67 @@ class GeminiModelProvider(ModelProvider):
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {image_path}: {e}")
|
||||
return None
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get Gemini's preferred model for a given category from allowed models.
|
||||
|
||||
Args:
|
||||
category: The tool category requiring a model
|
||||
allowed_models: Pre-filtered list of models allowed by restrictions
|
||||
|
||||
Returns:
|
||||
Preferred model name or None
|
||||
"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
if not allowed_models:
|
||||
return None
|
||||
|
||||
# Helper to find best model from candidates
|
||||
def find_best(candidates: list[str]) -> Optional[str]:
|
||||
"""Return best model from candidates (sorted for consistency)."""
|
||||
return sorted(candidates, reverse=True)[0] if candidates else None
|
||||
|
||||
if category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# For extended reasoning, prefer models with thinking support
|
||||
# First try Pro models that support thinking
|
||||
pro_thinking = [
|
||||
m
|
||||
for m in allowed_models
|
||||
if "pro" in m and m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
|
||||
]
|
||||
if pro_thinking:
|
||||
return find_best(pro_thinking)
|
||||
|
||||
# Then any model that supports thinking
|
||||
any_thinking = [
|
||||
m
|
||||
for m in allowed_models
|
||||
if m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
|
||||
]
|
||||
if any_thinking:
|
||||
return find_best(any_thinking)
|
||||
|
||||
# Finally, just prefer Pro models even without thinking
|
||||
pro_models = [m for m in allowed_models if "pro" in m]
|
||||
if pro_models:
|
||||
return find_best(pro_models)
|
||||
|
||||
elif category == ToolModelCategory.FAST_RESPONSE:
|
||||
# Prefer Flash models for speed
|
||||
flash_models = [m for m in allowed_models if "flash" in m]
|
||||
if flash_models:
|
||||
return find_best(flash_models)
|
||||
|
||||
# Default for BALANCED or as fallback
|
||||
# Prefer Flash for balanced use, then Pro, then anything
|
||||
flash_models = [m for m in allowed_models if "flash" in m]
|
||||
if flash_models:
|
||||
return find_best(flash_models)
|
||||
|
||||
pro_models = [m for m in allowed_models if "pro" in m]
|
||||
if pro_models:
|
||||
return find_best(pro_models)
|
||||
|
||||
# Ultimate fallback to best available model
|
||||
return find_best(allowed_models)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Base class for OpenAI-compatible API providers."""
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
@@ -219,10 +221,20 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
# Create httpx client with minimal config to avoid proxy conflicts
|
||||
# Note: proxies parameter was removed in httpx 0.28.0
|
||||
http_client = httpx.Client(
|
||||
timeout=timeout_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
# Check for test transport injection
|
||||
if hasattr(self, "_test_transport"):
|
||||
# Use custom transport for testing (HTTP recording/replay)
|
||||
http_client = httpx.Client(
|
||||
transport=self._test_transport,
|
||||
timeout=timeout_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
else:
|
||||
# Normal production client
|
||||
http_client = httpx.Client(
|
||||
timeout=timeout_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
# Keep client initialization minimal to avoid proxy parameter conflicts
|
||||
client_kwargs = {
|
||||
@@ -263,6 +275,63 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
return self._client
|
||||
|
||||
def _sanitize_for_logging(self, params: dict) -> dict:
|
||||
"""Sanitize sensitive data from parameters before logging.
|
||||
|
||||
Args:
|
||||
params: Dictionary of API parameters
|
||||
|
||||
Returns:
|
||||
dict: Sanitized copy of parameters safe for logging
|
||||
"""
|
||||
sanitized = copy.deepcopy(params)
|
||||
|
||||
# Sanitize messages content
|
||||
if "input" in sanitized:
|
||||
for msg in sanitized.get("input", []):
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
for content_item in msg.get("content", []):
|
||||
if isinstance(content_item, dict) and "text" in content_item:
|
||||
# Truncate long text and add ellipsis
|
||||
text = content_item["text"]
|
||||
if len(text) > 100:
|
||||
content_item["text"] = text[:100] + "... [truncated]"
|
||||
|
||||
# Remove any API keys that might be in headers/auth
|
||||
sanitized.pop("api_key", None)
|
||||
sanitized.pop("authorization", None)
|
||||
|
||||
return sanitized
|
||||
|
||||
def _safe_extract_output_text(self, response) -> str:
|
||||
"""Safely extract output_text from o3-pro response with validation.
|
||||
|
||||
Args:
|
||||
response: Response object from OpenAI SDK
|
||||
|
||||
Returns:
|
||||
str: The output text content
|
||||
|
||||
Raises:
|
||||
ValueError: If output_text is missing, None, or not a string
|
||||
"""
|
||||
logging.debug(f"Response object type: {type(response)}")
|
||||
logging.debug(f"Response attributes: {dir(response)}")
|
||||
|
||||
if not hasattr(response, "output_text"):
|
||||
raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
|
||||
|
||||
content = response.output_text
|
||||
logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
|
||||
|
||||
if content is None:
|
||||
raise ValueError("o3-pro returned None for output_text")
|
||||
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}")
|
||||
|
||||
return content
|
||||
|
||||
def _generate_with_responses_endpoint(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -308,30 +377,23 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
max_retries = 4
|
||||
retry_delays = [1, 3, 5, 8]
|
||||
last_exception = None
|
||||
actual_attempts = 0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try: # Log the exact payload being sent for debugging
|
||||
try: # Log sanitized payload for debugging
|
||||
import json
|
||||
|
||||
sanitized_params = self._sanitize_for_logging(completion_params)
|
||||
logging.info(
|
||||
f"o3-pro API request payload: {json.dumps(completion_params, indent=2, ensure_ascii=False)}"
|
||||
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
# Use OpenAI client's responses endpoint
|
||||
response = self.client.responses.create(**completion_params)
|
||||
|
||||
# Extract content and usage from responses endpoint format
|
||||
# The response format is different for responses endpoint
|
||||
content = ""
|
||||
if hasattr(response, "output") and response.output:
|
||||
if hasattr(response.output, "content") and response.output.content:
|
||||
# Look for output_text in content
|
||||
for content_item in response.output.content:
|
||||
if hasattr(content_item, "type") and content_item.type == "output_text":
|
||||
content = content_item.text
|
||||
break
|
||||
elif hasattr(response.output, "text"):
|
||||
content = response.output.text
|
||||
# Extract content from responses endpoint format
|
||||
# Use validation helper to safely extract output_text
|
||||
content = self._safe_extract_output_text(response)
|
||||
|
||||
# Try to extract usage information
|
||||
usage = None
|
||||
@@ -370,14 +432,13 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
if is_retryable and attempt < max_retries - 1:
|
||||
delay = retry_delays[attempt]
|
||||
logging.warning(
|
||||
f"Retryable error for o3-pro responses endpoint, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break
|
||||
|
||||
# If we get here, all retries failed
|
||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from last_exception
|
||||
@@ -387,7 +448,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
images: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
@@ -480,7 +541,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
completion_params[key] = value
|
||||
|
||||
# Check if this is o3-pro and needs the responses endpoint
|
||||
if resolved_model == "o3-pro-2025-06-10":
|
||||
if resolved_model == "o3-pro":
|
||||
# This model requires the /v1/responses endpoint
|
||||
# If it fails, we should not fall back to chat/completions
|
||||
return self._generate_with_responses_endpoint(
|
||||
@@ -496,8 +557,10 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
||||
|
||||
last_exception = None
|
||||
actual_attempts = 0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||
try:
|
||||
# Generate completion
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
@@ -535,12 +598,11 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
|
||||
# Log retry attempt
|
||||
logging.warning(
|
||||
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
# If we get here, all retries failed
|
||||
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from last_exception
|
||||
@@ -575,11 +637,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
# Try common encodings based on model patterns
|
||||
if "gpt-4" in model_name or "gpt-3.5" in model_name:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
else:
|
||||
encoding = tiktoken.get_encoding("cl100k_base") # Default
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
return len(encoding.encode(text))
|
||||
|
||||
@@ -678,11 +736,13 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
"""
|
||||
# Common vision-capable models - only include models that actually support images
|
||||
vision_models = {
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4.1-2025-04-14", # GPT-4.1 supports vision
|
||||
"gpt-4.1-2025-04-14",
|
||||
"o3",
|
||||
"o3-mini",
|
||||
"o3-pro",
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""OpenAI model provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
@@ -19,6 +22,60 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"gpt-5": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="gpt-5",
|
||||
friendly_name="OpenAI (GPT-5)",
|
||||
context_window=400_000, # 400K tokens
|
||||
max_output_tokens=128_000, # 128K max output tokens
|
||||
supports_extended_thinking=True, # Supports reasoning tokens
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # GPT-5 supports vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=True, # Regular models accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="GPT-5 (400K context, 128K output) - Advanced model with reasoning support",
|
||||
aliases=["gpt5", "gpt-5"],
|
||||
),
|
||||
"gpt-5-mini": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="gpt-5-mini",
|
||||
friendly_name="OpenAI (GPT-5-mini)",
|
||||
context_window=400_000, # 400K tokens
|
||||
max_output_tokens=128_000, # 128K max output tokens
|
||||
supports_extended_thinking=True, # Supports reasoning tokens
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # GPT-5-mini supports vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="GPT-5-mini (400K context, 128K output) - Efficient variant with reasoning support",
|
||||
aliases=["gpt5-mini", "gpt5mini", "mini"],
|
||||
),
|
||||
"gpt-5-nano": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="gpt-5-nano",
|
||||
friendly_name="OpenAI (GPT-5 nano)",
|
||||
context_window=400_000,
|
||||
max_output_tokens=128_000,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="GPT-5 nano (400K context) - Fastest, cheapest version of GPT-5 for summarization and classification tasks",
|
||||
aliases=["gpt5nano", "gpt5-nano", "nano"],
|
||||
),
|
||||
"o3": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o3",
|
||||
@@ -55,9 +112,9 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||
aliases=["o3mini", "o3-mini"],
|
||||
),
|
||||
"o3-pro-2025-06-10": ModelCapabilities(
|
||||
"o3-pro": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o3-pro-2025-06-10",
|
||||
model_name="o3-pro",
|
||||
friendly_name="OpenAI (O3-Pro)",
|
||||
context_window=200_000, # 200K tokens
|
||||
max_output_tokens=65536, # 64K max output tokens
|
||||
@@ -89,11 +146,11 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
supports_temperature=False, # O4 models don't accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
||||
aliases=["mini", "o4mini", "o4-mini"],
|
||||
aliases=["o4mini", "o4-mini"],
|
||||
),
|
||||
"gpt-4.1-2025-04-14": ModelCapabilities(
|
||||
"gpt-4.1": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="gpt-4.1-2025-04-14",
|
||||
model_name="gpt-4.1",
|
||||
friendly_name="OpenAI (GPT 4.1)",
|
||||
context_window=1_000_000, # 1M tokens
|
||||
max_output_tokens=32_768,
|
||||
@@ -107,7 +164,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
supports_temperature=True, # Regular models accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
||||
aliases=["gpt4.1"],
|
||||
aliases=["gpt4.1", "gpt-4.1"],
|
||||
),
|
||||
}
|
||||
|
||||
@@ -119,21 +176,41 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||
"""Get capabilities for a specific OpenAI model."""
|
||||
# Resolve shorthand
|
||||
# First check if it's a key in SUPPORTED_MODELS
|
||||
if model_name in self.SUPPORTED_MODELS:
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, model_name, model_name):
|
||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||
return self.SUPPORTED_MODELS[model_name]
|
||||
|
||||
# Try resolving as alias
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||
# Check if resolved name is a key
|
||||
if resolved_name in self.SUPPORTED_MODELS:
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||
# Finally check if resolved name matches any API model name
|
||||
for key, capabilities in self.SUPPORTED_MODELS.items():
|
||||
if resolved_name == capabilities.model_name:
|
||||
# Check if model is allowed by restrictions
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
restriction_service = get_restriction_service()
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, key, model_name):
|
||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||
return capabilities
|
||||
|
||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -162,7 +239,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
@@ -182,6 +259,47 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
# Currently no OpenAI models support extended thinking
|
||||
# This may change with future O3 models
|
||||
# GPT-5 models support reasoning tokens (extended thinking)
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
if resolved_name in ["gpt-5", "gpt-5-mini"]:
|
||||
return True
|
||||
# O3 models don't support extended thinking yet
|
||||
return False
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get OpenAI's preferred model for a given category from allowed models.
|
||||
|
||||
Args:
|
||||
category: The tool category requiring a model
|
||||
allowed_models: Pre-filtered list of models allowed by restrictions
|
||||
|
||||
Returns:
|
||||
Preferred model name or None
|
||||
"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
if not allowed_models:
|
||||
return None
|
||||
|
||||
# Helper to find first available from preference list
|
||||
def find_first(preferences: list[str]) -> Optional[str]:
|
||||
"""Return first available model from preference list."""
|
||||
for model in preferences:
|
||||
if model in allowed_models:
|
||||
return model
|
||||
return None
|
||||
|
||||
if category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# Prefer models with extended thinking support
|
||||
preferred = find_first(["o3", "o3-pro", "gpt-5"])
|
||||
return preferred if preferred else allowed_models[0]
|
||||
|
||||
elif category == ToolModelCategory.FAST_RESPONSE:
|
||||
# Prefer fast, cost-efficient models
|
||||
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
|
||||
return preferred if preferred else allowed_models[0]
|
||||
|
||||
else: # BALANCED or default
|
||||
# Prefer balanced performance/cost models
|
||||
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
|
||||
return preferred if preferred else allowed_models[0]
|
||||
|
||||
@@ -158,7 +158,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
|
||||
@@ -15,6 +15,17 @@ class ModelProviderRegistry:
|
||||
|
||||
_instance = None
|
||||
|
||||
# Provider priority order for model selection
|
||||
# Native APIs first, then custom endpoints, then catch-all providers
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE, # Direct Gemini access
|
||||
ProviderType.OPENAI, # Direct OpenAI access
|
||||
ProviderType.XAI, # Direct X.AI GROK access
|
||||
ProviderType.DIAL, # DIAL unified API access
|
||||
ProviderType.CUSTOM, # Local/self-hosted models
|
||||
ProviderType.OPENROUTER, # Catch-all for cloud models
|
||||
]
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern for registry."""
|
||||
if cls._instance is None:
|
||||
@@ -103,30 +114,19 @@ class ModelProviderRegistry:
|
||||
3. OPENROUTER - Catch-all for cloud models via unified API
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (e.g., "gemini-2.5-flash", "o3-mini")
|
||||
model_name: Name of the model (e.g., "gemini-2.5-flash", "gpt5")
|
||||
|
||||
Returns:
|
||||
ModelProvider instance that supports this model
|
||||
"""
|
||||
logging.debug(f"get_provider_for_model called with model_name='{model_name}'")
|
||||
|
||||
# Define explicit provider priority order
|
||||
# Native APIs first, then custom endpoints, then catch-all providers
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE, # Direct Gemini access
|
||||
ProviderType.OPENAI, # Direct OpenAI access
|
||||
ProviderType.XAI, # Direct X.AI GROK access
|
||||
ProviderType.DIAL, # DIAL unified API access
|
||||
ProviderType.CUSTOM, # Local/self-hosted models
|
||||
ProviderType.OPENROUTER, # Catch-all for cloud models
|
||||
]
|
||||
|
||||
# Check providers in priority order
|
||||
instance = cls()
|
||||
logging.debug(f"Registry instance: {instance}")
|
||||
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
|
||||
|
||||
for provider_type in PROVIDER_PRIORITY_ORDER:
|
||||
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
|
||||
if provider_type in instance._providers:
|
||||
logging.debug(f"Found {provider_type} in registry")
|
||||
# Get or create provider instance
|
||||
@@ -244,14 +244,49 @@ class ModelProviderRegistry:
|
||||
|
||||
return os.getenv(env_var)
|
||||
|
||||
@classmethod
|
||||
def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]:
|
||||
"""Get a list of allowed canonical model names for a given provider.
|
||||
|
||||
Args:
|
||||
provider: The provider instance to get models for
|
||||
provider_type: The provider type for restriction checking
|
||||
|
||||
Returns:
|
||||
List of model names that are both supported and allowed
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
|
||||
allowed_models = []
|
||||
|
||||
# Get the provider's supported models
|
||||
try:
|
||||
# Use list_models to get all supported models (handles both regular and custom providers)
|
||||
supported_models = provider.list_models(respect_restrictions=False)
|
||||
except (NotImplementedError, AttributeError):
|
||||
# Fallback to SUPPORTED_MODELS if list_models not implemented
|
||||
try:
|
||||
supported_models = list(provider.SUPPORTED_MODELS.keys())
|
||||
except AttributeError:
|
||||
supported_models = []
|
||||
|
||||
# Filter by restrictions
|
||||
for model_name in supported_models:
|
||||
if restriction_service.is_allowed(provider_type, model_name):
|
||||
allowed_models.append(model_name)
|
||||
|
||||
return allowed_models
|
||||
|
||||
@classmethod
|
||||
def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str:
|
||||
"""Get the preferred fallback model based on available API keys and tool category.
|
||||
"""Get the preferred fallback model based on provider priority and tool category.
|
||||
|
||||
This method checks which providers have valid API keys and returns
|
||||
a sensible default model for auto mode fallback situations.
|
||||
|
||||
Takes into account model restrictions when selecting fallback models.
|
||||
This method orchestrates model selection by:
|
||||
1. Getting allowed models for each provider (respecting restrictions)
|
||||
2. Asking providers for their preference from the allowed list
|
||||
3. Falling back to first available model if no preference given
|
||||
|
||||
Args:
|
||||
tool_category: Optional category to influence model selection
|
||||
@@ -259,167 +294,42 @@ class ModelProviderRegistry:
|
||||
Returns:
|
||||
Model name string for fallback use
|
||||
"""
|
||||
# Import here to avoid circular import
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
# Get available models respecting restrictions
|
||||
available_models = cls.get_available_models(respect_restrictions=True)
|
||||
effective_category = tool_category or ToolModelCategory.BALANCED
|
||||
first_available_model = None
|
||||
|
||||
# Group by provider
|
||||
openai_models = [m for m, p in available_models.items() if p == ProviderType.OPENAI]
|
||||
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
|
||||
xai_models = [m for m, p in available_models.items() if p == ProviderType.XAI]
|
||||
openrouter_models = [m for m, p in available_models.items() if p == ProviderType.OPENROUTER]
|
||||
custom_models = [m for m, p in available_models.items() if p == ProviderType.CUSTOM]
|
||||
# Ask each provider for their preference in priority order
|
||||
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
|
||||
provider = cls.get_provider(provider_type)
|
||||
if provider:
|
||||
# 1. Registry filters the models first
|
||||
allowed_models = cls._get_allowed_models_for_provider(provider, provider_type)
|
||||
|
||||
openai_available = bool(openai_models)
|
||||
gemini_available = bool(gemini_models)
|
||||
xai_available = bool(xai_models)
|
||||
openrouter_available = bool(openrouter_models)
|
||||
custom_available = bool(custom_models)
|
||||
|
||||
if tool_category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# Prefer thinking-capable models for deep reasoning tools
|
||||
if openai_available and "o3" in openai_models:
|
||||
return "o3" # O3 for deep reasoning
|
||||
elif openai_available and openai_models:
|
||||
# Fall back to any available OpenAI model
|
||||
return openai_models[0]
|
||||
elif xai_available and "grok-3" in xai_models:
|
||||
return "grok-3" # GROK-3 for deep reasoning
|
||||
elif xai_available and xai_models:
|
||||
# Fall back to any available XAI model
|
||||
return xai_models[0]
|
||||
elif gemini_available and any("pro" in m for m in gemini_models):
|
||||
# Find the pro model (handles full names)
|
||||
return next(m for m in gemini_models if "pro" in m)
|
||||
elif gemini_available and gemini_models:
|
||||
# Fall back to any available Gemini model
|
||||
return gemini_models[0]
|
||||
elif openrouter_available:
|
||||
# Try to find thinking-capable model from openrouter
|
||||
thinking_model = cls._find_extended_thinking_model()
|
||||
if thinking_model:
|
||||
return thinking_model
|
||||
# Fallback to first available OpenRouter model
|
||||
return openrouter_models[0]
|
||||
elif custom_available:
|
||||
# Fallback to custom models when available
|
||||
return custom_models[0]
|
||||
else:
|
||||
# Fallback to pro if nothing found
|
||||
return "gemini-2.5-pro"
|
||||
|
||||
elif tool_category == ToolModelCategory.FAST_RESPONSE:
|
||||
# Prefer fast, cost-efficient models
|
||||
if openai_available and "o4-mini" in openai_models:
|
||||
return "o4-mini" # Latest, fast and efficient
|
||||
elif openai_available and "o3-mini" in openai_models:
|
||||
return "o3-mini" # Second choice
|
||||
elif openai_available and openai_models:
|
||||
# Fall back to any available OpenAI model
|
||||
return openai_models[0]
|
||||
elif xai_available and "grok-3-fast" in xai_models:
|
||||
return "grok-3-fast" # GROK-3 Fast for speed
|
||||
elif xai_available and xai_models:
|
||||
# Fall back to any available XAI model
|
||||
return xai_models[0]
|
||||
elif gemini_available and any("flash" in m for m in gemini_models):
|
||||
# Find the flash model (handles full names)
|
||||
# Prefer 2.5 over 2.0 for backward compatibility
|
||||
flash_models = [m for m in gemini_models if "flash" in m]
|
||||
# Sort to ensure 2.5 comes before 2.0
|
||||
flash_models_sorted = sorted(flash_models, reverse=True)
|
||||
return flash_models_sorted[0]
|
||||
elif gemini_available and gemini_models:
|
||||
# Fall back to any available Gemini model
|
||||
return gemini_models[0]
|
||||
elif openrouter_available:
|
||||
# Fallback to first available OpenRouter model
|
||||
return openrouter_models[0]
|
||||
elif custom_available:
|
||||
# Fallback to custom models when available
|
||||
return custom_models[0]
|
||||
else:
|
||||
# Default to flash
|
||||
return "gemini-2.5-flash"
|
||||
|
||||
# BALANCED or no category specified - use existing balanced logic
|
||||
if openai_available and "o4-mini" in openai_models:
|
||||
return "o4-mini" # Latest balanced performance/cost
|
||||
elif openai_available and "o3-mini" in openai_models:
|
||||
return "o3-mini" # Second choice
|
||||
elif openai_available and openai_models:
|
||||
return openai_models[0]
|
||||
elif xai_available and "grok-3" in xai_models:
|
||||
return "grok-3" # GROK-3 as balanced choice
|
||||
elif xai_available and xai_models:
|
||||
return xai_models[0]
|
||||
elif gemini_available and any("flash" in m for m in gemini_models):
|
||||
# Prefer 2.5 over 2.0 for backward compatibility
|
||||
flash_models = [m for m in gemini_models if "flash" in m]
|
||||
flash_models_sorted = sorted(flash_models, reverse=True)
|
||||
return flash_models_sorted[0]
|
||||
elif gemini_available and gemini_models:
|
||||
return gemini_models[0]
|
||||
elif openrouter_available:
|
||||
return openrouter_models[0]
|
||||
elif custom_available:
|
||||
# Fallback to custom models when available
|
||||
return custom_models[0]
|
||||
else:
|
||||
# No models available due to restrictions - check if any providers exist
|
||||
if not available_models:
|
||||
# This might happen if all models are restricted
|
||||
logging.warning("No models available due to restrictions")
|
||||
# Return a reasonable default for backward compatibility
|
||||
return "gemini-2.5-flash"
|
||||
|
||||
@classmethod
|
||||
def _find_extended_thinking_model(cls) -> Optional[str]:
|
||||
"""Find a model suitable for extended reasoning from custom/openrouter providers.
|
||||
|
||||
Returns:
|
||||
Model name if found, None otherwise
|
||||
"""
|
||||
# Check custom provider first
|
||||
custom_provider = cls.get_provider(ProviderType.CUSTOM)
|
||||
if custom_provider:
|
||||
# Check if it's a CustomModelProvider and has thinking models
|
||||
try:
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
if isinstance(custom_provider, CustomProvider) and hasattr(custom_provider, "model_registry"):
|
||||
for model_name, config in custom_provider.model_registry.items():
|
||||
if config.get("supports_extended_thinking", False):
|
||||
return model_name
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Then check OpenRouter for high-context/powerful models
|
||||
openrouter_provider = cls.get_provider(ProviderType.OPENROUTER)
|
||||
if openrouter_provider:
|
||||
# Prefer models known for deep reasoning
|
||||
preferred_models = [
|
||||
"anthropic/claude-sonnet-4",
|
||||
"anthropic/claude-opus-4",
|
||||
"google/gemini-2.5-pro",
|
||||
"google/gemini-pro-1.5",
|
||||
"meta-llama/llama-3.1-70b-instruct",
|
||||
"mistralai/mixtral-8x7b-instruct",
|
||||
]
|
||||
for model in preferred_models:
|
||||
try:
|
||||
if openrouter_provider.validate_model_name(model):
|
||||
return model
|
||||
except Exception as e:
|
||||
# Log the error for debugging purposes but continue searching
|
||||
import logging
|
||||
|
||||
logging.warning(f"Model validation for '{model}' on OpenRouter failed: {e}")
|
||||
if not allowed_models:
|
||||
continue
|
||||
|
||||
return None
|
||||
# 2. Keep track of the first available model as fallback
|
||||
if not first_available_model:
|
||||
first_available_model = sorted(allowed_models)[0]
|
||||
|
||||
# 3. Ask provider to pick from allowed list
|
||||
preferred_model = provider.get_preferred_model(effective_category, allowed_models)
|
||||
|
||||
if preferred_model:
|
||||
logging.debug(
|
||||
f"Provider {provider_type.value} selected '{preferred_model}' for category '{effective_category.value}'"
|
||||
)
|
||||
return preferred_model
|
||||
|
||||
# If no provider returned a preference, use first available model
|
||||
if first_available_model:
|
||||
logging.debug(f"No provider preference, using first available: {first_available_model}")
|
||||
return first_available_model
|
||||
|
||||
# Ultimate fallback if no providers have models
|
||||
logging.warning("No models available from any provider, using default fallback")
|
||||
return "gemini-2.5-flash"
|
||||
|
||||
@classmethod
|
||||
def get_available_providers_with_keys(cls) -> list[ProviderType]:
|
||||
@@ -441,6 +351,17 @@ class ModelProviderRegistry:
|
||||
instance = cls()
|
||||
instance._initialized_providers.clear()
|
||||
|
||||
@classmethod
|
||||
def reset_for_testing(cls) -> None:
|
||||
"""Reset the registry to a clean state for testing.
|
||||
|
||||
This provides a safe, public API for tests to clean up registry state
|
||||
without directly manipulating private attributes.
|
||||
"""
|
||||
cls._instance = None
|
||||
if hasattr(cls, "_providers"):
|
||||
cls._providers = {}
|
||||
|
||||
@classmethod
|
||||
def unregister_provider(cls, provider_type: ProviderType) -> None:
|
||||
"""Unregister a provider (mainly for testing)."""
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""X.AI (GROK) model provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
@@ -21,6 +24,24 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"grok-4": ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name="grok-4",
|
||||
friendly_name="X.AI (Grok 4)",
|
||||
context_window=256_000, # 256K tokens
|
||||
max_output_tokens=256_000, # 256K tokens max output
|
||||
supports_extended_thinking=True, # Grok-4 supports reasoning mode
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True, # Function calling supported
|
||||
supports_json_mode=True, # Structured outputs supported
|
||||
supports_images=True, # Multimodal capabilities
|
||||
max_image_size_mb=20.0, # Standard image size limit
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="GROK-4 (256K context) - Frontier multimodal reasoning model with advanced capabilities",
|
||||
aliases=["grok", "grok4", "grok-4"],
|
||||
),
|
||||
"grok-3": ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name="grok-3",
|
||||
@@ -37,7 +58,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||
aliases=["grok", "grok3"],
|
||||
aliases=["grok3"],
|
||||
),
|
||||
"grok-3-fast": ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
@@ -110,7 +131,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
@@ -130,6 +151,52 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
# Currently GROK models do not support extended thinking
|
||||
# This may change with future GROK model releases
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
capabilities = self.SUPPORTED_MODELS.get(resolved_name)
|
||||
if capabilities:
|
||||
return capabilities.supports_extended_thinking
|
||||
return False
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get XAI's preferred model for a given category from allowed models.
|
||||
|
||||
Args:
|
||||
category: The tool category requiring a model
|
||||
allowed_models: Pre-filtered list of models allowed by restrictions
|
||||
|
||||
Returns:
|
||||
Preferred model name or None
|
||||
"""
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
if not allowed_models:
|
||||
return None
|
||||
|
||||
if category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# Prefer GROK-4 for advanced reasoning with thinking mode
|
||||
if "grok-4" in allowed_models:
|
||||
return "grok-4"
|
||||
elif "grok-3" in allowed_models:
|
||||
return "grok-3"
|
||||
# Fall back to any available model
|
||||
return allowed_models[0]
|
||||
|
||||
elif category == ToolModelCategory.FAST_RESPONSE:
|
||||
# Prefer GROK-3-Fast for speed, then GROK-4
|
||||
if "grok-3-fast" in allowed_models:
|
||||
return "grok-3-fast"
|
||||
elif "grok-4" in allowed_models:
|
||||
return "grok-4"
|
||||
# Fall back to any available model
|
||||
return allowed_models[0]
|
||||
|
||||
else: # BALANCED or default
|
||||
# Prefer GROK-4 for balanced use (best overall capabilities)
|
||||
if "grok-4" in allowed_models:
|
||||
return "grok-4"
|
||||
elif "grok-3" in allowed_models:
|
||||
return "grok-3"
|
||||
elif "grok-3-fast" in allowed_models:
|
||||
return "grok-3-fast"
|
||||
# Fall back to any available model
|
||||
return allowed_models[0]
|
||||
|
||||
Reference in New Issue
Block a user