Merge branch 'feat-local_support_with_UTF-8_encoding-update' of https://github.com/GiGiDKR/zen-mcp-server into feat-local_support_with_UTF-8_encoding-update
This commit is contained in:
38
CLAUDE.md
38
CLAUDE.md
@@ -128,7 +128,28 @@ python communication_simulator_test.py
|
|||||||
python communication_simulator_test.py --verbose
|
python communication_simulator_test.py --verbose
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Run Individual Simulator Tests (Recommended)
|
#### Quick Test Mode (Recommended for Time-Limited Testing)
|
||||||
|
```bash
|
||||||
|
# Run quick test mode - 6 essential tests that provide maximum functionality coverage
|
||||||
|
python communication_simulator_test.py --quick
|
||||||
|
|
||||||
|
# Run quick test mode with verbose output
|
||||||
|
python communication_simulator_test.py --quick --verbose
|
||||||
|
```
|
||||||
|
|
||||||
|
**Quick mode runs these 6 essential tests:**
|
||||||
|
- `cross_tool_continuation` - Cross-tool conversation memory testing (chat, thinkdeep, codereview, analyze, debug)
|
||||||
|
- `conversation_chain_validation` - Core conversation threading and memory validation
|
||||||
|
- `consensus_workflow_accurate` - Consensus tool with flash model and stance testing
|
||||||
|
- `codereview_validation` - CodeReview tool with flash model and multi-step workflows
|
||||||
|
- `planner_validation` - Planner tool with flash model and complex planning workflows
|
||||||
|
- `token_allocation_validation` - Token allocation and conversation history buildup testing
|
||||||
|
|
||||||
|
**Why these 6 tests:** They cover the core functionality including conversation memory (`utils/conversation_memory.py`), chat tool functionality, file processing and deduplication, model selection (flash/flashlite/o3), and cross-tool conversation workflows. These tests validate the most critical parts of the system in minimal time.
|
||||||
|
|
||||||
|
**Note:** Some workflow tools (analyze, codereview, planner, consensus, etc.) require specific workflow parameters and may need individual testing rather than quick mode testing.
|
||||||
|
|
||||||
|
#### Run Individual Simulator Tests (For Detailed Testing)
|
||||||
```bash
|
```bash
|
||||||
# List all available tests
|
# List all available tests
|
||||||
python communication_simulator_test.py --list-tests
|
python communication_simulator_test.py --list-tests
|
||||||
@@ -223,15 +244,17 @@ python -m pytest tests/ -v
|
|||||||
#### After Making Changes
|
#### After Making Changes
|
||||||
1. Run quality checks again: `./code_quality_checks.sh`
|
1. Run quality checks again: `./code_quality_checks.sh`
|
||||||
2. Run integration tests locally: `./run_integration_tests.sh`
|
2. Run integration tests locally: `./run_integration_tests.sh`
|
||||||
3. Run relevant simulator tests: `python communication_simulator_test.py --individual <test_name>`
|
3. Run quick test mode for fast validation: `python communication_simulator_test.py --quick`
|
||||||
4. Check logs for any issues: `tail -n 100 logs/mcp_server.log`
|
4. Run relevant specific simulator tests if needed: `python communication_simulator_test.py --individual <test_name>`
|
||||||
5. Restart Claude session to use updated code
|
5. Check logs for any issues: `tail -n 100 logs/mcp_server.log`
|
||||||
|
6. Restart Claude session to use updated code
|
||||||
|
|
||||||
#### Before Committing/PR
|
#### Before Committing/PR
|
||||||
1. Final quality check: `./code_quality_checks.sh`
|
1. Final quality check: `./code_quality_checks.sh`
|
||||||
2. Run integration tests: `./run_integration_tests.sh`
|
2. Run integration tests: `./run_integration_tests.sh`
|
||||||
3. Run full simulator test suite: `./run_integration_tests.sh --with-simulator`
|
3. Run quick test mode: `python communication_simulator_test.py --quick`
|
||||||
4. Verify all tests pass 100%
|
4. Run full simulator test suite (optional): `./run_integration_tests.sh --with-simulator`
|
||||||
|
5. Verify all tests pass 100%
|
||||||
|
|
||||||
### Common Troubleshooting
|
### Common Troubleshooting
|
||||||
|
|
||||||
@@ -250,6 +273,9 @@ which python
|
|||||||
|
|
||||||
#### Test Failures
|
#### Test Failures
|
||||||
```bash
|
```bash
|
||||||
|
# First try quick test mode to see if it's a general issue
|
||||||
|
python communication_simulator_test.py --quick --verbose
|
||||||
|
|
||||||
# Run individual failing test with verbose output
|
# Run individual failing test with verbose output
|
||||||
python communication_simulator_test.py --individual <test_name> --verbose
|
python communication_simulator_test.py --individual <test_name> --verbose
|
||||||
|
|
||||||
|
|||||||
@@ -409,7 +409,7 @@ for most debugging workflows, as Claude is usually able to confidently find the
|
|||||||
When in doubt, you can always follow up with a new prompt and ask Claude to share its findings with another model:
|
When in doubt, you can always follow up with a new prompt and ask Claude to share its findings with another model:
|
||||||
|
|
||||||
```text
|
```text
|
||||||
Use continuation with thinkdeep, share details with o4-mini-high to find out what the best fix is for this
|
Use continuation with thinkdeep, share details with o4-mini to find out what the best fix is for this
|
||||||
```
|
```
|
||||||
|
|
||||||
**[📖 Read More](docs/tools/debug.md)** - Step-by-step investigation methodology with workflow enforcement
|
**[📖 Read More](docs/tools/debug.md)** - Step-by-step investigation methodology with workflow enforcement
|
||||||
|
|||||||
@@ -38,6 +38,15 @@ Available tests:
|
|||||||
debug_validation - Debug tool validation with actual bugs
|
debug_validation - Debug tool validation with actual bugs
|
||||||
conversation_chain_validation - Conversation chain continuity validation
|
conversation_chain_validation - Conversation chain continuity validation
|
||||||
|
|
||||||
|
Quick Test Mode (for time-limited testing):
|
||||||
|
Use --quick to run the essential 6 tests that provide maximum coverage:
|
||||||
|
- cross_tool_continuation (cross-tool conversation memory)
|
||||||
|
- basic_conversation (basic chat functionality)
|
||||||
|
- content_validation (content validation and deduplication)
|
||||||
|
- model_thinking_config (flash/flashlite model testing)
|
||||||
|
- o3_model_selection (o3 model selection testing)
|
||||||
|
- per_tool_deduplication (file deduplication for individual tools)
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
# Run all tests
|
# Run all tests
|
||||||
python communication_simulator_test.py
|
python communication_simulator_test.py
|
||||||
@@ -48,6 +57,9 @@ Examples:
|
|||||||
# Run a single test individually (with full standalone setup)
|
# Run a single test individually (with full standalone setup)
|
||||||
python communication_simulator_test.py --individual content_validation
|
python communication_simulator_test.py --individual content_validation
|
||||||
|
|
||||||
|
# Run quick test mode (essential 6 tests for time-limited testing)
|
||||||
|
python communication_simulator_test.py --quick
|
||||||
|
|
||||||
# Force setup standalone server environment before running tests
|
# Force setup standalone server environment before running tests
|
||||||
python communication_simulator_test.py --setup
|
python communication_simulator_test.py --setup
|
||||||
|
|
||||||
@@ -68,21 +80,48 @@ class CommunicationSimulator:
|
|||||||
"""Simulates real-world Claude CLI communication with MCP Gemini server"""
|
"""Simulates real-world Claude CLI communication with MCP Gemini server"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, verbose: bool = False, keep_logs: bool = False, selected_tests: list[str] = None, setup: bool = False
|
self,
|
||||||
|
verbose: bool = False,
|
||||||
|
keep_logs: bool = False,
|
||||||
|
selected_tests: list[str] = None,
|
||||||
|
setup: bool = False,
|
||||||
|
quick_mode: bool = False,
|
||||||
):
|
):
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.keep_logs = keep_logs
|
self.keep_logs = keep_logs
|
||||||
self.selected_tests = selected_tests or []
|
self.selected_tests = selected_tests or []
|
||||||
self.setup = setup
|
self.setup = setup
|
||||||
|
self.quick_mode = quick_mode
|
||||||
self.temp_dir = None
|
self.temp_dir = None
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
self.python_path = self._get_python_path()
|
self.python_path = self._get_python_path()
|
||||||
|
|
||||||
|
# Configure logging first
|
||||||
|
log_level = logging.DEBUG if verbose else logging.INFO
|
||||||
|
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Import test registry
|
# Import test registry
|
||||||
from simulator_tests import TEST_REGISTRY
|
from simulator_tests import TEST_REGISTRY
|
||||||
|
|
||||||
self.test_registry = TEST_REGISTRY
|
self.test_registry = TEST_REGISTRY
|
||||||
|
|
||||||
|
# Define quick mode tests (essential tests for time-limited testing)
|
||||||
|
# Focus on tests that work with current tool configurations
|
||||||
|
self.quick_mode_tests = [
|
||||||
|
"cross_tool_continuation", # Cross-tool conversation memory
|
||||||
|
"basic_conversation", # Basic chat functionality
|
||||||
|
"content_validation", # Content validation and deduplication
|
||||||
|
"model_thinking_config", # Flash/flashlite model testing
|
||||||
|
"o3_model_selection", # O3 model selection testing
|
||||||
|
"per_tool_deduplication", # File deduplication for individual tools
|
||||||
|
]
|
||||||
|
|
||||||
|
# If quick mode is enabled, override selected_tests
|
||||||
|
if self.quick_mode:
|
||||||
|
self.selected_tests = self.quick_mode_tests
|
||||||
|
self.logger.info(f"Quick mode enabled - running {len(self.quick_mode_tests)} essential tests")
|
||||||
|
|
||||||
# Available test methods mapping
|
# Available test methods mapping
|
||||||
self.available_tests = {
|
self.available_tests = {
|
||||||
name: self._create_test_runner(test_class) for name, test_class in self.test_registry.items()
|
name: self._create_test_runner(test_class) for name, test_class in self.test_registry.items()
|
||||||
@@ -91,11 +130,6 @@ class CommunicationSimulator:
|
|||||||
# Test result tracking
|
# Test result tracking
|
||||||
self.test_results = dict.fromkeys(self.test_registry.keys(), False)
|
self.test_results = dict.fromkeys(self.test_registry.keys(), False)
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
log_level = logging.DEBUG if verbose else logging.INFO
|
|
||||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def _get_python_path(self) -> str:
|
def _get_python_path(self) -> str:
|
||||||
"""Get the Python path for the virtual environment"""
|
"""Get the Python path for the virtual environment"""
|
||||||
current_dir = os.getcwd()
|
current_dir = os.getcwd()
|
||||||
@@ -415,6 +449,9 @@ def parse_arguments():
|
|||||||
parser.add_argument("--tests", "-t", nargs="+", help="Specific tests to run (space-separated)")
|
parser.add_argument("--tests", "-t", nargs="+", help="Specific tests to run (space-separated)")
|
||||||
parser.add_argument("--list-tests", action="store_true", help="List available tests and exit")
|
parser.add_argument("--list-tests", action="store_true", help="List available tests and exit")
|
||||||
parser.add_argument("--individual", "-i", help="Run a single test individually")
|
parser.add_argument("--individual", "-i", help="Run a single test individually")
|
||||||
|
parser.add_argument(
|
||||||
|
"--quick", "-q", action="store_true", help="Run quick test mode (6 essential tests for time-limited testing)"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--setup", action="store_true", help="Force setup standalone server environment using run-server.sh"
|
"--setup", action="store_true", help="Force setup standalone server environment using run-server.sh"
|
||||||
)
|
)
|
||||||
@@ -492,7 +529,11 @@ def main():
|
|||||||
|
|
||||||
# Initialize simulator consistently for all use cases
|
# Initialize simulator consistently for all use cases
|
||||||
simulator = CommunicationSimulator(
|
simulator = CommunicationSimulator(
|
||||||
verbose=args.verbose, keep_logs=args.keep_logs, selected_tests=args.tests, setup=args.setup
|
verbose=args.verbose,
|
||||||
|
keep_logs=args.keep_logs,
|
||||||
|
selected_tests=args.tests,
|
||||||
|
setup=args.setup,
|
||||||
|
quick_mode=args.quick,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine execution mode and run
|
# Determine execution mode and run
|
||||||
|
|||||||
@@ -22,6 +22,7 @@
|
|||||||
"model_name": "The model identifier - OpenRouter format (e.g., 'anthropic/claude-opus-4') or custom model name (e.g., 'llama3.2')",
|
"model_name": "The model identifier - OpenRouter format (e.g., 'anthropic/claude-opus-4') or custom model name (e.g., 'llama3.2')",
|
||||||
"aliases": "Array of short names users can type instead of the full model name",
|
"aliases": "Array of short names users can type instead of the full model name",
|
||||||
"context_window": "Total number of tokens the model can process (input + output combined)",
|
"context_window": "Total number of tokens the model can process (input + output combined)",
|
||||||
|
"max_output_tokens": "Maximum number of tokens the model can generate in a single response",
|
||||||
"supports_extended_thinking": "Whether the model supports extended reasoning tokens (currently none do via OpenRouter or custom APIs)",
|
"supports_extended_thinking": "Whether the model supports extended reasoning tokens (currently none do via OpenRouter or custom APIs)",
|
||||||
"supports_json_mode": "Whether the model can guarantee valid JSON output",
|
"supports_json_mode": "Whether the model can guarantee valid JSON output",
|
||||||
"supports_function_calling": "Whether the model supports function/tool calling",
|
"supports_function_calling": "Whether the model supports function/tool calling",
|
||||||
@@ -36,6 +37,7 @@
|
|||||||
"model_name": "my-local-model",
|
"model_name": "my-local-model",
|
||||||
"aliases": ["shortname", "nickname", "abbrev"],
|
"aliases": ["shortname", "nickname", "abbrev"],
|
||||||
"context_window": 128000,
|
"context_window": 128000,
|
||||||
|
"max_output_tokens": 32768,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": true,
|
"supports_json_mode": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
@@ -52,6 +54,7 @@
|
|||||||
"model_name": "anthropic/claude-opus-4",
|
"model_name": "anthropic/claude-opus-4",
|
||||||
"aliases": ["opus", "claude-opus", "claude4-opus", "claude-4-opus"],
|
"aliases": ["opus", "claude-opus", "claude4-opus", "claude-4-opus"],
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 64000,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": false,
|
"supports_json_mode": false,
|
||||||
"supports_function_calling": false,
|
"supports_function_calling": false,
|
||||||
@@ -63,6 +66,7 @@
|
|||||||
"model_name": "anthropic/claude-sonnet-4",
|
"model_name": "anthropic/claude-sonnet-4",
|
||||||
"aliases": ["sonnet", "claude-sonnet", "claude4-sonnet", "claude-4-sonnet", "claude"],
|
"aliases": ["sonnet", "claude-sonnet", "claude4-sonnet", "claude-4-sonnet", "claude"],
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 64000,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": false,
|
"supports_json_mode": false,
|
||||||
"supports_function_calling": false,
|
"supports_function_calling": false,
|
||||||
@@ -74,6 +78,7 @@
|
|||||||
"model_name": "anthropic/claude-3.5-haiku",
|
"model_name": "anthropic/claude-3.5-haiku",
|
||||||
"aliases": ["haiku", "claude-haiku", "claude3-haiku", "claude-3-haiku"],
|
"aliases": ["haiku", "claude-haiku", "claude3-haiku", "claude-3-haiku"],
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 64000,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": false,
|
"supports_json_mode": false,
|
||||||
"supports_function_calling": false,
|
"supports_function_calling": false,
|
||||||
@@ -85,6 +90,7 @@
|
|||||||
"model_name": "google/gemini-2.5-pro",
|
"model_name": "google/gemini-2.5-pro",
|
||||||
"aliases": ["pro","gemini-pro", "gemini", "pro-openrouter"],
|
"aliases": ["pro","gemini-pro", "gemini", "pro-openrouter"],
|
||||||
"context_window": 1048576,
|
"context_window": 1048576,
|
||||||
|
"max_output_tokens": 65536,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": true,
|
"supports_json_mode": true,
|
||||||
"supports_function_calling": false,
|
"supports_function_calling": false,
|
||||||
@@ -96,6 +102,7 @@
|
|||||||
"model_name": "google/gemini-2.5-flash",
|
"model_name": "google/gemini-2.5-flash",
|
||||||
"aliases": ["flash","gemini-flash", "flash-openrouter", "flash-2.5"],
|
"aliases": ["flash","gemini-flash", "flash-openrouter", "flash-2.5"],
|
||||||
"context_window": 1048576,
|
"context_window": 1048576,
|
||||||
|
"max_output_tokens": 65536,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": true,
|
"supports_json_mode": true,
|
||||||
"supports_function_calling": false,
|
"supports_function_calling": false,
|
||||||
@@ -107,6 +114,7 @@
|
|||||||
"model_name": "mistralai/mistral-large-2411",
|
"model_name": "mistralai/mistral-large-2411",
|
||||||
"aliases": ["mistral-large", "mistral"],
|
"aliases": ["mistral-large", "mistral"],
|
||||||
"context_window": 128000,
|
"context_window": 128000,
|
||||||
|
"max_output_tokens": 32000,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": true,
|
"supports_json_mode": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
@@ -118,6 +126,7 @@
|
|||||||
"model_name": "meta-llama/llama-3-70b",
|
"model_name": "meta-llama/llama-3-70b",
|
||||||
"aliases": ["llama", "llama3", "llama3-70b", "llama-70b", "llama3-openrouter"],
|
"aliases": ["llama", "llama3", "llama3-70b", "llama-70b", "llama3-openrouter"],
|
||||||
"context_window": 8192,
|
"context_window": 8192,
|
||||||
|
"max_output_tokens": 8192,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": false,
|
"supports_json_mode": false,
|
||||||
"supports_function_calling": false,
|
"supports_function_calling": false,
|
||||||
@@ -129,6 +138,7 @@
|
|||||||
"model_name": "deepseek/deepseek-r1-0528",
|
"model_name": "deepseek/deepseek-r1-0528",
|
||||||
"aliases": ["deepseek-r1", "deepseek", "r1", "deepseek-thinking"],
|
"aliases": ["deepseek-r1", "deepseek", "r1", "deepseek-thinking"],
|
||||||
"context_window": 65536,
|
"context_window": 65536,
|
||||||
|
"max_output_tokens": 32768,
|
||||||
"supports_extended_thinking": true,
|
"supports_extended_thinking": true,
|
||||||
"supports_json_mode": true,
|
"supports_json_mode": true,
|
||||||
"supports_function_calling": false,
|
"supports_function_calling": false,
|
||||||
@@ -140,6 +150,7 @@
|
|||||||
"model_name": "perplexity/llama-3-sonar-large-32k-online",
|
"model_name": "perplexity/llama-3-sonar-large-32k-online",
|
||||||
"aliases": ["perplexity", "sonar", "perplexity-online"],
|
"aliases": ["perplexity", "sonar", "perplexity-online"],
|
||||||
"context_window": 32768,
|
"context_window": 32768,
|
||||||
|
"max_output_tokens": 32768,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": false,
|
"supports_json_mode": false,
|
||||||
"supports_function_calling": false,
|
"supports_function_calling": false,
|
||||||
@@ -151,6 +162,7 @@
|
|||||||
"model_name": "openai/o3",
|
"model_name": "openai/o3",
|
||||||
"aliases": ["o3"],
|
"aliases": ["o3"],
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 100000,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": true,
|
"supports_json_mode": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
@@ -164,6 +176,7 @@
|
|||||||
"model_name": "openai/o3-mini",
|
"model_name": "openai/o3-mini",
|
||||||
"aliases": ["o3-mini", "o3mini"],
|
"aliases": ["o3-mini", "o3mini"],
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 100000,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": true,
|
"supports_json_mode": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
@@ -177,6 +190,7 @@
|
|||||||
"model_name": "openai/o3-mini-high",
|
"model_name": "openai/o3-mini-high",
|
||||||
"aliases": ["o3-mini-high", "o3mini-high"],
|
"aliases": ["o3-mini-high", "o3mini-high"],
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 100000,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": true,
|
"supports_json_mode": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
@@ -190,6 +204,7 @@
|
|||||||
"model_name": "openai/o3-pro",
|
"model_name": "openai/o3-pro",
|
||||||
"aliases": ["o3-pro", "o3pro"],
|
"aliases": ["o3-pro", "o3pro"],
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 100000,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": true,
|
"supports_json_mode": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
@@ -203,6 +218,7 @@
|
|||||||
"model_name": "openai/o4-mini",
|
"model_name": "openai/o4-mini",
|
||||||
"aliases": ["o4-mini", "o4mini"],
|
"aliases": ["o4-mini", "o4mini"],
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
|
"max_output_tokens": 100000,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": true,
|
"supports_json_mode": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
@@ -212,23 +228,11 @@
|
|||||||
"temperature_constraint": "fixed",
|
"temperature_constraint": "fixed",
|
||||||
"description": "OpenAI's o4-mini model - optimized for shorter contexts with rapid reasoning and vision"
|
"description": "OpenAI's o4-mini model - optimized for shorter contexts with rapid reasoning and vision"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"model_name": "openai/o4-mini-high",
|
|
||||||
"aliases": ["o4-mini-high", "o4mini-high", "o4minihigh", "o4minihi"],
|
|
||||||
"context_window": 200000,
|
|
||||||
"supports_extended_thinking": false,
|
|
||||||
"supports_json_mode": true,
|
|
||||||
"supports_function_calling": true,
|
|
||||||
"supports_images": true,
|
|
||||||
"max_image_size_mb": 20.0,
|
|
||||||
"supports_temperature": false,
|
|
||||||
"temperature_constraint": "fixed",
|
|
||||||
"description": "OpenAI's o4-mini with high reasoning effort - enhanced for complex tasks with vision"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"model_name": "llama3.2",
|
"model_name": "llama3.2",
|
||||||
"aliases": ["local-llama", "local", "llama3.2", "ollama-llama"],
|
"aliases": ["local-llama", "local", "llama3.2", "ollama-llama"],
|
||||||
"context_window": 128000,
|
"context_window": 128000,
|
||||||
|
"max_output_tokens": 64000,
|
||||||
"supports_extended_thinking": false,
|
"supports_extended_thinking": false,
|
||||||
"supports_json_mode": false,
|
"supports_json_mode": false,
|
||||||
"supports_function_calling": false,
|
"supports_function_calling": false,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import os
|
|||||||
# These values are used in server responses and for tracking releases
|
# These values are used in server responses and for tracking releases
|
||||||
# IMPORTANT: This is the single source of truth for version and author info
|
# IMPORTANT: This is the single source of truth for version and author info
|
||||||
# Semantic versioning: MAJOR.MINOR.PATCH
|
# Semantic versioning: MAJOR.MINOR.PATCH
|
||||||
__version__ = "5.6.1"
|
__version__ = "5.7.0"
|
||||||
# Last update date in ISO format
|
# Last update date in ISO format
|
||||||
__updated__ = "2025-06-23"
|
__updated__ = "2025-06-23"
|
||||||
# Primary maintainer
|
# Primary maintainer
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ Regardless of your default configuration, you can specify models per request:
|
|||||||
| **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis |
|
| **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis |
|
||||||
| **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks |
|
| **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks |
|
||||||
| **`o4-mini`** | OpenAI | 200K tokens | Latest reasoning model | Optimized for shorter contexts |
|
| **`o4-mini`** | OpenAI | 200K tokens | Latest reasoning model | Optimized for shorter contexts |
|
||||||
| **`o4-mini-high`** | OpenAI | 200K tokens | Enhanced reasoning | Complex tasks requiring deeper analysis |
|
|
||||||
| **`gpt4.1`** | OpenAI | 1M tokens | Latest GPT-4 with extended context | Large codebase analysis, comprehensive reviews |
|
| **`gpt4.1`** | OpenAI | 1M tokens | Latest GPT-4 with extended context | Large codebase analysis, comprehensive reviews |
|
||||||
| **`llama`** (Llama 3.2) | Custom/Local | 128K tokens | Local inference, privacy | On-device analysis, cost-free processing |
|
| **`llama`** (Llama 3.2) | Custom/Local | 128K tokens | Local inference, privacy | On-device analysis, cost-free processing |
|
||||||
| **Any model** | OpenRouter | Varies | Access to GPT-4, Claude, Llama, etc. | User-specified or based on task requirements |
|
| **Any model** | OpenRouter | Varies | Access to GPT-4, Claude, Llama, etc. | User-specified or based on task requirements |
|
||||||
@@ -69,7 +68,7 @@ OPENAI_ALLOWED_MODELS=o4-mini
|
|||||||
|
|
||||||
# High-performance: Quality over cost
|
# High-performance: Quality over cost
|
||||||
GOOGLE_ALLOWED_MODELS=pro
|
GOOGLE_ALLOWED_MODELS=pro
|
||||||
OPENAI_ALLOWED_MODELS=o3,o4-mini-high
|
OPENAI_ALLOWED_MODELS=o3,o4-mini
|
||||||
```
|
```
|
||||||
|
|
||||||
**Important Notes:**
|
**Important Notes:**
|
||||||
@@ -144,7 +143,7 @@ All tools that work with files support **both individual files and entire direct
|
|||||||
**`analyze`** - Analyze files or directories
|
**`analyze`** - Analyze files or directories
|
||||||
- `files`: List of file paths or directories (required)
|
- `files`: List of file paths or directories (required)
|
||||||
- `question`: What to analyze (required)
|
- `question`: What to analyze (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `analysis_type`: architecture|performance|security|quality|general
|
- `analysis_type`: architecture|performance|security|quality|general
|
||||||
- `output_format`: summary|detailed|actionable
|
- `output_format`: summary|detailed|actionable
|
||||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||||
@@ -159,7 +158,7 @@ All tools that work with files support **both individual files and entire direct
|
|||||||
|
|
||||||
**`codereview`** - Review code files or directories
|
**`codereview`** - Review code files or directories
|
||||||
- `files`: List of file paths or directories (required)
|
- `files`: List of file paths or directories (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `review_type`: full|security|performance|quick
|
- `review_type`: full|security|performance|quick
|
||||||
- `focus_on`: Specific aspects to focus on
|
- `focus_on`: Specific aspects to focus on
|
||||||
- `standards`: Coding standards to enforce
|
- `standards`: Coding standards to enforce
|
||||||
@@ -175,7 +174,7 @@ All tools that work with files support **both individual files and entire direct
|
|||||||
|
|
||||||
**`debug`** - Debug with file context
|
**`debug`** - Debug with file context
|
||||||
- `error_description`: Description of the issue (required)
|
- `error_description`: Description of the issue (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `error_context`: Stack trace or logs
|
- `error_context`: Stack trace or logs
|
||||||
- `files`: Files or directories related to the issue
|
- `files`: Files or directories related to the issue
|
||||||
- `runtime_info`: Environment details
|
- `runtime_info`: Environment details
|
||||||
@@ -191,7 +190,7 @@ All tools that work with files support **both individual files and entire direct
|
|||||||
|
|
||||||
**`thinkdeep`** - Extended analysis with file context
|
**`thinkdeep`** - Extended analysis with file context
|
||||||
- `current_analysis`: Your current thinking (required)
|
- `current_analysis`: Your current thinking (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `problem_context`: Additional context
|
- `problem_context`: Additional context
|
||||||
- `focus_areas`: Specific aspects to focus on
|
- `focus_areas`: Specific aspects to focus on
|
||||||
- `files`: Files or directories for context
|
- `files`: Files or directories for context
|
||||||
@@ -207,7 +206,7 @@ All tools that work with files support **both individual files and entire direct
|
|||||||
**`testgen`** - Comprehensive test generation with edge case coverage
|
**`testgen`** - Comprehensive test generation with edge case coverage
|
||||||
- `files`: Code files or directories to generate tests for (required)
|
- `files`: Code files or directories to generate tests for (required)
|
||||||
- `prompt`: Description of what to test, testing objectives, and scope (required)
|
- `prompt`: Description of what to test, testing objectives, and scope (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `test_examples`: Optional existing test files as style/pattern reference
|
- `test_examples`: Optional existing test files as style/pattern reference
|
||||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||||
|
|
||||||
@@ -222,7 +221,7 @@ All tools that work with files support **both individual files and entire direct
|
|||||||
- `files`: Code files or directories to analyze for refactoring opportunities (required)
|
- `files`: Code files or directories to analyze for refactoring opportunities (required)
|
||||||
- `prompt`: Description of refactoring goals, context, and specific areas of focus (required)
|
- `prompt`: Description of refactoring goals, context, and specific areas of focus (required)
|
||||||
- `refactor_type`: codesmells|decompose|modernize|organization (required)
|
- `refactor_type`: codesmells|decompose|modernize|organization (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security')
|
- `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security')
|
||||||
- `style_guide_examples`: Optional existing code files to use as style/pattern reference
|
- `style_guide_examples`: Optional existing code files to use as style/pattern reference
|
||||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ CUSTOM_MODEL_NAME=llama3.2 # Default model
|
|||||||
|
|
||||||
**Default Model Selection:**
|
**Default Model Selection:**
|
||||||
```env
|
```env
|
||||||
# Options: 'auto', 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high', etc.
|
# Options: 'auto', 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', etc.
|
||||||
DEFAULT_MODEL=auto # Claude picks best model for each task (recommended)
|
DEFAULT_MODEL=auto # Claude picks best model for each task (recommended)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -74,7 +74,6 @@ DEFAULT_MODEL=auto # Claude picks best model for each task (recommended)
|
|||||||
- **`o3`**: Strong logical reasoning (200K context)
|
- **`o3`**: Strong logical reasoning (200K context)
|
||||||
- **`o3-mini`**: Balanced speed/quality (200K context)
|
- **`o3-mini`**: Balanced speed/quality (200K context)
|
||||||
- **`o4-mini`**: Latest reasoning model, optimized for shorter contexts
|
- **`o4-mini`**: Latest reasoning model, optimized for shorter contexts
|
||||||
- **`o4-mini-high`**: Enhanced O4 with higher reasoning effort
|
|
||||||
- **`grok`**: GROK-3 advanced reasoning (131K context)
|
- **`grok`**: GROK-3 advanced reasoning (131K context)
|
||||||
- **Custom models**: via OpenRouter or local APIs
|
- **Custom models**: via OpenRouter or local APIs
|
||||||
|
|
||||||
@@ -120,7 +119,6 @@ OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral
|
|||||||
- `o3` (200K context, high reasoning)
|
- `o3` (200K context, high reasoning)
|
||||||
- `o3-mini` (200K context, balanced)
|
- `o3-mini` (200K context, balanced)
|
||||||
- `o4-mini` (200K context, latest balanced)
|
- `o4-mini` (200K context, latest balanced)
|
||||||
- `o4-mini-high` (200K context, enhanced reasoning)
|
|
||||||
- `mini` (shorthand for o4-mini)
|
- `mini` (shorthand for o4-mini)
|
||||||
|
|
||||||
**Gemini Models:**
|
**Gemini Models:**
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ This workflow ensures methodical analysis before expert insights, resulting in d
|
|||||||
|
|
||||||
**Initial Configuration (used in step 1):**
|
**Initial Configuration (used in step 1):**
|
||||||
- `prompt`: What to analyze or look for (required)
|
- `prompt`: What to analyze or look for (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `analysis_type`: architecture|performance|security|quality|general (default: general)
|
- `analysis_type`: architecture|performance|security|quality|general (default: general)
|
||||||
- `output_format`: summary|detailed|actionable (default: detailed)
|
- `output_format`: summary|detailed|actionable (default: detailed)
|
||||||
- `temperature`: Temperature for analysis (0-1, default 0.2)
|
- `temperature`: Temperature for analysis (0-1, default 0.2)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ and then debate with the other models to give me a final verdict
|
|||||||
## Tool Parameters
|
## Tool Parameters
|
||||||
|
|
||||||
- `prompt`: Your question or discussion topic (required)
|
- `prompt`: Your question or discussion topic (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `files`: Optional files for context (absolute paths)
|
- `files`: Optional files for context (absolute paths)
|
||||||
- `images`: Optional images for visual context (absolute paths)
|
- `images`: Optional images for visual context (absolute paths)
|
||||||
- `temperature`: Response creativity (0-1, default 0.5)
|
- `temperature`: Response creativity (0-1, default 0.5)
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ The above prompt will simultaneously run two separate `codereview` tools with tw
|
|||||||
|
|
||||||
**Initial Review Configuration (used in step 1):**
|
**Initial Review Configuration (used in step 1):**
|
||||||
- `prompt`: User's summary of what the code does, expected behavior, constraints, and review objectives (required)
|
- `prompt`: User's summary of what the code does, expected behavior, constraints, and review objectives (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `review_type`: full|security|performance|quick (default: full)
|
- `review_type`: full|security|performance|quick (default: full)
|
||||||
- `focus_on`: Specific aspects to focus on (e.g., "security vulnerabilities", "performance bottlenecks")
|
- `focus_on`: Specific aspects to focus on (e.g., "security vulnerabilities", "performance bottlenecks")
|
||||||
- `standards`: Coding standards to enforce (e.g., "PEP8", "ESLint", "Google Style Guide")
|
- `standards`: Coding standards to enforce (e.g., "PEP8", "ESLint", "Google Style Guide")
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ This structured approach ensures Claude performs methodical groundwork before ex
|
|||||||
- `images`: Visual debugging materials (error screenshots, logs, etc.)
|
- `images`: Visual debugging materials (error screenshots, logs, etc.)
|
||||||
|
|
||||||
**Model Selection:**
|
**Model Selection:**
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini (default: server default)
|
||||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||||
- `use_websearch`: Enable web search for documentation and solutions (default: true)
|
- `use_websearch`: Enable web search for documentation and solutions (default: true)
|
||||||
- `use_assistant_model`: Whether to use expert analysis phase (default: true, set to false to use Claude only)
|
- `use_assistant_model`: Whether to use expert analysis phase (default: true, set to false to use Claude only)
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ Use zen and perform a thorough precommit ensuring there aren't any new regressio
|
|||||||
**Initial Configuration (used in step 1):**
|
**Initial Configuration (used in step 1):**
|
||||||
- `path`: Starting directory to search for repos (default: current directory, absolute path required)
|
- `path`: Starting directory to search for repos (default: current directory, absolute path required)
|
||||||
- `prompt`: The original user request description for the changes (required for context)
|
- `prompt`: The original user request description for the changes (required for context)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `compare_to`: Compare against a branch/tag instead of local changes (optional)
|
- `compare_to`: Compare against a branch/tag instead of local changes (optional)
|
||||||
- `severity_filter`: critical|high|medium|low|all (default: all)
|
- `severity_filter`: critical|high|medium|low|all (default: all)
|
||||||
- `include_staged`: Include staged changes in the review (default: true)
|
- `include_staged`: Include staged changes in the review (default: true)
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ This results in Claude first performing its own expert analysis, encouraging it
|
|||||||
**Initial Configuration (used in step 1):**
|
**Initial Configuration (used in step 1):**
|
||||||
- `prompt`: Description of refactoring goals, context, and specific areas of focus (required)
|
- `prompt`: Description of refactoring goals, context, and specific areas of focus (required)
|
||||||
- `refactor_type`: codesmells|decompose|modernize|organization (default: codesmells)
|
- `refactor_type`: codesmells|decompose|modernize|organization (default: codesmells)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security')
|
- `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security')
|
||||||
- `style_guide_examples`: Optional existing code files to use as style/pattern reference (absolute paths)
|
- `style_guide_examples`: Optional existing code files to use as style/pattern reference (absolute paths)
|
||||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ security remediation plan using planner
|
|||||||
- `images`: Architecture diagrams, security documentation, or visual references
|
- `images`: Architecture diagrams, security documentation, or visual references
|
||||||
|
|
||||||
**Initial Security Configuration (used in step 1):**
|
**Initial Security Configuration (used in step 1):**
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `security_scope`: Application context, technology stack, and security boundary definition (required)
|
- `security_scope`: Application context, technology stack, and security boundary definition (required)
|
||||||
- `threat_level`: low|medium|high|critical (default: medium) - determines assessment depth and urgency
|
- `threat_level`: low|medium|high|critical (default: medium) - determines assessment depth and urgency
|
||||||
- `compliance_requirements`: List of compliance frameworks to assess against (e.g., ["PCI DSS", "SOC2"])
|
- `compliance_requirements`: List of compliance frameworks to assess against (e.g., ["PCI DSS", "SOC2"])
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ Test generation excels with extended reasoning models like Gemini Pro or O3, whi
|
|||||||
|
|
||||||
**Initial Configuration (used in step 1):**
|
**Initial Configuration (used in step 1):**
|
||||||
- `prompt`: Description of what to test, testing objectives, and specific scope/focus areas (required)
|
- `prompt`: Description of what to test, testing objectives, and specific scope/focus areas (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `test_examples`: Optional existing test files or directories to use as style/pattern reference (absolute paths)
|
- `test_examples`: Optional existing test files or directories to use as style/pattern reference (absolute paths)
|
||||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||||
- `use_assistant_model`: Whether to use expert test generation phase (default: true, set to false to use Claude only)
|
- `use_assistant_model`: Whether to use expert test generation phase (default: true, set to false to use Claude only)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ with the best architecture for my project
|
|||||||
## Tool Parameters
|
## Tool Parameters
|
||||||
|
|
||||||
- `prompt`: Your current thinking/analysis to extend and validate (required)
|
- `prompt`: Your current thinking/analysis to extend and validate (required)
|
||||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||||
- `problem_context`: Additional context about the problem or goal
|
- `problem_context`: Additional context about the problem or goal
|
||||||
- `focus_areas`: Specific aspects to focus on (architecture, performance, security, etc.)
|
- `focus_areas`: Specific aspects to focus on (architecture, performance, security, etc.)
|
||||||
- `files`: Optional file paths or directories for additional context (absolute paths)
|
- `files`: Optional file paths or directories for additional context (absolute paths)
|
||||||
|
|||||||
@@ -132,6 +132,7 @@ class ModelCapabilities:
|
|||||||
model_name: str
|
model_name: str
|
||||||
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
||||||
context_window: int # Total context window size in tokens
|
context_window: int # Total context window size in tokens
|
||||||
|
max_output_tokens: int # Maximum output tokens per request
|
||||||
supports_extended_thinking: bool = False
|
supports_extended_thinking: bool = False
|
||||||
supports_system_prompts: bool = True
|
supports_system_prompts: bool = True
|
||||||
supports_streaming: bool = True
|
supports_streaming: bool = True
|
||||||
@@ -140,6 +141,19 @@ class ModelCapabilities:
|
|||||||
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
||||||
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
||||||
|
|
||||||
|
# Additional fields for comprehensive model information
|
||||||
|
description: str = "" # Human-readable description of the model
|
||||||
|
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
|
||||||
|
|
||||||
|
# JSON mode support (for providers that support structured output)
|
||||||
|
supports_json_mode: bool = False
|
||||||
|
|
||||||
|
# Thinking mode support (for models with thinking capabilities)
|
||||||
|
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
|
||||||
|
|
||||||
|
# Custom model flag (for models that only work with custom endpoints)
|
||||||
|
is_custom: bool = False # Whether this model requires custom API endpoints
|
||||||
|
|
||||||
# Temperature constraint object - preferred way to define temperature limits
|
# Temperature constraint object - preferred way to define temperature limits
|
||||||
temperature_constraint: TemperatureConstraint = field(
|
temperature_constraint: TemperatureConstraint = field(
|
||||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||||
@@ -251,7 +265,7 @@ class ModelProvider(ABC):
|
|||||||
capabilities = self.get_capabilities(model_name)
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
# Check if model supports temperature at all
|
# Check if model supports temperature at all
|
||||||
if hasattr(capabilities, "supports_temperature") and not capabilities.supports_temperature:
|
if not capabilities.supports_temperature:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get temperature range
|
# Get temperature range
|
||||||
@@ -290,19 +304,109 @@ class ModelProvider(ABC):
|
|||||||
"""Check if the model supports extended thinking mode."""
|
"""Check if the model supports extended thinking mode."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Get model configurations for this provider.
|
||||||
|
|
||||||
|
This is a hook method that subclasses can override to provide
|
||||||
|
their model configurations from different sources.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping model names to their ModelCapabilities objects
|
||||||
|
"""
|
||||||
|
# Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects)
|
||||||
|
if hasattr(self, "SUPPORTED_MODELS"):
|
||||||
|
return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||||
|
"""Get all model aliases for this provider.
|
||||||
|
|
||||||
|
This is a hook method that subclasses can override to provide
|
||||||
|
aliases from different sources.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping model names to their list of aliases
|
||||||
|
"""
|
||||||
|
# Default implementation extracts from ModelCapabilities objects
|
||||||
|
aliases = {}
|
||||||
|
for model_name, capabilities in self.get_model_configurations().items():
|
||||||
|
if capabilities.aliases:
|
||||||
|
aliases[model_name] = capabilities.aliases
|
||||||
|
return aliases
|
||||||
|
|
||||||
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
"""Resolve model shorthand to full name.
|
||||||
|
|
||||||
|
This implementation uses the hook methods to support different
|
||||||
|
model configuration sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model name that may be an alias
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved model name
|
||||||
|
"""
|
||||||
|
# Get model configurations from the hook method
|
||||||
|
model_configs = self.get_model_configurations()
|
||||||
|
|
||||||
|
# First check if it's already a base model name (case-sensitive exact match)
|
||||||
|
if model_name in model_configs:
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
# Check case-insensitively for both base models and aliases
|
||||||
|
model_name_lower = model_name.lower()
|
||||||
|
|
||||||
|
# Check base model names case-insensitively
|
||||||
|
for base_model in model_configs:
|
||||||
|
if base_model.lower() == model_name_lower:
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
# Check aliases from the hook method
|
||||||
|
all_aliases = self.get_all_model_aliases()
|
||||||
|
for base_model, aliases in all_aliases.items():
|
||||||
|
if any(alias.lower() == model_name_lower for alias in aliases):
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
# If not found, return as-is
|
||||||
|
return model_name
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||||
"""Return a list of model names supported by this provider.
|
"""Return a list of model names supported by this provider.
|
||||||
|
|
||||||
|
This implementation uses the get_model_configurations() hook
|
||||||
|
to support different model configuration sources.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of model names available from this provider
|
List of model names available from this provider
|
||||||
"""
|
"""
|
||||||
pass
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models = []
|
||||||
|
|
||||||
|
# Get model configurations from the hook method
|
||||||
|
model_configs = self.get_model_configurations()
|
||||||
|
|
||||||
|
for model_name in model_configs:
|
||||||
|
# Check restrictions if enabled
|
||||||
|
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add the base model
|
||||||
|
models.append(model_name)
|
||||||
|
|
||||||
|
# Get aliases from the hook method
|
||||||
|
all_aliases = self.get_all_model_aliases()
|
||||||
|
for model_name, aliases in all_aliases.items():
|
||||||
|
# Only add aliases for models that passed restriction check
|
||||||
|
if model_name in models:
|
||||||
|
models.extend(aliases)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
def list_all_known_models(self) -> list[str]:
|
||||||
"""Return all model names known by this provider, including alias targets.
|
"""Return all model names known by this provider, including alias targets.
|
||||||
|
|
||||||
@@ -312,21 +416,22 @@ class ModelProvider(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
List of all model names and alias targets known by this provider
|
List of all model names and alias targets known by this provider
|
||||||
"""
|
"""
|
||||||
pass
|
all_models = set()
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
# Get model configurations from the hook method
|
||||||
"""Resolve model shorthand to full name.
|
model_configs = self.get_model_configurations()
|
||||||
|
|
||||||
Base implementation returns the model name unchanged.
|
# Add all base model names
|
||||||
Subclasses should override to provide alias resolution.
|
for model_name in model_configs:
|
||||||
|
all_models.add(model_name.lower())
|
||||||
|
|
||||||
Args:
|
# Get aliases from the hook method and add them
|
||||||
model_name: Model name that may be an alias
|
all_aliases = self.get_all_model_aliases()
|
||||||
|
for _model_name, aliases in all_aliases.items():
|
||||||
|
for alias in aliases:
|
||||||
|
all_models.add(alias.lower())
|
||||||
|
|
||||||
Returns:
|
return list(all_models)
|
||||||
Resolved model name
|
|
||||||
"""
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Clean up any resources held by the provider.
|
"""Clean up any resources held by the provider.
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
model_name=resolved_name,
|
model_name=resolved_name,
|
||||||
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})",
|
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})",
|
||||||
context_window=32_768, # Conservative default
|
context_window=32_768, # Conservative default
|
||||||
|
max_output_tokens=32_768, # Conservative default max output
|
||||||
supports_extended_thinking=False, # Most custom models don't support this
|
supports_extended_thinking=False, # Most custom models don't support this
|
||||||
supports_system_prompts=True,
|
supports_system_prompts=True,
|
||||||
supports_streaming=True,
|
supports_streaming=True,
|
||||||
@@ -187,7 +188,7 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
Returns:
|
Returns:
|
||||||
True if model is intended for custom/local endpoint
|
True if model is intended for custom/local endpoint
|
||||||
"""
|
"""
|
||||||
logging.debug(f"Custom provider validating model: '{model_name}'")
|
# logging.debug(f"Custom provider validating model: '{model_name}'")
|
||||||
|
|
||||||
# Try to resolve through registry first
|
# Try to resolve through registry first
|
||||||
config = self._registry.resolve(model_name)
|
config = self._registry.resolve(model_name)
|
||||||
@@ -195,12 +196,12 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
model_id = config.model_name
|
model_id = config.model_name
|
||||||
# Use explicit is_custom flag for clean validation
|
# Use explicit is_custom flag for clean validation
|
||||||
if config.is_custom:
|
if config.is_custom:
|
||||||
logging.debug(f"Model '{model_name}' -> '{model_id}' validated via registry (custom model)")
|
logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' validated via registry")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
# This is a cloud/OpenRouter model - CustomProvider should NOT handle these
|
# This is a cloud/OpenRouter model - CustomProvider should NOT handle these
|
||||||
# Let OpenRouter provider handle them instead
|
# Let OpenRouter provider handle them instead
|
||||||
logging.debug(f"Model '{model_name}' -> '{model_id}' rejected (cloud model, defer to OpenRouter)")
|
# logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' not custom (defer to OpenRouter)")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Handle version tags for unknown models (e.g., "my-model:latest")
|
# Handle version tags for unknown models (e.g., "my-model:latest")
|
||||||
@@ -268,65 +269,50 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
"""Check if the model supports extended thinking mode.
|
"""Check if the model supports extended thinking mode.
|
||||||
|
|
||||||
Most custom/local models don't support extended thinking.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Model to check
|
model_name: Model to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
False (custom models generally don't support thinking mode)
|
True if model supports thinking mode, False otherwise
|
||||||
"""
|
"""
|
||||||
|
# Check if model is in registry
|
||||||
|
config = self._registry.resolve(model_name) if self._registry else None
|
||||||
|
if config and config.is_custom:
|
||||||
|
# Trust the config from custom_models.json
|
||||||
|
return config.supports_extended_thinking
|
||||||
|
|
||||||
|
# Default to False for unknown models
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||||
"""Return a list of model names supported by this provider.
|
"""Get model configurations from the registry.
|
||||||
|
|
||||||
Args:
|
For CustomProvider, we convert registry configurations to ModelCapabilities objects.
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of model names available from this provider
|
Dictionary mapping model names to their ModelCapabilities objects
|
||||||
"""
|
"""
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
configs = {}
|
||||||
models = []
|
|
||||||
|
|
||||||
if self._registry:
|
if self._registry:
|
||||||
# Get all models from the registry
|
# Get all models from registry
|
||||||
all_models = self._registry.list_models()
|
for model_name in self._registry.list_models():
|
||||||
aliases = self._registry.list_aliases()
|
# Only include custom models that this provider validates
|
||||||
|
|
||||||
# Add models that are validated by the custom provider
|
|
||||||
for model_name in all_models + aliases:
|
|
||||||
# Use the provider's validation logic to determine if this model
|
|
||||||
# is appropriate for the custom endpoint
|
|
||||||
if self.validate_model_name(model_name):
|
if self.validate_model_name(model_name):
|
||||||
# Check restrictions if enabled
|
config = self._registry.resolve(model_name)
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
if config and config.is_custom:
|
||||||
continue
|
# Use ModelCapabilities directly from registry
|
||||||
|
configs[model_name] = config
|
||||||
|
|
||||||
models.append(model_name)
|
return configs
|
||||||
|
|
||||||
return models
|
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||||
|
"""Get all model aliases from the registry.
|
||||||
def list_all_known_models(self) -> list[str]:
|
|
||||||
"""Return all model names known by this provider, including alias targets.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of all model names and alias targets known by this provider
|
Dictionary mapping model names to their list of aliases
|
||||||
"""
|
"""
|
||||||
all_models = set()
|
# Since aliases are now included in the configurations,
|
||||||
|
# we can use the base class implementation
|
||||||
if self._registry:
|
return super().get_all_model_aliases()
|
||||||
# Get all models and aliases from the registry
|
|
||||||
all_models.update(model.lower() for model in self._registry.list_models())
|
|
||||||
all_models.update(alias.lower() for alias in self._registry.list_aliases())
|
|
||||||
|
|
||||||
# For each alias, also add its target
|
|
||||||
for alias in self._registry.list_aliases():
|
|
||||||
config = self._registry.resolve(alias)
|
|
||||||
if config:
|
|
||||||
all_models.add(config.model_name.lower())
|
|
||||||
|
|
||||||
return list(all_models)
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from .base import (
|
|||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
create_temperature_constraint,
|
||||||
)
|
)
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
|
||||||
@@ -30,63 +30,170 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
MAX_RETRIES = 4
|
MAX_RETRIES = 4
|
||||||
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||||
|
|
||||||
# Supported DIAL models (these can be customized based on your DIAL deployment)
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"o3-2025-04-16": {
|
"o3-2025-04-16": ModelCapabilities(
|
||||||
"context_window": 200_000,
|
provider=ProviderType.DIAL,
|
||||||
"supports_extended_thinking": False,
|
model_name="o3-2025-04-16",
|
||||||
"supports_vision": True,
|
friendly_name="DIAL (O3)",
|
||||||
},
|
context_window=200_000,
|
||||||
"o4-mini-2025-04-16": {
|
max_output_tokens=100_000,
|
||||||
"context_window": 200_000,
|
supports_extended_thinking=False,
|
||||||
"supports_extended_thinking": False,
|
supports_system_prompts=True,
|
||||||
"supports_vision": True,
|
supports_streaming=True,
|
||||||
},
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
"anthropic.claude-sonnet-4-20250514-v1:0": {
|
supports_json_mode=True,
|
||||||
"context_window": 200_000,
|
supports_images=True,
|
||||||
"supports_extended_thinking": False,
|
max_image_size_mb=20.0,
|
||||||
"supports_vision": True,
|
supports_temperature=False, # O3 models don't accept temperature
|
||||||
},
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
|
description="OpenAI O3 via DIAL - Strong reasoning model",
|
||||||
"context_window": 200_000,
|
aliases=["o3"],
|
||||||
"supports_extended_thinking": True, # Thinking mode variant
|
),
|
||||||
"supports_vision": True,
|
"o4-mini-2025-04-16": ModelCapabilities(
|
||||||
},
|
provider=ProviderType.DIAL,
|
||||||
"anthropic.claude-opus-4-20250514-v1:0": {
|
model_name="o4-mini-2025-04-16",
|
||||||
"context_window": 200_000,
|
friendly_name="DIAL (O4-mini)",
|
||||||
"supports_extended_thinking": False,
|
context_window=200_000,
|
||||||
"supports_vision": True,
|
max_output_tokens=100_000,
|
||||||
},
|
supports_extended_thinking=False,
|
||||||
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
|
supports_system_prompts=True,
|
||||||
"context_window": 200_000,
|
supports_streaming=True,
|
||||||
"supports_extended_thinking": True, # Thinking mode variant
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
"supports_vision": True,
|
supports_json_mode=True,
|
||||||
},
|
supports_images=True,
|
||||||
"gemini-2.5-pro-preview-03-25-google-search": {
|
max_image_size_mb=20.0,
|
||||||
"context_window": 1_000_000,
|
supports_temperature=False, # O4 models don't accept temperature
|
||||||
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
"supports_vision": True,
|
description="OpenAI O4-mini via DIAL - Fast reasoning model",
|
||||||
},
|
aliases=["o4-mini"],
|
||||||
"gemini-2.5-pro-preview-05-06": {
|
),
|
||||||
"context_window": 1_000_000,
|
"anthropic.claude-sonnet-4-20250514-v1:0": ModelCapabilities(
|
||||||
"supports_extended_thinking": False,
|
provider=ProviderType.DIAL,
|
||||||
"supports_vision": True,
|
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||||
},
|
friendly_name="DIAL (Sonnet 4)",
|
||||||
"gemini-2.5-flash-preview-05-20": {
|
context_window=200_000,
|
||||||
"context_window": 1_000_000,
|
max_output_tokens=64_000,
|
||||||
"supports_extended_thinking": False,
|
supports_extended_thinking=False,
|
||||||
"supports_vision": True,
|
supports_system_prompts=True,
|
||||||
},
|
supports_streaming=True,
|
||||||
# Shorthands
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
"o3": "o3-2025-04-16",
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
"o4-mini": "o4-mini-2025-04-16",
|
supports_images=True,
|
||||||
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
|
max_image_size_mb=5.0,
|
||||||
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
supports_temperature=True,
|
||||||
"opus-4": "anthropic.claude-opus-4-20250514-v1:0",
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
description="Claude Sonnet 4 via DIAL - Balanced performance",
|
||||||
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
|
aliases=["sonnet-4"],
|
||||||
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
|
),
|
||||||
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
|
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
||||||
|
friendly_name="DIAL (Sonnet 4 Thinking)",
|
||||||
|
context_window=200_000,
|
||||||
|
max_output_tokens=64_000,
|
||||||
|
supports_extended_thinking=True, # Thinking mode variant
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=5.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Claude Sonnet 4 with thinking mode via DIAL",
|
||||||
|
aliases=["sonnet-4-thinking"],
|
||||||
|
),
|
||||||
|
"anthropic.claude-opus-4-20250514-v1:0": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="anthropic.claude-opus-4-20250514-v1:0",
|
||||||
|
friendly_name="DIAL (Opus 4)",
|
||||||
|
context_window=200_000,
|
||||||
|
max_output_tokens=64_000,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=5.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Claude Opus 4 via DIAL - Most capable Claude model",
|
||||||
|
aliases=["opus-4"],
|
||||||
|
),
|
||||||
|
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
||||||
|
friendly_name="DIAL (Opus 4 Thinking)",
|
||||||
|
context_window=200_000,
|
||||||
|
max_output_tokens=64_000,
|
||||||
|
supports_extended_thinking=True, # Thinking mode variant
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=5.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Claude Opus 4 with thinking mode via DIAL",
|
||||||
|
aliases=["opus-4-thinking"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="gemini-2.5-pro-preview-03-25-google-search",
|
||||||
|
friendly_name="DIAL (Gemini 2.5 Pro Search)",
|
||||||
|
context_window=1_000_000,
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=False, # DIAL doesn't expose thinking mode
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Gemini 2.5 Pro with Google Search via DIAL",
|
||||||
|
aliases=["gemini-2.5-pro-search"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="gemini-2.5-pro-preview-05-06",
|
||||||
|
friendly_name="DIAL (Gemini 2.5 Pro)",
|
||||||
|
context_window=1_000_000,
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
|
||||||
|
aliases=["gemini-2.5-pro"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="gemini-2.5-flash-preview-05-20",
|
||||||
|
friendly_name="DIAL (Gemini Flash 2.5)",
|
||||||
|
context_window=1_000_000,
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
|
||||||
|
aliases=["gemini-2.5-flash"],
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
@@ -181,20 +288,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
config = self.SUPPORTED_MODELS[resolved_name]
|
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||||
|
return self.SUPPORTED_MODELS[resolved_name]
|
||||||
return ModelCapabilities(
|
|
||||||
provider=ProviderType.DIAL,
|
|
||||||
model_name=resolved_name,
|
|
||||||
friendly_name=self.FRIENDLY_NAME,
|
|
||||||
context_window=config["context_window"],
|
|
||||||
supports_extended_thinking=config["supports_extended_thinking"],
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=True,
|
|
||||||
supports_images=config.get("supports_vision", False),
|
|
||||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
@@ -211,7 +306,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
"""
|
"""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check against base class allowed_models if configured
|
# Check against base class allowed_models if configured
|
||||||
@@ -231,20 +326,6 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
|
||||||
"""Resolve model shorthand to full name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Model name or shorthand
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Full model name
|
|
||||||
"""
|
|
||||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
|
||||||
if isinstance(shorthand_value, str):
|
|
||||||
return shorthand_value
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
def _get_deployment_client(self, deployment: str):
|
def _get_deployment_client(self, deployment: str):
|
||||||
"""Get or create a cached client for a specific deployment.
|
"""Get or create a cached client for a specific deployment.
|
||||||
|
|
||||||
@@ -357,7 +438,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
# Check model capabilities
|
# Check model capabilities
|
||||||
try:
|
try:
|
||||||
capabilities = self.get_capabilities(model_name)
|
capabilities = self.get_capabilities(model_name)
|
||||||
supports_temperature = getattr(capabilities, "supports_temperature", True)
|
supports_temperature = capabilities.supports_temperature
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
||||||
supports_temperature = True
|
supports_temperature = True
|
||||||
@@ -441,63 +522,12 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
|||||||
"""
|
"""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
if resolved_name in self.SUPPORTED_MODELS:
|
||||||
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
|
return self.SUPPORTED_MODELS[resolved_name].supports_images
|
||||||
|
|
||||||
# Fall back to parent implementation for unknown models
|
# Fall back to parent implementation for unknown models
|
||||||
return super()._supports_vision(model_name)
|
return super()._supports_vision(model_name)
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
|
||||||
"""Return a list of model names supported by this provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model names available from this provider
|
|
||||||
"""
|
|
||||||
# Get all model keys (both full names and aliases)
|
|
||||||
all_models = list(self.SUPPORTED_MODELS.keys())
|
|
||||||
|
|
||||||
if not respect_restrictions:
|
|
||||||
return all_models
|
|
||||||
|
|
||||||
# Apply restrictions if configured
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service()
|
|
||||||
|
|
||||||
# Filter based on restrictions
|
|
||||||
allowed_models = []
|
|
||||||
for model in all_models:
|
|
||||||
resolved_name = self._resolve_model_name(model)
|
|
||||||
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
|
|
||||||
allowed_models.append(model)
|
|
||||||
|
|
||||||
return allowed_models
|
|
||||||
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
|
||||||
"""Return all model names known by this provider, including alias targets.
|
|
||||||
|
|
||||||
This is used for validation purposes to ensure restriction policies
|
|
||||||
can validate against both aliases and their target model names.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of all model names and alias targets known by this provider
|
|
||||||
"""
|
|
||||||
# Collect all unique model names (both aliases and targets)
|
|
||||||
all_models = set()
|
|
||||||
|
|
||||||
for key, value in self.SUPPORTED_MODELS.items():
|
|
||||||
# Add the key (could be alias or full name)
|
|
||||||
all_models.add(key)
|
|
||||||
|
|
||||||
# If it's an alias (string value), add the target too
|
|
||||||
if isinstance(value, str):
|
|
||||||
all_models.add(value)
|
|
||||||
|
|
||||||
return sorted(all_models)
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Clean up HTTP clients when provider is closed."""
|
"""Clean up HTTP clients when provider is closed."""
|
||||||
logger.info("Closing DIAL provider HTTP clients...")
|
logger.info("Closing DIAL provider HTTP clients...")
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import Optional
|
|||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
|
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -17,47 +17,83 @@ logger = logging.getLogger(__name__)
|
|||||||
class GeminiModelProvider(ModelProvider):
|
class GeminiModelProvider(ModelProvider):
|
||||||
"""Google Gemini model provider implementation."""
|
"""Google Gemini model provider implementation."""
|
||||||
|
|
||||||
# Model configurations
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"gemini-2.0-flash": {
|
"gemini-2.0-flash": ModelCapabilities(
|
||||||
"context_window": 1_048_576, # 1M tokens
|
provider=ProviderType.GOOGLE,
|
||||||
"supports_extended_thinking": True, # Experimental thinking mode
|
model_name="gemini-2.0-flash",
|
||||||
"max_thinking_tokens": 24576, # Same as 2.5 flash for consistency
|
friendly_name="Gemini (Flash 2.0)",
|
||||||
"supports_images": True, # Vision capability
|
context_window=1_048_576, # 1M tokens
|
||||||
"max_image_size_mb": 20.0, # Conservative 20MB limit for reliability
|
max_output_tokens=65_536,
|
||||||
"description": "Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input",
|
supports_extended_thinking=True, # Experimental thinking mode
|
||||||
},
|
supports_system_prompts=True,
|
||||||
"gemini-2.0-flash-lite": {
|
supports_streaming=True,
|
||||||
"context_window": 1_048_576, # 1M tokens
|
supports_function_calling=True,
|
||||||
"supports_extended_thinking": False, # Not supported per user request
|
supports_json_mode=True,
|
||||||
"max_thinking_tokens": 0, # No thinking support
|
supports_images=True, # Vision capability
|
||||||
"supports_images": False, # Does not support images
|
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
|
||||||
"max_image_size_mb": 0.0, # No image support
|
supports_temperature=True,
|
||||||
"description": "Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only",
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
},
|
max_thinking_tokens=24576, # Same as 2.5 flash for consistency
|
||||||
"gemini-2.5-flash": {
|
description="Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input",
|
||||||
"context_window": 1_048_576, # 1M tokens
|
aliases=["flash-2.0", "flash2"],
|
||||||
"supports_extended_thinking": True,
|
),
|
||||||
"max_thinking_tokens": 24576, # Flash 2.5 thinking budget limit
|
"gemini-2.0-flash-lite": ModelCapabilities(
|
||||||
"supports_images": True, # Vision capability
|
provider=ProviderType.GOOGLE,
|
||||||
"max_image_size_mb": 20.0, # Conservative 20MB limit for reliability
|
model_name="gemini-2.0-flash-lite",
|
||||||
"description": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
friendly_name="Gemin (Flash Lite 2.0)",
|
||||||
},
|
context_window=1_048_576, # 1M tokens
|
||||||
"gemini-2.5-pro": {
|
max_output_tokens=65_536,
|
||||||
"context_window": 1_048_576, # 1M tokens
|
supports_extended_thinking=False, # Not supported per user request
|
||||||
"supports_extended_thinking": True,
|
supports_system_prompts=True,
|
||||||
"max_thinking_tokens": 32768, # Pro 2.5 thinking budget limit
|
supports_streaming=True,
|
||||||
"supports_images": True, # Vision capability
|
supports_function_calling=True,
|
||||||
"max_image_size_mb": 32.0, # Higher limit for Pro model
|
supports_json_mode=True,
|
||||||
"description": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
supports_images=False, # Does not support images
|
||||||
},
|
max_image_size_mb=0.0, # No image support
|
||||||
# Shorthands
|
supports_temperature=True,
|
||||||
"flash": "gemini-2.5-flash",
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
"flash-2.0": "gemini-2.0-flash",
|
description="Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only",
|
||||||
"flash2": "gemini-2.0-flash",
|
aliases=["flashlite", "flash-lite"],
|
||||||
"flashlite": "gemini-2.0-flash-lite",
|
),
|
||||||
"flash-lite": "gemini-2.0-flash-lite",
|
"gemini-2.5-flash": ModelCapabilities(
|
||||||
"pro": "gemini-2.5-pro",
|
provider=ProviderType.GOOGLE,
|
||||||
|
model_name="gemini-2.5-flash",
|
||||||
|
friendly_name="Gemini (Flash 2.5)",
|
||||||
|
context_window=1_048_576, # 1M tokens
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=True,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # Vision capability
|
||||||
|
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
max_thinking_tokens=24576, # Flash 2.5 thinking budget limit
|
||||||
|
description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||||
|
aliases=["flash", "flash2.5"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-pro": ModelCapabilities(
|
||||||
|
provider=ProviderType.GOOGLE,
|
||||||
|
model_name="gemini-2.5-pro",
|
||||||
|
friendly_name="Gemini (Pro 2.5)",
|
||||||
|
context_window=1_048_576, # 1M tokens
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=True,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # Vision capability
|
||||||
|
max_image_size_mb=32.0, # Higher limit for Pro model
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
max_thinking_tokens=32768, # Max thinking tokens for Pro model
|
||||||
|
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||||
|
aliases=["pro", "gemini pro", "gemini-pro"],
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
||||||
@@ -70,6 +106,14 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
"max": 1.0, # 100% of max - full thinking budget
|
"max": 1.0, # 100% of max - full thinking budget
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Model-specific thinking token limits
|
||||||
|
MAX_THINKING_TOKENS = {
|
||||||
|
"gemini-2.0-flash": 24576, # Same as 2.5 flash for consistency
|
||||||
|
"gemini-2.0-flash-lite": 0, # No thinking support
|
||||||
|
"gemini-2.5-flash": 24576, # Flash 2.5 thinking budget limit
|
||||||
|
"gemini-2.5-pro": 32768, # Pro 2.5 thinking budget limit
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
"""Initialize Gemini provider with API key."""
|
"""Initialize Gemini provider with API key."""
|
||||||
super().__init__(api_key, **kwargs)
|
super().__init__(api_key, **kwargs)
|
||||||
@@ -100,25 +144,8 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
||||||
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
|
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
config = self.SUPPORTED_MODELS[resolved_name]
|
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||||
|
return self.SUPPORTED_MODELS[resolved_name]
|
||||||
# Gemini models support 0.0-2.0 temperature range
|
|
||||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
|
||||||
|
|
||||||
return ModelCapabilities(
|
|
||||||
provider=ProviderType.GOOGLE,
|
|
||||||
model_name=resolved_name,
|
|
||||||
friendly_name="Gemini",
|
|
||||||
context_window=config["context_window"],
|
|
||||||
supports_extended_thinking=config["supports_extended_thinking"],
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=True,
|
|
||||||
supports_images=config.get("supports_images", False),
|
|
||||||
max_image_size_mb=config.get("max_image_size_mb", 0.0),
|
|
||||||
supports_temperature=True, # Gemini models accept temperature parameter
|
|
||||||
temperature_constraint=temp_constraint,
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
self,
|
self,
|
||||||
@@ -179,8 +206,8 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||||
# Get model's max thinking tokens and calculate actual budget
|
# Get model's max thinking tokens and calculate actual budget
|
||||||
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
||||||
if model_config and "max_thinking_tokens" in model_config:
|
if model_config and model_config.max_thinking_tokens > 0:
|
||||||
max_thinking_tokens = model_config["max_thinking_tokens"]
|
max_thinking_tokens = model_config.max_thinking_tokens
|
||||||
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||||
generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget)
|
generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget)
|
||||||
|
|
||||||
@@ -258,7 +285,7 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# First check if model is supported
|
# First check if model is supported
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Then check if model is allowed by restrictions
|
# Then check if model is allowed by restrictions
|
||||||
@@ -281,78 +308,20 @@ class GeminiModelProvider(ModelProvider):
|
|||||||
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
||||||
"""Get actual thinking token budget for a model and thinking mode."""
|
"""Get actual thinking token budget for a model and thinking mode."""
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
model_config = self.SUPPORTED_MODELS.get(resolved_name, {})
|
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
||||||
|
|
||||||
if not model_config.get("supports_extended_thinking", False):
|
if not model_config or not model_config.supports_extended_thinking:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if thinking_mode not in self.THINKING_BUDGETS:
|
if thinking_mode not in self.THINKING_BUDGETS:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
max_thinking_tokens = model_config.get("max_thinking_tokens", 0)
|
max_thinking_tokens = model_config.max_thinking_tokens
|
||||||
if max_thinking_tokens == 0:
|
if max_thinking_tokens == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
|
||||||
"""Return a list of model names supported by this provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model names available from this provider
|
|
||||||
"""
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
|
||||||
models = []
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Handle both base models (dict configs) and aliases (string values)
|
|
||||||
if isinstance(config, str):
|
|
||||||
# This is an alias - check if the target model would be allowed
|
|
||||||
target_model = config
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
|
||||||
continue
|
|
||||||
# Allow the alias
|
|
||||||
models.append(model_name)
|
|
||||||
else:
|
|
||||||
# This is a base model with config dict
|
|
||||||
# Check restrictions if enabled
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
|
||||||
continue
|
|
||||||
models.append(model_name)
|
|
||||||
|
|
||||||
return models
|
|
||||||
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
|
||||||
"""Return all model names known by this provider, including alias targets.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of all model names and alias targets known by this provider
|
|
||||||
"""
|
|
||||||
all_models = set()
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Add the model name itself
|
|
||||||
all_models.add(model_name.lower())
|
|
||||||
|
|
||||||
# If it's an alias (string value), add the target model too
|
|
||||||
if isinstance(config, str):
|
|
||||||
all_models.add(config.lower())
|
|
||||||
|
|
||||||
return list(all_models)
|
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
|
||||||
"""Resolve model shorthand to full name."""
|
|
||||||
# Check if it's a shorthand
|
|
||||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower())
|
|
||||||
if isinstance(shorthand_value, str):
|
|
||||||
return shorthand_value
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
def _extract_usage(self, response) -> dict[str, int]:
|
def _extract_usage(self, response) -> dict[str, int]:
|
||||||
"""Extract token usage from Gemini response."""
|
"""Extract token usage from Gemini response."""
|
||||||
usage = {}
|
usage = {}
|
||||||
|
|||||||
@@ -686,7 +686,6 @@ class OpenAICompatibleProvider(ModelProvider):
|
|||||||
"o3-mini",
|
"o3-mini",
|
||||||
"o3-pro",
|
"o3-pro",
|
||||||
"o4-mini",
|
"o4-mini",
|
||||||
"o4-mini-high",
|
|
||||||
# Note: Claude models would be handled by a separate provider
|
# Note: Claude models would be handled by a separate provider
|
||||||
}
|
}
|
||||||
supports = model_name.lower() in vision_models
|
supports = model_name.lower() in vision_models
|
||||||
|
|||||||
@@ -17,71 +17,98 @@ logger = logging.getLogger(__name__)
|
|||||||
class OpenAIModelProvider(OpenAICompatibleProvider):
|
class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||||
"""Official OpenAI API provider (api.openai.com)."""
|
"""Official OpenAI API provider (api.openai.com)."""
|
||||||
|
|
||||||
# Model configurations
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"o3": {
|
"o3": ModelCapabilities(
|
||||||
"context_window": 200_000, # 200K tokens
|
provider=ProviderType.OPENAI,
|
||||||
"supports_extended_thinking": False,
|
model_name="o3",
|
||||||
"supports_images": True, # O3 models support vision
|
friendly_name="OpenAI (O3)",
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
context_window=200_000, # 200K tokens
|
||||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
max_output_tokens=65536, # 64K max output tokens
|
||||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
supports_extended_thinking=False,
|
||||||
"description": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
|
supports_system_prompts=True,
|
||||||
},
|
supports_streaming=True,
|
||||||
"o3-mini": {
|
supports_function_calling=True,
|
||||||
"context_window": 200_000, # 200K tokens
|
supports_json_mode=True,
|
||||||
"supports_extended_thinking": False,
|
supports_images=True, # O3 models support vision
|
||||||
"supports_images": True, # O3 models support vision
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
description="Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
|
||||||
"description": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
aliases=[],
|
||||||
},
|
),
|
||||||
"o3-pro-2025-06-10": {
|
"o3-mini": ModelCapabilities(
|
||||||
"context_window": 200_000, # 200K tokens
|
provider=ProviderType.OPENAI,
|
||||||
"supports_extended_thinking": False,
|
model_name="o3-mini",
|
||||||
"supports_images": True, # O3 models support vision
|
friendly_name="OpenAI (O3-mini)",
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
context_window=200_000, # 200K tokens
|
||||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
max_output_tokens=65536, # 64K max output tokens
|
||||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
supports_extended_thinking=False,
|
||||||
"description": "Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
supports_system_prompts=True,
|
||||||
},
|
supports_streaming=True,
|
||||||
# Aliases
|
supports_function_calling=True,
|
||||||
"o3-pro": "o3-pro-2025-06-10",
|
supports_json_mode=True,
|
||||||
"o4-mini": {
|
supports_images=True, # O3 models support vision
|
||||||
"context_window": 200_000, # 200K tokens
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
"supports_extended_thinking": False,
|
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||||
"supports_images": True, # O4 models support vision
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||||
"supports_temperature": False, # O4 models don't accept temperature parameter
|
aliases=["o3mini", "o3-mini"],
|
||||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
),
|
||||||
"description": "Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
"o3-pro-2025-06-10": ModelCapabilities(
|
||||||
},
|
provider=ProviderType.OPENAI,
|
||||||
"o4-mini-high": {
|
model_name="o3-pro-2025-06-10",
|
||||||
"context_window": 200_000, # 200K tokens
|
friendly_name="OpenAI (O3-Pro)",
|
||||||
"supports_extended_thinking": False,
|
context_window=200_000, # 200K tokens
|
||||||
"supports_images": True, # O4 models support vision
|
max_output_tokens=65536, # 64K max output tokens
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
supports_extended_thinking=False,
|
||||||
"supports_temperature": False, # O4 models don't accept temperature parameter
|
supports_system_prompts=True,
|
||||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
supports_streaming=True,
|
||||||
"description": "Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks",
|
supports_function_calling=True,
|
||||||
},
|
supports_json_mode=True,
|
||||||
"gpt-4.1-2025-04-14": {
|
supports_images=True, # O3 models support vision
|
||||||
"context_window": 1_000_000, # 1M tokens
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
"supports_extended_thinking": False,
|
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||||
"supports_images": True, # GPT-4.1 supports vision
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
description="Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
||||||
"supports_temperature": True, # Regular models accept temperature parameter
|
aliases=["o3-pro"],
|
||||||
"temperature_constraint": "range", # 0.0-2.0 range
|
),
|
||||||
"description": "GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
"o4-mini": ModelCapabilities(
|
||||||
},
|
provider=ProviderType.OPENAI,
|
||||||
# Shorthands
|
model_name="o4-mini",
|
||||||
"mini": "o4-mini", # Default 'mini' to latest mini model
|
friendly_name="OpenAI (O4-mini)",
|
||||||
"o3mini": "o3-mini",
|
context_window=200_000, # 200K tokens
|
||||||
"o4mini": "o4-mini",
|
max_output_tokens=65536, # 64K max output tokens
|
||||||
"o4minihigh": "o4-mini-high",
|
supports_extended_thinking=False,
|
||||||
"o4minihi": "o4-mini-high",
|
supports_system_prompts=True,
|
||||||
"gpt4.1": "gpt-4.1-2025-04-14",
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # O4 models support vision
|
||||||
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=False, # O4 models don't accept temperature parameter
|
||||||
|
temperature_constraint=create_temperature_constraint("fixed"),
|
||||||
|
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
||||||
|
aliases=["mini", "o4mini", "o4-mini"],
|
||||||
|
),
|
||||||
|
"gpt-4.1-2025-04-14": ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name="gpt-4.1-2025-04-14",
|
||||||
|
friendly_name="OpenAI (GPT 4.1)",
|
||||||
|
context_window=1_000_000, # 1M tokens
|
||||||
|
max_output_tokens=32_768,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # GPT-4.1 supports vision
|
||||||
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=True, # Regular models accept temperature parameter
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
||||||
|
aliases=["gpt4.1"],
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
@@ -95,7 +122,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# Resolve shorthand
|
# Resolve shorthand
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||||
|
|
||||||
# Check if model is allowed by restrictions
|
# Check if model is allowed by restrictions
|
||||||
@@ -105,27 +132,8 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
||||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
config = self.SUPPORTED_MODELS[resolved_name]
|
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||||
|
return self.SUPPORTED_MODELS[resolved_name]
|
||||||
# Get temperature constraints and support from configuration
|
|
||||||
supports_temperature = config.get("supports_temperature", True) # Default to True for backward compatibility
|
|
||||||
temp_constraint_type = config.get("temperature_constraint", "range") # Default to range
|
|
||||||
temp_constraint = create_temperature_constraint(temp_constraint_type)
|
|
||||||
|
|
||||||
return ModelCapabilities(
|
|
||||||
provider=ProviderType.OPENAI,
|
|
||||||
model_name=model_name,
|
|
||||||
friendly_name="OpenAI",
|
|
||||||
context_window=config["context_window"],
|
|
||||||
supports_extended_thinking=config["supports_extended_thinking"],
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=True,
|
|
||||||
supports_images=config.get("supports_images", False),
|
|
||||||
max_image_size_mb=config.get("max_image_size_mb", 0.0),
|
|
||||||
supports_temperature=supports_temperature,
|
|
||||||
temperature_constraint=temp_constraint,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
@@ -136,7 +144,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# First check if model is supported
|
# First check if model is supported
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Then check if model is allowed by restrictions
|
# Then check if model is allowed by restrictions
|
||||||
@@ -177,61 +185,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# Currently no OpenAI models support extended thinking
|
# Currently no OpenAI models support extended thinking
|
||||||
# This may change with future O3 models
|
# This may change with future O3 models
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
|
||||||
"""Return a list of model names supported by this provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model names available from this provider
|
|
||||||
"""
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
|
||||||
models = []
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Handle both base models (dict configs) and aliases (string values)
|
|
||||||
if isinstance(config, str):
|
|
||||||
# This is an alias - check if the target model would be allowed
|
|
||||||
target_model = config
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
|
||||||
continue
|
|
||||||
# Allow the alias
|
|
||||||
models.append(model_name)
|
|
||||||
else:
|
|
||||||
# This is a base model with config dict
|
|
||||||
# Check restrictions if enabled
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
|
||||||
continue
|
|
||||||
models.append(model_name)
|
|
||||||
|
|
||||||
return models
|
|
||||||
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
|
||||||
"""Return all model names known by this provider, including alias targets.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of all model names and alias targets known by this provider
|
|
||||||
"""
|
|
||||||
all_models = set()
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Add the model name itself
|
|
||||||
all_models.add(model_name.lower())
|
|
||||||
|
|
||||||
# If it's an alias (string value), add the target model too
|
|
||||||
if isinstance(config, str):
|
|
||||||
all_models.add(config.lower())
|
|
||||||
|
|
||||||
return list(all_models)
|
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
|
||||||
"""Resolve model shorthand to full name."""
|
|
||||||
# Check if it's a shorthand
|
|
||||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
|
||||||
if isinstance(shorthand_value, str):
|
|
||||||
return shorthand_value
|
|
||||||
return model_name
|
|
||||||
|
|||||||
@@ -50,14 +50,6 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
aliases = self._registry.list_aliases()
|
aliases = self._registry.list_aliases()
|
||||||
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
|
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
|
||||||
|
|
||||||
def _parse_allowed_models(self) -> None:
|
|
||||||
"""Override to disable environment-based allow-list.
|
|
||||||
|
|
||||||
OpenRouter model access is controlled via the OpenRouter dashboard,
|
|
||||||
not through environment variables.
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
"""Resolve model aliases to OpenRouter model names.
|
"""Resolve model aliases to OpenRouter model names.
|
||||||
|
|
||||||
@@ -109,6 +101,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
model_name=resolved_name,
|
model_name=resolved_name,
|
||||||
friendly_name=self.FRIENDLY_NAME,
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
context_window=32_768, # Conservative default context window
|
context_window=32_768, # Conservative default context window
|
||||||
|
max_output_tokens=32_768,
|
||||||
supports_extended_thinking=False,
|
supports_extended_thinking=False,
|
||||||
supports_system_prompts=True,
|
supports_system_prompts=True,
|
||||||
supports_streaming=True,
|
supports_streaming=True,
|
||||||
@@ -130,16 +123,34 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
As the catch-all provider, OpenRouter accepts any model name that wasn't
|
As the catch-all provider, OpenRouter accepts any model name that wasn't
|
||||||
handled by higher-priority providers. OpenRouter will validate based on
|
handled by higher-priority providers. OpenRouter will validate based on
|
||||||
the API key's permissions.
|
the API key's permissions and local restrictions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Model name to validate
|
model_name: Model name to validate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Always True - OpenRouter is the catch-all provider
|
True if model is allowed, False if restricted
|
||||||
"""
|
"""
|
||||||
# Accept any model name - OpenRouter is the fallback provider
|
# Check model restrictions if configured
|
||||||
# Higher priority providers (native APIs, custom endpoints) get first chance
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service()
|
||||||
|
if restriction_service:
|
||||||
|
# Check if model name itself is allowed
|
||||||
|
if restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Also check aliases - model_name might be an alias
|
||||||
|
model_config = self._registry.resolve(model_name)
|
||||||
|
if model_config and model_config.aliases:
|
||||||
|
for alias in model_config.aliases:
|
||||||
|
if restriction_service.is_allowed(self.get_provider_type(), alias):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If restrictions are configured and model/alias not in allowed list, reject
|
||||||
|
return False
|
||||||
|
|
||||||
|
# No restrictions configured - accept any model name as the fallback provider
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def generate_content(
|
def generate_content(
|
||||||
@@ -260,3 +271,35 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||||||
all_models.add(config.model_name.lower())
|
all_models.add(config.model_name.lower())
|
||||||
|
|
||||||
return list(all_models)
|
return list(all_models)
|
||||||
|
|
||||||
|
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Get model configurations from the registry.
|
||||||
|
|
||||||
|
For OpenRouter, we convert registry configurations to ModelCapabilities objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping model names to their ModelCapabilities objects
|
||||||
|
"""
|
||||||
|
configs = {}
|
||||||
|
|
||||||
|
if self._registry:
|
||||||
|
# Get all models from registry
|
||||||
|
for model_name in self._registry.list_models():
|
||||||
|
# Only include models that this provider validates
|
||||||
|
if self.validate_model_name(model_name):
|
||||||
|
config = self._registry.resolve(model_name)
|
||||||
|
if config and not config.is_custom: # Only OpenRouter models, not custom ones
|
||||||
|
# Use ModelCapabilities directly from registry
|
||||||
|
configs[model_name] = config
|
||||||
|
|
||||||
|
return configs
|
||||||
|
|
||||||
|
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||||
|
"""Get all model aliases from the registry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping model names to their list of aliases
|
||||||
|
"""
|
||||||
|
# Since aliases are now included in the configurations,
|
||||||
|
# we can use the base class implementation
|
||||||
|
return super().get_all_model_aliases()
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -11,58 +10,10 @@ from utils.file_utils import read_json_file
|
|||||||
from .base import (
|
from .base import (
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
TemperatureConstraint,
|
|
||||||
create_temperature_constraint,
|
create_temperature_constraint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OpenRouterModelConfig:
|
|
||||||
"""Configuration for an OpenRouter model."""
|
|
||||||
|
|
||||||
model_name: str
|
|
||||||
aliases: list[str] = field(default_factory=list)
|
|
||||||
context_window: int = 32768 # Total context window size in tokens
|
|
||||||
supports_extended_thinking: bool = False
|
|
||||||
supports_system_prompts: bool = True
|
|
||||||
supports_streaming: bool = True
|
|
||||||
supports_function_calling: bool = False
|
|
||||||
supports_json_mode: bool = False
|
|
||||||
supports_images: bool = False # Whether model can process images
|
|
||||||
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
|
||||||
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
|
||||||
temperature_constraint: Optional[str] = (
|
|
||||||
None # Type of temperature constraint: "fixed", "range", "discrete", or None for default range
|
|
||||||
)
|
|
||||||
is_custom: bool = False # True for models that should only be used with custom endpoints
|
|
||||||
description: str = ""
|
|
||||||
|
|
||||||
def _create_temperature_constraint(self) -> TemperatureConstraint:
|
|
||||||
"""Create temperature constraint object from configuration.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TemperatureConstraint object based on configuration
|
|
||||||
"""
|
|
||||||
return create_temperature_constraint(self.temperature_constraint or "range")
|
|
||||||
|
|
||||||
def to_capabilities(self) -> ModelCapabilities:
|
|
||||||
"""Convert to ModelCapabilities object."""
|
|
||||||
return ModelCapabilities(
|
|
||||||
provider=ProviderType.OPENROUTER,
|
|
||||||
model_name=self.model_name,
|
|
||||||
friendly_name="OpenRouter",
|
|
||||||
context_window=self.context_window,
|
|
||||||
supports_extended_thinking=self.supports_extended_thinking,
|
|
||||||
supports_system_prompts=self.supports_system_prompts,
|
|
||||||
supports_streaming=self.supports_streaming,
|
|
||||||
supports_function_calling=self.supports_function_calling,
|
|
||||||
supports_images=self.supports_images,
|
|
||||||
max_image_size_mb=self.max_image_size_mb,
|
|
||||||
supports_temperature=self.supports_temperature,
|
|
||||||
temperature_constraint=self._create_temperature_constraint(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterModelRegistry:
|
class OpenRouterModelRegistry:
|
||||||
"""Registry for managing OpenRouter model configurations and aliases."""
|
"""Registry for managing OpenRouter model configurations and aliases."""
|
||||||
|
|
||||||
@@ -73,7 +24,7 @@ class OpenRouterModelRegistry:
|
|||||||
config_path: Path to config file. If None, uses default locations.
|
config_path: Path to config file. If None, uses default locations.
|
||||||
"""
|
"""
|
||||||
self.alias_map: dict[str, str] = {} # alias -> model_name
|
self.alias_map: dict[str, str] = {} # alias -> model_name
|
||||||
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config
|
self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config
|
||||||
|
|
||||||
# Determine config path
|
# Determine config path
|
||||||
if config_path:
|
if config_path:
|
||||||
@@ -139,7 +90,7 @@ class OpenRouterModelRegistry:
|
|||||||
self.alias_map = {}
|
self.alias_map = {}
|
||||||
self.model_map = {}
|
self.model_map = {}
|
||||||
|
|
||||||
def _read_config(self) -> list[OpenRouterModelConfig]:
|
def _read_config(self) -> list[ModelCapabilities]:
|
||||||
"""Read configuration from file.
|
"""Read configuration from file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -158,7 +109,27 @@ class OpenRouterModelRegistry:
|
|||||||
# Parse models
|
# Parse models
|
||||||
configs = []
|
configs = []
|
||||||
for model_data in data.get("models", []):
|
for model_data in data.get("models", []):
|
||||||
config = OpenRouterModelConfig(**model_data)
|
# Create ModelCapabilities directly from JSON data
|
||||||
|
# Handle temperature_constraint conversion
|
||||||
|
temp_constraint_str = model_data.get("temperature_constraint")
|
||||||
|
temp_constraint = create_temperature_constraint(temp_constraint_str or "range")
|
||||||
|
|
||||||
|
# Set provider-specific defaults based on is_custom flag
|
||||||
|
is_custom = model_data.get("is_custom", False)
|
||||||
|
if is_custom:
|
||||||
|
model_data.setdefault("provider", ProviderType.CUSTOM)
|
||||||
|
model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})")
|
||||||
|
else:
|
||||||
|
model_data.setdefault("provider", ProviderType.OPENROUTER)
|
||||||
|
model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})")
|
||||||
|
model_data["temperature_constraint"] = temp_constraint
|
||||||
|
|
||||||
|
# Remove the string version of temperature_constraint before creating ModelCapabilities
|
||||||
|
if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str):
|
||||||
|
del model_data["temperature_constraint"]
|
||||||
|
model_data["temperature_constraint"] = temp_constraint
|
||||||
|
|
||||||
|
config = ModelCapabilities(**model_data)
|
||||||
configs.append(config)
|
configs.append(config)
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
@@ -168,7 +139,7 @@ class OpenRouterModelRegistry:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error reading config from {self.config_path}: {e}")
|
raise ValueError(f"Error reading config from {self.config_path}: {e}")
|
||||||
|
|
||||||
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None:
|
def _build_maps(self, configs: list[ModelCapabilities]) -> None:
|
||||||
"""Build alias and model maps from configurations.
|
"""Build alias and model maps from configurations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -211,7 +182,7 @@ class OpenRouterModelRegistry:
|
|||||||
self.alias_map = alias_map
|
self.alias_map = alias_map
|
||||||
self.model_map = model_map
|
self.model_map = model_map
|
||||||
|
|
||||||
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]:
|
def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||||
"""Resolve a model name or alias to configuration.
|
"""Resolve a model name or alias to configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -237,10 +208,8 @@ class OpenRouterModelRegistry:
|
|||||||
Returns:
|
Returns:
|
||||||
ModelCapabilities if found, None otherwise
|
ModelCapabilities if found, None otherwise
|
||||||
"""
|
"""
|
||||||
config = self.resolve(name_or_alias)
|
# Registry now returns ModelCapabilities directly
|
||||||
if config:
|
return self.resolve(name_or_alias)
|
||||||
return config.to_capabilities()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def list_models(self) -> list[str]:
|
def list_models(self) -> list[str]:
|
||||||
"""List all available model names."""
|
"""List all available model names."""
|
||||||
|
|||||||
@@ -24,8 +24,6 @@ class ModelProviderRegistry:
|
|||||||
cls._instance._providers = {}
|
cls._instance._providers = {}
|
||||||
cls._instance._initialized_providers = {}
|
cls._instance._initialized_providers = {}
|
||||||
logging.debug(f"REGISTRY: Created instance {cls._instance}")
|
logging.debug(f"REGISTRY: Created instance {cls._instance}")
|
||||||
else:
|
|
||||||
logging.debug(f"REGISTRY: Returning existing instance {cls._instance}")
|
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -129,7 +127,6 @@ class ModelProviderRegistry:
|
|||||||
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
|
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
|
||||||
|
|
||||||
for provider_type in PROVIDER_PRIORITY_ORDER:
|
for provider_type in PROVIDER_PRIORITY_ORDER:
|
||||||
logging.debug(f"Checking provider_type: {provider_type}")
|
|
||||||
if provider_type in instance._providers:
|
if provider_type in instance._providers:
|
||||||
logging.debug(f"Found {provider_type} in registry")
|
logging.debug(f"Found {provider_type} in registry")
|
||||||
# Get or create provider instance
|
# Get or create provider instance
|
||||||
|
|||||||
136
providers/xai.py
136
providers/xai.py
@@ -7,7 +7,7 @@ from .base import (
|
|||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
create_temperature_constraint,
|
||||||
)
|
)
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
|
||||||
@@ -19,23 +19,44 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
FRIENDLY_NAME = "X.AI"
|
FRIENDLY_NAME = "X.AI"
|
||||||
|
|
||||||
# Model configurations
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"grok-3": {
|
"grok-3": ModelCapabilities(
|
||||||
"context_window": 131_072, # 131K tokens
|
provider=ProviderType.XAI,
|
||||||
"supports_extended_thinking": False,
|
model_name="grok-3",
|
||||||
"description": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
friendly_name="X.AI (Grok 3)",
|
||||||
},
|
context_window=131_072, # 131K tokens
|
||||||
"grok-3-fast": {
|
max_output_tokens=131072,
|
||||||
"context_window": 131_072, # 131K tokens
|
supports_extended_thinking=False,
|
||||||
"supports_extended_thinking": False,
|
supports_system_prompts=True,
|
||||||
"description": "GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
supports_streaming=True,
|
||||||
},
|
supports_function_calling=True,
|
||||||
# Shorthands for convenience
|
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||||
"grok": "grok-3", # Default to grok-3
|
supports_images=False, # Assuming GROK is text-only for now
|
||||||
"grok3": "grok-3",
|
max_image_size_mb=0.0,
|
||||||
"grok3fast": "grok-3-fast",
|
supports_temperature=True,
|
||||||
"grokfast": "grok-3-fast",
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||||
|
aliases=["grok", "grok3"],
|
||||||
|
),
|
||||||
|
"grok-3-fast": ModelCapabilities(
|
||||||
|
provider=ProviderType.XAI,
|
||||||
|
model_name="grok-3-fast",
|
||||||
|
friendly_name="X.AI (Grok 3 Fast)",
|
||||||
|
context_window=131_072, # 131K tokens
|
||||||
|
max_output_tokens=131072,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||||
|
supports_images=False, # Assuming GROK is text-only for now
|
||||||
|
max_image_size_mb=0.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
|
description="GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
||||||
|
aliases=["grok3fast", "grokfast", "grok3-fast"],
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str, **kwargs):
|
def __init__(self, api_key: str, **kwargs):
|
||||||
@@ -49,7 +70,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# Resolve shorthand
|
# Resolve shorthand
|
||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
||||||
|
|
||||||
# Check if model is allowed by restrictions
|
# Check if model is allowed by restrictions
|
||||||
@@ -59,23 +80,8 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
||||||
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
config = self.SUPPORTED_MODELS[resolved_name]
|
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||||
|
return self.SUPPORTED_MODELS[resolved_name]
|
||||||
# Define temperature constraints for GROK models
|
|
||||||
# GROK supports the standard OpenAI temperature range
|
|
||||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
|
||||||
|
|
||||||
return ModelCapabilities(
|
|
||||||
provider=ProviderType.XAI,
|
|
||||||
model_name=resolved_name,
|
|
||||||
friendly_name=self.FRIENDLY_NAME,
|
|
||||||
context_window=config["context_window"],
|
|
||||||
supports_extended_thinking=config["supports_extended_thinking"],
|
|
||||||
supports_system_prompts=True,
|
|
||||||
supports_streaming=True,
|
|
||||||
supports_function_calling=True,
|
|
||||||
temperature_constraint=temp_constraint,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
@@ -86,7 +92,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
resolved_name = self._resolve_model_name(model_name)
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
# First check if model is supported
|
# First check if model is supported
|
||||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
if resolved_name not in self.SUPPORTED_MODELS:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Then check if model is allowed by restrictions
|
# Then check if model is allowed by restrictions
|
||||||
@@ -127,61 +133,3 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
|||||||
# Currently GROK models do not support extended thinking
|
# Currently GROK models do not support extended thinking
|
||||||
# This may change with future GROK model releases
|
# This may change with future GROK model releases
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
|
||||||
"""Return a list of model names supported by this provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model names available from this provider
|
|
||||||
"""
|
|
||||||
from utils.model_restrictions import get_restriction_service
|
|
||||||
|
|
||||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
|
||||||
models = []
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Handle both base models (dict configs) and aliases (string values)
|
|
||||||
if isinstance(config, str):
|
|
||||||
# This is an alias - check if the target model would be allowed
|
|
||||||
target_model = config
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
|
||||||
continue
|
|
||||||
# Allow the alias
|
|
||||||
models.append(model_name)
|
|
||||||
else:
|
|
||||||
# This is a base model with config dict
|
|
||||||
# Check restrictions if enabled
|
|
||||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
|
||||||
continue
|
|
||||||
models.append(model_name)
|
|
||||||
|
|
||||||
return models
|
|
||||||
|
|
||||||
def list_all_known_models(self) -> list[str]:
|
|
||||||
"""Return all model names known by this provider, including alias targets.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of all model names and alias targets known by this provider
|
|
||||||
"""
|
|
||||||
all_models = set()
|
|
||||||
|
|
||||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
|
||||||
# Add the model name itself
|
|
||||||
all_models.add(model_name.lower())
|
|
||||||
|
|
||||||
# If it's an alias (string value), add the target model too
|
|
||||||
if isinstance(config, str):
|
|
||||||
all_models.add(config.lower())
|
|
||||||
|
|
||||||
return list(all_models)
|
|
||||||
|
|
||||||
def _resolve_model_name(self, model_name: str) -> str:
|
|
||||||
"""Resolve model shorthand to full name."""
|
|
||||||
# Check if it's a shorthand
|
|
||||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
|
||||||
if isinstance(shorthand_value, str):
|
|
||||||
return shorthand_value
|
|
||||||
return model_name
|
|
||||||
|
|||||||
108
server.py
108
server.py
@@ -158,6 +158,97 @@ logger = logging.getLogger(__name__)
|
|||||||
# This name is used by MCP clients to identify and connect to this specific server
|
# This name is used by MCP clients to identify and connect to this specific server
|
||||||
server: Server = Server("zen-server")
|
server: Server = Server("zen-server")
|
||||||
|
|
||||||
|
|
||||||
|
# Constants for tool filtering
|
||||||
|
ESSENTIAL_TOOLS = {"version", "listmodels"}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_disabled_tools_env() -> set[str]:
|
||||||
|
"""
|
||||||
|
Parse the DISABLED_TOOLS environment variable into a set of tool names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of lowercase tool names to disable, empty set if none specified
|
||||||
|
"""
|
||||||
|
disabled_tools_env = os.getenv("DISABLED_TOOLS", "").strip()
|
||||||
|
if not disabled_tools_env:
|
||||||
|
return set()
|
||||||
|
return {t.strip().lower() for t in disabled_tools_env.split(",") if t.strip()}
|
||||||
|
|
||||||
|
|
||||||
|
def validate_disabled_tools(disabled_tools: set[str], all_tools: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Validate the disabled tools list and log appropriate warnings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
disabled_tools: Set of tool names requested to be disabled
|
||||||
|
all_tools: Dictionary of all available tool instances
|
||||||
|
"""
|
||||||
|
essential_disabled = disabled_tools & ESSENTIAL_TOOLS
|
||||||
|
if essential_disabled:
|
||||||
|
logger.warning(f"Cannot disable essential tools: {sorted(essential_disabled)}")
|
||||||
|
unknown_tools = disabled_tools - set(all_tools.keys())
|
||||||
|
if unknown_tools:
|
||||||
|
logger.warning(f"Unknown tools in DISABLED_TOOLS: {sorted(unknown_tools)}")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_tool_filter(all_tools: dict[str, Any], disabled_tools: set[str]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Apply the disabled tools filter to create the final tools dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_tools: Dictionary of all available tool instances
|
||||||
|
disabled_tools: Set of tool names to disable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing only enabled tools
|
||||||
|
"""
|
||||||
|
enabled_tools = {}
|
||||||
|
for tool_name, tool_instance in all_tools.items():
|
||||||
|
if tool_name in ESSENTIAL_TOOLS or tool_name not in disabled_tools:
|
||||||
|
enabled_tools[tool_name] = tool_instance
|
||||||
|
else:
|
||||||
|
logger.debug(f"Tool '{tool_name}' disabled via DISABLED_TOOLS")
|
||||||
|
return enabled_tools
|
||||||
|
|
||||||
|
|
||||||
|
def log_tool_configuration(disabled_tools: set[str], enabled_tools: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Log the final tool configuration for visibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
disabled_tools: Set of tool names that were requested to be disabled
|
||||||
|
enabled_tools: Dictionary of tools that remain enabled
|
||||||
|
"""
|
||||||
|
if not disabled_tools:
|
||||||
|
logger.info("All tools enabled (DISABLED_TOOLS not set)")
|
||||||
|
return
|
||||||
|
actual_disabled = disabled_tools - ESSENTIAL_TOOLS
|
||||||
|
if actual_disabled:
|
||||||
|
logger.debug(f"Disabled tools: {sorted(actual_disabled)}")
|
||||||
|
logger.info(f"Active tools: {sorted(enabled_tools.keys())}")
|
||||||
|
|
||||||
|
|
||||||
|
def filter_disabled_tools(all_tools: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Filter tools based on DISABLED_TOOLS environment variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_tools: Dictionary of all available tool instances
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Filtered dictionary containing only enabled tools
|
||||||
|
"""
|
||||||
|
disabled_tools = parse_disabled_tools_env()
|
||||||
|
if not disabled_tools:
|
||||||
|
log_tool_configuration(disabled_tools, all_tools)
|
||||||
|
return all_tools
|
||||||
|
validate_disabled_tools(disabled_tools, all_tools)
|
||||||
|
enabled_tools = apply_tool_filter(all_tools, disabled_tools)
|
||||||
|
log_tool_configuration(disabled_tools, enabled_tools)
|
||||||
|
return enabled_tools
|
||||||
|
|
||||||
|
|
||||||
# Initialize the tool registry with all available AI-powered tools
|
# Initialize the tool registry with all available AI-powered tools
|
||||||
# Each tool provides specialized functionality for different development tasks
|
# Each tool provides specialized functionality for different development tasks
|
||||||
# Tools are instantiated once and reused across requests (stateless design)
|
# Tools are instantiated once and reused across requests (stateless design)
|
||||||
@@ -178,6 +269,7 @@ TOOLS = {
|
|||||||
"listmodels": ListModelsTool(), # List all available AI models by provider
|
"listmodels": ListModelsTool(), # List all available AI models by provider
|
||||||
"version": VersionTool(), # Display server version and system information
|
"version": VersionTool(), # Display server version and system information
|
||||||
}
|
}
|
||||||
|
TOOLS = filter_disabled_tools(TOOLS)
|
||||||
|
|
||||||
# Rich prompt templates for all tools
|
# Rich prompt templates for all tools
|
||||||
PROMPT_TEMPLATES = {
|
PROMPT_TEMPLATES = {
|
||||||
@@ -673,6 +765,11 @@ def parse_model_option(model_string: str) -> tuple[str, Optional[str]]:
|
|||||||
"""
|
"""
|
||||||
Parse model:option format into model name and option.
|
Parse model:option format into model name and option.
|
||||||
|
|
||||||
|
Handles different formats:
|
||||||
|
- OpenRouter models: preserve :free, :beta, :preview suffixes as part of model name
|
||||||
|
- Ollama/Custom models: split on : to extract tags like :latest
|
||||||
|
- Consensus stance: extract options like :for, :against
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_string: String that may contain "model:option" format
|
model_string: String that may contain "model:option" format
|
||||||
|
|
||||||
@@ -680,6 +777,17 @@ def parse_model_option(model_string: str) -> tuple[str, Optional[str]]:
|
|||||||
tuple: (model_name, option) where option may be None
|
tuple: (model_name, option) where option may be None
|
||||||
"""
|
"""
|
||||||
if ":" in model_string and not model_string.startswith("http"): # Avoid parsing URLs
|
if ":" in model_string and not model_string.startswith("http"): # Avoid parsing URLs
|
||||||
|
# Check if this looks like an OpenRouter model (contains /)
|
||||||
|
if "/" in model_string and model_string.count(":") == 1:
|
||||||
|
# Could be openai/gpt-4:something - check what comes after colon
|
||||||
|
parts = model_string.split(":", 1)
|
||||||
|
suffix = parts[1].strip().lower()
|
||||||
|
|
||||||
|
# Known OpenRouter suffixes to preserve
|
||||||
|
if suffix in ["free", "beta", "preview"]:
|
||||||
|
return model_string.strip(), None
|
||||||
|
|
||||||
|
# For other patterns (Ollama tags, consensus stances), split normally
|
||||||
parts = model_string.split(":", 1)
|
parts = model_string.split(":", 1)
|
||||||
model_name = parts[0].strip()
|
model_name = parts[0].strip()
|
||||||
model_option = parts[1].strip() if len(parts) > 1 else None
|
model_option = parts[1].strip() if len(parts) > 1 else None
|
||||||
|
|||||||
@@ -182,6 +182,10 @@ class ConversationBaseTest(BaseSimulatorTest):
|
|||||||
|
|
||||||
# Look for continuation_id in various places
|
# Look for continuation_id in various places
|
||||||
if isinstance(response_data, dict):
|
if isinstance(response_data, dict):
|
||||||
|
# Check top-level continuation_id (workflow tools)
|
||||||
|
if "continuation_id" in response_data:
|
||||||
|
return response_data["continuation_id"]
|
||||||
|
|
||||||
# Check metadata
|
# Check metadata
|
||||||
metadata = response_data.get("metadata", {})
|
metadata = response_data.get("metadata", {})
|
||||||
if "thread_id" in metadata:
|
if "thread_id" in metadata:
|
||||||
|
|||||||
@@ -91,11 +91,14 @@ class TestClass:
|
|||||||
response_a2, continuation_id_a2 = self.call_mcp_tool(
|
response_a2, continuation_id_a2 = self.call_mcp_tool(
|
||||||
"analyze",
|
"analyze",
|
||||||
{
|
{
|
||||||
"prompt": "Now analyze the code quality and suggest improvements.",
|
"step": "Now analyze the code quality and suggest improvements.",
|
||||||
"files": [test_file_path],
|
"step_number": 1,
|
||||||
|
"total_steps": 2,
|
||||||
|
"next_step_required": False,
|
||||||
|
"findings": "Continuing analysis from previous chat conversation to analyze code quality.",
|
||||||
|
"relevant_files": [test_file_path],
|
||||||
"continuation_id": continuation_id_a1,
|
"continuation_id": continuation_id_a1,
|
||||||
"model": "flash",
|
"model": "flash",
|
||||||
"temperature": 0.7,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -154,10 +157,14 @@ class TestClass:
|
|||||||
response_b2, continuation_id_b2 = self.call_mcp_tool(
|
response_b2, continuation_id_b2 = self.call_mcp_tool(
|
||||||
"analyze",
|
"analyze",
|
||||||
{
|
{
|
||||||
"prompt": "Analyze the previous greeting and suggest improvements.",
|
"step": "Analyze the previous greeting and suggest improvements.",
|
||||||
|
"step_number": 1,
|
||||||
|
"total_steps": 1,
|
||||||
|
"next_step_required": False,
|
||||||
|
"findings": "Analyzing the greeting from previous conversation and suggesting improvements.",
|
||||||
|
"relevant_files": [test_file_path],
|
||||||
"continuation_id": continuation_id_b1,
|
"continuation_id": continuation_id_b1,
|
||||||
"model": "flash",
|
"model": "flash",
|
||||||
"temperature": 0.7,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -206,11 +206,14 @@ if __name__ == "__main__":
|
|||||||
response2, continuation_id2 = self.call_mcp_tool(
|
response2, continuation_id2 = self.call_mcp_tool(
|
||||||
"analyze",
|
"analyze",
|
||||||
{
|
{
|
||||||
"prompt": "Analyze the performance implications of these recursive functions.",
|
"step": "Analyze the performance implications of these recursive functions.",
|
||||||
"files": [file1_path],
|
"step_number": 1,
|
||||||
|
"total_steps": 1,
|
||||||
|
"next_step_required": False,
|
||||||
|
"findings": "Continuing from chat conversation to analyze performance implications of recursive functions.",
|
||||||
|
"relevant_files": [file1_path],
|
||||||
"continuation_id": continuation_id1, # Continue the chat conversation
|
"continuation_id": continuation_id1, # Continue the chat conversation
|
||||||
"model": "flash",
|
"model": "flash",
|
||||||
"temperature": 0.7,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -221,10 +224,14 @@ if __name__ == "__main__":
|
|||||||
self.logger.info(f" ✅ Step 2 completed with continuation_id: {continuation_id2[:8]}...")
|
self.logger.info(f" ✅ Step 2 completed with continuation_id: {continuation_id2[:8]}...")
|
||||||
continuation_ids.append(continuation_id2)
|
continuation_ids.append(continuation_id2)
|
||||||
|
|
||||||
# Validate that we got a different continuation ID
|
# Validate continuation ID behavior for workflow tools
|
||||||
if continuation_id2 == continuation_id1:
|
# Workflow tools reuse the same continuation_id when continuing within a workflow session
|
||||||
self.logger.error(" ❌ Step 2: Got same continuation ID as Step 1 - continuation not working")
|
# This is expected behavior and different from simple tools
|
||||||
return False
|
if continuation_id2 != continuation_id1:
|
||||||
|
self.logger.info(" ✅ Step 2: Got new continuation ID (workflow behavior)")
|
||||||
|
else:
|
||||||
|
self.logger.info(" ✅ Step 2: Reused continuation ID (workflow session continuation)")
|
||||||
|
# Both behaviors are valid - what matters is that we got a continuation_id
|
||||||
|
|
||||||
# Validate that Step 2 is building on Step 1's conversation
|
# Validate that Step 2 is building on Step 1's conversation
|
||||||
# Check if the response references the previous conversation
|
# Check if the response references the previous conversation
|
||||||
@@ -276,17 +283,16 @@ if __name__ == "__main__":
|
|||||||
all_have_continuation_ids = bool(continuation_id1 and continuation_id2 and continuation_id3)
|
all_have_continuation_ids = bool(continuation_id1 and continuation_id2 and continuation_id3)
|
||||||
criteria.append(("All steps generated continuation IDs", all_have_continuation_ids))
|
criteria.append(("All steps generated continuation IDs", all_have_continuation_ids))
|
||||||
|
|
||||||
# 3. Each continuation ID is unique
|
# 3. Continuation behavior validation (handles both simple and workflow tools)
|
||||||
unique_continuation_ids = len(set(continuation_ids)) == len(continuation_ids)
|
# Simple tools create new IDs each time, workflow tools may reuse IDs within sessions
|
||||||
criteria.append(("Each response generated unique continuation ID", unique_continuation_ids))
|
has_valid_continuation_pattern = len(continuation_ids) == 3
|
||||||
|
criteria.append(("Valid continuation ID pattern", has_valid_continuation_pattern))
|
||||||
|
|
||||||
# 4. Continuation IDs follow the expected pattern
|
# 4. Check for conversation continuity (more important than ID uniqueness)
|
||||||
step_ids_different = (
|
conversation_has_continuity = len(continuation_ids) == 3 and all(
|
||||||
len(continuation_ids) == 3
|
cid is not None for cid in continuation_ids
|
||||||
and continuation_ids[0] != continuation_ids[1]
|
|
||||||
and continuation_ids[1] != continuation_ids[2]
|
|
||||||
)
|
)
|
||||||
criteria.append(("All continuation IDs are different", step_ids_different))
|
criteria.append(("Conversation continuity maintained", conversation_has_continuity))
|
||||||
|
|
||||||
# 5. Check responses build on each other (content validation)
|
# 5. Check responses build on each other (content validation)
|
||||||
step1_has_function_analysis = "fibonacci" in response1.lower() or "factorial" in response1.lower()
|
step1_has_function_analysis = "fibonacci" in response1.lower() or "factorial" in response1.lower()
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
friendly_name="Gemini",
|
friendly_name="Gemini",
|
||||||
context_window=context_window,
|
context_window=context_window,
|
||||||
|
max_output_tokens=8192,
|
||||||
supports_extended_thinking=False,
|
supports_extended_thinking=False,
|
||||||
supports_system_prompts=True,
|
supports_system_prompts=True,
|
||||||
supports_streaming=True,
|
supports_streaming=True,
|
||||||
|
|||||||
@@ -211,7 +211,7 @@ class TestAliasTargetRestrictions:
|
|||||||
# Verify the polymorphic method was called
|
# Verify the polymorphic method was called
|
||||||
mock_provider.list_all_known_models.assert_called_once()
|
mock_provider.list_all_known_models.assert_called_once()
|
||||||
|
|
||||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high"}) # Restrict to specific model
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}) # Restrict to specific model
|
||||||
def test_complex_alias_chains_handled_correctly(self):
|
def test_complex_alias_chains_handled_correctly(self):
|
||||||
"""Test that complex alias chains are handled correctly in restrictions."""
|
"""Test that complex alias chains are handled correctly in restrictions."""
|
||||||
# Clear cached restriction service
|
# Clear cached restriction service
|
||||||
@@ -221,12 +221,11 @@ class TestAliasTargetRestrictions:
|
|||||||
|
|
||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# Only o4-mini-high should be allowed
|
# Only o4-mini should be allowed
|
||||||
assert provider.validate_model_name("o4-mini-high")
|
assert provider.validate_model_name("o4-mini")
|
||||||
|
|
||||||
# Other models should be blocked
|
# Other models should be blocked
|
||||||
assert not provider.validate_model_name("o4-mini")
|
assert not provider.validate_model_name("o3")
|
||||||
assert not provider.validate_model_name("mini") # This resolves to o4-mini
|
|
||||||
assert not provider.validate_model_name("o3-mini")
|
assert not provider.validate_model_name("o3-mini")
|
||||||
|
|
||||||
def test_critical_regression_validation_sees_alias_targets(self):
|
def test_critical_regression_validation_sees_alias_targets(self):
|
||||||
@@ -307,7 +306,7 @@ class TestAliasTargetRestrictions:
|
|||||||
it appear that target-based restrictions don't work.
|
it appear that target-based restrictions don't work.
|
||||||
"""
|
"""
|
||||||
# Test with a made-up restriction scenario
|
# Test with a made-up restriction scenario
|
||||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high,o3-mini"}):
|
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,o3-mini"}):
|
||||||
# Clear cached restriction service
|
# Clear cached restriction service
|
||||||
import utils.model_restrictions
|
import utils.model_restrictions
|
||||||
|
|
||||||
@@ -318,7 +317,7 @@ class TestAliasTargetRestrictions:
|
|||||||
|
|
||||||
# These specific target models should be recognized as valid
|
# These specific target models should be recognized as valid
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_all_known_models()
|
||||||
assert "o4-mini-high" in all_known, "Target model o4-mini-high should be known"
|
assert "o4-mini" in all_known, "Target model o4-mini should be known"
|
||||||
assert "o3-mini" in all_known, "Target model o3-mini should be known"
|
assert "o3-mini" in all_known, "Target model o3-mini should be known"
|
||||||
|
|
||||||
# Validation should not warn about these being unrecognized
|
# Validation should not warn about these being unrecognized
|
||||||
@@ -329,11 +328,11 @@ class TestAliasTargetRestrictions:
|
|||||||
# Should not warn about our allowed models being unrecognized
|
# Should not warn about our allowed models being unrecognized
|
||||||
all_warnings = [str(call) for call in mock_logger.warning.call_args_list]
|
all_warnings = [str(call) for call in mock_logger.warning.call_args_list]
|
||||||
for warning in all_warnings:
|
for warning in all_warnings:
|
||||||
assert "o4-mini-high" not in warning or "not a recognized" not in warning
|
assert "o4-mini" not in warning or "not a recognized" not in warning
|
||||||
assert "o3-mini" not in warning or "not a recognized" not in warning
|
assert "o3-mini" not in warning or "not a recognized" not in warning
|
||||||
|
|
||||||
# The restriction should actually work
|
# The restriction should actually work
|
||||||
assert provider.validate_model_name("o4-mini-high")
|
assert provider.validate_model_name("o4-mini")
|
||||||
assert provider.validate_model_name("o3-mini")
|
assert provider.validate_model_name("o3-mini")
|
||||||
assert not provider.validate_model_name("o4-mini") # not in allowed list
|
assert not provider.validate_model_name("o3-pro") # not in allowed list
|
||||||
assert not provider.validate_model_name("o3") # not in allowed list
|
assert not provider.validate_model_name("o3") # not in allowed list
|
||||||
|
|||||||
@@ -59,12 +59,12 @@ class TestAutoMode:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Check that model has description
|
# Check that model has description
|
||||||
description = config.get("description", "")
|
description = config.description if hasattr(config, "description") else ""
|
||||||
if description:
|
if description:
|
||||||
models_with_descriptions[model_name] = description
|
models_with_descriptions[model_name] = description
|
||||||
|
|
||||||
# Check all expected models are present with meaningful descriptions
|
# Check all expected models are present with meaningful descriptions
|
||||||
expected_models = ["flash", "pro", "o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high"]
|
expected_models = ["flash", "pro", "o3", "o3-mini", "o3-pro", "o4-mini"]
|
||||||
for model in expected_models:
|
for model in expected_models:
|
||||||
# Model should exist somewhere in the providers
|
# Model should exist somewhere in the providers
|
||||||
# Note: Some models might not be available if API keys aren't configured
|
# Note: Some models might not be available if API keys aren't configured
|
||||||
|
|||||||
@@ -319,7 +319,18 @@ class TestAutoModeComprehensive:
|
|||||||
m
|
m
|
||||||
for m in available_models
|
for m in available_models
|
||||||
if not m.startswith("gemini")
|
if not m.startswith("gemini")
|
||||||
and m not in ["flash", "pro", "flash-2.0", "flash2", "flashlite", "flash-lite"]
|
and m
|
||||||
|
not in [
|
||||||
|
"flash",
|
||||||
|
"pro",
|
||||||
|
"flash-2.0",
|
||||||
|
"flash2",
|
||||||
|
"flashlite",
|
||||||
|
"flash-lite",
|
||||||
|
"flash2.5",
|
||||||
|
"gemini pro",
|
||||||
|
"gemini-pro",
|
||||||
|
]
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
len(non_gemini_models) == 0
|
len(non_gemini_models) == 0
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class TestAutoModeCustomProviderOnly:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Clear all other provider keys
|
# Clear all other provider keys
|
||||||
clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]
|
clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]
|
||||||
|
|
||||||
with patch.dict(os.environ, test_env, clear=False):
|
with patch.dict(os.environ, test_env, clear=False):
|
||||||
# Ensure other provider keys are not set
|
# Ensure other provider keys are not set
|
||||||
@@ -109,7 +109,7 @@ class TestAutoModeCustomProviderOnly:
|
|||||||
|
|
||||||
with patch.dict(os.environ, test_env, clear=False):
|
with patch.dict(os.environ, test_env, clear=False):
|
||||||
# Clear other provider keys
|
# Clear other provider keys
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
|
||||||
if key in os.environ:
|
if key in os.environ:
|
||||||
del os.environ[key]
|
del os.environ[key]
|
||||||
|
|
||||||
@@ -177,7 +177,7 @@ class TestAutoModeCustomProviderOnly:
|
|||||||
|
|
||||||
with patch.dict(os.environ, test_env, clear=False):
|
with patch.dict(os.environ, test_env, clear=False):
|
||||||
# Clear other provider keys
|
# Clear other provider keys
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
|
||||||
if key in os.environ:
|
if key in os.environ:
|
||||||
del os.environ[key]
|
del os.environ[key]
|
||||||
|
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ class TestBuggyBehaviorPrevention:
|
|||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# Simulate a scenario where admin wants to restrict specific targets
|
# Simulate a scenario where admin wants to restrict specific targets
|
||||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini-high"}):
|
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
|
||||||
# Clear cached restriction service
|
# Clear cached restriction service
|
||||||
import utils.model_restrictions
|
import utils.model_restrictions
|
||||||
|
|
||||||
@@ -126,19 +126,21 @@ class TestBuggyBehaviorPrevention:
|
|||||||
|
|
||||||
# These should work because they're explicitly allowed
|
# These should work because they're explicitly allowed
|
||||||
assert provider.validate_model_name("o3-mini")
|
assert provider.validate_model_name("o3-mini")
|
||||||
assert provider.validate_model_name("o4-mini-high")
|
assert provider.validate_model_name("o4-mini")
|
||||||
|
|
||||||
# These should be blocked
|
# These should be blocked
|
||||||
assert not provider.validate_model_name("o4-mini") # Not in allowed list
|
assert not provider.validate_model_name("o3-pro") # Not in allowed list
|
||||||
assert not provider.validate_model_name("o3") # Not in allowed list
|
assert not provider.validate_model_name("o3") # Not in allowed list
|
||||||
assert not provider.validate_model_name("mini") # Resolves to o4-mini, not allowed
|
|
||||||
|
# This should be ALLOWED because it resolves to o4-mini which is in the allowed list
|
||||||
|
assert provider.validate_model_name("mini") # Resolves to o4-mini, which IS allowed
|
||||||
|
|
||||||
# Verify our list_all_known_models includes the restricted models
|
# Verify our list_all_known_models includes the restricted models
|
||||||
all_known = provider.list_all_known_models()
|
all_known = provider.list_all_known_models()
|
||||||
assert "o3-mini" in all_known # Should be known (and allowed)
|
assert "o3-mini" in all_known # Should be known (and allowed)
|
||||||
assert "o4-mini-high" in all_known # Should be known (and allowed)
|
assert "o4-mini" in all_known # Should be known (and allowed)
|
||||||
assert "o4-mini" in all_known # Should be known (but blocked)
|
assert "o3-pro" in all_known # Should be known (but blocked)
|
||||||
assert "mini" in all_known # Should be known (but blocked)
|
assert "mini" in all_known # Should be known (and allowed since it resolves to o4-mini)
|
||||||
|
|
||||||
def test_demonstration_of_old_vs_new_interface(self):
|
def test_demonstration_of_old_vs_new_interface(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -506,17 +506,17 @@ class TestConversationFlow:
|
|||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
mock_storage.return_value = mock_client
|
mock_storage.return_value = mock_client
|
||||||
|
|
||||||
# Start conversation with files
|
# Start conversation with files using a simple tool
|
||||||
thread_id = create_thread("analyze", {"prompt": "Analyze this codebase", "relevant_files": ["/project/src/"]})
|
thread_id = create_thread("chat", {"prompt": "Analyze this codebase", "files": ["/project/src/"]})
|
||||||
|
|
||||||
# Turn 1: Claude provides context with multiple files
|
# Turn 1: Claude provides context with multiple files
|
||||||
initial_context = ThreadContext(
|
initial_context = ThreadContext(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
created_at="2023-01-01T00:00:00Z",
|
created_at="2023-01-01T00:00:00Z",
|
||||||
last_updated_at="2023-01-01T00:00:00Z",
|
last_updated_at="2023-01-01T00:00:00Z",
|
||||||
tool_name="analyze",
|
tool_name="chat",
|
||||||
turns=[],
|
turns=[],
|
||||||
initial_context={"prompt": "Analyze this codebase", "relevant_files": ["/project/src/"]},
|
initial_context={"prompt": "Analyze this codebase", "files": ["/project/src/"]},
|
||||||
)
|
)
|
||||||
mock_client.get.return_value = initial_context.model_dump_json()
|
mock_client.get.return_value = initial_context.model_dump_json()
|
||||||
|
|
||||||
|
|||||||
@@ -45,10 +45,17 @@ class TestCustomProvider:
|
|||||||
|
|
||||||
def test_get_capabilities_from_registry(self):
|
def test_get_capabilities_from_registry(self):
|
||||||
"""Test get_capabilities returns registry capabilities when available."""
|
"""Test get_capabilities returns registry capabilities when available."""
|
||||||
|
# Save original environment
|
||||||
|
original_env = os.environ.get("OPENROUTER_ALLOWED_MODELS")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Clear any restrictions
|
||||||
|
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||||
|
|
||||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||||
|
|
||||||
# Test with a model that should be in the registry (OpenRouter model) and is allowed by restrictions
|
# Test with a model that should be in the registry (OpenRouter model)
|
||||||
capabilities = provider.get_capabilities("o3") # o3 is in OPENROUTER_ALLOWED_MODELS
|
capabilities = provider.get_capabilities("o3") # o3 is an OpenRouter model
|
||||||
|
|
||||||
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
|
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
|
||||||
assert capabilities.context_window > 0
|
assert capabilities.context_window > 0
|
||||||
@@ -58,6 +65,13 @@ class TestCustomProvider:
|
|||||||
assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true
|
assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true
|
||||||
assert capabilities.context_window > 0
|
assert capabilities.context_window > 0
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original environment
|
||||||
|
if original_env is None:
|
||||||
|
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||||
|
else:
|
||||||
|
os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env
|
||||||
|
|
||||||
def test_get_capabilities_generic_fallback(self):
|
def test_get_capabilities_generic_fallback(self):
|
||||||
"""Test get_capabilities returns generic capabilities for unknown models."""
|
"""Test get_capabilities returns generic capabilities for unknown models."""
|
||||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class TestDIALProvider:
|
|||||||
# Test O3 capabilities
|
# Test O3 capabilities
|
||||||
capabilities = provider.get_capabilities("o3")
|
capabilities = provider.get_capabilities("o3")
|
||||||
assert capabilities.model_name == "o3-2025-04-16"
|
assert capabilities.model_name == "o3-2025-04-16"
|
||||||
assert capabilities.friendly_name == "DIAL"
|
assert capabilities.friendly_name == "DIAL (O3)"
|
||||||
assert capabilities.context_window == 200_000
|
assert capabilities.context_window == 200_000
|
||||||
assert capabilities.provider == ProviderType.DIAL
|
assert capabilities.provider == ProviderType.DIAL
|
||||||
assert capabilities.supports_images is True
|
assert capabilities.supports_images is True
|
||||||
|
|||||||
140
tests/test_disabled_tools.py
Normal file
140
tests/test_disabled_tools.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""Tests for DISABLED_TOOLS environment variable functionality."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from server import (
|
||||||
|
apply_tool_filter,
|
||||||
|
parse_disabled_tools_env,
|
||||||
|
validate_disabled_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Mock the tool classes since we're testing the filtering logic
|
||||||
|
class MockTool:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisabledTools:
|
||||||
|
"""Test suite for DISABLED_TOOLS functionality."""
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_empty(self):
|
||||||
|
"""Empty string returns empty set (no tools disabled)."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": ""}):
|
||||||
|
assert parse_disabled_tools_env() == set()
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_not_set(self):
|
||||||
|
"""Unset variable returns empty set."""
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
# Ensure DISABLED_TOOLS is not in environment
|
||||||
|
if "DISABLED_TOOLS" in os.environ:
|
||||||
|
del os.environ["DISABLED_TOOLS"]
|
||||||
|
assert parse_disabled_tools_env() == set()
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_single(self):
|
||||||
|
"""Single tool name parsed correctly."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug"}):
|
||||||
|
assert parse_disabled_tools_env() == {"debug"}
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_multiple(self):
|
||||||
|
"""Multiple tools with spaces parsed correctly."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug, analyze, refactor"}):
|
||||||
|
assert parse_disabled_tools_env() == {"debug", "analyze", "refactor"}
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_extra_spaces(self):
|
||||||
|
"""Extra spaces and empty items handled correctly."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": " debug , , analyze , "}):
|
||||||
|
assert parse_disabled_tools_env() == {"debug", "analyze"}
|
||||||
|
|
||||||
|
def test_parse_disabled_tools_duplicates(self):
|
||||||
|
"""Duplicate entries handled correctly (set removes duplicates)."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug,analyze,debug"}):
|
||||||
|
assert parse_disabled_tools_env() == {"debug", "analyze"}
|
||||||
|
|
||||||
|
def test_tool_filtering_logic(self):
|
||||||
|
"""Test the complete filtering logic using the actual server functions."""
|
||||||
|
# Simulate ALL_TOOLS
|
||||||
|
ALL_TOOLS = {
|
||||||
|
"chat": MockTool("chat"),
|
||||||
|
"debug": MockTool("debug"),
|
||||||
|
"analyze": MockTool("analyze"),
|
||||||
|
"version": MockTool("version"),
|
||||||
|
"listmodels": MockTool("listmodels"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test case 1: No tools disabled
|
||||||
|
disabled_tools = set()
|
||||||
|
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||||
|
|
||||||
|
assert len(enabled_tools) == 5 # All tools included
|
||||||
|
assert set(enabled_tools.keys()) == set(ALL_TOOLS.keys())
|
||||||
|
|
||||||
|
# Test case 2: Disable some regular tools
|
||||||
|
disabled_tools = {"debug", "analyze"}
|
||||||
|
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||||
|
|
||||||
|
assert len(enabled_tools) == 3 # chat, version, listmodels
|
||||||
|
assert "debug" not in enabled_tools
|
||||||
|
assert "analyze" not in enabled_tools
|
||||||
|
assert "chat" in enabled_tools
|
||||||
|
assert "version" in enabled_tools
|
||||||
|
assert "listmodels" in enabled_tools
|
||||||
|
|
||||||
|
# Test case 3: Attempt to disable essential tools
|
||||||
|
disabled_tools = {"version", "chat"}
|
||||||
|
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||||
|
|
||||||
|
assert "version" in enabled_tools # Essential tool not disabled
|
||||||
|
assert "chat" not in enabled_tools # Regular tool disabled
|
||||||
|
assert "listmodels" in enabled_tools # Essential tool included
|
||||||
|
|
||||||
|
def test_unknown_tools_warning(self, caplog):
|
||||||
|
"""Test that unknown tool names generate appropriate warnings."""
|
||||||
|
ALL_TOOLS = {
|
||||||
|
"chat": MockTool("chat"),
|
||||||
|
"debug": MockTool("debug"),
|
||||||
|
"analyze": MockTool("analyze"),
|
||||||
|
"version": MockTool("version"),
|
||||||
|
"listmodels": MockTool("listmodels"),
|
||||||
|
}
|
||||||
|
disabled_tools = {"chat", "unknown_tool", "another_unknown"}
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
validate_disabled_tools(disabled_tools, ALL_TOOLS)
|
||||||
|
assert "Unknown tools in DISABLED_TOOLS: ['another_unknown', 'unknown_tool']" in caplog.text
|
||||||
|
|
||||||
|
def test_essential_tools_warning(self, caplog):
|
||||||
|
"""Test warning when trying to disable essential tools."""
|
||||||
|
ALL_TOOLS = {
|
||||||
|
"chat": MockTool("chat"),
|
||||||
|
"debug": MockTool("debug"),
|
||||||
|
"analyze": MockTool("analyze"),
|
||||||
|
"version": MockTool("version"),
|
||||||
|
"listmodels": MockTool("listmodels"),
|
||||||
|
}
|
||||||
|
disabled_tools = {"version", "chat", "debug"}
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
validate_disabled_tools(disabled_tools, ALL_TOOLS)
|
||||||
|
assert "Cannot disable essential tools: ['version']" in caplog.text
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_value,expected",
|
||||||
|
[
|
||||||
|
("", set()), # Empty string
|
||||||
|
(" ", set()), # Only spaces
|
||||||
|
(",,,", set()), # Only commas
|
||||||
|
("chat", {"chat"}), # Single tool
|
||||||
|
("chat,debug", {"chat", "debug"}), # Multiple tools
|
||||||
|
("chat, debug, analyze", {"chat", "debug", "analyze"}), # With spaces
|
||||||
|
("chat,debug,chat", {"chat", "debug"}), # Duplicates
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_disabled_tools_parametrized(self, env_value, expected):
|
||||||
|
"""Parametrized tests for various input formats."""
|
||||||
|
with patch.dict(os.environ, {"DISABLED_TOOLS": env_value}):
|
||||||
|
assert parse_disabled_tools_env() == expected
|
||||||
@@ -483,14 +483,14 @@ class TestImageSupportIntegration:
|
|||||||
tool_name="chat",
|
tool_name="chat",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create child thread linked to parent
|
# Create child thread linked to parent using a simple tool
|
||||||
child_thread_id = create_thread("debug", {"child": "context"}, parent_thread_id=parent_thread_id)
|
child_thread_id = create_thread("chat", {"prompt": "child context"}, parent_thread_id=parent_thread_id)
|
||||||
add_turn(
|
add_turn(
|
||||||
thread_id=child_thread_id,
|
thread_id=child_thread_id,
|
||||||
role="user",
|
role="user",
|
||||||
content="Child thread with more images",
|
content="Child thread with more images",
|
||||||
images=["child1.png", "shared.png"], # shared.png appears again (should prioritize newer)
|
images=["child1.png", "shared.png"], # shared.png appears again (should prioritize newer)
|
||||||
tool_name="debug",
|
tool_name="chat",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock child thread context for get_thread call
|
# Mock child thread context for get_thread call
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ class TestModelEnumeration:
|
|||||||
("o3", False), # OpenAI - not available without API key
|
("o3", False), # OpenAI - not available without API key
|
||||||
("grok", False), # X.AI - not available without API key
|
("grok", False), # X.AI - not available without API key
|
||||||
("gemini-2.5-flash", False), # Full Gemini name - not available without API key
|
("gemini-2.5-flash", False), # Full Gemini name - not available without API key
|
||||||
("o4-mini-high", False), # OpenAI variant - not available without API key
|
("o4-mini", False), # OpenAI variant - not available without API key
|
||||||
("grok-3-fast", False), # X.AI variant - not available without API key
|
("grok-3-fast", False), # X.AI variant - not available without API key
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ class TestModelMetadataContinuation:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_multiple_turns_uses_last_assistant_model(self):
|
async def test_multiple_turns_uses_last_assistant_model(self):
|
||||||
"""Test that with multiple turns, the last assistant turn's model is used."""
|
"""Test that with multiple turns, the last assistant turn's model is used."""
|
||||||
thread_id = create_thread("analyze", {"prompt": "analyze this"})
|
thread_id = create_thread("chat", {"prompt": "analyze this"})
|
||||||
|
|
||||||
# Add multiple turns with different models
|
# Add multiple turns with different models
|
||||||
add_turn(thread_id, "assistant", "First response", model_name="gemini-2.5-flash", model_provider="google")
|
add_turn(thread_id, "assistant", "First response", model_name="gemini-2.5-flash", model_provider="google")
|
||||||
@@ -185,11 +185,11 @@ class TestModelMetadataContinuation:
|
|||||||
async def test_thread_chain_model_preservation(self):
|
async def test_thread_chain_model_preservation(self):
|
||||||
"""Test model preservation across thread chains (parent-child relationships)."""
|
"""Test model preservation across thread chains (parent-child relationships)."""
|
||||||
# Create parent thread
|
# Create parent thread
|
||||||
parent_id = create_thread("analyze", {"prompt": "analyze"})
|
parent_id = create_thread("chat", {"prompt": "analyze"})
|
||||||
add_turn(parent_id, "assistant", "Analysis", model_name="gemini-2.5-pro", model_provider="google")
|
add_turn(parent_id, "assistant", "Analysis", model_name="gemini-2.5-pro", model_provider="google")
|
||||||
|
|
||||||
# Create child thread
|
# Create child thread using a simple tool instead of workflow tool
|
||||||
child_id = create_thread("codereview", {"prompt": "review"}, parent_thread_id=parent_id)
|
child_id = create_thread("chat", {"prompt": "review"}, parent_thread_id=parent_id)
|
||||||
|
|
||||||
# Child thread should be able to access parent's model through chain traversal
|
# Child thread should be able to access parent's model through chain traversal
|
||||||
# NOTE: Current implementation only checks current thread (not parent threads)
|
# NOTE: Current implementation only checks current thread (not parent threads)
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ class TestModelRestrictionService:
|
|||||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
|
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
|
||||||
service = ModelRestrictionService()
|
service = ModelRestrictionService()
|
||||||
|
|
||||||
models = ["o3", "o3-mini", "o4-mini", "o4-mini-high"]
|
models = ["o3", "o3-mini", "o4-mini", "o3-pro"]
|
||||||
filtered = service.filter_models(ProviderType.OPENAI, models)
|
filtered = service.filter_models(ProviderType.OPENAI, models)
|
||||||
|
|
||||||
assert filtered == ["o3-mini", "o4-mini"]
|
assert filtered == ["o3-mini", "o4-mini"]
|
||||||
@@ -573,7 +573,7 @@ class TestShorthandRestrictions:
|
|||||||
|
|
||||||
# Other models should not work
|
# Other models should not work
|
||||||
assert not openai_provider.validate_model_name("o3")
|
assert not openai_provider.validate_model_name("o3")
|
||||||
assert not openai_provider.validate_model_name("o4-mini-high")
|
assert not openai_provider.validate_model_name("o3-pro")
|
||||||
|
|
||||||
@patch.dict(
|
@patch.dict(
|
||||||
os.environ,
|
os.environ,
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ class TestO3TemperatureParameterFixSimple:
|
|||||||
provider = OpenAIModelProvider(api_key="test-key")
|
provider = OpenAIModelProvider(api_key="test-key")
|
||||||
|
|
||||||
# Test O3/O4 models that should NOT support temperature parameter
|
# Test O3/O4 models that should NOT support temperature parameter
|
||||||
o3_o4_models = ["o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high"]
|
o3_o4_models = ["o3", "o3-mini", "o3-pro", "o4-mini"]
|
||||||
|
|
||||||
for model in o3_o4_models:
|
for model in o3_o4_models:
|
||||||
capabilities = provider.get_capabilities(model)
|
capabilities = provider.get_capabilities(model)
|
||||||
|
|||||||
@@ -47,14 +47,13 @@ class TestOpenAIProvider:
|
|||||||
assert provider.validate_model_name("o3-mini") is True
|
assert provider.validate_model_name("o3-mini") is True
|
||||||
assert provider.validate_model_name("o3-pro") is True
|
assert provider.validate_model_name("o3-pro") is True
|
||||||
assert provider.validate_model_name("o4-mini") is True
|
assert provider.validate_model_name("o4-mini") is True
|
||||||
assert provider.validate_model_name("o4-mini-high") is True
|
assert provider.validate_model_name("o4-mini") is True
|
||||||
|
|
||||||
# Test valid aliases
|
# Test valid aliases
|
||||||
assert provider.validate_model_name("mini") is True
|
assert provider.validate_model_name("mini") is True
|
||||||
assert provider.validate_model_name("o3mini") is True
|
assert provider.validate_model_name("o3mini") is True
|
||||||
assert provider.validate_model_name("o4mini") is True
|
assert provider.validate_model_name("o4mini") is True
|
||||||
assert provider.validate_model_name("o4minihigh") is True
|
assert provider.validate_model_name("o4mini") is True
|
||||||
assert provider.validate_model_name("o4minihi") is True
|
|
||||||
|
|
||||||
# Test invalid model
|
# Test invalid model
|
||||||
assert provider.validate_model_name("invalid-model") is False
|
assert provider.validate_model_name("invalid-model") is False
|
||||||
@@ -69,15 +68,14 @@ class TestOpenAIProvider:
|
|||||||
assert provider._resolve_model_name("mini") == "o4-mini"
|
assert provider._resolve_model_name("mini") == "o4-mini"
|
||||||
assert provider._resolve_model_name("o3mini") == "o3-mini"
|
assert provider._resolve_model_name("o3mini") == "o3-mini"
|
||||||
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
||||||
assert provider._resolve_model_name("o4minihigh") == "o4-mini-high"
|
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
||||||
assert provider._resolve_model_name("o4minihi") == "o4-mini-high"
|
|
||||||
|
|
||||||
# Test full name passthrough
|
# Test full name passthrough
|
||||||
assert provider._resolve_model_name("o3") == "o3"
|
assert provider._resolve_model_name("o3") == "o3"
|
||||||
assert provider._resolve_model_name("o3-mini") == "o3-mini"
|
assert provider._resolve_model_name("o3-mini") == "o3-mini"
|
||||||
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
|
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
|
||||||
assert provider._resolve_model_name("o4-mini") == "o4-mini"
|
assert provider._resolve_model_name("o4-mini") == "o4-mini"
|
||||||
assert provider._resolve_model_name("o4-mini-high") == "o4-mini-high"
|
assert provider._resolve_model_name("o4-mini") == "o4-mini"
|
||||||
|
|
||||||
def test_get_capabilities_o3(self):
|
def test_get_capabilities_o3(self):
|
||||||
"""Test getting model capabilities for O3."""
|
"""Test getting model capabilities for O3."""
|
||||||
@@ -85,7 +83,7 @@ class TestOpenAIProvider:
|
|||||||
|
|
||||||
capabilities = provider.get_capabilities("o3")
|
capabilities = provider.get_capabilities("o3")
|
||||||
assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities
|
assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities
|
||||||
assert capabilities.friendly_name == "OpenAI"
|
assert capabilities.friendly_name == "OpenAI (O3)"
|
||||||
assert capabilities.context_window == 200_000
|
assert capabilities.context_window == 200_000
|
||||||
assert capabilities.provider == ProviderType.OPENAI
|
assert capabilities.provider == ProviderType.OPENAI
|
||||||
assert not capabilities.supports_extended_thinking
|
assert not capabilities.supports_extended_thinking
|
||||||
@@ -101,8 +99,8 @@ class TestOpenAIProvider:
|
|||||||
provider = OpenAIModelProvider("test-key")
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
capabilities = provider.get_capabilities("mini")
|
capabilities = provider.get_capabilities("mini")
|
||||||
assert capabilities.model_name == "mini" # Capabilities should show original request
|
assert capabilities.model_name == "o4-mini" # Capabilities should show resolved model name
|
||||||
assert capabilities.friendly_name == "OpenAI"
|
assert capabilities.friendly_name == "OpenAI (O4-mini)"
|
||||||
assert capabilities.context_window == 200_000
|
assert capabilities.context_window == 200_000
|
||||||
assert capabilities.provider == ProviderType.OPENAI
|
assert capabilities.provider == ProviderType.OPENAI
|
||||||
|
|
||||||
@@ -184,11 +182,11 @@ class TestOpenAIProvider:
|
|||||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
assert call_kwargs["model"] == "o3-mini"
|
assert call_kwargs["model"] == "o3-mini"
|
||||||
|
|
||||||
# Test o4minihigh -> o4-mini-high
|
# Test o4mini -> o4-mini
|
||||||
mock_response.model = "o4-mini-high"
|
mock_response.model = "o4-mini"
|
||||||
provider.generate_content(prompt="Test", model_name="o4minihigh", temperature=1.0)
|
provider.generate_content(prompt="Test", model_name="o4mini", temperature=1.0)
|
||||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
assert call_kwargs["model"] == "o4-mini-high"
|
assert call_kwargs["model"] == "o4-mini"
|
||||||
|
|
||||||
@patch("providers.openai_compatible.OpenAI")
|
@patch("providers.openai_compatible.OpenAI")
|
||||||
def test_generate_content_no_alias_passthrough(self, mock_openai_class):
|
def test_generate_content_no_alias_passthrough(self, mock_openai_class):
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class TestOpenRouterProvider:
|
|||||||
caps = provider.get_capabilities("o3")
|
caps = provider.get_capabilities("o3")
|
||||||
assert caps.provider == ProviderType.OPENROUTER
|
assert caps.provider == ProviderType.OPENROUTER
|
||||||
assert caps.model_name == "openai/o3" # Resolved name
|
assert caps.model_name == "openai/o3" # Resolved name
|
||||||
assert caps.friendly_name == "OpenRouter"
|
assert caps.friendly_name == "OpenRouter (openai/o3)"
|
||||||
|
|
||||||
# Test with a model not in registry - should get generic capabilities
|
# Test with a model not in registry - should get generic capabilities
|
||||||
caps = provider.get_capabilities("unknown-model")
|
caps = provider.get_capabilities("unknown-model")
|
||||||
@@ -77,7 +77,7 @@ class TestOpenRouterProvider:
|
|||||||
assert provider._resolve_model_name("o3-mini") == "openai/o3-mini"
|
assert provider._resolve_model_name("o3-mini") == "openai/o3-mini"
|
||||||
assert provider._resolve_model_name("o3mini") == "openai/o3-mini"
|
assert provider._resolve_model_name("o3mini") == "openai/o3-mini"
|
||||||
assert provider._resolve_model_name("o4-mini") == "openai/o4-mini"
|
assert provider._resolve_model_name("o4-mini") == "openai/o4-mini"
|
||||||
assert provider._resolve_model_name("o4-mini-high") == "openai/o4-mini-high"
|
assert provider._resolve_model_name("o4-mini") == "openai/o4-mini"
|
||||||
assert provider._resolve_model_name("claude") == "anthropic/claude-sonnet-4"
|
assert provider._resolve_model_name("claude") == "anthropic/claude-sonnet-4"
|
||||||
assert provider._resolve_model_name("mistral") == "mistralai/mistral-large-2411"
|
assert provider._resolve_model_name("mistral") == "mistralai/mistral-large-2411"
|
||||||
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-r1-0528"
|
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-r1-0528"
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import tempfile
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
from providers.base import ModelCapabilities, ProviderType
|
||||||
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
|
|
||||||
class TestOpenRouterModelRegistry:
|
class TestOpenRouterModelRegistry:
|
||||||
@@ -24,7 +24,16 @@ class TestOpenRouterModelRegistry:
|
|||||||
def test_custom_config_path(self):
|
def test_custom_config_path(self):
|
||||||
"""Test registry with custom config path."""
|
"""Test registry with custom config path."""
|
||||||
# Create temporary config
|
# Create temporary config
|
||||||
config_data = {"models": [{"model_name": "test/model-1", "aliases": ["test1", "t1"], "context_window": 4096}]}
|
config_data = {
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"model_name": "test/model-1",
|
||||||
|
"aliases": ["test1", "t1"],
|
||||||
|
"context_window": 4096,
|
||||||
|
"max_output_tokens": 2048,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
@@ -42,7 +51,11 @@ class TestOpenRouterModelRegistry:
|
|||||||
def test_environment_variable_override(self):
|
def test_environment_variable_override(self):
|
||||||
"""Test OPENROUTER_MODELS_PATH environment variable."""
|
"""Test OPENROUTER_MODELS_PATH environment variable."""
|
||||||
# Create custom config
|
# Create custom config
|
||||||
config_data = {"models": [{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192}]}
|
config_data = {
|
||||||
|
"models": [
|
||||||
|
{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192, "max_output_tokens": 4096}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
@@ -110,28 +123,29 @@ class TestOpenRouterModelRegistry:
|
|||||||
assert registry.resolve("non-existent") is None
|
assert registry.resolve("non-existent") is None
|
||||||
|
|
||||||
def test_model_capabilities_conversion(self):
|
def test_model_capabilities_conversion(self):
|
||||||
"""Test conversion to ModelCapabilities."""
|
"""Test that registry returns ModelCapabilities directly."""
|
||||||
registry = OpenRouterModelRegistry()
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
config = registry.resolve("opus")
|
config = registry.resolve("opus")
|
||||||
assert config is not None
|
assert config is not None
|
||||||
|
|
||||||
caps = config.to_capabilities()
|
# Registry now returns ModelCapabilities objects directly
|
||||||
assert caps.provider == ProviderType.OPENROUTER
|
assert config.provider == ProviderType.OPENROUTER
|
||||||
assert caps.model_name == "anthropic/claude-opus-4"
|
assert config.model_name == "anthropic/claude-opus-4"
|
||||||
assert caps.friendly_name == "OpenRouter"
|
assert config.friendly_name == "OpenRouter (anthropic/claude-opus-4)"
|
||||||
assert caps.context_window == 200000
|
assert config.context_window == 200000
|
||||||
assert not caps.supports_extended_thinking
|
assert not config.supports_extended_thinking
|
||||||
|
|
||||||
def test_duplicate_alias_detection(self):
|
def test_duplicate_alias_detection(self):
|
||||||
"""Test that duplicate aliases are detected."""
|
"""Test that duplicate aliases are detected."""
|
||||||
config_data = {
|
config_data = {
|
||||||
"models": [
|
"models": [
|
||||||
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096},
|
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096, "max_output_tokens": 2048},
|
||||||
{
|
{
|
||||||
"model_name": "test/model-2",
|
"model_name": "test/model-2",
|
||||||
"aliases": ["DUPE"], # Same alias, different case
|
"aliases": ["DUPE"], # Same alias, different case
|
||||||
"context_window": 8192,
|
"context_window": 8192,
|
||||||
|
"max_output_tokens": 2048,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -199,19 +213,23 @@ class TestOpenRouterModelRegistry:
|
|||||||
|
|
||||||
def test_model_with_all_capabilities(self):
|
def test_model_with_all_capabilities(self):
|
||||||
"""Test model with all capability flags."""
|
"""Test model with all capability flags."""
|
||||||
config = OpenRouterModelConfig(
|
from providers.base import create_temperature_constraint
|
||||||
|
|
||||||
|
caps = ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENROUTER,
|
||||||
model_name="test/full-featured",
|
model_name="test/full-featured",
|
||||||
|
friendly_name="OpenRouter (test/full-featured)",
|
||||||
aliases=["full"],
|
aliases=["full"],
|
||||||
context_window=128000,
|
context_window=128000,
|
||||||
|
max_output_tokens=8192,
|
||||||
supports_extended_thinking=True,
|
supports_extended_thinking=True,
|
||||||
supports_system_prompts=True,
|
supports_system_prompts=True,
|
||||||
supports_streaming=True,
|
supports_streaming=True,
|
||||||
supports_function_calling=True,
|
supports_function_calling=True,
|
||||||
supports_json_mode=True,
|
supports_json_mode=True,
|
||||||
description="Fully featured test model",
|
description="Fully featured test model",
|
||||||
|
temperature_constraint=create_temperature_constraint("range"),
|
||||||
)
|
)
|
||||||
|
|
||||||
caps = config.to_capabilities()
|
|
||||||
assert caps.context_window == 128000
|
assert caps.context_window == 128000
|
||||||
assert caps.supports_extended_thinking
|
assert caps.supports_extended_thinking
|
||||||
assert caps.supports_system_prompts
|
assert caps.supports_system_prompts
|
||||||
|
|||||||
79
tests/test_parse_model_option.py
Normal file
79
tests/test_parse_model_option.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Tests for parse_model_option function."""
|
||||||
|
|
||||||
|
from server import parse_model_option
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseModelOption:
|
||||||
|
"""Test cases for model option parsing."""
|
||||||
|
|
||||||
|
def test_openrouter_free_suffix_preserved(self):
|
||||||
|
"""Test that OpenRouter :free suffix is preserved as part of model name."""
|
||||||
|
model, option = parse_model_option("openai/gpt-3.5-turbo:free")
|
||||||
|
assert model == "openai/gpt-3.5-turbo:free"
|
||||||
|
assert option is None
|
||||||
|
|
||||||
|
def test_openrouter_beta_suffix_preserved(self):
|
||||||
|
"""Test that OpenRouter :beta suffix is preserved as part of model name."""
|
||||||
|
model, option = parse_model_option("anthropic/claude-3-opus:beta")
|
||||||
|
assert model == "anthropic/claude-3-opus:beta"
|
||||||
|
assert option is None
|
||||||
|
|
||||||
|
def test_openrouter_preview_suffix_preserved(self):
|
||||||
|
"""Test that OpenRouter :preview suffix is preserved as part of model name."""
|
||||||
|
model, option = parse_model_option("google/gemini-pro:preview")
|
||||||
|
assert model == "google/gemini-pro:preview"
|
||||||
|
assert option is None
|
||||||
|
|
||||||
|
def test_ollama_tag_parsed_as_option(self):
|
||||||
|
"""Test that Ollama tags are parsed as options."""
|
||||||
|
model, option = parse_model_option("llama3.2:latest")
|
||||||
|
assert model == "llama3.2"
|
||||||
|
assert option == "latest"
|
||||||
|
|
||||||
|
def test_consensus_stance_parsed_as_option(self):
|
||||||
|
"""Test that consensus stances are parsed as options."""
|
||||||
|
model, option = parse_model_option("o3:for")
|
||||||
|
assert model == "o3"
|
||||||
|
assert option == "for"
|
||||||
|
|
||||||
|
model, option = parse_model_option("gemini-2.5-pro:against")
|
||||||
|
assert model == "gemini-2.5-pro"
|
||||||
|
assert option == "against"
|
||||||
|
|
||||||
|
def test_openrouter_unknown_suffix_parsed_as_option(self):
|
||||||
|
"""Test that unknown suffixes on OpenRouter models are parsed as options."""
|
||||||
|
model, option = parse_model_option("openai/gpt-4:custom-tag")
|
||||||
|
assert model == "openai/gpt-4"
|
||||||
|
assert option == "custom-tag"
|
||||||
|
|
||||||
|
def test_plain_model_name(self):
|
||||||
|
"""Test plain model names without colons."""
|
||||||
|
model, option = parse_model_option("gpt-4")
|
||||||
|
assert model == "gpt-4"
|
||||||
|
assert option is None
|
||||||
|
|
||||||
|
def test_url_not_parsed(self):
|
||||||
|
"""Test that URLs are not parsed for options."""
|
||||||
|
model, option = parse_model_option("http://localhost:8080")
|
||||||
|
assert model == "http://localhost:8080"
|
||||||
|
assert option is None
|
||||||
|
|
||||||
|
def test_whitespace_handling(self):
|
||||||
|
"""Test that whitespace is properly stripped."""
|
||||||
|
model, option = parse_model_option(" openai/gpt-3.5-turbo:free ")
|
||||||
|
assert model == "openai/gpt-3.5-turbo:free"
|
||||||
|
assert option is None
|
||||||
|
|
||||||
|
model, option = parse_model_option(" llama3.2 : latest ")
|
||||||
|
assert model == "llama3.2"
|
||||||
|
assert option == "latest"
|
||||||
|
|
||||||
|
def test_case_insensitive_suffix_matching(self):
|
||||||
|
"""Test that OpenRouter suffix matching is case-insensitive."""
|
||||||
|
model, option = parse_model_option("openai/gpt-3.5-turbo:FREE")
|
||||||
|
assert model == "openai/gpt-3.5-turbo:FREE" # Original case preserved
|
||||||
|
assert option is None
|
||||||
|
|
||||||
|
model, option = parse_model_option("openai/gpt-3.5-turbo:Free")
|
||||||
|
assert model == "openai/gpt-3.5-turbo:Free" # Original case preserved
|
||||||
|
assert option is None
|
||||||
@@ -58,7 +58,13 @@ class TestProviderRoutingBugs:
|
|||||||
"""
|
"""
|
||||||
# Save original environment
|
# Save original environment
|
||||||
original_env = {}
|
original_env = {}
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
for key in [
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS",
|
||||||
|
]:
|
||||||
original_env[key] = os.environ.get(key)
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -66,6 +72,7 @@ class TestProviderRoutingBugs:
|
|||||||
os.environ.pop("GEMINI_API_KEY", None) # No Google API key
|
os.environ.pop("GEMINI_API_KEY", None) # No Google API key
|
||||||
os.environ.pop("OPENAI_API_KEY", None)
|
os.environ.pop("OPENAI_API_KEY", None)
|
||||||
os.environ.pop("XAI_API_KEY", None)
|
os.environ.pop("XAI_API_KEY", None)
|
||||||
|
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||||
|
|
||||||
# Register only OpenRouter provider (like in server.py:configure_providers)
|
# Register only OpenRouter provider (like in server.py:configure_providers)
|
||||||
@@ -113,12 +120,24 @@ class TestProviderRoutingBugs:
|
|||||||
"""
|
"""
|
||||||
# Save original environment
|
# Save original environment
|
||||||
original_env = {}
|
original_env = {}
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
for key in [
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS",
|
||||||
|
]:
|
||||||
original_env[key] = os.environ.get(key)
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up scenario: NO API keys at all
|
# Set up scenario: NO API keys at all
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
for key in [
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS",
|
||||||
|
]:
|
||||||
os.environ.pop(key, None)
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
# Create tool to test fallback logic
|
# Create tool to test fallback logic
|
||||||
@@ -151,7 +170,13 @@ class TestProviderRoutingBugs:
|
|||||||
"""
|
"""
|
||||||
# Save original environment
|
# Save original environment
|
||||||
original_env = {}
|
original_env = {}
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
for key in [
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS",
|
||||||
|
]:
|
||||||
original_env[key] = os.environ.get(key)
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -160,6 +185,7 @@ class TestProviderRoutingBugs:
|
|||||||
os.environ["OPENAI_API_KEY"] = "test-openai-key"
|
os.environ["OPENAI_API_KEY"] = "test-openai-key"
|
||||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||||
os.environ.pop("XAI_API_KEY", None)
|
os.environ.pop("XAI_API_KEY", None)
|
||||||
|
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||||
|
|
||||||
# Register providers in priority order (like server.py)
|
# Register providers in priority order (like server.py)
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
|
|||||||
@@ -215,9 +215,7 @@ class TestOpenAIProvider:
|
|||||||
assert provider.validate_model_name("o3-mini") # Backwards compatibility
|
assert provider.validate_model_name("o3-mini") # Backwards compatibility
|
||||||
assert provider.validate_model_name("o4-mini")
|
assert provider.validate_model_name("o4-mini")
|
||||||
assert provider.validate_model_name("o4mini")
|
assert provider.validate_model_name("o4mini")
|
||||||
assert provider.validate_model_name("o4-mini-high")
|
assert provider.validate_model_name("o4-mini")
|
||||||
assert provider.validate_model_name("o4minihigh")
|
|
||||||
assert provider.validate_model_name("o4minihi")
|
|
||||||
assert not provider.validate_model_name("gpt-4o")
|
assert not provider.validate_model_name("gpt-4o")
|
||||||
assert not provider.validate_model_name("invalid-model")
|
assert not provider.validate_model_name("invalid-model")
|
||||||
|
|
||||||
@@ -229,4 +227,4 @@ class TestOpenAIProvider:
|
|||||||
assert not provider.supports_thinking_mode("o3mini")
|
assert not provider.supports_thinking_mode("o3mini")
|
||||||
assert not provider.supports_thinking_mode("o3-mini")
|
assert not provider.supports_thinking_mode("o3-mini")
|
||||||
assert not provider.supports_thinking_mode("o4-mini")
|
assert not provider.supports_thinking_mode("o4-mini")
|
||||||
assert not provider.supports_thinking_mode("o4-mini-high")
|
assert not provider.supports_thinking_mode("o4-mini")
|
||||||
|
|||||||
205
tests/test_supported_models_aliases.py
Normal file
205
tests/test_supported_models_aliases.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Test the SUPPORTED_MODELS aliases structure across all providers."""
|
||||||
|
|
||||||
|
from providers.dial import DIALModelProvider
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestSupportedModelsAliases:
|
||||||
|
"""Test that all providers have correctly structured SUPPORTED_MODELS with aliases."""
|
||||||
|
|
||||||
|
def test_gemini_provider_aliases(self):
|
||||||
|
"""Test Gemini provider's alias structure."""
|
||||||
|
provider = GeminiModelProvider("test-key")
|
||||||
|
|
||||||
|
# Check that all models have ModelCapabilities with aliases
|
||||||
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
|
# Test specific aliases
|
||||||
|
assert "flash" in provider.SUPPORTED_MODELS["gemini-2.5-flash"].aliases
|
||||||
|
assert "pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro"].aliases
|
||||||
|
assert "flash-2.0" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
|
||||||
|
assert "flash2" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
|
||||||
|
assert "flashlite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
|
||||||
|
assert "flash-lite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
|
||||||
|
|
||||||
|
# Test alias resolution
|
||||||
|
assert provider._resolve_model_name("flash") == "gemini-2.5-flash"
|
||||||
|
assert provider._resolve_model_name("pro") == "gemini-2.5-pro"
|
||||||
|
assert provider._resolve_model_name("flash-2.0") == "gemini-2.0-flash"
|
||||||
|
assert provider._resolve_model_name("flash2") == "gemini-2.0-flash"
|
||||||
|
assert provider._resolve_model_name("flashlite") == "gemini-2.0-flash-lite"
|
||||||
|
|
||||||
|
# Test case insensitive resolution
|
||||||
|
assert provider._resolve_model_name("Flash") == "gemini-2.5-flash"
|
||||||
|
assert provider._resolve_model_name("PRO") == "gemini-2.5-pro"
|
||||||
|
|
||||||
|
def test_openai_provider_aliases(self):
|
||||||
|
"""Test OpenAI provider's alias structure."""
|
||||||
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Check that all models have ModelCapabilities with aliases
|
||||||
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
|
# Test specific aliases
|
||||||
|
assert "mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
||||||
|
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
||||||
|
assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases
|
||||||
|
assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro-2025-06-10"].aliases
|
||||||
|
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
||||||
|
assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1-2025-04-14"].aliases
|
||||||
|
|
||||||
|
# Test alias resolution
|
||||||
|
assert provider._resolve_model_name("mini") == "o4-mini"
|
||||||
|
assert provider._resolve_model_name("o3mini") == "o3-mini"
|
||||||
|
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
|
||||||
|
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
||||||
|
assert provider._resolve_model_name("gpt4.1") == "gpt-4.1-2025-04-14"
|
||||||
|
|
||||||
|
# Test case insensitive resolution
|
||||||
|
assert provider._resolve_model_name("Mini") == "o4-mini"
|
||||||
|
assert provider._resolve_model_name("O3MINI") == "o3-mini"
|
||||||
|
|
||||||
|
def test_xai_provider_aliases(self):
|
||||||
|
"""Test XAI provider's alias structure."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Check that all models have ModelCapabilities with aliases
|
||||||
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
|
# Test specific aliases
|
||||||
|
assert "grok" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
||||||
|
assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
||||||
|
assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
||||||
|
assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
||||||
|
|
||||||
|
# Test alias resolution
|
||||||
|
assert provider._resolve_model_name("grok") == "grok-3"
|
||||||
|
assert provider._resolve_model_name("grok3") == "grok-3"
|
||||||
|
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
||||||
|
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
|
||||||
|
|
||||||
|
# Test case insensitive resolution
|
||||||
|
assert provider._resolve_model_name("Grok") == "grok-3"
|
||||||
|
assert provider._resolve_model_name("GROKFAST") == "grok-3-fast"
|
||||||
|
|
||||||
|
def test_dial_provider_aliases(self):
|
||||||
|
"""Test DIAL provider's alias structure."""
|
||||||
|
provider = DIALModelProvider("test-key")
|
||||||
|
|
||||||
|
# Check that all models have ModelCapabilities with aliases
|
||||||
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
|
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||||
|
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||||
|
|
||||||
|
# Test specific aliases
|
||||||
|
assert "o3" in provider.SUPPORTED_MODELS["o3-2025-04-16"].aliases
|
||||||
|
assert "o4-mini" in provider.SUPPORTED_MODELS["o4-mini-2025-04-16"].aliases
|
||||||
|
assert "sonnet-4" in provider.SUPPORTED_MODELS["anthropic.claude-sonnet-4-20250514-v1:0"].aliases
|
||||||
|
assert "opus-4" in provider.SUPPORTED_MODELS["anthropic.claude-opus-4-20250514-v1:0"].aliases
|
||||||
|
assert "gemini-2.5-pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro-preview-05-06"].aliases
|
||||||
|
|
||||||
|
# Test alias resolution
|
||||||
|
assert provider._resolve_model_name("o3") == "o3-2025-04-16"
|
||||||
|
assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16"
|
||||||
|
assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||||
|
assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0"
|
||||||
|
|
||||||
|
# Test case insensitive resolution
|
||||||
|
assert provider._resolve_model_name("O3") == "o3-2025-04-16"
|
||||||
|
assert provider._resolve_model_name("SONNET-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||||
|
|
||||||
|
def test_list_models_includes_aliases(self):
|
||||||
|
"""Test that list_models returns both base models and aliases."""
|
||||||
|
# Test Gemini
|
||||||
|
gemini_provider = GeminiModelProvider("test-key")
|
||||||
|
gemini_models = gemini_provider.list_models(respect_restrictions=False)
|
||||||
|
assert "gemini-2.5-flash" in gemini_models
|
||||||
|
assert "flash" in gemini_models
|
||||||
|
assert "gemini-2.5-pro" in gemini_models
|
||||||
|
assert "pro" in gemini_models
|
||||||
|
|
||||||
|
# Test OpenAI
|
||||||
|
openai_provider = OpenAIModelProvider("test-key")
|
||||||
|
openai_models = openai_provider.list_models(respect_restrictions=False)
|
||||||
|
assert "o4-mini" in openai_models
|
||||||
|
assert "mini" in openai_models
|
||||||
|
assert "o3-mini" in openai_models
|
||||||
|
assert "o3mini" in openai_models
|
||||||
|
|
||||||
|
# Test XAI
|
||||||
|
xai_provider = XAIModelProvider("test-key")
|
||||||
|
xai_models = xai_provider.list_models(respect_restrictions=False)
|
||||||
|
assert "grok-3" in xai_models
|
||||||
|
assert "grok" in xai_models
|
||||||
|
assert "grok-3-fast" in xai_models
|
||||||
|
assert "grokfast" in xai_models
|
||||||
|
|
||||||
|
# Test DIAL
|
||||||
|
dial_provider = DIALModelProvider("test-key")
|
||||||
|
dial_models = dial_provider.list_models(respect_restrictions=False)
|
||||||
|
assert "o3-2025-04-16" in dial_models
|
||||||
|
assert "o3" in dial_models
|
||||||
|
|
||||||
|
def test_list_all_known_models_includes_aliases(self):
|
||||||
|
"""Test that list_all_known_models returns all models and aliases in lowercase."""
|
||||||
|
# Test Gemini
|
||||||
|
gemini_provider = GeminiModelProvider("test-key")
|
||||||
|
gemini_all = gemini_provider.list_all_known_models()
|
||||||
|
assert "gemini-2.5-flash" in gemini_all
|
||||||
|
assert "flash" in gemini_all
|
||||||
|
assert "gemini-2.5-pro" in gemini_all
|
||||||
|
assert "pro" in gemini_all
|
||||||
|
# All should be lowercase
|
||||||
|
assert all(model == model.lower() for model in gemini_all)
|
||||||
|
|
||||||
|
# Test OpenAI
|
||||||
|
openai_provider = OpenAIModelProvider("test-key")
|
||||||
|
openai_all = openai_provider.list_all_known_models()
|
||||||
|
assert "o4-mini" in openai_all
|
||||||
|
assert "mini" in openai_all
|
||||||
|
assert "o3-mini" in openai_all
|
||||||
|
assert "o3mini" in openai_all
|
||||||
|
# All should be lowercase
|
||||||
|
assert all(model == model.lower() for model in openai_all)
|
||||||
|
|
||||||
|
def test_no_string_shorthand_in_supported_models(self):
|
||||||
|
"""Test that no provider has string-based shorthands anymore."""
|
||||||
|
providers = [
|
||||||
|
GeminiModelProvider("test-key"),
|
||||||
|
OpenAIModelProvider("test-key"),
|
||||||
|
XAIModelProvider("test-key"),
|
||||||
|
DIALModelProvider("test-key"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for provider in providers:
|
||||||
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
|
# All values must be ModelCapabilities objects, not strings or dicts
|
||||||
|
from providers.base import ModelCapabilities
|
||||||
|
|
||||||
|
assert isinstance(config, ModelCapabilities), (
|
||||||
|
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "
|
||||||
|
f"must be a ModelCapabilities object, not {type(config).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_resolve_returns_original_if_not_found(self):
|
||||||
|
"""Test that _resolve_model_name returns original name if alias not found."""
|
||||||
|
providers = [
|
||||||
|
GeminiModelProvider("test-key"),
|
||||||
|
OpenAIModelProvider("test-key"),
|
||||||
|
XAIModelProvider("test-key"),
|
||||||
|
DIALModelProvider("test-key"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for provider in providers:
|
||||||
|
# Test with unknown model name
|
||||||
|
assert provider._resolve_model_name("unknown-model") == "unknown-model"
|
||||||
|
assert provider._resolve_model_name("gpt-4") == "gpt-4"
|
||||||
|
assert provider._resolve_model_name("claude-3") == "claude-3"
|
||||||
@@ -48,7 +48,13 @@ class TestWorkflowMetadata:
|
|||||||
"""
|
"""
|
||||||
# Save original environment
|
# Save original environment
|
||||||
original_env = {}
|
original_env = {}
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
for key in [
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS",
|
||||||
|
]:
|
||||||
original_env[key] = os.environ.get(key)
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -56,6 +62,7 @@ class TestWorkflowMetadata:
|
|||||||
os.environ.pop("GEMINI_API_KEY", None)
|
os.environ.pop("GEMINI_API_KEY", None)
|
||||||
os.environ.pop("OPENAI_API_KEY", None)
|
os.environ.pop("OPENAI_API_KEY", None)
|
||||||
os.environ.pop("XAI_API_KEY", None)
|
os.environ.pop("XAI_API_KEY", None)
|
||||||
|
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||||
|
|
||||||
# Register OpenRouter provider
|
# Register OpenRouter provider
|
||||||
@@ -124,7 +131,13 @@ class TestWorkflowMetadata:
|
|||||||
"""
|
"""
|
||||||
# Save original environment
|
# Save original environment
|
||||||
original_env = {}
|
original_env = {}
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
for key in [
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS",
|
||||||
|
]:
|
||||||
original_env[key] = os.environ.get(key)
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -132,6 +145,7 @@ class TestWorkflowMetadata:
|
|||||||
os.environ.pop("GEMINI_API_KEY", None)
|
os.environ.pop("GEMINI_API_KEY", None)
|
||||||
os.environ.pop("OPENAI_API_KEY", None)
|
os.environ.pop("OPENAI_API_KEY", None)
|
||||||
os.environ.pop("XAI_API_KEY", None)
|
os.environ.pop("XAI_API_KEY", None)
|
||||||
|
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||||
|
|
||||||
# Register OpenRouter provider
|
# Register OpenRouter provider
|
||||||
@@ -182,6 +196,15 @@ class TestWorkflowMetadata:
|
|||||||
"""
|
"""
|
||||||
Test that workflow tools handle metadata gracefully when model context is missing.
|
Test that workflow tools handle metadata gracefully when model context is missing.
|
||||||
"""
|
"""
|
||||||
|
# Save original environment
|
||||||
|
original_env = {}
|
||||||
|
for key in ["OPENROUTER_ALLOWED_MODELS"]:
|
||||||
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Clear any restrictions
|
||||||
|
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||||
|
|
||||||
# Create debug tool
|
# Create debug tool
|
||||||
debug_tool = DebugIssueTool()
|
debug_tool = DebugIssueTool()
|
||||||
|
|
||||||
@@ -220,6 +243,14 @@ class TestWorkflowMetadata:
|
|||||||
assert metadata["model_used"] == "flash", "model_used should be from request"
|
assert metadata["model_used"] == "flash", "model_used should be from request"
|
||||||
assert metadata["provider_used"] == "unknown", "provider_used should be 'unknown' in fallback"
|
assert metadata["provider_used"] == "unknown", "provider_used should be 'unknown' in fallback"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original environment
|
||||||
|
for key, value in original_env.items():
|
||||||
|
if value is None:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
else:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
@pytest.mark.no_mock_provider
|
@pytest.mark.no_mock_provider
|
||||||
def test_workflow_metadata_preserves_existing_response_fields(self):
|
def test_workflow_metadata_preserves_existing_response_fields(self):
|
||||||
"""
|
"""
|
||||||
@@ -227,7 +258,13 @@ class TestWorkflowMetadata:
|
|||||||
"""
|
"""
|
||||||
# Save original environment
|
# Save original environment
|
||||||
original_env = {}
|
original_env = {}
|
||||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
for key in [
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"XAI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS",
|
||||||
|
]:
|
||||||
original_env[key] = os.environ.get(key)
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -235,6 +272,7 @@ class TestWorkflowMetadata:
|
|||||||
os.environ.pop("GEMINI_API_KEY", None)
|
os.environ.pop("GEMINI_API_KEY", None)
|
||||||
os.environ.pop("OPENAI_API_KEY", None)
|
os.environ.pop("OPENAI_API_KEY", None)
|
||||||
os.environ.pop("XAI_API_KEY", None)
|
os.environ.pop("XAI_API_KEY", None)
|
||||||
|
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||||
|
|
||||||
# Register OpenRouter provider
|
# Register OpenRouter provider
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
capabilities = provider.get_capabilities("grok-3")
|
capabilities = provider.get_capabilities("grok-3")
|
||||||
assert capabilities.model_name == "grok-3"
|
assert capabilities.model_name == "grok-3"
|
||||||
assert capabilities.friendly_name == "X.AI"
|
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
||||||
assert capabilities.context_window == 131_072
|
assert capabilities.context_window == 131_072
|
||||||
assert capabilities.provider == ProviderType.XAI
|
assert capabilities.provider == ProviderType.XAI
|
||||||
assert not capabilities.supports_extended_thinking
|
assert not capabilities.supports_extended_thinking
|
||||||
@@ -96,7 +96,7 @@ class TestXAIProvider:
|
|||||||
|
|
||||||
capabilities = provider.get_capabilities("grok-3-fast")
|
capabilities = provider.get_capabilities("grok-3-fast")
|
||||||
assert capabilities.model_name == "grok-3-fast"
|
assert capabilities.model_name == "grok-3-fast"
|
||||||
assert capabilities.friendly_name == "X.AI"
|
assert capabilities.friendly_name == "X.AI (Grok 3 Fast)"
|
||||||
assert capabilities.context_window == 131_072
|
assert capabilities.context_window == 131_072
|
||||||
assert capabilities.provider == ProviderType.XAI
|
assert capabilities.provider == ProviderType.XAI
|
||||||
assert not capabilities.supports_extended_thinking
|
assert not capabilities.supports_extended_thinking
|
||||||
@@ -212,31 +212,34 @@ class TestXAIProvider:
|
|||||||
assert provider.FRIENDLY_NAME == "X.AI"
|
assert provider.FRIENDLY_NAME == "X.AI"
|
||||||
|
|
||||||
capabilities = provider.get_capabilities("grok-3")
|
capabilities = provider.get_capabilities("grok-3")
|
||||||
assert capabilities.friendly_name == "X.AI"
|
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
||||||
|
|
||||||
def test_supported_models_structure(self):
|
def test_supported_models_structure(self):
|
||||||
"""Test that SUPPORTED_MODELS has the correct structure."""
|
"""Test that SUPPORTED_MODELS has the correct structure."""
|
||||||
provider = XAIModelProvider("test-key")
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
# Check that all expected models are present
|
# Check that all expected base models are present
|
||||||
assert "grok-3" in provider.SUPPORTED_MODELS
|
assert "grok-3" in provider.SUPPORTED_MODELS
|
||||||
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
||||||
assert "grok" in provider.SUPPORTED_MODELS
|
|
||||||
assert "grok3" in provider.SUPPORTED_MODELS
|
|
||||||
assert "grokfast" in provider.SUPPORTED_MODELS
|
|
||||||
assert "grok3fast" in provider.SUPPORTED_MODELS
|
|
||||||
|
|
||||||
# Check model configs have required fields
|
# Check model configs have required fields
|
||||||
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
from providers.base import ModelCapabilities
|
||||||
assert isinstance(grok3_config, dict)
|
|
||||||
assert "context_window" in grok3_config
|
|
||||||
assert "supports_extended_thinking" in grok3_config
|
|
||||||
assert grok3_config["context_window"] == 131_072
|
|
||||||
assert grok3_config["supports_extended_thinking"] is False
|
|
||||||
|
|
||||||
# Check shortcuts point to full names
|
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
||||||
assert provider.SUPPORTED_MODELS["grok"] == "grok-3"
|
assert isinstance(grok3_config, ModelCapabilities)
|
||||||
assert provider.SUPPORTED_MODELS["grokfast"] == "grok-3-fast"
|
assert hasattr(grok3_config, "context_window")
|
||||||
|
assert hasattr(grok3_config, "supports_extended_thinking")
|
||||||
|
assert hasattr(grok3_config, "aliases")
|
||||||
|
assert grok3_config.context_window == 131_072
|
||||||
|
assert grok3_config.supports_extended_thinking is False
|
||||||
|
|
||||||
|
# Check aliases are correctly structured
|
||||||
|
assert "grok" in grok3_config.aliases
|
||||||
|
assert "grok3" in grok3_config.aliases
|
||||||
|
|
||||||
|
grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"]
|
||||||
|
assert "grok3fast" in grok3fast_config.aliases
|
||||||
|
assert "grokfast" in grok3fast_config.aliases
|
||||||
|
|
||||||
@patch("providers.openai_compatible.OpenAI")
|
@patch("providers.openai_compatible.OpenAI")
|
||||||
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
|
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
|
||||||
|
|||||||
@@ -99,15 +99,11 @@ class ListModelsTool(BaseTool):
|
|||||||
output_lines.append("**Status**: Configured and available")
|
output_lines.append("**Status**: Configured and available")
|
||||||
output_lines.append("\n**Models**:")
|
output_lines.append("\n**Models**:")
|
||||||
|
|
||||||
# Get models from the provider's SUPPORTED_MODELS
|
# Get models from the provider's model configurations
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
for model_name, capabilities in provider.get_model_configurations().items():
|
||||||
# Skip alias entries (string values)
|
# Get description and context from the ModelCapabilities object
|
||||||
if isinstance(config, str):
|
description = capabilities.description or "No description available"
|
||||||
continue
|
context_window = capabilities.context_window
|
||||||
|
|
||||||
# Get description and context from the model config
|
|
||||||
description = config.get("description", "No description available")
|
|
||||||
context_window = config.get("context_window", 0)
|
|
||||||
|
|
||||||
# Format context window
|
# Format context window
|
||||||
if context_window >= 1_000_000:
|
if context_window >= 1_000_000:
|
||||||
@@ -133,13 +129,14 @@ class ListModelsTool(BaseTool):
|
|||||||
|
|
||||||
# Show aliases for this provider
|
# Show aliases for this provider
|
||||||
aliases = []
|
aliases = []
|
||||||
for alias_name, target in provider.SUPPORTED_MODELS.items():
|
for model_name, capabilities in provider.get_model_configurations().items():
|
||||||
if isinstance(target, str): # This is an alias
|
if capabilities.aliases:
|
||||||
aliases.append(f"- `{alias_name}` → `{target}`")
|
for alias in capabilities.aliases:
|
||||||
|
aliases.append(f"- `{alias}` → `{model_name}`")
|
||||||
|
|
||||||
if aliases:
|
if aliases:
|
||||||
output_lines.append("\n**Aliases**:")
|
output_lines.append("\n**Aliases**:")
|
||||||
output_lines.extend(aliases)
|
output_lines.extend(sorted(aliases)) # Sort for consistent output
|
||||||
else:
|
else:
|
||||||
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
||||||
|
|
||||||
@@ -237,7 +234,7 @@ class ListModelsTool(BaseTool):
|
|||||||
|
|
||||||
for alias in registry.list_aliases():
|
for alias in registry.list_aliases():
|
||||||
config = registry.resolve(alias)
|
config = registry.resolve(alias)
|
||||||
if config and hasattr(config, "is_custom") and config.is_custom:
|
if config and config.is_custom:
|
||||||
custom_models.append((alias, config))
|
custom_models.append((alias, config))
|
||||||
|
|
||||||
if custom_models:
|
if custom_models:
|
||||||
|
|||||||
@@ -256,8 +256,8 @@ class BaseTool(ABC):
|
|||||||
# Find all custom models (is_custom=true)
|
# Find all custom models (is_custom=true)
|
||||||
for alias in registry.list_aliases():
|
for alias in registry.list_aliases():
|
||||||
config = registry.resolve(alias)
|
config = registry.resolve(alias)
|
||||||
# Use hasattr for defensive programming - is_custom is optional with default False
|
# Check if this is a custom model that requires custom endpoints
|
||||||
if config and hasattr(config, "is_custom") and config.is_custom:
|
if config and config.is_custom:
|
||||||
if alias not in all_models:
|
if alias not in all_models:
|
||||||
all_models.append(alias)
|
all_models.append(alias)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -311,12 +311,16 @@ class BaseTool(ABC):
|
|||||||
ProviderType.GOOGLE: "Gemini models",
|
ProviderType.GOOGLE: "Gemini models",
|
||||||
ProviderType.OPENAI: "OpenAI models",
|
ProviderType.OPENAI: "OpenAI models",
|
||||||
ProviderType.XAI: "X.AI GROK models",
|
ProviderType.XAI: "X.AI GROK models",
|
||||||
|
ProviderType.DIAL: "DIAL models",
|
||||||
ProviderType.CUSTOM: "Custom models",
|
ProviderType.CUSTOM: "Custom models",
|
||||||
ProviderType.OPENROUTER: "OpenRouter models",
|
ProviderType.OPENROUTER: "OpenRouter models",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Check available providers and add their model descriptions
|
# Check available providers and add their model descriptions
|
||||||
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI]:
|
|
||||||
|
# Start with native providers
|
||||||
|
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]:
|
||||||
|
# Only if this is registered / available
|
||||||
provider = ModelProviderRegistry.get_provider(provider_type)
|
provider = ModelProviderRegistry.get_provider(provider_type)
|
||||||
if provider:
|
if provider:
|
||||||
provider_section_added = False
|
provider_section_added = False
|
||||||
@@ -324,13 +328,13 @@ class BaseTool(ABC):
|
|||||||
try:
|
try:
|
||||||
# Get model config to extract description
|
# Get model config to extract description
|
||||||
model_config = provider.SUPPORTED_MODELS.get(model_name)
|
model_config = provider.SUPPORTED_MODELS.get(model_name)
|
||||||
if isinstance(model_config, dict) and "description" in model_config:
|
if model_config and model_config.description:
|
||||||
if not provider_section_added:
|
if not provider_section_added:
|
||||||
model_desc_parts.append(
|
model_desc_parts.append(
|
||||||
f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:"
|
f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:"
|
||||||
)
|
)
|
||||||
provider_section_added = True
|
provider_section_added = True
|
||||||
model_desc_parts.append(f"- '{model_name}': {model_config['description']}")
|
model_desc_parts.append(f"- '{model_name}': {model_config.description}")
|
||||||
except Exception:
|
except Exception:
|
||||||
# Skip models without descriptions
|
# Skip models without descriptions
|
||||||
continue
|
continue
|
||||||
@@ -346,8 +350,8 @@ class BaseTool(ABC):
|
|||||||
# Find all custom models (is_custom=true)
|
# Find all custom models (is_custom=true)
|
||||||
for alias in registry.list_aliases():
|
for alias in registry.list_aliases():
|
||||||
config = registry.resolve(alias)
|
config = registry.resolve(alias)
|
||||||
# Use hasattr for defensive programming - is_custom is optional with default False
|
# Check if this is a custom model that requires custom endpoints
|
||||||
if config and hasattr(config, "is_custom") and config.is_custom:
|
if config and config.is_custom:
|
||||||
# Format context window
|
# Format context window
|
||||||
context_tokens = config.context_window
|
context_tokens = config.context_window
|
||||||
if context_tokens >= 1_000_000:
|
if context_tokens >= 1_000_000:
|
||||||
|
|||||||
@@ -128,6 +128,10 @@ class ModelRestrictionService:
|
|||||||
|
|
||||||
allowed_set = self.restrictions[provider_type]
|
allowed_set = self.restrictions[provider_type]
|
||||||
|
|
||||||
|
if len(allowed_set) == 0:
|
||||||
|
# Empty set - allowed
|
||||||
|
return True
|
||||||
|
|
||||||
# Check both the resolved name and original name (if different)
|
# Check both the resolved name and original name (if different)
|
||||||
names_to_check = {model_name.lower()}
|
names_to_check = {model_name.lower()}
|
||||||
if original_name and original_name.lower() != model_name.lower():
|
if original_name and original_name.lower() != model_name.lower():
|
||||||
|
|||||||
Reference in New Issue
Block a user