Merge branch 'feat-local_support_with_UTF-8_encoding-update' of https://github.com/GiGiDKR/zen-mcp-server into feat-local_support_with_UTF-8_encoding-update
This commit is contained in:
38
CLAUDE.md
38
CLAUDE.md
@@ -128,7 +128,28 @@ python communication_simulator_test.py
|
||||
python communication_simulator_test.py --verbose
|
||||
```
|
||||
|
||||
#### Run Individual Simulator Tests (Recommended)
|
||||
#### Quick Test Mode (Recommended for Time-Limited Testing)
|
||||
```bash
|
||||
# Run quick test mode - 6 essential tests that provide maximum functionality coverage
|
||||
python communication_simulator_test.py --quick
|
||||
|
||||
# Run quick test mode with verbose output
|
||||
python communication_simulator_test.py --quick --verbose
|
||||
```
|
||||
|
||||
**Quick mode runs these 6 essential tests:**
|
||||
- `cross_tool_continuation` - Cross-tool conversation memory testing (chat, thinkdeep, codereview, analyze, debug)
|
||||
- `conversation_chain_validation` - Core conversation threading and memory validation
|
||||
- `consensus_workflow_accurate` - Consensus tool with flash model and stance testing
|
||||
- `codereview_validation` - CodeReview tool with flash model and multi-step workflows
|
||||
- `planner_validation` - Planner tool with flash model and complex planning workflows
|
||||
- `token_allocation_validation` - Token allocation and conversation history buildup testing
|
||||
|
||||
**Why these 6 tests:** They cover the core functionality including conversation memory (`utils/conversation_memory.py`), chat tool functionality, file processing and deduplication, model selection (flash/flashlite/o3), and cross-tool conversation workflows. These tests validate the most critical parts of the system in minimal time.
|
||||
|
||||
**Note:** Some workflow tools (analyze, codereview, planner, consensus, etc.) require specific workflow parameters and may need individual testing rather than quick mode testing.
|
||||
|
||||
#### Run Individual Simulator Tests (For Detailed Testing)
|
||||
```bash
|
||||
# List all available tests
|
||||
python communication_simulator_test.py --list-tests
|
||||
@@ -223,15 +244,17 @@ python -m pytest tests/ -v
|
||||
#### After Making Changes
|
||||
1. Run quality checks again: `./code_quality_checks.sh`
|
||||
2. Run integration tests locally: `./run_integration_tests.sh`
|
||||
3. Run relevant simulator tests: `python communication_simulator_test.py --individual <test_name>`
|
||||
4. Check logs for any issues: `tail -n 100 logs/mcp_server.log`
|
||||
5. Restart Claude session to use updated code
|
||||
3. Run quick test mode for fast validation: `python communication_simulator_test.py --quick`
|
||||
4. Run relevant specific simulator tests if needed: `python communication_simulator_test.py --individual <test_name>`
|
||||
5. Check logs for any issues: `tail -n 100 logs/mcp_server.log`
|
||||
6. Restart Claude session to use updated code
|
||||
|
||||
#### Before Committing/PR
|
||||
1. Final quality check: `./code_quality_checks.sh`
|
||||
2. Run integration tests: `./run_integration_tests.sh`
|
||||
3. Run full simulator test suite: `./run_integration_tests.sh --with-simulator`
|
||||
4. Verify all tests pass 100%
|
||||
3. Run quick test mode: `python communication_simulator_test.py --quick`
|
||||
4. Run full simulator test suite (optional): `./run_integration_tests.sh --with-simulator`
|
||||
5. Verify all tests pass 100%
|
||||
|
||||
### Common Troubleshooting
|
||||
|
||||
@@ -250,6 +273,9 @@ which python
|
||||
|
||||
#### Test Failures
|
||||
```bash
|
||||
# First try quick test mode to see if it's a general issue
|
||||
python communication_simulator_test.py --quick --verbose
|
||||
|
||||
# Run individual failing test with verbose output
|
||||
python communication_simulator_test.py --individual <test_name> --verbose
|
||||
|
||||
|
||||
@@ -409,7 +409,7 @@ for most debugging workflows, as Claude is usually able to confidently find the
|
||||
When in doubt, you can always follow up with a new prompt and ask Claude to share its findings with another model:
|
||||
|
||||
```text
|
||||
Use continuation with thinkdeep, share details with o4-mini-high to find out what the best fix is for this
|
||||
Use continuation with thinkdeep, share details with o4-mini to find out what the best fix is for this
|
||||
```
|
||||
|
||||
**[📖 Read More](docs/tools/debug.md)** - Step-by-step investigation methodology with workflow enforcement
|
||||
|
||||
@@ -38,6 +38,15 @@ Available tests:
|
||||
debug_validation - Debug tool validation with actual bugs
|
||||
conversation_chain_validation - Conversation chain continuity validation
|
||||
|
||||
Quick Test Mode (for time-limited testing):
|
||||
Use --quick to run the essential 6 tests that provide maximum coverage:
|
||||
- cross_tool_continuation (cross-tool conversation memory)
|
||||
- basic_conversation (basic chat functionality)
|
||||
- content_validation (content validation and deduplication)
|
||||
- model_thinking_config (flash/flashlite model testing)
|
||||
- o3_model_selection (o3 model selection testing)
|
||||
- per_tool_deduplication (file deduplication for individual tools)
|
||||
|
||||
Examples:
|
||||
# Run all tests
|
||||
python communication_simulator_test.py
|
||||
@@ -48,6 +57,9 @@ Examples:
|
||||
# Run a single test individually (with full standalone setup)
|
||||
python communication_simulator_test.py --individual content_validation
|
||||
|
||||
# Run quick test mode (essential 6 tests for time-limited testing)
|
||||
python communication_simulator_test.py --quick
|
||||
|
||||
# Force setup standalone server environment before running tests
|
||||
python communication_simulator_test.py --setup
|
||||
|
||||
@@ -68,21 +80,48 @@ class CommunicationSimulator:
|
||||
"""Simulates real-world Claude CLI communication with MCP Gemini server"""
|
||||
|
||||
def __init__(
|
||||
self, verbose: bool = False, keep_logs: bool = False, selected_tests: list[str] = None, setup: bool = False
|
||||
self,
|
||||
verbose: bool = False,
|
||||
keep_logs: bool = False,
|
||||
selected_tests: list[str] = None,
|
||||
setup: bool = False,
|
||||
quick_mode: bool = False,
|
||||
):
|
||||
self.verbose = verbose
|
||||
self.keep_logs = keep_logs
|
||||
self.selected_tests = selected_tests or []
|
||||
self.setup = setup
|
||||
self.quick_mode = quick_mode
|
||||
self.temp_dir = None
|
||||
self.server_process = None
|
||||
self.python_path = self._get_python_path()
|
||||
|
||||
# Configure logging first
|
||||
log_level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Import test registry
|
||||
from simulator_tests import TEST_REGISTRY
|
||||
|
||||
self.test_registry = TEST_REGISTRY
|
||||
|
||||
# Define quick mode tests (essential tests for time-limited testing)
|
||||
# Focus on tests that work with current tool configurations
|
||||
self.quick_mode_tests = [
|
||||
"cross_tool_continuation", # Cross-tool conversation memory
|
||||
"basic_conversation", # Basic chat functionality
|
||||
"content_validation", # Content validation and deduplication
|
||||
"model_thinking_config", # Flash/flashlite model testing
|
||||
"o3_model_selection", # O3 model selection testing
|
||||
"per_tool_deduplication", # File deduplication for individual tools
|
||||
]
|
||||
|
||||
# If quick mode is enabled, override selected_tests
|
||||
if self.quick_mode:
|
||||
self.selected_tests = self.quick_mode_tests
|
||||
self.logger.info(f"Quick mode enabled - running {len(self.quick_mode_tests)} essential tests")
|
||||
|
||||
# Available test methods mapping
|
||||
self.available_tests = {
|
||||
name: self._create_test_runner(test_class) for name, test_class in self.test_registry.items()
|
||||
@@ -91,11 +130,6 @@ class CommunicationSimulator:
|
||||
# Test result tracking
|
||||
self.test_results = dict.fromkeys(self.test_registry.keys(), False)
|
||||
|
||||
# Configure logging
|
||||
log_level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _get_python_path(self) -> str:
|
||||
"""Get the Python path for the virtual environment"""
|
||||
current_dir = os.getcwd()
|
||||
@@ -415,6 +449,9 @@ def parse_arguments():
|
||||
parser.add_argument("--tests", "-t", nargs="+", help="Specific tests to run (space-separated)")
|
||||
parser.add_argument("--list-tests", action="store_true", help="List available tests and exit")
|
||||
parser.add_argument("--individual", "-i", help="Run a single test individually")
|
||||
parser.add_argument(
|
||||
"--quick", "-q", action="store_true", help="Run quick test mode (6 essential tests for time-limited testing)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--setup", action="store_true", help="Force setup standalone server environment using run-server.sh"
|
||||
)
|
||||
@@ -492,7 +529,11 @@ def main():
|
||||
|
||||
# Initialize simulator consistently for all use cases
|
||||
simulator = CommunicationSimulator(
|
||||
verbose=args.verbose, keep_logs=args.keep_logs, selected_tests=args.tests, setup=args.setup
|
||||
verbose=args.verbose,
|
||||
keep_logs=args.keep_logs,
|
||||
selected_tests=args.tests,
|
||||
setup=args.setup,
|
||||
quick_mode=args.quick,
|
||||
)
|
||||
|
||||
# Determine execution mode and run
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
"model_name": "The model identifier - OpenRouter format (e.g., 'anthropic/claude-opus-4') or custom model name (e.g., 'llama3.2')",
|
||||
"aliases": "Array of short names users can type instead of the full model name",
|
||||
"context_window": "Total number of tokens the model can process (input + output combined)",
|
||||
"max_output_tokens": "Maximum number of tokens the model can generate in a single response",
|
||||
"supports_extended_thinking": "Whether the model supports extended reasoning tokens (currently none do via OpenRouter or custom APIs)",
|
||||
"supports_json_mode": "Whether the model can guarantee valid JSON output",
|
||||
"supports_function_calling": "Whether the model supports function/tool calling",
|
||||
@@ -36,6 +37,7 @@
|
||||
"model_name": "my-local-model",
|
||||
"aliases": ["shortname", "nickname", "abbrev"],
|
||||
"context_window": 128000,
|
||||
"max_output_tokens": 32768,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": true,
|
||||
"supports_function_calling": true,
|
||||
@@ -52,6 +54,7 @@
|
||||
"model_name": "anthropic/claude-opus-4",
|
||||
"aliases": ["opus", "claude-opus", "claude4-opus", "claude-4-opus"],
|
||||
"context_window": 200000,
|
||||
"max_output_tokens": 64000,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": false,
|
||||
"supports_function_calling": false,
|
||||
@@ -63,6 +66,7 @@
|
||||
"model_name": "anthropic/claude-sonnet-4",
|
||||
"aliases": ["sonnet", "claude-sonnet", "claude4-sonnet", "claude-4-sonnet", "claude"],
|
||||
"context_window": 200000,
|
||||
"max_output_tokens": 64000,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": false,
|
||||
"supports_function_calling": false,
|
||||
@@ -74,6 +78,7 @@
|
||||
"model_name": "anthropic/claude-3.5-haiku",
|
||||
"aliases": ["haiku", "claude-haiku", "claude3-haiku", "claude-3-haiku"],
|
||||
"context_window": 200000,
|
||||
"max_output_tokens": 64000,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": false,
|
||||
"supports_function_calling": false,
|
||||
@@ -85,6 +90,7 @@
|
||||
"model_name": "google/gemini-2.5-pro",
|
||||
"aliases": ["pro","gemini-pro", "gemini", "pro-openrouter"],
|
||||
"context_window": 1048576,
|
||||
"max_output_tokens": 65536,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": true,
|
||||
"supports_function_calling": false,
|
||||
@@ -96,6 +102,7 @@
|
||||
"model_name": "google/gemini-2.5-flash",
|
||||
"aliases": ["flash","gemini-flash", "flash-openrouter", "flash-2.5"],
|
||||
"context_window": 1048576,
|
||||
"max_output_tokens": 65536,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": true,
|
||||
"supports_function_calling": false,
|
||||
@@ -107,6 +114,7 @@
|
||||
"model_name": "mistralai/mistral-large-2411",
|
||||
"aliases": ["mistral-large", "mistral"],
|
||||
"context_window": 128000,
|
||||
"max_output_tokens": 32000,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": true,
|
||||
"supports_function_calling": true,
|
||||
@@ -118,6 +126,7 @@
|
||||
"model_name": "meta-llama/llama-3-70b",
|
||||
"aliases": ["llama", "llama3", "llama3-70b", "llama-70b", "llama3-openrouter"],
|
||||
"context_window": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": false,
|
||||
"supports_function_calling": false,
|
||||
@@ -129,6 +138,7 @@
|
||||
"model_name": "deepseek/deepseek-r1-0528",
|
||||
"aliases": ["deepseek-r1", "deepseek", "r1", "deepseek-thinking"],
|
||||
"context_window": 65536,
|
||||
"max_output_tokens": 32768,
|
||||
"supports_extended_thinking": true,
|
||||
"supports_json_mode": true,
|
||||
"supports_function_calling": false,
|
||||
@@ -140,6 +150,7 @@
|
||||
"model_name": "perplexity/llama-3-sonar-large-32k-online",
|
||||
"aliases": ["perplexity", "sonar", "perplexity-online"],
|
||||
"context_window": 32768,
|
||||
"max_output_tokens": 32768,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": false,
|
||||
"supports_function_calling": false,
|
||||
@@ -151,6 +162,7 @@
|
||||
"model_name": "openai/o3",
|
||||
"aliases": ["o3"],
|
||||
"context_window": 200000,
|
||||
"max_output_tokens": 100000,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": true,
|
||||
"supports_function_calling": true,
|
||||
@@ -164,6 +176,7 @@
|
||||
"model_name": "openai/o3-mini",
|
||||
"aliases": ["o3-mini", "o3mini"],
|
||||
"context_window": 200000,
|
||||
"max_output_tokens": 100000,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": true,
|
||||
"supports_function_calling": true,
|
||||
@@ -177,6 +190,7 @@
|
||||
"model_name": "openai/o3-mini-high",
|
||||
"aliases": ["o3-mini-high", "o3mini-high"],
|
||||
"context_window": 200000,
|
||||
"max_output_tokens": 100000,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": true,
|
||||
"supports_function_calling": true,
|
||||
@@ -190,6 +204,7 @@
|
||||
"model_name": "openai/o3-pro",
|
||||
"aliases": ["o3-pro", "o3pro"],
|
||||
"context_window": 200000,
|
||||
"max_output_tokens": 100000,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": true,
|
||||
"supports_function_calling": true,
|
||||
@@ -203,6 +218,7 @@
|
||||
"model_name": "openai/o4-mini",
|
||||
"aliases": ["o4-mini", "o4mini"],
|
||||
"context_window": 200000,
|
||||
"max_output_tokens": 100000,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": true,
|
||||
"supports_function_calling": true,
|
||||
@@ -212,23 +228,11 @@
|
||||
"temperature_constraint": "fixed",
|
||||
"description": "OpenAI's o4-mini model - optimized for shorter contexts with rapid reasoning and vision"
|
||||
},
|
||||
{
|
||||
"model_name": "openai/o4-mini-high",
|
||||
"aliases": ["o4-mini-high", "o4mini-high", "o4minihigh", "o4minihi"],
|
||||
"context_window": 200000,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_images": true,
|
||||
"max_image_size_mb": 20.0,
|
||||
"supports_temperature": false,
|
||||
"temperature_constraint": "fixed",
|
||||
"description": "OpenAI's o4-mini with high reasoning effort - enhanced for complex tasks with vision"
|
||||
},
|
||||
{
|
||||
"model_name": "llama3.2",
|
||||
"aliases": ["local-llama", "local", "llama3.2", "ollama-llama"],
|
||||
"context_window": 128000,
|
||||
"max_output_tokens": 64000,
|
||||
"supports_extended_thinking": false,
|
||||
"supports_json_mode": false,
|
||||
"supports_function_calling": false,
|
||||
|
||||
@@ -14,7 +14,7 @@ import os
|
||||
# These values are used in server responses and for tracking releases
|
||||
# IMPORTANT: This is the single source of truth for version and author info
|
||||
# Semantic versioning: MAJOR.MINOR.PATCH
|
||||
__version__ = "5.6.1"
|
||||
__version__ = "5.7.0"
|
||||
# Last update date in ISO format
|
||||
__updated__ = "2025-06-23"
|
||||
# Primary maintainer
|
||||
|
||||
@@ -38,7 +38,6 @@ Regardless of your default configuration, you can specify models per request:
|
||||
| **`o3`** | OpenAI | 200K tokens | Strong logical reasoning | Debugging logic errors, systematic analysis |
|
||||
| **`o3-mini`** | OpenAI | 200K tokens | Balanced speed/quality | Moderate complexity tasks |
|
||||
| **`o4-mini`** | OpenAI | 200K tokens | Latest reasoning model | Optimized for shorter contexts |
|
||||
| **`o4-mini-high`** | OpenAI | 200K tokens | Enhanced reasoning | Complex tasks requiring deeper analysis |
|
||||
| **`gpt4.1`** | OpenAI | 1M tokens | Latest GPT-4 with extended context | Large codebase analysis, comprehensive reviews |
|
||||
| **`llama`** (Llama 3.2) | Custom/Local | 128K tokens | Local inference, privacy | On-device analysis, cost-free processing |
|
||||
| **Any model** | OpenRouter | Varies | Access to GPT-4, Claude, Llama, etc. | User-specified or based on task requirements |
|
||||
@@ -69,7 +68,7 @@ OPENAI_ALLOWED_MODELS=o4-mini
|
||||
|
||||
# High-performance: Quality over cost
|
||||
GOOGLE_ALLOWED_MODELS=pro
|
||||
OPENAI_ALLOWED_MODELS=o3,o4-mini-high
|
||||
OPENAI_ALLOWED_MODELS=o3,o4-mini
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
@@ -144,7 +143,7 @@ All tools that work with files support **both individual files and entire direct
|
||||
**`analyze`** - Analyze files or directories
|
||||
- `files`: List of file paths or directories (required)
|
||||
- `question`: What to analyze (required)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `analysis_type`: architecture|performance|security|quality|general
|
||||
- `output_format`: summary|detailed|actionable
|
||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||
@@ -159,7 +158,7 @@ All tools that work with files support **both individual files and entire direct
|
||||
|
||||
**`codereview`** - Review code files or directories
|
||||
- `files`: List of file paths or directories (required)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `review_type`: full|security|performance|quick
|
||||
- `focus_on`: Specific aspects to focus on
|
||||
- `standards`: Coding standards to enforce
|
||||
@@ -175,7 +174,7 @@ All tools that work with files support **both individual files and entire direct
|
||||
|
||||
**`debug`** - Debug with file context
|
||||
- `error_description`: Description of the issue (required)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `error_context`: Stack trace or logs
|
||||
- `files`: Files or directories related to the issue
|
||||
- `runtime_info`: Environment details
|
||||
@@ -191,7 +190,7 @@ All tools that work with files support **both individual files and entire direct
|
||||
|
||||
**`thinkdeep`** - Extended analysis with file context
|
||||
- `current_analysis`: Your current thinking (required)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `problem_context`: Additional context
|
||||
- `focus_areas`: Specific aspects to focus on
|
||||
- `files`: Files or directories for context
|
||||
@@ -207,7 +206,7 @@ All tools that work with files support **both individual files and entire direct
|
||||
**`testgen`** - Comprehensive test generation with edge case coverage
|
||||
- `files`: Code files or directories to generate tests for (required)
|
||||
- `prompt`: Description of what to test, testing objectives, and scope (required)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `test_examples`: Optional existing test files as style/pattern reference
|
||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||
|
||||
@@ -222,7 +221,7 @@ All tools that work with files support **both individual files and entire direct
|
||||
- `files`: Code files or directories to analyze for refactoring opportunities (required)
|
||||
- `prompt`: Description of refactoring goals, context, and specific areas of focus (required)
|
||||
- `refactor_type`: codesmells|decompose|modernize|organization (required)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security')
|
||||
- `style_guide_examples`: Optional existing code files to use as style/pattern reference
|
||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||
|
||||
@@ -63,7 +63,7 @@ CUSTOM_MODEL_NAME=llama3.2 # Default model
|
||||
|
||||
**Default Model Selection:**
|
||||
```env
|
||||
# Options: 'auto', 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', 'o4-mini-high', etc.
|
||||
# Options: 'auto', 'pro', 'flash', 'o3', 'o3-mini', 'o4-mini', etc.
|
||||
DEFAULT_MODEL=auto # Claude picks best model for each task (recommended)
|
||||
```
|
||||
|
||||
@@ -74,7 +74,6 @@ DEFAULT_MODEL=auto # Claude picks best model for each task (recommended)
|
||||
- **`o3`**: Strong logical reasoning (200K context)
|
||||
- **`o3-mini`**: Balanced speed/quality (200K context)
|
||||
- **`o4-mini`**: Latest reasoning model, optimized for shorter contexts
|
||||
- **`o4-mini-high`**: Enhanced O4 with higher reasoning effort
|
||||
- **`grok`**: GROK-3 advanced reasoning (131K context)
|
||||
- **Custom models**: via OpenRouter or local APIs
|
||||
|
||||
@@ -120,7 +119,6 @@ OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral
|
||||
- `o3` (200K context, high reasoning)
|
||||
- `o3-mini` (200K context, balanced)
|
||||
- `o4-mini` (200K context, latest balanced)
|
||||
- `o4-mini-high` (200K context, enhanced reasoning)
|
||||
- `mini` (shorthand for o4-mini)
|
||||
|
||||
**Gemini Models:**
|
||||
|
||||
@@ -65,7 +65,7 @@ This workflow ensures methodical analysis before expert insights, resulting in d
|
||||
|
||||
**Initial Configuration (used in step 1):**
|
||||
- `prompt`: What to analyze or look for (required)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `analysis_type`: architecture|performance|security|quality|general (default: general)
|
||||
- `output_format`: summary|detailed|actionable (default: detailed)
|
||||
- `temperature`: Temperature for analysis (0-1, default 0.2)
|
||||
|
||||
@@ -33,7 +33,7 @@ and then debate with the other models to give me a final verdict
|
||||
## Tool Parameters
|
||||
|
||||
- `prompt`: Your question or discussion topic (required)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `files`: Optional files for context (absolute paths)
|
||||
- `images`: Optional images for visual context (absolute paths)
|
||||
- `temperature`: Response creativity (0-1, default 0.5)
|
||||
|
||||
@@ -80,7 +80,7 @@ The above prompt will simultaneously run two separate `codereview` tools with tw
|
||||
|
||||
**Initial Review Configuration (used in step 1):**
|
||||
- `prompt`: User's summary of what the code does, expected behavior, constraints, and review objectives (required)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `review_type`: full|security|performance|quick (default: full)
|
||||
- `focus_on`: Specific aspects to focus on (e.g., "security vulnerabilities", "performance bottlenecks")
|
||||
- `standards`: Coding standards to enforce (e.g., "PEP8", "ESLint", "Google Style Guide")
|
||||
|
||||
@@ -73,7 +73,7 @@ This structured approach ensures Claude performs methodical groundwork before ex
|
||||
- `images`: Visual debugging materials (error screenshots, logs, etc.)
|
||||
|
||||
**Model Selection:**
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini (default: server default)
|
||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||
- `use_websearch`: Enable web search for documentation and solutions (default: true)
|
||||
- `use_assistant_model`: Whether to use expert analysis phase (default: true, set to false to use Claude only)
|
||||
|
||||
@@ -135,7 +135,7 @@ Use zen and perform a thorough precommit ensuring there aren't any new regressio
|
||||
**Initial Configuration (used in step 1):**
|
||||
- `path`: Starting directory to search for repos (default: current directory, absolute path required)
|
||||
- `prompt`: The original user request description for the changes (required for context)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `compare_to`: Compare against a branch/tag instead of local changes (optional)
|
||||
- `severity_filter`: critical|high|medium|low|all (default: all)
|
||||
- `include_staged`: Include staged changes in the review (default: true)
|
||||
|
||||
@@ -103,7 +103,7 @@ This results in Claude first performing its own expert analysis, encouraging it
|
||||
**Initial Configuration (used in step 1):**
|
||||
- `prompt`: Description of refactoring goals, context, and specific areas of focus (required)
|
||||
- `refactor_type`: codesmells|decompose|modernize|organization (default: codesmells)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `focus_areas`: Specific areas to focus on (e.g., 'performance', 'readability', 'maintainability', 'security')
|
||||
- `style_guide_examples`: Optional existing code files to use as style/pattern reference (absolute paths)
|
||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||
|
||||
@@ -86,7 +86,7 @@ security remediation plan using planner
|
||||
- `images`: Architecture diagrams, security documentation, or visual references
|
||||
|
||||
**Initial Security Configuration (used in step 1):**
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `security_scope`: Application context, technology stack, and security boundary definition (required)
|
||||
- `threat_level`: low|medium|high|critical (default: medium) - determines assessment depth and urgency
|
||||
- `compliance_requirements`: List of compliance frameworks to assess against (e.g., ["PCI DSS", "SOC2"])
|
||||
|
||||
@@ -70,7 +70,7 @@ Test generation excels with extended reasoning models like Gemini Pro or O3, whi
|
||||
|
||||
**Initial Configuration (used in step 1):**
|
||||
- `prompt`: Description of what to test, testing objectives, and specific scope/focus areas (required)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `test_examples`: Optional existing test files or directories to use as style/pattern reference (absolute paths)
|
||||
- `thinking_mode`: minimal|low|medium|high|max (default: medium, Gemini only)
|
||||
- `use_assistant_model`: Whether to use expert test generation phase (default: true, set to false to use Claude only)
|
||||
|
||||
@@ -30,7 +30,7 @@ with the best architecture for my project
|
||||
## Tool Parameters
|
||||
|
||||
- `prompt`: Your current thinking/analysis to extend and validate (required)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|o4-mini-high|gpt4.1 (default: server default)
|
||||
- `model`: auto|pro|flash|o3|o3-mini|o4-mini|gpt4.1 (default: server default)
|
||||
- `problem_context`: Additional context about the problem or goal
|
||||
- `focus_areas`: Specific aspects to focus on (architecture, performance, security, etc.)
|
||||
- `files`: Optional file paths or directories for additional context (absolute paths)
|
||||
|
||||
@@ -132,6 +132,7 @@ class ModelCapabilities:
|
||||
model_name: str
|
||||
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
||||
context_window: int # Total context window size in tokens
|
||||
max_output_tokens: int # Maximum output tokens per request
|
||||
supports_extended_thinking: bool = False
|
||||
supports_system_prompts: bool = True
|
||||
supports_streaming: bool = True
|
||||
@@ -140,6 +141,19 @@ class ModelCapabilities:
|
||||
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
||||
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
||||
|
||||
# Additional fields for comprehensive model information
|
||||
description: str = "" # Human-readable description of the model
|
||||
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
|
||||
|
||||
# JSON mode support (for providers that support structured output)
|
||||
supports_json_mode: bool = False
|
||||
|
||||
# Thinking mode support (for models with thinking capabilities)
|
||||
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
|
||||
|
||||
# Custom model flag (for models that only work with custom endpoints)
|
||||
is_custom: bool = False # Whether this model requires custom API endpoints
|
||||
|
||||
# Temperature constraint object - preferred way to define temperature limits
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
@@ -251,7 +265,7 @@ class ModelProvider(ABC):
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
# Check if model supports temperature at all
|
||||
if hasattr(capabilities, "supports_temperature") and not capabilities.supports_temperature:
|
||||
if not capabilities.supports_temperature:
|
||||
return None
|
||||
|
||||
# Get temperature range
|
||||
@@ -290,19 +304,109 @@ class ModelProvider(ABC):
|
||||
"""Check if the model supports extended thinking mode."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations for this provider.
|
||||
|
||||
This is a hook method that subclasses can override to provide
|
||||
their model configurations from different sources.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
# Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects)
|
||||
if hasattr(self, "SUPPORTED_MODELS"):
|
||||
return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)}
|
||||
return {}
|
||||
|
||||
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||
"""Get all model aliases for this provider.
|
||||
|
||||
This is a hook method that subclasses can override to provide
|
||||
aliases from different sources.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their list of aliases
|
||||
"""
|
||||
# Default implementation extracts from ModelCapabilities objects
|
||||
aliases = {}
|
||||
for model_name, capabilities in self.get_model_configurations().items():
|
||||
if capabilities.aliases:
|
||||
aliases[model_name] = capabilities.aliases
|
||||
return aliases
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
|
||||
This implementation uses the hook methods to support different
|
||||
model configuration sources.
|
||||
|
||||
Args:
|
||||
model_name: Model name that may be an alias
|
||||
|
||||
Returns:
|
||||
Resolved model name
|
||||
"""
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
|
||||
# First check if it's already a base model name (case-sensitive exact match)
|
||||
if model_name in model_configs:
|
||||
return model_name
|
||||
|
||||
# Check case-insensitively for both base models and aliases
|
||||
model_name_lower = model_name.lower()
|
||||
|
||||
# Check base model names case-insensitively
|
||||
for base_model in model_configs:
|
||||
if base_model.lower() == model_name_lower:
|
||||
return base_model
|
||||
|
||||
# Check aliases from the hook method
|
||||
all_aliases = self.get_all_model_aliases()
|
||||
for base_model, aliases in all_aliases.items():
|
||||
if any(alias.lower() == model_name_lower for alias in aliases):
|
||||
return base_model
|
||||
|
||||
# If not found, return as-is
|
||||
return model_name
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
This implementation uses the get_model_configurations() hook
|
||||
to support different model configuration sources.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
pass
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
|
||||
for model_name in model_configs:
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
|
||||
# Add the base model
|
||||
models.append(model_name)
|
||||
|
||||
# Get aliases from the hook method
|
||||
all_aliases = self.get_all_model_aliases()
|
||||
for model_name, aliases in all_aliases.items():
|
||||
# Only add aliases for models that passed restriction check
|
||||
if model_name in models:
|
||||
models.extend(aliases)
|
||||
|
||||
return models
|
||||
|
||||
@abstractmethod
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
@@ -312,21 +416,22 @@ class ModelProvider(ABC):
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
pass
|
||||
all_models = set()
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
# Get model configurations from the hook method
|
||||
model_configs = self.get_model_configurations()
|
||||
|
||||
Base implementation returns the model name unchanged.
|
||||
Subclasses should override to provide alias resolution.
|
||||
# Add all base model names
|
||||
for model_name in model_configs:
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
Args:
|
||||
model_name: Model name that may be an alias
|
||||
# Get aliases from the hook method and add them
|
||||
all_aliases = self.get_all_model_aliases()
|
||||
for _model_name, aliases in all_aliases.items():
|
||||
for alias in aliases:
|
||||
all_models.add(alias.lower())
|
||||
|
||||
Returns:
|
||||
Resolved model name
|
||||
"""
|
||||
return model_name
|
||||
return list(all_models)
|
||||
|
||||
def close(self):
|
||||
"""Clean up any resources held by the provider.
|
||||
|
||||
@@ -158,6 +158,7 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
model_name=resolved_name,
|
||||
friendly_name=f"{self.FRIENDLY_NAME} ({resolved_name})",
|
||||
context_window=32_768, # Conservative default
|
||||
max_output_tokens=32_768, # Conservative default max output
|
||||
supports_extended_thinking=False, # Most custom models don't support this
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
@@ -187,7 +188,7 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
Returns:
|
||||
True if model is intended for custom/local endpoint
|
||||
"""
|
||||
logging.debug(f"Custom provider validating model: '{model_name}'")
|
||||
# logging.debug(f"Custom provider validating model: '{model_name}'")
|
||||
|
||||
# Try to resolve through registry first
|
||||
config = self._registry.resolve(model_name)
|
||||
@@ -195,12 +196,12 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
model_id = config.model_name
|
||||
# Use explicit is_custom flag for clean validation
|
||||
if config.is_custom:
|
||||
logging.debug(f"Model '{model_name}' -> '{model_id}' validated via registry (custom model)")
|
||||
logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' validated via registry")
|
||||
return True
|
||||
else:
|
||||
# This is a cloud/OpenRouter model - CustomProvider should NOT handle these
|
||||
# Let OpenRouter provider handle them instead
|
||||
logging.debug(f"Model '{model_name}' -> '{model_id}' rejected (cloud model, defer to OpenRouter)")
|
||||
# logging.debug(f"... [Custom] Model '{model_name}' -> '{model_id}' not custom (defer to OpenRouter)")
|
||||
return False
|
||||
|
||||
# Handle version tags for unknown models (e.g., "my-model:latest")
|
||||
@@ -268,65 +269,50 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
"""Check if the model supports extended thinking mode.
|
||||
|
||||
Most custom/local models don't support extended thinking.
|
||||
|
||||
Args:
|
||||
model_name: Model to check
|
||||
|
||||
Returns:
|
||||
False (custom models generally don't support thinking mode)
|
||||
True if model supports thinking mode, False otherwise
|
||||
"""
|
||||
# Check if model is in registry
|
||||
config = self._registry.resolve(model_name) if self._registry else None
|
||||
if config and config.is_custom:
|
||||
# Trust the config from custom_models.json
|
||||
return config.supports_extended_thinking
|
||||
|
||||
# Default to False for unknown models
|
||||
return False
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations from the registry.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
For CustomProvider, we convert registry configurations to ModelCapabilities objects.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
configs = {}
|
||||
|
||||
if self._registry:
|
||||
# Get all models from the registry
|
||||
all_models = self._registry.list_models()
|
||||
aliases = self._registry.list_aliases()
|
||||
|
||||
# Add models that are validated by the custom provider
|
||||
for model_name in all_models + aliases:
|
||||
# Use the provider's validation logic to determine if this model
|
||||
# is appropriate for the custom endpoint
|
||||
# Get all models from registry
|
||||
for model_name in self._registry.list_models():
|
||||
# Only include custom models that this provider validates
|
||||
if self.validate_model_name(model_name):
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and config.is_custom:
|
||||
# Use ModelCapabilities directly from registry
|
||||
configs[model_name] = config
|
||||
|
||||
models.append(model_name)
|
||||
return configs
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||
"""Get all model aliases from the registry.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
Dictionary mapping model names to their list of aliases
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
if self._registry:
|
||||
# Get all models and aliases from the registry
|
||||
all_models.update(model.lower() for model in self._registry.list_models())
|
||||
all_models.update(alias.lower() for alias in self._registry.list_aliases())
|
||||
|
||||
# For each alias, also add its target
|
||||
for alias in self._registry.list_aliases():
|
||||
config = self._registry.resolve(alias)
|
||||
if config:
|
||||
all_models.add(config.model_name.lower())
|
||||
|
||||
return list(all_models)
|
||||
# Since aliases are now included in the configurations,
|
||||
# we can use the base class implementation
|
||||
return super().get_all_model_aliases()
|
||||
|
||||
@@ -10,7 +10,7 @@ from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
@@ -30,63 +30,170 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
MAX_RETRIES = 4
|
||||
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||
|
||||
# Supported DIAL models (these can be customized based on your DIAL deployment)
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"o3-2025-04-16": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"o4-mini-2025-04-16": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": True, # Thinking mode variant
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-opus-4-20250514-v1:0": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": {
|
||||
"context_window": 200_000,
|
||||
"supports_extended_thinking": True, # Thinking mode variant
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-pro-preview-03-25-google-search": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False, # DIAL doesn't expose thinking mode
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-pro-preview-05-06": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.5-flash-preview-05-20": {
|
||||
"context_window": 1_000_000,
|
||||
"supports_extended_thinking": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Shorthands
|
||||
"o3": "o3-2025-04-16",
|
||||
"o4-mini": "o4-mini-2025-04-16",
|
||||
"sonnet-4": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"sonnet-4-thinking": "anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
||||
"opus-4": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
"opus-4-thinking": "anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-search": "gemini-2.5-pro-preview-03-25-google-search",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
|
||||
"o3-2025-04-16": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="o3-2025-04-16",
|
||||
friendly_name="DIAL (O3)",
|
||||
context_window=200_000,
|
||||
max_output_tokens=100_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=False, # O3 models don't accept temperature
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="OpenAI O3 via DIAL - Strong reasoning model",
|
||||
aliases=["o3"],
|
||||
),
|
||||
"o4-mini-2025-04-16": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="o4-mini-2025-04-16",
|
||||
friendly_name="DIAL (O4-mini)",
|
||||
context_window=200_000,
|
||||
max_output_tokens=100_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=False, # O4 models don't accept temperature
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="OpenAI O4-mini via DIAL - Fast reasoning model",
|
||||
aliases=["o4-mini"],
|
||||
),
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
friendly_name="DIAL (Sonnet 4)",
|
||||
context_window=200_000,
|
||||
max_output_tokens=64_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Sonnet 4 via DIAL - Balanced performance",
|
||||
aliases=["sonnet-4"],
|
||||
),
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0-with-thinking": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0-with-thinking",
|
||||
friendly_name="DIAL (Sonnet 4 Thinking)",
|
||||
context_window=200_000,
|
||||
max_output_tokens=64_000,
|
||||
supports_extended_thinking=True, # Thinking mode variant
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Sonnet 4 with thinking mode via DIAL",
|
||||
aliases=["sonnet-4-thinking"],
|
||||
),
|
||||
"anthropic.claude-opus-4-20250514-v1:0": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-opus-4-20250514-v1:0",
|
||||
friendly_name="DIAL (Opus 4)",
|
||||
context_window=200_000,
|
||||
max_output_tokens=64_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Opus 4 via DIAL - Most capable Claude model",
|
||||
aliases=["opus-4"],
|
||||
),
|
||||
"anthropic.claude-opus-4-20250514-v1:0-with-thinking": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-opus-4-20250514-v1:0-with-thinking",
|
||||
friendly_name="DIAL (Opus 4 Thinking)",
|
||||
context_window=200_000,
|
||||
max_output_tokens=64_000,
|
||||
supports_extended_thinking=True, # Thinking mode variant
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # Claude doesn't have function calling
|
||||
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Claude Opus 4 with thinking mode via DIAL",
|
||||
aliases=["opus-4-thinking"],
|
||||
),
|
||||
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-pro-preview-03-25-google-search",
|
||||
friendly_name="DIAL (Gemini 2.5 Pro Search)",
|
||||
context_window=1_000_000,
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=False, # DIAL doesn't expose thinking mode
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Gemini 2.5 Pro with Google Search via DIAL",
|
||||
aliases=["gemini-2.5-pro-search"],
|
||||
),
|
||||
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-pro-preview-05-06",
|
||||
friendly_name="DIAL (Gemini 2.5 Pro)",
|
||||
context_window=1_000_000,
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
|
||||
aliases=["gemini-2.5-pro"],
|
||||
),
|
||||
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-flash-preview-05-20",
|
||||
friendly_name="DIAL (Gemini Flash 2.5)",
|
||||
context_window=1_000_000,
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
|
||||
aliases=["gemini-2.5-flash"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
@@ -181,20 +288,8 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model_name):
|
||||
raise ValueError(f"Model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name=resolved_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_images=config.get("supports_vision", False),
|
||||
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 0.7),
|
||||
)
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -211,7 +306,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
return False
|
||||
|
||||
# Check against base class allowed_models if configured
|
||||
@@ -231,20 +326,6 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name.
|
||||
|
||||
Args:
|
||||
model_name: Model name or shorthand
|
||||
|
||||
Returns:
|
||||
Full model name
|
||||
"""
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
def _get_deployment_client(self, deployment: str):
|
||||
"""Get or create a cached client for a specific deployment.
|
||||
|
||||
@@ -357,7 +438,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
# Check model capabilities
|
||||
try:
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
supports_temperature = getattr(capabilities, "supports_temperature", True)
|
||||
supports_temperature = capabilities.supports_temperature
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to check temperature support for {model_name}: {e}")
|
||||
supports_temperature = True
|
||||
@@ -441,63 +522,12 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name in self.SUPPORTED_MODELS and isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
return self.SUPPORTED_MODELS[resolved_name].get("supports_vision", False)
|
||||
if resolved_name in self.SUPPORTED_MODELS:
|
||||
return self.SUPPORTED_MODELS[resolved_name].supports_images
|
||||
|
||||
# Fall back to parent implementation for unknown models
|
||||
return super()._supports_vision(model_name)
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
# Get all model keys (both full names and aliases)
|
||||
all_models = list(self.SUPPORTED_MODELS.keys())
|
||||
|
||||
if not respect_restrictions:
|
||||
return all_models
|
||||
|
||||
# Apply restrictions if configured
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
|
||||
# Filter based on restrictions
|
||||
allowed_models = []
|
||||
for model in all_models:
|
||||
resolved_name = self._resolve_model_name(model)
|
||||
if restriction_service.is_allowed(ProviderType.DIAL, resolved_name, model):
|
||||
allowed_models.append(model)
|
||||
|
||||
return allowed_models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
This is used for validation purposes to ensure restriction policies
|
||||
can validate against both aliases and their target model names.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
# Collect all unique model names (both aliases and targets)
|
||||
all_models = set()
|
||||
|
||||
for key, value in self.SUPPORTED_MODELS.items():
|
||||
# Add the key (could be alias or full name)
|
||||
all_models.add(key)
|
||||
|
||||
# If it's an alias (string value), add the target too
|
||||
if isinstance(value, str):
|
||||
all_models.add(value)
|
||||
|
||||
return sorted(all_models)
|
||||
|
||||
def close(self):
|
||||
"""Clean up HTTP clients when provider is closed."""
|
||||
logger.info("Closing DIAL provider HTTP clients...")
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Optional
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, RangeTemperatureConstraint
|
||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,47 +17,83 @@ logger = logging.getLogger(__name__)
|
||||
class GeminiModelProvider(ModelProvider):
|
||||
"""Google Gemini model provider implementation."""
|
||||
|
||||
# Model configurations
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"gemini-2.0-flash": {
|
||||
"context_window": 1_048_576, # 1M tokens
|
||||
"supports_extended_thinking": True, # Experimental thinking mode
|
||||
"max_thinking_tokens": 24576, # Same as 2.5 flash for consistency
|
||||
"supports_images": True, # Vision capability
|
||||
"max_image_size_mb": 20.0, # Conservative 20MB limit for reliability
|
||||
"description": "Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input",
|
||||
},
|
||||
"gemini-2.0-flash-lite": {
|
||||
"context_window": 1_048_576, # 1M tokens
|
||||
"supports_extended_thinking": False, # Not supported per user request
|
||||
"max_thinking_tokens": 0, # No thinking support
|
||||
"supports_images": False, # Does not support images
|
||||
"max_image_size_mb": 0.0, # No image support
|
||||
"description": "Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only",
|
||||
},
|
||||
"gemini-2.5-flash": {
|
||||
"context_window": 1_048_576, # 1M tokens
|
||||
"supports_extended_thinking": True,
|
||||
"max_thinking_tokens": 24576, # Flash 2.5 thinking budget limit
|
||||
"supports_images": True, # Vision capability
|
||||
"max_image_size_mb": 20.0, # Conservative 20MB limit for reliability
|
||||
"description": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
"context_window": 1_048_576, # 1M tokens
|
||||
"supports_extended_thinking": True,
|
||||
"max_thinking_tokens": 32768, # Pro 2.5 thinking budget limit
|
||||
"supports_images": True, # Vision capability
|
||||
"max_image_size_mb": 32.0, # Higher limit for Pro model
|
||||
"description": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||
},
|
||||
# Shorthands
|
||||
"flash": "gemini-2.5-flash",
|
||||
"flash-2.0": "gemini-2.0-flash",
|
||||
"flash2": "gemini-2.0-flash",
|
||||
"flashlite": "gemini-2.0-flash-lite",
|
||||
"flash-lite": "gemini-2.0-flash-lite",
|
||||
"pro": "gemini-2.5-pro",
|
||||
"gemini-2.0-flash": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.0-flash",
|
||||
friendly_name="Gemini (Flash 2.0)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=True, # Experimental thinking mode
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # Vision capability
|
||||
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
max_thinking_tokens=24576, # Same as 2.5 flash for consistency
|
||||
description="Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input",
|
||||
aliases=["flash-2.0", "flash2"],
|
||||
),
|
||||
"gemini-2.0-flash-lite": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.0-flash-lite",
|
||||
friendly_name="Gemin (Flash Lite 2.0)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=False, # Not supported per user request
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=False, # Does not support images
|
||||
max_image_size_mb=0.0, # No image support
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only",
|
||||
aliases=["flashlite", "flash-lite"],
|
||||
),
|
||||
"gemini-2.5-flash": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.5-flash",
|
||||
friendly_name="Gemini (Flash 2.5)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # Vision capability
|
||||
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
max_thinking_tokens=24576, # Flash 2.5 thinking budget limit
|
||||
description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||
aliases=["flash", "flash2.5"],
|
||||
),
|
||||
"gemini-2.5-pro": ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.5-pro",
|
||||
friendly_name="Gemini (Pro 2.5)",
|
||||
context_window=1_048_576, # 1M tokens
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # Vision capability
|
||||
max_image_size_mb=32.0, # Higher limit for Pro model
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
max_thinking_tokens=32768, # Max thinking tokens for Pro model
|
||||
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||
aliases=["pro", "gemini pro", "gemini-pro"],
|
||||
),
|
||||
}
|
||||
|
||||
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
||||
@@ -70,6 +106,14 @@ class GeminiModelProvider(ModelProvider):
|
||||
"max": 1.0, # 100% of max - full thinking budget
|
||||
}
|
||||
|
||||
# Model-specific thinking token limits
|
||||
MAX_THINKING_TOKENS = {
|
||||
"gemini-2.0-flash": 24576, # Same as 2.5 flash for consistency
|
||||
"gemini-2.0-flash-lite": 0, # No thinking support
|
||||
"gemini-2.5-flash": 24576, # Flash 2.5 thinking budget limit
|
||||
"gemini-2.5-pro": 32768, # Pro 2.5 thinking budget limit
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize Gemini provider with API key."""
|
||||
super().__init__(api_key, **kwargs)
|
||||
@@ -100,25 +144,8 @@ class GeminiModelProvider(ModelProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.GOOGLE, resolved_name, model_name):
|
||||
raise ValueError(f"Gemini model '{resolved_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
# Gemini models support 0.0-2.0 temperature range
|
||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name=resolved_name,
|
||||
friendly_name="Gemini",
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_images=config.get("supports_images", False),
|
||||
max_image_size_mb=config.get("max_image_size_mb", 0.0),
|
||||
supports_temperature=True, # Gemini models accept temperature parameter
|
||||
temperature_constraint=temp_constraint,
|
||||
)
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
@@ -179,8 +206,8 @@ class GeminiModelProvider(ModelProvider):
|
||||
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||
# Get model's max thinking tokens and calculate actual budget
|
||||
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
||||
if model_config and "max_thinking_tokens" in model_config:
|
||||
max_thinking_tokens = model_config["max_thinking_tokens"]
|
||||
if model_config and model_config.max_thinking_tokens > 0:
|
||||
max_thinking_tokens = model_config.max_thinking_tokens
|
||||
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||
generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget)
|
||||
|
||||
@@ -258,7 +285,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
@@ -281,78 +308,20 @@ class GeminiModelProvider(ModelProvider):
|
||||
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
||||
"""Get actual thinking token budget for a model and thinking mode."""
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
model_config = self.SUPPORTED_MODELS.get(resolved_name, {})
|
||||
model_config = self.SUPPORTED_MODELS.get(resolved_name)
|
||||
|
||||
if not model_config.get("supports_extended_thinking", False):
|
||||
if not model_config or not model_config.supports_extended_thinking:
|
||||
return 0
|
||||
|
||||
if thinking_mode not in self.THINKING_BUDGETS:
|
||||
return 0
|
||||
|
||||
max_thinking_tokens = model_config.get("max_thinking_tokens", 0)
|
||||
max_thinking_tokens = model_config.max_thinking_tokens
|
||||
if max_thinking_tokens == 0:
|
||||
return 0
|
||||
|
||||
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Handle both base models (dict configs) and aliases (string values)
|
||||
if isinstance(config, str):
|
||||
# This is an alias - check if the target model would be allowed
|
||||
target_model = config
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||
continue
|
||||
# Allow the alias
|
||||
models.append(model_name)
|
||||
else:
|
||||
# This is a base model with config dict
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
models.append(model_name)
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Add the model name itself
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
# If it's an alias (string value), add the target model too
|
||||
if isinstance(config, str):
|
||||
all_models.add(config.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
# Check if it's a shorthand
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name.lower())
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
def _extract_usage(self, response) -> dict[str, int]:
|
||||
"""Extract token usage from Gemini response."""
|
||||
usage = {}
|
||||
|
||||
@@ -686,7 +686,6 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
"o3-mini",
|
||||
"o3-pro",
|
||||
"o4-mini",
|
||||
"o4-mini-high",
|
||||
# Note: Claude models would be handled by a separate provider
|
||||
}
|
||||
supports = model_name.lower() in vision_models
|
||||
|
||||
@@ -17,71 +17,98 @@ logger = logging.getLogger(__name__)
|
||||
class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
"""Official OpenAI API provider (api.openai.com)."""
|
||||
|
||||
# Model configurations
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"o3": {
|
||||
"context_window": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # O3 models support vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
||||
"description": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
|
||||
},
|
||||
"o3-mini": {
|
||||
"context_window": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # O3 models support vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
||||
"description": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||
},
|
||||
"o3-pro-2025-06-10": {
|
||||
"context_window": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # O3 models support vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": False, # O3 models don't accept temperature parameter
|
||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
||||
"description": "Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
||||
},
|
||||
# Aliases
|
||||
"o3-pro": "o3-pro-2025-06-10",
|
||||
"o4-mini": {
|
||||
"context_window": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # O4 models support vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": False, # O4 models don't accept temperature parameter
|
||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
||||
"description": "Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
||||
},
|
||||
"o4-mini-high": {
|
||||
"context_window": 200_000, # 200K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # O4 models support vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": False, # O4 models don't accept temperature parameter
|
||||
"temperature_constraint": "fixed", # Fixed at 1.0
|
||||
"description": "Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks",
|
||||
},
|
||||
"gpt-4.1-2025-04-14": {
|
||||
"context_window": 1_000_000, # 1M tokens
|
||||
"supports_extended_thinking": False,
|
||||
"supports_images": True, # GPT-4.1 supports vision
|
||||
"max_image_size_mb": 20.0, # 20MB per OpenAI docs
|
||||
"supports_temperature": True, # Regular models accept temperature parameter
|
||||
"temperature_constraint": "range", # 0.0-2.0 range
|
||||
"description": "GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
||||
},
|
||||
# Shorthands
|
||||
"mini": "o4-mini", # Default 'mini' to latest mini model
|
||||
"o3mini": "o3-mini",
|
||||
"o4mini": "o4-mini",
|
||||
"o4minihigh": "o4-mini-high",
|
||||
"o4minihi": "o4-mini-high",
|
||||
"gpt4.1": "gpt-4.1-2025-04-14",
|
||||
"o3": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o3",
|
||||
friendly_name="OpenAI (O3)",
|
||||
context_window=200_000, # 200K tokens
|
||||
max_output_tokens=65536, # 64K max output tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O3 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
|
||||
aliases=[],
|
||||
),
|
||||
"o3-mini": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o3-mini",
|
||||
friendly_name="OpenAI (O3-mini)",
|
||||
context_window=200_000, # 200K tokens
|
||||
max_output_tokens=65536, # 64K max output tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O3 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||
aliases=["o3mini", "o3-mini"],
|
||||
),
|
||||
"o3-pro-2025-06-10": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o3-pro-2025-06-10",
|
||||
friendly_name="OpenAI (O3-Pro)",
|
||||
context_window=200_000, # 200K tokens
|
||||
max_output_tokens=65536, # 64K max output tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O3 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
||||
aliases=["o3-pro"],
|
||||
),
|
||||
"o4-mini": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="o4-mini",
|
||||
friendly_name="OpenAI (O4-mini)",
|
||||
context_window=200_000, # 200K tokens
|
||||
max_output_tokens=65536, # 64K max output tokens
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O4 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O4 models don't accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("fixed"),
|
||||
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
||||
aliases=["mini", "o4mini", "o4-mini"],
|
||||
),
|
||||
"gpt-4.1-2025-04-14": ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="gpt-4.1-2025-04-14",
|
||||
friendly_name="OpenAI (GPT 4.1)",
|
||||
context_window=1_000_000, # 1M tokens
|
||||
max_output_tokens=32_768,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # GPT-4.1 supports vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=True, # Regular models accept temperature parameter
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
||||
aliases=["gpt4.1"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
@@ -95,7 +122,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
# Resolve shorthand
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
@@ -105,27 +132,8 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.OPENAI, resolved_name, model_name):
|
||||
raise ValueError(f"OpenAI model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
# Get temperature constraints and support from configuration
|
||||
supports_temperature = config.get("supports_temperature", True) # Default to True for backward compatibility
|
||||
temp_constraint_type = config.get("temperature_constraint", "range") # Default to range
|
||||
temp_constraint = create_temperature_constraint(temp_constraint_type)
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name=model_name,
|
||||
friendly_name="OpenAI",
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_images=config.get("supports_images", False),
|
||||
max_image_size_mb=config.get("max_image_size_mb", 0.0),
|
||||
supports_temperature=supports_temperature,
|
||||
temperature_constraint=temp_constraint,
|
||||
)
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -136,7 +144,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
@@ -177,61 +185,3 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
# Currently no OpenAI models support extended thinking
|
||||
# This may change with future O3 models
|
||||
return False
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Handle both base models (dict configs) and aliases (string values)
|
||||
if isinstance(config, str):
|
||||
# This is an alias - check if the target model would be allowed
|
||||
target_model = config
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||
continue
|
||||
# Allow the alias
|
||||
models.append(model_name)
|
||||
else:
|
||||
# This is a base model with config dict
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
models.append(model_name)
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Add the model name itself
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
# If it's an alias (string value), add the target model too
|
||||
if isinstance(config, str):
|
||||
all_models.add(config.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
# Check if it's a shorthand
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
@@ -50,14 +50,6 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
aliases = self._registry.list_aliases()
|
||||
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
|
||||
|
||||
def _parse_allowed_models(self) -> None:
|
||||
"""Override to disable environment-based allow-list.
|
||||
|
||||
OpenRouter model access is controlled via the OpenRouter dashboard,
|
||||
not through environment variables.
|
||||
"""
|
||||
return None
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model aliases to OpenRouter model names.
|
||||
|
||||
@@ -109,6 +101,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
model_name=resolved_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=32_768, # Conservative default context window
|
||||
max_output_tokens=32_768,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
@@ -130,16 +123,34 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
|
||||
As the catch-all provider, OpenRouter accepts any model name that wasn't
|
||||
handled by higher-priority providers. OpenRouter will validate based on
|
||||
the API key's permissions.
|
||||
the API key's permissions and local restrictions.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
Always True - OpenRouter is the catch-all provider
|
||||
True if model is allowed, False if restricted
|
||||
"""
|
||||
# Accept any model name - OpenRouter is the fallback provider
|
||||
# Higher priority providers (native APIs, custom endpoints) get first chance
|
||||
# Check model restrictions if configured
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service()
|
||||
if restriction_service:
|
||||
# Check if model name itself is allowed
|
||||
if restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
return True
|
||||
|
||||
# Also check aliases - model_name might be an alias
|
||||
model_config = self._registry.resolve(model_name)
|
||||
if model_config and model_config.aliases:
|
||||
for alias in model_config.aliases:
|
||||
if restriction_service.is_allowed(self.get_provider_type(), alias):
|
||||
return True
|
||||
|
||||
# If restrictions are configured and model/alias not in allowed list, reject
|
||||
return False
|
||||
|
||||
# No restrictions configured - accept any model name as the fallback provider
|
||||
return True
|
||||
|
||||
def generate_content(
|
||||
@@ -260,3 +271,35 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
all_models.add(config.model_name.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
||||
"""Get model configurations from the registry.
|
||||
|
||||
For OpenRouter, we convert registry configurations to ModelCapabilities objects.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their ModelCapabilities objects
|
||||
"""
|
||||
configs = {}
|
||||
|
||||
if self._registry:
|
||||
# Get all models from registry
|
||||
for model_name in self._registry.list_models():
|
||||
# Only include models that this provider validates
|
||||
if self.validate_model_name(model_name):
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and not config.is_custom: # Only OpenRouter models, not custom ones
|
||||
# Use ModelCapabilities directly from registry
|
||||
configs[model_name] = config
|
||||
|
||||
return configs
|
||||
|
||||
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
||||
"""Get all model aliases from the registry.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model names to their list of aliases
|
||||
"""
|
||||
# Since aliases are now included in the configurations,
|
||||
# we can use the base class implementation
|
||||
return super().get_all_model_aliases()
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -11,58 +10,10 @@ from utils.file_utils import read_json_file
|
||||
from .base import (
|
||||
ModelCapabilities,
|
||||
ProviderType,
|
||||
TemperatureConstraint,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenRouterModelConfig:
|
||||
"""Configuration for an OpenRouter model."""
|
||||
|
||||
model_name: str
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
context_window: int = 32768 # Total context window size in tokens
|
||||
supports_extended_thinking: bool = False
|
||||
supports_system_prompts: bool = True
|
||||
supports_streaming: bool = True
|
||||
supports_function_calling: bool = False
|
||||
supports_json_mode: bool = False
|
||||
supports_images: bool = False # Whether model can process images
|
||||
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
||||
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
||||
temperature_constraint: Optional[str] = (
|
||||
None # Type of temperature constraint: "fixed", "range", "discrete", or None for default range
|
||||
)
|
||||
is_custom: bool = False # True for models that should only be used with custom endpoints
|
||||
description: str = ""
|
||||
|
||||
def _create_temperature_constraint(self) -> TemperatureConstraint:
|
||||
"""Create temperature constraint object from configuration.
|
||||
|
||||
Returns:
|
||||
TemperatureConstraint object based on configuration
|
||||
"""
|
||||
return create_temperature_constraint(self.temperature_constraint or "range")
|
||||
|
||||
def to_capabilities(self) -> ModelCapabilities:
|
||||
"""Convert to ModelCapabilities object."""
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
model_name=self.model_name,
|
||||
friendly_name="OpenRouter",
|
||||
context_window=self.context_window,
|
||||
supports_extended_thinking=self.supports_extended_thinking,
|
||||
supports_system_prompts=self.supports_system_prompts,
|
||||
supports_streaming=self.supports_streaming,
|
||||
supports_function_calling=self.supports_function_calling,
|
||||
supports_images=self.supports_images,
|
||||
max_image_size_mb=self.max_image_size_mb,
|
||||
supports_temperature=self.supports_temperature,
|
||||
temperature_constraint=self._create_temperature_constraint(),
|
||||
)
|
||||
|
||||
|
||||
class OpenRouterModelRegistry:
|
||||
"""Registry for managing OpenRouter model configurations and aliases."""
|
||||
|
||||
@@ -73,7 +24,7 @@ class OpenRouterModelRegistry:
|
||||
config_path: Path to config file. If None, uses default locations.
|
||||
"""
|
||||
self.alias_map: dict[str, str] = {} # alias -> model_name
|
||||
self.model_map: dict[str, OpenRouterModelConfig] = {} # model_name -> config
|
||||
self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config
|
||||
|
||||
# Determine config path
|
||||
if config_path:
|
||||
@@ -139,7 +90,7 @@ class OpenRouterModelRegistry:
|
||||
self.alias_map = {}
|
||||
self.model_map = {}
|
||||
|
||||
def _read_config(self) -> list[OpenRouterModelConfig]:
|
||||
def _read_config(self) -> list[ModelCapabilities]:
|
||||
"""Read configuration from file.
|
||||
|
||||
Returns:
|
||||
@@ -158,7 +109,27 @@ class OpenRouterModelRegistry:
|
||||
# Parse models
|
||||
configs = []
|
||||
for model_data in data.get("models", []):
|
||||
config = OpenRouterModelConfig(**model_data)
|
||||
# Create ModelCapabilities directly from JSON data
|
||||
# Handle temperature_constraint conversion
|
||||
temp_constraint_str = model_data.get("temperature_constraint")
|
||||
temp_constraint = create_temperature_constraint(temp_constraint_str or "range")
|
||||
|
||||
# Set provider-specific defaults based on is_custom flag
|
||||
is_custom = model_data.get("is_custom", False)
|
||||
if is_custom:
|
||||
model_data.setdefault("provider", ProviderType.CUSTOM)
|
||||
model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})")
|
||||
else:
|
||||
model_data.setdefault("provider", ProviderType.OPENROUTER)
|
||||
model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})")
|
||||
model_data["temperature_constraint"] = temp_constraint
|
||||
|
||||
# Remove the string version of temperature_constraint before creating ModelCapabilities
|
||||
if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str):
|
||||
del model_data["temperature_constraint"]
|
||||
model_data["temperature_constraint"] = temp_constraint
|
||||
|
||||
config = ModelCapabilities(**model_data)
|
||||
configs.append(config)
|
||||
|
||||
return configs
|
||||
@@ -168,7 +139,7 @@ class OpenRouterModelRegistry:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading config from {self.config_path}: {e}")
|
||||
|
||||
def _build_maps(self, configs: list[OpenRouterModelConfig]) -> None:
|
||||
def _build_maps(self, configs: list[ModelCapabilities]) -> None:
|
||||
"""Build alias and model maps from configurations.
|
||||
|
||||
Args:
|
||||
@@ -211,7 +182,7 @@ class OpenRouterModelRegistry:
|
||||
self.alias_map = alias_map
|
||||
self.model_map = model_map
|
||||
|
||||
def resolve(self, name_or_alias: str) -> Optional[OpenRouterModelConfig]:
|
||||
def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||
"""Resolve a model name or alias to configuration.
|
||||
|
||||
Args:
|
||||
@@ -237,10 +208,8 @@ class OpenRouterModelRegistry:
|
||||
Returns:
|
||||
ModelCapabilities if found, None otherwise
|
||||
"""
|
||||
config = self.resolve(name_or_alias)
|
||||
if config:
|
||||
return config.to_capabilities()
|
||||
return None
|
||||
# Registry now returns ModelCapabilities directly
|
||||
return self.resolve(name_or_alias)
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
"""List all available model names."""
|
||||
|
||||
@@ -24,8 +24,6 @@ class ModelProviderRegistry:
|
||||
cls._instance._providers = {}
|
||||
cls._instance._initialized_providers = {}
|
||||
logging.debug(f"REGISTRY: Created instance {cls._instance}")
|
||||
else:
|
||||
logging.debug(f"REGISTRY: Returning existing instance {cls._instance}")
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
@@ -129,7 +127,6 @@ class ModelProviderRegistry:
|
||||
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
|
||||
|
||||
for provider_type in PROVIDER_PRIORITY_ORDER:
|
||||
logging.debug(f"Checking provider_type: {provider_type}")
|
||||
if provider_type in instance._providers:
|
||||
logging.debug(f"Found {provider_type} in registry")
|
||||
# Get or create provider instance
|
||||
|
||||
136
providers/xai.py
136
providers/xai.py
@@ -7,7 +7,7 @@ from .base import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
@@ -19,23 +19,44 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
FRIENDLY_NAME = "X.AI"
|
||||
|
||||
# Model configurations
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
"grok-3": {
|
||||
"context_window": 131_072, # 131K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"description": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||
},
|
||||
"grok-3-fast": {
|
||||
"context_window": 131_072, # 131K tokens
|
||||
"supports_extended_thinking": False,
|
||||
"description": "GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
||||
},
|
||||
# Shorthands for convenience
|
||||
"grok": "grok-3", # Default to grok-3
|
||||
"grok3": "grok-3",
|
||||
"grok3fast": "grok-3-fast",
|
||||
"grokfast": "grok-3-fast",
|
||||
"grok-3": ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name="grok-3",
|
||||
friendly_name="X.AI (Grok 3)",
|
||||
context_window=131_072, # 131K tokens
|
||||
max_output_tokens=131072,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||
supports_images=False, # Assuming GROK is text-only for now
|
||||
max_image_size_mb=0.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||
aliases=["grok", "grok3"],
|
||||
),
|
||||
"grok-3-fast": ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name="grok-3-fast",
|
||||
friendly_name="X.AI (Grok 3 Fast)",
|
||||
context_window=131_072, # 131K tokens
|
||||
max_output_tokens=131072,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||
supports_images=False, # Assuming GROK is text-only for now
|
||||
max_image_size_mb=0.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
description="GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
||||
aliases=["grok3fast", "grokfast", "grok3-fast"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
@@ -49,7 +70,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
# Resolve shorthand
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
||||
|
||||
# Check if model is allowed by restrictions
|
||||
@@ -59,23 +80,8 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
||||
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
# Define temperature constraints for GROK models
|
||||
# GROK supports the standard OpenAI temperature range
|
||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.XAI,
|
||||
model_name=resolved_name,
|
||||
friendly_name=self.FRIENDLY_NAME,
|
||||
context_window=config["context_window"],
|
||||
supports_extended_thinking=config["supports_extended_thinking"],
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
temperature_constraint=temp_constraint,
|
||||
)
|
||||
# Return the ModelCapabilities object directly from SUPPORTED_MODELS
|
||||
return self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
"""Get the provider type."""
|
||||
@@ -86,7 +92,7 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
resolved_name = self._resolve_model_name(model_name)
|
||||
|
||||
# First check if model is supported
|
||||
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||
if resolved_name not in self.SUPPORTED_MODELS:
|
||||
return False
|
||||
|
||||
# Then check if model is allowed by restrictions
|
||||
@@ -127,61 +133,3 @@ class XAIModelProvider(OpenAICompatibleProvider):
|
||||
# Currently GROK models do not support extended thinking
|
||||
# This may change with future GROK model releases
|
||||
return False
|
||||
|
||||
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
||||
"""Return a list of model names supported by this provider.
|
||||
|
||||
Args:
|
||||
respect_restrictions: Whether to apply provider-specific restriction logic.
|
||||
|
||||
Returns:
|
||||
List of model names available from this provider
|
||||
"""
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||
models = []
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Handle both base models (dict configs) and aliases (string values)
|
||||
if isinstance(config, str):
|
||||
# This is an alias - check if the target model would be allowed
|
||||
target_model = config
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), target_model):
|
||||
continue
|
||||
# Allow the alias
|
||||
models.append(model_name)
|
||||
else:
|
||||
# This is a base model with config dict
|
||||
# Check restrictions if enabled
|
||||
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||
continue
|
||||
models.append(model_name)
|
||||
|
||||
return models
|
||||
|
||||
def list_all_known_models(self) -> list[str]:
|
||||
"""Return all model names known by this provider, including alias targets.
|
||||
|
||||
Returns:
|
||||
List of all model names and alias targets known by this provider
|
||||
"""
|
||||
all_models = set()
|
||||
|
||||
for model_name, config in self.SUPPORTED_MODELS.items():
|
||||
# Add the model name itself
|
||||
all_models.add(model_name.lower())
|
||||
|
||||
# If it's an alias (string value), add the target model too
|
||||
if isinstance(config, str):
|
||||
all_models.add(config.lower())
|
||||
|
||||
return list(all_models)
|
||||
|
||||
def _resolve_model_name(self, model_name: str) -> str:
|
||||
"""Resolve model shorthand to full name."""
|
||||
# Check if it's a shorthand
|
||||
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(shorthand_value, str):
|
||||
return shorthand_value
|
||||
return model_name
|
||||
|
||||
108
server.py
108
server.py
@@ -158,6 +158,97 @@ logger = logging.getLogger(__name__)
|
||||
# This name is used by MCP clients to identify and connect to this specific server
|
||||
server: Server = Server("zen-server")
|
||||
|
||||
|
||||
# Constants for tool filtering
|
||||
ESSENTIAL_TOOLS = {"version", "listmodels"}
|
||||
|
||||
|
||||
def parse_disabled_tools_env() -> set[str]:
|
||||
"""
|
||||
Parse the DISABLED_TOOLS environment variable into a set of tool names.
|
||||
|
||||
Returns:
|
||||
Set of lowercase tool names to disable, empty set if none specified
|
||||
"""
|
||||
disabled_tools_env = os.getenv("DISABLED_TOOLS", "").strip()
|
||||
if not disabled_tools_env:
|
||||
return set()
|
||||
return {t.strip().lower() for t in disabled_tools_env.split(",") if t.strip()}
|
||||
|
||||
|
||||
def validate_disabled_tools(disabled_tools: set[str], all_tools: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the disabled tools list and log appropriate warnings.
|
||||
|
||||
Args:
|
||||
disabled_tools: Set of tool names requested to be disabled
|
||||
all_tools: Dictionary of all available tool instances
|
||||
"""
|
||||
essential_disabled = disabled_tools & ESSENTIAL_TOOLS
|
||||
if essential_disabled:
|
||||
logger.warning(f"Cannot disable essential tools: {sorted(essential_disabled)}")
|
||||
unknown_tools = disabled_tools - set(all_tools.keys())
|
||||
if unknown_tools:
|
||||
logger.warning(f"Unknown tools in DISABLED_TOOLS: {sorted(unknown_tools)}")
|
||||
|
||||
|
||||
def apply_tool_filter(all_tools: dict[str, Any], disabled_tools: set[str]) -> dict[str, Any]:
|
||||
"""
|
||||
Apply the disabled tools filter to create the final tools dictionary.
|
||||
|
||||
Args:
|
||||
all_tools: Dictionary of all available tool instances
|
||||
disabled_tools: Set of tool names to disable
|
||||
|
||||
Returns:
|
||||
Dictionary containing only enabled tools
|
||||
"""
|
||||
enabled_tools = {}
|
||||
for tool_name, tool_instance in all_tools.items():
|
||||
if tool_name in ESSENTIAL_TOOLS or tool_name not in disabled_tools:
|
||||
enabled_tools[tool_name] = tool_instance
|
||||
else:
|
||||
logger.debug(f"Tool '{tool_name}' disabled via DISABLED_TOOLS")
|
||||
return enabled_tools
|
||||
|
||||
|
||||
def log_tool_configuration(disabled_tools: set[str], enabled_tools: dict[str, Any]) -> None:
|
||||
"""
|
||||
Log the final tool configuration for visibility.
|
||||
|
||||
Args:
|
||||
disabled_tools: Set of tool names that were requested to be disabled
|
||||
enabled_tools: Dictionary of tools that remain enabled
|
||||
"""
|
||||
if not disabled_tools:
|
||||
logger.info("All tools enabled (DISABLED_TOOLS not set)")
|
||||
return
|
||||
actual_disabled = disabled_tools - ESSENTIAL_TOOLS
|
||||
if actual_disabled:
|
||||
logger.debug(f"Disabled tools: {sorted(actual_disabled)}")
|
||||
logger.info(f"Active tools: {sorted(enabled_tools.keys())}")
|
||||
|
||||
|
||||
def filter_disabled_tools(all_tools: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Filter tools based on DISABLED_TOOLS environment variable.
|
||||
|
||||
Args:
|
||||
all_tools: Dictionary of all available tool instances
|
||||
|
||||
Returns:
|
||||
dict: Filtered dictionary containing only enabled tools
|
||||
"""
|
||||
disabled_tools = parse_disabled_tools_env()
|
||||
if not disabled_tools:
|
||||
log_tool_configuration(disabled_tools, all_tools)
|
||||
return all_tools
|
||||
validate_disabled_tools(disabled_tools, all_tools)
|
||||
enabled_tools = apply_tool_filter(all_tools, disabled_tools)
|
||||
log_tool_configuration(disabled_tools, enabled_tools)
|
||||
return enabled_tools
|
||||
|
||||
|
||||
# Initialize the tool registry with all available AI-powered tools
|
||||
# Each tool provides specialized functionality for different development tasks
|
||||
# Tools are instantiated once and reused across requests (stateless design)
|
||||
@@ -178,6 +269,7 @@ TOOLS = {
|
||||
"listmodels": ListModelsTool(), # List all available AI models by provider
|
||||
"version": VersionTool(), # Display server version and system information
|
||||
}
|
||||
TOOLS = filter_disabled_tools(TOOLS)
|
||||
|
||||
# Rich prompt templates for all tools
|
||||
PROMPT_TEMPLATES = {
|
||||
@@ -673,6 +765,11 @@ def parse_model_option(model_string: str) -> tuple[str, Optional[str]]:
|
||||
"""
|
||||
Parse model:option format into model name and option.
|
||||
|
||||
Handles different formats:
|
||||
- OpenRouter models: preserve :free, :beta, :preview suffixes as part of model name
|
||||
- Ollama/Custom models: split on : to extract tags like :latest
|
||||
- Consensus stance: extract options like :for, :against
|
||||
|
||||
Args:
|
||||
model_string: String that may contain "model:option" format
|
||||
|
||||
@@ -680,6 +777,17 @@ def parse_model_option(model_string: str) -> tuple[str, Optional[str]]:
|
||||
tuple: (model_name, option) where option may be None
|
||||
"""
|
||||
if ":" in model_string and not model_string.startswith("http"): # Avoid parsing URLs
|
||||
# Check if this looks like an OpenRouter model (contains /)
|
||||
if "/" in model_string and model_string.count(":") == 1:
|
||||
# Could be openai/gpt-4:something - check what comes after colon
|
||||
parts = model_string.split(":", 1)
|
||||
suffix = parts[1].strip().lower()
|
||||
|
||||
# Known OpenRouter suffixes to preserve
|
||||
if suffix in ["free", "beta", "preview"]:
|
||||
return model_string.strip(), None
|
||||
|
||||
# For other patterns (Ollama tags, consensus stances), split normally
|
||||
parts = model_string.split(":", 1)
|
||||
model_name = parts[0].strip()
|
||||
model_option = parts[1].strip() if len(parts) > 1 else None
|
||||
|
||||
@@ -182,6 +182,10 @@ class ConversationBaseTest(BaseSimulatorTest):
|
||||
|
||||
# Look for continuation_id in various places
|
||||
if isinstance(response_data, dict):
|
||||
# Check top-level continuation_id (workflow tools)
|
||||
if "continuation_id" in response_data:
|
||||
return response_data["continuation_id"]
|
||||
|
||||
# Check metadata
|
||||
metadata = response_data.get("metadata", {})
|
||||
if "thread_id" in metadata:
|
||||
|
||||
@@ -91,11 +91,14 @@ class TestClass:
|
||||
response_a2, continuation_id_a2 = self.call_mcp_tool(
|
||||
"analyze",
|
||||
{
|
||||
"prompt": "Now analyze the code quality and suggest improvements.",
|
||||
"files": [test_file_path],
|
||||
"step": "Now analyze the code quality and suggest improvements.",
|
||||
"step_number": 1,
|
||||
"total_steps": 2,
|
||||
"next_step_required": False,
|
||||
"findings": "Continuing analysis from previous chat conversation to analyze code quality.",
|
||||
"relevant_files": [test_file_path],
|
||||
"continuation_id": continuation_id_a1,
|
||||
"model": "flash",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -154,10 +157,14 @@ class TestClass:
|
||||
response_b2, continuation_id_b2 = self.call_mcp_tool(
|
||||
"analyze",
|
||||
{
|
||||
"prompt": "Analyze the previous greeting and suggest improvements.",
|
||||
"step": "Analyze the previous greeting and suggest improvements.",
|
||||
"step_number": 1,
|
||||
"total_steps": 1,
|
||||
"next_step_required": False,
|
||||
"findings": "Analyzing the greeting from previous conversation and suggesting improvements.",
|
||||
"relevant_files": [test_file_path],
|
||||
"continuation_id": continuation_id_b1,
|
||||
"model": "flash",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -206,11 +206,14 @@ if __name__ == "__main__":
|
||||
response2, continuation_id2 = self.call_mcp_tool(
|
||||
"analyze",
|
||||
{
|
||||
"prompt": "Analyze the performance implications of these recursive functions.",
|
||||
"files": [file1_path],
|
||||
"step": "Analyze the performance implications of these recursive functions.",
|
||||
"step_number": 1,
|
||||
"total_steps": 1,
|
||||
"next_step_required": False,
|
||||
"findings": "Continuing from chat conversation to analyze performance implications of recursive functions.",
|
||||
"relevant_files": [file1_path],
|
||||
"continuation_id": continuation_id1, # Continue the chat conversation
|
||||
"model": "flash",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -221,10 +224,14 @@ if __name__ == "__main__":
|
||||
self.logger.info(f" ✅ Step 2 completed with continuation_id: {continuation_id2[:8]}...")
|
||||
continuation_ids.append(continuation_id2)
|
||||
|
||||
# Validate that we got a different continuation ID
|
||||
if continuation_id2 == continuation_id1:
|
||||
self.logger.error(" ❌ Step 2: Got same continuation ID as Step 1 - continuation not working")
|
||||
return False
|
||||
# Validate continuation ID behavior for workflow tools
|
||||
# Workflow tools reuse the same continuation_id when continuing within a workflow session
|
||||
# This is expected behavior and different from simple tools
|
||||
if continuation_id2 != continuation_id1:
|
||||
self.logger.info(" ✅ Step 2: Got new continuation ID (workflow behavior)")
|
||||
else:
|
||||
self.logger.info(" ✅ Step 2: Reused continuation ID (workflow session continuation)")
|
||||
# Both behaviors are valid - what matters is that we got a continuation_id
|
||||
|
||||
# Validate that Step 2 is building on Step 1's conversation
|
||||
# Check if the response references the previous conversation
|
||||
@@ -276,17 +283,16 @@ if __name__ == "__main__":
|
||||
all_have_continuation_ids = bool(continuation_id1 and continuation_id2 and continuation_id3)
|
||||
criteria.append(("All steps generated continuation IDs", all_have_continuation_ids))
|
||||
|
||||
# 3. Each continuation ID is unique
|
||||
unique_continuation_ids = len(set(continuation_ids)) == len(continuation_ids)
|
||||
criteria.append(("Each response generated unique continuation ID", unique_continuation_ids))
|
||||
# 3. Continuation behavior validation (handles both simple and workflow tools)
|
||||
# Simple tools create new IDs each time, workflow tools may reuse IDs within sessions
|
||||
has_valid_continuation_pattern = len(continuation_ids) == 3
|
||||
criteria.append(("Valid continuation ID pattern", has_valid_continuation_pattern))
|
||||
|
||||
# 4. Continuation IDs follow the expected pattern
|
||||
step_ids_different = (
|
||||
len(continuation_ids) == 3
|
||||
and continuation_ids[0] != continuation_ids[1]
|
||||
and continuation_ids[1] != continuation_ids[2]
|
||||
# 4. Check for conversation continuity (more important than ID uniqueness)
|
||||
conversation_has_continuity = len(continuation_ids) == 3 and all(
|
||||
cid is not None for cid in continuation_ids
|
||||
)
|
||||
criteria.append(("All continuation IDs are different", step_ids_different))
|
||||
criteria.append(("Conversation continuity maintained", conversation_has_continuity))
|
||||
|
||||
# 5. Check responses build on each other (content validation)
|
||||
step1_has_function_analysis = "fibonacci" in response1.lower() or "factorial" in response1.lower()
|
||||
|
||||
@@ -15,6 +15,7 @@ def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576
|
||||
model_name=model_name,
|
||||
friendly_name="Gemini",
|
||||
context_window=context_window,
|
||||
max_output_tokens=8192,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
|
||||
@@ -211,7 +211,7 @@ class TestAliasTargetRestrictions:
|
||||
# Verify the polymorphic method was called
|
||||
mock_provider.list_all_known_models.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high"}) # Restrict to specific model
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}) # Restrict to specific model
|
||||
def test_complex_alias_chains_handled_correctly(self):
|
||||
"""Test that complex alias chains are handled correctly in restrictions."""
|
||||
# Clear cached restriction service
|
||||
@@ -221,12 +221,11 @@ class TestAliasTargetRestrictions:
|
||||
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
# Only o4-mini-high should be allowed
|
||||
assert provider.validate_model_name("o4-mini-high")
|
||||
# Only o4-mini should be allowed
|
||||
assert provider.validate_model_name("o4-mini")
|
||||
|
||||
# Other models should be blocked
|
||||
assert not provider.validate_model_name("o4-mini")
|
||||
assert not provider.validate_model_name("mini") # This resolves to o4-mini
|
||||
assert not provider.validate_model_name("o3")
|
||||
assert not provider.validate_model_name("o3-mini")
|
||||
|
||||
def test_critical_regression_validation_sees_alias_targets(self):
|
||||
@@ -307,7 +306,7 @@ class TestAliasTargetRestrictions:
|
||||
it appear that target-based restrictions don't work.
|
||||
"""
|
||||
# Test with a made-up restriction scenario
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini-high,o3-mini"}):
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini,o3-mini"}):
|
||||
# Clear cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
@@ -318,7 +317,7 @@ class TestAliasTargetRestrictions:
|
||||
|
||||
# These specific target models should be recognized as valid
|
||||
all_known = provider.list_all_known_models()
|
||||
assert "o4-mini-high" in all_known, "Target model o4-mini-high should be known"
|
||||
assert "o4-mini" in all_known, "Target model o4-mini should be known"
|
||||
assert "o3-mini" in all_known, "Target model o3-mini should be known"
|
||||
|
||||
# Validation should not warn about these being unrecognized
|
||||
@@ -329,11 +328,11 @@ class TestAliasTargetRestrictions:
|
||||
# Should not warn about our allowed models being unrecognized
|
||||
all_warnings = [str(call) for call in mock_logger.warning.call_args_list]
|
||||
for warning in all_warnings:
|
||||
assert "o4-mini-high" not in warning or "not a recognized" not in warning
|
||||
assert "o4-mini" not in warning or "not a recognized" not in warning
|
||||
assert "o3-mini" not in warning or "not a recognized" not in warning
|
||||
|
||||
# The restriction should actually work
|
||||
assert provider.validate_model_name("o4-mini-high")
|
||||
assert provider.validate_model_name("o4-mini")
|
||||
assert provider.validate_model_name("o3-mini")
|
||||
assert not provider.validate_model_name("o4-mini") # not in allowed list
|
||||
assert not provider.validate_model_name("o3-pro") # not in allowed list
|
||||
assert not provider.validate_model_name("o3") # not in allowed list
|
||||
|
||||
@@ -59,12 +59,12 @@ class TestAutoMode:
|
||||
continue
|
||||
|
||||
# Check that model has description
|
||||
description = config.get("description", "")
|
||||
description = config.description if hasattr(config, "description") else ""
|
||||
if description:
|
||||
models_with_descriptions[model_name] = description
|
||||
|
||||
# Check all expected models are present with meaningful descriptions
|
||||
expected_models = ["flash", "pro", "o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high"]
|
||||
expected_models = ["flash", "pro", "o3", "o3-mini", "o3-pro", "o4-mini"]
|
||||
for model in expected_models:
|
||||
# Model should exist somewhere in the providers
|
||||
# Note: Some models might not be available if API keys aren't configured
|
||||
|
||||
@@ -319,7 +319,18 @@ class TestAutoModeComprehensive:
|
||||
m
|
||||
for m in available_models
|
||||
if not m.startswith("gemini")
|
||||
and m not in ["flash", "pro", "flash-2.0", "flash2", "flashlite", "flash-lite"]
|
||||
and m
|
||||
not in [
|
||||
"flash",
|
||||
"pro",
|
||||
"flash-2.0",
|
||||
"flash2",
|
||||
"flashlite",
|
||||
"flash-lite",
|
||||
"flash2.5",
|
||||
"gemini pro",
|
||||
"gemini-pro",
|
||||
]
|
||||
]
|
||||
assert (
|
||||
len(non_gemini_models) == 0
|
||||
|
||||
@@ -70,7 +70,7 @@ class TestAutoModeCustomProviderOnly:
|
||||
}
|
||||
|
||||
# Clear all other provider keys
|
||||
clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]
|
||||
clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]
|
||||
|
||||
with patch.dict(os.environ, test_env, clear=False):
|
||||
# Ensure other provider keys are not set
|
||||
@@ -109,7 +109,7 @@ class TestAutoModeCustomProviderOnly:
|
||||
|
||||
with patch.dict(os.environ, test_env, clear=False):
|
||||
# Clear other provider keys
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
@@ -177,7 +177,7 @@ class TestAutoModeCustomProviderOnly:
|
||||
|
||||
with patch.dict(os.environ, test_env, clear=False):
|
||||
# Clear other provider keys
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ class TestBuggyBehaviorPrevention:
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
# Simulate a scenario where admin wants to restrict specific targets
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini-high"}):
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
|
||||
# Clear cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
@@ -126,19 +126,21 @@ class TestBuggyBehaviorPrevention:
|
||||
|
||||
# These should work because they're explicitly allowed
|
||||
assert provider.validate_model_name("o3-mini")
|
||||
assert provider.validate_model_name("o4-mini-high")
|
||||
assert provider.validate_model_name("o4-mini")
|
||||
|
||||
# These should be blocked
|
||||
assert not provider.validate_model_name("o4-mini") # Not in allowed list
|
||||
assert not provider.validate_model_name("o3-pro") # Not in allowed list
|
||||
assert not provider.validate_model_name("o3") # Not in allowed list
|
||||
assert not provider.validate_model_name("mini") # Resolves to o4-mini, not allowed
|
||||
|
||||
# This should be ALLOWED because it resolves to o4-mini which is in the allowed list
|
||||
assert provider.validate_model_name("mini") # Resolves to o4-mini, which IS allowed
|
||||
|
||||
# Verify our list_all_known_models includes the restricted models
|
||||
all_known = provider.list_all_known_models()
|
||||
assert "o3-mini" in all_known # Should be known (and allowed)
|
||||
assert "o4-mini-high" in all_known # Should be known (and allowed)
|
||||
assert "o4-mini" in all_known # Should be known (but blocked)
|
||||
assert "mini" in all_known # Should be known (but blocked)
|
||||
assert "o4-mini" in all_known # Should be known (and allowed)
|
||||
assert "o3-pro" in all_known # Should be known (but blocked)
|
||||
assert "mini" in all_known # Should be known (and allowed since it resolves to o4-mini)
|
||||
|
||||
def test_demonstration_of_old_vs_new_interface(self):
|
||||
"""
|
||||
|
||||
@@ -506,17 +506,17 @@ class TestConversationFlow:
|
||||
mock_client = Mock()
|
||||
mock_storage.return_value = mock_client
|
||||
|
||||
# Start conversation with files
|
||||
thread_id = create_thread("analyze", {"prompt": "Analyze this codebase", "relevant_files": ["/project/src/"]})
|
||||
# Start conversation with files using a simple tool
|
||||
thread_id = create_thread("chat", {"prompt": "Analyze this codebase", "files": ["/project/src/"]})
|
||||
|
||||
# Turn 1: Claude provides context with multiple files
|
||||
initial_context = ThreadContext(
|
||||
thread_id=thread_id,
|
||||
created_at="2023-01-01T00:00:00Z",
|
||||
last_updated_at="2023-01-01T00:00:00Z",
|
||||
tool_name="analyze",
|
||||
tool_name="chat",
|
||||
turns=[],
|
||||
initial_context={"prompt": "Analyze this codebase", "relevant_files": ["/project/src/"]},
|
||||
initial_context={"prompt": "Analyze this codebase", "files": ["/project/src/"]},
|
||||
)
|
||||
mock_client.get.return_value = initial_context.model_dump_json()
|
||||
|
||||
|
||||
@@ -45,10 +45,17 @@ class TestCustomProvider:
|
||||
|
||||
def test_get_capabilities_from_registry(self):
|
||||
"""Test get_capabilities returns registry capabilities when available."""
|
||||
# Save original environment
|
||||
original_env = os.environ.get("OPENROUTER_ALLOWED_MODELS")
|
||||
|
||||
try:
|
||||
# Clear any restrictions
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||
|
||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||
|
||||
# Test with a model that should be in the registry (OpenRouter model) and is allowed by restrictions
|
||||
capabilities = provider.get_capabilities("o3") # o3 is in OPENROUTER_ALLOWED_MODELS
|
||||
# Test with a model that should be in the registry (OpenRouter model)
|
||||
capabilities = provider.get_capabilities("o3") # o3 is an OpenRouter model
|
||||
|
||||
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
|
||||
assert capabilities.context_window > 0
|
||||
@@ -58,6 +65,13 @@ class TestCustomProvider:
|
||||
assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true
|
||||
assert capabilities.context_window > 0
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
if original_env is None:
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||
else:
|
||||
os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env
|
||||
|
||||
def test_get_capabilities_generic_fallback(self):
|
||||
"""Test get_capabilities returns generic capabilities for unknown models."""
|
||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||
|
||||
@@ -84,7 +84,7 @@ class TestDIALProvider:
|
||||
# Test O3 capabilities
|
||||
capabilities = provider.get_capabilities("o3")
|
||||
assert capabilities.model_name == "o3-2025-04-16"
|
||||
assert capabilities.friendly_name == "DIAL"
|
||||
assert capabilities.friendly_name == "DIAL (O3)"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.provider == ProviderType.DIAL
|
||||
assert capabilities.supports_images is True
|
||||
|
||||
140
tests/test_disabled_tools.py
Normal file
140
tests/test_disabled_tools.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Tests for DISABLED_TOOLS environment variable functionality."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from server import (
|
||||
apply_tool_filter,
|
||||
parse_disabled_tools_env,
|
||||
validate_disabled_tools,
|
||||
)
|
||||
|
||||
|
||||
# Mock the tool classes since we're testing the filtering logic
|
||||
class MockTool:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class TestDisabledTools:
|
||||
"""Test suite for DISABLED_TOOLS functionality."""
|
||||
|
||||
def test_parse_disabled_tools_empty(self):
|
||||
"""Empty string returns empty set (no tools disabled)."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": ""}):
|
||||
assert parse_disabled_tools_env() == set()
|
||||
|
||||
def test_parse_disabled_tools_not_set(self):
|
||||
"""Unset variable returns empty set."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Ensure DISABLED_TOOLS is not in environment
|
||||
if "DISABLED_TOOLS" in os.environ:
|
||||
del os.environ["DISABLED_TOOLS"]
|
||||
assert parse_disabled_tools_env() == set()
|
||||
|
||||
def test_parse_disabled_tools_single(self):
|
||||
"""Single tool name parsed correctly."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug"}):
|
||||
assert parse_disabled_tools_env() == {"debug"}
|
||||
|
||||
def test_parse_disabled_tools_multiple(self):
|
||||
"""Multiple tools with spaces parsed correctly."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug, analyze, refactor"}):
|
||||
assert parse_disabled_tools_env() == {"debug", "analyze", "refactor"}
|
||||
|
||||
def test_parse_disabled_tools_extra_spaces(self):
|
||||
"""Extra spaces and empty items handled correctly."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": " debug , , analyze , "}):
|
||||
assert parse_disabled_tools_env() == {"debug", "analyze"}
|
||||
|
||||
def test_parse_disabled_tools_duplicates(self):
|
||||
"""Duplicate entries handled correctly (set removes duplicates)."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": "debug,analyze,debug"}):
|
||||
assert parse_disabled_tools_env() == {"debug", "analyze"}
|
||||
|
||||
def test_tool_filtering_logic(self):
|
||||
"""Test the complete filtering logic using the actual server functions."""
|
||||
# Simulate ALL_TOOLS
|
||||
ALL_TOOLS = {
|
||||
"chat": MockTool("chat"),
|
||||
"debug": MockTool("debug"),
|
||||
"analyze": MockTool("analyze"),
|
||||
"version": MockTool("version"),
|
||||
"listmodels": MockTool("listmodels"),
|
||||
}
|
||||
|
||||
# Test case 1: No tools disabled
|
||||
disabled_tools = set()
|
||||
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||
|
||||
assert len(enabled_tools) == 5 # All tools included
|
||||
assert set(enabled_tools.keys()) == set(ALL_TOOLS.keys())
|
||||
|
||||
# Test case 2: Disable some regular tools
|
||||
disabled_tools = {"debug", "analyze"}
|
||||
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||
|
||||
assert len(enabled_tools) == 3 # chat, version, listmodels
|
||||
assert "debug" not in enabled_tools
|
||||
assert "analyze" not in enabled_tools
|
||||
assert "chat" in enabled_tools
|
||||
assert "version" in enabled_tools
|
||||
assert "listmodels" in enabled_tools
|
||||
|
||||
# Test case 3: Attempt to disable essential tools
|
||||
disabled_tools = {"version", "chat"}
|
||||
enabled_tools = apply_tool_filter(ALL_TOOLS, disabled_tools)
|
||||
|
||||
assert "version" in enabled_tools # Essential tool not disabled
|
||||
assert "chat" not in enabled_tools # Regular tool disabled
|
||||
assert "listmodels" in enabled_tools # Essential tool included
|
||||
|
||||
def test_unknown_tools_warning(self, caplog):
|
||||
"""Test that unknown tool names generate appropriate warnings."""
|
||||
ALL_TOOLS = {
|
||||
"chat": MockTool("chat"),
|
||||
"debug": MockTool("debug"),
|
||||
"analyze": MockTool("analyze"),
|
||||
"version": MockTool("version"),
|
||||
"listmodels": MockTool("listmodels"),
|
||||
}
|
||||
disabled_tools = {"chat", "unknown_tool", "another_unknown"}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
validate_disabled_tools(disabled_tools, ALL_TOOLS)
|
||||
assert "Unknown tools in DISABLED_TOOLS: ['another_unknown', 'unknown_tool']" in caplog.text
|
||||
|
||||
def test_essential_tools_warning(self, caplog):
|
||||
"""Test warning when trying to disable essential tools."""
|
||||
ALL_TOOLS = {
|
||||
"chat": MockTool("chat"),
|
||||
"debug": MockTool("debug"),
|
||||
"analyze": MockTool("analyze"),
|
||||
"version": MockTool("version"),
|
||||
"listmodels": MockTool("listmodels"),
|
||||
}
|
||||
disabled_tools = {"version", "chat", "debug"}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
validate_disabled_tools(disabled_tools, ALL_TOOLS)
|
||||
assert "Cannot disable essential tools: ['version']" in caplog.text
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_value,expected",
|
||||
[
|
||||
("", set()), # Empty string
|
||||
(" ", set()), # Only spaces
|
||||
(",,,", set()), # Only commas
|
||||
("chat", {"chat"}), # Single tool
|
||||
("chat,debug", {"chat", "debug"}), # Multiple tools
|
||||
("chat, debug, analyze", {"chat", "debug", "analyze"}), # With spaces
|
||||
("chat,debug,chat", {"chat", "debug"}), # Duplicates
|
||||
],
|
||||
)
|
||||
def test_parse_disabled_tools_parametrized(self, env_value, expected):
|
||||
"""Parametrized tests for various input formats."""
|
||||
with patch.dict(os.environ, {"DISABLED_TOOLS": env_value}):
|
||||
assert parse_disabled_tools_env() == expected
|
||||
@@ -483,14 +483,14 @@ class TestImageSupportIntegration:
|
||||
tool_name="chat",
|
||||
)
|
||||
|
||||
# Create child thread linked to parent
|
||||
child_thread_id = create_thread("debug", {"child": "context"}, parent_thread_id=parent_thread_id)
|
||||
# Create child thread linked to parent using a simple tool
|
||||
child_thread_id = create_thread("chat", {"prompt": "child context"}, parent_thread_id=parent_thread_id)
|
||||
add_turn(
|
||||
thread_id=child_thread_id,
|
||||
role="user",
|
||||
content="Child thread with more images",
|
||||
images=["child1.png", "shared.png"], # shared.png appears again (should prioritize newer)
|
||||
tool_name="debug",
|
||||
tool_name="chat",
|
||||
)
|
||||
|
||||
# Mock child thread context for get_thread call
|
||||
|
||||
@@ -149,7 +149,7 @@ class TestModelEnumeration:
|
||||
("o3", False), # OpenAI - not available without API key
|
||||
("grok", False), # X.AI - not available without API key
|
||||
("gemini-2.5-flash", False), # Full Gemini name - not available without API key
|
||||
("o4-mini-high", False), # OpenAI variant - not available without API key
|
||||
("o4-mini", False), # OpenAI variant - not available without API key
|
||||
("grok-3-fast", False), # X.AI variant - not available without API key
|
||||
],
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ class TestModelMetadataContinuation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_turns_uses_last_assistant_model(self):
|
||||
"""Test that with multiple turns, the last assistant turn's model is used."""
|
||||
thread_id = create_thread("analyze", {"prompt": "analyze this"})
|
||||
thread_id = create_thread("chat", {"prompt": "analyze this"})
|
||||
|
||||
# Add multiple turns with different models
|
||||
add_turn(thread_id, "assistant", "First response", model_name="gemini-2.5-flash", model_provider="google")
|
||||
@@ -185,11 +185,11 @@ class TestModelMetadataContinuation:
|
||||
async def test_thread_chain_model_preservation(self):
|
||||
"""Test model preservation across thread chains (parent-child relationships)."""
|
||||
# Create parent thread
|
||||
parent_id = create_thread("analyze", {"prompt": "analyze"})
|
||||
parent_id = create_thread("chat", {"prompt": "analyze"})
|
||||
add_turn(parent_id, "assistant", "Analysis", model_name="gemini-2.5-pro", model_provider="google")
|
||||
|
||||
# Create child thread
|
||||
child_id = create_thread("codereview", {"prompt": "review"}, parent_thread_id=parent_id)
|
||||
# Create child thread using a simple tool instead of workflow tool
|
||||
child_id = create_thread("chat", {"prompt": "review"}, parent_thread_id=parent_id)
|
||||
|
||||
# Child thread should be able to access parent's model through chain traversal
|
||||
# NOTE: Current implementation only checks current thread (not parent threads)
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestModelRestrictionService:
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
|
||||
service = ModelRestrictionService()
|
||||
|
||||
models = ["o3", "o3-mini", "o4-mini", "o4-mini-high"]
|
||||
models = ["o3", "o3-mini", "o4-mini", "o3-pro"]
|
||||
filtered = service.filter_models(ProviderType.OPENAI, models)
|
||||
|
||||
assert filtered == ["o3-mini", "o4-mini"]
|
||||
@@ -573,7 +573,7 @@ class TestShorthandRestrictions:
|
||||
|
||||
# Other models should not work
|
||||
assert not openai_provider.validate_model_name("o3")
|
||||
assert not openai_provider.validate_model_name("o4-mini-high")
|
||||
assert not openai_provider.validate_model_name("o3-pro")
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
|
||||
@@ -185,7 +185,7 @@ class TestO3TemperatureParameterFixSimple:
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
# Test O3/O4 models that should NOT support temperature parameter
|
||||
o3_o4_models = ["o3", "o3-mini", "o3-pro", "o4-mini", "o4-mini-high"]
|
||||
o3_o4_models = ["o3", "o3-mini", "o3-pro", "o4-mini"]
|
||||
|
||||
for model in o3_o4_models:
|
||||
capabilities = provider.get_capabilities(model)
|
||||
|
||||
@@ -47,14 +47,13 @@ class TestOpenAIProvider:
|
||||
assert provider.validate_model_name("o3-mini") is True
|
||||
assert provider.validate_model_name("o3-pro") is True
|
||||
assert provider.validate_model_name("o4-mini") is True
|
||||
assert provider.validate_model_name("o4-mini-high") is True
|
||||
assert provider.validate_model_name("o4-mini") is True
|
||||
|
||||
# Test valid aliases
|
||||
assert provider.validate_model_name("mini") is True
|
||||
assert provider.validate_model_name("o3mini") is True
|
||||
assert provider.validate_model_name("o4mini") is True
|
||||
assert provider.validate_model_name("o4minihigh") is True
|
||||
assert provider.validate_model_name("o4minihi") is True
|
||||
assert provider.validate_model_name("o4mini") is True
|
||||
|
||||
# Test invalid model
|
||||
assert provider.validate_model_name("invalid-model") is False
|
||||
@@ -69,15 +68,14 @@ class TestOpenAIProvider:
|
||||
assert provider._resolve_model_name("mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("o3mini") == "o3-mini"
|
||||
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("o4minihigh") == "o4-mini-high"
|
||||
assert provider._resolve_model_name("o4minihi") == "o4-mini-high"
|
||||
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
||||
|
||||
# Test full name passthrough
|
||||
assert provider._resolve_model_name("o3") == "o3"
|
||||
assert provider._resolve_model_name("o3-mini") == "o3-mini"
|
||||
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
|
||||
assert provider._resolve_model_name("o4-mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("o4-mini-high") == "o4-mini-high"
|
||||
assert provider._resolve_model_name("o4-mini") == "o4-mini"
|
||||
|
||||
def test_get_capabilities_o3(self):
|
||||
"""Test getting model capabilities for O3."""
|
||||
@@ -85,7 +83,7 @@ class TestOpenAIProvider:
|
||||
|
||||
capabilities = provider.get_capabilities("o3")
|
||||
assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities
|
||||
assert capabilities.friendly_name == "OpenAI"
|
||||
assert capabilities.friendly_name == "OpenAI (O3)"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.provider == ProviderType.OPENAI
|
||||
assert not capabilities.supports_extended_thinking
|
||||
@@ -101,8 +99,8 @@ class TestOpenAIProvider:
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("mini")
|
||||
assert capabilities.model_name == "mini" # Capabilities should show original request
|
||||
assert capabilities.friendly_name == "OpenAI"
|
||||
assert capabilities.model_name == "o4-mini" # Capabilities should show resolved model name
|
||||
assert capabilities.friendly_name == "OpenAI (O4-mini)"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.provider == ProviderType.OPENAI
|
||||
|
||||
@@ -184,11 +182,11 @@ class TestOpenAIProvider:
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["model"] == "o3-mini"
|
||||
|
||||
# Test o4minihigh -> o4-mini-high
|
||||
mock_response.model = "o4-mini-high"
|
||||
provider.generate_content(prompt="Test", model_name="o4minihigh", temperature=1.0)
|
||||
# Test o4mini -> o4-mini
|
||||
mock_response.model = "o4-mini"
|
||||
provider.generate_content(prompt="Test", model_name="o4mini", temperature=1.0)
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert call_kwargs["model"] == "o4-mini-high"
|
||||
assert call_kwargs["model"] == "o4-mini"
|
||||
|
||||
@patch("providers.openai_compatible.OpenAI")
|
||||
def test_generate_content_no_alias_passthrough(self, mock_openai_class):
|
||||
|
||||
@@ -57,7 +57,7 @@ class TestOpenRouterProvider:
|
||||
caps = provider.get_capabilities("o3")
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "openai/o3" # Resolved name
|
||||
assert caps.friendly_name == "OpenRouter"
|
||||
assert caps.friendly_name == "OpenRouter (openai/o3)"
|
||||
|
||||
# Test with a model not in registry - should get generic capabilities
|
||||
caps = provider.get_capabilities("unknown-model")
|
||||
@@ -77,7 +77,7 @@ class TestOpenRouterProvider:
|
||||
assert provider._resolve_model_name("o3-mini") == "openai/o3-mini"
|
||||
assert provider._resolve_model_name("o3mini") == "openai/o3-mini"
|
||||
assert provider._resolve_model_name("o4-mini") == "openai/o4-mini"
|
||||
assert provider._resolve_model_name("o4-mini-high") == "openai/o4-mini-high"
|
||||
assert provider._resolve_model_name("o4-mini") == "openai/o4-mini"
|
||||
assert provider._resolve_model_name("claude") == "anthropic/claude-sonnet-4"
|
||||
assert provider._resolve_model_name("mistral") == "mistralai/mistral-large-2411"
|
||||
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-r1-0528"
|
||||
|
||||
@@ -6,8 +6,8 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry
|
||||
from providers.base import ModelCapabilities, ProviderType
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
|
||||
class TestOpenRouterModelRegistry:
|
||||
@@ -24,7 +24,16 @@ class TestOpenRouterModelRegistry:
|
||||
def test_custom_config_path(self):
|
||||
"""Test registry with custom config path."""
|
||||
# Create temporary config
|
||||
config_data = {"models": [{"model_name": "test/model-1", "aliases": ["test1", "t1"], "context_window": 4096}]}
|
||||
config_data = {
|
||||
"models": [
|
||||
{
|
||||
"model_name": "test/model-1",
|
||||
"aliases": ["test1", "t1"],
|
||||
"context_window": 4096,
|
||||
"max_output_tokens": 2048,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
@@ -42,7 +51,11 @@ class TestOpenRouterModelRegistry:
|
||||
def test_environment_variable_override(self):
|
||||
"""Test OPENROUTER_MODELS_PATH environment variable."""
|
||||
# Create custom config
|
||||
config_data = {"models": [{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192}]}
|
||||
config_data = {
|
||||
"models": [
|
||||
{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192, "max_output_tokens": 4096}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
@@ -110,28 +123,29 @@ class TestOpenRouterModelRegistry:
|
||||
assert registry.resolve("non-existent") is None
|
||||
|
||||
def test_model_capabilities_conversion(self):
|
||||
"""Test conversion to ModelCapabilities."""
|
||||
"""Test that registry returns ModelCapabilities directly."""
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
config = registry.resolve("opus")
|
||||
assert config is not None
|
||||
|
||||
caps = config.to_capabilities()
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "anthropic/claude-opus-4"
|
||||
assert caps.friendly_name == "OpenRouter"
|
||||
assert caps.context_window == 200000
|
||||
assert not caps.supports_extended_thinking
|
||||
# Registry now returns ModelCapabilities objects directly
|
||||
assert config.provider == ProviderType.OPENROUTER
|
||||
assert config.model_name == "anthropic/claude-opus-4"
|
||||
assert config.friendly_name == "OpenRouter (anthropic/claude-opus-4)"
|
||||
assert config.context_window == 200000
|
||||
assert not config.supports_extended_thinking
|
||||
|
||||
def test_duplicate_alias_detection(self):
|
||||
"""Test that duplicate aliases are detected."""
|
||||
config_data = {
|
||||
"models": [
|
||||
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096},
|
||||
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096, "max_output_tokens": 2048},
|
||||
{
|
||||
"model_name": "test/model-2",
|
||||
"aliases": ["DUPE"], # Same alias, different case
|
||||
"context_window": 8192,
|
||||
"max_output_tokens": 2048,
|
||||
},
|
||||
]
|
||||
}
|
||||
@@ -199,19 +213,23 @@ class TestOpenRouterModelRegistry:
|
||||
|
||||
def test_model_with_all_capabilities(self):
|
||||
"""Test model with all capability flags."""
|
||||
config = OpenRouterModelConfig(
|
||||
from providers.base import create_temperature_constraint
|
||||
|
||||
caps = ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
model_name="test/full-featured",
|
||||
friendly_name="OpenRouter (test/full-featured)",
|
||||
aliases=["full"],
|
||||
context_window=128000,
|
||||
max_output_tokens=8192,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
description="Fully featured test model",
|
||||
temperature_constraint=create_temperature_constraint("range"),
|
||||
)
|
||||
|
||||
caps = config.to_capabilities()
|
||||
assert caps.context_window == 128000
|
||||
assert caps.supports_extended_thinking
|
||||
assert caps.supports_system_prompts
|
||||
|
||||
79
tests/test_parse_model_option.py
Normal file
79
tests/test_parse_model_option.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for parse_model_option function."""
|
||||
|
||||
from server import parse_model_option
|
||||
|
||||
|
||||
class TestParseModelOption:
|
||||
"""Test cases for model option parsing."""
|
||||
|
||||
def test_openrouter_free_suffix_preserved(self):
|
||||
"""Test that OpenRouter :free suffix is preserved as part of model name."""
|
||||
model, option = parse_model_option("openai/gpt-3.5-turbo:free")
|
||||
assert model == "openai/gpt-3.5-turbo:free"
|
||||
assert option is None
|
||||
|
||||
def test_openrouter_beta_suffix_preserved(self):
|
||||
"""Test that OpenRouter :beta suffix is preserved as part of model name."""
|
||||
model, option = parse_model_option("anthropic/claude-3-opus:beta")
|
||||
assert model == "anthropic/claude-3-opus:beta"
|
||||
assert option is None
|
||||
|
||||
def test_openrouter_preview_suffix_preserved(self):
|
||||
"""Test that OpenRouter :preview suffix is preserved as part of model name."""
|
||||
model, option = parse_model_option("google/gemini-pro:preview")
|
||||
assert model == "google/gemini-pro:preview"
|
||||
assert option is None
|
||||
|
||||
def test_ollama_tag_parsed_as_option(self):
|
||||
"""Test that Ollama tags are parsed as options."""
|
||||
model, option = parse_model_option("llama3.2:latest")
|
||||
assert model == "llama3.2"
|
||||
assert option == "latest"
|
||||
|
||||
def test_consensus_stance_parsed_as_option(self):
|
||||
"""Test that consensus stances are parsed as options."""
|
||||
model, option = parse_model_option("o3:for")
|
||||
assert model == "o3"
|
||||
assert option == "for"
|
||||
|
||||
model, option = parse_model_option("gemini-2.5-pro:against")
|
||||
assert model == "gemini-2.5-pro"
|
||||
assert option == "against"
|
||||
|
||||
def test_openrouter_unknown_suffix_parsed_as_option(self):
|
||||
"""Test that unknown suffixes on OpenRouter models are parsed as options."""
|
||||
model, option = parse_model_option("openai/gpt-4:custom-tag")
|
||||
assert model == "openai/gpt-4"
|
||||
assert option == "custom-tag"
|
||||
|
||||
def test_plain_model_name(self):
|
||||
"""Test plain model names without colons."""
|
||||
model, option = parse_model_option("gpt-4")
|
||||
assert model == "gpt-4"
|
||||
assert option is None
|
||||
|
||||
def test_url_not_parsed(self):
|
||||
"""Test that URLs are not parsed for options."""
|
||||
model, option = parse_model_option("http://localhost:8080")
|
||||
assert model == "http://localhost:8080"
|
||||
assert option is None
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test that whitespace is properly stripped."""
|
||||
model, option = parse_model_option(" openai/gpt-3.5-turbo:free ")
|
||||
assert model == "openai/gpt-3.5-turbo:free"
|
||||
assert option is None
|
||||
|
||||
model, option = parse_model_option(" llama3.2 : latest ")
|
||||
assert model == "llama3.2"
|
||||
assert option == "latest"
|
||||
|
||||
def test_case_insensitive_suffix_matching(self):
|
||||
"""Test that OpenRouter suffix matching is case-insensitive."""
|
||||
model, option = parse_model_option("openai/gpt-3.5-turbo:FREE")
|
||||
assert model == "openai/gpt-3.5-turbo:FREE" # Original case preserved
|
||||
assert option is None
|
||||
|
||||
model, option = parse_model_option("openai/gpt-3.5-turbo:Free")
|
||||
assert model == "openai/gpt-3.5-turbo:Free" # Original case preserved
|
||||
assert option is None
|
||||
@@ -58,7 +58,13 @@ class TestProviderRoutingBugs:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
@@ -66,6 +72,7 @@ class TestProviderRoutingBugs:
|
||||
os.environ.pop("GEMINI_API_KEY", None) # No Google API key
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("XAI_API_KEY", None)
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||
|
||||
# Register only OpenRouter provider (like in server.py:configure_providers)
|
||||
@@ -113,12 +120,24 @@ class TestProviderRoutingBugs:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Set up scenario: NO API keys at all
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Create tool to test fallback logic
|
||||
@@ -151,7 +170,13 @@ class TestProviderRoutingBugs:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
@@ -160,6 +185,7 @@ class TestProviderRoutingBugs:
|
||||
os.environ["OPENAI_API_KEY"] = "test-openai-key"
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||
os.environ.pop("XAI_API_KEY", None)
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||
|
||||
# Register providers in priority order (like server.py)
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
@@ -215,9 +215,7 @@ class TestOpenAIProvider:
|
||||
assert provider.validate_model_name("o3-mini") # Backwards compatibility
|
||||
assert provider.validate_model_name("o4-mini")
|
||||
assert provider.validate_model_name("o4mini")
|
||||
assert provider.validate_model_name("o4-mini-high")
|
||||
assert provider.validate_model_name("o4minihigh")
|
||||
assert provider.validate_model_name("o4minihi")
|
||||
assert provider.validate_model_name("o4-mini")
|
||||
assert not provider.validate_model_name("gpt-4o")
|
||||
assert not provider.validate_model_name("invalid-model")
|
||||
|
||||
@@ -229,4 +227,4 @@ class TestOpenAIProvider:
|
||||
assert not provider.supports_thinking_mode("o3mini")
|
||||
assert not provider.supports_thinking_mode("o3-mini")
|
||||
assert not provider.supports_thinking_mode("o4-mini")
|
||||
assert not provider.supports_thinking_mode("o4-mini-high")
|
||||
assert not provider.supports_thinking_mode("o4-mini")
|
||||
|
||||
205
tests/test_supported_models_aliases.py
Normal file
205
tests/test_supported_models_aliases.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Test the SUPPORTED_MODELS aliases structure across all providers."""
|
||||
|
||||
from providers.dial import DIALModelProvider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
|
||||
class TestSupportedModelsAliases:
|
||||
"""Test that all providers have correctly structured SUPPORTED_MODELS with aliases."""
|
||||
|
||||
def test_gemini_provider_aliases(self):
|
||||
"""Test Gemini provider's alias structure."""
|
||||
provider = GeminiModelProvider("test-key")
|
||||
|
||||
# Check that all models have ModelCapabilities with aliases
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||
|
||||
# Test specific aliases
|
||||
assert "flash" in provider.SUPPORTED_MODELS["gemini-2.5-flash"].aliases
|
||||
assert "pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro"].aliases
|
||||
assert "flash-2.0" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
|
||||
assert "flash2" in provider.SUPPORTED_MODELS["gemini-2.0-flash"].aliases
|
||||
assert "flashlite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
|
||||
assert "flash-lite" in provider.SUPPORTED_MODELS["gemini-2.0-flash-lite"].aliases
|
||||
|
||||
# Test alias resolution
|
||||
assert provider._resolve_model_name("flash") == "gemini-2.5-flash"
|
||||
assert provider._resolve_model_name("pro") == "gemini-2.5-pro"
|
||||
assert provider._resolve_model_name("flash-2.0") == "gemini-2.0-flash"
|
||||
assert provider._resolve_model_name("flash2") == "gemini-2.0-flash"
|
||||
assert provider._resolve_model_name("flashlite") == "gemini-2.0-flash-lite"
|
||||
|
||||
# Test case insensitive resolution
|
||||
assert provider._resolve_model_name("Flash") == "gemini-2.5-flash"
|
||||
assert provider._resolve_model_name("PRO") == "gemini-2.5-pro"
|
||||
|
||||
def test_openai_provider_aliases(self):
|
||||
"""Test OpenAI provider's alias structure."""
|
||||
provider = OpenAIModelProvider("test-key")
|
||||
|
||||
# Check that all models have ModelCapabilities with aliases
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||
|
||||
# Test specific aliases
|
||||
assert "mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
||||
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
||||
assert "o3mini" in provider.SUPPORTED_MODELS["o3-mini"].aliases
|
||||
assert "o3-pro" in provider.SUPPORTED_MODELS["o3-pro-2025-06-10"].aliases
|
||||
assert "o4mini" in provider.SUPPORTED_MODELS["o4-mini"].aliases
|
||||
assert "gpt4.1" in provider.SUPPORTED_MODELS["gpt-4.1-2025-04-14"].aliases
|
||||
|
||||
# Test alias resolution
|
||||
assert provider._resolve_model_name("mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("o3mini") == "o3-mini"
|
||||
assert provider._resolve_model_name("o3-pro") == "o3-pro-2025-06-10"
|
||||
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("gpt4.1") == "gpt-4.1-2025-04-14"
|
||||
|
||||
# Test case insensitive resolution
|
||||
assert provider._resolve_model_name("Mini") == "o4-mini"
|
||||
assert provider._resolve_model_name("O3MINI") == "o3-mini"
|
||||
|
||||
def test_xai_provider_aliases(self):
|
||||
"""Test XAI provider's alias structure."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# Check that all models have ModelCapabilities with aliases
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||
|
||||
# Test specific aliases
|
||||
assert "grok" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
||||
assert "grok3" in provider.SUPPORTED_MODELS["grok-3"].aliases
|
||||
assert "grok3fast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
||||
assert "grokfast" in provider.SUPPORTED_MODELS["grok-3-fast"].aliases
|
||||
|
||||
# Test alias resolution
|
||||
assert provider._resolve_model_name("grok") == "grok-3"
|
||||
assert provider._resolve_model_name("grok3") == "grok-3"
|
||||
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
||||
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
|
||||
|
||||
# Test case insensitive resolution
|
||||
assert provider._resolve_model_name("Grok") == "grok-3"
|
||||
assert provider._resolve_model_name("GROKFAST") == "grok-3-fast"
|
||||
|
||||
def test_dial_provider_aliases(self):
|
||||
"""Test DIAL provider's alias structure."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Check that all models have ModelCapabilities with aliases
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
assert hasattr(config, "aliases"), f"{model_name} must have aliases attribute"
|
||||
assert isinstance(config.aliases, list), f"{model_name} aliases must be a list"
|
||||
|
||||
# Test specific aliases
|
||||
assert "o3" in provider.SUPPORTED_MODELS["o3-2025-04-16"].aliases
|
||||
assert "o4-mini" in provider.SUPPORTED_MODELS["o4-mini-2025-04-16"].aliases
|
||||
assert "sonnet-4" in provider.SUPPORTED_MODELS["anthropic.claude-sonnet-4-20250514-v1:0"].aliases
|
||||
assert "opus-4" in provider.SUPPORTED_MODELS["anthropic.claude-opus-4-20250514-v1:0"].aliases
|
||||
assert "gemini-2.5-pro" in provider.SUPPORTED_MODELS["gemini-2.5-pro-preview-05-06"].aliases
|
||||
|
||||
# Test alias resolution
|
||||
assert provider._resolve_model_name("o3") == "o3-2025-04-16"
|
||||
assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16"
|
||||
assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0"
|
||||
|
||||
# Test case insensitive resolution
|
||||
assert provider._resolve_model_name("O3") == "o3-2025-04-16"
|
||||
assert provider._resolve_model_name("SONNET-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
|
||||
def test_list_models_includes_aliases(self):
|
||||
"""Test that list_models returns both base models and aliases."""
|
||||
# Test Gemini
|
||||
gemini_provider = GeminiModelProvider("test-key")
|
||||
gemini_models = gemini_provider.list_models(respect_restrictions=False)
|
||||
assert "gemini-2.5-flash" in gemini_models
|
||||
assert "flash" in gemini_models
|
||||
assert "gemini-2.5-pro" in gemini_models
|
||||
assert "pro" in gemini_models
|
||||
|
||||
# Test OpenAI
|
||||
openai_provider = OpenAIModelProvider("test-key")
|
||||
openai_models = openai_provider.list_models(respect_restrictions=False)
|
||||
assert "o4-mini" in openai_models
|
||||
assert "mini" in openai_models
|
||||
assert "o3-mini" in openai_models
|
||||
assert "o3mini" in openai_models
|
||||
|
||||
# Test XAI
|
||||
xai_provider = XAIModelProvider("test-key")
|
||||
xai_models = xai_provider.list_models(respect_restrictions=False)
|
||||
assert "grok-3" in xai_models
|
||||
assert "grok" in xai_models
|
||||
assert "grok-3-fast" in xai_models
|
||||
assert "grokfast" in xai_models
|
||||
|
||||
# Test DIAL
|
||||
dial_provider = DIALModelProvider("test-key")
|
||||
dial_models = dial_provider.list_models(respect_restrictions=False)
|
||||
assert "o3-2025-04-16" in dial_models
|
||||
assert "o3" in dial_models
|
||||
|
||||
def test_list_all_known_models_includes_aliases(self):
|
||||
"""Test that list_all_known_models returns all models and aliases in lowercase."""
|
||||
# Test Gemini
|
||||
gemini_provider = GeminiModelProvider("test-key")
|
||||
gemini_all = gemini_provider.list_all_known_models()
|
||||
assert "gemini-2.5-flash" in gemini_all
|
||||
assert "flash" in gemini_all
|
||||
assert "gemini-2.5-pro" in gemini_all
|
||||
assert "pro" in gemini_all
|
||||
# All should be lowercase
|
||||
assert all(model == model.lower() for model in gemini_all)
|
||||
|
||||
# Test OpenAI
|
||||
openai_provider = OpenAIModelProvider("test-key")
|
||||
openai_all = openai_provider.list_all_known_models()
|
||||
assert "o4-mini" in openai_all
|
||||
assert "mini" in openai_all
|
||||
assert "o3-mini" in openai_all
|
||||
assert "o3mini" in openai_all
|
||||
# All should be lowercase
|
||||
assert all(model == model.lower() for model in openai_all)
|
||||
|
||||
def test_no_string_shorthand_in_supported_models(self):
|
||||
"""Test that no provider has string-based shorthands anymore."""
|
||||
providers = [
|
||||
GeminiModelProvider("test-key"),
|
||||
OpenAIModelProvider("test-key"),
|
||||
XAIModelProvider("test-key"),
|
||||
DIALModelProvider("test-key"),
|
||||
]
|
||||
|
||||
for provider in providers:
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
# All values must be ModelCapabilities objects, not strings or dicts
|
||||
from providers.base import ModelCapabilities
|
||||
|
||||
assert isinstance(config, ModelCapabilities), (
|
||||
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "
|
||||
f"must be a ModelCapabilities object, not {type(config).__name__}"
|
||||
)
|
||||
|
||||
def test_resolve_returns_original_if_not_found(self):
|
||||
"""Test that _resolve_model_name returns original name if alias not found."""
|
||||
providers = [
|
||||
GeminiModelProvider("test-key"),
|
||||
OpenAIModelProvider("test-key"),
|
||||
XAIModelProvider("test-key"),
|
||||
DIALModelProvider("test-key"),
|
||||
]
|
||||
|
||||
for provider in providers:
|
||||
# Test with unknown model name
|
||||
assert provider._resolve_model_name("unknown-model") == "unknown-model"
|
||||
assert provider._resolve_model_name("gpt-4") == "gpt-4"
|
||||
assert provider._resolve_model_name("claude-3") == "claude-3"
|
||||
@@ -48,7 +48,13 @@ class TestWorkflowMetadata:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
@@ -56,6 +62,7 @@ class TestWorkflowMetadata:
|
||||
os.environ.pop("GEMINI_API_KEY", None)
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("XAI_API_KEY", None)
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||
|
||||
# Register OpenRouter provider
|
||||
@@ -124,7 +131,13 @@ class TestWorkflowMetadata:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
@@ -132,6 +145,7 @@ class TestWorkflowMetadata:
|
||||
os.environ.pop("GEMINI_API_KEY", None)
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("XAI_API_KEY", None)
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||
|
||||
# Register OpenRouter provider
|
||||
@@ -182,6 +196,15 @@ class TestWorkflowMetadata:
|
||||
"""
|
||||
Test that workflow tools handle metadata gracefully when model context is missing.
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["OPENROUTER_ALLOWED_MODELS"]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
# Clear any restrictions
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||
|
||||
# Create debug tool
|
||||
debug_tool = DebugIssueTool()
|
||||
|
||||
@@ -220,6 +243,14 @@ class TestWorkflowMetadata:
|
||||
assert metadata["model_used"] == "flash", "model_used should be from request"
|
||||
assert metadata["provider_used"] == "unknown", "provider_used should be 'unknown' in fallback"
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
def test_workflow_metadata_preserves_existing_response_fields(self):
|
||||
"""
|
||||
@@ -227,7 +258,13 @@ class TestWorkflowMetadata:
|
||||
"""
|
||||
# Save original environment
|
||||
original_env = {}
|
||||
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||
for key in [
|
||||
"GEMINI_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
]:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
try:
|
||||
@@ -235,6 +272,7 @@ class TestWorkflowMetadata:
|
||||
os.environ.pop("GEMINI_API_KEY", None)
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("XAI_API_KEY", None)
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None) # Clear any restrictions
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||
|
||||
# Register OpenRouter provider
|
||||
|
||||
@@ -77,7 +77,7 @@ class TestXAIProvider:
|
||||
|
||||
capabilities = provider.get_capabilities("grok-3")
|
||||
assert capabilities.model_name == "grok-3"
|
||||
assert capabilities.friendly_name == "X.AI"
|
||||
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
||||
assert capabilities.context_window == 131_072
|
||||
assert capabilities.provider == ProviderType.XAI
|
||||
assert not capabilities.supports_extended_thinking
|
||||
@@ -96,7 +96,7 @@ class TestXAIProvider:
|
||||
|
||||
capabilities = provider.get_capabilities("grok-3-fast")
|
||||
assert capabilities.model_name == "grok-3-fast"
|
||||
assert capabilities.friendly_name == "X.AI"
|
||||
assert capabilities.friendly_name == "X.AI (Grok 3 Fast)"
|
||||
assert capabilities.context_window == 131_072
|
||||
assert capabilities.provider == ProviderType.XAI
|
||||
assert not capabilities.supports_extended_thinking
|
||||
@@ -212,31 +212,34 @@ class TestXAIProvider:
|
||||
assert provider.FRIENDLY_NAME == "X.AI"
|
||||
|
||||
capabilities = provider.get_capabilities("grok-3")
|
||||
assert capabilities.friendly_name == "X.AI"
|
||||
assert capabilities.friendly_name == "X.AI (Grok 3)"
|
||||
|
||||
def test_supported_models_structure(self):
|
||||
"""Test that SUPPORTED_MODELS has the correct structure."""
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
# Check that all expected models are present
|
||||
# Check that all expected base models are present
|
||||
assert "grok-3" in provider.SUPPORTED_MODELS
|
||||
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
||||
assert "grok" in provider.SUPPORTED_MODELS
|
||||
assert "grok3" in provider.SUPPORTED_MODELS
|
||||
assert "grokfast" in provider.SUPPORTED_MODELS
|
||||
assert "grok3fast" in provider.SUPPORTED_MODELS
|
||||
|
||||
# Check model configs have required fields
|
||||
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
||||
assert isinstance(grok3_config, dict)
|
||||
assert "context_window" in grok3_config
|
||||
assert "supports_extended_thinking" in grok3_config
|
||||
assert grok3_config["context_window"] == 131_072
|
||||
assert grok3_config["supports_extended_thinking"] is False
|
||||
from providers.base import ModelCapabilities
|
||||
|
||||
# Check shortcuts point to full names
|
||||
assert provider.SUPPORTED_MODELS["grok"] == "grok-3"
|
||||
assert provider.SUPPORTED_MODELS["grokfast"] == "grok-3-fast"
|
||||
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
||||
assert isinstance(grok3_config, ModelCapabilities)
|
||||
assert hasattr(grok3_config, "context_window")
|
||||
assert hasattr(grok3_config, "supports_extended_thinking")
|
||||
assert hasattr(grok3_config, "aliases")
|
||||
assert grok3_config.context_window == 131_072
|
||||
assert grok3_config.supports_extended_thinking is False
|
||||
|
||||
# Check aliases are correctly structured
|
||||
assert "grok" in grok3_config.aliases
|
||||
assert "grok3" in grok3_config.aliases
|
||||
|
||||
grok3fast_config = provider.SUPPORTED_MODELS["grok-3-fast"]
|
||||
assert "grok3fast" in grok3fast_config.aliases
|
||||
assert "grokfast" in grok3fast_config.aliases
|
||||
|
||||
@patch("providers.openai_compatible.OpenAI")
|
||||
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
|
||||
|
||||
@@ -99,15 +99,11 @@ class ListModelsTool(BaseTool):
|
||||
output_lines.append("**Status**: Configured and available")
|
||||
output_lines.append("\n**Models**:")
|
||||
|
||||
# Get models from the provider's SUPPORTED_MODELS
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
# Skip alias entries (string values)
|
||||
if isinstance(config, str):
|
||||
continue
|
||||
|
||||
# Get description and context from the model config
|
||||
description = config.get("description", "No description available")
|
||||
context_window = config.get("context_window", 0)
|
||||
# Get models from the provider's model configurations
|
||||
for model_name, capabilities in provider.get_model_configurations().items():
|
||||
# Get description and context from the ModelCapabilities object
|
||||
description = capabilities.description or "No description available"
|
||||
context_window = capabilities.context_window
|
||||
|
||||
# Format context window
|
||||
if context_window >= 1_000_000:
|
||||
@@ -133,13 +129,14 @@ class ListModelsTool(BaseTool):
|
||||
|
||||
# Show aliases for this provider
|
||||
aliases = []
|
||||
for alias_name, target in provider.SUPPORTED_MODELS.items():
|
||||
if isinstance(target, str): # This is an alias
|
||||
aliases.append(f"- `{alias_name}` → `{target}`")
|
||||
for model_name, capabilities in provider.get_model_configurations().items():
|
||||
if capabilities.aliases:
|
||||
for alias in capabilities.aliases:
|
||||
aliases.append(f"- `{alias}` → `{model_name}`")
|
||||
|
||||
if aliases:
|
||||
output_lines.append("\n**Aliases**:")
|
||||
output_lines.extend(aliases)
|
||||
output_lines.extend(sorted(aliases)) # Sort for consistent output
|
||||
else:
|
||||
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
||||
|
||||
@@ -237,7 +234,7 @@ class ListModelsTool(BaseTool):
|
||||
|
||||
for alias in registry.list_aliases():
|
||||
config = registry.resolve(alias)
|
||||
if config and hasattr(config, "is_custom") and config.is_custom:
|
||||
if config and config.is_custom:
|
||||
custom_models.append((alias, config))
|
||||
|
||||
if custom_models:
|
||||
|
||||
@@ -256,8 +256,8 @@ class BaseTool(ABC):
|
||||
# Find all custom models (is_custom=true)
|
||||
for alias in registry.list_aliases():
|
||||
config = registry.resolve(alias)
|
||||
# Use hasattr for defensive programming - is_custom is optional with default False
|
||||
if config and hasattr(config, "is_custom") and config.is_custom:
|
||||
# Check if this is a custom model that requires custom endpoints
|
||||
if config and config.is_custom:
|
||||
if alias not in all_models:
|
||||
all_models.append(alias)
|
||||
except Exception as e:
|
||||
@@ -311,12 +311,16 @@ class BaseTool(ABC):
|
||||
ProviderType.GOOGLE: "Gemini models",
|
||||
ProviderType.OPENAI: "OpenAI models",
|
||||
ProviderType.XAI: "X.AI GROK models",
|
||||
ProviderType.DIAL: "DIAL models",
|
||||
ProviderType.CUSTOM: "Custom models",
|
||||
ProviderType.OPENROUTER: "OpenRouter models",
|
||||
}
|
||||
|
||||
# Check available providers and add their model descriptions
|
||||
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI]:
|
||||
|
||||
# Start with native providers
|
||||
for provider_type in [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL]:
|
||||
# Only if this is registered / available
|
||||
provider = ModelProviderRegistry.get_provider(provider_type)
|
||||
if provider:
|
||||
provider_section_added = False
|
||||
@@ -324,13 +328,13 @@ class BaseTool(ABC):
|
||||
try:
|
||||
# Get model config to extract description
|
||||
model_config = provider.SUPPORTED_MODELS.get(model_name)
|
||||
if isinstance(model_config, dict) and "description" in model_config:
|
||||
if model_config and model_config.description:
|
||||
if not provider_section_added:
|
||||
model_desc_parts.append(
|
||||
f"\n{provider_names[provider_type]} - Available when {provider_type.value.upper()}_API_KEY is configured:"
|
||||
)
|
||||
provider_section_added = True
|
||||
model_desc_parts.append(f"- '{model_name}': {model_config['description']}")
|
||||
model_desc_parts.append(f"- '{model_name}': {model_config.description}")
|
||||
except Exception:
|
||||
# Skip models without descriptions
|
||||
continue
|
||||
@@ -346,8 +350,8 @@ class BaseTool(ABC):
|
||||
# Find all custom models (is_custom=true)
|
||||
for alias in registry.list_aliases():
|
||||
config = registry.resolve(alias)
|
||||
# Use hasattr for defensive programming - is_custom is optional with default False
|
||||
if config and hasattr(config, "is_custom") and config.is_custom:
|
||||
# Check if this is a custom model that requires custom endpoints
|
||||
if config and config.is_custom:
|
||||
# Format context window
|
||||
context_tokens = config.context_window
|
||||
if context_tokens >= 1_000_000:
|
||||
|
||||
@@ -128,6 +128,10 @@ class ModelRestrictionService:
|
||||
|
||||
allowed_set = self.restrictions[provider_type]
|
||||
|
||||
if len(allowed_set) == 0:
|
||||
# Empty set - allowed
|
||||
return True
|
||||
|
||||
# Check both the resolved name and original name (if different)
|
||||
names_to_check = {model_name.lower()}
|
||||
if original_name and original_name.lower() != model_name.lower():
|
||||
|
||||
Reference in New Issue
Block a user