GPT-5, GPT-5-mini support

Improvements to model name resolution
Improved instructions for multi-step workflows when continuation is available
Improved instructions for chat tool
Improved preferred model resolution, moved code from registry -> each provider
Updated tests
This commit is contained in:
Fahad
2025-08-08 08:51:34 +05:00
parent 9a4791cb06
commit 1a8ec2e12f
30 changed files with 792 additions and 483 deletions

View File

@@ -37,13 +37,13 @@ OPENROUTER_API_KEY=your_openrouter_api_key_here
# Optional: Default model to use # Optional: Default model to use
# Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high', # Options: 'auto' (Claude picks best model), 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high',
# 'grok', 'opus-4', 'sonnet-4', or any DIAL model if DIAL is configured # 'gpt-5', 'gpt-5-mini', 'grok', 'opus-4', 'sonnet-4', or any DIAL model if DIAL is configured
# When set to 'auto', Claude will select the best model for each task # When set to 'auto', Claude will select the best model for each task
# Defaults to 'auto' if not specified # Defaults to 'auto' if not specified
DEFAULT_MODEL=auto DEFAULT_MODEL=auto
# Optional: Default thinking mode for ThinkDeep tool # Optional: Default thinking mode for ThinkDeep tool
# NOTE: Only applies to models that support extended thinking (e.g., Gemini 2.5 Pro) # NOTE: Only applies to models that support extended thinking (e.g., Gemini 2.5 Pro, GPT-5 models)
# Flash models (2.0) will use system prompt engineering instead # Flash models (2.0) will use system prompt engineering instead
# Token consumption per mode: # Token consumption per mode:
# minimal: 128 tokens - Quick analysis, fastest response # minimal: 128 tokens - Quick analysis, fastest response
@@ -65,6 +65,8 @@ DEFAULT_THINKING_MODE_THINKDEEP=high
# - o3-mini (200K context, balanced) # - o3-mini (200K context, balanced)
# - o4-mini (200K context, latest balanced, temperature=1.0 only) # - o4-mini (200K context, latest balanced, temperature=1.0 only)
# - o4-mini-high (200K context, enhanced reasoning, temperature=1.0 only) # - o4-mini-high (200K context, enhanced reasoning, temperature=1.0 only)
# - gpt-5 (400K context, 128K output, reasoning tokens)
# - gpt-5-mini (400K context, 128K output, reasoning tokens)
# - mini (shorthand for o4-mini) # - mini (shorthand for o4-mini)
# #
# Supported Google/Gemini models: # Supported Google/Gemini models:

View File

@@ -75,10 +75,10 @@ DEFAULT_CONSENSUS_MAX_INSTANCES_PER_COMBINATION = 2
# #
# IMPORTANT: This limit ONLY applies to the Claude CLI ↔ MCP Server transport boundary. # IMPORTANT: This limit ONLY applies to the Claude CLI ↔ MCP Server transport boundary.
# It does NOT limit internal MCP Server operations like system prompts, file embeddings, # It does NOT limit internal MCP Server operations like system prompts, file embeddings,
# conversation history, or content sent to external models (Gemini/O3/OpenRouter). # conversation history, or content sent to external models (Gemini/OpenAI/OpenRouter).
# #
# MCP Protocol Architecture: # MCP Protocol Architecture:
# Claude CLI ←→ MCP Server ←→ External Model (Gemini/O3/etc.) # Claude CLI ←→ MCP Server ←→ External Model (Gemini/OpenAI/etc.)
# ↑ ↑ # ↑ ↑
# │ │ # │ │
# MCP transport Internal processing # MCP transport Internal processing

View File

@@ -4,7 +4,10 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -118,10 +121,10 @@ def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint
return FixedTemperatureConstraint(1.0) return FixedTemperatureConstraint(1.0)
elif constraint_type == "discrete": elif constraint_type == "discrete":
# For models with specific allowed values - using common OpenAI values as default # For models with specific allowed values - using common OpenAI values as default
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.7) return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
else: else:
# Default range constraint (for "range" or None) # Default range constraint (for "range" or None)
return RangeTemperatureConstraint(0.0, 2.0, 0.7) return RangeTemperatureConstraint(0.0, 2.0, 0.3)
@dataclass @dataclass
@@ -154,24 +157,11 @@ class ModelCapabilities:
# Custom model flag (for models that only work with custom endpoints) # Custom model flag (for models that only work with custom endpoints)
is_custom: bool = False # Whether this model requires custom API endpoints is_custom: bool = False # Whether this model requires custom API endpoints
# Temperature constraint object - preferred way to define temperature limits # Temperature constraint object - defines temperature limits and behavior
temperature_constraint: TemperatureConstraint = field( temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7) default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
) )
# Backward compatibility property for existing code
@property
def temperature_range(self) -> tuple[float, float]:
"""Backward compatibility for existing code that uses temperature_range."""
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
elif isinstance(self.temperature_constraint, FixedTemperatureConstraint):
return (self.temperature_constraint.value, self.temperature_constraint.value)
elif isinstance(self.temperature_constraint, DiscreteTemperatureConstraint):
values = self.temperature_constraint.allowed_values
return (min(values), max(values))
return (0.0, 2.0) # Fallback
@dataclass @dataclass
class ModelResponse: class ModelResponse:
@@ -268,18 +258,15 @@ class ModelProvider(ABC):
if not capabilities.supports_temperature: if not capabilities.supports_temperature:
return None return None
# Get temperature range # Use temperature constraint to get corrected value
min_temp, max_temp = capabilities.temperature_range corrected_temp = capabilities.temperature_constraint.get_corrected_value(requested_temperature)
# Clamp to valid range if corrected_temp != requested_temperature:
if requested_temperature < min_temp: logger.debug(
logger.debug(f"Clamping temperature from {requested_temperature} to {min_temp} for model {model_name}") f"Adjusting temperature from {requested_temperature} to {corrected_temp} for model {model_name}"
return min_temp )
elif requested_temperature > max_temp:
logger.debug(f"Clamping temperature from {requested_temperature} to {max_temp} for model {model_name}") return corrected_temp
return max_temp
else:
return requested_temperature
except Exception as e: except Exception as e:
logger.debug(f"Could not determine effective temperature for {model_name}: {e}") logger.debug(f"Could not determine effective temperature for {model_name}: {e}")
@@ -294,10 +281,10 @@ class ModelProvider(ABC):
""" """
capabilities = self.get_capabilities(model_name) capabilities = self.get_capabilities(model_name)
# Validate temperature # Validate temperature using constraint
min_temp, max_temp = capabilities.temperature_range if not capabilities.temperature_constraint.validate(temperature):
if not min_temp <= temperature <= max_temp: constraint_desc = capabilities.temperature_constraint.get_description()
raise ValueError(f"Temperature {temperature} out of range [{min_temp}, {max_temp}] for model {model_name}") raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
@abstractmethod @abstractmethod
def supports_thinking_mode(self, model_name: str) -> bool: def supports_thinking_mode(self, model_name: str) -> bool:
@@ -441,3 +428,28 @@ class ModelProvider(ABC):
""" """
# Base implementation: no resources to clean up # Base implementation: no resources to clean up
return return
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get the preferred model from this provider for a given category.
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of model names that are allowed by restrictions
Returns:
Model name if this provider has a preference, None otherwise
"""
# Default implementation - providers can override with specific logic
return None
def get_model_registry(self) -> Optional[dict[str, Any]]:
"""Get the model registry for providers that maintain one.
This is a hook method for providers like CustomProvider that maintain
a dynamic model registry.
Returns:
Model registry dict or None if not applicable
"""
# Default implementation - most providers don't have a registry
return None

View File

@@ -4,7 +4,10 @@ import base64
import logging import logging
import os import os
import time import time
from typing import Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from google import genai from google import genai
from google.genai import types from google.genai import types
@@ -19,6 +22,25 @@ class GeminiModelProvider(ModelProvider):
# Model configurations using ModelCapabilities objects # Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"gemini-2.5-pro": ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.5-pro",
friendly_name="Gemini (Pro 2.5)",
context_window=1_048_576, # 1M tokens
max_output_tokens=65_536,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # Vision capability
max_image_size_mb=32.0, # Higher limit for Pro model
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
max_thinking_tokens=32768, # Max thinking tokens for Pro model
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
aliases=["pro", "gemini pro", "gemini-pro"],
),
"gemini-2.0-flash": ModelCapabilities( "gemini-2.0-flash": ModelCapabilities(
provider=ProviderType.GOOGLE, provider=ProviderType.GOOGLE,
model_name="gemini-2.0-flash", model_name="gemini-2.0-flash",
@@ -75,25 +97,6 @@ class GeminiModelProvider(ModelProvider):
description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations", description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
aliases=["flash", "flash2.5"], aliases=["flash", "flash2.5"],
), ),
"gemini-2.5-pro": ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.5-pro",
friendly_name="Gemini (Pro 2.5)",
context_window=1_048_576, # 1M tokens
max_output_tokens=65_536,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # Vision capability
max_image_size_mb=32.0, # Higher limit for Pro model
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
max_thinking_tokens=32768, # Max thinking tokens for Pro model
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
aliases=["pro", "gemini pro", "gemini-pro"],
),
} }
# Thinking mode configurations - percentages of model's max_thinking_tokens # Thinking mode configurations - percentages of model's max_thinking_tokens
@@ -465,3 +468,67 @@ class GeminiModelProvider(ModelProvider):
except Exception as e: except Exception as e:
logger.error(f"Error processing image {image_path}: {e}") logger.error(f"Error processing image {image_path}: {e}")
return None return None
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get Gemini's preferred model for a given category from allowed models.
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of models allowed by restrictions
Returns:
Preferred model name or None
"""
from tools.models import ToolModelCategory
if not allowed_models:
return None
# Helper to find best model from candidates
def find_best(candidates: list[str]) -> Optional[str]:
"""Return best model from candidates (sorted for consistency)."""
return sorted(candidates, reverse=True)[0] if candidates else None
if category == ToolModelCategory.EXTENDED_REASONING:
# For extended reasoning, prefer models with thinking support
# First try Pro models that support thinking
pro_thinking = [
m
for m in allowed_models
if "pro" in m and m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
]
if pro_thinking:
return find_best(pro_thinking)
# Then any model that supports thinking
any_thinking = [
m
for m in allowed_models
if m in self.SUPPORTED_MODELS and self.SUPPORTED_MODELS[m].supports_extended_thinking
]
if any_thinking:
return find_best(any_thinking)
# Finally, just prefer Pro models even without thinking
pro_models = [m for m in allowed_models if "pro" in m]
if pro_models:
return find_best(pro_models)
elif category == ToolModelCategory.FAST_RESPONSE:
# Prefer Flash models for speed
flash_models = [m for m in allowed_models if "flash" in m]
if flash_models:
return find_best(flash_models)
# Default for BALANCED or as fallback
# Prefer Flash for balanced use, then Pro, then anything
flash_models = [m for m in allowed_models if "flash" in m]
if flash_models:
return find_best(flash_models)
pro_models = [m for m in allowed_models if "pro" in m]
if pro_models:
return find_best(pro_models)
# Ultimate fallback to best available model
return find_best(allowed_models)

View File

