Merge branch 'feat-local_support_with_UTF-8_encoding-update' of https://github.com/GiGiDKR/zen-mcp-server into feat-local_support_with_UTF-8_encoding-update

This commit is contained in:
OhMyApps
2025-06-23 22:24:47 +02:00
57 changed files with 1589 additions and 863 deletions

View File

@@ -128,7 +128,28 @@ python communication_simulator_test.py
python communication_simulator_test.py --verbose python communication_simulator_test.py --verbose
``` ```
#### Run Individual Simulator Tests (Recommended) #### Quick Test Mode (Recommended for Time-Limited Testing)
```bash
# Run quick test mode - 6 essential tests that provide maximum functionality coverage
python communication_simulator_test.py --quick
# Run quick test mode with verbose output
python communication_simulator_test.py --quick --verbose
```
**Quick mode runs these 6 essential tests:**
- `cross_tool_continuation` - Cross-tool conversation memory testing (chat, thinkdeep, codereview, analyze, debug)
- `conversation_chain_validation` - Core conversation threading and memory validation
- `consensus_workflow_accurate` - Consensus tool with flash model and stance testing
- `codereview_validation` - CodeReview tool with flash model and multi-step workflows
- `planner_validation` - Planner tool with flash model and complex planning workflows
- `token_allocation_validation` - Token allocation and conversation history buildup testing
**Why these 6 tests:** They cover the core functionality including conversation memory (`utils/conversation_memory.py`), chat tool functionality, file processing and deduplication, model selection (flash/flashlite/o3), and cross-tool conversation workflows. These tests validate the most critical parts of the system in minimal time.
**Note:** Some workflow tools (analyze, codereview, planner, consensus, etc.) require specific workflow parameters and may need individual testing rather than quick mode testing.
#### Run Individual Simulator Tests (For Detailed Testing)
```bash ```bash
# List all available tests # List all available tests
python communication_simulator_test.py --list-tests python communication_simulator_test.py --list-tests
@@ -223,15 +244,17 @@ python -m pytest tests/ -v
#### After Making Changes #### After Making Changes
1. Run quality checks again: `./code_quality_checks.sh` 1. Run quality checks again: `./code_quality_checks.sh`
2. Run integration tests locally: `./run_integration_tests.sh` 2. Run integration tests locally: `./run_integration_tests.sh`
3. Run relevant simulator tests: `python communication_simulator_test.py --individual <test_name>` 3. Run quick test mode for fast validation: `python communication_simulator_test.py --quick`
4. Check logs for any issues: `tail -n 100 logs/mcp_server.log` 4. Run relevant specific simulator tests if needed: `python communication_simulator_test.py --individual <test_name>`
5. Restart Claude session to use updated code 5. Check logs for any issues: `tail -n 100 logs/mcp_server.log`
6. Restart Claude session to use updated code
#### Before Committing/PR #### Before Committing/PR
1. Final quality check: `./code_quality_checks.sh` 1. Final quality check: `./code_quality_checks.sh`
2. Run integration tests: `./run_integration_tests.sh` 2. Run integration tests: `./run_integration_tests.sh`
3. Run full simulator test suite: `./run_integration_tests.sh --with-simulator` 3. Run quick test mode: `python communication_simulator_test.py --quick`
4. Verify all tests pass 100% 4. Run full simulator test suite (optional): `./run_integration_tests.sh --with-simulator`
5. Verify all tests pass 100%
### Common Troubleshooting ### Common Troubleshooting
@@ -250,6 +273,9 @@ which python
#### Test Failures #### Test Failures
```bash ```bash
# First try quick test mode to see if it's a general issue
python communication_simulator_test.py --quick --verbose
# Run individual failing test with verbose output # Run individual failing test with verbose output
python communication_simulator_test.py --individual <test_name> --verbose python communication_simulator_test.py --individual <test_name> --verbose

View File

@@ -409,7 +409,7 @@ for most debugging workflows, as Claude is usually able to confidently find the
When in doubt, you can always follow up with a new prompt and ask Claude to share its findings with another model: When in doubt, you can always follow up with a new prompt and ask Claude to share its findings with another model:
```text ```text
Use continuation with thinkdeep, share details with o4-mini-high to find out what the best fix is for this Use continuation with thinkdeep, share details with o4-mini to find out what the best fix is for this
``` ```
**[📖 Read More](docs/tools/debug.md)** - Step-by-step investigation methodology with workflow enforcement **[📖 Read More](docs/tools/debug.md)** - Step-by-step investigation methodology with workflow enforcement

View File

@@ -38,6 +38,15 @@ Available tests:
debug_validation - Debug tool validation with actual bugs debug_validation - Debug tool validation with actual bugs
conversation_chain_validation - Conversation chain continuity validation conversation_chain_validation - Conversation chain continuity validation
Quick Test Mode (for time-limited testing):
Use --quick to run the essential 6 tests that provide maximum coverage:
- cross_tool_continuation (cross-tool conversation memory)
- basic_conversation (basic chat functionality)
- content_validation (content validation and deduplication)
- model_thinking_config (flash/flashlite model testing)
- o3_model_selection (o3 model selection testing)
- per_tool_deduplication (file deduplication for individual tools)
Examples: Examples:
# Run all tests # Run all tests
python communication_simulator_test.py python communication_simulator_test.py
@@ -48,6 +57,9 @@ Examples:
# Run a single test individually (with full standalone setup) # Run a single test individually (with full standalone setup)
python communication_simulator_test.py --individual content_validation python communication_simulator_test.py --individual content_validation
# Run quick test mode (essential 6 tests for time-limited testing)
python communication_simulator_test.py --quick
# Force setup standalone server environment before running tests # Force setup standalone server environment before running tests
python communication_simulator_test.py --setup python communication_simulator_test.py --setup
@@ -68,21 +80,48 @@ class CommunicationSimulator:
"""Simulates real-world Claude CLI communication with MCP Gemini server""" """Simulates real-world Claude CLI communication with MCP Gemini server"""
def __init__( def __init__(
self, verbose: bool = False, keep_logs: bool = False, selected_tests: list[str] = None, setup: bool = False self,
verbose: bool = False,
keep_logs: bool = False,
selected_tests: list[str] = None,
setup: bool = False,
quick_mode: bool = False,
): ):
self.verbose = verbose self.verbose = verbose
self.keep_logs = keep_logs self.keep_logs = keep_logs
self.selected_tests = selected_tests or [] self.selected_tests = selected_tests or []
self.setup = setup self.setup = setup
self.quick_mode = quick_mode
self.temp_dir = None self.temp_dir = None
self.server_process = None self.server_process = None
self.python_path = self._get_python_path() self.python_path = self._get_python_path()
# Configure logging first
log_level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s")
self.logger = logging.getLogger(__name__)
# Import test registry # Import test registry
from simulator_tests import TEST_REGISTRY from simulator_tests import TEST_REGISTRY
self.test_registry = TEST_REGISTRY self.test_registry = TEST_REGISTRY
# Define quick mode tests (essential tests for time-limited testing)
# Focus on tests that work with current tool configurations
self.quick_mode_tests = [
"cross_tool_continuation", # Cross-tool conversation memory
"basic_conversation", # Basic chat functionality
"content_validation", # Content validation and deduplication
"model_thinking_config", # Flash/flashlite model testing
"o3_model_selection", # O3 model selection testing
"per_tool_deduplication", # File deduplication for individual tools
]
# If quick mode is enabled, override selected_tests
if self.quick_mode:
self.selected_tests = self.quick_mode_tests
self.logger.info(f"Quick mode enabled - running {len(self.quick_mode_tests)} essential tests")
# Available test methods mapping # Available test methods mapping
self.available_tests = { self.available_tests = {
name: self._create_test_runner(test_class) for name, test_class in self.test_registry.items() name: self._create_test_runner(test_class) for name, test_class in self.test_registry.items()
@@ -91,11 +130,6 @@ class CommunicationSimulator:
# Test result tracking # Test result tracking
self.test_results = dict.fromkeys(self.test_registry.keys(), False) self.test_results = dict.fromkeys(self.test_registry.keys(), False)
# Configure logging
log_level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s")
self.logger = logging.getLogger(__name__)
def _get_python_path(self) -> str: def _get_python_path(self) -> str:
"""Get the Python path for the virtual environment""" """Get the Python path for the virtual environment"""
current_dir = os.getcwd() current_dir = os.getcwd()
@@ -415,6 +449,9 @@ def parse_arguments():
parser.add_argument("--tests", "-t", nargs="+", help="Specific tests to run (space-separated)") parser.add_argument("--tests", "-t", nargs="+", help="Specific tests to run (space-separated)")
parser.add_argument("--list-tests", action="store_true", help="List available tests and exit") parser.add_argument("--list-tests", action="store_true", help="List available tests and exit")
parser.add_argument("--individual", "-i", help="Run a single test individually") parser.add_argument("--individual", "-i", help="Run a single test individually")
parser.add_argument(
"--quick", "-q", action="store_true", help="Run quick test mode (6 essential tests for time-limited testing)"
)
parser.add_argument( parser.add_argument(
"--setup", action="store_true", help="Force setup standalone server environment using run-server.sh" "--setup", action="store_true", help="Force setup standalone server environment using run-server.sh"
) )
@@ -492,7 +529,11 @@ def main():
# Initialize simulator consistently for all use cases # Initialize simulator consistently for all use cases
simulator = CommunicationSimulator( simulator = CommunicationSimulator(
verbose=args.verbose, keep_logs=args.keep_logs, selected_tests=args.tests, setup=args.setup verbose=args.verbose,
keep_logs=args.keep_logs,
selected_tests=args.tests,
setup=args.setup,
quick_mode=args.quick,
) )
# Determine execution mode and run # Determine execution mode and run

View File

@@ -22,6 +22,7 @@
"model_name": "The model identifier - OpenRouter format (e.g., 'anthropic/claude-opus-4') or custom model name (e.g., 'llama3.2')", "model_name": "The model identifier - OpenRouter format (e.g., 'anthropic/claude-opus-4') or custom model name (e.g., 'llama3.2')",
"aliases": "Array of short names users can type instead of the full model name", "aliases": "Array of short names users can type instead of the full model name",
"context_window": "Total number of tokens the model can process (input + output combined)", "context_window": "Total number of tokens the model can process (input + output combined)",
"max_output_tokens": "Maximum number of tokens the model can generate in a single response",
"supports_extended_thinking": "Whether the model supports extended reasoning tokens (currently none do via OpenRouter or custom APIs)", "supports_extended_thinking": "Whether the model supports extended reasoning tokens (currently none do via OpenRouter or custom APIs)",
"supports_json_mode": "Whether the model can guarantee valid JSON output", "supports_json_mode": "Whether the model can guarantee valid JSON output",
"supports_function_calling": "Whether the model supports function/tool calling", "supports_function_calling": "Whether the model supports function/tool calling",
@@ -36,6 +37,7 @@
"model_name": "my-local-model", "model_name": "my-local-model",
"aliases": ["shortname", "nickname", "abbrev"], "aliases": ["shortname", "nickname", "abbrev"],
"context_window": 128000, "context_window": 128000,
"max_output_tokens": 32768,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": true, "supports_json_mode": true,
"supports_function_calling": true, "supports_function_calling": true,
@@ -52,6 +54,7 @@
"model_name": "anthropic/claude-opus-4", "model_name": "anthropic/claude-opus-4",
"aliases": ["opus", "claude-opus", "claude4-opus", "claude-4-opus"], "aliases": ["opus", "claude-opus", "claude4-opus", "claude-4-opus"],
"context_window": 200000, "context_window": 200000,
"max_output_tokens": 64000,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": false, "supports_json_mode": false,
"supports_function_calling": false, "supports_function_calling": false,
@@ -63,6 +66,7 @@
"model_name": "anthropic/claude-sonnet-4", "model_name": "anthropic/claude-sonnet-4",
"aliases": ["sonnet", "claude-sonnet", "claude4-sonnet", "claude-4-sonnet", "claude"], "aliases": ["sonnet", "claude-sonnet", "claude4-sonnet", "claude-4-sonnet", "claude"],
"context_window": 200000, "context_window": 200000,
"max_output_tokens": 64000,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": false, "supports_json_mode": false,
"supports_function_calling": false, "supports_function_calling": false,
@@ -74,6 +78,7 @@
"model_name": "anthropic/claude-3.5-haiku", "model_name": "anthropic/claude-3.5-haiku",
"aliases": ["haiku", "claude-haiku", "claude3-haiku", "claude-3-haiku"], "aliases": ["haiku", "claude-haiku", "claude3-haiku", "claude-3-haiku"],
"context_window": 200000, "context_window": 200000,
"max_output_tokens": 64000,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": false, "supports_json_mode": false,
"supports_function_calling": false, "supports_function_calling": false,
@@ -85,6 +90,7 @@
"model_name": "google/gemini-2.5-pro", "model_name": "google/gemini-2.5-pro",
"aliases": ["pro","gemini-pro", "gemini", "pro-openrouter"], "aliases": ["pro","gemini-pro", "gemini", "pro-openrouter"],
"context_window": 1048576, "context_window": 1048576,
"max_output_tokens": 65536,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": true, "supports_json_mode": true,
"supports_function_calling": false, "supports_function_calling": false,
@@ -96,6 +102,7 @@
"model_name": "google/gemini-2.5-flash", "model_name": "google/gemini-2.5-flash",
"aliases": ["flash","gemini-flash", "flash-openrouter", "flash-2.5"], "aliases": ["flash","gemini-flash", "flash-openrouter", "flash-2.5"],
"context_window": 1048576, "context_window": 1048576,
"max_output_tokens": 65536,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": true, "supports_json_mode": true,
"supports_function_calling": false, "supports_function_calling": false,
@@ -107,6 +114,7 @@
"model_name": "mistralai/mistral-large-2411", "model_name": "mistralai/mistral-large-2411",
"aliases": ["mistral-large", "mistral"], "aliases": ["mistral-large", "mistral"],
"context_window": 128000, "context_window": 128000,
"max_output_tokens": 32000,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": true, "supports_json_mode": true,
"supports_function_calling": true, "supports_function_calling": true,
@@ -118,6 +126,7 @@
"model_name": "meta-llama/llama-3-70b", "model_name": "meta-llama/llama-3-70b",
"aliases": ["llama", "llama3", "llama3-70b", "llama-70b", "llama3-openrouter"], "aliases": ["llama", "llama3", "llama3-70b", "llama-70b", "llama3-openrouter"],
"context_window": 8192, "context_window": 8192,
"max_output_tokens": 8192,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": false, "supports_json_mode": false,
"supports_function_calling": false, "supports_function_calling": false,
@@ -129,6 +138,7 @@
"model_name": "deepseek/deepseek-r1-0528", "model_name": "deepseek/deepseek-r1-0528",
"aliases": ["deepseek-r1", "deepseek", "r1", "deepseek-thinking"], "aliases": ["deepseek-r1", "deepseek", "r1", "deepseek-thinking"],
"context_window": 65536, "context_window": 65536,
"max_output_tokens": 32768,
"supports_extended_thinking": true, "supports_extended_thinking": true,
"supports_json_mode": true, "supports_json_mode": true,
"supports_function_calling": false, "supports_function_calling": false,
@@ -140,6 +150,7 @@
"model_name": "perplexity/llama-3-sonar-large-32k-online", "model_name": "perplexity/llama-3-sonar-large-32k-online",
"aliases": ["perplexity", "sonar", "perplexity-online"], "aliases": ["perplexity", "sonar", "perplexity-online"],
"context_window": 32768, "context_window": 32768,
"max_output_tokens": 32768,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": false, "supports_json_mode": false,
"supports_function_calling": false, "supports_function_calling": false,
@@ -151,6 +162,7 @@
"model_name": "openai/o3", "model_name": "openai/o3",
"aliases": ["o3"], "aliases": ["o3"],
"context_window": 200000, "context_window": 200000,
"max_output_tokens": 100000,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": true, "supports_json_mode": true,
"supports_function_calling": true, "supports_function_calling": true,
@@ -164,6 +176,7 @@
"model_name": "openai/o3-mini", "model_name": "openai/o3-mini",
"aliases": ["o3-mini", "o3mini"], "aliases": ["o3-mini", "o3mini"],
"context_window": 200000, "context_window": 200000,
"max_output_tokens": 100000,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": true, "supports_json_mode": true,
"supports_function_calling": true, "supports_function_calling": true,
@@ -177,6 +190,7 @@
"model_name": "openai/o3-mini-high", "model_name": "openai/o3-mini-high",
"aliases": ["o3-mini-high", "o3mini-high"], "aliases": ["o3-mini-high", "o3mini-high"],
"context_window": 200000, "context_window": 200000,
"max_output_tokens": 100000,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": true, "supports_json_mode": true,
"supports_function_calling": true, "supports_function_calling": true,
@@ -190,6 +204,7 @@
"model_name": "openai/o3-pro", "model_name": "openai/o3-pro",
"aliases": ["o3-pro", "o3pro"], "aliases": ["o3-pro", "o3pro"],
"context_window": 200000, "context_window": 200000,
"max_output_tokens": 100000,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": true, "supports_json_mode": true,
"supports_function_calling": true, "supports_function_calling": true,
@@ -203,6 +218,7 @@
"model_name": "openai/o4-mini", "model_name": "openai/o4-mini",
"aliases": ["o4-mini", "o4mini"], "aliases": ["o4-mini", "o4mini"],
"context_window": 200000, "context_window": 200000,
"max_output_tokens": 100000,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": true, "supports_json_mode": true,
"supports_function_calling": true, "supports_function_calling": true,
@@ -212,23 +228,11 @@
"temperature_constraint": "fixed", "temperature_constraint": "fixed",
"description": "OpenAI's o4-mini model - optimized for shorter contexts with rapid reasoning and vision" "description": "OpenAI's o4-mini model - optimized for shorter contexts with rapid reasoning and vision"
}, },
{
"model_name": "openai/o4-mini-high",
"aliases": ["o4-mini-high", "o4mini-high", "o4minihigh", "o4minihi"],
"context_window": 200000,
"supports_extended_thinking": false,
"supports_json_mode": true,
"supports_function_calling": true,
"supports_images": true,
"max_image_size_mb": 20.0,
"supports_temperature": false,
"temperature_constraint": "fixed",
"description": "OpenAI's o4-mini with high reasoning effort - enhanced for complex tasks with vision"
},
{ {
"model_name": "llama3.2", "model_name": "llama3.2",
"aliases": ["local-llama", "local", "llama3.2", "ollama-llama"], "aliases": ["local-llama", "local", "llama3.2", "ollama-llama"],
"context_window": 128000, "context_window": 128000,
"max_output_tokens": 64000,
"supports_extended_thinking": false, "supports_extended_thinking": false,
"supports_json_mode": false, "supports_json_mode": false,
"supports_function_calling": false, "supports_function_calling": false,

View File

@@ -14,7 +14,7 @@ import os
# These values are used in server responses and for tracking releases # These values are used in server responses and for tracking releases
# IMPORTANT: This is the single source of truth for version and author info # IMPORTANT: This is the single source of truth for version and author info
# Semantic versioning: MAJOR.MINOR.PATCH # Semantic versioning: MAJOR.MINOR.PATCH
__version__ = "5.6.1" __version__ = "5.7.0"
# Last update date in ISO format # Last update date in ISO format
__updated__ = "2025-06-23" __updated__ = "2025-06-23"
# Primary maintainer # Primary maintainer

View File

@@ -38,7 +38,6 @@ Regardless of your default configuration, you can specify models per request:
| **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis | | **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis |
| **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks | | **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks |
| **`o4-mini`** | OpenAI | 200K tokens | Latest reasoning model | Optimized for shorter contexts | | **`o4-mini`** | OpenAI | 200K tokens | Latest reasoning model | Optimized for shorter contexts |
| **`o4-mini-high`** | OpenAI | 200K tokens | Enhanced reasoning | Complex tasks requiring deeper analysis |
| **`gpt4.1`** | OpenAI | 1M tokens | Latest GPT-4 with extended context | Large codebase analysis, comprehensive reviews | | **`gpt4.1`** | OpenAI | 1M tokens | Latest GPT-4 with extended context | Large codebase analysis, comprehensive reviews |
| **`llama`** (Llama 3.2) | Custom/Local | 128K tokens | Local inference, privacy | On-device analysis, cost-free processing | | **`llama`** (Llama 3.2) | Custom/Local | 128K tokens | Local inference, privacy | On-device analysis, cost-free processing |
| **Any model** | OpenRouter | Varies | Access to GPT-4, Claude, Llama, etc. | User-specified or based on task requirements | | **Any model** | OpenRouter | Varies | Access to GPT-4, Claude, Llama, etc. | User-specified or based on task requirements |
@@ -69,7 +68,7 @@ OPENAI_ALLOWED_MODELS=o4-mini
# High-performance: Quality over cost # High-performance: Quality over cost
GOOGLE_ALLOWED_MODELS=pro GOOGLE_ALLOWED_MODELS=pro
OPENAI_ALLOWED_MODELS=o3,o4-mini-high OPENAI_ALLOWED_MODELS=o3,o4-mini
``` ```
**Important Notes:** **Important Notes:**
@@ -144,7 +143,7 @@ All tools that work with files support **both individual files and entire direct
**`analyze`** - Analyze files or directories **`analyze`** - Analyze files or directories
- `files`: List of file paths or directories (required) - `files`: List of file paths or directories (required)
- `question`: What to analyze (required) - `question`: What to analyze (required)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `analysis_type`: architecture|performance|security|quality|general - `analysis_type`: architecture|performance|security|quality|general
- `output_format`: summary|detailed|actionable - `output_format`: summary|detailed|actionable
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
@@ -159,7 +158,7 @@ All tools that work with files support **both individual files and entire direct
**`codereview`** - Review code files or directories **`codereview`** - Review code files or directories
- `files`: List of file paths or directories (required) - `files`: List of file paths or directories (required)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `review_type`: full|security|performance|quick - `review_type`: full|security|performance|quick
- `focus_on`: Specific aspects to focus on - `focus_on`: Specific aspects to focus on
- `standards`: Coding standards to enforce - `standards`: Coding standards to enforce
@@ -175,7 +174,7 @@ All tools that work with files support **both individual files and entire direct
**`debug`** - Debug with file context **`debug`** - Debug with file context
- `error_description`: Description of the issue (required) - `error_description`: Description of the issue (required)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `error_context`: Stack trace or logs - `error_context`: Stack trace or logs
- `files`: Files or directories related to the issue - `files`: Files or directories related to the issue
- `runtime_info`: Environment details - `runtime_info`: Environment details
@@ -191,7 +190,7 @@ All tools that work with files support **both individual files and entire direct
**`thinkdeep`** - Extended analysis with file context **`thinkdeep`** - Extended analysis with file context
- `current_analysis`: Your current thinking (required) - `current_analysis`: Your current thinking (required)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `problem_context`: Additional context - `problem_context`: Additional context
- `focus_areas`: Specific aspects to focus on - `focus_areas`: Specific aspects to focus on
- `files`: Files or directories for context - `files`: Files or directories for context
@@ -207,7 +206,7 @@ All tools that work with files support **both individual files and entire direct
**`testgen`** - Comprehensive test generation with edge case coverage **`testgen`** - Comprehensive test generation with edge case coverage
- `files`: Code files or directories to generate tests for (required) - `files`: Code files or directories to generate tests for (required)
- `prompt`: Description of what to test, testing objectives, and scope (required) - `prompt`: Description of what to test, testing objectives, and scope (required)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `test_examples`: Optional existing test files as style/pattern reference - `test_examples`: Optional existing test files as style/pattern reference
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
@@ -222,7 +221,7 @@ All tools that work with files support **both individual files and entire direct
- `files`: Code files or directories to analyze for refactoring opportunities (required) - `files`: Code files or directories to analyze for refactoring opportunities (required)
- `prompt`: Description of refactoring goals, context, and specific areas of focus (required) - `prompt`: Description of refactoring goals, context, and specific areas of focus (required)
- `refactor_type`: codesmells|decompose|modernize|organization (required) - `refactor_type`: codesmells|decompose|modernize|organization (required)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security') - `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security')
- `style_guide_examples`: Optional existing code files to use as style/pattern reference - `style_guide_examples`: Optional existing code files to use as style/pattern reference
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)

