Merge branch 'main' into refactor-image-validation

This commit is contained in:
Beehive Innovations
2025-08-07 23:12:00 -07:00
committed by GitHub
55 changed files with 2491 additions and 623 deletions

View File

@@ -7,7 +7,10 @@ import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from utils.file_types import IMAGES, get_image_mime_type
@@ -123,10 +126,10 @@ def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint
return FixedTemperatureConstraint(1.0)
elif constraint_type == "discrete":
# For models with specific allowed values - using common OpenAI values as default
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.7)
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
else:
# Default range constraint (for "range" or None)
return RangeTemperatureConstraint(0.0, 2.0, 0.7)
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
@dataclass
@@ -159,24 +162,11 @@ class ModelCapabilities:
# Custom model flag (for models that only work with custom endpoints)
is_custom: bool = False # Whether this model requires custom API endpoints
# Temperature constraint object - preferred way to define temperature limits
# Temperature constraint object - defines temperature limits and behavior
temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
)
# Backward compatibility property for existing code
@property
def temperature_range(self) -> tuple[float, float]:
"""Backward compatibility for existing code that uses temperature_range."""
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
elif isinstance(self.temperature_constraint, FixedTemperatureConstraint):
return (self.temperature_constraint.value, self.temperature_constraint.value)
elif isinstance(self.temperature_constraint, DiscreteTemperatureConstraint):
values = self.temperature_constraint.allowed_values
return (min(values), max(values))
return (0.0, 2.0) # Fallback
@dataclass
class ModelResponse:
@@ -220,7 +210,7 @@ class ModelProvider(ABC):
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
@@ -276,18 +266,15 @@ class ModelProvider(ABC):
if not capabilities.supports_temperature:
return None
# Get temperature range
min_temp, max_temp = capabilities.temperature_range
# Use temperature constraint to get corrected value
corrected_temp = capabilities.temperature_constraint.get_corrected_value(requested_temperature)
# Clamp to valid range
if requested_temperature < min_temp:
logger.debug(f"Clamping temperature from {requested_temperature} to {min_temp} for model {model_name}")
return min_temp
elif requested_temperature > max_temp:
logger.debug(f"Clamping temperature from {requested_temperature} to {max_temp} for model {model_name}")
return max_temp
else:
return requested_temperature
if corrected_temp != requested_temperature:
logger.debug(
f"Adjusting temperature from {requested_temperature} to {corrected_temp} for model {model_name}"
)
return corrected_temp
except Exception as e:
logger.debug(f"Could not determine effective temperature for {model_name}: {e}")
@@ -302,10 +289,10 @@ class ModelProvider(ABC):
"""
capabilities = self.get_capabilities(model_name)
# Validate temperature
min_temp, max_temp = capabilities.temperature_range
if not min_temp <= temperature <= max_temp:
raise ValueError(f"Temperature {temperature} out of range [{min_temp}, {max_temp}] for model {model_name}")
# Validate temperature using constraint
if not capabilities.temperature_constraint.validate(temperature):
constraint_desc = capabilities.temperature_constraint.get_description()
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
@abstractmethod
def supports_thinking_mode(self, model_name: str) -> bool:
@@ -520,3 +507,28 @@ class ModelProvider(ABC):
"""
# Base implementation: no resources to clean up
return
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get the preferred model from this provider for a given category.
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of model names that are allowed by restrictions
Returns:
Model name if this provider has a preference, None otherwise
"""
# Default implementation - providers can override with specific logic
return None
def get_model_registry(self) -> Optional[dict[str, Any]]:
"""Get the model registry for providers that maintain one.
This is a hook method for providers like CustomProvider that maintain
a dynamic model registry.
Returns:
Model registry dict or None if not applicable
"""
# Default implementation - most providers don't have a registry
return None

View File

@@ -236,7 +236,7 @@ class CustomProvider(OpenAICompatibleProvider):
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:

View File

@@ -375,7 +375,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
images: Optional[list[str]] = None,
**kwargs,

View File

@@ -3,7 +3,10 @@
import base64
import logging
import time
from typing import Optional
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from google import genai
from google.genai import types
@@ -18,6 +21,25 @@ class GeminiModelProvider(ModelProvider):
# Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = {
"gemini-2.5-pro": ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.5-pro",
friendly_name="Gemini (Pro 2.5)",
context_window=1_048_576, # 1M tokens
max_output_tokens=65_536,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # Vision capability
max_image_size_mb=32.0, # Higher limit for Pro model
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
max_thinking_tokens=32768, # Max thinking tokens for Pro model
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
aliases=["pro", "gemini pro", "gemini-pro"],
),
"gemini-2.0-flash": ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.0-flash",
@@ -74,25 +96,6 @@ class GeminiModelProvider(ModelProvider):
description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
aliases=["flash", "flash2.5"],
),
"gemini-2.5-pro": ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.5-pro",
friendly_name="Gemini (Pro 2.5)",
context_window=1_048_576, # 1M tokens
max_output_tokens=65_536,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # Vision capability
max_image_size_mb=32.0, # Higher limit for Pro model
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
max_thinking_tokens=32768, # Max thinking tokens for Pro model
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
aliases=["pro", "gemini pro", "gemini-pro"],
),
}
# Thinking mode configurations - percentages of model's max_thinking_tokens
@@ -151,7 +154,7 @@ class GeminiModelProvider(ModelProvider):
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
thinking_mode: str = "medium",
images: Optional[list[str]] = None,
@@ -458,3 +461,67 @@ class GeminiModelProvider(ModelProvider):
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
return None
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get Gemini's preferred model for a given category from allowed models.
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of models allowed by restrictions
Returns:
Preferred model name or None
"""
from tools.models import ToolModelCategory
if not allowed_models:
return None
# Helper to find best model from candidates
def find_best(candidates: list[str]) -> Optional[str]:
"""Return best model from candidates (sorted for consistency)."""
return sorted(candidates, reverse=True)[0] if candidates else None
if category == ToolModelCategory.EXTENDED_REASONING:
# For extended reasoning, prefer models with thinking support
# First try Pro models that support thinking
pro_thinking = [
m
for m in allowed_models
if "pro" in m and m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
]
if pro_thinking:
return find_best(pro_thinking)
# Then any model that supports thinking
any_thinking = [
m
for m in allowed_models
if m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
]
if any_thinking:
return find_best(any_thinking)
# Finally, just prefer Pro models even without thinking
pro_models = [m for m in allowed_models if "pro" in m]
if pro_models:
return find_best(pro_models)
elif category == ToolModelCategory.FAST_RESPONSE:
# Prefer Flash models for speed
flash_models = [m for m in allowed_models if "flash" in m]
if flash_models:
return find_best(flash_models)
# Default for BALANCED or as fallback
# Prefer Flash for balanced use, then Pro, then anything
flash_models = [m for m in allowed_models if "flash" in m]
if flash_models:
return find_best(flash_models)
pro_models = [m for m in allowed_models if "pro" in m]
if pro_models:
return find_best(pro_models)
# Ultimate fallback to best available model
return find_best(allowed_models)