@@ -309,8 +309,10 @@ class OpenAICompatibleProvider(ModelProvider):
max_retries = 4 max_retries = 4
retry_delays = [1, 3, 5, 8] retry_delays = [1, 3, 5, 8]
last_exception = None last_exception = None
actual_attempts = 0
for attempt in range(max_retries): for attempt in range(max_retries):
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
try: # Log the exact payload being sent for debugging try: # Log the exact payload being sent for debugging
import json import json
@@ -371,14 +373,13 @@ class OpenAICompatibleProvider(ModelProvider):
if is_retryable and attempt < max_retries - 1: if is_retryable and attempt < max_retries - 1:
delay = retry_delays[attempt] delay = retry_delays[attempt]
logging.warning( logging.warning(
f"Retryable error for o3-pro responses endpoint, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..." f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
) )
time.sleep(delay) time.sleep(delay)
else: else:
break break
# If we get here, all retries failed # If we get here, all retries failed
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg) logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception raise RuntimeError(error_msg) from last_exception
@@ -481,7 +482,7 @@ class OpenAICompatibleProvider(ModelProvider):
completion_params[key] = value completion_params[key] = value
# Check if this is o3-pro and needs the responses endpoint # Check if this is o3-pro and needs the responses endpoint
if resolved_model == "o3-pro-2025-06-10": if resolved_model == "o3-pro":
# This model requires the /v1/responses endpoint # This model requires the /v1/responses endpoint
# If it fails, we should not fall back to chat/completions # If it fails, we should not fall back to chat/completions
return self._generate_with_responses_endpoint( return self._generate_with_responses_endpoint(
@@ -497,8 +498,10 @@ class OpenAICompatibleProvider(ModelProvider):
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
last_exception = None last_exception = None
actual_attempts = 0
for attempt in range(max_retries): for attempt in range(max_retries):
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
try: try:
# Generate completion # Generate completion
response = self.client.chat.completions.create(**completion_params) response = self.client.chat.completions.create(**completion_params)
@@ -536,12 +539,11 @@ class OpenAICompatibleProvider(ModelProvider):
# Log retry attempt # Log retry attempt
logging.warning( logging.warning(
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..." f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
) )
time.sleep(delay) time.sleep(delay)
# If we get here, all retries failed # If we get here, all retries failed
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg) logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception raise RuntimeError(error_msg) from last_exception
@@ -576,11 +578,7 @@ class OpenAICompatibleProvider(ModelProvider):
try: try:
encoding = tiktoken.encoding_for_model(model_name) encoding = tiktoken.encoding_for_model(model_name)
except KeyError: except KeyError:
# Try common encodings based on model patterns encoding = tiktoken.get_encoding("cl100k_base")
if "gpt-4" in model_name or "gpt-3.5" in model_name:
encoding = tiktoken.get_encoding("cl100k_base")
else:
encoding = tiktoken.get_encoding("cl100k_base") # Default
return len(encoding.encode(text)) return len(encoding.encode(text))
@@ -679,11 +677,13 @@ class OpenAICompatibleProvider(ModelProvider):
""" """
# Common vision-capable models - only include models that actually support images # Common vision-capable models - only include models that actually support images
vision_models = { vision_models = {
"gpt-5",
"gpt-5-mini",
"gpt-4o", "gpt-4o",
"gpt-4o-mini", "gpt-4o-mini",
"gpt-4-turbo", "gpt-4-turbo",
"gpt-4-vision-preview", "gpt-4-vision-preview",
"gpt-4.1-2025-04-14", # GPT-4.1 supports vision "gpt-4.1-2025-04-14",
"o3", "o3",
"o3-mini", "o3-mini",
"o3-pro", "o3-pro",

View File

@@ -1,7 +1,10 @@
"""OpenAI model provider implementation.""" """OpenAI model provider implementation."""
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .base import ( from .base import (
ModelCapabilities, ModelCapabilities,
@@ -19,6 +22,42 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
# Model configurations using ModelCapabilities objects # Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"gpt-5": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="gpt-5",
friendly_name="OpenAI (GPT-5)",
context_window=400_000, # 400K tokens
max_output_tokens=128_000, # 128K max output tokens
supports_extended_thinking=True, # Supports reasoning tokens
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # GPT-5 supports vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=True, # Regular models accept temperature parameter
temperature_constraint=create_temperature_constraint("fixed"),
description="GPT-5 (400K context, 128K output) - Advanced model with reasoning support",
aliases=["gpt5", "gpt-5"],
),
"gpt-5-mini": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="gpt-5-mini",
friendly_name="OpenAI (GPT-5-mini)",
context_window=400_000, # 400K tokens
max_output_tokens=128_000, # 128K max output tokens
supports_extended_thinking=True, # Supports reasoning tokens
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # GPT-5-mini supports vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=True, # Regular models accept temperature parameter
temperature_constraint=create_temperature_constraint("fixed"),
description="GPT-5-mini (400K context, 128K output) - Efficient variant with reasoning support",
aliases=["gpt5-mini", "gpt5mini", "mini"],
),
"o3": ModelCapabilities( "o3": ModelCapabilities(
provider=ProviderType.OPENAI, provider=ProviderType.OPENAI,
model_name="o3", model_name="o3",
@@ -55,9 +94,9 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity", description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
aliases=["o3mini", "o3-mini"], aliases=["o3mini", "o3-mini"],
), ),
"o3-pro-2025-06-10": ModelCapabilities( "o3-pro": ModelCapabilities(
provider=ProviderType.OPENAI, provider=ProviderType.OPENAI,
model_name="o3-pro-2025-06-10", model_name="o3-pro",
friendly_name="OpenAI (O3-Pro)", friendly_name="OpenAI (O3-Pro)",
context_window=200_000, # 200K tokens context_window=200_000, # 200K tokens
max_output_tokens=65536, # 64K max output tokens max_output_tokens=65536, # 64K max output tokens
@@ -89,11 +128,11 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
supports_temperature=False, # O4 models don't accept temperature parameter supports_temperature=False, # O4 models don't accept temperature parameter
temperature_constraint=create_temperature_constraint("fixed"), temperature_constraint=create_temperature_constraint("fixed"),
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning", description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
aliases=["mini", "o4mini", "o4-mini"], aliases=["o4mini", "o4-mini"],
), ),
"gpt-4.1-2025-04-14": ModelCapabilities( "gpt-4.1": ModelCapabilities(
provider=ProviderType.OPENAI, provider=ProviderType.OPENAI,
model_name="gpt-4.1-2025-04-14", model_name="gpt-4.1",
friendly_name="OpenAI (GPT 4.1)", friendly_name="OpenAI (GPT 4.1)",
context_window=1_000_000, # 1M tokens context_window=1_000_000, # 1M tokens
max_output_tokens=32_768, max_output_tokens=32_768,
@@ -107,7 +146,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
supports_temperature=True, # Regular models accept temperature parameter supports_temperature=True, # Regular models accept temperature parameter
temperature_constraint=create_temperature_constraint("range"), temperature_constraint=create_temperature_constraint("range"),
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window", description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
aliases=["gpt4.1"], aliases=["gpt4.1", "gpt-4.1"],
), ),
} }
@@ -119,21 +158,41 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
def get_capabilities(self, model_name: str) -> ModelCapabilities: def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a specific OpenAI model.""" """Get capabilities for a specific OpenAI model."""
# Resolve shorthand # First check if it's a key in SUPPORTED_MODELS
if model_name in self.SUPPORTED_MODELS:
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if not restriction_service.is_allowed(ProviderType.OPENAI, model_name, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
return self.SUPPORTED_MODELS[model_name]
# Try resolving as alias
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS: # Check if resolved name is a key
raise ValueError(f"Unsupported OpenAI model: {model_name}") if resolved_name in self.SUPPORTED_MODELS:
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
# Check if model is allowed by restrictions restriction_service = get_restriction_service()
from utils.model_restrictions import get_restriction_service if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
return self.SUPPORTED_MODELS[resolved_name]
restriction_service = get_restriction_service() # Finally check if resolved name matches any API model name
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name): for key, capabilities in self.SUPPORTED_MODELS.items():
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.") if resolved_name == capabilities.model_name:
# Check if model is allowed by restrictions
from utils.model_restrictions import get_restriction_service
# Return the ModelCapabilities object directly from SUPPORTED_MODELS restriction_service = get_restriction_service()
return self.SUPPORTED_MODELS[resolved_name] if not restriction_service.is_allowed(ProviderType.OPENAI, key, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
return capabilities
raise ValueError(f"Unsupported OpenAI model: {model_name}")
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""
@@ -182,6 +241,47 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
def supports_thinking_mode(self, model_name: str) -> bool: def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode.""" """Check if the model supports extended thinking mode."""
# Currently no OpenAI models support extended thinking # GPT-5 models support reasoning tokens (extended thinking)
# This may change with future O3 models resolved_name = self._resolve_model_name(model_name)
if resolved_name in ["gpt-5", "gpt-5-mini"]:
return True
# O3 models don't support extended thinking yet
return False return False
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get OpenAI's preferred model for a given category from allowed models.
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of models allowed by restrictions
Returns:
Preferred model name or None
"""
from tools.models import ToolModelCategory
if not allowed_models:
return None
# Helper to find first available from preference list
def find_first(preferences: list[str]) -> Optional[str]:
"""Return first available model from preference list."""
for model in preferences:
if model in allowed_models:
return model
return None
if category == ToolModelCategory.EXTENDED_REASONING:
# Prefer models with extended thinking support
preferred = find_first(["o3", "o3-pro", "gpt-5"])
return preferred if preferred else allowed_models[0]
elif category == ToolModelCategory.FAST_RESPONSE:
# Prefer fast, cost-efficient models
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
return preferred if preferred else allowed_models[0]
else: # BALANCED or default
# Prefer balanced performance/cost models
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
return preferred if preferred else allowed_models[0]

View File

@@ -15,6 +15,17 @@ class ModelProviderRegistry:
_instance = None _instance = None
# Provider priority order for model selection
# Native APIs first, then custom endpoints, then catch-all providers
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.XAI, # Direct X.AI GROK access
ProviderType.DIAL, # DIAL unified API access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
def __new__(cls): def __new__(cls):
"""Singleton pattern for registry.""" """Singleton pattern for registry."""
if cls._instance is None: if cls._instance is None:
@@ -103,30 +114,19 @@ class ModelProviderRegistry:
3. OPENROUTER - Catch-all for cloud models via unified API 3. OPENROUTER - Catch-all for cloud models via unified API
Args: Args:
model_name: Name of the model (e.g., "gemini-2.5-flash", "o3-mini") model_name: Name of the model (e.g., "gemini-2.5-flash", "gpt5")
Returns: Returns:
ModelProvider instance that supports this model ModelProvider instance that supports this model
""" """
logging.debug(f"get_provider_for_model called with model_name='{model_name}'") logging.debug(f"get_provider_for_model called with model_name='{model_name}'")
# Define explicit provider priority order
# Native APIs first, then custom endpoints, then catch-all providers
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.XAI, # Direct X.AI GROK access
ProviderType.DIAL, # DIAL unified API access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
# Check providers in priority order # Check providers in priority order
instance = cls() instance = cls()
logging.debug(f"Registry instance: {instance}") logging.debug(f"Registry instance: {instance}")
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}") logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
for provider_type in PROVIDER_PRIORITY_ORDER: for provider_type in cls.PROVIDER_PRIORITY_ORDER:
if provider_type in instance._providers: if provider_type in instance._providers:
logging.debug(f"Found {provider_type} in registry") logging.debug(f"Found {provider_type} in registry")
# Get or create provider instance # Get or create provider instance
@@ -244,14 +244,49 @@ class ModelProviderRegistry:
return os.getenv(env_var) return os.getenv(env_var)
@classmethod
def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]:
"""Get a list of allowed canonical model names for a given provider.
Args:
provider: The provider instance to get models for
provider_type: The provider type for restriction checking
Returns:
List of model names that are both supported and allowed
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
allowed_models = []
# Get the provider's supported models
try:
# Use list_models to get all supported models (handles both regular and custom providers)
supported_models = provider.list_models(respect_restrictions=False)
except (NotImplementedError, AttributeError):
# Fallback to SUPPORTED_MODELS if list_models not implemented
try:
supported_models = list(provider.SUPPORTED_MODELS.keys())
except AttributeError:
supported_models = []
# Filter by restrictions
for model_name in supported_models:
if restriction_service.is_allowed(provider_type, model_name):
allowed_models.append(model_name)
return allowed_models
@classmethod @classmethod
def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str: def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str:
"""Get the preferred fallback model based on available API keys and tool category. """Get the preferred fallback model based on provider priority and tool category.
This method checks which providers have valid API keys and returns This method orchestrates model selection by:
a sensible default model for auto mode fallback situations. 1. Getting allowed models for each provider (respecting restrictions)
2. Asking providers for their preference from the allowed list
Takes into account model restrictions when selecting fallback models. 3. Falling back to first available model if no preference given
Args: Args:
tool_category: Optional category to influence model selection tool_category: Optional category to influence model selection
@@ -259,167 +294,42 @@ class ModelProviderRegistry:
Returns: Returns:
Model name string for fallback use Model name string for fallback use
""" """
# Import here to avoid circular import
from tools.models import ToolModelCategory from tools.models import ToolModelCategory
# Get available models respecting restrictions effective_category = tool_category or ToolModelCategory.BALANCED
available_models = cls.get_available_models(respect_restrictions=True) first_available_model = None
# Group by provider # Ask each provider for their preference in priority order
openai_models = [m for m, p in available_models.items() if p == ProviderType.OPENAI] for provider_type in cls.PROVIDER_PRIORITY_ORDER:
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE] provider = cls.get_provider(provider_type)
xai_models = [m for m, p in available_models.items() if p == ProviderType.XAI] if provider:
openrouter_models = [m for m, p in available_models.items() if p == ProviderType.OPENROUTER] # 1. Registry filters the models first
custom_models = [m for m, p in available_models.items() if p == ProviderType.CUSTOM] allowed_models = cls._get_allowed_models_for_provider(provider, provider_type)
openai_available = bool(openai_models) if not allowed_models:
gemini_available = bool(gemini_models)
xai_available = bool(xai_models)
openrouter_available = bool(openrouter_models)
custom_available = bool(custom_models)
if tool_category == ToolModelCategory.EXTENDED_REASONING:
# Prefer thinking-capable models for deep reasoning tools
if openai_available and "o3" in openai_models:
return "o3" # O3 for deep reasoning
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
elif xai_available and "grok-3" in xai_models:
return "grok-3" # GROK-3 for deep reasoning
elif xai_available and xai_models:
# Fall back to any available XAI model
return xai_models[0]
elif gemini_available and any("pro" in m for m in gemini_models):
# Find the pro model (handles full names)
return next(m for m in gemini_models if "pro" in m)
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
elif openrouter_available:
# Try to find thinking-capable model from openrouter
thinking_model = cls._find_extended_thinking_model()
if thinking_model:
return thinking_model
# Fallback to first available OpenRouter model
return openrouter_models[0]
elif custom_available:
# Fallback to custom models when available
return custom_models[0]
else:
# Fallback to pro if nothing found
return "gemini-2.5-pro"
elif tool_category == ToolModelCategory.FAST_RESPONSE:
# Prefer fast, cost-efficient models
if openai_available and "o4-mini" in openai_models:
return "o4-mini" # Latest, fast and efficient
elif openai_available and "o3-mini" in openai_models:
return "o3-mini" # Second choice
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
elif xai_available and "grok-3-fast" in xai_models:
return "grok-3-fast" # GROK-3 Fast for speed
elif xai_available and xai_models:
# Fall back to any available XAI model
return xai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
# Find the flash model (handles full names)
# Prefer 2.5 over 2.0 for backward compatibility
flash_models = [m for m in gemini_models if "flash" in m]
# Sort to ensure 2.5 comes before 2.0
flash_models_sorted = sorted(flash_models, reverse=True)
return flash_models_sorted[0]
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
elif openrouter_available:
# Fallback to first available OpenRouter model
return openrouter_models[0]
elif custom_available:
# Fallback to custom models when available
return custom_models[0]
else:
# Default to flash
return "gemini-2.5-flash"
# BALANCED or no category specified - use existing balanced logic
if openai_available and "o4-mini" in openai_models:
return "o4-mini" # Latest balanced performance/cost
elif openai_available and "o3-mini" in openai_models:
return "o3-mini" # Second choice
elif openai_available and openai_models:
return openai_models[0]
elif xai_available and "grok-3" in xai_models:
return "grok-3" # GROK-3 as balanced choice
elif xai_available and xai_models:
return xai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
# Prefer 2.5 over 2.0 for backward compatibility
flash_models = [m for m in gemini_models if "flash" in m]
flash_models_sorted = sorted(flash_models, reverse=True)
return flash_models_sorted[0]
elif gemini_available and gemini_models:
return gemini_models[0]
elif openrouter_available:
return openrouter_models[0]
elif custom_available:
# Fallback to custom models when available
return custom_models[0]
else:
# No models available due to restrictions - check if any providers exist
if not available_models:
# This might happen if all models are restricted
logging.warning("No models available due to restrictions")
# Return a reasonable default for backward compatibility
return "gemini-2.5-flash"
@classmethod
def _find_extended_thinking_model(cls) -> Optional[str]:
"""Find a model suitable for extended reasoning from custom/openrouter providers.
Returns:
Model name if found, None otherwise
"""
# Check custom provider first
custom_provider = cls.get_provider(ProviderType.CUSTOM)
if custom_provider:
# Check if it's a CustomModelProvider and has thinking models
try:
from providers.custom import CustomProvider
if isinstance(custom_provider, CustomProvider) and hasattr(custom_provider, "model_registry"):
for model_name, config in custom_provider.model_registry.items():
if config.get("supports_extended_thinking", False):
return model_name
except ImportError:
pass
# Then check OpenRouter for high-context/powerful models
openrouter_provider = cls.get_provider(ProviderType.OPENROUTER)
if openrouter_provider:
# Prefer models known for deep reasoning
preferred_models = [
"anthropic/claude-sonnet-4",
"anthropic/claude-opus-4",
"google/gemini-2.5-pro",
"google/gemini-pro-1.5",
"meta-llama/llama-3.1-70b-instruct",
"mistralai/mixtral-8x7b-instruct",
]
for model in preferred_models:
try:
if openrouter_provider.validate_model_name(model):
return model
except Exception as e:
# Log the error for debugging purposes but continue searching
import logging
logging.warning(f"Model validation for '{model}' on OpenRouter failed: {e}")
continue continue
return None # 2. Keep track of the first available model as fallback
if not first_available_model:
first_available_model = sorted(allowed_models)[0]
# 3. Ask provider to pick from allowed list
preferred_model = provider.get_preferred_model(effective_category, allowed_models)
if preferred_model:
logging.debug(
f"Provider {provider_type.value} selected '{preferred_model}' for category '{effective_category.value}'"
)
return preferred_model
# If no provider returned a preference, use first available model
if first_available_model:
logging.debug(f"No provider preference, using first available: {first_available_model}")
return first_available_model
# Ultimate fallback if no providers have models
logging.warning("No models available from any provider, using default fallback")
return "gemini-2.5-flash"
@classmethod @classmethod
def get_available_providers_with_keys(cls) -> list[ProviderType]: def get_available_providers_with_keys(cls) -> list[ProviderType]:

View File

@@ -1,7 +1,10 @@
"""X.AI (GROK) model provider implementation.""" """X.AI (GROK) model provider implementation."""
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .base import ( from .base import (
ModelCapabilities, ModelCapabilities,
@@ -133,3 +136,41 @@ class XAIModelProvider(OpenAICompatibleProvider):
# Currently GROK models do not support extended thinking # Currently GROK models do not support extended thinking
# This may change with future GROK model releases # This may change with future GROK model releases
return False return False
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get XAI's preferred model for a given category from allowed models.
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of models allowed by restrictions
Returns:
Preferred model name or None
"""
from tools.models import ToolModelCategory
if not allowed_models:
return None
if category == ToolModelCategory.EXTENDED_REASONING:
# Prefer GROK-3 for reasoning
if "grok-3" in allowed_models:
return "grok-3"
# Fall back to any available model
return allowed_models[0]
elif category == ToolModelCategory.FAST_RESPONSE:
# Prefer GROK-3-Fast for speed
if "grok-3-fast" in allowed_models:
return "grok-3-fast"
# Fall back to any available model
return allowed_models[0]
else: # BALANCED or default
# Prefer standard GROK-3 for balanced use
if "grok-3" in allowed_models:
return "grok-3"
elif "grok-3-fast" in allowed_models:
return "grok-3-fast"
# Fall back to any available model
return allowed_models[0]

View File

@@ -409,9 +409,9 @@ def configure_providers():
openai_key = os.getenv("OPENAI_API_KEY") openai_key = os.getenv("OPENAI_API_KEY")
logger.debug(f"OpenAI key check: key={'[PRESENT]' if openai_key else '[MISSING]'}") logger.debug(f"OpenAI key check: key={'[PRESENT]' if openai_key else '[MISSING]'}")
if openai_key and openai_key != "your_openai_api_key_here": if openai_key and openai_key != "your_openai_api_key_here":
valid_providers.append("OpenAI (o3)") valid_providers.append("OpenAI")
has_native_apis = True has_native_apis = True
logger.info("OpenAI API key found - o3 model available") logger.info("OpenAI API key found")
else: else:
if not openai_key: if not openai_key:
logger.debug("OpenAI API key not found in environment") logger.debug("OpenAI API key not found in environment")
@@ -493,7 +493,7 @@ def configure_providers():
raise ValueError( raise ValueError(
"At least one API configuration is required. Please set either:\n" "At least one API configuration is required. Please set either:\n"
"- GEMINI_API_KEY for Gemini models\n" "- GEMINI_API_KEY for Gemini models\n"
"- OPENAI_API_KEY for OpenAI o3 model\n" "- OPENAI_API_KEY for OpenAI models\n"
"- XAI_API_KEY for X.AI GROK models\n" "- XAI_API_KEY for X.AI GROK models\n"
"- DIAL_API_KEY for DIAL models\n" "- DIAL_API_KEY for DIAL models\n"
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n" "- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
@@ -742,7 +742,9 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextCon
# Parse model:option format if present # Parse model:option format if present
model_name, model_option = parse_model_option(model_name) model_name, model_option = parse_model_option(model_name)
if model_option: if model_option:
logger.debug(f"Parsed model format - model: '{model_name}', option: '{model_option}'") logger.info(f"Parsed model format - model: '{model_name}', option: '{model_option}'")
else:
logger.info(f"Parsed model format - model: '{model_name}'")
# Consensus tool handles its own model configuration validation # Consensus tool handles its own model configuration validation
# No special handling needed at server level # No special handling needed at server level
@@ -1190,16 +1192,16 @@ async def handle_get_prompt(name: str, arguments: dict[str, Any] = None) -> GetP
""" """
Get prompt details and generate the actual prompt text. Get prompt details and generate the actual prompt text.
This handler is called when a user invokes a prompt (e.g., /zen:thinkdeeper or /zen:chat:o3). This handler is called when a user invokes a prompt (e.g., /zen:thinkdeeper or /zen:chat:gpt5).
It generates the appropriate text that Claude will then use to call the It generates the appropriate text that Claude will then use to call the
underlying tool. underlying tool.
Supports structured prompt names like "chat:o3" where: Supports structured prompt names like "chat:gpt5" where:
- "chat" is the tool name - "chat" is the tool name
- "o3" is the model to use - "gpt5" is the model to use
Args: Args:
name: The name of the prompt to execute (can include model like "chat:o3") name: The name of the prompt to execute (can include model like "chat:gpt5")
arguments: Optional arguments for the prompt (e.g., model, thinking_mode) arguments: Optional arguments for the prompt (e.g., model, thinking_mode)
Returns: Returns:

View File

@@ -48,7 +48,8 @@ class TestAliasTargetRestrictions:
"""Test that restriction policy allows alias when target model is allowed. """Test that restriction policy allows alias when target model is allowed.
This is the correct user-friendly behavior - if you allow 'o4-mini', This is the correct user-friendly behavior - if you allow 'o4-mini',
you should be able to use its alias 'mini' as well. you should be able to use its aliases 'o4mini' and 'o4-mini'.
Note: 'mini' is now an alias for 'gpt-5-mini', not 'o4-mini'.
""" """
# Clear cached restriction service # Clear cached restriction service
import utils.model_restrictions import utils.model_restrictions
@@ -57,15 +58,16 @@ class TestAliasTargetRestrictions:
provider = OpenAIModelProvider(api_key="test-key") provider = OpenAIModelProvider(api_key="test-key")
# Both target and alias should be allowed # Both target and its actual aliases should be allowed
assert provider.validate_model_name("o4-mini") assert provider.validate_model_name("o4-mini")
assert provider.validate_model_name("mini") assert provider.validate_model_name("o4mini")
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only
def test_restriction_policy_allows_only_alias_when_alias_specified(self): def test_restriction_policy_allows_only_alias_when_alias_specified(self):
"""Test that restriction policy allows only the alias when just alias is specified. """Test that restriction policy allows only the alias when just alias is specified.
If you restrict to 'mini', only the alias should work, not the direct target. If you restrict to 'mini' (which is an alias for gpt-5-mini),
only the alias should work, not other models.
This is the correct restrictive behavior. This is the correct restrictive behavior.
""" """
# Clear cached restriction service # Clear cached restriction service
@@ -77,7 +79,9 @@ class TestAliasTargetRestrictions:
# Only the alias should be allowed # Only the alias should be allowed
assert provider.validate_model_name("mini") assert provider.validate_model_name("mini")
# Direct target should NOT be allowed # Direct target for this alias should NOT be allowed (mini -> gpt-5-mini)
assert not provider.validate_model_name("gpt-5-mini")
# Other models should NOT be allowed
assert not provider.validate_model_name("o4-mini") assert not provider.validate_model_name("o4-mini")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target
@@ -127,12 +131,15 @@ class TestAliasTargetRestrictions:
# The warning should include both aliases and targets in known models # The warning should include both aliases and targets in known models
warning_message = str(warning_calls[0]) warning_message = str(warning_calls[0])
assert "mini" in warning_message # alias should be in known models assert "o4mini" in warning_message or "o4-mini" in warning_message # aliases should be in known models
assert "o4-mini" in warning_message # target should be in known models
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,o4-mini"}) # Allow both alias and target @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,gpt-5-mini,o4-mini,o4mini"}) # Allow different models
def test_both_alias_and_target_allowed_when_both_specified(self): def test_both_alias_and_target_allowed_when_both_specified(self):
"""Test that both alias and target work when both are explicitly allowed.""" """Test that both alias and target work when both are explicitly allowed.
mini -> gpt-5-mini
o4mini -> o4-mini
"""
# Clear cached restriction service # Clear cached restriction service
import utils.model_restrictions import utils.model_restrictions
@@ -140,9 +147,11 @@ class TestAliasTargetRestrictions:
provider = OpenAIModelProvider(api_key="test-key") provider = OpenAIModelProvider(api_key="test-key")
# Both should be allowed # All should be allowed since we explicitly allowed them
assert provider.validate_model_name("mini") assert provider.validate_model_name("mini") # alias for gpt-5-mini
assert provider.validate_model_name("o4-mini") assert provider.validate_model_name("gpt-5-mini") # target
assert provider.validate_model_name("o4-mini") # target
assert provider.validate_model_name("o4mini") # alias for o4-mini
def test_alias_target_policy_regression_prevention(self): def test_alias_target_policy_regression_prevention(self):
"""Regression test to ensure aliases and targets are both validated properly. """Regression test to ensure aliases and targets are both validated properly.

View File

@@ -95,8 +95,8 @@ class TestAutoModeComprehensive:
}, },
{ {
"EXTENDED_REASONING": "o3", # O3 for deep reasoning "EXTENDED_REASONING": "o3", # O3 for deep reasoning
"FAST_RESPONSE": "o4-mini", # O4-mini for speed "FAST_RESPONSE": "gpt-5", # Prefer gpt-5 for speed
"BALANCED": "o4-mini", # O4-mini as balanced "BALANCED": "gpt-5", # Prefer gpt-5 for balanced
}, },
), ),
# Only X.AI API available # Only X.AI API available
@@ -113,7 +113,7 @@ class TestAutoModeComprehensive:
"BALANCED": "grok-3", # GROK-3 as balanced "BALANCED": "grok-3", # GROK-3 as balanced
}, },
), ),
# Both Gemini and OpenAI available - should prefer based on tool category # Both Gemini and OpenAI available - Google comes first in priority
( (
{ {
"GEMINI_API_KEY": "real-key", "GEMINI_API_KEY": "real-key",
@@ -122,12 +122,12 @@ class TestAutoModeComprehensive:
"OPENROUTER_API_KEY": None, "OPENROUTER_API_KEY": None,
}, },
{ {
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning "EXTENDED_REASONING": "gemini-2.5-pro", # Gemini comes first in priority
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed "FAST_RESPONSE": "gemini-2.5-flash", # Prefer flash for speed
"BALANCED": "o4-mini", # Prefer OpenAI for balanced "BALANCED": "gemini-2.5-flash", # Prefer flash for balanced
}, },
), ),
# All native APIs available - should prefer based on tool category # All native APIs available - Google still comes first
( (
{ {
"GEMINI_API_KEY": "real-key", "GEMINI_API_KEY": "real-key",
@@ -136,9 +136,9 @@ class TestAutoModeComprehensive:
"OPENROUTER_API_KEY": None, "OPENROUTER_API_KEY": None,
}, },
{ {
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning "EXTENDED_REASONING": "gemini-2.5-pro", # Gemini comes first in priority
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed "FAST_RESPONSE": "gemini-2.5-flash", # Prefer flash for speed
"BALANCED": "o4-mini", # Prefer OpenAI for balanced "BALANCED": "gemini-2.5-flash", # Prefer flash for balanced
}, },
), ),
], ],

View File

@@ -97,10 +97,10 @@ class TestAutoModeProviderSelection:
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE) fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED) balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
# Should select appropriate OpenAI models # Should select appropriate OpenAI models based on new preference order
assert extended_reasoning in ["o3", "o3-mini", "o4-mini"] # Any available OpenAI model for reasoning assert extended_reasoning == "o3" # O3 for extended reasoning
assert fast_response in ["o4-mini", "o3-mini"] # Prefer faster models assert fast_response == "gpt-5" # gpt-5 comes first in fast response preference
assert balanced in ["o4-mini", "o3-mini"] # Balanced selection assert balanced == "gpt-5" # gpt-5 for balanced
finally: finally:
# Restore original environment # Restore original environment
@@ -138,11 +138,11 @@ class TestAutoModeProviderSelection:
) )
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE) fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# Should prefer OpenAI for reasoning (based on fallback logic) # Should prefer Gemini now (based on new provider priority: Gemini before OpenAI)
assert extended_reasoning == "o3" # Should prefer O3 for extended reasoning assert extended_reasoning == "gemini-2.5-pro" # Gemini has higher priority now
# Should prefer OpenAI for fast response # Should prefer Gemini for fast response
assert fast_response == "o4-mini" # Should prefer O4-mini for fast response assert fast_response == "gemini-2.5-flash" # Gemini has higher priority now
finally: finally:
# Restore original environment # Restore original environment
@@ -318,7 +318,7 @@ class TestAutoModeProviderSelection:
test_cases = [ test_cases = [
("flash", ProviderType.GOOGLE, "gemini-2.5-flash"), ("flash", ProviderType.GOOGLE, "gemini-2.5-flash"),
("pro", ProviderType.GOOGLE, "gemini-2.5-pro"), ("pro", ProviderType.GOOGLE, "gemini-2.5-pro"),
("mini", ProviderType.OPENAI, "o4-mini"), ("mini", ProviderType.OPENAI, "gpt-5-mini"), # "mini" now resolves to gpt-5-mini
("o3mini", ProviderType.OPENAI, "o3-mini"), ("o3mini", ProviderType.OPENAI, "o3-mini"),
("grok", ProviderType.XAI, "grok-3"), ("grok", ProviderType.XAI, "grok-3"),
("grokfast", ProviderType.XAI, "grok-3-fast"), ("grokfast", ProviderType.XAI, "grok-3-fast"),

View File

@@ -132,8 +132,11 @@ class TestBuggyBehaviorPrevention:
assert not provider.validate_model_name("o3-pro") # Not in allowed list assert not provider.validate_model_name("o3-pro") # Not in allowed list
assert not provider.validate_model_name("o3") # Not in allowed list assert not provider.validate_model_name("o3") # Not in allowed list
# This should be ALLOWED because it resolves to o4-mini which is in the allowed list # "mini" now resolves to gpt-5-mini, not o4-mini, so it should be blocked
assert provider.validate_model_name("mini") # Resolves to o4-mini, which IS allowed assert not provider.validate_model_name("mini") # Resolves to gpt-5-mini, which is NOT allowed
# But o4mini (the actual alias for o4-mini) should work
assert provider.validate_model_name("o4mini") # Resolves to o4-mini, which IS allowed
# Verify our list_all_known_models includes the restricted models # Verify our list_all_known_models includes the restricted models
all_known = provider.list_all_known_models() all_known = provider.list_all_known_models()

View File

@@ -113,7 +113,7 @@ class TestDIALProvider:
# Test temperature constraint # Test temperature constraint
assert capabilities.temperature_constraint.min_temp == 0.0 assert capabilities.temperature_constraint.min_temp == 0.0
assert capabilities.temperature_constraint.max_temp == 2.0 assert capabilities.temperature_constraint.max_temp == 2.0
assert capabilities.temperature_constraint.default_temp == 0.7 assert capabilities.temperature_constraint.default_temp == 0.3
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False) @patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
@patch("utils.model_restrictions._restriction_service", None) @patch("utils.model_restrictions._restriction_service", None)

View File

@@ -37,14 +37,14 @@ class TestIntelligentFallback:
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False) @patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False)
def test_prefers_openai_o3_mini_when_available(self): def test_prefers_openai_o3_mini_when_available(self):
"""Test that o4-mini is preferred when OpenAI API key is available""" """Test that gpt-5 is preferred when OpenAI API key is available (based on new preference order)"""
# Register only OpenAI provider for this test # Register only OpenAI provider for this test
from providers.openai_provider import OpenAIModelProvider from providers.openai_provider import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider) ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
fallback_model = ModelProviderRegistry.get_preferred_fallback_model() fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o4-mini" assert fallback_model == "gpt-5" # Based on new preference order: gpt-5 before o4-mini
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False) @patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
def test_prefers_gemini_flash_when_openai_unavailable(self): def test_prefers_gemini_flash_when_openai_unavailable(self):
@@ -68,7 +68,7 @@ class TestIntelligentFallback:
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider) ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
fallback_model = ModelProviderRegistry.get_preferred_fallback_model() fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o4-mini" # OpenAI has priority assert fallback_model == "gemini-2.5-flash" # Gemini has priority now (based on new PROVIDER_PRIORITY_ORDER)
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False) @patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False)
def test_fallback_when_no_keys_available(self): def test_fallback_when_no_keys_available(self):
@@ -147,8 +147,8 @@ class TestIntelligentFallback:
history, tokens = build_conversation_history(context, model_context=None) history, tokens = build_conversation_history(context, model_context=None)
# Verify that ModelContext was called with o4-mini (the intelligent fallback) # Verify that ModelContext was called with gpt-5 (the intelligent fallback based on new preference order)
mock_context_class.assert_called_once_with("o4-mini") mock_context_class.assert_called_once_with("gpt-5")
def test_auto_mode_with_gemini_only(self): def test_auto_mode_with_gemini_only(self):
"""Test auto mode behavior when only Gemini API key is available""" """Test auto mode behavior when only Gemini API key is available"""

View File

@@ -635,6 +635,13 @@ class TestAutoModeWithRestrictions:
mock_openai.list_models = openai_list_models mock_openai.list_models = openai_list_models
mock_openai.list_all_known_models.return_value = ["o3", "o3-mini", "o4-mini"] mock_openai.list_all_known_models.return_value = ["o3", "o3-mini", "o4-mini"]
# Add get_preferred_model method to mock to match new implementation
def get_preferred_model(category, allowed_models):
# Simple preference logic for testing - just return first allowed model
return allowed_models[0] if allowed_models else None
mock_openai.get_preferred_model = get_preferred_model
def get_provider_side_effect(provider_type): def get_provider_side_effect(provider_type):
if provider_type == ProviderType.OPENAI: if provider_type == ProviderType.OPENAI:
return mock_openai return mock_openai
@@ -685,8 +692,9 @@ class TestAutoModeWithRestrictions:
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE) model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# The fallback will depend on how get_available_models handles aliases # The fallback will depend on how get_available_models handles aliases
# For now, we accept either behavior and document it # When "mini" is allowed, it's returned as the allowed model
assert model in ["o4-mini", "gemini-2.5-flash"] # "mini" is now an alias for gpt-5-mini, but the list shows "mini" itself
assert model in ["mini", "gpt-5-mini", "o4-mini", "gemini-2.5-flash"]
finally: finally:
# Restore original registry state # Restore original registry state
registry = ModelProviderRegistry() registry = ModelProviderRegistry()

View File

@@ -230,7 +230,7 @@ class TestO3TemperatureParameterFixSimple:
assert temp_constraint.validate(0.5) is False assert temp_constraint.validate(0.5) is False
# Test regular model constraints - use gpt-4.1 which is supported # Test regular model constraints - use gpt-4.1 which is supported
gpt41_capabilities = provider.get_capabilities("gpt-4.1-2025-04-14") gpt41_capabilities = provider.get_capabilities("gpt-4.1")
assert gpt41_capabilities.temperature_constraint is not None assert gpt41_capabilities.temperature_constraint is not None
# Regular models should allow a range # Regular models should allow a range

View File

@@ -48,12 +48,17 @@ class TestOpenAIProvider:
assert provider.validate_model_name("o3-pro") is True assert provider.validate_model_name("o3-pro") is True
assert provider.validate_model_name("o4-mini") is True assert provider.validate_model_name("o4-mini") is True
assert provider.validate_model_name("o4-mini") is True assert provider.validate_model_name("o4-mini") is True
assert provider.validate_model_name("gpt-5") is True
assert provider.validate_model_name("gpt-5-mini") is True
# Test valid aliases # Test valid aliases
assert provider.validate_model_name("mini") is True assert provider.validate_model_name("mini") is True
assert provider.validate_model_name("o3mini") is True assert provider.validate_model_name("o3mini") is True
assert provider.validate_model_name("o4mini") is True assert provider.validate_model_name("o4mini") is True
assert provider.validate_model_name("o4mini") is True assert provider.validate_model_name("o4mini") is True
assert provider.validate_model_name("gpt5") is True
assert provider.validate_model_name("gpt5-mini") is True
assert provider.validate_model_name("gpt5mini") is True
# Test invalid model # Test invalid model
assert provider.validate_model_name("invalid-model") is False assert provider.validate_model_name("invalid-model") is False
@@ -65,17 +70,22 @@ class TestOpenAIProvider:
provider = OpenAIModelProvider("test-key") provider = OpenAIModelProvider("test-key")
# Test shorthand resolution # Test shorthand resolution
assert provider._resolve_model_name("mini") == "o4-mini" assert provider._resolve_model_name("mini") == "gpt-5-mini" # "mini" now resolves to gpt-5-mini
assert provider._resolve_model_name("o3mini") == "o3-mini" assert provider._resolve_model_name("o3mini") == "o3-mini"
assert provider._resolve_model_name("o4mini") == "o4-mini" assert provider._resolve_model_name("o4mini") == "o4-mini"
assert provider._resolve_model_name("o4mini") == "o4-mini" assert provider._resolve_model_name("o4mini") == "o4-mini"
assert provider._resolve_model_name("gpt5") == "gpt-5"
assert provider._resolve_model_name("gpt5-mini") == "gpt-5-mini"
assert provider._resolve_model_name("gpt5mini") == "gpt-5-mini"
# Test full name passthrough # Test full name passthrough
assert provider._resolve_model_name("o3") == "o3" assert provider._resolve_model_name("o3") == "o3"
assert provider._resolve_model_name("o3-mini") == "o3-mini" assert provider._resolve_model_name("o3-mini") == "o3-mini"
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10" assert provider._resolve_model_name("o3-pro") == "o3-pro"
assert provider._resolve_model_name("o4-mini") == "o4-mini" assert provider._resolve_model_name("o4-mini") == "o4-mini"
assert provider._resolve_model_name("o4-mini") == "o4-mini" assert provider._resolve_model_name("o4-mini") == "o4-mini"
assert provider._resolve_model_name("gpt-5") == "gpt-5"
assert provider._resolve_model_name("gpt-5-mini") == "gpt-5-mini"
def test_get_capabilities_o3(self): def test_get_capabilities_o3(self):
"""Test getting model capabilities for O3.""" """Test getting model capabilities for O3."""
@@ -99,11 +109,43 @@ class TestOpenAIProvider:
provider = OpenAIModelProvider("test-key") provider = OpenAIModelProvider("test-key")
capabilities = provider.get_capabilities("mini") capabilities = provider.get_capabilities("mini")
assert capabilities.model_name == "o4-mini" # Capabilities should show resolved model name assert capabilities.model_name == "gpt-5-mini" # "mini" now resolves to gpt-5-mini
assert capabilities.friendly_name == "OpenAI (O4-mini)" assert capabilities.friendly_name == "OpenAI (GPT-5-mini)"
assert capabilities.context_window == 200_000 assert capabilities.context_window == 400_000
assert capabilities.provider == ProviderType.OPENAI assert capabilities.provider == ProviderType.OPENAI
def test_get_capabilities_gpt5(self):
"""Test getting model capabilities for GPT-5."""
provider = OpenAIModelProvider("test-key")
capabilities = provider.get_capabilities("gpt-5")
assert capabilities.model_name == "gpt-5"
assert capabilities.friendly_name == "OpenAI (GPT-5)"
assert capabilities.context_window == 400_000
assert capabilities.max_output_tokens == 128_000
assert capabilities.provider == ProviderType.OPENAI
assert capabilities.supports_extended_thinking is True
assert capabilities.supports_system_prompts is True
assert capabilities.supports_streaming is True
assert capabilities.supports_function_calling is True
assert capabilities.supports_temperature is True
def test_get_capabilities_gpt5_mini(self):
"""Test getting model capabilities for GPT-5-mini."""
provider = OpenAIModelProvider("test-key")
capabilities = provider.get_capabilities("gpt-5-mini")
assert capabilities.model_name == "gpt-5-mini"
assert capabilities.friendly_name == "OpenAI (GPT-5-mini)"
assert capabilities.context_window == 400_000
assert capabilities.max_output_tokens == 128_000
assert capabilities.provider == ProviderType.OPENAI
assert capabilities.supports_extended_thinking is True
assert capabilities.supports_system_prompts is True
assert capabilities.supports_streaming is True
assert capabilities.supports_function_calling is True
assert capabilities.supports_temperature is True
@patch("providers.openai_compatible.OpenAI") @patch("providers.openai_compatible.OpenAI")
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class): def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
"""Test that generate_content resolves aliases before making API calls. """Test that generate_content resolves aliases before making API calls.
@@ -132,21 +174,19 @@ class TestOpenAIProvider:
provider = OpenAIModelProvider("test-key") provider = OpenAIModelProvider("test-key")
# Call generate_content with alias 'gpt4.1' (resolves to gpt-4.1-2025-04-14, supports temperature) # Call generate_content with alias 'gpt4.1' (resolves to gpt-4.1, supports temperature)
result = provider.generate_content( result = provider.generate_content(
prompt="Test prompt", prompt="Test prompt",
model_name="gpt4.1", model_name="gpt4.1",
temperature=1.0, # This should be resolved to "gpt-4.1-2025-04-14" temperature=1.0, # This should be resolved to "gpt-4.1"
) )
# Verify the API was called with the RESOLVED model name # Verify the API was called with the RESOLVED model name
mock_client.chat.completions.create.assert_called_once() mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args[1] call_kwargs = mock_client.chat.completions.create.call_args[1]
# CRITICAL ASSERTION: The API should receive "gpt-4.1-2025-04-14", not "gpt4.1" # CRITICAL ASSERTION: The API should receive "gpt-4.1", not "gpt4.1"
assert ( assert call_kwargs["model"] == "gpt-4.1", f"Expected 'gpt-4.1' but API received '{call_kwargs['model']}'"
call_kwargs["model"] == "gpt-4.1-2025-04-14"
), f"Expected 'gpt-4.1-2025-04-14' but API received '{call_kwargs['model']}'"
# Verify other parameters (gpt-4.1 supports temperature unlike O3/O4 models) # Verify other parameters (gpt-4.1 supports temperature unlike O3/O4 models)
assert call_kwargs["temperature"] == 1.0 assert call_kwargs["temperature"] == 1.0
@@ -156,7 +196,7 @@ class TestOpenAIProvider:
# Verify response # Verify response
assert result.content == "Test response" assert result.content == "Test response"
assert result.model_name == "gpt-4.1-2025-04-14" # Should be the resolved name assert result.model_name == "gpt-4.1" # Should be the resolved name
@patch("providers.openai_compatible.OpenAI") @patch("providers.openai_compatible.OpenAI")
def test_generate_content_other_aliases(self, mock_openai_class): def test_generate_content_other_aliases(self, mock_openai_class):
@@ -213,14 +253,22 @@ class TestOpenAIProvider:
assert call_kwargs["model"] == "o3-mini" # Should be unchanged assert call_kwargs["model"] == "o3-mini" # Should be unchanged
def test_supports_thinking_mode(self): def test_supports_thinking_mode(self):
"""Test thinking mode support (currently False for all OpenAI models).""" """Test thinking mode support."""
provider = OpenAIModelProvider("test-key") provider = OpenAIModelProvider("test-key")
# All OpenAI models currently don't support thinking mode # GPT-5 models support thinking mode (reasoning tokens)
assert provider.supports_thinking_mode("gpt-5") is True
assert provider.supports_thinking_mode("gpt-5-mini") is True
assert provider.supports_thinking_mode("gpt5") is True # Test with alias
assert provider.supports_thinking_mode("gpt5mini") is True # Test with alias
# O3/O4 models don't support thinking mode
assert provider.supports_thinking_mode("o3") is False assert provider.supports_thinking_mode("o3") is False
assert provider.supports_thinking_mode("o3-mini") is False assert provider.supports_thinking_mode("o3-mini") is False
assert provider.supports_thinking_mode("o4-mini") is False assert provider.supports_thinking_mode("o4-mini") is False
assert provider.supports_thinking_mode("mini") is False # Test with alias too assert (
provider.supports_thinking_mode("mini") is True
) # "mini" now resolves to gpt-5-mini which supports thinking
@patch("providers.openai_compatible.OpenAI") @patch("providers.openai_compatible.OpenAI")
def test_o3_pro_routes_to_responses_endpoint(self, mock_openai_class): def test_o3_pro_routes_to_responses_endpoint(self, mock_openai_class):
@@ -234,7 +282,7 @@ class TestOpenAIProvider:
mock_response.output.content = [MagicMock()] mock_response.output.content = [MagicMock()]
mock_response.output.content[0].type = "output_text" mock_response.output.content[0].type = "output_text"
mock_response.output.content[0].text = "4" mock_response.output.content[0].text = "4"
mock_response.model = "o3-pro-2025-06-10" mock_response.model = "o3-pro"
mock_response.id = "test-id" mock_response.id = "test-id"
mock_response.created_at = 1234567890 mock_response.created_at = 1234567890
mock_response.usage = MagicMock() mock_response.usage = MagicMock()
@@ -252,13 +300,13 @@ class TestOpenAIProvider:
# Verify responses.create was called # Verify responses.create was called
mock_client.responses.create.assert_called_once() mock_client.responses.create.assert_called_once()
call_args = mock_client.responses.create.call_args[1] call_args = mock_client.responses.create.call_args[1]
assert call_args["model"] == "o3-pro-2025-06-10" assert call_args["model"] == "o3-pro"
assert call_args["input"][0]["role"] == "user" assert call_args["input"][0]["role"] == "user"
assert "What is 2 + 2?" in call_args["input"][0]["content"][0]["text"] assert "What is 2 + 2?" in call_args["input"][0]["content"][0]["text"]
# Verify the response # Verify the response
assert result.content == "4" assert result.content == "4"
assert result.model_name == "o3-pro-2025-06-10" assert result.model_name == "o3-pro"
assert result.metadata["endpoint"] == "responses" assert result.metadata["endpoint"] == "responses"
@patch("providers.openai_compatible.OpenAI") @patch("providers.openai_compatible.OpenAI")

View File

@@ -3,6 +3,7 @@ Test per-tool model default selection functionality
""" """
import json import json
import os
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@@ -73,154 +74,194 @@ class TestToolModelCategories:
class TestModelSelection: class TestModelSelection:
"""Test model selection based on tool categories.""" """Test model selection based on tool categories."""
def teardown_method(self):
"""Clean up after each test to prevent state pollution."""
ModelProviderRegistry.clear_cache()
# Unregister all providers
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
def test_extended_reasoning_with_openai(self): def test_extended_reasoning_with_openai(self):
"""Test EXTENDED_REASONING prefers o3 when OpenAI is available.""" """Test EXTENDED_REASONING with OpenAI provider."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available: # Setup with only OpenAI provider
# Mock OpenAI models available ModelProviderRegistry.clear_cache()
mock_get_available.return_value = { # First unregister all providers to ensure isolation
"o3": ProviderType.OPENAI, for provider_type in list(ProviderType):
"o3-mini": ProviderType.OPENAI, ModelProviderRegistry.unregister_provider(provider_type)
"o4-mini": ProviderType.OPENAI,
} with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
from providers.openai_provider import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING) model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# OpenAI prefers o3 for extended reasoning
assert model == "o3" assert model == "o3"
def test_extended_reasoning_with_gemini_only(self): def test_extended_reasoning_with_gemini_only(self):
"""Test EXTENDED_REASONING prefers pro when only Gemini is available.""" """Test EXTENDED_REASONING prefers pro when only Gemini is available."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available: # Clear cache and unregister all providers first
# Mock only Gemini models available ModelProviderRegistry.clear_cache()
mock_get_available.return_value = { for provider_type in list(ProviderType):
"gemini-2.5-pro": ProviderType.GOOGLE, ModelProviderRegistry.unregister_provider(provider_type)
"gemini-2.5-flash": ProviderType.GOOGLE,
} # Register only Gemini provider
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}, clear=False):
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING) model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# Should find the pro model for extended reasoning # Gemini should return one of its models for extended reasoning
assert "pro" in model or model == "gemini-2.5-pro" # The default behavior may return flash when pro is not explicitly preferred
assert model in ["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.0-flash"]
def test_fast_response_with_openai(self): def test_fast_response_with_openai(self):
"""Test FAST_RESPONSE prefers o4-mini when OpenAI is available.""" """Test FAST_RESPONSE with OpenAI provider."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available: # Setup with only OpenAI provider
# Mock OpenAI models available ModelProviderRegistry.clear_cache()
mock_get_available.return_value = { # First unregister all providers to ensure isolation
"o3": ProviderType.OPENAI, for provider_type in list(ProviderType):
"o3-mini": ProviderType.OPENAI, ModelProviderRegistry.unregister_provider(provider_type)
"o4-mini": ProviderType.OPENAI,
} with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
from providers.openai_provider import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE) model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
assert model == "o4-mini" # OpenAI now prefers gpt-5 for fast response (based on our new preference order)
assert model == "gpt-5"
def test_fast_response_with_gemini_only(self): def test_fast_response_with_gemini_only(self):
"""Test FAST_RESPONSE prefers flash when only Gemini is available.""" """Test FAST_RESPONSE prefers flash when only Gemini is available."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available: # Clear cache and unregister all providers first
# Mock only Gemini models available ModelProviderRegistry.clear_cache()
mock_get_available.return_value = { for provider_type in list(ProviderType):
"gemini-2.5-pro": ProviderType.GOOGLE, ModelProviderRegistry.unregister_provider(provider_type)
"gemini-2.5-flash": ProviderType.GOOGLE,
} # Register only Gemini provider
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}, clear=False):
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE) model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# Should find the flash model for fast response # Gemini should return one of its models for fast response
assert "flash" in model or model == "gemini-2.5-flash" assert model in ["gemini-2.5-flash", "gemini-2.0-flash", "gemini-2.5-pro"]
def test_balanced_category_fallback(self): def test_balanced_category_fallback(self):
"""Test BALANCED category uses existing logic.""" """Test BALANCED category uses existing logic."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available: # Setup with only OpenAI provider
# Mock OpenAI models available ModelProviderRegistry.clear_cache()
mock_get_available.return_value = { # First unregister all providers to ensure isolation
"o3": ProviderType.OPENAI, for provider_type in list(ProviderType):
"o3-mini": ProviderType.OPENAI, ModelProviderRegistry.unregister_provider(provider_type)
"o4-mini": ProviderType.OPENAI,
} with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
from providers.openai_provider import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED) model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
assert model == "o4-mini" # Balanced prefers o4-mini when OpenAI available # OpenAI prefers gpt-5 for balanced (based on our new preference order)
assert model == "gpt-5"
def test_no_category_uses_balanced_logic(self): def test_no_category_uses_balanced_logic(self):
"""Test that no category specified uses balanced logic.""" """Test that no category specified uses balanced logic."""
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available: # Setup with only Gemini provider
# Mock only Gemini models available with patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"}, clear=False):
mock_get_available.return_value = { from providers.gemini import GeminiModelProvider
"gemini-2.5-pro": ProviderType.GOOGLE,
"gemini-2.5-flash": ProviderType.GOOGLE, ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
}
model = ModelProviderRegistry.get_preferred_fallback_model() model = ModelProviderRegistry.get_preferred_fallback_model()
# Should pick a reasonable default, preferring flash for balanced use # Should pick flash for balanced use
assert "flash" in model or model == "gemini-2.5-flash" assert model == "gemini-2.5-flash"
class TestFlexibleModelSelection: class TestFlexibleModelSelection:
"""Test that model selection handles various naming scenarios.""" """Test that model selection handles various naming scenarios."""
def test_fallback_handles_mixed_model_names(self): def test_fallback_handles_mixed_model_names(self):
"""Test that fallback selection works with mix of full names and shorthands.""" """Test that fallback selection works with different providers."""
# Test with mix of full names and shorthands # Test with different provider configurations
test_cases = [ test_cases = [
# Case 1: Mix of OpenAI shorthands and full names # Case 1: OpenAI provider for extended reasoning
{ {
"available": {"o3": ProviderType.OPENAI, "o4-mini": ProviderType.OPENAI}, "env": {"OPENAI_API_KEY": "test-key"},
"provider_type": ProviderType.OPENAI,
"category": ToolModelCategory.EXTENDED_REASONING, "category": ToolModelCategory.EXTENDED_REASONING,
"expected": "o3", "expected": "o3",
}, },
# Case 2: Mix of Gemini shorthands and full names # Case 2: Gemini provider for fast response
{ {
"available": { "env": {"GEMINI_API_KEY": "test-key"},
"gemini-2.5-flash": ProviderType.GOOGLE, "provider_type": ProviderType.GOOGLE,
"gemini-2.5-pro": ProviderType.GOOGLE,
},
"category": ToolModelCategory.FAST_RESPONSE, "category": ToolModelCategory.FAST_RESPONSE,
"expected_contains": "flash", "expected": "gemini-2.5-flash",
}, },
# Case 3: Only shorthands available # Case 3: OpenAI provider for fast response
{ {
"available": {"o4-mini": ProviderType.OPENAI, "o3-mini": ProviderType.OPENAI}, "env": {"OPENAI_API_KEY": "test-key"},
"provider_type": ProviderType.OPENAI,
"category": ToolModelCategory.FAST_RESPONSE, "category": ToolModelCategory.FAST_RESPONSE,
"expected": "o4-mini", "expected": "gpt-5", # Based on new preference order
}, },
] ]
for case in test_cases: for case in test_cases:
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available: # Clear registry for clean test
mock_get_available.return_value = case["available"] ModelProviderRegistry.clear_cache()
# First unregister all providers to ensure isolation
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
with patch.dict(os.environ, case["env"], clear=False):
# Register the appropriate provider
if case["provider_type"] == ProviderType.OPENAI:
from providers.openai_provider import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
elif case["provider_type"] == ProviderType.GOOGLE:
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(case["category"]) model = ModelProviderRegistry.get_preferred_fallback_model(case["category"])
assert model == case["expected"], f"Failed for case: {case}, got {model}"
if "expected" in case:
assert model == case["expected"], f"Failed for case: {case}"
elif "expected_contains" in case:
assert (
case["expected_contains"] in model
), f"Expected '{case['expected_contains']}' in '{model}' for case: {case}"
class TestCustomProviderFallback: class TestCustomProviderFallback:
"""Test fallback to custom/openrouter providers.""" """Test fallback to custom/openrouter providers."""
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model") def test_extended_reasoning_custom_fallback(self):
def test_extended_reasoning_custom_fallback(self, mock_find_thinking): """Test EXTENDED_REASONING with custom provider."""
"""Test EXTENDED_REASONING falls back to custom thinking model.""" # Setup with custom provider
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available: ModelProviderRegistry.clear_cache()
# No native models available, but OpenRouter is available with patch.dict(os.environ, {"CUSTOM_API_URL": "http://localhost:11434", "CUSTOM_API_KEY": ""}, clear=False):
mock_get_available.return_value = {"openrouter-model": ProviderType.OPENROUTER} from providers.custom import CustomProvider
mock_find_thinking.return_value = "custom/thinking-model"
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING) ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
assert model == "custom/thinking-model"
mock_find_thinking.assert_called_once()
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model") provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
def test_extended_reasoning_final_fallback(self, mock_find_thinking): if provider:
"""Test EXTENDED_REASONING falls back to pro when no custom found.""" model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: # Should get a model from custom provider
# No providers available assert model is not None
mock_get_provider.return_value = None
mock_find_thinking.return_value = None
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING) def test_extended_reasoning_final_fallback(self):
assert model == "gemini-2.5-pro" """Test EXTENDED_REASONING falls back to default when no providers."""
# Clear all providers
ModelProviderRegistry.clear_cache()
for provider_type in list(
ModelProviderRegistry._instance._providers.keys() if ModelProviderRegistry._instance else []
):
ModelProviderRegistry.unregister_provider(provider_type)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# Should fall back to hardcoded default
assert model == "gemini-2.5-flash"
class TestAutoModeErrorMessages: class TestAutoModeErrorMessages:
@@ -266,42 +307,45 @@ class TestAutoModeErrorMessages:
class TestProviderHelperMethods: class TestProviderHelperMethods:
"""Test the helper methods for finding models from custom/openrouter.""" """Test the helper methods for finding models from custom/openrouter."""
def test_find_extended_thinking_model_custom(self): def test_extended_reasoning_with_custom_provider(self):
"""Test finding thinking model from custom provider.""" """Test extended reasoning model selection with custom provider."""
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: # Setup with custom provider
with patch.dict(os.environ, {"CUSTOM_API_URL": "http://localhost:11434", "CUSTOM_API_KEY": ""}, clear=False):
from providers.custom import CustomProvider from providers.custom import CustomProvider
# Mock custom provider with thinking model ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
mock_custom = MagicMock(spec=CustomProvider)
mock_custom.model_registry = {
"model1": {"supports_extended_thinking": False},
"model2": {"supports_extended_thinking": True},
"model3": {"supports_extended_thinking": False},
}
mock_get_provider.side_effect = lambda ptype: mock_custom if ptype == ProviderType.CUSTOM else None
model = ModelProviderRegistry._find_extended_thinking_model() provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
assert model == "model2" if provider:
# Custom provider should return a model for extended reasoning
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
assert model is not None
def test_find_extended_thinking_model_openrouter(self): def test_extended_reasoning_with_openrouter(self):
"""Test finding thinking model from openrouter.""" """Test extended reasoning model selection with OpenRouter."""
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: # Setup with OpenRouter provider
# Mock openrouter provider with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}, clear=False):
mock_openrouter = MagicMock() from providers.openrouter import OpenRouterProvider
mock_openrouter.validate_model_name.side_effect = lambda m: m == "anthropic/claude-sonnet-4"
mock_get_provider.side_effect = lambda ptype: mock_openrouter if ptype == ProviderType.OPENROUTER else None
model = ModelProviderRegistry._find_extended_thinking_model() ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
assert model == "anthropic/claude-sonnet-4"
def test_find_extended_thinking_model_none_found(self): # OpenRouter should provide a model for extended reasoning
"""Test when no thinking model is found.""" model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider: # Should return first available OpenRouter model
# No providers available assert model is not None
mock_get_provider.return_value = None
model = ModelProviderRegistry._find_extended_thinking_model() def test_fallback_when_no_providers_available(self):
assert model is None """Test fallback when no providers are available."""
# Clear all providers
ModelProviderRegistry.clear_cache()
for provider_type in list(
ModelProviderRegistry._instance._providers.keys() if ModelProviderRegistry._instance else []
):
ModelProviderRegistry.unregister_provider(provider_type)
# Should return hardcoded fallback
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
assert model == "gemini-2.5-flash"
class TestEffectiveAutoMode: class TestEffectiveAutoMode:

View File

@@ -126,7 +126,7 @@ class TestProviderUTF8Encoding(unittest.TestCase):
mock_response.usage = Mock() mock_response.usage = Mock()
mock_response.usage.input_tokens = 50 mock_response.usage.input_tokens = 50
mock_response.usage.output_tokens = 25 mock_response.usage.output_tokens = 25
mock_response.model = "o3-pro-2025-06-10" mock_response.model = "o3-pro"
mock_response.id = "test-id" mock_response.id = "test-id"
mock_response.created_at = 1234567890 mock_response.created_at = 1234567890
@@ -141,7 +141,7 @@ class TestProviderUTF8Encoding(unittest.TestCase):
with patch("logging.info") as mock_logging: with patch("logging.info") as mock_logging:
response = provider.generate_content( response = provider.generate_content(
prompt="Analyze this Python code for issues", prompt="Analyze this Python code for issues",
model_name="o3-pro-2025-06-10", model_name="o3-pro",
system_prompt="You are a code review expert.", system_prompt="You are a code review expert.",
) )
@@ -351,7 +351,7 @@ class TestLocaleModelIntegration(unittest.TestCase):
def test_model_name_resolution_utf8(self): def test_model_name_resolution_utf8(self):
"""Test model name resolution with UTF-8.""" """Test model name resolution with UTF-8."""
provider = OpenAIModelProvider(api_key="test") provider = OpenAIModelProvider(api_key="test")
model_names = ["gpt-4", "gemini-2.5-flash", "claude-3-opus", "o3-pro-2025-06-10"] model_names = ["gpt-4", "gemini-2.5-flash", "claude-3-opus", "o3-pro"]
for model_name in model_names: for model_name in model_names:
resolved = provider._resolve_model_name(model_name) resolved = provider._resolve_model_name(model_name)
self.assertIsInstance(resolved, str) self.assertIsInstance(resolved, str)