View File

@@ -63,7 +63,7 @@ CUSTOM_MODEL_NAME=llama3.2 # Default model
**Default Model Selection:** **Default Model Selection:**
```env ```env
# Options: 'auto', 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high', etc. # Options: 'auto', 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', etc.
DEFAULT_MODEL=auto # Claude picks best model for each task (recommended) DEFAULT_MODEL=auto # Claude picks best model for each task (recommended)
``` ```
@@ -74,7 +74,6 @@ DEFAULT_MODEL=auto # Claude picks best model for each task (recommended)
- **`o3`**: Strong logical reasoning (200K context) - **`o3`**: Strong logical reasoning (200K context)
- **`o3-mini`**: Balanced speed/quality (200K context) - **`o3-mini`**: Balanced speed/quality (200K context)
- **`o4-mini`**: Latest reasoning model, optimized for shorter contexts - **`o4-mini`**: Latest reasoning model, optimized for shorter contexts
- **`o4-mini-high`**: Enhanced O4 with higher reasoning effort
- **`grok`**: GROK-3 advanced reasoning (131K context) - **`grok`**: GROK-3 advanced reasoning (131K context)
- **Custom models**: via OpenRouter or local APIs - **Custom models**: via OpenRouter or local APIs
@@ -120,7 +119,6 @@ OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral
- `o3` (200K context, high reasoning) - `o3` (200K context, high reasoning)
- `o3-mini` (200K context, balanced) - `o3-mini` (200K context, balanced)
- `o4-mini` (200K context, latest balanced) - `o4-mini` (200K context, latest balanced)
- `o4-mini-high` (200K context, enhanced reasoning)
- `mini` (shorthand for o4-mini) - `mini` (shorthand for o4-mini)
**Gemini Models:** **Gemini Models:**

View File

@@ -65,7 +65,7 @@ This workflow ensures methodical analysis before expert insights, resulting in d
**Initial Configuration (used in step 1):** **Initial Configuration (used in step 1):**
- `prompt`: What to analyze or look for (required) - `prompt`: What to analyze or look for (required)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `analysis_type`: architecture|performance|security|quality|general (default: general) - `analysis_type`: architecture|performance|security|quality|general (default: general)
- `output_format`: summary|detailed|actionable (default: detailed) - `output_format`: summary|detailed|actionable (default: detailed)
- `temperature`: Temperature for analysis (0-1, default 0.2) - `temperature`: Temperature for analysis (0-1, default 0.2)

View File

@@ -33,7 +33,7 @@ and then debate with the other models to give me a final verdict
## Tool Parameters ## Tool Parameters
- `prompt`: Your question or discussion topic (required) - `prompt`: Your question or discussion topic (required)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `files`: Optional files for context (absolute paths) - `files`: Optional files for context (absolute paths)
- `images`: Optional images for visual context (absolute paths) - `images`: Optional images for visual context (absolute paths)
- `temperature`: Response creativity (0-1, default 0.5) - `temperature`: Response creativity (0-1, default 0.5)

View File

@@ -80,7 +80,7 @@ The above prompt will simultaneously run two separate `codereview` tools with tw
**Initial Review Configuration (used in step 1):** **Initial Review Configuration (used in step 1):**
- `prompt`: User's summary of what the code does, expected behavior, constraints, and review objectives (required) - `prompt`: User's summary of what the code does, expected behavior, constraints, and review objectives (required)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `review_type`: full|security|performance|quick (default: full) - `review_type`: full|security|performance|quick (default: full)
- `focus_on`: Specific aspects to focus on (e.g., "security vulnerabilities", "performance bottlenecks") - `focus_on`: Specific aspects to focus on (e.g., "security vulnerabilities", "performance bottlenecks")
- `standards`: Coding standards to enforce (e.g., "PEP8", "ESLint", "Google Style Guide") - `standards`: Coding standards to enforce (e.g., "PEP8", "ESLint", "Google Style Guide")

View File

@@ -73,7 +73,7 @@ This structured approach ensures Claude performs methodical groundwork before ex
- `images`: Visual debugging materials (error screenshots, logs, etc.) - `images`: Visual debugging materials (error screenshots, logs, etc.)
**Model Selection:** **Model Selection:**
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini (default: server default)
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
- `use_websearch`: Enable web search for documentation and solutions (default: true) - `use_websearch`: Enable web search for documentation and solutions (default: true)
- `use_assistant_model`: Whether to use expert analysis phase (default: true, set to false to use Claude only) - `use_assistant_model`: Whether to use expert analysis phase (default: true, set to false to use Claude only)

View File

@@ -135,7 +135,7 @@ Use zen and perform a thorough precommit ensuring there aren't any new regressio
**Initial Configuration (used in step 1):** **Initial Configuration (used in step 1):**
- `path`: Starting directory to search for repos (default: current directory, absolute path required) - `path`: Starting directory to search for repos (default: current directory, absolute path required)
- `prompt`: The original user request description for the changes (required for context) - `prompt`: The original user request description for the changes (required for context)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `compare_to`: Compare against a branch/tag instead of local changes (optional) - `compare_to`: Compare against a branch/tag instead of local changes (optional)
- `severity_filter`: critical|high|medium|low|all (default: all) - `severity_filter`: critical|high|medium|low|all (default: all)
- `include_staged`: Include staged changes in the review (default: true) - `include_staged`: Include staged changes in the review (default: true)

View File

@@ -103,7 +103,7 @@ This results in Claude first performing its own expert analysis, encouraging it
**Initial Configuration (used in step 1):** **Initial Configuration (used in step 1):**
- `prompt`: Description of refactoring goals, context, and specific areas of focus (required) - `prompt`: Description of refactoring goals, context, and specific areas of focus (required)
- `refactor_type`: codesmells|decompose|modernize|organization (default: codesmells) - `refactor_type`: codesmells|decompose|modernize|organization (default: codesmells)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security') - `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security')
- `style_guide_examples`: Optional existing code files to use as style/pattern reference (absolute paths) - `style_guide_examples`: Optional existing code files to use as style/pattern reference (absolute paths)
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)

View File

@@ -86,7 +86,7 @@ security remediation plan using planner
- `images`: Architecture diagrams, security documentation, or visual references - `images`: Architecture diagrams, security documentation, or visual references
**Initial Security Configuration (used in step 1):** **Initial Security Configuration (used in step 1):**
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `security_scope`: Application context, technology stack, and security boundary definition (required) - `security_scope`: Application context, technology stack, and security boundary definition (required)
- `threat_level`: low|medium|high|critical (default: medium) - determines assessment depth and urgency - `threat_level`: low|medium|high|critical (default: medium) - determines assessment depth and urgency
- `compliance_requirements`: List of compliance frameworks to assess against (e.g., ["PCI DSS", "SOC2"]) - `compliance_requirements`: List of compliance frameworks to assess against (e.g., ["PCI DSS", "SOC2"])

View File

@@ -70,7 +70,7 @@ Test generation excels with extended reasoning models like Gemini Pro or O3, whi
**Initial Configuration (used in step 1):** **Initial Configuration (used in step 1):**
- `prompt`: Description of what to test, testing objectives, and specific scope/focus areas (required) - `prompt`: Description of what to test, testing objectives, and specific scope/focus areas (required)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `test_examples`: Optional existing test files or directories to use as style/pattern reference (absolute paths) - `test_examples`: Optional existing test files or directories to use as style/pattern reference (absolute paths)
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only) - `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
- `use_assistant_model`: Whether to use expert test generation phase (default: true, set to false to use Claude only) - `use_assistant_model`: Whether to use expert test generation phase (default: true, set to false to use Claude only)

View File

@@ -30,7 +30,7 @@ with the best architecture for my project
## Tool Parameters ## Tool Parameters
- `prompt`: Your current thinking/analysis to extend and validate (required) - `prompt`: Your current thinking/analysis to extend and validate (required)
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default) - `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
- `problem_context`: Additional context about the problem or goal - `problem_context`: Additional context about the problem or goal
- `focus_areas`: Specific aspects to focus on (architecture, performance, security, etc.) - `focus_areas`: Specific aspects to focus on (architecture, performance, security, etc.)
- `files`: Optional file paths or directories for additional context (absolute paths) - `files`: Optional file paths or directories for additional context (absolute paths)

View File