View File

@@ -1,5 +1,7 @@
"""Base class for OpenAI-compatible API providers."""
import base64
import copy
import ipaddress
import logging
import os
@@ -219,10 +221,20 @@ class OpenAICompatibleProvider(ModelProvider):
# Create httpx client with minimal config to avoid proxy conflicts
# Note: proxies parameter was removed in httpx 0.28.0
http_client = httpx.Client(
timeout=timeout_config,
follow_redirects=True,
)
# Check for test transport injection
if hasattr(self, "_test_transport"):
# Use custom transport for testing (HTTP recording/replay)
http_client = httpx.Client(
transport=self._test_transport,
timeout=timeout_config,
follow_redirects=True,
)
else:
# Normal production client
http_client = httpx.Client(
timeout=timeout_config,
follow_redirects=True,
)
# Keep client initialization minimal to avoid proxy parameter conflicts
client_kwargs = {
@@ -263,6 +275,63 @@ class OpenAICompatibleProvider(ModelProvider):
return self._client
def _sanitize_for_logging(self, params: dict) -> dict:
"""Sanitize sensitive data from parameters before logging.
Args:
params: Dictionary of API parameters
Returns:
dict: Sanitized copy of parameters safe for logging
"""
sanitized = copy.deepcopy(params)
# Sanitize messages content
if "input" in sanitized:
for msg in sanitized.get("input", []):
if isinstance(msg, dict) and "content" in msg:
for content_item in msg.get("content", []):
if isinstance(content_item, dict) and "text" in content_item:
# Truncate long text and add ellipsis
text = content_item["text"]
if len(text) > 100:
content_item["text"] = text[:100] + "... [truncated]"
# Remove any API keys that might be in headers/auth
sanitized.pop("api_key", None)
sanitized.pop("authorization", None)
return sanitized
def _safe_extract_output_text(self, response) -> str:
"""Safely extract output_text from o3-pro response with validation.
Args:
response: Response object from OpenAI SDK
Returns:
str: The output text content
Raises:
ValueError: If output_text is missing, None, or not a string
"""
logging.debug(f"Response object type: {type(response)}")
logging.debug(f"Response attributes: {dir(response)}")
if not hasattr(response, "output_text"):
raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
content = response.output_text
logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
if content is None:
raise ValueError("o3-pro returned None for output_text")
if not isinstance(content, str):
raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}")
return content
def _generate_with_responses_endpoint(
self,
model_name: str,
@@ -308,30 +377,23 @@ class OpenAICompatibleProvider(ModelProvider):
max_retries = 4
retry_delays = [1, 3, 5, 8]
last_exception = None
actual_attempts = 0
for attempt in range(max_retries):
try: # Log the exact payload being sent for debugging
try: # Log sanitized payload for debugging
import json
sanitized_params = self._sanitize_for_logging(completion_params)
logging.info(
f"o3-pro API request payload: {json.dumps(completion_params, indent=2, ensure_ascii=False)}"
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
)
# Use OpenAI client's responses endpoint
response = self.client.responses.create(**completion_params)
# Extract content and usage from responses endpoint format
# The response format is different for responses endpoint
content = ""
if hasattr(response, "output") and response.output:
if hasattr(response.output, "content") and response.output.content:
# Look for output_text in content
for content_item in response.output.content:
if hasattr(content_item, "type") and content_item.type == "output_text":
content = content_item.text
break
elif hasattr(response.output, "text"):
content = response.output.text
# Extract content from responses endpoint format
# Use validation helper to safely extract output_text
content = self._safe_extract_output_text(response)
# Try to extract usage information
usage = None
@@ -370,14 +432,13 @@ class OpenAICompatibleProvider(ModelProvider):
if is_retryable and attempt < max_retries - 1:
delay = retry_delays[attempt]
logging.warning(
f"Retryable error for o3-pro responses endpoint, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
else:
break
# If we get here, all retries failed
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
@@ -387,7 +448,7 @@ class OpenAICompatibleProvider(ModelProvider):
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
images: Optional[list[str]] = None,
**kwargs,
@@ -480,7 +541,7 @@ class OpenAICompatibleProvider(ModelProvider):
completion_params[key] = value
# Check if this is o3-pro and needs the responses endpoint
if resolved_model == "o3-pro-2025-06-10":
if resolved_model == "o3-pro":
# This model requires the /v1/responses endpoint
# If it fails, we should not fall back to chat/completions
return self._generate_with_responses_endpoint(
@@ -496,8 +557,10 @@ class OpenAICompatibleProvider(ModelProvider):
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
last_exception = None
actual_attempts = 0
for attempt in range(max_retries):
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
try:
# Generate completion
response = self.client.chat.completions.create(**completion_params)
@@ -535,12 +598,11 @@ class OpenAICompatibleProvider(ModelProvider):
# Log retry attempt
logging.warning(
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
# If we get here, all retries failed
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
@@ -575,11 +637,7 @@ class OpenAICompatibleProvider(ModelProvider):
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
# Try common encodings based on model patterns
if "gpt-4" in model_name or "gpt-3.5" in model_name:
encoding = tiktoken.get_encoding("cl100k_base")
else:
encoding = tiktoken.get_encoding("cl100k_base") # Default
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
@@ -678,11 +736,13 @@ class OpenAICompatibleProvider(ModelProvider):
"""
# Common vision-capable models - only include models that actually support images
vision_models = {
"gpt-5",
"gpt-5-mini",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4-vision-preview",
"gpt-4.1-2025-04-14", # GPT-4.1 supports vision
"gpt-4.1-2025-04-14",
"o3",
"o3-mini",
"o3-pro",

View File

@@ -1,7 +1,10 @@
"""OpenAI model provider implementation."""
import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .base import (
ModelCapabilities,
@@ -19,6 +22,60 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
# Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = {
"gpt-5": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="gpt-5",
friendly_name="OpenAI (GPT-5)",
context_window=400_000, # 400K tokens
max_output_tokens=128_000, # 128K max output tokens
supports_extended_thinking=True, # Supports reasoning tokens
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # GPT-5 supports vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=True, # Regular models accept temperature parameter
temperature_constraint=create_temperature_constraint("fixed"),
description="GPT-5 (400K context, 128K output) - Advanced model with reasoning support",
aliases=["gpt5", "gpt-5"],
),
"gpt-5-mini": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="gpt-5-mini",
friendly_name="OpenAI (GPT-5-mini)",
context_window=400_000, # 400K tokens
max_output_tokens=128_000, # 128K max output tokens
supports_extended_thinking=True, # Supports reasoning tokens
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # GPT-5-mini supports vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=True,
temperature_constraint=create_temperature_constraint("fixed"),
description="GPT-5-mini (400K context, 128K output) - Efficient variant with reasoning support",
aliases=["gpt5-mini", "gpt5mini", "mini"],
),
"gpt-5-nano": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="gpt-5-nano",
friendly_name="OpenAI (GPT-5 nano)",
context_window=400_000,
max_output_tokens=128_000,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("fixed"),
description="GPT-5 nano (400K context) - Fastest, cheapest version of GPT-5 for summarization and classification tasks",
aliases=["gpt5nano", "gpt5-nano", "nano"],
),
"o3": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="o3",
@@ -55,9 +112,9 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
aliases=["o3mini", "o3-mini"],
),
"o3-pro-2025-06-10": ModelCapabilities(
"o3-pro": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="o3-pro-2025-06-10",
model_name="o3-pro",
friendly_name="OpenAI (O3-Pro)",
context_window=200_000, # 200K tokens
max_output_tokens=65536, # 64K max output tokens
@@ -89,11 +146,11 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
supports_temperature=False, # O4 models don't accept temperature parameter
temperature_constraint=create_temperature_constraint("fixed"),
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
aliases=["mini", "o4mini", "o4-mini"],
aliases=["o4mini", "o4-mini"],
),
"gpt-4.1-2025-04-14": ModelCapabilities(
"gpt-4.1": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="gpt-4.1-2025-04-14",
model_name="gpt-4.1",
friendly_name="OpenAI (GPT 4.1)",
context_window=1_000_000, # 1M tokens
max_output_tokens=32_768,
@@ -107,7 +164,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
supports_temperature=True, # Regular models accept temperature parameter
temperature_constraint=create_temperature_constraint("range"),
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
aliases=["gpt4.1"],
aliases=["gpt4.1", "gpt-4.1"],
),
}
@@ -119,21 +176,41 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific OpenAI model."""
# Resolve shorthand
# First check if it's a key in SUPPORTED_MODELS
if model_name in self.SUPPORTED_MODELS:
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, model_name, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
return self.SUPPORTED_MODELS[model_name]
# Try resolving as alias
resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported OpenAI model: {model_name}")
# Check if resolved name is a key
if resolved_name in self.SUPPORTED_MODELS:
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
return self.SUPPORTED_MODELS[resolved_name]
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
# Finally check if resolved name matches any API model name
for key, capabilities in self.SUPPORTED_MODELS.items():
if resolved_name == capabilities.model_name:
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
return self.SUPPORTED_MODELS[resolved_name]
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, key, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
return capabilities
raise ValueError(f"Unsupported OpenAI model: {model_name}")
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
@@ -162,7 +239,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
@@ -182,6 +259,47 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
# Currently no OpenAI models support extended thinking
# This may change with future O3 models
# GPT-5 models support reasoning tokens (extended thinking)
resolved_name = self._resolve_model_name(model_name)
if resolved_name in ["gpt-5", "gpt-5-mini"]:
return True
# O3 models don't support extended thinking yet
return False
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get OpenAI's preferred model for a given category from allowed models.
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of models allowed by restrictions
Returns:
Preferred model name or None
"""
from tools.models import ToolModelCategory
if not allowed_models:
return None
# Helper to find first available from preference list
def find_first(preferences: list[str]) -> Optional[str]:
"""Return first available model from preference list."""
for model in preferences:
if model in allowed_models:
return model
return None
if category == ToolModelCategory.EXTENDED_REASONING:
# Prefer models with extended thinking support
preferred = find_first(["o3", "o3-pro", "gpt-5"])
return preferred if preferred else allowed_models[0]
elif category == ToolModelCategory.FAST_RESPONSE:
# Prefer fast, cost-efficient models
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
return preferred if preferred else allowed_models[0]
else: # BALANCED or default
# Prefer balanced performance/cost models
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
return preferred if preferred else allowed_models[0]

View File

@@ -158,7 +158,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:

View File

@@ -15,6 +15,17 @@ class ModelProviderRegistry:
_instance = None
# Provider priority order for model selection
# Native APIs first, then custom endpoints, then catch-all providers
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.XAI, # Direct X.AI GROK access
ProviderType.DIAL, # DIAL unified API access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
def __new__(cls):
"""Singleton pattern for registry."""
if cls._instance is None:
@@ -103,30 +114,19 @@ class ModelProviderRegistry:
3. OPENROUTER - Catch-all for cloud models via unified API
Args:
model_name: Name of the model (e.g., "gemini-2.5-flash", "o3-mini")
model_name: Name of the model (e.g., "gemini-2.5-flash", "gpt5")
Returns:
ModelProvider instance that supports this model
"""
logging.debug(f"get_provider_for_model called with model_name='{model_name}'")
# Define explicit provider priority order
# Native APIs first, then custom endpoints, then catch-all providers
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.XAI, # Direct X.AI GROK access
ProviderType.DIAL, # DIAL unified API access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
# Check providers in priority order
instance = cls()
logging.debug(f"Registry instance: {instance}")
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
for provider_type in PROVIDER_PRIORITY_ORDER:
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
if provider_type in instance._providers:
logging.debug(f"Found {provider_type} in registry")
# Get or create provider instance
@@ -244,14 +244,49 @@ class ModelProviderRegistry:
return os.getenv(env_var)
@classmethod
def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]:
"""Get a list of allowed canonical model names for a given provider.
Args:
provider: The provider instance to get models for
provider_type: The provider type for restriction checking
Returns:
List of model names that are both supported and allowed
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
allowed_models = []
# Get the provider's supported models
try:
# Use list_models to get all supported models (handles both regular and custom providers)
supported_models = provider.list_models(respect_restrictions=False)
except (NotImplementedError, AttributeError):
# Fallback to SUPPORTED_MODELS if list_models not implemented
try:
supported_models = list(provider.SUPPORTED_MODELS.keys())
except AttributeError:
supported_models = []
# Filter by restrictions
for model_name in supported_models:
if restriction_service.is_allowed(provider_type, model_name):
allowed_models.append(model_name)
return allowed_models
@classmethod
def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str:
"""Get the preferred fallback model based on available API keys and tool category.
"""Get the preferred fallback model based on provider priority and tool category.
This method checks which providers have valid API keys and returns
a sensible default model for auto mode fallback situations.
Takes into account model restrictions when selecting fallback models.
This method orchestrates model selection by:
1. Getting allowed models for each provider (respecting restrictions)
2. Asking providers for their preference from the allowed list
3. Falling back to first available model if no preference given
Args:
tool_category: Optional category to influence model selection
@@ -259,167 +294,42 @@ class ModelProviderRegistry:
Returns:
Model name string for fallback use
"""
# Import here to avoid circular import
from tools.models import ToolModelCategory
# Get available models respecting restrictions
available_models = cls.get_available_models(respect_restrictions=True)
effective_category = tool_category or ToolModelCategory.BALANCED
first_available_model = None
# Group by provider
openai_models = [m for m, p in available_models.items() if p == ProviderType.OPENAI]
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
xai_models = [m for m, p in available_models.items() if p == ProviderType.XAI]
openrouter_models = [m for m, p in available_models.items() if p == ProviderType.OPENROUTER]
custom_models = [m for m, p in available_models.items() if p == ProviderType.CUSTOM]
# Ask each provider for their preference in priority order
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
provider = cls.get_provider(provider_type)
if provider:
# 1. Registry filters the models first
allowed_models = cls._get_allowed_models_for_provider(provider, provider_type)
openai_available = bool(openai_models)
gemini_available = bool(gemini_models)
xai_available = bool(xai_models)
openrouter_available = bool(openrouter_models)
custom_available = bool(custom_models)
if tool_category == ToolModelCategory.EXTENDED_REASONING:
# Prefer thinking-capable models for deep reasoning tools
if openai_available and "o3" in openai_models:
return "o3" # O3 for deep reasoning
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
elif xai_available and "grok-3" in xai_models:
return "grok-3" # GROK-3 for deep reasoning
elif xai_available and xai_models:
# Fall back to any available XAI model
return xai_models[0]
elif gemini_available and any("pro" in m for m in gemini_models):
# Find the pro model (handles full names)
return next(m for m in gemini_models if "pro" in m)
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
elif openrouter_available:
# Try to find thinking-capable model from openrouter
thinking_model = cls._find_extended_thinking_model()
if thinking_model:
return thinking_model
# Fallback to first available OpenRouter model
return openrouter_models[0]
elif custom_available:
# Fallback to custom models when available
return custom_models[0]
else:
# Fallback to pro if nothing found
return "gemini-2.5-pro"
elif tool_category == ToolModelCategory.FAST_RESPONSE:
# Prefer fast, cost-efficient models
if openai_available and "o4-mini" in openai_models:
return "o4-mini" # Latest, fast and efficient
elif openai_available and "o3-mini" in openai_models:
return "o3-mini" # Second choice
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
elif xai_available and "grok-3-fast" in xai_models:
return "grok-3-fast" # GROK-3 Fast for speed
elif xai_available and xai_models:
# Fall back to any available XAI model
return xai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
# Find the flash model (handles full names)
# Prefer 2.5 over 2.0 for backward compatibility
flash_models = [m for m in gemini_models if "flash" in m]
# Sort to ensure 2.5 comes before 2.0
flash_models_sorted = sorted(flash_models, reverse=True)
return flash_models_sorted[0]
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
elif openrouter_available:
# Fallback to first available OpenRouter model
return openrouter_models[0]
elif custom_available:
# Fallback to custom models when available
return custom_models[0]
else:
# Default to flash
return "gemini-2.5-flash"
# BALANCED or no category specified - use existing balanced logic
if openai_available and "o4-mini" in openai_models:
return "o4-mini" # Latest balanced performance/cost
elif openai_available and "o3-mini" in openai_models:
return "o3-mini" # Second choice
elif openai_available and openai_models:
return openai_models[0]
elif xai_available and "grok-3" in xai_models:
return "grok-3" # GROK-3 as balanced choice
elif xai_available and xai_models:
return xai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
# Prefer 2.5 over 2.0 for backward compatibility
flash_models = [m for m in gemini_models if "flash" in m]
flash_models_sorted = sorted(flash_models, reverse=True)
return flash_models_sorted[0]
elif gemini_available and gemini_models:
return gemini_models[0]
elif openrouter_available:
return openrouter_models[0]
elif custom_available:
# Fallback to custom models when available
return custom_models[0]
else:
# No models available due to restrictions - check if any providers exist
if not available_models:
# This might happen if all models are restricted
logging.warning("No models available due to restrictions")
# Return a reasonable default for backward compatibility
return "gemini-2.5-flash"
@classmethod
def _find_extended_thinking_model(cls) -> Optional[str]:
"""Find a model suitable for extended reasoning from custom/openrouter providers.
Returns:
Model name if found, None otherwise
"""
# Check custom provider first
custom_provider = cls.get_provider(ProviderType.CUSTOM)
if custom_provider:
# Check if it's a CustomModelProvider and has thinking models
try:
from providers.custom import CustomProvider
if isinstance(custom_provider, CustomProvider) and hasattr(custom_provider, "model_registry"):
for model_name, config in custom_provider.model_registry.items():
if config.get("supports_extended_thinking", False):
return model_name
except ImportError:
pass
# Then check OpenRouter for high-context/powerful models
openrouter_provider = cls.get_provider(ProviderType.OPENROUTER)
if openrouter_provider:
# Prefer models known for deep reasoning
preferred_models = [
"anthropic/claude-sonnet-4",
"anthropic/claude-opus-4",
"google/gemini-2.5-pro",
"google/gemini-pro-1.5",
"meta-llama/llama-3.1-70b-instruct",
"mistralai/mixtral-8x7b-instruct",
]
for model in preferred_models:
try:
if openrouter_provider.validate_model_name(model):
return model
except Exception as e:
# Log the error for debugging purposes but continue searching
import logging
logging.warning(f"Model validation for '{model}' on OpenRouter failed: {e}")
if not allowed_models:
continue
return None
# 2. Keep track of the first available model as fallback
if not first_available_model:
first_available_model = sorted(allowed_models)[0]
# 3. Ask provider to pick from allowed list
preferred_model = provider.get_preferred_model(effective_category, allowed_models)
if preferred_model:
logging.debug(
f"Provider {provider_type.value} selected '{preferred_model}' for category '{effective_category.value}'"
)
return preferred_model
# If no provider returned a preference, use first available model
if first_available_model:
logging.debug(f"No provider preference, using first available: {first_available_model}")
return first_available_model
# Ultimate fallback if no providers have models
logging.warning("No models available from any provider, using default fallback")
return "gemini-2.5-flash"
@classmethod
def get_available_providers_with_keys(cls) -> list[ProviderType]:
@@ -441,6 +351,17 @@ class ModelProviderRegistry:
instance = cls()
instance._initialized_providers.clear()
@classmethod
def reset_for_testing(cls) -> None:
"""Reset the registry to a clean state for testing.
This provides a safe, public API for tests to clean up registry state
without directly manipulating private attributes.
"""
cls._instance = None
if hasattr(cls, "_providers"):
cls._providers = {}
@classmethod
def unregister_provider(cls, provider_type: ProviderType) -> None:
"""Unregister a provider (mainly for testing)."""

View File

@@ -1,7 +1,10 @@
"""X.AI (GROK) model provider implementation."""
import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .base import (
ModelCapabilities,
@@ -21,6 +24,24 @@ class XAIModelProvider(OpenAICompatibleProvider):
# Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = {
"grok-4": ModelCapabilities(
provider=ProviderType.XAI,
model_name="grok-4",
friendly_name="X.AI (Grok 4)",
context_window=256_000, # 256K tokens
max_output_tokens=256_000, # 256K tokens max output
supports_extended_thinking=True, # Grok-4 supports reasoning mode
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True, # Function calling supported
supports_json_mode=True, # Structured outputs supported
supports_images=True, # Multimodal capabilities
max_image_size_mb=20.0, # Standard image size limit
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="GROK-4 (256K context) - Frontier multimodal reasoning model with advanced capabilities",
aliases=["grok", "grok4", "grok-4"],
),
"grok-3": ModelCapabilities(
provider=ProviderType.XAI,
model_name="grok-3",
@@ -37,7 +58,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
aliases=["grok", "grok3"],
aliases=["grok3"],
),
"grok-3-fast": ModelCapabilities(
provider=ProviderType.XAI,
@@ -110,7 +131,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
@@ -130,6 +151,52 @@ class XAIModelProvider(OpenAICompatibleProvider):
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
# Currently GROK models do not support extended thinking
# This may change with future GROK model releases
resolved_name = self._resolve_model_name(model_name)
capabilities = self.SUPPORTED_MODELS.get(resolved_name)
if capabilities:
return capabilities.supports_extended_thinking
return False
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get XAI's preferred model for a given category from allowed models.
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of models allowed by restrictions
Returns:
Preferred model name or None
"""
from tools.models import ToolModelCategory
if not allowed_models:
return None
if category == ToolModelCategory.EXTENDED_REASONING:
# Prefer GROK-4 for advanced reasoning with thinking mode
if "grok-4" in allowed_models:
return "grok-4"
elif "grok-3" in allowed_models:
return "grok-3"
# Fall back to any available model
return allowed_models[0]
elif category == ToolModelCategory.FAST_RESPONSE:
# Prefer GROK-3-Fast for speed, then GROK-4
if "grok-3-fast" in allowed_models:
return "grok-3-fast"
elif "grok-4" in allowed_models:
return "grok-4"
# Fall back to any available model
return allowed_models[0]
else: # BALANCED or default
# Prefer GROK-4 for balanced use (best overall capabilities)
if "grok-4" in allowed_models:
return "grok-4"
elif "grok-3" in allowed_models:
return "grok-3"
elif "grok-3-fast" in allowed_models:
return "grok-3-fast"
# Fall back to any available model
return allowed_models[0]