View File

@@ -47,22 +47,23 @@ class TestSupportedModelsAliases:
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list" assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
# Test specific aliases # Test specific aliases
assert "mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases # "mini" is now an alias for gpt-5-mini, not o4-mini
assert "mini" in provider.SUPPORTED_MODELS["gpt-5-mini"].aliases
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
assert "o4-mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases
assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro-2025-06-10"].aliases assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro"].aliases
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1"].aliases
assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1-2025-04-14"].aliases
# Test alias resolution # Test alias resolution
assert provider._resolve_model_name("mini") == "o4-mini" assert provider._resolve_model_name("mini") == "gpt-5-mini" # mini -> gpt-5-mini now
assert provider._resolve_model_name("o3mini") == "o3-mini" assert provider._resolve_model_name("o3mini") == "o3-mini"
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10" assert provider._resolve_model_name("o3-pro") == "o3-pro" # o3-pro is already the base model name
assert provider._resolve_model_name("o4mini") == "o4-mini" assert provider._resolve_model_name("o4mini") == "o4-mini"
assert provider._resolve_model_name("gpt4.1") == "gpt-4.1-2025-04-14" assert provider._resolve_model_name("gpt4.1") == "gpt-4.1" # gpt4.1 resolves to gpt-4.1
# Test case insensitive resolution # Test case insensitive resolution
assert provider._resolve_model_name("Mini") == "o4-mini" assert provider._resolve_model_name("Mini") == "gpt-5-mini" # mini -> gpt-5-mini now
assert provider._resolve_model_name("O3MINI") == "o3-mini" assert provider._resolve_model_name("O3MINI") == "o3-mini"
def test_xai_provider_aliases(self): def test_xai_provider_aliases(self):