@@ -132,6 +132,7 @@ class ModelCapabilities:
model_name: str model_name: str
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI" friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
context_window: int # Total context window size in tokens context_window: int # Total context window size in tokens
max_output_tokens: int # Maximum output tokens per request
supports_extended_thinking: bool = False supports_extended_thinking: bool = False
supports_system_prompts: bool = True supports_system_prompts: bool = True
supports_streaming: bool = True supports_streaming: bool = True
@@ -140,6 +141,19 @@ class ModelCapabilities:
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
# Additional fields for comprehensive model information
description: str = "" # Human-readable description of the model
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
# JSON mode support (for providers that support structured output)
supports_json_mode: bool = False
# Thinking mode support (for models with thinking capabilities)
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
# Custom model flag (for models that only work with custom endpoints)
is_custom: bool = False # Whether this model requires custom API endpoints
# Temperature constraint object - preferred way to define temperature limits # Temperature constraint object - preferred way to define temperature limits
temperature_constraint: TemperatureConstraint = field( temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7) default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
@@ -251,7 +265,7 @@ class ModelProvider(ABC):
capabilities = self.get_capabilities(model_name) capabilities = self.get_capabilities(model_name)
# Check if model supports temperature at all # Check if model supports temperature at all
if hasattr(capabilities, "supports_temperature") and not capabilities.supports_temperature: if not capabilities.supports_temperature:
return None return None
# Get temperature range # Get temperature range
@@ -290,19 +304,109 @@ class ModelProvider(ABC):
"""Check if the model supports extended thinking mode.""" """Check if the model supports extended thinking mode."""
pass pass
@abstractmethod def get_model_configurations(self) -> dict[str, ModelCapabilities]:
"""Get model configurations for this provider.
This is a hook method that subclasses can override to provide
their model configurations from different sources.
Returns:
Dictionary mapping model names to their ModelCapabilities objects
"""
# Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects)
if hasattr(self, "SUPPORTED_MODELS"):
return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)}
return {}
def get_all_model_aliases(self) -> dict[str, list[str]]:
"""Get all model aliases for this provider.
This is a hook method that subclasses can override to provide
aliases from different sources.
Returns:
Dictionary mapping model names to their list of aliases
"""
# Default implementation extracts from ModelCapabilities objects
aliases = {}
for model_name, capabilities in self.get_model_configurations().items():
if capabilities.aliases:
aliases[model_name] = capabilities.aliases
return aliases
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name.
This implementation uses the hook methods to support different
model configuration sources.
Args:
model_name: Model name that may be an alias
Returns:
Resolved model name
"""
# Get model configurations from the hook method
model_configs = self.get_model_configurations()
# First check if it's already a base model name (case-sensitive exact match)
if model_name in model_configs:
return model_name
# Check case-insensitively for both base models and aliases
model_name_lower = model_name.lower()
# Check base model names case-insensitively
for base_model in model_configs:
if base_model.lower() == model_name_lower:
return base_model
# Check aliases from the hook method
all_aliases = self.get_all_model_aliases()
for base_model, aliases in all_aliases.items():
if any(alias.lower() == model_name_lower for alias in aliases):
return base_model
# If not found, return as-is
return model_name
def list_models(self, respect_restrictions: bool = True) -> list[str]: def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider. """Return a list of model names supported by this provider.
This implementation uses the get_model_configurations() hook
to support different model configuration sources.
Args: Args:
respect_restrictions: Whether to apply provider-specific restriction logic. respect_restrictions: Whether to apply provider-specific restriction logic.
Returns: Returns:
List of model names available from this provider List of model names available from this provider
""" """
pass from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
# Get model configurations from the hook method
model_configs = self.get_model_configurations()
for model_name in model_configs:
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
# Add the base model
models.append(model_name)
# Get aliases from the hook method
all_aliases = self.get_all_model_aliases()
for model_name, aliases in all_aliases.items():
# Only add aliases for models that passed restriction check
if model_name in models:
models.extend(aliases)
return models
@abstractmethod
def list_all_known_models(self) -> list[str]: def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets. """Return all model names known by this provider, including alias targets.
@@ -312,21 +416,22 @@ class ModelProvider(ABC):
Returns: Returns:
List of all model names and alias targets known by this provider List of all model names and alias targets known by this provider
""" """
pass all_models = set()
def _resolve_model_name(self, model_name: str) -> str: # Get model configurations from the hook method
"""Resolve model shorthand to full name. model_configs = self.get_model_configurations()
Base implementation returns the model name unchanged. # Add all base model names
Subclasses should override to provide alias resolution. for model_name in model_configs:
all_models.add(model_name.lower())
Args: # Get aliases from the hook method and add them
model_name: Model name that may be an alias all_aliases = self.get_all_model_aliases()
for _model_name, aliases in all_aliases.items():
for alias in aliases:
all_models.add(alias.lower())
Returns: return list(all_models)
Resolved model name
"""
return model_name
def close(self): def close(self):
"""Clean up any resources held by the provider. """Clean up any resources held by the provider.

View File

@@ -158,6 +158,7 @@ class CustomProvider(OpenAICompatibleProvider):
model_name=resolved_name, model_name=resolved_name,
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})", friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})",
context_window=32_768, # Conservative default context_window=32_768, # Conservative default
max_output_tokens=32_768, # Conservative default max output
supports_extended_thinking=False, # Most custom models don't support this supports_extended_thinking=False, # Most custom models don't support this
supports_system_prompts=True, supports_system_prompts=True,
supports_streaming=True, supports_streaming=True,
@@ -187,7 +188,7 @@ class CustomProvider(OpenAICompatibleProvider):
Returns: Returns:
True if model is intended for custom/local endpoint True if model is intended for custom/local endpoint
""" """
logging.debug(f"Custom provider validating model: '{model_name}'") # logging.debug(f"Custom provider validating model: '{model_name}'")
# Try to resolve through registry first # Try to resolve through registry first
config = self._registry.resolve(model_name) config = self._registry.resolve(model_name)
@@ -195,12 +196,12 @@ class CustomProvider(OpenAICompatibleProvider):
model_id = config.model_name model_id = config.model_name
# Use explicit is_custom flag for clean validation # Use explicit is_custom flag for clean validation
if config.is_custom: if config.is_custom:
logging.debug(f"Model '{model_name}' -> '{model_id}' validated via registry (custom model)") logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' validated via registry")
return True return True
else: else:
# This is a cloud/OpenRouter model - CustomProvider should NOT handle these # This is a cloud/OpenRouter model - CustomProvider should NOT handle these
# Let OpenRouter provider handle them instead # Let OpenRouter provider handle them instead
logging.debug(f"Model '{model_name}' -> '{model_id}' rejected (cloud model, defer to OpenRouter)") # logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' not custom (defer to OpenRouter)")
return False return False
# Handle version tags for unknown models (e.g., "my-model:latest") # Handle version tags for unknown models (e.g., "my-model:latest")
@@ -268,65 +269,50 @@ class CustomProvider(OpenAICompatibleProvider):
def supports_thinking_mode(self, model_name: str) -> bool: def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode. """Check if the model supports extended thinking mode.
Most custom/local models don't support extended thinking.
Args: Args:
model_name: Model to check model_name: Model to check
Returns: Returns:
False (custom models generally don't support thinking mode) True if model supports thinking mode, False otherwise
""" """
# Check if model is in registry
config = self._registry.resolve(model_name) if self._registry else None
if config and config.is_custom:
# Trust the config from custom_models.json
return config.supports_extended_thinking
# Default to False for unknown models
return False return False
def list_models(self, respect_restrictions: bool = True) -> list[str]: def get_model_configurations(self) -> dict[str, ModelCapabilities]:
"""Return a list of model names supported by this provider. """Get model configurations from the registry.
Args: For CustomProvider, we convert registry configurations to ModelCapabilities objects.
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns: Returns:
List of model names available from this provider Dictionary mapping model names to their ModelCapabilities objects
""" """
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None configs = {}
models = []
if self._registry: if self._registry:
# Get all models from the registry # Get all models from registry
all_models = self._registry.list_models() for model_name in self._registry.list_models():
aliases = self._registry.list_aliases() # Only include custom models that this provider validates
# Add models that are validated by the custom provider
for model_name in all_models + aliases:
# Use the provider's validation logic to determine if this model
# is appropriate for the custom endpoint
if self.validate_model_name(model_name): if self.validate_model_name(model_name):
# Check restrictions if enabled config = self._registry.resolve(model_name)
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name): if config and config.is_custom:
continue # Use ModelCapabilities directly from registry
configs[model_name] = config
models.append(model_name) return configs
return models def get_all_model_aliases(self) -> dict[str, list[str]]:
"""Get all model aliases from the registry.
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
Returns: Returns:
List of all model names and alias targets known by this provider Dictionary mapping model names to their list of aliases
""" """
all_models = set() # Since aliases are now included in the configurations,
# we can use the base class implementation
if self._registry: return super().get_all_model_aliases()
# Get all models and aliases from the registry
all_models.update(model.lower() for model in self._registry.list_models())
all_models.update(alias.lower() for alias in self._registry.list_aliases())
# For each alias, also add its target
for alias in self._registry.list_aliases():
config = self._registry.resolve(alias)
if config:
all_models.add(config.model_name.lower())
return list(all_models)

View File

@@ -10,7 +10,7 @@ from .base import (
ModelCapabilities, ModelCapabilities,
ModelResponse, ModelResponse,
ProviderType, ProviderType,
RangeTemperatureConstraint, create_temperature_constraint,
) )
from .openai_compatible import OpenAICompatibleProvider from .openai_compatible import OpenAICompatibleProvider
@@ -30,63 +30,170 @@ class DIALModelProvider(OpenAICompatibleProvider):
MAX_RETRIES = 4 MAX_RETRIES = 4
RETRY_DELAYS = [1, 3, 5, 8] # seconds RETRY_DELAYS = [1, 3, 5, 8] # seconds
# Supported DIAL models (these can be customized based on your DIAL deployment) # Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"o3-2025-04-16": { "o3-2025-04-16": ModelCapabilities(
"context_window": 200_000, provider=ProviderType.DIAL,
"supports_extended_thinking": False, model_name="o3-2025-04-16",
"supports_vision": True, friendly_name="DIAL (O3)",
}, context_window=200_000,
"o4-mini-2025-04-16": { max_output_tokens=100_000,
"context_window": 200_000, supports_extended_thinking=False,
"supports_extended_thinking": False, supports_system_prompts=True,
"supports_vision": True, supports_streaming=True,
}, supports_function_calling=False, # DIAL may not expose function calling
"anthropic.claude-sonnet-4-20250514-v1:0": { supports_json_mode=True,
"context_window": 200_000, supports_images=True,
"supports_extended_thinking": False, max_image_size_mb=20.0,
"supports_vision": True, supports_temperature=False, # O3 models don't accept temperature
}, temperature_constraint=create_temperature_constraint("fixed"),
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": { description="OpenAI O3 via DIAL - Strong reasoning model",
"context_window": 200_000, aliases=["o3"],
"supports_extended_thinking": True, # Thinking mode variant ),
"supports_vision": True, "o4-mini-2025-04-16": ModelCapabilities(
}, provider=ProviderType.DIAL,
"anthropic.claude-opus-4-20250514-v1:0": { model_name="o4-mini-2025-04-16",
"context_window": 200_000, friendly_name="DIAL (O4-mini)",
"supports_extended_thinking": False, context_window=200_000,
"supports_vision": True, max_output_tokens=100_000,
}, supports_extended_thinking=False,
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": { supports_system_prompts=True,
"context_window": 200_000, supports_streaming=True,
"supports_extended_thinking": True, # Thinking mode variant supports_function_calling=False, # DIAL may not expose function calling
"supports_vision": True, supports_json_mode=True,
}, supports_images=True,
"gemini-2.5-pro-preview-03-25-google-search": { max_image_size_mb=20.0,
"context_window": 1_000_000, supports_temperature=False, # O4 models don't accept temperature
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode temperature_constraint=create_temperature_constraint("fixed"),
"supports_vision": True, description="OpenAI O4-mini via DIAL - Fast reasoning model",
}, aliases=["o4-mini"],
"gemini-2.5-pro-preview-05-06": { ),
"context_window": 1_000_000, "anthropic.claude-sonnet-4-20250514-v1:0": ModelCapabilities(
"supports_extended_thinking": False, provider=ProviderType.DIAL,
"supports_vision": True, model_name="anthropic.claude-sonnet-4-20250514-v1:0",
}, friendly_name="DIAL (Sonnet 4)",
"gemini-2.5-flash-preview-05-20": { context_window=200_000,
"context_window": 1_000_000, max_output_tokens=64_000,
"supports_extended_thinking": False, supports_extended_thinking=False,
"supports_vision": True, supports_system_prompts=True,
}, supports_streaming=True,
# Shorthands supports_function_calling=False, # Claude doesn't have function calling
"o3": "o3-2025-04-16", supports_json_mode=False, # Claude doesn't have JSON mode
"o4-mini": "o4-mini-2025-04-16", supports_images=True,
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0", max_image_size_mb=5.0,
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking", supports_temperature=True,
"opus-4": "anthropic.claude-opus-4-20250514-v1:0", temperature_constraint=create_temperature_constraint("range"),
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking", description="Claude Sonnet 4 via DIAL - Balanced performance",
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06", aliases=["sonnet-4"],
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search", ),
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20", "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
friendly_name="DIAL (Sonnet 4 Thinking)",
context_window=200_000,
max_output_tokens=64_000,
supports_extended_thinking=True, # Thinking mode variant
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Claude doesn't have function calling
supports_json_mode=False, # Claude doesn't have JSON mode
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Claude Sonnet 4 with thinking mode via DIAL",
aliases=["sonnet-4-thinking"],
),
"anthropic.claude-opus-4-20250514-v1:0": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-opus-4-20250514-v1:0",
friendly_name="DIAL (Opus 4)",
context_window=200_000,
max_output_tokens=64_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Claude doesn't have function calling
supports_json_mode=False, # Claude doesn't have JSON mode
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Claude Opus 4 via DIAL - Most capable Claude model",
aliases=["opus-4"],
),
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-opus-4-20250514-v1:0-with-thinking",
friendly_name="DIAL (Opus 4 Thinking)",
context_window=200_000,
max_output_tokens=64_000,
supports_extended_thinking=True, # Thinking mode variant
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Claude doesn't have function calling
supports_json_mode=False, # Claude doesn't have JSON mode
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Claude Opus 4 with thinking mode via DIAL",
aliases=["opus-4-thinking"],
),
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-pro-preview-03-25-google-search",
friendly_name="DIAL (Gemini 2.5 Pro Search)",
context_window=1_000_000,
max_output_tokens=65_536,
supports_extended_thinking=False, # DIAL doesn't expose thinking mode
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Gemini 2.5 Pro with Google Search via DIAL",
aliases=["gemini-2.5-pro-search"],
),
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-pro-preview-05-06",
friendly_name="DIAL (Gemini 2.5 Pro)",
context_window=1_000_000,
max_output_tokens=65_536,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
aliases=["gemini-2.5-pro"],
),
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-flash-preview-05-20",
friendly_name="DIAL (Gemini Flash 2.5)",
context_window=1_000_000,
max_output_tokens=65_536,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
aliases=["gemini-2.5-flash"],
),
} }
def __init__(self, api_key: str, **kwargs): def __init__(self, api_key: str, **kwargs):
@@ -181,20 +288,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name): if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.") raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
config = self.SUPPORTED_MODELS[resolved_name] # Return the ModelCapabilities object directly from SUPPORTED_MODELS
return self.SUPPORTED_MODELS[resolved_name]
return ModelCapabilities(
provider=ProviderType.DIAL,
model_name=resolved_name,
friendly_name=self.FRIENDLY_NAME,
context_window=config["context_window"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_images=config.get("supports_vision", False),
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
)
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""
@@ -211,7 +306,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
""" """
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): if resolved_name not in self.SUPPORTED_MODELS:
return False return False
# Check against base class allowed_models if configured # Check against base class allowed_models if configured
@@ -231,20 +326,6 @@ class DIALModelProvider(OpenAICompatibleProvider):
return True return True
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name.
Args:
model_name: Model name or shorthand
Returns:
Full model name
"""
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
def _get_deployment_client(self, deployment: str): def _get_deployment_client(self, deployment: str):
"""Get or create a cached client for a specific deployment. """Get or create a cached client for a specific deployment.
@@ -357,7 +438,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
# Check model capabilities # Check model capabilities
try: try:
capabilities = self.get_capabilities(model_name) capabilities = self.get_capabilities(model_name)
supports_temperature = getattr(capabilities, "supports_temperature", True) supports_temperature = capabilities.supports_temperature
except Exception as e: except Exception as e:
logger.debug(f"Failed to check temperature support for {model_name}: {e}") logger.debug(f"Failed to check temperature support for {model_name}: {e}")
supports_temperature = True supports_temperature = True
@@ -441,63 +522,12 @@ class DIALModelProvider(OpenAICompatibleProvider):
""" """
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict): if resolved_name in self.SUPPORTED_MODELS:
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False) return self.SUPPORTED_MODELS[resolved_name].supports_images
# Fall back to parent implementation for unknown models # Fall back to parent implementation for unknown models
return super()._supports_vision(model_name) return super()._supports_vision(model_name)
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
# Get all model keys (both full names and aliases)
all_models = list(self.SUPPORTED_MODELS.keys())
if not respect_restrictions:
return all_models
# Apply restrictions if configured
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
# Filter based on restrictions
allowed_models = []
for model in all_models:
resolved_name = self._resolve_model_name(model)
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
allowed_models.append(model)
return allowed_models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
This is used for validation purposes to ensure restriction policies
can validate against both aliases and their target model names.
Returns:
List of all model names and alias targets known by this provider
"""
# Collect all unique model names (both aliases and targets)
all_models = set()
for key, value in self.SUPPORTED_MODELS.items():
# Add the key (could be alias or full name)
all_models.add(key)
# If it's an alias (string value), add the target too
if isinstance(value, str):
all_models.add(value)
return sorted(all_models)
def close(self): def close(self):
"""Clean up HTTP clients when provider is closed.""" """Clean up HTTP clients when provider is closed."""
logger.info("Closing DIAL provider HTTP clients...") logger.info("Closing DIAL provider HTTP clients...")

View File

@@ -9,7 +9,7 @@ from typing import Optional
from google import genai from google import genai
from google.genai import types from google.genai import types
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -17,47 +17,83 @@ logger = logging.getLogger(__name__)
class GeminiModelProvider(ModelProvider): class GeminiModelProvider(ModelProvider):
"""Google Gemini model provider implementation.""" """Google Gemini model provider implementation."""
# Model configurations # Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"gemini-2.0-flash": { "gemini-2.0-flash": ModelCapabilities(
"context_window": 1_048_576, # 1M tokens provider=ProviderType.GOOGLE,
"supports_extended_thinking": True, # Experimental thinking mode model_name="gemini-2.0-flash",
"max_thinking_tokens": 24576, # Same as 2.5 flash for consistency friendly_name="Gemini (Flash 2.0)",
"supports_images": True, # Vision capability context_window=1_048_576, # 1M tokens
"max_image_size_mb": 20.0, # Conservative 20MB limit for reliability max_output_tokens=65_536,
"description": "Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input", supports_extended_thinking=True, # Experimental thinking mode
}, supports_system_prompts=True,
"gemini-2.0-flash-lite": { supports_streaming=True,
"context_window": 1_048_576, # 1M tokens supports_function_calling=True,
"supports_extended_thinking": False, # Not supported per user request supports_json_mode=True,
"max_thinking_tokens": 0, # No thinking support supports_images=True, # Vision capability
"supports_images": False, # Does not support images max_image_size_mb=20.0, # Conservative 20MB limit for reliability
"max_image_size_mb": 0.0, # No image support supports_temperature=True,
"description": "Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only", temperature_constraint=create_temperature_constraint("range"),
}, max_thinking_tokens=24576, # Same as 2.5 flash for consistency
"gemini-2.5-flash": { description="Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input",
"context_window": 1_048_576, # 1M tokens aliases=["flash-2.0", "flash2"],
"supports_extended_thinking": True, ),
"max_thinking_tokens": 24576, # Flash 2.5 thinking budget limit "gemini-2.0-flash-lite": ModelCapabilities(
"supports_images": True, # Vision capability provider=ProviderType.GOOGLE,
"max_image_size_mb": 20.0, # Conservative 20MB limit for reliability model_name="gemini-2.0-flash-lite",
"description": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations", friendly_name="Gemin (Flash Lite 2.0)",
}, context_window=1_048_576, # 1M tokens
"gemini-2.5-pro": { max_output_tokens=65_536,
"context_window": 1_048_576, # 1M tokens supports_extended_thinking=False, # Not supported per user request
"supports_extended_thinking": True, supports_system_prompts=True,
"max_thinking_tokens": 32768, # Pro 2.5 thinking budget limit supports_streaming=True,
"supports_images": True, # Vision capability supports_function_calling=True,
"max_image_size_mb": 32.0, # Higher limit for Pro model supports_json_mode=True,
"description": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis", supports_images=False, # Does not support images
}, max_image_size_mb=0.0, # No image support
# Shorthands supports_temperature=True,
"flash": "gemini-2.5-flash", temperature_constraint=create_temperature_constraint("range"),
"flash-2.0": "gemini-2.0-flash", description="Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only",
"flash2": "gemini-2.0-flash", aliases=["flashlite", "flash-lite"],
"flashlite": "gemini-2.0-flash-lite", ),
"flash-lite": "gemini-2.0-flash-lite", "gemini-2.5-flash": ModelCapabilities(
"pro": "gemini-2.5-pro", provider=ProviderType.GOOGLE,
model_name="gemini-2.5-flash",
friendly_name="Gemini (Flash 2.5)",
context_window=1_048_576, # 1M tokens
max_output_tokens=65_536,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # Vision capability
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
max_thinking_tokens=24576, # Flash 2.5 thinking budget limit
description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
aliases=["flash", "flash2.5"],
),
"gemini-2.5-pro": ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.5-pro",
friendly_name="Gemini (Pro 2.5)",
context_window=1_048_576, # 1M tokens
max_output_tokens=65_536,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # Vision capability
max_image_size_mb=32.0, # Higher limit for Pro model
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
max_thinking_tokens=32768, # Max thinking tokens for Pro model
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
aliases=["pro", "gemini pro", "gemini-pro"],
),
} }
# Thinking mode configurations - percentages of model's max_thinking_tokens # Thinking mode configurations - percentages of model's max_thinking_tokens
@@ -70,6 +106,14 @@ class GeminiModelProvider(ModelProvider):
"max": 1.0, # 100% of max - full thinking budget "max": 1.0, # 100% of max - full thinking budget
} }
# Model-specific thinking token limits
MAX_THINKING_TOKENS = {
"gemini-2.0-flash": 24576, # Same as 2.5 flash for consistency
"gemini-2.0-flash-lite": 0, # No thinking support
"gemini-2.5-flash": 24576, # Flash 2.5 thinking budget limit
"gemini-2.5-pro": 32768, # Pro 2.5 thinking budget limit
}
def __init__(self, api_key: str, **kwargs): def __init__(self, api_key: str, **kwargs):
"""Initialize Gemini provider with API key.""" """Initialize Gemini provider with API key."""
super().__init__(api_key, **kwargs) super().__init__(api_key, **kwargs)
@@ -100,25 +144,8 @@ class GeminiModelProvider(ModelProvider):
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name): if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.") raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
config = self.SUPPORTED_MODELS[resolved_name] # Return the ModelCapabilities object directly from SUPPORTED_MODELS
return self.SUPPORTED_MODELS[resolved_name]
# Gemini models support 0.0-2.0 temperature range
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
return ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name=resolved_name,
friendly_name="Gemini",
context_window=config["context_window"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_images=config.get("supports_images", False),
max_image_size_mb=config.get("max_image_size_mb", 0.0),
supports_temperature=True, # Gemini models accept temperature parameter
temperature_constraint=temp_constraint,
)
def generate_content( def generate_content(
self, self,
@@ -179,8 +206,8 @@ class GeminiModelProvider(ModelProvider):
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS: if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
# Get model's max thinking tokens and calculate actual budget # Get model's max thinking tokens and calculate actual budget
model_config = self.SUPPORTED_MODELS.get(resolved_name) model_config = self.SUPPORTED_MODELS.get(resolved_name)
if model_config and "max_thinking_tokens" in model_config: if model_config and model_config.max_thinking_tokens > 0:
max_thinking_tokens = model_config["max_thinking_tokens"] max_thinking_tokens = model_config.max_thinking_tokens
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode]) actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget) generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget)
@@ -258,7 +285,7 @@ class GeminiModelProvider(ModelProvider):
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
# First check if model is supported # First check if model is supported
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): if resolved_name not in self.SUPPORTED_MODELS:
return False return False
# Then check if model is allowed by restrictions # Then check if model is allowed by restrictions
@@ -281,78 +308,20 @@ class GeminiModelProvider(ModelProvider):
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int: def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
"""Get actual thinking token budget for a model and thinking mode.""" """Get actual thinking token budget for a model and thinking mode."""
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
model_config = self.SUPPORTED_MODELS.get(resolved_name, {}) model_config = self.SUPPORTED_MODELS.get(resolved_name)
if not model_config.get("supports_extended_thinking", False): if not model_config or not model_config.supports_extended_thinking:
return 0 return 0
if thinking_mode not in self.THINKING_BUDGETS: if thinking_mode not in self.THINKING_BUDGETS:
return 0 return 0
max_thinking_tokens = model_config.get("max_thinking_tokens", 0) max_thinking_tokens = model_config.max_thinking_tokens
if max_thinking_tokens == 0: if max_thinking_tokens == 0:
return 0 return 0
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode]) return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in self.SUPPORTED_MODELS.items():
# Handle both base models (dict configs) and aliases (string values)
if isinstance(config, str):
# This is an alias - check if the target model would be allowed
target_model = config
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
continue
# Allow the alias
models.append(model_name)
else:
# This is a base model with config dict
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
models.append(model_name)
return models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
Returns:
List of all model names and alias targets known by this provider
"""
all_models = set()
for model_name, config in self.SUPPORTED_MODELS.items():
# Add the model name itself
all_models.add(model_name.lower())
# If it's an alias (string value), add the target model too
if isinstance(config, str):
all_models.add(config.lower())
return list(all_models)
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
# Check if it's a shorthand
shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower())
if isinstance(shorthand_value, str):
return shorthand_value
return model_name
def _extract_usage(self, response) -> dict[str, int]: def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from Gemini response.""" """Extract token usage from Gemini response."""
usage = {} usage = {}

View File

@@ -686,7 +686,6 @@ class OpenAICompatibleProvider(ModelProvider):
"o3-mini", "o3-mini",
"o3-pro", "o3-pro",
"o4-mini", "o4-mini",
"o4-mini-high",
# Note: Claude models would be handled by a separate provider # Note: Claude models would be handled by a separate provider
} }
supports = model_name.lower() in vision_models supports = model_name.lower() in vision_models

View File

@@ -17,71 +17,98 @@ logger = logging.getLogger(__name__)
class OpenAIModelProvider(OpenAICompatibleProvider): class OpenAIModelProvider(OpenAICompatibleProvider):
"""Official OpenAI API provider (api.openai.com).""" """Official OpenAI API provider (api.openai.com)."""
# Model configurations # Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"o3": { "o3": ModelCapabilities(
"context_window": 200_000, # 200K tokens provider=ProviderType.OPENAI,
"supports_extended_thinking": False, model_name="o3",
"supports_images": True, # O3 models support vision friendly_name="OpenAI (O3)",
"max_image_size_mb": 20.0, # 20MB per OpenAI docs context_window=200_000, # 200K tokens
"supports_temperature": False, # O3 models don't accept temperature parameter max_output_tokens=65536, # 64K max output tokens
"temperature_constraint": "fixed", # Fixed at 1.0 supports_extended_thinking=False,
"description": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis", supports_system_prompts=True,
}, supports_streaming=True,
"o3-mini": { supports_function_calling=True,
"context_window": 200_000, # 200K tokens supports_json_mode=True,
"supports_extended_thinking": False, supports_images=True, # O3 models support vision
"supports_images": True, # O3 models support vision max_image_size_mb=20.0, # 20MB per OpenAI docs
"max_image_size_mb": 20.0, # 20MB per OpenAI docs supports_temperature=False, # O3 models don't accept temperature parameter
"supports_temperature": False, # O3 models don't accept temperature parameter temperature_constraint=create_temperature_constraint("fixed"),
"temperature_constraint": "fixed", # Fixed at 1.0 description="Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
"description": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity", aliases=[],
}, ),
"o3-pro-2025-06-10": { "o3-mini": ModelCapabilities(
"context_window": 200_000, # 200K tokens provider=ProviderType.OPENAI,
"supports_extended_thinking": False, model_name="o3-mini",
"supports_images": True, # O3 models support vision friendly_name="OpenAI (O3-mini)",
"max_image_size_mb": 20.0, # 20MB per OpenAI docs context_window=200_000, # 200K tokens
"supports_temperature": False, # O3 models don't accept temperature parameter max_output_tokens=65536, # 64K max output tokens
"temperature_constraint": "fixed", # Fixed at 1.0 supports_extended_thinking=False,
"description": "Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.", supports_system_prompts=True,
}, supports_streaming=True,
# Aliases supports_function_calling=True,
"o3-pro": "o3-pro-2025-06-10", supports_json_mode=True,
"o4-mini": { supports_images=True, # O3 models support vision
"context_window": 200_000, # 200K tokens max_image_size_mb=20.0, # 20MB per OpenAI docs
"supports_extended_thinking": False, supports_temperature=False, # O3 models don't accept temperature parameter
"supports_images": True, # O4 models support vision temperature_constraint=create_temperature_constraint("fixed"),
"max_image_size_mb": 20.0, # 20MB per OpenAI docs description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
"supports_temperature": False, # O4 models don't accept temperature parameter aliases=["o3mini", "o3-mini"],
"temperature_constraint": "fixed", # Fixed at 1.0 ),
"description": "Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning", "o3-pro-2025-06-10": ModelCapabilities(
}, provider=ProviderType.OPENAI,
"o4-mini-high": { model_name="o3-pro-2025-06-10",
"context_window": 200_000, # 200K tokens friendly_name="OpenAI (O3-Pro)",
"supports_extended_thinking": False, context_window=200_000, # 200K tokens
"supports_images": True, # O4 models support vision max_output_tokens=65536, # 64K max output tokens
"max_image_size_mb": 20.0, # 20MB per OpenAI docs supports_extended_thinking=False,
"supports_temperature": False, # O4 models don't accept temperature parameter supports_system_prompts=True,
"temperature_constraint": "fixed", # Fixed at 1.0 supports_streaming=True,
"description": "Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks", supports_function_calling=True,
}, supports_json_mode=True,
"gpt-4.1-2025-04-14": { supports_images=True, # O3 models support vision
"context_window": 1_000_000, # 1M tokens max_image_size_mb=20.0, # 20MB per OpenAI docs
"supports_extended_thinking": False, supports_temperature=False, # O3 models don't accept temperature parameter
"supports_images": True, # GPT-4.1 supports vision temperature_constraint=create_temperature_constraint("fixed"),
"max_image_size_mb": 20.0, # 20MB per OpenAI docs description="Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
"supports_temperature": True, # Regular models accept temperature parameter aliases=["o3-pro"],
"temperature_constraint": "range", # 0.0-2.0 range ),
"description": "GPT-4.1 (1M context) - Advanced reasoning model with large context window", "o4-mini": ModelCapabilities(
}, provider=ProviderType.OPENAI,
# Shorthands model_name="o4-mini",
"mini": "o4-mini", # Default 'mini' to latest mini model friendly_name="OpenAI (O4-mini)",
"o3mini": "o3-mini", context_window=200_000, # 200K tokens
"o4mini": "o4-mini", max_output_tokens=65536, # 64K max output tokens
"o4minihigh": "o4-mini-high", supports_extended_thinking=False,
"o4minihi": "o4-mini-high", supports_system_prompts=True,
"gpt4.1": "gpt-4.1-2025-04-14", supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # O4 models support vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=False, # O4 models don't accept temperature parameter
temperature_constraint=create_temperature_constraint("fixed"),
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
aliases=["mini", "o4mini", "o4-mini"],
),
"gpt-4.1-2025-04-14": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="gpt-4.1-2025-04-14",
friendly_name="OpenAI (GPT 4.1)",
context_window=1_000_000, # 1M tokens
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # GPT-4.1 supports vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=True, # Regular models accept temperature parameter
temperature_constraint=create_temperature_constraint("range"),
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
aliases=["gpt4.1"],
),
} }
def __init__(self, api_key: str, **kwargs): def __init__(self, api_key: str, **kwargs):
@@ -95,7 +122,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
# Resolve shorthand # Resolve shorthand
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str): if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported OpenAI model: {model_name}") raise ValueError(f"Unsupported OpenAI model: {model_name}")
# Check if model is allowed by restrictions # Check if model is allowed by restrictions
@@ -105,27 +132,8 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name): if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.") raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
config = self.SUPPORTED_MODELS[resolved_name] # Return the ModelCapabilities object directly from SUPPORTED_MODELS
return self.SUPPORTED_MODELS[resolved_name]
# Get temperature constraints and support from configuration
supports_temperature = config.get("supports_temperature", True) # Default to True for backward compatibility
temp_constraint_type = config.get("temperature_constraint", "range") # Default to range
temp_constraint = create_temperature_constraint(temp_constraint_type)
return ModelCapabilities(
provider=ProviderType.OPENAI,
model_name=model_name,
friendly_name="OpenAI",
context_window=config["context_window"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_images=config.get("supports_images", False),
max_image_size_mb=config.get("max_image_size_mb", 0.0),
supports_temperature=supports_temperature,
temperature_constraint=temp_constraint,
)
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""
@@ -136,7 +144,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
# First check if model is supported # First check if model is supported
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): if resolved_name not in self.SUPPORTED_MODELS:
return False return False
# Then check if model is allowed by restrictions # Then check if model is allowed by restrictions
@@ -177,61 +185,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
# Currently no OpenAI models support extended thinking # Currently no OpenAI models support extended thinking
# This may change with future O3 models # This may change with future O3 models
return False return False
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in self.SUPPORTED_MODELS.items():
# Handle both base models (dict configs) and aliases (string values)
if isinstance(config, str):
# This is an alias - check if the target model would be allowed
target_model = config
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
continue
# Allow the alias
models.append(model_name)
else:
# This is a base model with config dict
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
models.append(model_name)
return models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
Returns:
List of all model names and alias targets known by this provider
"""
all_models = set()
for model_name, config in self.SUPPORTED_MODELS.items():
# Add the model name itself
all_models.add(model_name.lower())
# If it's an alias (string value), add the target model too
if isinstance(config, str):
all_models.add(config.lower())
return list(all_models)
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
# Check if it's a shorthand
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
if isinstance(shorthand_value, str):
return shorthand_value
return model_name

View File

@@ -50,14 +50,6 @@ class OpenRouterProvider(OpenAICompatibleProvider):
aliases = self._registry.list_aliases() aliases = self._registry.list_aliases()
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases") logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
def _parse_allowed_models(self) -> None:
"""Override to disable environment-based allow-list.
OpenRouter model access is controlled via the OpenRouter dashboard,
not through environment variables.
"""
return None
def _resolve_model_name(self, model_name: str) -> str: def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model aliases to OpenRouter model names. """Resolve model aliases to OpenRouter model names.
@@ -109,6 +101,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
model_name=resolved_name, model_name=resolved_name,
friendly_name=self.FRIENDLY_NAME, friendly_name=self.FRIENDLY_NAME,
context_window=32_768, # Conservative default context window context_window=32_768, # Conservative default context window
max_output_tokens=32_768,
supports_extended_thinking=False, supports_extended_thinking=False,
supports_system_prompts=True, supports_system_prompts=True,
supports_streaming=True, supports_streaming=True,
@@ -130,16 +123,34 @@ class OpenRouterProvider(OpenAICompatibleProvider):
As the catch-all provider, OpenRouter accepts any model name that wasn't As the catch-all provider, OpenRouter accepts any model name that wasn't
handled by higher-priority providers. OpenRouter will validate based on handled by higher-priority providers. OpenRouter will validate based on
the API key's permissions. the API key's permissions and local restrictions.
Args: Args:
model_name: Model name to validate model_name: Model name to validate
Returns: Returns:
Always True - OpenRouter is the catch-all provider True if model is allowed, False if restricted
""" """
# Accept any model name - OpenRouter is the fallback provider # Check model restrictions if configured
# Higher priority providers (native APIs, custom endpoints) get first chance from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if restriction_service:
# Check if model name itself is allowed
if restriction_service.is_allowed(self.get_provider_type(), model_name):
return True
# Also check aliases - model_name might be an alias
model_config = self._registry.resolve(model_name)
if model_config and model_config.aliases:
for alias in model_config.aliases:
if restriction_service.is_allowed(self.get_provider_type(), alias):
return True
# If restrictions are configured and model/alias not in allowed list, reject
return False
# No restrictions configured - accept any model name as the fallback provider
return True return True
def generate_content( def generate_content(
@@ -260,3 +271,35 @@ class OpenRouterProvider(OpenAICompatibleProvider):
all_models.add(config.model_name.lower()) all_models.add(config.model_name.lower())
return list(all_models) return list(all_models)
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
"""Get model configurations from the registry.
For OpenRouter, we convert registry configurations to ModelCapabilities objects.
Returns:
Dictionary mapping model names to their ModelCapabilities objects
"""
configs = {}
if self._registry:
# Get all models from registry
for model_name in self._registry.list_models():
# Only include models that this provider validates
if self.validate_model_name(model_name):
config = self._registry.resolve(model_name)
if config and not config.is_custom: # Only OpenRouter models, not custom ones
# Use ModelCapabilities directly from registry
configs[model_name] = config
return configs
def get_all_model_aliases(self) -> dict[str, list[str]]:
"""Get all model aliases from the registry.
Returns:
Dictionary mapping model names to their list of aliases
"""
# Since aliases are now included in the configurations,
# we can use the base class implementation
return super().get_all_model_aliases()

View File

@@ -2,7 +2,6 @@
import logging import logging
import os import os
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -11,58 +10,10 @@ from utils.file_utils import read_json_file
from .base import ( from .base import (
ModelCapabilities, ModelCapabilities,
ProviderType, ProviderType,
TemperatureConstraint,
create_temperature_constraint, create_temperature_constraint,
) )
@dataclass
class OpenRouterModelConfig:
"""Configuration for an OpenRouter model."""
model_name: str
aliases: list[str] = field(default_factory=list)
context_window: int = 32768 # Total context window size in tokens
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
supports_streaming: bool = True
supports_function_calling: bool = False
supports_json_mode: bool = False
supports_images: bool = False # Whether model can process images
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
temperature_constraint: Optional[str] = (
None # Type of temperature constraint: "fixed", "range", "discrete", or None for default range
)
is_custom: bool = False # True for models that should only be used with custom endpoints
description: str = ""
def _create_temperature_constraint(self) -> TemperatureConstraint:
"""Create temperature constraint object from configuration.
Returns:
TemperatureConstraint object based on configuration
"""
return create_temperature_constraint(self.temperature_constraint or "range")
def to_capabilities(self) -> ModelCapabilities:
"""Convert to ModelCapabilities object."""
return ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=self.model_name,
friendly_name="OpenRouter",
context_window=self.context_window,
supports_extended_thinking=self.supports_extended_thinking,
supports_system_prompts=self.supports_system_prompts,
supports_streaming=self.supports_streaming,
supports_function_calling=self.supports_function_calling,
supports_images=self.supports_images,
max_image_size_mb=self.max_image_size_mb,
supports_temperature=self.supports_temperature,
temperature_constraint=self._create_temperature_constraint(),
)
class OpenRouterModelRegistry: class OpenRouterModelRegistry:
"""Registry for managing OpenRouter model configurations and aliases.""" """Registry for managing OpenRouter model configurations and aliases."""
@@ -73,7 +24,7 @@ class OpenRouterModelRegistry:
config_path: Path to config file. If None, uses default locations. config_path: Path to config file. If None, uses default locations.
""" """
self.alias_map: dict[str, str] = {} # alias -> model_name self.alias_map: dict[str, str] = {} # alias -> model_name
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config
# Determine config path # Determine config path
if config_path: if config_path:
@@ -139,7 +90,7 @@ class OpenRouterModelRegistry:
self.alias_map = {} self.alias_map = {}
self.model_map = {} self.model_map = {}
def _read_config(self) -> list[OpenRouterModelConfig]: def _read_config(self) -> list[ModelCapabilities]:
"""Read configuration from file. """Read configuration from file.
Returns: Returns:
@@ -158,7 +109,27 @@ class OpenRouterModelRegistry:
# Parse models # Parse models
configs = [] configs = []
for model_data in data.get("models", []): for model_data in data.get("models", []):
config = OpenRouterModelConfig(**model_data) # Create ModelCapabilities directly from JSON data
# Handle temperature_constraint conversion
temp_constraint_str = model_data.get("temperature_constraint")
temp_constraint = create_temperature_constraint(temp_constraint_str or "range")
# Set provider-specific defaults based on is_custom flag
is_custom = model_data.get("is_custom", False)
if is_custom:
model_data.setdefault("provider", ProviderType.CUSTOM)
model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})")
else:
model_data.setdefault("provider", ProviderType.OPENROUTER)
model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})")
model_data["temperature_constraint"] = temp_constraint
# Remove the string version of temperature_constraint before creating ModelCapabilities
if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str):
del model_data["temperature_constraint"]
model_data["temperature_constraint"] = temp_constraint
config = ModelCapabilities(**model_data)
configs.append(config) configs.append(config)
return configs return configs
@@ -168,7 +139,7 @@ class OpenRouterModelRegistry:
except Exception as e: except Exception as e:
raise ValueError(f"Error reading config from {self.config_path}: {e}") raise ValueError(f"Error reading config from {self.config_path}: {e}")
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None: def _build_maps(self, configs: list[ModelCapabilities]) -> None:
"""Build alias and model maps from configurations. """Build alias and model maps from configurations.
Args: Args:
@@ -211,7 +182,7 @@ class OpenRouterModelRegistry:
self.alias_map = alias_map self.alias_map = alias_map
self.model_map = model_map self.model_map = model_map
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]: def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]:
"""Resolve a model name or alias to configuration. """Resolve a model name or alias to configuration.
Args: Args:
@@ -237,10 +208,8 @@ class OpenRouterModelRegistry:
Returns: Returns:
ModelCapabilities if found, None otherwise ModelCapabilities if found, None otherwise
""" """
config = self.resolve(name_or_alias) # Registry now returns ModelCapabilities directly
if config: return self.resolve(name_or_alias)
return config.to_capabilities()
return None
def list_models(self) -> list[str]: def list_models(self) -> list[str]:
"""List all available model names.""" """List all available model names."""

View File

@@ -24,8 +24,6 @@ class ModelProviderRegistry:
cls._instance._providers = {} cls._instance._providers = {}
cls._instance._initialized_providers = {} cls._instance._initialized_providers = {}
logging.debug(f"REGISTRY: Created instance {cls._instance}") logging.debug(f"REGISTRY: Created instance {cls._instance}")
else:
logging.debug(f"REGISTRY: Returning existing instance {cls._instance}")
return cls._instance return cls._instance
@classmethod @classmethod
@@ -129,7 +127,6 @@ class ModelProviderRegistry:
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}") logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
for provider_type in PROVIDER_PRIORITY_ORDER: for provider_type in PROVIDER_PRIORITY_ORDER:
logging.debug(f"Checking provider_type: {provider_type}")
if provider_type in instance._providers: if provider_type in instance._providers:
logging.debug(f"Found {provider_type} in registry") logging.debug(f"Found {provider_type} in registry")
# Get or create provider instance # Get or create provider instance

View File

@@ -7,7 +7,7 @@ from .base import (
ModelCapabilities, ModelCapabilities,
ModelResponse, ModelResponse,
ProviderType, ProviderType,
RangeTemperatureConstraint, create_temperature_constraint,
) )
from .openai_compatible import OpenAICompatibleProvider from .openai_compatible import OpenAICompatibleProvider
@@ -19,23 +19,44 @@ class XAIModelProvider(OpenAICompatibleProvider):
FRIENDLY_NAME = "X.AI" FRIENDLY_NAME = "X.AI"
# Model configurations # Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"grok-3": { "grok-3": ModelCapabilities(
"context_window": 131_072, # 131K tokens provider=ProviderType.XAI,
"supports_extended_thinking": False, model_name="grok-3",
"description": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis", friendly_name="X.AI (Grok 3)",
}, context_window=131_072, # 131K tokens
"grok-3-fast": { max_output_tokens=131072,
"context_window": 131_072, # 131K tokens supports_extended_thinking=False,
"supports_extended_thinking": False, supports_system_prompts=True,
"description": "GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive", supports_streaming=True,
}, supports_function_calling=True,
# Shorthands for convenience supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
"grok": "grok-3", # Default to grok-3 supports_images=False, # Assuming GROK is text-only for now
"grok3": "grok-3", max_image_size_mb=0.0,
"grok3fast": "grok-3-fast", supports_temperature=True,
"grokfast": "grok-3-fast", temperature_constraint=create_temperature_constraint("range"),
description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
aliases=["grok", "grok3"],
),
"grok-3-fast": ModelCapabilities(
provider=ProviderType.XAI,
model_name="grok-3-fast",
friendly_name="X.AI (Grok 3 Fast)",
context_window=131_072, # 131K tokens
max_output_tokens=131072,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
supports_images=False, # Assuming GROK is text-only for now
max_image_size_mb=0.0,
supports_temperature=True,
temperature_constraint=create_temperature_constraint("range"),
description="GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
aliases=["grok3fast", "grokfast", "grok3-fast"],
),
} }
def __init__(self, api_key: str, **kwargs): def __init__(self, api_key: str, **kwargs):
@@ -49,7 +70,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
# Resolve shorthand # Resolve shorthand
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str): if resolved_name not in self.SUPPORTED_MODELS:
raise ValueError(f"Unsupported X.AI model: {model_name}") raise ValueError(f"Unsupported X.AI model: {model_name}")
# Check if model is allowed by restrictions # Check if model is allowed by restrictions
@@ -59,23 +80,8 @@ class XAIModelProvider(OpenAICompatibleProvider):
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name): if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.") raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
config = self.SUPPORTED_MODELS[resolved_name] # Return the ModelCapabilities object directly from SUPPORTED_MODELS
return self.SUPPORTED_MODELS[resolved_name]
# Define temperature constraints for GROK models
# GROK supports the standard OpenAI temperature range
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
return ModelCapabilities(
provider=ProviderType.XAI,
model_name=resolved_name,
friendly_name=self.FRIENDLY_NAME,
context_window=config["context_window"],
supports_extended_thinking=config["supports_extended_thinking"],
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
temperature_constraint=temp_constraint,
)
def get_provider_type(self) -> ProviderType: def get_provider_type(self) -> ProviderType:
"""Get the provider type.""" """Get the provider type."""
@@ -86,7 +92,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
resolved_name = self._resolve_model_name(model_name) resolved_name = self._resolve_model_name(model_name)
# First check if model is supported # First check if model is supported
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict): if resolved_name not in self.SUPPORTED_MODELS:
return False return False
# Then check if model is allowed by restrictions # Then check if model is allowed by restrictions
@@ -127,61 +133,3 @@ class XAIModelProvider(OpenAICompatibleProvider):
# Currently GROK models do not support extended thinking # Currently GROK models do not support extended thinking
# This may change with future GROK model releases # This may change with future GROK model releases
return False return False
def list_models(self, respect_restrictions: bool = True) -> list[str]:
"""Return a list of model names supported by this provider.
Args:
respect_restrictions: Whether to apply provider-specific restriction logic.
Returns:
List of model names available from this provider
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in self.SUPPORTED_MODELS.items():
# Handle both base models (dict configs) and aliases (string values)
if isinstance(config, str):
# This is an alias - check if the target model would be allowed
target_model = config
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
continue
# Allow the alias
models.append(model_name)
else:
# This is a base model with config dict
# Check restrictions if enabled
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
models.append(model_name)
return models
def list_all_known_models(self) -> list[str]:
"""Return all model names known by this provider, including alias targets.
Returns:
List of all model names and alias targets known by this provider
"""
all_models = set()
for model_name, config in self.SUPPORTED_MODELS.items():
# Add the model name itself
all_models.add(model_name.lower())
# If it's an alias (string value), add the target model too
if isinstance(config, str):
all_models.add(config.lower())
return list(all_models)
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name."""
# Check if it's a shorthand
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
if isinstance(shorthand_value, str):
return shorthand_value
return model_name

108
server.py
View File

@@ -158,6 +158,97 @@ logger = logging.getLogger(__name__)
# This name is used by MCP clients to identify and connect to this specific server # This name is used by MCP clients to identify and connect to this specific server
server: Server = Server("zen-server") server: Server = Server("zen-server")
# Constants for tool filtering
ESSENTIAL_TOOLS = {"version", "listmodels"}
def parse_disabled_tools_env() -> set[str]:
"""
Parse the DISABLED_TOOLS environment variable into a set of tool names.
Returns:
Set of lowercase tool names to disable, empty set if none specified
"""
disabled_tools_env = os.getenv("DISABLED_TOOLS", "").strip()
if not disabled_tools_env:
return set()
return {t.strip().lower() for t in disabled_tools_env.split(",") if t.strip()}
def validate_disabled_tools(disabled_tools: set[str], all_tools: dict[str, Any]) -> None:
"""
Validate the disabled tools list and log appropriate warnings.
Args:
disabled_tools: Set of tool names requested to be disabled
all_tools: Dictionary of all available tool instances
"""
essential_disabled = disabled_tools & ESSENTIAL_TOOLS
if essential_disabled:
logger.warning(f"Cannot disable essential tools: {sorted(essential_disabled)}")
unknown_tools = disabled_tools - set(all_tools.keys())
if unknown_tools:
logger.warning(f"Unknown tools in DISABLED_TOOLS: {sorted(unknown_tools)}")
def apply_tool_filter(all_tools: dict[str, Any], disabled_tools: set[str]) -> dict[str, Any]:
"""
Apply the disabled tools filter to create the final tools dictionary.
Args:
all_tools: Dictionary of all available tool instances
disabled_tools: Set of tool names to disable
Returns:
Dictionary containing only enabled tools
"""
enabled_tools = {}
for tool_name, tool_instance in all_tools.items():
if tool_name in ESSENTIAL_TOOLS or tool_name not in disabled_tools:
enabled_tools[tool_name] = tool_instance
else:
logger.debug(f"Tool '{tool_name}' disabled via DISABLED_TOOLS")
return enabled_tools
def log_tool_configuration(disabled_tools: set[str], enabled_tools: dict[str, Any]) -> None:
"""
Log the final tool configuration for visibility.
Args:
disabled_tools: Set of tool names that were requested to be disabled
enabled_tools: Dictionary of tools that remain enabled
"""
if not disabled_tools:
logger.info("All tools enabled (DISABLED_TOOLS not set)")
return
actual_disabled = disabled_tools - ESSENTIAL_TOOLS
if actual_disabled:
logger.debug(f"Disabled tools: {sorted(actual_disabled)}")
logger.info(f"Active tools: {sorted(enabled_tools.keys())}")
def filter_disabled_tools(all_tools: dict[str, Any]) -> dict[str, Any]:
"""
Filter tools based on DISABLED_TOOLS environment variable.
Args:
all_tools: Dictionary of all available tool instances
Returns:
dict: Filtered dictionary containing only enabled tools
"""
disabled_tools = parse_disabled_tools_env()
if not disabled_tools:
log_tool_configuration(disabled_tools, all_tools)
return all_tools
validate_disabled_tools(disabled_tools, all_tools)
enabled_tools = apply_tool_filter(all_tools, disabled_tools)
log_tool_configuration(disabled_tools, enabled_tools)
return enabled_tools
# Initialize the tool registry with all available AI-powered tools # Initialize the tool registry with all available AI-powered tools
# Each tool provides specialized functionality for different development tasks # Each tool provides specialized functionality for different development tasks
# Tools are instantiated once and reused across requests (stateless design) # Tools are instantiated once and reused across requests (stateless design)
@@ -178,6 +269,7 @@ TOOLS = {
"listmodels": ListModelsTool(), # List all available AI models by provider "listmodels": ListModelsTool(), # List all available AI models by provider
"version": VersionTool(), # Display server version and system information "version": VersionTool(), # Display server version and system information
} }
TOOLS = filter_disabled_tools(TOOLS)
# Rich prompt templates for all tools # Rich prompt templates for all tools
PROMPT_TEMPLATES = { PROMPT_TEMPLATES = {
@@ -673,6 +765,11 @@ def parse_model_option(model_string: str) -> tuple[str, Optional[str]]:
""" """
Parse model:option format into model name and option. Parse model:option format into model name and option.
Handles different formats:
- OpenRouter models: preserve :free, :beta, :preview suffixes as part of model name
- Ollama/Custom models: split on : to extract tags like :latest
- Consensus stance: extract options like :for, :against
Args: Args:
model_string: String that may contain "model:option" format model_string: String that may contain "model:option" format
@@ -680,6 +777,17 @@ def parse_model_option(model_string: str) -> tuple[str, Optional[str]]:
tuple: (model_name, option) where option may be None tuple: (model_name, option) where option may be None
""" """
if ":" in model_string and not model_string.startswith("http"): # Avoid parsing URLs if ":" in model_string and not model_string.startswith("http"): # Avoid parsing URLs
# Check if this looks like an OpenRouter model (contains /)
if "/" in model_string and model_string.count(":") == 1:
# Could be openai/gpt-4:something - check what comes after colon
parts = model_string.split(":", 1)
suffix = parts[1].strip().lower()
# Known OpenRouter suffixes to preserve
if suffix in ["free", "beta", "preview"]:
return model_string.strip(), None
# For other patterns (Ollama tags, consensus stances), split normally
parts = model_string.split(":", 1) parts = model_string.split(":", 1)
model_name = parts[0].strip() model_name = parts[0].strip()
model_option = parts[1].strip() if len(parts) > 1 else None model_option = parts[1].strip() if len(parts) > 1 else None

View File

@@ -182,6 +182,10 @@ class ConversationBaseTest(BaseSimulatorTest):
# Look for continuation_id in various places # Look for continuation_id in various places
if isinstance(response_data, dict): if isinstance(response_data, dict):
# Check top-level continuation_id (workflow tools)
if "continuation_id" in response_data:
return response_data["continuation_id"]
# Check metadata # Check metadata
metadata = response_data.get("metadata", {}) metadata = response_data.get("metadata", {})
if "thread_id" in metadata: if "thread_id" in metadata:

View File

@@ -91,11 +91,14 @@ class TestClass:
response_a2, continuation_id_a2 = self.call_mcp_tool( response_a2, continuation_id_a2 = self.call_mcp_tool(
"analyze", "analyze",
{ {
"prompt": "Now analyze the code quality and suggest improvements.", "step": "Now analyze the code quality and suggest improvements.",
"files": [test_file_path], "step_number": 1,
"total_steps": 2,
"next_step_required": False,
"findings": "Continuing analysis from previous chat conversation to analyze code quality.",
"relevant_files": [test_file_path],
"continuation_id": continuation_id_a1, "continuation_id": continuation_id_a1,
"model": "flash", "model": "flash",
"temperature": 0.7,
}, },
) )
@@ -154,10 +157,14 @@ class TestClass:
response_b2, continuation_id_b2 = self.call_mcp_tool( response_b2, continuation_id_b2 = self.call_mcp_tool(
"analyze", "analyze",
{ {
"prompt": "Analyze the previous greeting and suggest improvements.", "step": "Analyze the previous greeting and suggest improvements.",
"step_number": 1,
"total_steps": 1,
"next_step_required": False,
"findings": "Analyzing the greeting from previous conversation and suggesting improvements.",
"relevant_files": [test_file_path],
"continuation_id": continuation_id_b1, "continuation_id": continuation_id_b1,
"model": "flash", "model": "flash",
"temperature": 0.7,
}, },
) )

View File

@@ -206,11 +206,14 @@ if __name__ == "__main__":
response2, continuation_id2 = self.call_mcp_tool( response2, continuation_id2 = self.call_mcp_tool(
"analyze", "analyze",
{ {
"prompt": "Analyze the performance implications of these recursive functions.", "step": "Analyze the performance implications of these recursive functions.",
"files": [file1_path], "step_number": 1,
"total_steps": 1,
"next_step_required": False,
"findings": "Continuing from chat conversation to analyze performance implications of recursive functions.",
"relevant_files": [file1_path],
"continuation_id": continuation_id1, # Continue the chat conversation "continuation_id": continuation_id1, # Continue the chat conversation
"model": "flash", "model": "flash",
"temperature": 0.7,
}, },
) )
@@ -221,10 +224,14 @@ if __name__ == "__main__":
self.logger.info(f" ✅ Step 2 completed with continuation_id: {continuation_id2[:8]}...") self.logger.info(f" ✅ Step 2 completed with continuation_id: {continuation_id2[:8]}...")
continuation_ids.append(continuation_id2) continuation_ids.append(continuation_id2)
# Validate that we got a different continuation ID # Validate continuation ID behavior for workflow tools
if continuation_id2 == continuation_id1: # Workflow tools reuse the same continuation_id when continuing within a workflow session
self.logger.error(" ❌ Step 2: Got same continuation ID as Step 1 - continuation not working") # This is expected behavior and different from simple tools
return False if continuation_id2 != continuation_id1:
self.logger.info(" ✅ Step 2: Got new continuation ID (workflow behavior)")
else:
self.logger.info(" ✅ Step 2: Reused continuation ID (workflow session continuation)")
# Both behaviors are valid - what matters is that we got a continuation_id
# Validate that Step 2 is building on Step 1's conversation # Validate that Step 2 is building on Step 1's conversation
# Check if the response references the previous conversation # Check if the response references the previous conversation
@@ -276,17 +283,16 @@ if __name__ == "__main__":
all_have_continuation_ids = bool(continuation_id1 and continuation_id2 and continuation_id3) all_have_continuation_ids = bool(continuation_id1 and continuation_id2 and continuation_id3)
criteria.append(("All steps generated continuation IDs", all_have_continuation_ids)) criteria.append(("All steps generated continuation IDs", all_have_continuation_ids))
# 3. Each continuation ID is unique # 3. Continuation behavior validation (handles both simple and workflow tools)
unique_continuation_ids = len(set(continuation_ids)) == len(continuation_ids) # Simple tools create new IDs each time, workflow tools may reuse IDs within sessions
criteria.append(("Each response generated unique continuation ID", unique_continuation_ids)) has_valid_continuation_pattern = len(continuation_ids) == 3
criteria.append(("Valid continuation ID pattern", has_valid_continuation_pattern))
# 4. Continuation IDs follow the expected pattern # 4. Check for conversation continuity (more important than ID uniqueness)
step_ids_different = ( conversation_has_continuity = len(continuation_ids) == 3 and all(
len(continuation_ids) == 3 cid is not None for cid in continuation_ids
and continuation_ids[0] != continuation_ids[1]
and continuation_ids[1] != continuation_ids[2]
) )
criteria.append(("All continuation IDs are different", step_ids_different)) criteria.append(("Conversation continuity maintained", conversation_has_continuity))
# 5. Check responses build on each other (content validation) # 5. Check responses build on each other (content validation)
step1_has_function_analysis = "fibonacci" in response1.lower() or "factorial" in response1.lower() step1_has_function_analysis = "fibonacci" in response1.lower() or "factorial" in response1.lower()

View File

@@ -15,6 +15,7 @@ def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576
model_name=model_name, model_name=model_name,
friendly_name="Gemini", friendly_name="Gemini",
context_window=context_window, context_window=context_window,
max_output_tokens=8192,
supports_extended_thinking=False, supports_extended_thinking=False,
supports_system_prompts=True, supports_system_prompts=True,
supports_streaming=True, supports_streaming=True,

View File

@@ -211,7 +211,7 @@ class TestAliasTargetRestrictions:
# Verify the polymorphic method was called # Verify the polymorphic method was called
mock_provider.list_all_known_models.assert_called_once() mock_provider.list_all_known_models.assert_called_once()
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high"}) # Restrict to specific model @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}) # Restrict to specific model
def test_complex_alias_chains_handled_correctly(self): def test_complex_alias_chains_handled_correctly(self):
"""Test that complex alias chains are handled correctly in restrictions.""" """Test that complex alias chains are handled correctly in restrictions."""
# Clear cached restriction service # Clear cached restriction service
@@ -221,12 +221,11 @@ class TestAliasTargetRestrictions:
provider = OpenAIModelProvider(api_key="test-key") provider = OpenAIModelProvider(api_key="test-key")
# Only o4-mini-high should be allowed # Only o4-mini should be allowed
assert provider.validate_model_name("o4-mini-high") assert provider.validate_model_name("o4-mini")
# Other models should be blocked # Other models should be blocked
assert not provider.validate_model_name("o4-mini") assert not provider.validate_model_name("o3")
assert not provider.validate_model_name("mini") # This resolves to o4-mini
assert not provider.validate_model_name("o3-mini") assert not provider.validate_model_name("o3-mini")
def test_critical_regression_validation_sees_alias_targets(self): def test_critical_regression_validation_sees_alias_targets(self):
@@ -307,7 +306,7 @@ class TestAliasTargetRestrictions:
it appear that target-based restrictions don't work. it appear that target-based restrictions don't work.
""" """
# Test with a made-up restriction scenario # Test with a made-up restriction scenario
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high,o3-mini"}): with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,o3-mini"}):
# Clear cached restriction service # Clear cached restriction service
import utils.model_restrictions import utils.model_restrictions
@@ -318,7 +317,7 @@ class TestAliasTargetRestrictions:
# These specific target models should be recognized as valid # These specific target models should be recognized as valid
all_known = provider.list_all_known_models() all_known = provider.list_all_known_models()
assert "o4-mini-high" in all_known, "Target model o4-mini-high should be known" assert "o4-mini" in all_known, "Target model o4-mini should be known"
assert "o3-mini" in all_known, "Target model o3-mini should be known" assert "o3-mini" in all_known, "Target model o3-mini should be known"
# Validation should not warn about these being unrecognized # Validation should not warn about these being unrecognized
@@ -329,11 +328,11 @@ class TestAliasTargetRestrictions:
# Should not warn about our allowed models being unrecognized # Should not warn about our allowed models being unrecognized
all_warnings = [str(call) for call in mock_logger.warning.call_args_list] all_warnings = [str(call) for call in mock_logger.warning.call_args_list]
for warning in all_warnings: for warning in all_warnings:
assert "o4-mini-high" not in warning or "not a recognized" not in warning assert "o4-mini" not in warning or "not a recognized" not in warning
assert "o3-mini" not in warning or "not a recognized" not in warning assert "o3-mini" not in warning or "not a recognized" not in warning
# The restriction should actually work # The restriction should actually work
assert provider.validate_model_name("o4-mini-high") assert provider.validate_model_name("o4-mini")
assert provider.validate_model_name("o3-mini") assert provider.validate_model_name("o3-mini")
assert not provider.validate_model_name("o4-mini") # not in allowed list assert not provider.validate_model_name("o3-pro") # not in allowed list
assert not provider.validate_model_name("o3") # not in allowed list assert not provider.validate_model_name("o3") # not in allowed list

View File

@@ -59,12 +59,12 @@ class TestAutoMode:
continue continue
# Check that model has description # Check that model has description
description = config.get("description", "") description = config.description if hasattr(config, "description") else ""
if description: if description:
models_with_descriptions[model_name] = description models_with_descriptions[model_name] = description
# Check all expected models are present with meaningful descriptions # Check all expected models are present with meaningful descriptions
expected_models = ["flash", "pro", "o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high"] expected_models = ["flash", "pro", "o3", "o3-mini", "o3-pro", "o4-mini"]
for model in expected_models: for model in expected_models:
# Model should exist somewhere in the providers # Model should exist somewhere in the providers
# Note: Some models might not be available if API keys aren't configured # Note: Some models might not be available if API keys aren't configured

View File

@@ -319,7 +319,18 @@ class TestAutoModeComprehensive:
m m
for m in available_models for m in available_models
if not m.startswith("gemini") if not m.startswith("gemini")
and m not in ["flash", "pro", "flash-2.0", "flash2", "flashlite", "flash-lite"] and m
not in [
"flash",
"pro",
"flash-2.0",
"flash2",
"flashlite",
"flash-lite",
"flash2.5",
"gemini pro",
"gemini-pro",
]
] ]
assert ( assert (
len(non_gemini_models) == 0 len(non_gemini_models) == 0

View File

@@ -70,7 +70,7 @@ class TestAutoModeCustomProviderOnly:
} }
# Clear all other provider keys # Clear all other provider keys
clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"] clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]
with patch.dict(os.environ, test_env, clear=False): with patch.dict(os.environ, test_env, clear=False):
# Ensure other provider keys are not set # Ensure other provider keys are not set
@@ -109,7 +109,7 @@ class TestAutoModeCustomProviderOnly:
with patch.dict(os.environ, test_env, clear=False): with patch.dict(os.environ, test_env, clear=False):
# Clear other provider keys # Clear other provider keys
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
if key in os.environ: if key in os.environ:
del os.environ[key] del os.environ[key]
@@ -177,7 +177,7 @@ class TestAutoModeCustomProviderOnly:
with patch.dict(os.environ, test_env, clear=False): with patch.dict(os.environ, test_env, clear=False):
# Clear other provider keys # Clear other provider keys
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
if key in os.environ: if key in os.environ:
del os.environ[key] del os.environ[key]

View File

@@ -118,7 +118,7 @@ class TestBuggyBehaviorPrevention:
provider = OpenAIModelProvider(api_key="test-key") provider = OpenAIModelProvider(api_key="test-key")
# Simulate a scenario where admin wants to restrict specific targets # Simulate a scenario where admin wants to restrict specific targets
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini-high"}): with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
# Clear cached restriction service # Clear cached restriction service
import utils.model_restrictions import utils.model_restrictions
@@ -126,19 +126,21 @@ class TestBuggyBehaviorPrevention:
# These should work because they're explicitly allowed # These should work because they're explicitly allowed
assert provider.validate_model_name("o3-mini") assert provider.validate_model_name("o3-mini")
assert provider.validate_model_name("o4-mini-high") assert provider.validate_model_name("o4-mini")
# These should be blocked # These should be blocked
assert not provider.validate_model_name("o4-mini") # Not in allowed list assert not provider.validate_model_name("o3-pro") # Not in allowed list
assert not provider.validate_model_name("o3") # Not in allowed list assert not provider.validate_model_name("o3") # Not in allowed list
assert not provider.validate_model_name("mini") # Resolves to o4-mini, not allowed
# This should be ALLOWED because it resolves to o4-mini which is in the allowed list
assert provider.validate_model_name("mini") # Resolves to o4-mini, which IS allowed
# Verify our list_all_known_models includes the restricted models # Verify our list_all_known_models includes the restricted models
all_known = provider.list_all_known_models() all_known = provider.list_all_known_models()
assert "o3-mini" in all_known # Should be known (and allowed) assert "o3-mini" in all_known # Should be known (and allowed)
assert "o4-mini-high" in all_known # Should be known (and allowed) assert "o4-mini" in all_known # Should be known (and allowed)
assert "o4-mini" in all_known # Should be known (but blocked) assert "o3-pro" in all_known # Should be known (but blocked)
assert "mini" in all_known # Should be known (but blocked) assert "mini" in all_known # Should be known (and allowed since it resolves to o4-mini)
def test_demonstration_of_old_vs_new_interface(self): def test_demonstration_of_old_vs_new_interface(self):
""" """

View File

@@ -506,17 +506,17 @@ class TestConversationFlow:
mock_client = Mock() mock_client = Mock()
mock_storage.return_value = mock_client mock_storage.return_value = mock_client
# Start conversation with files # Start conversation with files using a simple tool
thread_id = create_thread("analyze", {"prompt": "Analyze this codebase", "relevant_files": ["/project/src/"]}) thread_id = create_thread("chat", {"prompt": "Analyze this codebase", "files": ["/project/src/"]})
# Turn 1: Claude provides context with multiple files # Turn 1: Claude provides context with multiple files
initial_context = ThreadContext( initial_context = ThreadContext(
thread_id=thread_id, thread_id=thread_id,
created_at="2023-01-01T00:00:00Z", created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:00:00Z", last_updated_at="2023-01-01T00:00:00Z",
tool_name="analyze", tool_name="chat",
turns=[], turns=[],
initial_context={"prompt": "Analyze this codebase", "relevant_files": ["/project/src/"]}, initial_context={"prompt": "Analyze this codebase", "files": ["/project/src/"]},
) )
mock_client.get.return_value = initial_context.model_dump_json() mock_client.get.return_value = initial_context.model_dump_json()

View File

@@ -45,10 +45,17 @@ class TestCustomProvider:
def test_get_capabilities_from_registry(self): def test_get_capabilities_from_registry(self):
"""Test get_capabilities returns registry capabilities when available.""" """Test get_capabilities returns registry capabilities when available."""
# Save original environment
original_env = os.environ.get("OPENROUTER_ALLOWED_MODELS")
try:
# Clear any restrictions
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1") provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# Test with a model that should be in the registry (OpenRouter model) and is allowed by restrictions # Test with a model that should be in the registry (OpenRouter model)
capabilities = provider.get_capabilities("o3") # o3 is in OPENROUTER_ALLOWED_MODELS capabilities = provider.get_capabilities("o3") # o3 is an OpenRouter model
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false) assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
assert capabilities.context_window > 0 assert capabilities.context_window > 0
@@ -58,6 +65,13 @@ class TestCustomProvider:
assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true
assert capabilities.context_window > 0 assert capabilities.context_window > 0
finally:
# Restore original environment
if original_env is None:
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
else:
os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env
def test_get_capabilities_generic_fallback(self): def test_get_capabilities_generic_fallback(self):
"""Test get_capabilities returns generic capabilities for unknown models.""" """Test get_capabilities returns generic capabilities for unknown models."""
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1") provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")

View File

@@ -84,7 +84,7 @@ class TestDIALProvider:
# Test O3 capabilities # Test O3 capabilities
capabilities = provider.get_capabilities("o3") capabilities = provider.get_capabilities("o3")
assert capabilities.model_name == "o3-2025-04-16" assert capabilities.model_name == "o3-2025-04-16"
assert capabilities.friendly_name == "DIAL" assert capabilities.friendly_name == "DIAL (O3)"
assert capabilities.context_window == 200_000 assert capabilities.context_window == 200_000
assert capabilities.provider == ProviderType.DIAL assert capabilities.provider == ProviderType.DIAL
assert capabilities.supports_images is True assert capabilities.supports_images is True

View File

@@ -0,0 +1,140 @@
"""Tests for DISABLED_TOOLS environment variable functionality."""
import logging
import os
from unittest.mock import patch
import pytest
from server import (
apply_tool_filter,
parse_disabled_tools_env,
validate_disabled_tools,
)
# Mock the tool classes since we're testing the filtering logic
class MockTool:
def __init__(self, name):
self.name = name
class TestDisabledTools:
"""Test suite for DISABLED_TOOLS functionality."""
def test_parse_disabled_tools_empty(self):
"""Empty string returns empty set (no tools disabled)."""
with patch.dict(os.environ, {"DISABLED_TOOLS": ""}):
assert parse_disabled_tools_env() == set()
def test_parse_disabled_tools_not_set(self):
"""Unset variable returns empty set."""
with patch.dict(os.environ, {}, clear=True):
# Ensure DISABLED_TOOLS is not in environment
if "DISABLED_TOOLS" in os.environ:
del os.environ["DISABLED_TOOLS"]
assert parse_disabled_tools_env() == set()
def test_parse_disabled_tools_single(self):
"""Single tool name parsed correctly."""
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug"}):
assert parse_disabled_tools_env() == {"debug"}
def test_parse_disabled_tools_multiple(self):
"""Multiple tools with spaces parsed correctly."""
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug, analyze, refactor"}):
assert parse_disabled_tools_env() == {"debug", "analyze", "refactor"}
def test_parse_disabled_tools_extra_spaces(self):
"""Extra spaces and empty items handled correctly."""
with patch.dict(os.environ, {"DISABLED_TOOLS": " debug , , analyze , "}):
assert parse_disabled_tools_env() == {"debug", "analyze"}
def test_parse_disabled_tools_duplicates(self):
"""Duplicate entries handled correctly (set removes duplicates)."""
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug,analyze,debug"}):
assert parse_disabled_tools_env() == {"debug", "analyze"}
def test_tool_filtering_logic(self):
"""Test the complete filtering logic using the actual server functions."""
# Simulate ALL_TOOLS
ALL_TOOLS = {
"chat": MockTool("chat"),
"debug": MockTool("debug"),
"analyze": MockTool("analyze"),
"version": MockTool("version"),
"listmodels": MockTool("listmodels"),
}
# Test case 1: No tools disabled
disabled_tools = set()
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
assert len(enabled_tools) == 5 # All tools included
assert set(enabled_tools.keys()) == set(ALL_TOOLS.keys())
# Test case 2: Disable some regular tools
disabled_tools = {"debug", "analyze"}
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
assert len(enabled_tools) == 3 # chat, version, listmodels
assert "debug" not in enabled_tools
assert "analyze" not in enabled_tools
assert "chat" in enabled_tools
assert "version" in enabled_tools
assert "listmodels" in enabled_tools
# Test case 3: Attempt to disable essential tools
disabled_tools = {"version", "chat"}
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
assert "version" in enabled_tools # Essential tool not disabled
assert "chat" not in enabled_tools # Regular tool disabled
assert "listmodels" in enabled_tools # Essential tool included
def test_unknown_tools_warning(self, caplog):
"""Test that unknown tool names generate appropriate warnings."""
ALL_TOOLS = {
"chat": MockTool("chat"),
"debug": MockTool("debug"),
"analyze": MockTool("analyze"),
"version": MockTool("version"),
"listmodels": MockTool("listmodels"),
}
disabled_tools = {"chat", "unknown_tool", "another_unknown"}
with caplog.at_level(logging.WARNING):
validate_disabled_tools(disabled_tools, ALL_TOOLS)
assert "Unknown tools in DISABLED_TOOLS: ['another_unknown', 'unknown_tool']" in caplog.text
def test_essential_tools_warning(self, caplog):
"""Test warning when trying to disable essential tools."""
ALL_TOOLS = {
"chat": MockTool("chat"),
"debug": MockTool("debug"),
"analyze": MockTool("analyze"),
"version": MockTool("version"),
"listmodels": MockTool("listmodels"),
}
disabled_tools = {"version", "chat", "debug"}
with caplog.at_level(logging.WARNING):
validate_disabled_tools(disabled_tools, ALL_TOOLS)
assert "Cannot disable essential tools: ['version']" in caplog.text
@pytest.mark.parametrize(
"env_value,expected",
[
("", set()), # Empty string
(" ", set()), # Only spaces
(",,,", set()), # Only commas
("chat", {"chat"}), # Single tool
("chat,debug", {"chat", "debug"}), # Multiple tools
("chat, debug, analyze", {"chat", "debug", "analyze"}), # With spaces
("chat,debug,chat", {"chat", "debug"}), # Duplicates
],
)
def test_parse_disabled_tools_parametrized(self, env_value, expected):
"""Parametrized tests for various input formats."""
with patch.dict(os.environ, {"DISABLED_TOOLS": env_value}):
assert parse_disabled_tools_env() == expected

View File

@@ -483,14 +483,14 @@ class TestImageSupportIntegration:
tool_name="chat", tool_name="chat",
) )
# Create child thread linked to parent # Create child thread linked to parent using a simple tool
child_thread_id = create_thread("debug", {"child": "context"}, parent_thread_id=parent_thread_id) child_thread_id = create_thread("chat", {"prompt": "child context"}, parent_thread_id=parent_thread_id)
add_turn( add_turn(
thread_id=child_thread_id, thread_id=child_thread_id,
role="user", role="user",
content="Child thread with more images", content="Child thread with more images",
images=["child1.png", "shared.png"], # shared.png appears again (should prioritize newer) images=["child1.png", "shared.png"], # shared.png appears again (should prioritize newer)
tool_name="debug", tool_name="chat",
) )
# Mock child thread context for get_thread call # Mock child thread context for get_thread call

View File

@@ -149,7 +149,7 @@ class TestModelEnumeration:
("o3", False), # OpenAI - not available without API key ("o3", False), # OpenAI - not available without API key
("grok", False), # X.AI - not available without API key ("grok", False), # X.AI - not available without API key
("gemini-2.5-flash", False), # Full Gemini name - not available without API key ("gemini-2.5-flash", False), # Full Gemini name - not available without API key
("o4-mini-high", False), # OpenAI variant - not available without API key ("o4-mini", False), # OpenAI variant - not available without API key
("grok-3-fast", False), # X.AI variant - not available without API key ("grok-3-fast", False), # X.AI variant - not available without API key
], ],
) )

View File

@@ -89,7 +89,7 @@ class TestModelMetadataContinuation:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_turns_uses_last_assistant_model(self): async def test_multiple_turns_uses_last_assistant_model(self):
"""Test that with multiple turns, the last assistant turn's model is used.""" """Test that with multiple turns, the last assistant turn's model is used."""
thread_id = create_thread("analyze", {"prompt": "analyze this"}) thread_id = create_thread("chat", {"prompt": "analyze this"})
# Add multiple turns with different models # Add multiple turns with different models
add_turn(thread_id, "assistant", "First response", model_name="gemini-2.5-flash", model_provider="google") add_turn(thread_id, "assistant", "First response", model_name="gemini-2.5-flash", model_provider="google")
@@ -185,11 +185,11 @@ class TestModelMetadataContinuation:
async def test_thread_chain_model_preservation(self): async def test_thread_chain_model_preservation(self):
"""Test model preservation across thread chains (parent-child relationships).""" """Test model preservation across thread chains (parent-child relationships)."""
# Create parent thread # Create parent thread
parent_id = create_thread("analyze", {"prompt": "analyze"}) parent_id = create_thread("chat", {"prompt": "analyze"})
add_turn(parent_id, "assistant", "Analysis", model_name="gemini-2.5-pro", model_provider="google") add_turn(parent_id, "assistant", "Analysis", model_name="gemini-2.5-pro", model_provider="google")
# Create child thread # Create child thread using a simple tool instead of workflow tool
child_id = create_thread("codereview", {"prompt": "review"}, parent_thread_id=parent_id) child_id = create_thread("chat", {"prompt": "review"}, parent_thread_id=parent_id)
# Child thread should be able to access parent's model through chain traversal # Child thread should be able to access parent's model through chain traversal
# NOTE: Current implementation only checks current thread (not parent threads) # NOTE: Current implementation only checks current thread (not parent threads)

View File

@@ -93,7 +93,7 @@ class TestModelRestrictionService:
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}): with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
service = ModelRestrictionService() service = ModelRestrictionService()
models = ["o3", "o3-mini", "o4-mini", "o4-mini-high"] models = ["o3", "o3-mini", "o4-mini", "o3-pro"]
filtered = service.filter_models(ProviderType.OPENAI, models) filtered = service.filter_models(ProviderType.OPENAI, models)
assert filtered == ["o3-mini", "o4-mini"] assert filtered == ["o3-mini", "o4-mini"]
@@ -573,7 +573,7 @@ class TestShorthandRestrictions:
# Other models should not work # Other models should not work
assert not openai_provider.validate_model_name("o3") assert not openai_provider.validate_model_name("o3")
assert not openai_provider.validate_model_name("o4-mini-high") assert not openai_provider.validate_model_name("o3-pro")
@patch.dict( @patch.dict(
os.environ, os.environ,

View File

@@ -185,7 +185,7 @@ class TestO3TemperatureParameterFixSimple:
provider = OpenAIModelProvider(api_key="test-key") provider = OpenAIModelProvider(api_key="test-key")
# Test O3/O4 models that should NOT support temperature parameter # Test O3/O4 models that should NOT support temperature parameter
o3_o4_models = ["o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high"] o3_o4_models = ["o3", "o3-mini", "o3-pro", "o4-mini"]
for model in o3_o4_models: for model in o3_o4_models:
capabilities = provider.get_capabilities(model) capabilities = provider.get_capabilities(model)

View File

@@ -47,14 +47,13 @@ class TestOpenAIProvider:
assert provider.validate_model_name("o3-mini") is True assert provider.validate_model_name("o3-mini") is True
assert provider.validate_model_name("o3-pro") is True assert provider.validate_model_name("o3-pro") is True
assert provider.validate_model_name("o4-mini") is True assert provider.validate_model_name("o4-mini") is True
assert provider.validate_model_name("o4-mini-high") is True assert provider.validate_model_name("o4-mini") is True
# Test valid aliases # Test valid aliases
assert provider.validate_model_name("mini") is True assert provider.validate_model_name("mini") is True
assert provider.validate_model_name("o3mini") is True assert provider.validate_model_name("o3mini") is True
assert provider.validate_model_name("o4mini") is True assert provider.validate_model_name("o4mini") is True
assert provider.validate_model_name("o4minihigh") is True assert provider.validate_model_name("o4mini") is True
assert provider.validate_model_name("o4minihi") is True
# Test invalid model # Test invalid model
assert provider.validate_model_name("invalid-model") is False assert provider.validate_model_name("invalid-model") is False
@@ -69,15 +68,14 @@ class TestOpenAIProvider:
assert provider._resolve_model_name("mini") == "o4-mini" assert provider._resolve_model_name("mini") == "o4-mini"
assert provider._resolve_model_name("o3mini") == "o3-mini" assert provider._resolve_model_name("o3mini") == "o3-mini"
assert provider._resolve_model_name("o4mini") == "o4-mini" assert provider._resolve_model_name("o4mini") == "o4-mini"
assert provider._resolve_model_name("o4minihigh") == "o4-mini-high" assert provider._resolve_model_name("o4mini") == "o4-mini"
assert provider._resolve_model_name("o4minihi") == "o4-mini-high"
# Test full name passthrough # Test full name passthrough
assert provider._resolve_model_name("o3") == "o3" assert provider._resolve_model_name("o3") == "o3"
assert provider._resolve_model_name("o3-mini") == "o3-mini" assert provider._resolve_model_name("o3-mini") == "o3-mini"
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10" assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
assert provider._resolve_model_name("o4-mini") == "o4-mini" assert provider._resolve_model_name("o4-mini") == "o4-mini"
assert provider._resolve_model_name("o4-mini-high") == "o4-mini-high" assert provider._resolve_model_name("o4-mini") == "o4-mini"
def test_get_capabilities_o3(self): def test_get_capabilities_o3(self):
"""Test getting model capabilities for O3.""" """Test getting model capabilities for O3."""
@@ -85,7 +83,7 @@ class TestOpenAIProvider:
capabilities = provider.get_capabilities("o3") capabilities = provider.get_capabilities("o3")
assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities
assert capabilities.friendly_name == "OpenAI" assert capabilities.friendly_name == "OpenAI (O3)"
assert capabilities.context_window == 200_000 assert capabilities.context_window == 200_000
assert capabilities.provider == ProviderType.OPENAI assert capabilities.provider == ProviderType.OPENAI
assert not capabilities.supports_extended_thinking assert not capabilities.supports_extended_thinking
@@ -101,8 +99,8 @@ class TestOpenAIProvider:
provider = OpenAIModelProvider("test-key") provider = OpenAIModelProvider("test-key")
capabilities = provider.get_capabilities("mini") capabilities = provider.get_capabilities("mini")
assert capabilities.model_name == "mini" # Capabilities should show original request assert capabilities.model_name == "o4-mini" # Capabilities should show resolved model name
assert capabilities.friendly_name == "OpenAI" assert capabilities.friendly_name == "OpenAI (O4-mini)"
assert capabilities.context_window == 200_000 assert capabilities.context_window == 200_000
assert capabilities.provider == ProviderType.OPENAI assert capabilities.provider == ProviderType.OPENAI
@@ -184,11 +182,11 @@ class TestOpenAIProvider:
call_kwargs = mock_client.chat.completions.create.call_args[1] call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "o3-mini" assert call_kwargs["model"] == "o3-mini"
# Test o4minihigh -> o4-mini-high # Test o4mini -> o4-mini
mock_response.model = "o4-mini-high" mock_response.model = "o4-mini"
provider.generate_content(prompt="Test", model_name="o4minihigh", temperature=1.0) provider.generate_content(prompt="Test", model_name="o4mini", temperature=1.0)
call_kwargs = mock_client.chat.completions.create.call_args[1] call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "o4-mini-high" assert call_kwargs["model"] == "o4-mini"
@patch("providers.openai_compatible.OpenAI") @patch("providers.openai_compatible.OpenAI")
def test_generate_content_no_alias_passthrough(self, mock_openai_class): def test_generate_content_no_alias_passthrough(self, mock_openai_class):

View File

@@ -57,7 +57,7 @@ class TestOpenRouterProvider:
caps = provider.get_capabilities("o3") caps = provider.get_capabilities("o3")
assert caps.provider == ProviderType.OPENROUTER assert caps.provider == ProviderType.OPENROUTER
assert caps.model_name == "openai/o3" # Resolved name assert caps.model_name == "openai/o3" # Resolved name
assert caps.friendly_name == "OpenRouter" assert caps.friendly_name == "OpenRouter (openai/o3)"
# Test with a model not in registry - should get generic capabilities # Test with a model not in registry - should get generic capabilities
caps = provider.get_capabilities("unknown-model") caps = provider.get_capabilities("unknown-model")
@@ -77,7 +77,7 @@ class TestOpenRouterProvider:
assert provider._resolve_model_name("o3-mini") == "openai/o3-mini" assert provider._resolve_model_name("o3-mini") == "openai/o3-mini"
assert provider._resolve_model_name("o3mini") == "openai/o3-mini" assert provider._resolve_model_name("o3mini") == "openai/o3-mini"
assert provider._resolve_model_name("o4-mini") == "openai/o4-mini" assert provider._resolve_model_name("o4-mini") == "openai/o4-mini"
assert provider._resolve_model_name("o4-mini-high") == "openai/o4-mini-high" assert provider._resolve_model_name("o4-mini") == "openai/o4-mini"
assert provider._resolve_model_name("claude") == "anthropic/claude-sonnet-4" assert provider._resolve_model_name("claude") == "anthropic/claude-sonnet-4"
assert provider._resolve_model_name("mistral") == "mistralai/mistral-large-2411" assert provider._resolve_model_name("mistral") == "mistralai/mistral-large-2411"
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-r1-0528" assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-r1-0528"

View File

@@ -6,8 +6,8 @@ import tempfile
import pytest import pytest
from providers.base import ProviderType from providers.base import ModelCapabilities, ProviderType
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry from providers.openrouter_registry import OpenRouterModelRegistry
class TestOpenRouterModelRegistry: class TestOpenRouterModelRegistry:
@@ -24,7 +24,16 @@ class TestOpenRouterModelRegistry:
def test_custom_config_path(self): def test_custom_config_path(self):
"""Test registry with custom config path.""" """Test registry with custom config path."""
# Create temporary config # Create temporary config
config_data = {"models": [{"model_name": "test/model-1", "aliases": ["test1", "t1"], "context_window": 4096}]} config_data = {
"models": [
{
"model_name": "test/model-1",
"aliases": ["test1", "t1"],
"context_window": 4096,
"max_output_tokens": 2048,
}
]
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(config_data, f) json.dump(config_data, f)
@@ -42,7 +51,11 @@ class TestOpenRouterModelRegistry:
def test_environment_variable_override(self): def test_environment_variable_override(self):
"""Test OPENROUTER_MODELS_PATH environment variable.""" """Test OPENROUTER_MODELS_PATH environment variable."""
# Create custom config # Create custom config
config_data = {"models": [{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192}]} config_data = {
"models": [
{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192, "max_output_tokens": 4096}
]
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(config_data, f) json.dump(config_data, f)
@@ -110,28 +123,29 @@ class TestOpenRouterModelRegistry:
assert registry.resolve("non-existent") is None assert registry.resolve("non-existent") is None
def test_model_capabilities_conversion(self): def test_model_capabilities_conversion(self):
"""Test conversion to ModelCapabilities.""" """Test that registry returns ModelCapabilities directly."""
registry = OpenRouterModelRegistry() registry = OpenRouterModelRegistry()
config = registry.resolve("opus") config = registry.resolve("opus")
assert config is not None assert config is not None
caps = config.to_capabilities() # Registry now returns ModelCapabilities objects directly
assert caps.provider == ProviderType.OPENROUTER assert config.provider == ProviderType.OPENROUTER
assert caps.model_name == "anthropic/claude-opus-4" assert config.model_name == "anthropic/claude-opus-4"
assert caps.friendly_name == "OpenRouter" assert config.friendly_name == "OpenRouter (anthropic/claude-opus-4)"
assert caps.context_window == 200000 assert config.context_window == 200000
assert not caps.supports_extended_thinking assert not config.supports_extended_thinking
def test_duplicate_alias_detection(self): def test_duplicate_alias_detection(self):
"""Test that duplicate aliases are detected.""" """Test that duplicate aliases are detected."""
config_data = { config_data = {
"models": [ "models": [
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096}, {"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096, "max_output_tokens": 2048},
{ {
"model_name": "test/model-2", "model_name": "test/model-2",
"aliases": ["DUPE"], # Same alias, different case "aliases": ["DUPE"], # Same alias, different case
"context_window": 8192, "context_window": 8192,
"max_output_tokens": 2048,
}, },
] ]
} }
@@ -199,19 +213,23 @@ class TestOpenRouterModelRegistry:
def test_model_with_all_capabilities(self): def test_model_with_all_capabilities(self):
"""Test model with all capability flags.""" """Test model with all capability flags."""
config = OpenRouterModelConfig( from providers.base import create_temperature_constraint
caps = ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name="test/full-featured", model_name="test/full-featured",
friendly_name="OpenRouter (test/full-featured)",
aliases=["full"], aliases=["full"],
context_window=128000, context_window=128000,
max_output_tokens=8192,
supports_extended_thinking=True, supports_extended_thinking=True,
supports_system_prompts=True, supports_system_prompts=True,
supports_streaming=True, supports_streaming=True,
supports_function_calling=True, supports_function_calling=True,
supports_json_mode=True, supports_json_mode=True,
description="Fully featured test model", description="Fully featured test model",
temperature_constraint=create_temperature_constraint("range"),
) )
caps = config.to_capabilities()
assert caps.context_window == 128000 assert caps.context_window == 128000
assert caps.supports_extended_thinking assert caps.supports_extended_thinking
assert caps.supports_system_prompts assert caps.supports_system_prompts

View File

@@ -0,0 +1,79 @@
"""Tests for parse_model_option function."""
from server import parse_model_option
class TestParseModelOption:
"""Test cases for model option parsing."""
def test_openrouter_free_suffix_preserved(self):
"""Test that OpenRouter :free suffix is preserved as part of model name."""
model, option = parse_model_option("openai/gpt-3.5-turbo:free")
assert model == "openai/gpt-3.5-turbo:free"
assert option is None
def test_openrouter_beta_suffix_preserved(self):
"""Test that OpenRouter :beta suffix is preserved as part of model name."""
model, option = parse_model_option("anthropic/claude-3-opus:beta")
assert model == "anthropic/claude-3-opus:beta"
assert option is None
def test_openrouter_preview_suffix_preserved(self):
"""Test that OpenRouter :preview suffix is preserved as part of model name."""
model, option = parse_model_option("google/gemini-pro:preview")
assert model == "google/gemini-pro:preview"
assert option is None
def test_ollama_tag_parsed_as_option(self):
"""Test that Ollama tags are parsed as options."""
model, option = parse_model_option("llama3.2:latest")
assert model == "llama3.2"
assert option == "latest"
def test_consensus_stance_parsed_as_option(self):
"""Test that consensus stances are parsed as options."""
model, option = parse_model_option("o3:for")
assert model == "o3"
assert option == "for"
model, option = parse_model_option("gemini-2.5-pro:against")
assert model == "gemini-2.5-pro"
assert option == "against"
def test_openrouter_unknown_suffix_parsed_as_option(self):
"""Test that unknown suffixes on OpenRouter models are parsed as options."""
model, option = parse_model_option("openai/gpt-4:custom-tag")
assert model == "openai/gpt-4"
assert option == "custom-tag"
def test_plain_model_name(self):
"""Test plain model names without colons."""
model, option = parse_model_option("gpt-4")
assert model == "gpt-4"
assert option is None
def test_url_not_parsed(self):
"""Test that URLs are not parsed for options."""
model, option = parse_model_option("http://localhost:8080")
assert model == "http://localhost:8080"
assert option is None
def test_whitespace_handling(self):
"""Test that whitespace is properly stripped."""
model, option = parse_model_option(" openai/gpt-3.5-turbo:free ")
assert model == "openai/gpt-3.5-turbo:free"
assert option is None
model, option = parse_model_option(" llama3.2 : latest ")
assert model == "llama3.2"
assert option == "latest"
def test_case_insensitive_suffix_matching(self):
"""Test that OpenRouter suffix matching is case-insensitive."""
model, option = parse_model_option("openai/gpt-3.5-turbo:FREE")
assert model == "openai/gpt-3.5-turbo:FREE" # Original case preserved
assert option is None
model, option = parse_model_option("openai/gpt-3.5-turbo:Free")
assert model == "openai/gpt-3.5-turbo:Free" # Original case preserved
assert option is None

View File

@@ -58,7 +58,13 @@ class TestProviderRoutingBugs:
""" """
# Save original environment # Save original environment
original_env = {} original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: for key in [
"GEMINI_API_KEY",
"OPENAI_API_KEY",
"XAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENROUTER_ALLOWED_MODELS",
]:
original_env[key] = os.environ.get(key) original_env[key] = os.environ.get(key)
try: try:
@@ -66,6 +72,7 @@ class TestProviderRoutingBugs:
os.environ.pop("GEMINI_API_KEY", None) # No Google API key os.environ.pop("GEMINI_API_KEY", None) # No Google API key
os.environ.pop("OPENAI_API_KEY", None) os.environ.pop("OPENAI_API_KEY", None)
os.environ.pop("XAI_API_KEY", None) os.environ.pop("XAI_API_KEY", None)
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key" os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
# Register only OpenRouter provider (like in server.py:configure_providers) # Register only OpenRouter provider (like in server.py:configure_providers)
@@ -113,12 +120,24 @@ class TestProviderRoutingBugs:
""" """
# Save original environment # Save original environment
original_env = {} original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: for key in [
"GEMINI_API_KEY",
"OPENAI_API_KEY",
"XAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENROUTER_ALLOWED_MODELS",
]:
original_env[key] = os.environ.get(key) original_env[key] = os.environ.get(key)
try: try:
# Set up scenario: NO API keys at all # Set up scenario: NO API keys at all
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: for key in [
"GEMINI_API_KEY",
"OPENAI_API_KEY",
"XAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENROUTER_ALLOWED_MODELS",
]:
os.environ.pop(key, None) os.environ.pop(key, None)
# Create tool to test fallback logic # Create tool to test fallback logic
@@ -151,7 +170,13 @@ class TestProviderRoutingBugs:
""" """
# Save original environment # Save original environment
original_env = {} original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: for key in [
"GEMINI_API_KEY",
"OPENAI_API_KEY",
"XAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENROUTER_ALLOWED_MODELS",
]:
original_env[key] = os.environ.get(key) original_env[key] = os.environ.get(key)
try: try:
@@ -160,6 +185,7 @@ class TestProviderRoutingBugs:
os.environ["OPENAI_API_KEY"] = "test-openai-key" os.environ["OPENAI_API_KEY"] = "test-openai-key"
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key" os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
os.environ.pop("XAI_API_KEY", None) os.environ.pop("XAI_API_KEY", None)
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
# Register providers in priority order (like server.py) # Register providers in priority order (like server.py)
from providers.gemini import GeminiModelProvider from providers.gemini import GeminiModelProvider

View File

@@ -215,9 +215,7 @@ class TestOpenAIProvider:
assert provider.validate_model_name("o3-mini") # Backwards compatibility assert provider.validate_model_name("o3-mini") # Backwards compatibility
assert provider.validate_model_name("o4-mini") assert provider.validate_model_name("o4-mini")
assert provider.validate_model_name("o4mini") assert provider.validate_model_name("o4mini")
assert provider.validate_model_name("o4-mini-high") assert provider.validate_model_name("o4-mini")
assert provider.validate_model_name("o4minihigh")
assert provider.validate_model_name("o4minihi")
assert not provider.validate_model_name("gpt-4o") assert not provider.validate_model_name("gpt-4o")
assert not provider.validate_model_name("invalid-model") assert not provider.validate_model_name("invalid-model")
@@ -229,4 +227,4 @@ class TestOpenAIProvider:
assert not provider.supports_thinking_mode("o3mini") assert not provider.supports_thinking_mode("o3mini")
assert not provider.supports_thinking_mode("o3-mini") assert not provider.supports_thinking_mode("o3-mini")
assert not provider.supports_thinking_mode("o4-mini") assert not provider.supports_thinking_mode("o4-mini")
assert not provider.supports_thinking_mode("o4-mini-high") assert not provider.supports_thinking_mode("o4-mini")

View File

@@ -0,0 +1,205 @@
"""Test the SUPPORTED_MODELS aliases structure across all providers."""
from providers.dial import DIALModelProvider
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.xai import XAIModelProvider
class TestSupportedModelsAliases:
"""Test that all providers have correctly structured SUPPORTED_MODELS with aliases."""
def test_gemini_provider_aliases(self):
"""Test Gemini provider's alias structure."""
provider = GeminiModelProvider("test-key")
# Check that all models have ModelCapabilities with aliases
for model_name, config in provider.SUPPORTED_MODELS.items():
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
# Test specific aliases
assert "flash" in provider.SUPPORTED_MODELS["gemini-2.5-flash"].aliases
assert "pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro"].aliases
assert "flash-2.0" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
assert "flash2" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
assert "flashlite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
assert "flash-lite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
# Test alias resolution
assert provider._resolve_model_name("flash") == "gemini-2.5-flash"
assert provider._resolve_model_name("pro") == "gemini-2.5-pro"
assert provider._resolve_model_name("flash-2.0") == "gemini-2.0-flash"
assert provider._resolve_model_name("flash2") == "gemini-2.0-flash"
assert provider._resolve_model_name("flashlite") == "gemini-2.0-flash-lite"
# Test case insensitive resolution
assert provider._resolve_model_name("Flash") == "gemini-2.5-flash"
assert provider._resolve_model_name("PRO") == "gemini-2.5-pro"
def test_openai_provider_aliases(self):
"""Test OpenAI provider's alias structure."""
provider = OpenAIModelProvider("test-key")
# Check that all models have ModelCapabilities with aliases
for model_name, config in provider.SUPPORTED_MODELS.items():
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
# Test specific aliases
assert "mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases
assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro-2025-06-10"].aliases
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1-2025-04-14"].aliases
# Test alias resolution
assert provider._resolve_model_name("mini") == "o4-mini"
assert provider._resolve_model_name("o3mini") == "o3-mini"
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
assert provider._resolve_model_name("o4mini") == "o4-mini"
assert provider._resolve_model_name("gpt4.1") == "gpt-4.1-2025-04-14"
# Test case insensitive resolution
assert provider._resolve_model_name("Mini") == "o4-mini"
assert provider._resolve_model_name("O3MINI") == "o3-mini"
def test_xai_provider_aliases(self):
"""Test XAI provider's alias structure."""
provider = XAIModelProvider("test-key")
# Check that all models have ModelCapabilities with aliases
for model_name, config in provider.SUPPORTED_MODELS.items():
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
# Test specific aliases
assert "grok" in provider.SUPPORTED_MODELS["grok-3"].aliases
assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases
assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
# Test alias resolution
assert provider._resolve_model_name("grok") == "grok-3"
assert provider._resolve_model_name("grok3") == "grok-3"
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
# Test case insensitive resolution
assert provider._resolve_model_name("Grok") == "grok-3"
assert provider._resolve_model_name("GROKFAST") == "grok-3-fast"
def test_dial_provider_aliases(self):
"""Test DIAL provider's alias structure."""
provider = DIALModelProvider("test-key")
# Check that all models have ModelCapabilities with aliases
for model_name, config in provider.SUPPORTED_MODELS.items():
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
# Test specific aliases
assert "o3" in provider.SUPPORTED_MODELS["o3-2025-04-16"].aliases
assert "o4-mini" in provider.SUPPORTED_MODELS["o4-mini-2025-04-16"].aliases
assert "sonnet-4" in provider.SUPPORTED_MODELS["anthropic.claude-sonnet-4-20250514-v1:0"].aliases
assert "opus-4" in provider.SUPPORTED_MODELS["anthropic.claude-opus-4-20250514-v1:0"].aliases
assert "gemini-2.5-pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro-preview-05-06"].aliases
# Test alias resolution
assert provider._resolve_model_name("o3") == "o3-2025-04-16"
assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16"
assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0"
# Test case insensitive resolution
assert provider._resolve_model_name("O3") == "o3-2025-04-16"
assert provider._resolve_model_name("SONNET-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
def test_list_models_includes_aliases(self):
"""Test that list_models returns both base models and aliases."""
# Test Gemini
gemini_provider = GeminiModelProvider("test-key")
gemini_models = gemini_provider.list_models(respect_restrictions=False)
assert "gemini-2.5-flash" in gemini_models
assert "flash" in gemini_models
assert "gemini-2.5-pro" in gemini_models
assert "pro" in gemini_models
# Test OpenAI
openai_provider = OpenAIModelProvider("test-key")
openai_models = openai_provider.list_models(respect_restrictions=False)
assert "o4-mini" in openai_models
assert "mini" in openai_models
assert "o3-mini" in openai_models
assert "o3mini" in openai_models
# Test XAI
xai_provider = XAIModelProvider("test-key")
xai_models = xai_provider.list_models(respect_restrictions=False)
assert "grok-3" in xai_models
assert "grok" in xai_models
assert "grok-3-fast" in xai_models
assert "grokfast" in xai_models
# Test DIAL
dial_provider = DIALModelProvider("test-key")
dial_models = dial_provider.list_models(respect_restrictions=False)
assert "o3-2025-04-16" in dial_models
assert "o3" in dial_models
def test_list_all_known_models_includes_aliases(self):
"""Test that list_all_known_models returns all models and aliases in lowercase."""
# Test Gemini
gemini_provider = GeminiModelProvider("test-key")
gemini_all = gemini_provider.list_all_known_models()
assert "gemini-2.5-flash" in gemini_all
assert "flash" in gemini_all
assert "gemini-2.5-pro" in gemini_all
assert "pro" in gemini_all
# All should be lowercase
assert all(model == model.lower() for model in gemini_all)
# Test OpenAI
openai_provider = OpenAIModelProvider("test-key")
openai_all = openai_provider.list_all_known_models()
assert "o4-mini" in openai_all
assert "mini" in openai_all
assert "o3-mini" in openai_all
assert "o3mini" in openai_all
# All should be lowercase
assert all(model == model.lower() for model in openai_all)
def test_no_string_shorthand_in_supported_models(self):
"""Test that no provider has string-based shorthands anymore."""
providers = [
GeminiModelProvider("test-key"),
OpenAIModelProvider("test-key"),
XAIModelProvider("test-key"),
DIALModelProvider("test-key"),
]
for provider in providers:
for model_name, config in provider.SUPPORTED_MODELS.items():
# All values must be ModelCapabilities objects, not strings or dicts
from providers.base import ModelCapabilities
assert isinstance(config, ModelCapabilities), (
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "
f"must be a ModelCapabilities object, not {type(config).__name__}"
)
def test_resolve_returns_original_if_not_found(self):
"""Test that _resolve_model_name returns original name if alias not found."""
providers = [
GeminiModelProvider("test-key"),
OpenAIModelProvider("test-key"),
XAIModelProvider("test-key"),
DIALModelProvider("test-key"),
]
for provider in providers:
# Test with unknown model name
assert provider._resolve_model_name("unknown-model") == "unknown-model"
assert provider._resolve_model_name("gpt-4") == "gpt-4"
assert provider._resolve_model_name("claude-3") == "claude-3"

View File

@@ -48,7 +48,13 @@ class TestWorkflowMetadata:
""" """
# Save original environment # Save original environment
original_env = {} original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: for key in [
"GEMINI_API_KEY",
"OPENAI_API_KEY",
"XAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENROUTER_ALLOWED_MODELS",
]:
original_env[key] = os.environ.get(key) original_env[key] = os.environ.get(key)
try: try:
@@ -56,6 +62,7 @@ class TestWorkflowMetadata:
os.environ.pop("GEMINI_API_KEY", None) os.environ.pop("GEMINI_API_KEY", None)
os.environ.pop("OPENAI_API_KEY", None) os.environ.pop("OPENAI_API_KEY", None)
os.environ.pop("XAI_API_KEY", None) os.environ.pop("XAI_API_KEY", None)
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key" os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
# Register OpenRouter provider # Register OpenRouter provider
@@ -124,7 +131,13 @@ class TestWorkflowMetadata:
""" """
# Save original environment # Save original environment
original_env = {} original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: for key in [
"GEMINI_API_KEY",
"OPENAI_API_KEY",
"XAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENROUTER_ALLOWED_MODELS",
]:
original_env[key] = os.environ.get(key) original_env[key] = os.environ.get(key)
try: try:
@@ -132,6 +145,7 @@ class TestWorkflowMetadata:
os.environ.pop("GEMINI_API_KEY", None) os.environ.pop("GEMINI_API_KEY", None)
os.environ.pop("OPENAI_API_KEY", None) os.environ.pop("OPENAI_API_KEY", None)
os.environ.pop("XAI_API_KEY", None) os.environ.pop("XAI_API_KEY", None)
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key" os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
# Register OpenRouter provider # Register OpenRouter provider
@@ -182,6 +196,15 @@ class TestWorkflowMetadata:
""" """
Test that workflow tools handle metadata gracefully when model context is missing. Test that workflow tools handle metadata gracefully when model context is missing.
""" """
# Save original environment
original_env = {}
for key in ["OPENROUTER_ALLOWED_MODELS"]:
original_env[key] = os.environ.get(key)
try:
# Clear any restrictions
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
# Create debug tool # Create debug tool
debug_tool = DebugIssueTool() debug_tool = DebugIssueTool()
@@ -220,6 +243,14 @@ class TestWorkflowMetadata:
assert metadata["model_used"] == "flash", "model_used should be from request" assert metadata["model_used"] == "flash", "model_used should be from request"
assert metadata["provider_used"] == "unknown", "provider_used should be 'unknown' in fallback" assert metadata["provider_used"] == "unknown", "provider_used should be 'unknown' in fallback"
finally:
# Restore original environment
for key, value in original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
@pytest.mark.no_mock_provider @pytest.mark.no_mock_provider
def test_workflow_metadata_preserves_existing_response_fields(self): def test_workflow_metadata_preserves_existing_response_fields(self):
""" """
@@ -227,7 +258,13 @@ class TestWorkflowMetadata:
""" """
# Save original environment # Save original environment
original_env = {} original_env = {}
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]: for key in [
"GEMINI_API_KEY",
"OPENAI_API_KEY",
"XAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENROUTER_ALLOWED_MODELS",
]:
original_env[key] = os.environ.get(key) original_env[key] = os.environ.get(key)
try: try:
@@ -235,6 +272,7 @@ class TestWorkflowMetadata:
os.environ.pop("GEMINI_API_KEY", None) os.environ.pop("GEMINI_API_KEY", None)
os.environ.pop("OPENAI_API_KEY", None) os.environ.pop("OPENAI_API_KEY", None)
os.environ.pop("XAI_API_KEY", None) os.environ.pop("XAI_API_KEY", None)
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key" os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
# Register OpenRouter provider # Register OpenRouter provider