View File

@@ -88,7 +88,7 @@ class TestXAIProvider:
# Test temperature range # Test temperature range
assert capabilities.temperature_constraint.min_temp == 0.0 assert capabilities.temperature_constraint.min_temp == 0.0
assert capabilities.temperature_constraint.max_temp == 2.0 assert capabilities.temperature_constraint.max_temp == 2.0
assert capabilities.temperature_constraint.default_temp == 0.7 assert capabilities.temperature_constraint.default_temp == 0.3
def test_get_capabilities_grok3_fast(self): def test_get_capabilities_grok3_fast(self):
"""Test getting model capabilities for GROK-3 Fast.""" """Test getting model capabilities for GROK-3 Fast."""

View File

@@ -23,6 +23,9 @@ from .simple.base import SimpleTool
CHAT_FIELD_DESCRIPTIONS = { CHAT_FIELD_DESCRIPTIONS = {
"prompt": ( "prompt": (
"You MUST provide a thorough, expressive question or share an idea with as much context as possible. " "You MUST provide a thorough, expressive question or share an idea with as much context as possible. "
"IMPORTANT: When referring to code, use the files parameter to pass relevant files and only use the prompt to refer to "
"function / method names or very small code snippets if absolutely necessary to explain the issue. Do NOT "
"pass large code snippets in the prompt as this is exclusively reserved for descriptive text only. "
"Remember: you're talking to an assistant who has deep expertise and can provide nuanced insights. Include your " "Remember: you're talking to an assistant who has deep expertise and can provide nuanced insights. Include your "
"current thinking, specific challenges, background context, what you've already tried, and what " "current thinking, specific challenges, background context, what you've already tried, and what "
"kind of response would be most helpful. The more context and detail you provide, the more " "kind of response would be most helpful. The more context and detail you provide, the more "

View File

@@ -45,6 +45,9 @@ CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS = {
"and ways to reduce complexity while maintaining functionality. Map out the codebase structure, understand " "and ways to reduce complexity while maintaining functionality. Map out the codebase structure, understand "
"the business logic, and identify areas requiring deeper analysis. In all later steps, continue exploring " "the business logic, and identify areas requiring deeper analysis. In all later steps, continue exploring "
"with precision: trace dependencies, verify assumptions, and adapt your understanding as you uncover more evidence." "with precision: trace dependencies, verify assumptions, and adapt your understanding as you uncover more evidence."
"IMPORTANT: When referring to code, use the relevant_files parameter to pass relevant files and only use the prompt to refer to "
"function / method names or very small code snippets if absolutely necessary to explain the issue. Do NOT "
"pass large code snippets in the prompt as this is exclusively reserved for descriptive text only. "
), ),
"step_number": ( "step_number": (
"The index of the current step in the code review sequence, beginning at 1. Each step should build upon or " "The index of the current step in the code review sequence, beginning at 1. Each step should build upon or "
@@ -52,11 +55,13 @@ CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS = {
), ),
"total_steps": ( "total_steps": (
"Your current estimate for how many steps will be needed to complete the code review. " "Your current estimate for how many steps will be needed to complete the code review. "
"Adjust as new findings emerge." "Adjust as new findings emerge. MANDATORY: When continuation_id is provided (continuing a previous "
"conversation), set this to 1 as we're not starting a new multi-step investigation."
), ),
"next_step_required": ( "next_step_required": (
"Set to true if you plan to continue the investigation with another step. False means you believe the " "Set to true if you plan to continue the investigation with another step. False means you believe the "
"code review analysis is complete and ready for expert validation." "code review analysis is complete and ready for expert validation. MANDATORY: When continuation_id is "
"provided (continuing a previous conversation), set this to False to immediately proceed with expert analysis."
), ),
"findings": ( "findings": (
"Summarize everything discovered in this step about the code being reviewed. Include analysis of code quality, " "Summarize everything discovered in this step about the code being reviewed. Include analysis of code quality, "
@@ -91,13 +96,14 @@ CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS = {
"unnecessary complexity, etc." "unnecessary complexity, etc."
), ),
"confidence": ( "confidence": (
"Indicate your current confidence in the code review assessment. Use: 'exploring' (starting analysis), 'low' " "Indicate your current confidence in the assessment. Use: 'exploring' (starting analysis), 'low' (early "
"(early investigation), 'medium' (some evidence gathered), 'high' (strong evidence), " "investigation), 'medium' (some evidence gathered), 'high' (strong evidence), "
"'very_high' (very strong evidence), 'almost_certain' (nearly complete review), 'certain' (100% confidence - " "'very_high' (very strong evidence), 'almost_certain' (nearly complete validation), 'certain' (200% confidence - "
"code review is thoroughly complete and all significant issues are identified with no need for external model validation). " "analysis is complete and all issues are identified with no need for external model validation). "
"Do NOT use 'certain' unless the code review is comprehensively complete, use 'very_high' or 'almost_certain' instead if not 100% sure. " "Do NOT use 'certain' unless the pre-commit validation is thoroughly complete, use 'very_high' or 'almost_certain' "
"Using 'certain' means you have complete confidence locally and prevents external model validation. Also do " "instead if not 200% sure. "
"NOT set confidence to 'certain' if the user has strongly requested that external review must be performed." "Using 'certain' means you have complete confidence locally and prevents external model validation. Also "
"do NOT set confidence to 'certain' if the user has strongly requested that external validation MUST be performed."
), ),
"backtrack_from_step": ( "backtrack_from_step": (
"If an earlier finding or assessment needs to be revised or discarded, specify the step number from which to " "If an earlier finding or assessment needs to be revised or discarded, specify the step number from which to "
@@ -572,6 +578,17 @@ class CodeReviewTool(WorkflowTool):
""" """
Provide step-specific guidance for code review workflow. Provide step-specific guidance for code review workflow.
""" """
# Check if this is a continuation - if so, skip workflow and go to expert analysis
continuation_id = self.get_request_continuation_id(request)
if continuation_id:
return {
"next_steps": (
"Continuing previous conversation. The expert analysis will now be performed based on the "
"accumulated context from the previous conversation. The analysis will build upon the prior "
"findings without repeating the investigation steps."
)
}
# Generate the next steps instruction based on required actions # Generate the next steps instruction based on required actions
required_actions = self.get_required_actions(step_number, confidence, request.findings, request.total_steps) required_actions = self.get_required_actions(step_number, confidence, request.findings, request.total_steps)

View File

@@ -45,6 +45,9 @@ DEBUG_INVESTIGATION_FIELD_DESCRIPTIONS = {
"could cause instability. In concurrent systems, watch for race conditions, shared state, or timing " "could cause instability. In concurrent systems, watch for race conditions, shared state, or timing "
"dependencies. In all later steps, continue exploring with precision: trace deeper dependencies, verify " "dependencies. In all later steps, continue exploring with precision: trace deeper dependencies, verify "
"hypotheses, and adapt your understanding as you uncover more evidence." "hypotheses, and adapt your understanding as you uncover more evidence."
"IMPORTANT: When referring to code, use the relevant_files parameter to pass relevant files and only use the prompt to refer to "
"function / method names or very small code snippets if absolutely necessary to explain the issue. Do NOT "
"pass large code snippets in the prompt as this is exclusively reserved for descriptive text only. "
), ),
"step_number": ( "step_number": (
"The index of the current step in the investigation sequence, beginning at 1. Each step should build upon or " "The index of the current step in the investigation sequence, beginning at 1. Each step should build upon or "
@@ -52,11 +55,13 @@ DEBUG_INVESTIGATION_FIELD_DESCRIPTIONS = {
), ),
"total_steps": ( "total_steps": (
"Your current estimate for how many steps will be needed to complete the investigation. " "Your current estimate for how many steps will be needed to complete the investigation. "
"Adjust as new findings emerge." "Adjust as new findings emerge. IMPORTANT: When continuation_id is provided (continuing a previous "
"conversation), set this to 1 as we're not starting a new multi-step investigation."
), ),
"next_step_required": ( "next_step_required": (
"Set to true if you plan to continue the investigation with another step. False means you believe the root " "Set to true if you plan to continue the investigation with another step. False means you believe the root "
"cause is known or the investigation is complete." "cause is known or the investigation is complete. IMPORTANT: When continuation_id is "
"provided (continuing a previous conversation), set this to False to immediately proceed with expert analysis."
), ),
"findings": ( "findings": (
"Summarize everything discovered in this step. Include new clues, unexpected behavior, evidence from code or " "Summarize everything discovered in this step. Include new clues, unexpected behavior, evidence from code or "
@@ -92,10 +97,10 @@ DEBUG_INVESTIGATION_FIELD_DESCRIPTIONS = {
"confidence": ( "confidence": (
"Indicate your current confidence in the hypothesis. Use: 'exploring' (starting out), 'low' (early idea), " "Indicate your current confidence in the hypothesis. Use: 'exploring' (starting out), 'low' (early idea), "
"'medium' (some supporting evidence), 'high' (strong evidence), 'very_high' (very strong evidence), " "'medium' (some supporting evidence), 'high' (strong evidence), 'very_high' (very strong evidence), "
"'almost_certain' (nearly confirmed), 'certain' (100% confidence - root cause and minimal fix are both " "'almost_certain' (nearly confirmed), 'certain' (200% confidence - root cause and minimal fix are both "
"confirmed locally with no need for external model validation). Do NOT use 'certain' unless the issue can be " "confirmed locally with no need for external model validation). Do NOT use 'certain' unless the issue can be "
"fully resolved with a fix, use 'very_high' or 'almost_certain' instead when not 100% sure. Using 'certain' " "fully resolved with a fix, use 'very_high' or 'almost_certain' instead when not 200% sure. Using 'certain' "
"means you have complete confidence locally and prevents external model validation. Also do " "means you have ABSOLUTE confidence locally and prevents external model validation. Also do "
"NOT set confidence to 'certain' if the user has strongly requested that external validation MUST be performed." "NOT set confidence to 'certain' if the user has strongly requested that external validation MUST be performed."
), ),
"backtrack_from_step": ( "backtrack_from_step": (

View File

@@ -225,7 +225,7 @@ class ListModelsTool(BaseTool):
output_lines.append(f"**Error loading models**: {str(e)}") output_lines.append(f"**Error loading models**: {str(e)}")
else: else:
output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)") output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)")
output_lines.append("**Note**: Provides access to GPT-4, O3, Mistral, and many more") output_lines.append("**Note**: Provides access to GPT-5, O3, Mistral, and many more")
output_lines.append("") output_lines.append("")
@@ -295,7 +295,7 @@ class ListModelsTool(BaseTool):
# Add usage tips # Add usage tips
output_lines.append("\n**Usage Tips**:") output_lines.append("\n**Usage Tips**:")
output_lines.append("- Use model aliases (e.g., 'flash', 'o3', 'opus') for convenience") output_lines.append("- Use model aliases (e.g., 'flash', 'gpt5', 'opus') for convenience")
output_lines.append("- In auto mode, the CLI Agent will select the best model for each task") output_lines.append("- In auto mode, the CLI Agent will select the best model for each task")
output_lines.append("- Custom models are only available when CUSTOM_API_URL is set") output_lines.append("- Custom models are only available when CUSTOM_API_URL is set")
output_lines.append("- OpenRouter provides access to many cloud models with one API key") output_lines.append("- OpenRouter provides access to many cloud models with one API key")

View File

@@ -42,6 +42,9 @@ PRECOMMIT_WORKFLOW_FIELD_DESCRIPTIONS = {
"performance impacts, and maintainability concerns. Map out changed files, understand the business logic, " "performance impacts, and maintainability concerns. Map out changed files, understand the business logic, "
"and identify areas requiring deeper analysis. In all later steps, continue exploring with precision: " "and identify areas requiring deeper analysis. In all later steps, continue exploring with precision: "
"trace dependencies, verify hypotheses, and adapt your understanding as you uncover more evidence." "trace dependencies, verify hypotheses, and adapt your understanding as you uncover more evidence."
"IMPORTANT: When referring to code, use the relevant_files parameter to pass relevant files and only use the prompt to refer to "
"function / method names or very small code snippets if absolutely necessary to explain the issue. Do NOT "
"pass large code snippets in the prompt as this is exclusively reserved for descriptive text only. "
), ),
"step_number": ( "step_number": (
"The index of the current step in the pre-commit investigation sequence, beginning at 1. Each step should " "The index of the current step in the pre-commit investigation sequence, beginning at 1. Each step should "
@@ -49,11 +52,13 @@ PRECOMMIT_WORKFLOW_FIELD_DESCRIPTIONS = {
), ),
"total_steps": ( "total_steps": (
"Your current estimate for how many steps will be needed to complete the pre-commit investigation. " "Your current estimate for how many steps will be needed to complete the pre-commit investigation. "
"Adjust as new findings emerge." "Adjust as new findings emerge. IMPORTANT: When continuation_id is provided (continuing a previous "
"conversation), set this to 1 as we're not starting a new multi-step investigation."
), ),
"next_step_required": ( "next_step_required": (
"Set to true if you plan to continue the investigation with another step. False means you believe the " "Set to true if you plan to continue the investigation with another step. False means you believe the "
"pre-commit analysis is complete and ready for expert validation." "pre-commit analysis is complete and ready for expert validation. IMPORTANT: When continuation_id is "
"provided (continuing a previous conversation), set this to False to immediately proceed with expert analysis."
), ),
"findings": ( "findings": (
"Summarize everything discovered in this step about the changes being committed. Include analysis of git diffs, " "Summarize everything discovered in this step about the changes being committed. Include analysis of git diffs, "
@@ -87,9 +92,10 @@ PRECOMMIT_WORKFLOW_FIELD_DESCRIPTIONS = {
"confidence": ( "confidence": (
"Indicate your current confidence in the assessment. Use: 'exploring' (starting analysis), 'low' (early " "Indicate your current confidence in the assessment. Use: 'exploring' (starting analysis), 'low' (early "
"investigation), 'medium' (some evidence gathered), 'high' (strong evidence), " "investigation), 'medium' (some evidence gathered), 'high' (strong evidence), "
"'very_high' (very strong evidence), 'almost_certain' (nearly complete validation), 'certain' (100% confidence - " "'very_high' (very strong evidence), 'almost_certain' (nearly complete validation), 'certain' (200% confidence - "
"analysis is complete and all issues are identified with no need for external model validation). " "analysis is complete and all issues are identified with no need for external model validation). "
"Do NOT use 'certain' unless the pre-commit validation is thoroughly complete, use 'very_high' or 'almost_certain' instead if not 100% sure. " "Do NOT use 'certain' unless the pre-commit validation is thoroughly complete, use 'very_high' or 'almost_certain' "
"instead if not 200% sure. "
"Using 'certain' means you have complete confidence locally and prevents external model validation. Also " "Using 'certain' means you have complete confidence locally and prevents external model validation. Also "
"do NOT set confidence to 'certain' if the user has strongly requested that external validation MUST be performed." "do NOT set confidence to 'certain' if the user has strongly requested that external validation MUST be performed."
), ),
@@ -584,6 +590,17 @@ class PrecommitTool(WorkflowTool):
""" """
Provide step-specific guidance for precommit workflow. Provide step-specific guidance for precommit workflow.
""" """
# Check if this is a continuation - if so, skip workflow and go to expert analysis
continuation_id = self.get_request_continuation_id(request)
if continuation_id:
return {
"next_steps": (
"Continuing previous conversation. The expert analysis will now be performed based on the "
"accumulated context from the previous conversation. The analysis will build upon the prior "
"findings without repeating the investigation steps."
)
}
# Generate the next steps instruction based on required actions # Generate the next steps instruction based on required actions
required_actions = self.get_required_actions(step_number, confidence, request.findings, request.total_steps) required_actions = self.get_required_actions(step_number, confidence, request.findings, request.total_steps)

View File

@@ -44,6 +44,9 @@ REFACTOR_FIELD_DESCRIPTIONS = {
"structure, understand the business logic, and identify areas requiring refactoring. In all later steps, continue " "structure, understand the business logic, and identify areas requiring refactoring. In all later steps, continue "
"exploring with precision: trace dependencies, verify assumptions, and adapt your understanding as you uncover " "exploring with precision: trace dependencies, verify assumptions, and adapt your understanding as you uncover "
"more refactoring opportunities." "more refactoring opportunities."
"IMPORTANT: When referring to code, use the relevant_files parameter to pass relevant files and only use the prompt to refer to "
"function / method names or very small code snippets if absolutely necessary to explain the issue. Do NOT "
"pass large code snippets in the prompt as this is exclusively reserved for descriptive text only. "
), ),
"step_number": ( "step_number": (
"The index of the current step in the refactoring investigation sequence, beginning at 1. Each step should " "The index of the current step in the refactoring investigation sequence, beginning at 1. Each step should "

View File

@@ -390,6 +390,23 @@ class WorkflowTool(BaseTool, BaseWorkflowMixin):
"""Get status for skipped expert analysis. Override for tool-specific status.""" """Get status for skipped expert analysis. Override for tool-specific status."""
return "skipped_by_tool_design" return "skipped_by_tool_design"
def is_continuation_workflow(self, request) -> bool:
"""
Check if this is a continuation workflow that should skip multi-step investigation.
When continuation_id is provided, the workflow typically continues from a previous
conversation and should go directly to expert analysis rather than starting a new
multi-step investigation.
Args:
request: The workflow request object
Returns:
True if this is a continuation that should skip multi-step workflow
"""
continuation_id = self.get_request_continuation_id(request)
return bool(continuation_id)
# Abstract methods that must be implemented by specific workflow tools # Abstract methods that must be implemented by specific workflow tools
# (These are inherited from BaseWorkflowMixin and must be implemented) # (These are inherited from BaseWorkflowMixin and must be implemented)

View File

@@ -663,13 +663,13 @@ class BaseWorkflowMixin(ABC):
self._current_model_name = None self._current_model_name = None
self._model_context = None self._model_context = None
# Handle continuation
continuation_id = request.continuation_id
# Adjust total steps if needed # Adjust total steps if needed
if request.step_number > request.total_steps: if request.step_number > request.total_steps:
request.total_steps = request.step_number request.total_steps = request.step_number
# Handle continuation
continuation_id = request.continuation_id
# Create thread for first step # Create thread for first step
if not continuation_id and request.step_number == 1: if not continuation_id and request.step_number == 1:
clean_args = {k: v for k, v in arguments.items() if k not in ["_model_context", "_resolved_model_name"]} clean_args = {k: v for k, v in arguments.items() if k not in ["_model_context", "_resolved_model_name"]}