View File

@@ -77,7 +77,7 @@ class TestXAIProvider:
capabilities = provider.get_capabilities("grok-3") capabilities = provider.get_capabilities("grok-3")
assert capabilities.model_name == "grok-3" assert capabilities.model_name == "grok-3"
assert capabilities.friendly_name == "X.AI" assert capabilities.friendly_name == "X.AI (Grok 3)"
assert capabilities.context_window == 131_072 assert capabilities.context_window == 131_072
assert capabilities.provider == ProviderType.XAI assert capabilities.provider == ProviderType.XAI
assert not capabilities.supports_extended_thinking assert not capabilities.supports_extended_thinking
@@ -96,7 +96,7 @@ class TestXAIProvider:
capabilities = provider.get_capabilities("grok-3-fast") capabilities = provider.get_capabilities("grok-3-fast")
assert capabilities.model_name == "grok-3-fast" assert capabilities.model_name == "grok-3-fast"
assert capabilities.friendly_name == "X.AI" assert capabilities.friendly_name == "X.AI (Grok 3 Fast)"
assert capabilities.context_window == 131_072 assert capabilities.context_window == 131_072
assert capabilities.provider == ProviderType.XAI assert capabilities.provider == ProviderType.XAI
assert not capabilities.supports_extended_thinking assert not capabilities.supports_extended_thinking
@@ -212,31 +212,34 @@ class TestXAIProvider:
assert provider.FRIENDLY_NAME == "X.AI" assert provider.FRIENDLY_NAME == "X.AI"
capabilities = provider.get_capabilities("grok-3") capabilities = provider.get_capabilities("grok-3")
assert capabilities.friendly_name == "X.AI" assert capabilities.friendly_name == "X.AI (Grok 3)"
def test_supported_models_structure(self): def test_supported_models_structure(self):
"""Test that SUPPORTED_MODELS has the correct structure.""" """Test that SUPPORTED_MODELS has the correct structure."""
provider = XAIModelProvider("test-key") provider = XAIModelProvider("test-key")
# Check that all expected models are present # Check that all expected base models are present
assert "grok-3" in provider.SUPPORTED_MODELS assert "grok-3" in provider.SUPPORTED_MODELS
assert "grok-3-fast" in provider.SUPPORTED_MODELS assert "grok-3-fast" in provider.SUPPORTED_MODELS
assert "grok" in provider.SUPPORTED_MODELS
assert "grok3" in provider.SUPPORTED_MODELS
assert "grokfast" in provider.SUPPORTED_MODELS
assert "grok3fast" in provider.SUPPORTED_MODELS
# Check model configs have required fields # Check model configs have required fields
grok3_config = provider.SUPPORTED_MODELS["grok-3"] from providers.base import ModelCapabilities
assert isinstance(grok3_config, dict)
assert "context_window" in grok3_config
assert "supports_extended_thinking" in grok3_config
assert grok3_config["context_window"] == 131_072
assert grok3_config["supports_extended_thinking"] is False
# Check shortcuts point to full names grok3_config = provider.SUPPORTED_MODELS["grok-3"]
assert provider.SUPPORTED_MODELS["grok"] == "grok-3" assert isinstance(grok3_config, ModelCapabilities)
assert provider.SUPPORTED_MODELS["grokfast"] == "grok-3-fast" assert hasattr(grok3_config, "context_window")
assert hasattr(grok3_config, "supports_extended_thinking")
assert hasattr(grok3_config, "aliases")
assert grok3_config.context_window == 131_072
assert grok3_config.supports_extended_thinking is False
# Check aliases are correctly structured
assert "grok" in grok3_config.aliases
assert "grok3" in grok3_config.aliases
grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"]
assert "grok3fast" in grok3fast_config.aliases
assert "grokfast" in grok3fast_config.aliases
@patch("providers.openai_compatible.OpenAI") @patch("providers.openai_compatible.OpenAI")
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class): def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):

View File

@@ -99,15 +99,11 @@ class ListModelsTool(BaseTool):
output_lines.append("**Status**: Configured and available") output_lines.append("**Status**: Configured and available")
output_lines.append("\n**Models**:") output_lines.append("\n**Models**:")
# Get models from the provider's SUPPORTED_MODELS # Get models from the provider's model configurations
for model_name, config in provider.SUPPORTED_MODELS.items(): for model_name, capabilities in provider.get_model_configurations().items():
# Skip alias entries (string values) # Get description and context from the ModelCapabilities object
if isinstance(config, str): description = capabilities.description or "No description available"
continue context_window = capabilities.context_window
# Get description and context from the model config
description = config.get("description", "No description available")
context_window = config.get("context_window", 0)
# Format context window # Format context window
if context_window >= 1_000_000: if context_window >= 1_000_000:
@@ -133,13 +129,14 @@ class ListModelsTool(BaseTool):
# Show aliases for this provider # Show aliases for this provider
aliases = [] aliases = []
for alias_name, target in provider.SUPPORTED_MODELS.items(): for model_name, capabilities in provider.get_model_configurations().items():
if isinstance(target, str): # This is an alias if capabilities.aliases:
aliases.append(f"- `{alias_name}` → `{target}`") for alias in capabilities.aliases:
aliases.append(f"- `{alias}` → `{model_name}`")
if aliases: if aliases:
output_lines.append("\n**Aliases**:") output_lines.append("\n**Aliases**:")
output_lines.extend(aliases) output_lines.extend(sorted(aliases)) # Sort for consistent output
else: else:
output_lines.append(f"**Status**: Not configured (set {info['env_key']})") output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
@@ -237,7 +234,7 @@ class ListModelsTool(BaseTool):
for alias in registry.list_aliases(): for alias in registry.list_aliases():
config = registry.resolve(alias) config = registry.resolve(alias)
if config and hasattr(config, "is_custom") and config.is_custom: if config and config.is_custom:
custom_models.append((alias, config)) custom_models.append((alias, config))
if custom_models: if custom_models:

View File

@@ -256,8 +256,8 @@ class BaseTool(ABC):
# Find all custom models (is_custom=true) # Find all custom models (is_custom=true)
for alias in registry.list_aliases(): for alias in registry.list_aliases():
config = registry.resolve(alias) config = registry.resolve(alias)
# Use hasattr for defensive programming - is_custom is optional with default False # Check if this is a custom model that requires custom endpoints
if config and hasattr(config, "is_custom") and config.is_custom: if config and config.is_custom:
if alias not in all_models: if alias not in all_models:
all_models.append(alias) all_models.append(alias)
except Exception as e: except Exception as e:
@@ -311,12 +311,16 @@ class BaseTool(ABC):
ProviderType.GOOGLE: "Gemini models", ProviderType.GOOGLE: "Gemini models",
ProviderType.OPENAI: "OpenAI models", ProviderType.OPENAI: "OpenAI models",
ProviderType.XAI: "X.AI GROK models", ProviderType.XAI: "X.AI GROK models",
ProviderType.DIAL: "DIAL models",
ProviderType.CUSTOM: "Custom models", ProviderType.CUSTOM: "Custom models",
ProviderType.OPENROUTER: "OpenRouter models", ProviderType.OPENROUTER: "OpenRouter models",
} }
# Check available providers and add their model descriptions # Check available providers and add their model descriptions
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI]:
# Start with native providers
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]:
# Only if this is registered / available
provider = ModelProviderRegistry.get_provider(provider_type) provider = ModelProviderRegistry.get_provider(provider_type)
if provider: if provider:
provider_section_added = False provider_section_added = False
@@ -324,13 +328,13 @@ class BaseTool(ABC):
try: try:
# Get model config to extract description # Get model config to extract description
model_config = provider.SUPPORTED_MODELS.get(model_name) model_config = provider.SUPPORTED_MODELS.get(model_name)
if isinstance(model_config, dict) and "description" in model_config: if model_config and model_config.description:
if not provider_section_added: if not provider_section_added:
model_desc_parts.append( model_desc_parts.append(
f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:" f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:"
) )
provider_section_added = True provider_section_added = True
model_desc_parts.append(f"- '{model_name}': {model_config['description']}") model_desc_parts.append(f"- '{model_name}': {model_config.description}")
except Exception: except Exception:
# Skip models without descriptions # Skip models without descriptions
continue continue
@@ -346,8 +350,8 @@ class BaseTool(ABC):
# Find all custom models (is_custom=true) # Find all custom models (is_custom=true)
for alias in registry.list_aliases(): for alias in registry.list_aliases():
config = registry.resolve(alias) config = registry.resolve(alias)
# Use hasattr for defensive programming - is_custom is optional with default False # Check if this is a custom model that requires custom endpoints
if config and hasattr(config, "is_custom") and config.is_custom: if config and config.is_custom:
# Format context window # Format context window
context_tokens = config.context_window context_tokens = config.context_window
if context_tokens >= 1_000_000: if context_tokens >= 1_000_000:

View File

@@ -128,6 +128,10 @@ class ModelRestrictionService:
allowed_set = self.restrictions[provider_type] allowed_set = self.restrictions[provider_type]
if len(allowed_set) == 0:
# Empty set - allowed
return True
# Check both the resolved name and original name (if different) # Check both the resolved name and original name (if different)
names_to_check = {model_name.lower()} names_to_check = {model_name.lower()}
if original_name and original_name.lower() != model_name.lower(): if original_name and original_name.lower() != model_name.lower():