Native support for xAI Grok3
Model shorthand mapping related fixes Comprehensive auto-mode related tests
This commit is contained in:
13
.env.example
13
.env.example
@@ -18,6 +18,9 @@ GEMINI_API_KEY=your_gemini_api_key_here
|
|||||||
# Get your OpenAI API key from: https://platform.openai.com/api-keys
|
# Get your OpenAI API key from: https://platform.openai.com/api-keys
|
||||||
OPENAI_API_KEY=your_openai_api_key_here
|
OPENAI_API_KEY=your_openai_api_key_here
|
||||||
|
|
||||||
|
# Get your X.AI API key from: https://console.x.ai/
|
||||||
|
XAI_API_KEY=your_xai_api_key_here
|
||||||
|
|
||||||
# Option 2: Use OpenRouter for access to multiple models through one API
|
# Option 2: Use OpenRouter for access to multiple models through one API
|
||||||
# Get your OpenRouter API key from: https://openrouter.ai/
|
# Get your OpenRouter API key from: https://openrouter.ai/
|
||||||
# If using OpenRouter, comment out the native API keys above
|
# If using OpenRouter, comment out the native API keys above
|
||||||
@@ -68,15 +71,25 @@ DEFAULT_THINKING_MODE_THINKDEEP=high
|
|||||||
# - flash (shorthand for gemini-2.5-flash-preview-05-20)
|
# - flash (shorthand for gemini-2.5-flash-preview-05-20)
|
||||||
# - pro (shorthand for gemini-2.5-pro-preview-06-05)
|
# - pro (shorthand for gemini-2.5-pro-preview-06-05)
|
||||||
#
|
#
|
||||||
|
# Supported X.AI GROK models:
|
||||||
|
# - grok-3 (131K context, advanced reasoning)
|
||||||
|
# - grok-3-fast (131K context, higher performance but more expensive)
|
||||||
|
# - grok (shorthand for grok-3)
|
||||||
|
# - grok3 (shorthand for grok-3)
|
||||||
|
# - grokfast (shorthand for grok-3-fast)
|
||||||
|
#
|
||||||
# Examples:
|
# Examples:
|
||||||
# OPENAI_ALLOWED_MODELS=o3-mini,o4-mini,mini # Only allow mini models (cost control)
|
# OPENAI_ALLOWED_MODELS=o3-mini,o4-mini,mini # Only allow mini models (cost control)
|
||||||
# GOOGLE_ALLOWED_MODELS=flash # Only allow Flash (fast responses)
|
# GOOGLE_ALLOWED_MODELS=flash # Only allow Flash (fast responses)
|
||||||
|
# XAI_ALLOWED_MODELS=grok-3 # Only allow standard GROK (not fast variant)
|
||||||
# OPENAI_ALLOWED_MODELS=o4-mini # Single model standardization
|
# OPENAI_ALLOWED_MODELS=o4-mini # Single model standardization
|
||||||
# GOOGLE_ALLOWED_MODELS=flash,pro # Allow both Gemini models
|
# GOOGLE_ALLOWED_MODELS=flash,pro # Allow both Gemini models
|
||||||
|
# XAI_ALLOWED_MODELS=grok,grok-3-fast # Allow both GROK variants
|
||||||
#
|
#
|
||||||
# Note: These restrictions apply even in 'auto' mode - Claude will only pick from allowed models
|
# Note: These restrictions apply even in 'auto' mode - Claude will only pick from allowed models
|
||||||
# OPENAI_ALLOWED_MODELS=
|
# OPENAI_ALLOWED_MODELS=
|
||||||
# GOOGLE_ALLOWED_MODELS=
|
# GOOGLE_ALLOWED_MODELS=
|
||||||
|
# XAI_ALLOWED_MODELS=
|
||||||
|
|
||||||
# Optional: Custom model configuration file path
|
# Optional: Custom model configuration file path
|
||||||
# Override the default location of custom_models.json
|
# Override the default location of custom_models.json
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
https://github.com/user-attachments/assets/8097e18e-b926-4d8b-ba14-a979e4c58bda
|
https://github.com/user-attachments/assets/8097e18e-b926-4d8b-ba14-a979e4c58bda
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<b>🤖 Claude + [Gemini / O3 / OpenRouter / Ollama / Any Model] = Your Ultimate AI Development Team</b>
|
<b>🤖 Claude + [Gemini / O3 / GROK / OpenRouter / Ollama / Any Model] = Your Ultimate AI Development Team</b>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<br/>
|
<br/>
|
||||||
@@ -115,6 +115,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan
|
|||||||
**Option B: Native APIs**
|
**Option B: Native APIs**
|
||||||
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models.
|
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models.
|
||||||
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) to get an API key for O3 model access.
|
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) to get an API key for O3 model access.
|
||||||
|
- **X.AI**: Visit [X.AI Console](https://console.x.ai/) to get an API key for GROK model access.
|
||||||
|
|
||||||
**Option C: Custom API Endpoints (Local models like Ollama, vLLM)**
|
**Option C: Custom API Endpoints (Local models like Ollama, vLLM)**
|
||||||
[Please see the setup guide](docs/custom_models.md#option-2-custom-api-setup-ollama-vllm-etc). With a custom API you can use:
|
[Please see the setup guide](docs/custom_models.md#option-2-custom-api-setup-ollama-vllm-etc). With a custom API you can use:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import os
|
|||||||
# These values are used in server responses and for tracking releases
|
# These values are used in server responses and for tracking releases
|
||||||
# IMPORTANT: This is the single source of truth for version and author info
|
# IMPORTANT: This is the single source of truth for version and author info
|
||||||
# Semantic versioning: MAJOR.MINOR.PATCH
|
# Semantic versioning: MAJOR.MINOR.PATCH
|
||||||
__version__ = "4.5.1"
|
__version__ = "4.6.0"
|
||||||
# Last update date in ISO format
|
# Last update date in ISO format
|
||||||
__updated__ = "2025-06-15"
|
__updated__ = "2025-06-15"
|
||||||
# Primary maintainer
|
# Primary maintainer
|
||||||
@@ -53,6 +53,12 @@ MODEL_CAPABILITIES_DESC = {
|
|||||||
"o3-pro": "Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
"o3-pro": "Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
||||||
"o4-mini": "Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
"o4-mini": "Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
||||||
"o4-mini-high": "Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks",
|
"o4-mini-high": "Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks",
|
||||||
|
# X.AI GROK models - Available when XAI_API_KEY is configured
|
||||||
|
"grok": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||||
|
"grok-3": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||||
|
"grok-3-fast": "GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
||||||
|
"grok3": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||||
|
"grokfast": "GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
||||||
# Full model names also supported (for explicit specification)
|
# Full model names also supported (for explicit specification)
|
||||||
"gemini-2.5-flash-preview-05-20": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
"gemini-2.5-flash-preview-05-20": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||||
"gemini-2.5-pro-preview-06-05": (
|
"gemini-2.5-pro-preview-06-05": (
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||||
|
- XAI_API_KEY=${XAI_API_KEY:-}
|
||||||
# OpenRouter support
|
# OpenRouter support
|
||||||
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
|
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
|
||||||
- CUSTOM_MODELS_CONFIG_PATH=${CUSTOM_MODELS_CONFIG_PATH:-}
|
- CUSTOM_MODELS_CONFIG_PATH=${CUSTOM_MODELS_CONFIG_PATH:-}
|
||||||
@@ -45,6 +46,7 @@ services:
|
|||||||
# Model usage restrictions
|
# Model usage restrictions
|
||||||
- OPENAI_ALLOWED_MODELS=${OPENAI_ALLOWED_MODELS:-}
|
- OPENAI_ALLOWED_MODELS=${OPENAI_ALLOWED_MODELS:-}
|
||||||
- GOOGLE_ALLOWED_MODELS=${GOOGLE_ALLOWED_MODELS:-}
|
- GOOGLE_ALLOWED_MODELS=${GOOGLE_ALLOWED_MODELS:-}
|
||||||
|
- XAI_ALLOWED_MODELS=${XAI_ALLOWED_MODELS:-}
|
||||||
- REDIS_URL=redis://redis:6379/0
|
- REDIS_URL=redis://redis:6379/0
|
||||||
# Use HOME not PWD: Claude needs access to any absolute file path, not just current project,
|
# Use HOME not PWD: Claude needs access to any absolute file path, not just current project,
|
||||||
# and Claude Code could be running from multiple locations at the same time
|
# and Claude Code could be running from multiple locations at the same time
|
||||||
|
|||||||
@@ -23,9 +23,11 @@ Inherit from `ModelProvider` when:
|
|||||||
### Option B: OpenAI-Compatible Provider (Simplified)
|
### Option B: OpenAI-Compatible Provider (Simplified)
|
||||||
Inherit from `OpenAICompatibleProvider` when:
|
Inherit from `OpenAICompatibleProvider` when:
|
||||||
- Your API follows OpenAI's chat completion format
|
- Your API follows OpenAI's chat completion format
|
||||||
- You want to reuse existing implementation for `generate_content` and `count_tokens`
|
- You want to reuse existing implementation for most functionality
|
||||||
- You only need to define model capabilities and validation
|
- You only need to define model capabilities and validation
|
||||||
|
|
||||||
|
⚠️ **CRITICAL**: If your provider has model aliases (shorthands), you **MUST** override `generate_content()` to resolve aliases before API calls. See implementation example below.
|
||||||
|
|
||||||
## Step-by-Step Guide
|
## Step-by-Step Guide
|
||||||
|
|
||||||
### 1. Add Provider Type to Enum
|
### 1. Add Provider Type to Enum
|
||||||
@@ -177,8 +179,11 @@ For providers with OpenAI-compatible APIs, the implementation is much simpler:
|
|||||||
"""Example provider using OpenAI-compatible interface."""
|
"""Example provider using OpenAI-compatible interface."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
RangeTemperatureConstraint,
|
||||||
)
|
)
|
||||||
@@ -268,7 +273,31 @@ class ExampleProvider(OpenAICompatibleProvider):
|
|||||||
return shorthand_value
|
return shorthand_value
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
# Note: generate_content and count_tokens are inherited from OpenAICompatibleProvider
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using API with proper model name resolution."""
|
||||||
|
# CRITICAL: Resolve model alias before making API call
|
||||||
|
# This ensures aliases like "large" get sent as "example-model-large" to the API
|
||||||
|
resolved_model_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# Call parent implementation with resolved model name
|
||||||
|
return super().generate_content(
|
||||||
|
prompt=prompt,
|
||||||
|
model_name=resolved_model_name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: count_tokens is inherited from OpenAICompatibleProvider
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. Update Registry Configuration
|
### 3. Update Registry Configuration
|
||||||
@@ -291,7 +320,32 @@ def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]
|
|||||||
# ... rest of the method
|
# ... rest of the method
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. Register Provider in server.py
|
### 4. Configure Docker Environment Variables
|
||||||
|
|
||||||
|
**CRITICAL**: You must add your provider's environment variables to `docker-compose.yml` for them to be available in the Docker container.
|
||||||
|
|
||||||
|
Add your API key and restriction variables to the `environment` section:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
services:
|
||||||
|
zen-mcp:
|
||||||
|
# ... other configuration ...
|
||||||
|
environment:
|
||||||
|
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||||
|
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||||
|
- EXAMPLE_API_KEY=${EXAMPLE_API_KEY:-} # Add this line
|
||||||
|
# OpenRouter support
|
||||||
|
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
|
||||||
|
# ... other variables ...
|
||||||
|
# Model usage restrictions
|
||||||
|
- OPENAI_ALLOWED_MODELS=${OPENAI_ALLOWED_MODELS:-}
|
||||||
|
- GOOGLE_ALLOWED_MODELS=${GOOGLE_ALLOWED_MODELS:-}
|
||||||
|
- EXAMPLE_ALLOWED_MODELS=${EXAMPLE_ALLOWED_MODELS:-} # Add this line
|
||||||
|
```
|
||||||
|
|
||||||
|
⚠️ **Without this step**, the Docker container won't have access to your environment variables, and your provider won't be registered even if the API key is set in your `.env` file.
|
||||||
|
|
||||||
|
### 5. Register Provider in server.py
|
||||||
|
|
||||||
The `configure_providers()` function in `server.py` handles provider registration. You need to:
|
The `configure_providers()` function in `server.py` handles provider registration. You need to:
|
||||||
|
|
||||||
@@ -355,7 +409,7 @@ def configure_providers():
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5. Add Model Capabilities for Auto Mode
|
### 6. Add Model Capabilities for Auto Mode
|
||||||
|
|
||||||
Update `config.py` to add your models to `MODEL_CAPABILITIES_DESC`:
|
Update `config.py` to add your models to `MODEL_CAPABILITIES_DESC`:
|
||||||
|
|
||||||
@@ -372,9 +426,9 @@ MODEL_CAPABILITIES_DESC = {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 6. Update Documentation
|
### 7. Update Documentation
|
||||||
|
|
||||||
#### 6.1. Update README.md
|
#### 7.1. Update README.md
|
||||||
|
|
||||||
Add your provider to the quickstart section:
|
Add your provider to the quickstart section:
|
||||||
|
|
||||||
@@ -396,9 +450,9 @@ Also update the .env file example:
|
|||||||
# EXAMPLE_API_KEY=your-example-api-key-here # Add this
|
# EXAMPLE_API_KEY=your-example-api-key-here # Add this
|
||||||
```
|
```
|
||||||
|
|
||||||
### 7. Write Tests
|
### 8. Write Tests
|
||||||
|
|
||||||
#### 7.1. Unit Tests
|
#### 8.1. Unit Tests
|
||||||
|
|
||||||
Create `tests/test_example_provider.py`:
|
Create `tests/test_example_provider.py`:
|
||||||
|
|
||||||
@@ -460,7 +514,7 @@ class TestExampleProvider:
|
|||||||
assert capabilities.temperature_constraint.max_temp == 2.0
|
assert capabilities.temperature_constraint.max_temp == 2.0
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 7.2. Simulator Tests (Real-World Validation)
|
#### 8.2. Simulator Tests (Real-World Validation)
|
||||||
|
|
||||||
Create a simulator test to validate that your provider works correctly in real-world scenarios. Create `simulator_tests/test_example_models.py`:
|
Create a simulator test to validate that your provider works correctly in real-world scenarios. Create `simulator_tests/test_example_models.py`:
|
||||||
|
|
||||||
@@ -696,6 +750,36 @@ SUPPORTED_MODELS = {
|
|||||||
|
|
||||||
The `_resolve_model_name()` method handles this mapping automatically.
|
The `_resolve_model_name()` method handles this mapping automatically.
|
||||||
|
|
||||||
|
## Critical Implementation Requirements
|
||||||
|
|
||||||
|
### Alias Resolution for OpenAI-Compatible Providers
|
||||||
|
|
||||||
|
If you inherit from `OpenAICompatibleProvider` and define model aliases, you **MUST** override `generate_content()` to resolve aliases before API calls. This is because:
|
||||||
|
|
||||||
|
1. **The base `OpenAICompatibleProvider.generate_content()`** sends the original model name directly to the API
|
||||||
|
2. **Your API expects the full model name**, not the alias
|
||||||
|
3. **Without resolution**, requests like `model="large"` will fail with 404/400 errors
|
||||||
|
|
||||||
|
**Examples of providers that need this:**
|
||||||
|
- XAI provider: `"grok"` → `"grok-3"`
|
||||||
|
- OpenAI provider: `"mini"` → `"o4-mini"`
|
||||||
|
- Custom provider: `"fast"` → `"llama-3.1-8b-instruct"`
|
||||||
|
|
||||||
|
**Example implementation pattern:**
|
||||||
|
```python
|
||||||
|
def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
|
||||||
|
# CRITICAL: Resolve alias before API call
|
||||||
|
resolved_model_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# Pass resolved name to parent
|
||||||
|
return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Providers that DON'T need this:**
|
||||||
|
- Gemini provider (has its own generate_content implementation)
|
||||||
|
- OpenRouter provider (already implements this pattern)
|
||||||
|
- Providers without aliases
|
||||||
|
|
||||||
## Best Practices
|
## Best Practices
|
||||||
|
|
||||||
1. **Always validate model names** against supported models and restrictions
|
1. **Always validate model names** against supported models and restrictions
|
||||||
@@ -715,6 +799,7 @@ Before submitting your PR:
|
|||||||
- [ ] Provider implementation complete with all required methods
|
- [ ] Provider implementation complete with all required methods
|
||||||
- [ ] API key mapping added to `_get_api_key_for_provider()` in `providers/registry.py`
|
- [ ] API key mapping added to `_get_api_key_for_provider()` in `providers/registry.py`
|
||||||
- [ ] Provider added to `PROVIDER_PRIORITY_ORDER` in `registry.py` (if native provider)
|
- [ ] Provider added to `PROVIDER_PRIORITY_ORDER` in `registry.py` (if native provider)
|
||||||
|
- [ ] **Environment variables added to `docker-compose.yml`** (API key and restrictions)
|
||||||
- [ ] Provider imported and registered in `server.py`'s `configure_providers()`
|
- [ ] Provider imported and registered in `server.py`'s `configure_providers()`
|
||||||
- [ ] API key checking added to `configure_providers()` function
|
- [ ] API key checking added to `configure_providers()` function
|
||||||
- [ ] Error message updated to include new provider
|
- [ ] Error message updated to include new provider
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ class ProviderType(Enum):
|
|||||||
|
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
XAI = "xai"
|
||||||
OPENROUTER = "openrouter"
|
OPENROUTER = "openrouter"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
"""OpenAI model provider implementation."""
|
"""OpenAI model provider implementation."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
FixedTemperatureConstraint,
|
FixedTemperatureConstraint,
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
RangeTemperatureConstraint,
|
||||||
)
|
)
|
||||||
@@ -111,6 +113,29 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using OpenAI API with proper model name resolution."""
|
||||||
|
# Resolve model alias before making API call
|
||||||
|
resolved_model_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# Call parent implementation with resolved model name
|
||||||
|
return super().generate_content(
|
||||||
|
prompt=prompt,
|
||||||
|
model_name=resolved_model_name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
"""Check if the model supports extended thinking mode."""
|
"""Check if the model supports extended thinking mode."""
|
||||||
# Currently no OpenAI models support extended thinking
|
# Currently no OpenAI models support extended thinking
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ class ModelProviderRegistry:
|
|||||||
PROVIDER_PRIORITY_ORDER = [
|
PROVIDER_PRIORITY_ORDER = [
|
||||||
ProviderType.GOOGLE, # Direct Gemini access
|
ProviderType.GOOGLE, # Direct Gemini access
|
||||||
ProviderType.OPENAI, # Direct OpenAI access
|
ProviderType.OPENAI, # Direct OpenAI access
|
||||||
|
ProviderType.XAI, # Direct X.AI GROK access
|
||||||
ProviderType.CUSTOM, # Local/self-hosted models
|
ProviderType.CUSTOM, # Local/self-hosted models
|
||||||
ProviderType.OPENROUTER, # Catch-all for cloud models
|
ProviderType.OPENROUTER, # Catch-all for cloud models
|
||||||
]
|
]
|
||||||
@@ -173,15 +174,21 @@ class ModelProviderRegistry:
|
|||||||
# Get supported models based on provider type
|
# Get supported models based on provider type
|
||||||
if hasattr(provider, "SUPPORTED_MODELS"):
|
if hasattr(provider, "SUPPORTED_MODELS"):
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
# Skip aliases (string values)
|
# Handle both base models (dict configs) and aliases (string values)
|
||||||
if isinstance(config, str):
|
if isinstance(config, str):
|
||||||
|
# This is an alias - check if the target model would be allowed
|
||||||
|
target_model = config
|
||||||
|
if restriction_service and not restriction_service.is_allowed(provider_type, target_model):
|
||||||
|
logging.debug(f"Alias {model_name} -> {target_model} filtered by restrictions")
|
||||||
continue
|
continue
|
||||||
|
# Allow the alias
|
||||||
|
models[model_name] = provider_type
|
||||||
|
else:
|
||||||
|
# This is a base model with config dict
|
||||||
# Check restrictions if enabled
|
# Check restrictions if enabled
|
||||||
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
||||||
logging.debug(f"Model {model_name} filtered by restrictions")
|
logging.debug(f"Model {model_name} filtered by restrictions")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
models[model_name] = provider_type
|
models[model_name] = provider_type
|
||||||
elif provider_type == ProviderType.OPENROUTER:
|
elif provider_type == ProviderType.OPENROUTER:
|
||||||
# OpenRouter uses a registry system instead of SUPPORTED_MODELS
|
# OpenRouter uses a registry system instead of SUPPORTED_MODELS
|
||||||
@@ -230,6 +237,7 @@ class ModelProviderRegistry:
|
|||||||
key_mapping = {
|
key_mapping = {
|
||||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||||
|
ProviderType.XAI: "XAI_API_KEY",
|
||||||
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
||||||
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
|
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
|
||||||
}
|
}
|
||||||
@@ -264,9 +272,13 @@ class ModelProviderRegistry:
|
|||||||
# Group by provider
|
# Group by provider
|
||||||
openai_models = [m for m, p in available_models.items() if p == ProviderType.OPENAI]
|
openai_models = [m for m, p in available_models.items() if p == ProviderType.OPENAI]
|
||||||
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
|
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
|
||||||
|
xai_models = [m for m, p in available_models.items() if p == ProviderType.XAI]
|
||||||
|
openrouter_models = [m for m, p in available_models.items() if p == ProviderType.OPENROUTER]
|
||||||
|
|
||||||
openai_available = bool(openai_models)
|
openai_available = bool(openai_models)
|
||||||
gemini_available = bool(gemini_models)
|
gemini_available = bool(gemini_models)
|
||||||
|
xai_available = bool(xai_models)
|
||||||
|
openrouter_available = bool(openrouter_models)
|
||||||
|
|
||||||
if tool_category == ToolModelCategory.EXTENDED_REASONING:
|
if tool_category == ToolModelCategory.EXTENDED_REASONING:
|
||||||
# Prefer thinking-capable models for deep reasoning tools
|
# Prefer thinking-capable models for deep reasoning tools
|
||||||
@@ -275,17 +287,25 @@ class ModelProviderRegistry:
|
|||||||
elif openai_available and openai_models:
|
elif openai_available and openai_models:
|
||||||
# Fall back to any available OpenAI model
|
# Fall back to any available OpenAI model
|
||||||
return openai_models[0]
|
return openai_models[0]
|
||||||
|
elif xai_available and "grok-3" in xai_models:
|
||||||
|
return "grok-3" # GROK-3 for deep reasoning
|
||||||
|
elif xai_available and xai_models:
|
||||||
|
# Fall back to any available XAI model
|
||||||
|
return xai_models[0]
|
||||||
elif gemini_available and any("pro" in m for m in gemini_models):
|
elif gemini_available and any("pro" in m for m in gemini_models):
|
||||||
# Find the pro model (handles full names)
|
# Find the pro model (handles full names)
|
||||||
return next(m for m in gemini_models if "pro" in m)
|
return next(m for m in gemini_models if "pro" in m)
|
||||||
elif gemini_available and gemini_models:
|
elif gemini_available and gemini_models:
|
||||||
# Fall back to any available Gemini model
|
# Fall back to any available Gemini model
|
||||||
return gemini_models[0]
|
return gemini_models[0]
|
||||||
else:
|
elif openrouter_available:
|
||||||
# Try to find thinking-capable model from custom/openrouter
|
# Try to find thinking-capable model from openrouter
|
||||||
thinking_model = cls._find_extended_thinking_model()
|
thinking_model = cls._find_extended_thinking_model()
|
||||||
if thinking_model:
|
if thinking_model:
|
||||||
return thinking_model
|
return thinking_model
|
||||||
|
# Fallback to first available OpenRouter model
|
||||||
|
return openrouter_models[0]
|
||||||
|
else:
|
||||||
# Fallback to pro if nothing found
|
# Fallback to pro if nothing found
|
||||||
return "gemini-2.5-pro-preview-06-05"
|
return "gemini-2.5-pro-preview-06-05"
|
||||||
|
|
||||||
@@ -298,12 +318,20 @@ class ModelProviderRegistry:
|
|||||||
elif openai_available and openai_models:
|
elif openai_available and openai_models:
|
||||||
# Fall back to any available OpenAI model
|
# Fall back to any available OpenAI model
|
||||||
return openai_models[0]
|
return openai_models[0]
|
||||||
|
elif xai_available and "grok-3-fast" in xai_models:
|
||||||
|
return "grok-3-fast" # GROK-3 Fast for speed
|
||||||
|
elif xai_available and xai_models:
|
||||||
|
# Fall back to any available XAI model
|
||||||
|
return xai_models[0]
|
||||||
elif gemini_available and any("flash" in m for m in gemini_models):
|
elif gemini_available and any("flash" in m for m in gemini_models):
|
||||||
# Find the flash model (handles full names)
|
# Find the flash model (handles full names)
|
||||||
return next(m for m in gemini_models if "flash" in m)
|
return next(m for m in gemini_models if "flash" in m)
|
||||||
elif gemini_available and gemini_models:
|
elif gemini_available and gemini_models:
|
||||||
# Fall back to any available Gemini model
|
# Fall back to any available Gemini model
|
||||||
return gemini_models[0]
|
return gemini_models[0]
|
||||||
|
elif openrouter_available:
|
||||||
|
# Fallback to first available OpenRouter model
|
||||||
|
return openrouter_models[0]
|
||||||
else:
|
else:
|
||||||
# Default to flash
|
# Default to flash
|
||||||
return "gemini-2.5-flash-preview-05-20"
|
return "gemini-2.5-flash-preview-05-20"
|
||||||
@@ -315,10 +343,16 @@ class ModelProviderRegistry:
|
|||||||
return "o3-mini" # Second choice
|
return "o3-mini" # Second choice
|
||||||
elif openai_available and openai_models:
|
elif openai_available and openai_models:
|
||||||
return openai_models[0]
|
return openai_models[0]
|
||||||
|
elif xai_available and "grok-3" in xai_models:
|
||||||
|
return "grok-3" # GROK-3 as balanced choice
|
||||||
|
elif xai_available and xai_models:
|
||||||
|
return xai_models[0]
|
||||||
elif gemini_available and any("flash" in m for m in gemini_models):
|
elif gemini_available and any("flash" in m for m in gemini_models):
|
||||||
return next(m for m in gemini_models if "flash" in m)
|
return next(m for m in gemini_models if "flash" in m)
|
||||||
elif gemini_available and gemini_models:
|
elif gemini_available and gemini_models:
|
||||||
return gemini_models[0]
|
return gemini_models[0]
|
||||||
|
elif openrouter_available:
|
||||||
|
return openrouter_models[0]
|
||||||
else:
|
else:
|
||||||
# No models available due to restrictions - check if any providers exist
|
# No models available due to restrictions - check if any providers exist
|
||||||
if not available_models:
|
if not available_models:
|
||||||
@@ -355,8 +389,9 @@ class ModelProviderRegistry:
|
|||||||
preferred_models = [
|
preferred_models = [
|
||||||
"anthropic/claude-3.5-sonnet",
|
"anthropic/claude-3.5-sonnet",
|
||||||
"anthropic/claude-3-opus-20240229",
|
"anthropic/claude-3-opus-20240229",
|
||||||
"meta-llama/llama-3.1-70b-instruct",
|
"google/gemini-2.5-pro-preview-06-05",
|
||||||
"google/gemini-pro-1.5",
|
"google/gemini-pro-1.5",
|
||||||
|
"meta-llama/llama-3.1-70b-instruct",
|
||||||
"mistralai/mixtral-8x7b-instruct",
|
"mistralai/mixtral-8x7b-instruct",
|
||||||
]
|
]
|
||||||
for model in preferred_models:
|
for model in preferred_models:
|
||||||
|
|||||||
135
providers/xai.py
Normal file
135
providers/xai.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""X.AI (GROK) model provider implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderType,
|
||||||
|
RangeTemperatureConstraint,
|
||||||
|
)
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class XAIModelProvider(OpenAICompatibleProvider):
|
||||||
|
"""X.AI GROK API provider (api.x.ai)."""
|
||||||
|
|
||||||
|
FRIENDLY_NAME = "X.AI"
|
||||||
|
|
||||||
|
# Model configurations
|
||||||
|
SUPPORTED_MODELS = {
|
||||||
|
"grok-3": {
|
||||||
|
"context_window": 131_072, # 131K tokens
|
||||||
|
"supports_extended_thinking": False,
|
||||||
|
},
|
||||||
|
"grok-3-fast": {
|
||||||
|
"context_window": 131_072, # 131K tokens
|
||||||
|
"supports_extended_thinking": False,
|
||||||
|
},
|
||||||
|
# Shorthands for convenience
|
||||||
|
"grok": "grok-3", # Default to grok-3
|
||||||
|
"grok3": "grok-3",
|
||||||
|
"grok3fast": "grok-3-fast",
|
||||||
|
"grokfast": "grok-3-fast",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize X.AI provider with API key."""
|
||||||
|
# Set X.AI base URL
|
||||||
|
kwargs.setdefault("base_url", "https://api.x.ai/v1")
|
||||||
|
super().__init__(api_key, **kwargs)
|
||||||
|
|
||||||
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
|
"""Get capabilities for a specific X.AI model."""
|
||||||
|
# Resolve shorthand
|
||||||
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
|
||||||
|
raise ValueError(f"Unsupported X.AI model: {model_name}")
|
||||||
|
|
||||||
|
# Check if model is allowed by restrictions
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service()
|
||||||
|
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
||||||
|
raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
|
||||||
|
|
||||||
|
config = self.SUPPORTED_MODELS[resolved_name]
|
||||||
|
|
||||||
|
# Define temperature constraints for GROK models
|
||||||
|
# GROK supports the standard OpenAI temperature range
|
||||||
|
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||||
|
|
||||||
|
return ModelCapabilities(
|
||||||
|
provider=ProviderType.XAI,
|
||||||
|
model_name=resolved_name,
|
||||||
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
|
context_window=config["context_window"],
|
||||||
|
supports_extended_thinking=config["supports_extended_thinking"],
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
temperature_constraint=temp_constraint,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type."""
|
||||||
|
return ProviderType.XAI
|
||||||
|
|
||||||
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
|
"""Validate if the model name is supported and allowed."""
|
||||||
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# First check if model is supported
|
||||||
|
if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Then check if model is allowed by restrictions
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service()
|
||||||
|
if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
|
||||||
|
logger.debug(f"X.AI model '{model_name}' -> '{resolved_name}' blocked by restrictions")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using X.AI API with proper model name resolution."""
|
||||||
|
# Resolve model alias before making API call
|
||||||
|
resolved_model_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# Call parent implementation with resolved model name
|
||||||
|
return super().generate_content(
|
||||||
|
prompt=prompt,
|
||||||
|
model_name=resolved_model_name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||||
|
"""Check if the model supports extended thinking mode."""
|
||||||
|
# Currently GROK models do not support extended thinking
|
||||||
|
# This may change with future GROK model releases
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
"""Resolve model shorthand to full name."""
|
||||||
|
# Check if it's a shorthand
|
||||||
|
shorthand_value = self.SUPPORTED_MODELS.get(model_name)
|
||||||
|
if isinstance(shorthand_value, str):
|
||||||
|
return shorthand_value
|
||||||
|
return model_name
|
||||||
@@ -120,6 +120,16 @@ else
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ -n "${XAI_API_KEY:-}" ]; then
|
||||||
|
# Replace the placeholder API key with the actual value
|
||||||
|
if command -v sed >/dev/null 2>&1; then
|
||||||
|
sed -i.bak "s/your_xai_api_key_here/$XAI_API_KEY/" .env && rm .env.bak
|
||||||
|
echo "✅ Updated .env with existing XAI_API_KEY from environment"
|
||||||
|
else
|
||||||
|
echo "⚠️ Found XAI_API_KEY in environment, but sed not available. Please update .env manually."
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
if [ -n "${OPENROUTER_API_KEY:-}" ]; then
|
if [ -n "${OPENROUTER_API_KEY:-}" ]; then
|
||||||
# Replace the placeholder API key with the actual value
|
# Replace the placeholder API key with the actual value
|
||||||
if command -v sed >/dev/null 2>&1; then
|
if command -v sed >/dev/null 2>&1; then
|
||||||
@@ -169,6 +179,7 @@ source .env 2>/dev/null || true
|
|||||||
|
|
||||||
VALID_GEMINI_KEY=false
|
VALID_GEMINI_KEY=false
|
||||||
VALID_OPENAI_KEY=false
|
VALID_OPENAI_KEY=false
|
||||||
|
VALID_XAI_KEY=false
|
||||||
VALID_OPENROUTER_KEY=false
|
VALID_OPENROUTER_KEY=false
|
||||||
VALID_CUSTOM_URL=false
|
VALID_CUSTOM_URL=false
|
||||||
|
|
||||||
@@ -184,6 +195,12 @@ if [ -n "${OPENAI_API_KEY:-}" ] && [ "$OPENAI_API_KEY" != "your_openai_api_key_h
|
|||||||
echo "✅ OPENAI_API_KEY found"
|
echo "✅ OPENAI_API_KEY found"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Check if XAI_API_KEY is set and not the placeholder
|
||||||
|
if [ -n "${XAI_API_KEY:-}" ] && [ "$XAI_API_KEY" != "your_xai_api_key_here" ]; then
|
||||||
|
VALID_XAI_KEY=true
|
||||||
|
echo "✅ XAI_API_KEY found"
|
||||||
|
fi
|
||||||
|
|
||||||
# Check if OPENROUTER_API_KEY is set and not the placeholder
|
# Check if OPENROUTER_API_KEY is set and not the placeholder
|
||||||
if [ -n "${OPENROUTER_API_KEY:-}" ] && [ "$OPENROUTER_API_KEY" != "your_openrouter_api_key_here" ]; then
|
if [ -n "${OPENROUTER_API_KEY:-}" ] && [ "$OPENROUTER_API_KEY" != "your_openrouter_api_key_here" ]; then
|
||||||
VALID_OPENROUTER_KEY=true
|
VALID_OPENROUTER_KEY=true
|
||||||
@@ -197,19 +214,21 @@ if [ -n "${CUSTOM_API_URL:-}" ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Require at least one valid API key or custom URL
|
# Require at least one valid API key or custom URL
|
||||||
if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ] && [ "$VALID_OPENROUTER_KEY" = false ] && [ "$VALID_CUSTOM_URL" = false ]; then
|
if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ] && [ "$VALID_XAI_KEY" = false ] && [ "$VALID_OPENROUTER_KEY" = false ] && [ "$VALID_CUSTOM_URL" = false ]; then
|
||||||
echo ""
|
echo ""
|
||||||
echo "❌ ERROR: At least one valid API key or custom URL is required!"
|
echo "❌ ERROR: At least one valid API key or custom URL is required!"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Please edit the .env file and set at least one of:"
|
echo "Please edit the .env file and set at least one of:"
|
||||||
echo " - GEMINI_API_KEY (get from https://makersuite.google.com/app/apikey)"
|
echo " - GEMINI_API_KEY (get from https://makersuite.google.com/app/apikey)"
|
||||||
echo " - OPENAI_API_KEY (get from https://platform.openai.com/api-keys)"
|
echo " - OPENAI_API_KEY (get from https://platform.openai.com/api-keys)"
|
||||||
|
echo " - XAI_API_KEY (get from https://console.x.ai/)"
|
||||||
echo " - OPENROUTER_API_KEY (get from https://openrouter.ai/)"
|
echo " - OPENROUTER_API_KEY (get from https://openrouter.ai/)"
|
||||||
echo " - CUSTOM_API_URL (for local models like Ollama, vLLM, etc.)"
|
echo " - CUSTOM_API_URL (for local models like Ollama, vLLM, etc.)"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Example:"
|
echo "Example:"
|
||||||
echo " GEMINI_API_KEY=your-actual-api-key-here"
|
echo " GEMINI_API_KEY=your-actual-api-key-here"
|
||||||
echo " OPENAI_API_KEY=sk-your-actual-openai-key-here"
|
echo " OPENAI_API_KEY=sk-your-actual-openai-key-here"
|
||||||
|
echo " XAI_API_KEY=xai-your-actual-xai-key-here"
|
||||||
echo " OPENROUTER_API_KEY=sk-or-your-actual-openrouter-key-here"
|
echo " OPENROUTER_API_KEY=sk-or-your-actual-openrouter-key-here"
|
||||||
echo " CUSTOM_API_URL=http://host.docker.internal:11434/v1 # Ollama (use host.docker.internal, NOT localhost!)"
|
echo " CUSTOM_API_URL=http://host.docker.internal:11434/v1 # Ollama (use host.docker.internal, NOT localhost!)"
|
||||||
echo ""
|
echo ""
|
||||||
@@ -302,7 +321,7 @@ show_configuration_steps() {
|
|||||||
echo ""
|
echo ""
|
||||||
echo "🔄 Next steps:"
|
echo "🔄 Next steps:"
|
||||||
NEEDS_KEY_UPDATE=false
|
NEEDS_KEY_UPDATE=false
|
||||||
if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null || grep -q "your_openrouter_api_key_here" .env 2>/dev/null; then
|
if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null || grep -q "your_xai_api_key_here" .env 2>/dev/null || grep -q "your_openrouter_api_key_here" .env 2>/dev/null; then
|
||||||
NEEDS_KEY_UPDATE=true
|
NEEDS_KEY_UPDATE=true
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@@ -310,6 +329,7 @@ show_configuration_steps() {
|
|||||||
echo "1. Edit .env and replace placeholder API keys with actual ones"
|
echo "1. Edit .env and replace placeholder API keys with actual ones"
|
||||||
echo " - GEMINI_API_KEY: your-gemini-api-key-here"
|
echo " - GEMINI_API_KEY: your-gemini-api-key-here"
|
||||||
echo " - OPENAI_API_KEY: your-openai-api-key-here"
|
echo " - OPENAI_API_KEY: your-openai-api-key-here"
|
||||||
|
echo " - XAI_API_KEY: your-xai-api-key-here"
|
||||||
echo " - OPENROUTER_API_KEY: your-openrouter-api-key-here (optional)"
|
echo " - OPENROUTER_API_KEY: your-openrouter-api-key-here (optional)"
|
||||||
echo "2. Restart services: $COMPOSE_CMD restart"
|
echo "2. Restart services: $COMPOSE_CMD restart"
|
||||||
echo "3. Copy the configuration below to your Claude Desktop config if required:"
|
echo "3. Copy the configuration below to your Claude Desktop config if required:"
|
||||||
|
|||||||
11
server.py
11
server.py
@@ -169,6 +169,7 @@ def configure_providers():
|
|||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai import OpenAIModelProvider
|
from providers.openai import OpenAIModelProvider
|
||||||
from providers.openrouter import OpenRouterProvider
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
from providers.xai import XAIModelProvider
|
||||||
from utils.model_restrictions import get_restriction_service
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
valid_providers = []
|
valid_providers = []
|
||||||
@@ -190,6 +191,13 @@ def configure_providers():
|
|||||||
has_native_apis = True
|
has_native_apis = True
|
||||||
logger.info("OpenAI API key found - o3 model available")
|
logger.info("OpenAI API key found - o3 model available")
|
||||||
|
|
||||||
|
# Check for X.AI API key
|
||||||
|
xai_key = os.getenv("XAI_API_KEY")
|
||||||
|
if xai_key and xai_key != "your_xai_api_key_here":
|
||||||
|
valid_providers.append("X.AI (GROK)")
|
||||||
|
has_native_apis = True
|
||||||
|
logger.info("X.AI API key found - GROK models available")
|
||||||
|
|
||||||
# Check for OpenRouter API key
|
# Check for OpenRouter API key
|
||||||
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
||||||
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
|
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
|
||||||
@@ -221,6 +229,8 @@ def configure_providers():
|
|||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
if openai_key and openai_key != "your_openai_api_key_here":
|
if openai_key and openai_key != "your_openai_api_key_here":
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
if xai_key and xai_key != "your_xai_api_key_here":
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||||
|
|
||||||
# 2. Custom provider second (for local/private models)
|
# 2. Custom provider second (for local/private models)
|
||||||
if has_custom:
|
if has_custom:
|
||||||
@@ -242,6 +252,7 @@ def configure_providers():
|
|||||||
"At least one API configuration is required. Please set either:\n"
|
"At least one API configuration is required. Please set either:\n"
|
||||||
"- GEMINI_API_KEY for Gemini models\n"
|
"- GEMINI_API_KEY for Gemini models\n"
|
||||||
"- OPENAI_API_KEY for OpenAI o3 model\n"
|
"- OPENAI_API_KEY for OpenAI o3 model\n"
|
||||||
|
"- XAI_API_KEY for X.AI GROK models\n"
|
||||||
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
|
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
|
||||||
"- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)"
|
"- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from .test_redis_validation import RedisValidationTest
|
|||||||
from .test_refactor_validation import RefactorValidationTest
|
from .test_refactor_validation import RefactorValidationTest
|
||||||
from .test_testgen_validation import TestGenValidationTest
|
from .test_testgen_validation import TestGenValidationTest
|
||||||
from .test_token_allocation_validation import TokenAllocationValidationTest
|
from .test_token_allocation_validation import TokenAllocationValidationTest
|
||||||
|
from .test_xai_models import XAIModelsTest
|
||||||
|
|
||||||
# Test registry for dynamic loading
|
# Test registry for dynamic loading
|
||||||
TEST_REGISTRY = {
|
TEST_REGISTRY = {
|
||||||
@@ -44,6 +45,7 @@ TEST_REGISTRY = {
|
|||||||
"testgen_validation": TestGenValidationTest,
|
"testgen_validation": TestGenValidationTest,
|
||||||
"refactor_validation": RefactorValidationTest,
|
"refactor_validation": RefactorValidationTest,
|
||||||
"conversation_chain_validation": ConversationChainValidationTest,
|
"conversation_chain_validation": ConversationChainValidationTest,
|
||||||
|
"xai_models": XAIModelsTest,
|
||||||
# "o3_pro_expensive": O3ProExpensiveTest, # COMMENTED OUT - too expensive to run by default
|
# "o3_pro_expensive": O3ProExpensiveTest, # COMMENTED OUT - too expensive to run by default
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,5 +69,6 @@ __all__ = [
|
|||||||
"TestGenValidationTest",
|
"TestGenValidationTest",
|
||||||
"RefactorValidationTest",
|
"RefactorValidationTest",
|
||||||
"ConversationChainValidationTest",
|
"ConversationChainValidationTest",
|
||||||
|
"XAIModelsTest",
|
||||||
"TEST_REGISTRY",
|
"TEST_REGISTRY",
|
||||||
]
|
]
|
||||||
|
|||||||
280
simulator_tests/test_xai_models.py
Normal file
280
simulator_tests/test_xai_models.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
X.AI GROK Model Tests
|
||||||
|
|
||||||
|
Tests that verify X.AI GROK functionality including:
|
||||||
|
- Model alias resolution (grok, grok3, grokfast map to actual GROK models)
|
||||||
|
- GROK-3 and GROK-3-fast models work correctly
|
||||||
|
- Conversation continuity works with GROK models
|
||||||
|
- API integration and response validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from .base_test import BaseSimulatorTest
|
||||||
|
|
||||||
|
|
||||||
|
class XAIModelsTest(BaseSimulatorTest):
|
||||||
|
"""Test X.AI GROK model functionality and integration"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_name(self) -> str:
|
||||||
|
return "xai_models"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_description(self) -> str:
|
||||||
|
return "X.AI GROK model functionality and integration"
|
||||||
|
|
||||||
|
def get_recent_server_logs(self) -> str:
|
||||||
|
"""Get recent server logs from the log file directly"""
|
||||||
|
try:
|
||||||
|
# Read logs directly from the log file
|
||||||
|
cmd = ["docker", "exec", self.container_name, "tail", "-n", "500", "/tmp/mcp_server.log"]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
return result.stdout
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"Failed to read server logs: {result.stderr}")
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to get server logs: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def run_test(self) -> bool:
|
||||||
|
"""Test X.AI GROK model functionality"""
|
||||||
|
try:
|
||||||
|
self.logger.info("Test: X.AI GROK model functionality and integration")
|
||||||
|
|
||||||
|
# Check if X.AI API key is configured and not empty
|
||||||
|
check_cmd = [
|
||||||
|
"docker",
|
||||||
|
"exec",
|
||||||
|
self.container_name,
|
||||||
|
"python",
|
||||||
|
"-c",
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
xai_key = os.environ.get("XAI_API_KEY", "")
|
||||||
|
is_valid = bool(xai_key and xai_key != "your_xai_api_key_here" and xai_key.strip())
|
||||||
|
print(f"XAI_KEY_VALID:{is_valid}")
|
||||||
|
""".strip(),
|
||||||
|
]
|
||||||
|
result = subprocess.run(check_cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0 and "XAI_KEY_VALID:False" in result.stdout:
|
||||||
|
self.logger.info(" ⚠️ X.AI API key not configured or empty - skipping test")
|
||||||
|
self.logger.info(" ℹ️ This test requires XAI_API_KEY to be set in .env with a valid key")
|
||||||
|
return True # Return True to indicate test is skipped, not failed
|
||||||
|
|
||||||
|
# Setup test files for later use
|
||||||
|
self.setup_test_files()
|
||||||
|
|
||||||
|
# Test 1: 'grok' alias (should map to grok-3)
|
||||||
|
self.logger.info(" 1: Testing 'grok' alias (should map to grok-3)")
|
||||||
|
|
||||||
|
response1, continuation_id = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Say 'Hello from GROK model!' and nothing else.",
|
||||||
|
"model": "grok",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response1:
|
||||||
|
self.logger.error(" ❌ GROK alias test failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ GROK alias call completed")
|
||||||
|
if continuation_id:
|
||||||
|
self.logger.info(f" ✅ Got continuation_id: {continuation_id}")
|
||||||
|
|
||||||
|
# Test 2: Direct grok-3 model name
|
||||||
|
self.logger.info(" 2: Testing direct model name (grok-3)")
|
||||||
|
|
||||||
|
response2, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Say 'Hello from GROK-3!' and nothing else.",
|
||||||
|
"model": "grok-3",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response2:
|
||||||
|
self.logger.error(" ❌ Direct GROK-3 model test failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ Direct GROK-3 model call completed")
|
||||||
|
|
||||||
|
# Test 3: grok-3-fast model
|
||||||
|
self.logger.info(" 3: Testing GROK-3-fast model")
|
||||||
|
|
||||||
|
response3, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Say 'Hello from GROK-3-fast!' and nothing else.",
|
||||||
|
"model": "grok-3-fast",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response3:
|
||||||
|
self.logger.error(" ❌ GROK-3-fast model test failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ GROK-3-fast model call completed")
|
||||||
|
|
||||||
|
# Test 4: Shorthand aliases
|
||||||
|
self.logger.info(" 4: Testing shorthand aliases (grok3, grokfast)")
|
||||||
|
|
||||||
|
response4, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Say 'Hello from grok3 alias!' and nothing else.",
|
||||||
|
"model": "grok3",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response4:
|
||||||
|
self.logger.error(" ❌ grok3 alias test failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
response5, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Say 'Hello from grokfast alias!' and nothing else.",
|
||||||
|
"model": "grokfast",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response5:
|
||||||
|
self.logger.error(" ❌ grokfast alias test failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.logger.info(" ✅ Shorthand aliases work correctly")
|
||||||
|
|
||||||
|
# Test 5: Conversation continuity with GROK models
|
||||||
|
self.logger.info(" 5: Testing conversation continuity with GROK")
|
||||||
|
|
||||||
|
response6, new_continuation_id = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "Remember this number: 87. What number did I just tell you?",
|
||||||
|
"model": "grok",
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response6 or not new_continuation_id:
|
||||||
|
self.logger.error(" ❌ Failed to start conversation with continuation_id")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Continue the conversation
|
||||||
|
response7, _ = self.call_mcp_tool(
|
||||||
|
"chat",
|
||||||
|
{
|
||||||
|
"prompt": "What was the number I told you earlier?",
|
||||||
|
"model": "grok",
|
||||||
|
"continuation_id": new_continuation_id,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response7:
|
||||||
|
self.logger.error(" ❌ Failed to continue conversation")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the model remembered the number
|
||||||
|
if "87" in response7:
|
||||||
|
self.logger.info(" ✅ Conversation continuity working with GROK")
|
||||||
|
else:
|
||||||
|
self.logger.warning(" ⚠️ Model may not have remembered the number")
|
||||||
|
|
||||||
|
# Test 6: Validate X.AI API usage from logs
|
||||||
|
self.logger.info(" 6: Validating X.AI API usage in logs")
|
||||||
|
logs = self.get_recent_server_logs()
|
||||||
|
|
||||||
|
# Check for X.AI API calls
|
||||||
|
xai_logs = [line for line in logs.split("\n") if "x.ai" in line.lower()]
|
||||||
|
xai_api_logs = [line for line in logs.split("\n") if "api.x.ai" in line]
|
||||||
|
grok_logs = [line for line in logs.split("\n") if "grok" in line.lower()]
|
||||||
|
|
||||||
|
# Check for specific model resolution
|
||||||
|
grok_resolution_logs = [
|
||||||
|
line
|
||||||
|
for line in logs.split("\n")
|
||||||
|
if ("Resolved model" in line and "grok" in line.lower()) or ("grok" in line and "->" in line)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check for X.AI provider usage
|
||||||
|
xai_provider_logs = [line for line in logs.split("\n") if "XAI" in line or "X.AI" in line]
|
||||||
|
|
||||||
|
# Log findings
|
||||||
|
self.logger.info(f" X.AI-related logs: {len(xai_logs)}")
|
||||||
|
self.logger.info(f" X.AI API logs: {len(xai_api_logs)}")
|
||||||
|
self.logger.info(f" GROK-related logs: {len(grok_logs)}")
|
||||||
|
self.logger.info(f" Model resolution logs: {len(grok_resolution_logs)}")
|
||||||
|
self.logger.info(f" X.AI provider logs: {len(xai_provider_logs)}")
|
||||||
|
|
||||||
|
# Sample log output for debugging
|
||||||
|
if self.verbose and xai_logs:
|
||||||
|
self.logger.debug(" 📋 Sample X.AI logs:")
|
||||||
|
for log in xai_logs[:3]:
|
||||||
|
self.logger.debug(f" {log}")
|
||||||
|
|
||||||
|
if self.verbose and grok_logs:
|
||||||
|
self.logger.debug(" 📋 Sample GROK logs:")
|
||||||
|
for log in grok_logs[:3]:
|
||||||
|
self.logger.debug(f" {log}")
|
||||||
|
|
||||||
|
# Success criteria
|
||||||
|
grok_mentioned = len(grok_logs) > 0
|
||||||
|
api_used = len(xai_api_logs) > 0 or len(xai_logs) > 0
|
||||||
|
provider_used = len(xai_provider_logs) > 0
|
||||||
|
|
||||||
|
success_criteria = [
|
||||||
|
("GROK models mentioned in logs", grok_mentioned),
|
||||||
|
("X.AI API calls made", api_used),
|
||||||
|
("X.AI provider used", provider_used),
|
||||||
|
("All model calls succeeded", True), # We already checked this above
|
||||||
|
("Conversation continuity works", True), # We already tested this
|
||||||
|
]
|
||||||
|
|
||||||
|
passed_criteria = sum(1 for _, passed in success_criteria if passed)
|
||||||
|
self.logger.info(f" Success criteria met: {passed_criteria}/{len(success_criteria)}")
|
||||||
|
|
||||||
|
for criterion, passed in success_criteria:
|
||||||
|
status = "✅" if passed else "❌"
|
||||||
|
self.logger.info(f" {status} {criterion}")
|
||||||
|
|
||||||
|
if passed_criteria >= 3: # At least 3 out of 5 criteria
|
||||||
|
self.logger.info(" ✅ X.AI GROK model tests passed")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.logger.error(" ❌ X.AI GROK model tests failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"X.AI GROK model test failed: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
self.cleanup_test_files()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run the X.AI GROK model tests"""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
verbose = "--verbose" in sys.argv or "-v" in sys.argv
|
||||||
|
test = XAIModelsTest(verbose=verbose)
|
||||||
|
|
||||||
|
success = test.run_test()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -21,6 +21,8 @@ if "GEMINI_API_KEY" not in os.environ:
|
|||||||
os.environ["GEMINI_API_KEY"] = "dummy-key-for-tests"
|
os.environ["GEMINI_API_KEY"] = "dummy-key-for-tests"
|
||||||
if "OPENAI_API_KEY" not in os.environ:
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
os.environ["OPENAI_API_KEY"] = "dummy-key-for-tests"
|
os.environ["OPENAI_API_KEY"] = "dummy-key-for-tests"
|
||||||
|
if "XAI_API_KEY" not in os.environ:
|
||||||
|
os.environ["XAI_API_KEY"] = "dummy-key-for-tests"
|
||||||
|
|
||||||
# Set default model to a specific value for tests to avoid auto mode
|
# Set default model to a specific value for tests to avoid auto mode
|
||||||
# This prevents all tests from failing due to missing model parameter
|
# This prevents all tests from failing due to missing model parameter
|
||||||
@@ -46,10 +48,12 @@ from providers import ModelProviderRegistry # noqa: E402
|
|||||||
from providers.base import ProviderType # noqa: E402
|
from providers.base import ProviderType # noqa: E402
|
||||||
from providers.gemini import GeminiModelProvider # noqa: E402
|
from providers.gemini import GeminiModelProvider # noqa: E402
|
||||||
from providers.openai import OpenAIModelProvider # noqa: E402
|
from providers.openai import OpenAIModelProvider # noqa: E402
|
||||||
|
from providers.xai import XAIModelProvider # noqa: E402
|
||||||
|
|
||||||
# Register providers at test startup
|
# Register providers at test startup
|
||||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -90,6 +94,18 @@ def mock_provider_availability(request, monkeypatch):
|
|||||||
if marker:
|
if marker:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Ensure providers are registered (in case other tests cleared the registry)
|
||||||
|
from providers.base import ProviderType
|
||||||
|
|
||||||
|
registry = ModelProviderRegistry()
|
||||||
|
|
||||||
|
if ProviderType.GOOGLE not in registry._providers:
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
if ProviderType.OPENAI not in registry._providers:
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
if ProviderType.XAI not in registry._providers:
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
original_get_provider = ModelProviderRegistry.get_provider_for_model
|
original_get_provider = ModelProviderRegistry.get_provider_for_model
|
||||||
@@ -119,3 +135,31 @@ def mock_provider_availability(request, monkeypatch):
|
|||||||
return original_get_provider(model_name)
|
return original_get_provider(model_name)
|
||||||
|
|
||||||
monkeypatch.setattr(ModelProviderRegistry, "get_provider_for_model", mock_get_provider_for_model)
|
monkeypatch.setattr(ModelProviderRegistry, "get_provider_for_model", mock_get_provider_for_model)
|
||||||
|
|
||||||
|
# Also mock is_effective_auto_mode for all BaseTool instances to return False
|
||||||
|
# unless we're specifically testing auto mode behavior
|
||||||
|
from tools.base import BaseTool
|
||||||
|
|
||||||
|
def mock_is_effective_auto_mode(self):
|
||||||
|
# If this is an auto mode test file or specific auto mode test, use the real logic
|
||||||
|
test_file = request.node.fspath.basename if hasattr(request, "node") and hasattr(request.node, "fspath") else ""
|
||||||
|
test_name = request.node.name if hasattr(request, "node") else ""
|
||||||
|
|
||||||
|
# Allow auto mode for tests in auto mode files or with auto in the name
|
||||||
|
if (
|
||||||
|
"auto_mode" in test_file.lower()
|
||||||
|
or "auto" in test_name.lower()
|
||||||
|
or "intelligent_fallback" in test_file.lower()
|
||||||
|
or "per_tool_model_defaults" in test_file.lower()
|
||||||
|
):
|
||||||
|
# Call original method logic
|
||||||
|
from config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
if DEFAULT_MODEL.lower() == "auto":
|
||||||
|
return True
|
||||||
|
provider = ModelProviderRegistry.get_provider_for_model(DEFAULT_MODEL)
|
||||||
|
return provider is None
|
||||||
|
# For all other tests, return False to disable auto mode
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode)
|
||||||
|
|||||||
582
tests/test_auto_mode_comprehensive.py
Normal file
582
tests/test_auto_mode_comprehensive.py
Normal file
@@ -0,0 +1,582 @@
|
|||||||
|
"""Comprehensive tests for auto mode functionality across all provider combinations"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from tools.analyze import AnalyzeTool
|
||||||
|
from tools.chat import ChatTool
|
||||||
|
from tools.debug import DebugIssueTool
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
from tools.thinkdeep import ThinkDeepTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_mock_provider
|
||||||
|
class TestAutoModeComprehensive:
|
||||||
|
"""Test auto mode model selection across all provider combinations"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up clean state before each test."""
|
||||||
|
# Save original environment state for restoration
|
||||||
|
import os
|
||||||
|
|
||||||
|
self._original_default_model = os.environ.get("DEFAULT_MODEL", "")
|
||||||
|
|
||||||
|
# Clear restriction service cache
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
# Clear provider registry by resetting singleton instance
|
||||||
|
ModelProviderRegistry._instance = None
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Clean up after each test."""
|
||||||
|
# Restore original DEFAULT_MODEL
|
||||||
|
import os
|
||||||
|
|
||||||
|
if self._original_default_model:
|
||||||
|
os.environ["DEFAULT_MODEL"] = self._original_default_model
|
||||||
|
elif "DEFAULT_MODEL" in os.environ:
|
||||||
|
del os.environ["DEFAULT_MODEL"]
|
||||||
|
|
||||||
|
# Reload config to pick up the restored DEFAULT_MODEL
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Clear restriction service cache
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
# Clear provider registry by resetting singleton instance
|
||||||
|
ModelProviderRegistry._instance = None
|
||||||
|
|
||||||
|
# Re-register providers for subsequent tests (like conftest.py does)
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_config,expected_models",
|
||||||
|
[
|
||||||
|
# Only Gemini API available
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"GEMINI_API_KEY": "real-key",
|
||||||
|
"OPENAI_API_KEY": None,
|
||||||
|
"XAI_API_KEY": None,
|
||||||
|
"OPENROUTER_API_KEY": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"EXTENDED_REASONING": "gemini-2.5-pro-preview-06-05", # Pro for deep thinking
|
||||||
|
"FAST_RESPONSE": "gemini-2.5-flash-preview-05-20", # Flash for speed
|
||||||
|
"BALANCED": "gemini-2.5-flash-preview-05-20", # Flash as balanced
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Only OpenAI API available
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"GEMINI_API_KEY": None,
|
||||||
|
"OPENAI_API_KEY": "real-key",
|
||||||
|
"XAI_API_KEY": None,
|
||||||
|
"OPENROUTER_API_KEY": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"EXTENDED_REASONING": "o3", # O3 for deep reasoning
|
||||||
|
"FAST_RESPONSE": "o4-mini", # O4-mini for speed
|
||||||
|
"BALANCED": "o4-mini", # O4-mini as balanced
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Only X.AI API available
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"GEMINI_API_KEY": None,
|
||||||
|
"OPENAI_API_KEY": None,
|
||||||
|
"XAI_API_KEY": "real-key",
|
||||||
|
"OPENROUTER_API_KEY": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"EXTENDED_REASONING": "grok-3", # GROK-3 for reasoning
|
||||||
|
"FAST_RESPONSE": "grok-3-fast", # GROK-3-fast for speed
|
||||||
|
"BALANCED": "grok-3", # GROK-3 as balanced
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Both Gemini and OpenAI available - should prefer based on tool category
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"GEMINI_API_KEY": "real-key",
|
||||||
|
"OPENAI_API_KEY": "real-key",
|
||||||
|
"XAI_API_KEY": None,
|
||||||
|
"OPENROUTER_API_KEY": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
|
||||||
|
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
|
||||||
|
"BALANCED": "o4-mini", # Prefer OpenAI for balanced
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# All native APIs available - should prefer based on tool category
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"GEMINI_API_KEY": "real-key",
|
||||||
|
"OPENAI_API_KEY": "real-key",
|
||||||
|
"XAI_API_KEY": "real-key",
|
||||||
|
"OPENROUTER_API_KEY": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
|
||||||
|
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
|
||||||
|
"BALANCED": "o4-mini", # Prefer OpenAI for balanced
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Only OpenRouter available - should fall back to proxy models
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"GEMINI_API_KEY": None,
|
||||||
|
"OPENAI_API_KEY": None,
|
||||||
|
"XAI_API_KEY": None,
|
||||||
|
"OPENROUTER_API_KEY": "real-key",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"EXTENDED_REASONING": "anthropic/claude-3.5-sonnet", # First preferred thinking model from OpenRouter
|
||||||
|
"FAST_RESPONSE": "anthropic/claude-3-opus", # First available OpenRouter model
|
||||||
|
"BALANCED": "anthropic/claude-3-opus", # First available OpenRouter model
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_auto_mode_model_selection_by_provider(self, provider_config, expected_models):
|
||||||
|
"""Test that auto mode selects correct models based on available providers."""
|
||||||
|
|
||||||
|
# Set up environment with specific provider configuration
|
||||||
|
# Filter out None values and handle them separately
|
||||||
|
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||||
|
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env_to_set, clear=False):
|
||||||
|
# Clear the None-valued environment variables
|
||||||
|
for key in env_to_clear:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
# Reload config to pick up auto mode
|
||||||
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Register providers based on configuration
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
if provider_config.get("GEMINI_API_KEY"):
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
if provider_config.get("OPENAI_API_KEY"):
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
if provider_config.get("XAI_API_KEY"):
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||||
|
if provider_config.get("OPENROUTER_API_KEY"):
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
|
# Test each tool category
|
||||||
|
for category_name, expected_model in expected_models.items():
|
||||||
|
category = ToolModelCategory(category_name.lower())
|
||||||
|
|
||||||
|
# Get preferred fallback model for this category
|
||||||
|
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(category)
|
||||||
|
|
||||||
|
assert fallback_model == expected_model, (
|
||||||
|
f"Provider config {provider_config}: "
|
||||||
|
f"Expected {expected_model} for {category_name}, got {fallback_model}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tool_class,expected_category",
|
||||||
|
[
|
||||||
|
(ChatTool, ToolModelCategory.FAST_RESPONSE),
|
||||||
|
(AnalyzeTool, ToolModelCategory.EXTENDED_REASONING), # AnalyzeTool uses EXTENDED_REASONING
|
||||||
|
(DebugIssueTool, ToolModelCategory.EXTENDED_REASONING),
|
||||||
|
(ThinkDeepTool, ToolModelCategory.EXTENDED_REASONING),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_tool_model_categories(self, tool_class, expected_category):
|
||||||
|
"""Test that tools have the correct model categories."""
|
||||||
|
tool = tool_class()
|
||||||
|
assert tool.get_model_category() == expected_category
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_mode_with_gemini_only_uses_correct_models(self):
|
||||||
|
"""Test that auto mode with only Gemini uses flash for fast tools and pro for reasoning tools."""
|
||||||
|
|
||||||
|
provider_config = {
|
||||||
|
"GEMINI_API_KEY": "real-key",
|
||||||
|
"OPENAI_API_KEY": None,
|
||||||
|
"XAI_API_KEY": None,
|
||||||
|
"OPENROUTER_API_KEY": None,
|
||||||
|
"DEFAULT_MODEL": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter out None values to avoid patch.dict errors
|
||||||
|
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||||
|
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env_to_set, clear=False):
|
||||||
|
# Clear the None-valued environment variables
|
||||||
|
for key in env_to_clear:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Register only Gemini provider
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
|
# Mock provider to capture what model is requested
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.generate_content.return_value = MagicMock(
|
||||||
|
content="test response", model_name="test-model", usage={"input_tokens": 10, "output_tokens": 5}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
|
||||||
|
# Test ChatTool (FAST_RESPONSE) - should prefer flash
|
||||||
|
chat_tool = ChatTool()
|
||||||
|
await chat_tool.execute({"prompt": "test", "model": "auto"}) # This should trigger auto selection
|
||||||
|
|
||||||
|
# In auto mode, the tool should get an error requiring model selection
|
||||||
|
# but the suggested model should be flash
|
||||||
|
|
||||||
|
# Reset mock for next test
|
||||||
|
ModelProviderRegistry.get_provider_for_model.reset_mock()
|
||||||
|
|
||||||
|
# Test DebugIssueTool (EXTENDED_REASONING) - should prefer pro
|
||||||
|
debug_tool = DebugIssueTool()
|
||||||
|
await debug_tool.execute({"prompt": "test error", "model": "auto"})
|
||||||
|
|
||||||
|
def test_auto_mode_schema_includes_all_available_models(self):
|
||||||
|
"""Test that auto mode schema includes all available models for user convenience."""
|
||||||
|
|
||||||
|
# Test with only Gemini available
|
||||||
|
provider_config = {
|
||||||
|
"GEMINI_API_KEY": "real-key",
|
||||||
|
"OPENAI_API_KEY": None,
|
||||||
|
"XAI_API_KEY": None,
|
||||||
|
"OPENROUTER_API_KEY": None,
|
||||||
|
"DEFAULT_MODEL": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter out None values to avoid patch.dict errors
|
||||||
|
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||||
|
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env_to_set, clear=False):
|
||||||
|
# Clear the None-valued environment variables
|
||||||
|
for key in env_to_clear:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Register only Gemini provider
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
|
tool = AnalyzeTool()
|
||||||
|
schema = tool.get_input_schema()
|
||||||
|
|
||||||
|
# Should have model as required field
|
||||||
|
assert "model" in schema["required"]
|
||||||
|
|
||||||
|
# Should include all model options from global config
|
||||||
|
model_schema = schema["properties"]["model"]
|
||||||
|
assert "enum" in model_schema
|
||||||
|
|
||||||
|
available_models = model_schema["enum"]
|
||||||
|
|
||||||
|
# Should include Gemini models
|
||||||
|
assert "flash" in available_models
|
||||||
|
assert "pro" in available_models
|
||||||
|
assert "gemini-2.5-flash-preview-05-20" in available_models
|
||||||
|
assert "gemini-2.5-pro-preview-06-05" in available_models
|
||||||
|
|
||||||
|
# Should also include other models (users might have OpenRouter configured)
|
||||||
|
# The schema should show all options; validation happens at runtime
|
||||||
|
assert "o3" in available_models
|
||||||
|
assert "o4-mini" in available_models
|
||||||
|
assert "grok" in available_models
|
||||||
|
assert "grok-3" in available_models
|
||||||
|
|
||||||
|
def test_auto_mode_schema_with_all_providers(self):
|
||||||
|
"""Test that auto mode schema includes models from all available providers."""
|
||||||
|
|
||||||
|
provider_config = {
|
||||||
|
"GEMINI_API_KEY": "real-key",
|
||||||
|
"OPENAI_API_KEY": "real-key",
|
||||||
|
"XAI_API_KEY": "real-key",
|
||||||
|
"OPENROUTER_API_KEY": None, # Don't include OpenRouter to avoid infinite models
|
||||||
|
"DEFAULT_MODEL": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter out None values to avoid patch.dict errors
|
||||||
|
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||||
|
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env_to_set, clear=False):
|
||||||
|
# Clear the None-valued environment variables
|
||||||
|
for key in env_to_clear:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Register all native providers
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||||
|
|
||||||
|
tool = AnalyzeTool()
|
||||||
|
schema = tool.get_input_schema()
|
||||||
|
|
||||||
|
model_schema = schema["properties"]["model"]
|
||||||
|
available_models = model_schema["enum"]
|
||||||
|
|
||||||
|
# Should include models from all providers
|
||||||
|
# Gemini models
|
||||||
|
assert "flash" in available_models
|
||||||
|
assert "pro" in available_models
|
||||||
|
|
||||||
|
# OpenAI models
|
||||||
|
assert "o3" in available_models
|
||||||
|
assert "o4-mini" in available_models
|
||||||
|
|
||||||
|
# XAI models
|
||||||
|
assert "grok" in available_models
|
||||||
|
assert "grok-3" in available_models
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_mode_model_parameter_required_error(self):
|
||||||
|
"""Test that auto mode properly requires model parameter and suggests correct model."""
|
||||||
|
|
||||||
|
provider_config = {
|
||||||
|
"GEMINI_API_KEY": "real-key",
|
||||||
|
"OPENAI_API_KEY": None,
|
||||||
|
"XAI_API_KEY": None,
|
||||||
|
"OPENROUTER_API_KEY": None,
|
||||||
|
"DEFAULT_MODEL": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter out None values to avoid patch.dict errors
|
||||||
|
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||||
|
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env_to_set, clear=False):
|
||||||
|
# Clear the None-valued environment variables
|
||||||
|
for key in env_to_clear:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Register only Gemini provider
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
|
# Test with ChatTool (FAST_RESPONSE category)
|
||||||
|
chat_tool = ChatTool()
|
||||||
|
result = await chat_tool.execute(
|
||||||
|
{
|
||||||
|
"prompt": "test"
|
||||||
|
# Note: no model parameter provided in auto mode
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should get error requiring model selection
|
||||||
|
assert len(result) == 1
|
||||||
|
response_text = result[0].text
|
||||||
|
|
||||||
|
# Parse JSON response to check error
|
||||||
|
import json
|
||||||
|
|
||||||
|
response_data = json.loads(response_text)
|
||||||
|
|
||||||
|
assert response_data["status"] == "error"
|
||||||
|
assert "Model parameter is required" in response_data["content"]
|
||||||
|
assert "flash" in response_data["content"] # Should suggest flash for FAST_RESPONSE
|
||||||
|
assert "category: fast_response" in response_data["content"]
|
||||||
|
|
||||||
|
def test_model_availability_with_restrictions(self):
|
||||||
|
"""Test that auto mode respects model restrictions when selecting fallback models."""
|
||||||
|
|
||||||
|
provider_config = {
|
||||||
|
"GEMINI_API_KEY": "real-key",
|
||||||
|
"OPENAI_API_KEY": "real-key",
|
||||||
|
"XAI_API_KEY": None,
|
||||||
|
"OPENROUTER_API_KEY": None,
|
||||||
|
"DEFAULT_MODEL": "auto",
|
||||||
|
"OPENAI_ALLOWED_MODELS": "o4-mini", # Restrict OpenAI to only o4-mini
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter out None values to avoid patch.dict errors
|
||||||
|
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||||
|
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env_to_set, clear=False):
|
||||||
|
# Clear the None-valued environment variables
|
||||||
|
for key in env_to_clear:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Clear restriction service to pick up new env vars
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
# Register providers
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
|
# Get available models - should respect restrictions
|
||||||
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
|
||||||
|
# Should include restricted OpenAI model
|
||||||
|
assert "o4-mini" in available_models
|
||||||
|
|
||||||
|
# Should NOT include non-restricted OpenAI models
|
||||||
|
assert "o3" not in available_models
|
||||||
|
assert "o3-mini" not in available_models
|
||||||
|
|
||||||
|
# Should still include all Gemini models (no restrictions)
|
||||||
|
assert "gemini-2.5-flash-preview-05-20" in available_models
|
||||||
|
assert "gemini-2.5-pro-preview-06-05" in available_models
|
||||||
|
|
||||||
|
def test_openrouter_fallback_when_no_native_apis(self):
|
||||||
|
"""Test that OpenRouter provides fallback models when no native APIs are available."""
|
||||||
|
|
||||||
|
provider_config = {
|
||||||
|
"GEMINI_API_KEY": None,
|
||||||
|
"OPENAI_API_KEY": None,
|
||||||
|
"XAI_API_KEY": None,
|
||||||
|
"OPENROUTER_API_KEY": "real-key",
|
||||||
|
"DEFAULT_MODEL": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter out None values to avoid patch.dict errors
|
||||||
|
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||||
|
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env_to_set, clear=False):
|
||||||
|
# Clear the None-valued environment variables
|
||||||
|
for key in env_to_clear:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Register only OpenRouter provider
|
||||||
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
|
# Mock OpenRouter registry to return known models
|
||||||
|
mock_registry = MagicMock()
|
||||||
|
mock_registry.list_models.return_value = [
|
||||||
|
"google/gemini-2.5-flash-preview-05-20",
|
||||||
|
"google/gemini-2.5-pro-preview-06-05",
|
||||||
|
"openai/o3",
|
||||||
|
"openai/o4-mini",
|
||||||
|
"anthropic/claude-3-opus",
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(OpenRouterProvider, "_registry", mock_registry):
|
||||||
|
# Get preferred models for different categories
|
||||||
|
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||||
|
ToolModelCategory.EXTENDED_REASONING
|
||||||
|
)
|
||||||
|
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||||
|
|
||||||
|
# Should fallback to known good models even via OpenRouter
|
||||||
|
# The exact model depends on _find_extended_thinking_model implementation
|
||||||
|
assert extended_reasoning is not None
|
||||||
|
assert fast_response is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_actual_model_name_resolution_in_auto_mode(self):
|
||||||
|
"""Test that when a model is selected in auto mode, the tool executes successfully."""
|
||||||
|
|
||||||
|
provider_config = {
|
||||||
|
"GEMINI_API_KEY": "real-key",
|
||||||
|
"OPENAI_API_KEY": None,
|
||||||
|
"XAI_API_KEY": None,
|
||||||
|
"OPENROUTER_API_KEY": None,
|
||||||
|
"DEFAULT_MODEL": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter out None values to avoid patch.dict errors
|
||||||
|
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||||
|
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env_to_set, clear=False):
|
||||||
|
# Clear the None-valued environment variables
|
||||||
|
for key in env_to_clear:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
import config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
|
||||||
|
# Register Gemini provider
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
|
# Mock the actual provider to simulate successful execution
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = "test response"
|
||||||
|
mock_response.model_name = "gemini-2.5-flash-preview-05-20" # The resolved name
|
||||||
|
mock_response.usage = {"input_tokens": 10, "output_tokens": 5}
|
||||||
|
# Mock _resolve_model_name to simulate alias resolution
|
||||||
|
mock_provider._resolve_model_name = lambda alias: (
|
||||||
|
"gemini-2.5-flash-preview-05-20" if alias == "flash" else alias
|
||||||
|
)
|
||||||
|
mock_provider.generate_content.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
|
||||||
|
chat_tool = ChatTool()
|
||||||
|
result = await chat_tool.execute({"prompt": "test", "model": "flash"}) # Use alias in auto mode
|
||||||
|
|
||||||
|
# Should succeed with proper model resolution
|
||||||
|
assert len(result) == 1
|
||||||
|
# Just verify that the tool executed successfully and didn't return an error
|
||||||
|
assert "error" not in result[0].text.lower()
|
||||||
344
tests/test_auto_mode_provider_selection.py
Normal file
344
tests/test_auto_mode_provider_selection.py
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
"""Test auto mode provider selection logic specifically"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_mock_provider
|
||||||
|
class TestAutoModeProviderSelection:
|
||||||
|
"""Test the core auto mode provider selection logic"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up clean state before each test."""
|
||||||
|
# Clear restriction service cache
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
# Clear provider registry
|
||||||
|
registry = ModelProviderRegistry()
|
||||||
|
registry._providers.clear()
|
||||||
|
registry._initialized_providers.clear()
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Clean up after each test."""
|
||||||
|
# Clear restriction service cache
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
def test_gemini_only_fallback_selection(self):
|
||||||
|
"""Test auto mode fallback when only Gemini is available."""
|
||||||
|
|
||||||
|
# Save original environment
|
||||||
|
original_env = {}
|
||||||
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||||
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set up environment - only Gemini available
|
||||||
|
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||||
|
for key in ["OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
# Register only Gemini provider
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
|
# Test fallback selection for different categories
|
||||||
|
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||||
|
ToolModelCategory.EXTENDED_REASONING
|
||||||
|
)
|
||||||
|
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||||
|
balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
|
||||||
|
|
||||||
|
# Should select appropriate Gemini models
|
||||||
|
assert extended_reasoning in ["gemini-2.5-pro-preview-06-05", "pro"]
|
||||||
|
assert fast_response in ["gemini-2.5-flash-preview-05-20", "flash"]
|
||||||
|
assert balanced in ["gemini-2.5-flash-preview-05-20", "flash"]
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original environment
|
||||||
|
for key, value in original_env.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key] = value
|
||||||
|
else:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
def test_openai_only_fallback_selection(self):
|
||||||
|
"""Test auto mode fallback when only OpenAI is available."""
|
||||||
|
|
||||||
|
# Save original environment
|
||||||
|
original_env = {}
|
||||||
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||||
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set up environment - only OpenAI available
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
for key in ["GEMINI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
# Register only OpenAI provider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
|
# Test fallback selection for different categories
|
||||||
|
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||||
|
ToolModelCategory.EXTENDED_REASONING
|
||||||
|
)
|
||||||
|
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||||
|
balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
|
||||||
|
|
||||||
|
# Should select appropriate OpenAI models
|
||||||
|
assert extended_reasoning in ["o3", "o3-mini", "o4-mini"] # Any available OpenAI model for reasoning
|
||||||
|
assert fast_response in ["o4-mini", "o3-mini"] # Prefer faster models
|
||||||
|
assert balanced in ["o4-mini", "o3-mini"] # Balanced selection
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original environment
|
||||||
|
for key, value in original_env.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key] = value
|
||||||
|
else:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
def test_both_gemini_and_openai_priority(self):
|
||||||
|
"""Test auto mode when both Gemini and OpenAI are available."""
|
||||||
|
|
||||||
|
# Save original environment
|
||||||
|
original_env = {}
|
||||||
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||||
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set up environment - both Gemini and OpenAI available
|
||||||
|
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
for key in ["XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
# Register both providers
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
|
# Test fallback selection for different categories
|
||||||
|
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||||
|
ToolModelCategory.EXTENDED_REASONING
|
||||||
|
)
|
||||||
|
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||||
|
|
||||||
|
# Should prefer OpenAI for reasoning (based on fallback logic)
|
||||||
|
assert extended_reasoning == "o3" # Should prefer O3 for extended reasoning
|
||||||
|
|
||||||
|
# Should prefer OpenAI for fast response
|
||||||
|
assert fast_response == "o4-mini" # Should prefer O4-mini for fast response
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original environment
|
||||||
|
for key, value in original_env.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key] = value
|
||||||
|
else:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
def test_xai_only_fallback_selection(self):
|
||||||
|
"""Test auto mode fallback when only XAI is available."""
|
||||||
|
|
||||||
|
# Save original environment
|
||||||
|
original_env = {}
|
||||||
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||||
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set up environment - only XAI available
|
||||||
|
os.environ["XAI_API_KEY"] = "test-key"
|
||||||
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "OPENROUTER_API_KEY"]:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
# Register only XAI provider
|
||||||
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||||
|
|
||||||
|
# Test fallback selection for different categories
|
||||||
|
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||||
|
ToolModelCategory.EXTENDED_REASONING
|
||||||
|
)
|
||||||
|
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||||
|
|
||||||
|
# Should fallback to available models or default fallbacks
|
||||||
|
# Since XAI models are not explicitly handled in fallback logic,
|
||||||
|
# it should fall back to the hardcoded defaults
|
||||||
|
assert extended_reasoning is not None
|
||||||
|
assert fast_response is not None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original environment
|
||||||
|
for key, value in original_env.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key] = value
|
||||||
|
else:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
def test_available_models_respects_restrictions(self):
|
||||||
|
"""Test that get_available_models respects model restrictions."""
|
||||||
|
|
||||||
|
# Save original environment
|
||||||
|
original_env = {}
|
||||||
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "OPENAI_ALLOWED_MODELS"]:
|
||||||
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set up environment with restrictions
|
||||||
|
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
os.environ["OPENAI_ALLOWED_MODELS"] = "o4-mini" # Only allow o4-mini
|
||||||
|
|
||||||
|
# Clear restriction service to pick up new restrictions
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
# Register both providers
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
|
# Get available models with restrictions
|
||||||
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
|
||||||
|
# Should include allowed OpenAI model
|
||||||
|
assert "o4-mini" in available_models
|
||||||
|
assert available_models["o4-mini"] == ProviderType.OPENAI
|
||||||
|
|
||||||
|
# Should NOT include restricted OpenAI models
|
||||||
|
assert "o3" not in available_models
|
||||||
|
assert "o3-mini" not in available_models
|
||||||
|
|
||||||
|
# Should include all Gemini models (no restrictions)
|
||||||
|
assert "gemini-2.5-flash-preview-05-20" in available_models
|
||||||
|
assert available_models["gemini-2.5-flash-preview-05-20"] == ProviderType.GOOGLE
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original environment
|
||||||
|
for key, value in original_env.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key] = value
|
||||||
|
else:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
def test_model_validation_across_providers(self):
|
||||||
|
"""Test that model validation works correctly across different providers."""
|
||||||
|
|
||||||
|
# Save original environment
|
||||||
|
original_env = {}
|
||||||
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"]:
|
||||||
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set up all providers
|
||||||
|
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
os.environ["XAI_API_KEY"] = "test-key"
|
||||||
|
|
||||||
|
# Register all providers
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||||
|
|
||||||
|
# Test model validation - each provider should handle its own models
|
||||||
|
# Gemini models
|
||||||
|
gemini_provider = ModelProviderRegistry.get_provider_for_model("flash")
|
||||||
|
assert gemini_provider is not None
|
||||||
|
assert gemini_provider.get_provider_type() == ProviderType.GOOGLE
|
||||||
|
|
||||||
|
# OpenAI models
|
||||||
|
openai_provider = ModelProviderRegistry.get_provider_for_model("o3")
|
||||||
|
assert openai_provider is not None
|
||||||
|
assert openai_provider.get_provider_type() == ProviderType.OPENAI
|
||||||
|
|
||||||
|
# XAI models
|
||||||
|
xai_provider = ModelProviderRegistry.get_provider_for_model("grok")
|
||||||
|
assert xai_provider is not None
|
||||||
|
assert xai_provider.get_provider_type() == ProviderType.XAI
|
||||||
|
|
||||||
|
# Invalid model should return None
|
||||||
|
invalid_provider = ModelProviderRegistry.get_provider_for_model("invalid-model-name")
|
||||||
|
assert invalid_provider is None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original environment
|
||||||
|
for key, value in original_env.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key] = value
|
||||||
|
else:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
def test_alias_resolution_before_api_calls(self):
|
||||||
|
"""Test that model aliases are resolved before being passed to providers."""
|
||||||
|
|
||||||
|
# Save original environment
|
||||||
|
original_env = {}
|
||||||
|
for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"]:
|
||||||
|
original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set up all providers
|
||||||
|
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
os.environ["XAI_API_KEY"] = "test-key"
|
||||||
|
|
||||||
|
# Register all providers
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||||
|
|
||||||
|
# Test that providers resolve aliases correctly
|
||||||
|
test_cases = [
|
||||||
|
("flash", ProviderType.GOOGLE, "gemini-2.5-flash-preview-05-20"),
|
||||||
|
("pro", ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05"),
|
||||||
|
("mini", ProviderType.OPENAI, "o4-mini"),
|
||||||
|
("o3mini", ProviderType.OPENAI, "o3-mini"),
|
||||||
|
("grok", ProviderType.XAI, "grok-3"),
|
||||||
|
("grokfast", ProviderType.XAI, "grok-3-fast"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for alias, expected_provider_type, expected_resolved_name in test_cases:
|
||||||
|
provider = ModelProviderRegistry.get_provider_for_model(alias)
|
||||||
|
assert provider is not None, f"No provider found for alias '{alias}'"
|
||||||
|
assert provider.get_provider_type() == expected_provider_type, f"Wrong provider for '{alias}'"
|
||||||
|
|
||||||
|
# Test alias resolution
|
||||||
|
resolved_name = provider._resolve_model_name(alias)
|
||||||
|
assert (
|
||||||
|
resolved_name == expected_resolved_name
|
||||||
|
), f"Alias '{alias}' should resolve to '{expected_resolved_name}', got '{resolved_name}'"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original environment
|
||||||
|
for key, value in original_env.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key] = value
|
||||||
|
else:
|
||||||
|
os.environ.pop(key, None)
|
||||||
@@ -55,6 +55,8 @@ class TestClaudeContinuationOffers:
|
|||||||
"""Test Claude continuation offer functionality"""
|
"""Test Claude continuation offer functionality"""
|
||||||
|
|
||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
|
# Note: Tool creation and schema generation happens here
|
||||||
|
# If providers are not registered yet, tool might detect auto mode
|
||||||
self.tool = ClaudeContinuationTool()
|
self.tool = ClaudeContinuationTool()
|
||||||
# Set default model to avoid effective auto mode
|
# Set default model to avoid effective auto mode
|
||||||
self.tool.default_model = "gemini-2.5-flash-preview-05-20"
|
self.tool.default_model = "gemini-2.5-flash-preview-05-20"
|
||||||
@@ -63,11 +65,15 @@ class TestClaudeContinuationOffers:
|
|||||||
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
|
||||||
async def test_new_conversation_offers_continuation(self, mock_redis):
|
async def test_new_conversation_offers_continuation(self, mock_redis):
|
||||||
"""Test that new conversations offer Claude continuation opportunity"""
|
"""Test that new conversations offer Claude continuation opportunity"""
|
||||||
|
# Create tool AFTER providers are registered (in conftest.py fixture)
|
||||||
|
tool = ClaudeContinuationTool()
|
||||||
|
tool.default_model = "gemini-2.5-flash-preview-05-20"
|
||||||
|
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
mock_redis.return_value = mock_client
|
mock_redis.return_value = mock_client
|
||||||
|
|
||||||
# Mock the model
|
# Mock the model
|
||||||
with patch.object(self.tool, "get_model_provider") as mock_get_provider:
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
mock_provider = create_mock_provider()
|
mock_provider = create_mock_provider()
|
||||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||||
mock_provider.supports_thinking_mode.return_value = False
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
@@ -81,7 +87,7 @@ class TestClaudeContinuationOffers:
|
|||||||
|
|
||||||
# Execute tool without continuation_id (new conversation)
|
# Execute tool without continuation_id (new conversation)
|
||||||
arguments = {"prompt": "Analyze this code"}
|
arguments = {"prompt": "Analyze this code"}
|
||||||
response = await self.tool.execute(arguments)
|
response = await tool.execute(arguments)
|
||||||
|
|
||||||
# Parse response
|
# Parse response
|
||||||
response_data = json.loads(response[0].text)
|
response_data = json.loads(response[0].text)
|
||||||
@@ -177,10 +183,6 @@ class TestClaudeContinuationOffers:
|
|||||||
assert len(response) == 1
|
assert len(response) == 1
|
||||||
response_data = json.loads(response[0].text)
|
response_data = json.loads(response[0].text)
|
||||||
|
|
||||||
# Debug output
|
|
||||||
if response_data.get("status") == "error":
|
|
||||||
print(f"Error content: {response_data.get('content')}")
|
|
||||||
|
|
||||||
assert response_data["status"] == "continuation_available"
|
assert response_data["status"] == "continuation_available"
|
||||||
assert response_data["content"] == "Analysis complete. The code looks good."
|
assert response_data["content"] == "Analysis complete. The code looks good."
|
||||||
assert "continuation_offer" in response_data
|
assert "continuation_offer" in response_data
|
||||||
|
|||||||
@@ -17,51 +17,93 @@ class TestIntelligentFallback:
|
|||||||
"""Test intelligent model fallback logic"""
|
"""Test intelligent model fallback logic"""
|
||||||
|
|
||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
"""Setup for each test - clear registry cache"""
|
"""Setup for each test - clear registry and reset providers"""
|
||||||
ModelProviderRegistry.clear_cache()
|
# Store original providers for restoration
|
||||||
|
registry = ModelProviderRegistry()
|
||||||
|
self._original_providers = registry._providers.copy()
|
||||||
|
self._original_initialized = registry._initialized_providers.copy()
|
||||||
|
|
||||||
|
# Clear registry completely
|
||||||
|
ModelProviderRegistry._instance = None
|
||||||
|
|
||||||
def teardown_method(self):
|
def teardown_method(self):
|
||||||
"""Cleanup after each test"""
|
"""Cleanup after each test - restore original providers"""
|
||||||
ModelProviderRegistry.clear_cache()
|
# Restore original registry state
|
||||||
|
registry = ModelProviderRegistry()
|
||||||
|
registry._providers.clear()
|
||||||
|
registry._initialized_providers.clear()
|
||||||
|
registry._providers.update(self._original_providers)
|
||||||
|
registry._initialized_providers.update(self._original_initialized)
|
||||||
|
|
||||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False)
|
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False)
|
||||||
def test_prefers_openai_o3_mini_when_available(self):
|
def test_prefers_openai_o3_mini_when_available(self):
|
||||||
"""Test that o4-mini is preferred when OpenAI API key is available"""
|
"""Test that o4-mini is preferred when OpenAI API key is available"""
|
||||||
ModelProviderRegistry.clear_cache()
|
# Register only OpenAI provider for this test
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||||
assert fallback_model == "o4-mini"
|
assert fallback_model == "o4-mini"
|
||||||
|
|
||||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
|
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
|
||||||
def test_prefers_gemini_flash_when_openai_unavailable(self):
|
def test_prefers_gemini_flash_when_openai_unavailable(self):
|
||||||
"""Test that gemini-2.5-flash-preview-05-20 is used when only Gemini API key is available"""
|
"""Test that gemini-2.5-flash-preview-05-20 is used when only Gemini API key is available"""
|
||||||
ModelProviderRegistry.clear_cache()
|
# Register only Gemini provider for this test
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||||
assert fallback_model == "gemini-2.5-flash-preview-05-20"
|
assert fallback_model == "gemini-2.5-flash-preview-05-20"
|
||||||
|
|
||||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
|
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
|
||||||
def test_prefers_openai_when_both_available(self):
|
def test_prefers_openai_when_both_available(self):
|
||||||
"""Test that OpenAI is preferred when both API keys are available"""
|
"""Test that OpenAI is preferred when both API keys are available"""
|
||||||
ModelProviderRegistry.clear_cache()
|
# Register both OpenAI and Gemini providers
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||||
assert fallback_model == "o4-mini" # OpenAI has priority
|
assert fallback_model == "o4-mini" # OpenAI has priority
|
||||||
|
|
||||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False)
|
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False)
|
||||||
def test_fallback_when_no_keys_available(self):
|
def test_fallback_when_no_keys_available(self):
|
||||||
"""Test fallback behavior when no API keys are available"""
|
"""Test fallback behavior when no API keys are available"""
|
||||||
ModelProviderRegistry.clear_cache()
|
# Register providers but with no API keys available
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||||
assert fallback_model == "gemini-2.5-flash-preview-05-20" # Default fallback
|
assert fallback_model == "gemini-2.5-flash-preview-05-20" # Default fallback
|
||||||
|
|
||||||
def test_available_providers_with_keys(self):
|
def test_available_providers_with_keys(self):
|
||||||
"""Test the get_available_providers_with_keys method"""
|
"""Test the get_available_providers_with_keys method"""
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False):
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False):
|
||||||
ModelProviderRegistry.clear_cache()
|
# Clear and register providers
|
||||||
|
ModelProviderRegistry._instance = None
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
available = ModelProviderRegistry.get_available_providers_with_keys()
|
available = ModelProviderRegistry.get_available_providers_with_keys()
|
||||||
assert ProviderType.OPENAI in available
|
assert ProviderType.OPENAI in available
|
||||||
assert ProviderType.GOOGLE not in available
|
assert ProviderType.GOOGLE not in available
|
||||||
|
|
||||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False):
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False):
|
||||||
ModelProviderRegistry.clear_cache()
|
# Clear and register providers
|
||||||
|
ModelProviderRegistry._instance = None
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
available = ModelProviderRegistry.get_available_providers_with_keys()
|
available = ModelProviderRegistry.get_available_providers_with_keys()
|
||||||
assert ProviderType.GOOGLE in available
|
assert ProviderType.GOOGLE in available
|
||||||
assert ProviderType.OPENAI not in available
|
assert ProviderType.OPENAI not in available
|
||||||
@@ -76,7 +118,10 @@ class TestIntelligentFallback:
|
|||||||
patch("config.DEFAULT_MODEL", "auto"),
|
patch("config.DEFAULT_MODEL", "auto"),
|
||||||
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False),
|
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False),
|
||||||
):
|
):
|
||||||
ModelProviderRegistry.clear_cache()
|
# Register only OpenAI provider for this test
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|
||||||
# Create a context with at least one turn so it doesn't exit early
|
# Create a context with at least one turn so it doesn't exit early
|
||||||
from utils.conversation_memory import ConversationTurn
|
from utils.conversation_memory import ConversationTurn
|
||||||
@@ -114,7 +159,10 @@ class TestIntelligentFallback:
|
|||||||
patch("config.DEFAULT_MODEL", "auto"),
|
patch("config.DEFAULT_MODEL", "auto"),
|
||||||
patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False),
|
patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False),
|
||||||
):
|
):
|
||||||
ModelProviderRegistry.clear_cache()
|
# Register only Gemini provider for this test
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
from utils.conversation_memory import ConversationTurn
|
from utils.conversation_memory import ConversationTurn
|
||||||
|
|
||||||
|
|||||||
@@ -243,6 +243,19 @@ class TestLargePromptHandling:
|
|||||||
tool = ChatTool()
|
tool = ChatTool()
|
||||||
exact_prompt = "x" * MCP_PROMPT_SIZE_LIMIT
|
exact_prompt = "x" * MCP_PROMPT_SIZE_LIMIT
|
||||||
|
|
||||||
|
# Mock the model provider to avoid real API calls
|
||||||
|
with patch.object(tool, "get_model_provider") as mock_get_provider:
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.get_provider_type.return_value = MagicMock(value="google")
|
||||||
|
mock_provider.supports_thinking_mode.return_value = False
|
||||||
|
mock_provider.generate_content.return_value = MagicMock(
|
||||||
|
content="Response to the large prompt",
|
||||||
|
usage={"input_tokens": 12000, "output_tokens": 10, "total_tokens": 12010},
|
||||||
|
model_name="gemini-2.5-flash-preview-05-20",
|
||||||
|
metadata={"finish_reason": "STOP"},
|
||||||
|
)
|
||||||
|
mock_get_provider.return_value = mock_provider
|
||||||
|
|
||||||
# With the fix, this should now pass because we check at MCP transport boundary before adding internal content
|
# With the fix, this should now pass because we check at MCP transport boundary before adding internal content
|
||||||
result = await tool.execute({"prompt": exact_prompt})
|
result = await tool.execute({"prompt": exact_prompt})
|
||||||
output = json.loads(result[0].text)
|
output = json.loads(result[0].text)
|
||||||
|
|||||||
@@ -535,13 +535,26 @@ class TestAutoModeWithRestrictions:
|
|||||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GEMINI_API_KEY": "", "OPENAI_API_KEY": "test-key"})
|
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GEMINI_API_KEY": "", "OPENAI_API_KEY": "test-key"})
|
||||||
def test_fallback_with_shorthand_restrictions(self):
|
def test_fallback_with_shorthand_restrictions(self):
|
||||||
"""Test fallback model selection with shorthand restrictions."""
|
"""Test fallback model selection with shorthand restrictions."""
|
||||||
# Clear caches
|
# Clear caches and reset registry
|
||||||
import utils.model_restrictions
|
import utils.model_restrictions
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
from tools.models import ToolModelCategory
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
utils.model_restrictions._restriction_service = None
|
utils.model_restrictions._restriction_service = None
|
||||||
ModelProviderRegistry.clear_cache()
|
|
||||||
|
# Store original providers for restoration
|
||||||
|
registry = ModelProviderRegistry()
|
||||||
|
original_providers = registry._providers.copy()
|
||||||
|
original_initialized = registry._initialized_providers.copy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Clear registry and register only OpenAI and Gemini providers
|
||||||
|
ModelProviderRegistry._instance = None
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
# Even with "mini" restriction, fallback should work if provider handles it correctly
|
# Even with "mini" restriction, fallback should work if provider handles it correctly
|
||||||
# This tests the real-world scenario
|
# This tests the real-world scenario
|
||||||
@@ -550,3 +563,10 @@ class TestAutoModeWithRestrictions:
|
|||||||
# The fallback will depend on how get_available_models handles aliases
|
# The fallback will depend on how get_available_models handles aliases
|
||||||
# For now, we accept either behavior and document it
|
# For now, we accept either behavior and document it
|
||||||
assert model in ["o4-mini", "gemini-2.5-flash-preview-05-20"]
|
assert model in ["o4-mini", "gemini-2.5-flash-preview-05-20"]
|
||||||
|
finally:
|
||||||
|
# Restore original registry state
|
||||||
|
registry = ModelProviderRegistry()
|
||||||
|
registry._providers.clear()
|
||||||
|
registry._initialized_providers.clear()
|
||||||
|
registry._providers.update(original_providers)
|
||||||
|
registry._initialized_providers.update(original_initialized)
|
||||||
|
|||||||
221
tests/test_openai_provider.py
Normal file
221
tests/test_openai_provider.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
"""Tests for OpenAI provider implementation."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.openai import OpenAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIProvider:
|
||||||
|
"""Test OpenAI provider functionality."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up clean state before each test."""
|
||||||
|
# Clear restriction service cache before each test
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Clean up after each test to avoid singleton issues."""
|
||||||
|
# Clear restriction service cache after each test
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"})
|
||||||
|
def test_initialization(self):
|
||||||
|
"""Test provider initialization."""
|
||||||
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.get_provider_type() == ProviderType.OPENAI
|
||||||
|
assert provider.base_url == "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
def test_initialization_with_custom_url(self):
|
||||||
|
"""Test provider initialization with custom base URL."""
|
||||||
|
provider = OpenAIModelProvider("test-key", base_url="https://custom.openai.com/v1")
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.base_url == "https://custom.openai.com/v1"
|
||||||
|
|
||||||
|
def test_model_validation(self):
|
||||||
|
"""Test model name validation."""
|
||||||
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Test valid models
|
||||||
|
assert provider.validate_model_name("o3") is True
|
||||||
|
assert provider.validate_model_name("o3-mini") is True
|
||||||
|
assert provider.validate_model_name("o3-pro") is True
|
||||||
|
assert provider.validate_model_name("o4-mini") is True
|
||||||
|
assert provider.validate_model_name("o4-mini-high") is True
|
||||||
|
|
||||||
|
# Test valid aliases
|
||||||
|
assert provider.validate_model_name("mini") is True
|
||||||
|
assert provider.validate_model_name("o3mini") is True
|
||||||
|
assert provider.validate_model_name("o4mini") is True
|
||||||
|
assert provider.validate_model_name("o4minihigh") is True
|
||||||
|
assert provider.validate_model_name("o4minihi") is True
|
||||||
|
|
||||||
|
# Test invalid model
|
||||||
|
assert provider.validate_model_name("invalid-model") is False
|
||||||
|
assert provider.validate_model_name("gpt-4") is False
|
||||||
|
assert provider.validate_model_name("gemini-pro") is False
|
||||||
|
|
||||||
|
def test_resolve_model_name(self):
|
||||||
|
"""Test model name resolution."""
|
||||||
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Test shorthand resolution
|
||||||
|
assert provider._resolve_model_name("mini") == "o4-mini"
|
||||||
|
assert provider._resolve_model_name("o3mini") == "o3-mini"
|
||||||
|
assert provider._resolve_model_name("o4mini") == "o4-mini"
|
||||||
|
assert provider._resolve_model_name("o4minihigh") == "o4-mini-high"
|
||||||
|
assert provider._resolve_model_name("o4minihi") == "o4-mini-high"
|
||||||
|
|
||||||
|
# Test full name passthrough
|
||||||
|
assert provider._resolve_model_name("o3") == "o3"
|
||||||
|
assert provider._resolve_model_name("o3-mini") == "o3-mini"
|
||||||
|
assert provider._resolve_model_name("o3-pro") == "o3-pro"
|
||||||
|
assert provider._resolve_model_name("o4-mini") == "o4-mini"
|
||||||
|
assert provider._resolve_model_name("o4-mini-high") == "o4-mini-high"
|
||||||
|
|
||||||
|
def test_get_capabilities_o3(self):
|
||||||
|
"""Test getting model capabilities for O3."""
|
||||||
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
|
capabilities = provider.get_capabilities("o3")
|
||||||
|
assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities
|
||||||
|
assert capabilities.friendly_name == "OpenAI"
|
||||||
|
assert capabilities.context_window == 200_000
|
||||||
|
assert capabilities.provider == ProviderType.OPENAI
|
||||||
|
assert not capabilities.supports_extended_thinking
|
||||||
|
assert capabilities.supports_system_prompts is True
|
||||||
|
assert capabilities.supports_streaming is True
|
||||||
|
assert capabilities.supports_function_calling is True
|
||||||
|
|
||||||
|
# Test temperature constraint (O3 has fixed temperature)
|
||||||
|
assert capabilities.temperature_constraint.value == 1.0
|
||||||
|
|
||||||
|
def test_get_capabilities_with_alias(self):
|
||||||
|
"""Test getting model capabilities with alias resolves correctly."""
|
||||||
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
|
capabilities = provider.get_capabilities("mini")
|
||||||
|
assert capabilities.model_name == "mini" # Capabilities should show original request
|
||||||
|
assert capabilities.friendly_name == "OpenAI"
|
||||||
|
assert capabilities.context_window == 200_000
|
||||||
|
assert capabilities.provider == ProviderType.OPENAI
|
||||||
|
|
||||||
|
@patch("providers.openai_compatible.OpenAI")
|
||||||
|
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
|
||||||
|
"""Test that generate_content resolves aliases before making API calls.
|
||||||
|
|
||||||
|
This is the CRITICAL test that was missing - verifying that aliases
|
||||||
|
like 'mini' get resolved to 'o4-mini' before being sent to OpenAI API.
|
||||||
|
"""
|
||||||
|
# Set up mock OpenAI client
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_openai_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Mock the completion response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Test response"
|
||||||
|
mock_response.choices[0].finish_reason = "stop"
|
||||||
|
mock_response.model = "o4-mini" # API returns the resolved model name
|
||||||
|
mock_response.id = "test-id"
|
||||||
|
mock_response.created = 1234567890
|
||||||
|
mock_response.usage = MagicMock()
|
||||||
|
mock_response.usage.prompt_tokens = 10
|
||||||
|
mock_response.usage.completion_tokens = 5
|
||||||
|
mock_response.usage.total_tokens = 15
|
||||||
|
|
||||||
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Call generate_content with alias 'mini'
|
||||||
|
result = provider.generate_content(
|
||||||
|
prompt="Test prompt", model_name="mini", temperature=1.0 # This should be resolved to "o4-mini"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the API was called with the RESOLVED model name
|
||||||
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
|
|
||||||
|
# CRITICAL ASSERTION: The API should receive "o4-mini", not "mini"
|
||||||
|
assert call_kwargs["model"] == "o4-mini", f"Expected 'o4-mini' but API received '{call_kwargs['model']}'"
|
||||||
|
|
||||||
|
# Verify other parameters
|
||||||
|
assert call_kwargs["temperature"] == 1.0
|
||||||
|
assert len(call_kwargs["messages"]) == 1
|
||||||
|
assert call_kwargs["messages"][0]["role"] == "user"
|
||||||
|
assert call_kwargs["messages"][0]["content"] == "Test prompt"
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert result.content == "Test response"
|
||||||
|
assert result.model_name == "o4-mini" # Should be the resolved name
|
||||||
|
|
||||||
|
@patch("providers.openai_compatible.OpenAI")
|
||||||
|
def test_generate_content_other_aliases(self, mock_openai_class):
|
||||||
|
"""Test other alias resolutions in generate_content."""
|
||||||
|
# Set up mock
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_openai_class.return_value = mock_client
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Test response"
|
||||||
|
mock_response.choices[0].finish_reason = "stop"
|
||||||
|
mock_response.usage = MagicMock()
|
||||||
|
mock_response.usage.prompt_tokens = 10
|
||||||
|
mock_response.usage.completion_tokens = 5
|
||||||
|
mock_response.usage.total_tokens = 15
|
||||||
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Test o3mini -> o3-mini
|
||||||
|
mock_response.model = "o3-mini"
|
||||||
|
provider.generate_content(prompt="Test", model_name="o3mini", temperature=1.0)
|
||||||
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
|
assert call_kwargs["model"] == "o3-mini"
|
||||||
|
|
||||||
|
# Test o4minihigh -> o4-mini-high
|
||||||
|
mock_response.model = "o4-mini-high"
|
||||||
|
provider.generate_content(prompt="Test", model_name="o4minihigh", temperature=1.0)
|
||||||
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
|
assert call_kwargs["model"] == "o4-mini-high"
|
||||||
|
|
||||||
|
@patch("providers.openai_compatible.OpenAI")
|
||||||
|
def test_generate_content_no_alias_passthrough(self, mock_openai_class):
|
||||||
|
"""Test that full model names pass through unchanged."""
|
||||||
|
# Set up mock
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_openai_class.return_value = mock_client
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Test response"
|
||||||
|
mock_response.choices[0].finish_reason = "stop"
|
||||||
|
mock_response.model = "o3-pro"
|
||||||
|
mock_response.usage = MagicMock()
|
||||||
|
mock_response.usage.prompt_tokens = 10
|
||||||
|
mock_response.usage.completion_tokens = 5
|
||||||
|
mock_response.usage.total_tokens = 15
|
||||||
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Test full model name passes through unchanged
|
||||||
|
provider.generate_content(prompt="Test", model_name="o3-pro", temperature=1.0)
|
||||||
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
|
assert call_kwargs["model"] == "o3-pro" # Should be unchanged
|
||||||
|
|
||||||
|
def test_supports_thinking_mode(self):
|
||||||
|
"""Test thinking mode support (currently False for all OpenAI models)."""
|
||||||
|
provider = OpenAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# All OpenAI models currently don't support thinking mode
|
||||||
|
assert provider.supports_thinking_mode("o3") is False
|
||||||
|
assert provider.supports_thinking_mode("o3-mini") is False
|
||||||
|
assert provider.supports_thinking_mode("o4-mini") is False
|
||||||
|
assert provider.supports_thinking_mode("mini") is False # Test with alias too
|
||||||
@@ -202,9 +202,9 @@ class TestCustomProviderFallback:
|
|||||||
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
|
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
|
||||||
def test_extended_reasoning_custom_fallback(self, mock_find_thinking):
|
def test_extended_reasoning_custom_fallback(self, mock_find_thinking):
|
||||||
"""Test EXTENDED_REASONING falls back to custom thinking model."""
|
"""Test EXTENDED_REASONING falls back to custom thinking model."""
|
||||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
|
||||||
# No native providers available
|
# No native models available, but OpenRouter is available
|
||||||
mock_get_provider.return_value = None
|
mock_get_available.return_value = {"openrouter-model": ProviderType.OPENROUTER}
|
||||||
mock_find_thinking.return_value = "custom/thinking-model"
|
mock_find_thinking.return_value = "custom/thinking-model"
|
||||||
|
|
||||||
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
|
||||||
|
|||||||
326
tests/test_xai_provider.py
Normal file
326
tests/test_xai_provider.py
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
"""Tests for X.AI provider implementation."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from providers.base import ProviderType
|
||||||
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestXAIProvider:
|
||||||
|
"""Test X.AI provider functionality."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up clean state before each test."""
|
||||||
|
# Clear restriction service cache before each test
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Clean up after each test to avoid singleton issues."""
|
||||||
|
# Clear restriction service cache after each test
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"XAI_API_KEY": "test-key"})
|
||||||
|
def test_initialization(self):
|
||||||
|
"""Test provider initialization."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.get_provider_type() == ProviderType.XAI
|
||||||
|
assert provider.base_url == "https://api.x.ai/v1"
|
||||||
|
|
||||||
|
def test_initialization_with_custom_url(self):
|
||||||
|
"""Test provider initialization with custom base URL."""
|
||||||
|
provider = XAIModelProvider("test-key", base_url="https://custom.x.ai/v1")
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.base_url == "https://custom.x.ai/v1"
|
||||||
|
|
||||||
|
def test_model_validation(self):
|
||||||
|
"""Test model name validation."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Test valid models
|
||||||
|
assert provider.validate_model_name("grok-3") is True
|
||||||
|
assert provider.validate_model_name("grok-3-fast") is True
|
||||||
|
assert provider.validate_model_name("grok") is True
|
||||||
|
assert provider.validate_model_name("grok3") is True
|
||||||
|
assert provider.validate_model_name("grokfast") is True
|
||||||
|
assert provider.validate_model_name("grok3fast") is True
|
||||||
|
|
||||||
|
# Test invalid model
|
||||||
|
assert provider.validate_model_name("invalid-model") is False
|
||||||
|
assert provider.validate_model_name("gpt-4") is False
|
||||||
|
assert provider.validate_model_name("gemini-pro") is False
|
||||||
|
|
||||||
|
def test_resolve_model_name(self):
|
||||||
|
"""Test model name resolution."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Test shorthand resolution
|
||||||
|
assert provider._resolve_model_name("grok") == "grok-3"
|
||||||
|
assert provider._resolve_model_name("grok3") == "grok-3"
|
||||||
|
assert provider._resolve_model_name("grokfast") == "grok-3-fast"
|
||||||
|
assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
|
||||||
|
|
||||||
|
# Test full name passthrough
|
||||||
|
assert provider._resolve_model_name("grok-3") == "grok-3"
|
||||||
|
assert provider._resolve_model_name("grok-3-fast") == "grok-3-fast"
|
||||||
|
|
||||||
|
def test_get_capabilities_grok3(self):
|
||||||
|
"""Test getting model capabilities for GROK-3."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
capabilities = provider.get_capabilities("grok-3")
|
||||||
|
assert capabilities.model_name == "grok-3"
|
||||||
|
assert capabilities.friendly_name == "X.AI"
|
||||||
|
assert capabilities.context_window == 131_072
|
||||||
|
assert capabilities.provider == ProviderType.XAI
|
||||||
|
assert not capabilities.supports_extended_thinking
|
||||||
|
assert capabilities.supports_system_prompts is True
|
||||||
|
assert capabilities.supports_streaming is True
|
||||||
|
assert capabilities.supports_function_calling is True
|
||||||
|
|
||||||
|
# Test temperature range
|
||||||
|
assert capabilities.temperature_constraint.min_temp == 0.0
|
||||||
|
assert capabilities.temperature_constraint.max_temp == 2.0
|
||||||
|
assert capabilities.temperature_constraint.default_temp == 0.7
|
||||||
|
|
||||||
|
def test_get_capabilities_grok3_fast(self):
|
||||||
|
"""Test getting model capabilities for GROK-3 Fast."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
capabilities = provider.get_capabilities("grok-3-fast")
|
||||||
|
assert capabilities.model_name == "grok-3-fast"
|
||||||
|
assert capabilities.friendly_name == "X.AI"
|
||||||
|
assert capabilities.context_window == 131_072
|
||||||
|
assert capabilities.provider == ProviderType.XAI
|
||||||
|
assert not capabilities.supports_extended_thinking
|
||||||
|
|
||||||
|
def test_get_capabilities_with_shorthand(self):
|
||||||
|
"""Test getting model capabilities with shorthand."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
capabilities = provider.get_capabilities("grok")
|
||||||
|
assert capabilities.model_name == "grok-3" # Should resolve to full name
|
||||||
|
assert capabilities.context_window == 131_072
|
||||||
|
|
||||||
|
capabilities_fast = provider.get_capabilities("grokfast")
|
||||||
|
assert capabilities_fast.model_name == "grok-3-fast" # Should resolve to full name
|
||||||
|
|
||||||
|
def test_unsupported_model_capabilities(self):
|
||||||
|
"""Test error handling for unsupported models."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unsupported X.AI model"):
|
||||||
|
provider.get_capabilities("invalid-model")
|
||||||
|
|
||||||
|
def test_no_thinking_mode_support(self):
|
||||||
|
"""Test that X.AI models don't support thinking mode."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
assert not provider.supports_thinking_mode("grok-3")
|
||||||
|
assert not provider.supports_thinking_mode("grok-3-fast")
|
||||||
|
assert not provider.supports_thinking_mode("grok")
|
||||||
|
assert not provider.supports_thinking_mode("grokfast")
|
||||||
|
|
||||||
|
def test_provider_type(self):
|
||||||
|
"""Test provider type identification."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
assert provider.get_provider_type() == ProviderType.XAI
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok-3"})
|
||||||
|
def test_model_restrictions(self):
|
||||||
|
"""Test model restrictions functionality."""
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# grok-3 should be allowed
|
||||||
|
assert provider.validate_model_name("grok-3") is True
|
||||||
|
assert provider.validate_model_name("grok") is True # Shorthand for grok-3
|
||||||
|
|
||||||
|
# grok-3-fast should be blocked by restrictions
|
||||||
|
assert provider.validate_model_name("grok-3-fast") is False
|
||||||
|
assert provider.validate_model_name("grokfast") is False
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3-fast"})
|
||||||
|
def test_multiple_model_restrictions(self):
|
||||||
|
"""Test multiple models in restrictions."""
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Shorthand "grok" should be allowed (resolves to grok-3)
|
||||||
|
assert provider.validate_model_name("grok") is True
|
||||||
|
|
||||||
|
# Full name "grok-3" should NOT be allowed (only shorthand "grok" is in restriction list)
|
||||||
|
assert provider.validate_model_name("grok-3") is False
|
||||||
|
|
||||||
|
# "grok-3-fast" should be allowed (explicitly listed)
|
||||||
|
assert provider.validate_model_name("grok-3-fast") is True
|
||||||
|
|
||||||
|
# Shorthand "grokfast" should be allowed (resolves to grok-3-fast)
|
||||||
|
assert provider.validate_model_name("grokfast") is True
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3"})
|
||||||
|
def test_both_shorthand_and_full_name_allowed(self):
|
||||||
|
"""Test that both shorthand and full name can be allowed."""
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Both shorthand and full name should be allowed
|
||||||
|
assert provider.validate_model_name("grok") is True
|
||||||
|
assert provider.validate_model_name("grok-3") is True
|
||||||
|
|
||||||
|
# Other models should not be allowed
|
||||||
|
assert provider.validate_model_name("grok-3-fast") is False
|
||||||
|
assert provider.validate_model_name("grokfast") is False
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"XAI_ALLOWED_MODELS": ""})
|
||||||
|
def test_empty_restrictions_allows_all(self):
|
||||||
|
"""Test that empty restrictions allow all models."""
|
||||||
|
# Clear cached restriction service
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
assert provider.validate_model_name("grok-3") is True
|
||||||
|
assert provider.validate_model_name("grok-3-fast") is True
|
||||||
|
assert provider.validate_model_name("grok") is True
|
||||||
|
assert provider.validate_model_name("grokfast") is True
|
||||||
|
|
||||||
|
def test_friendly_name(self):
|
||||||
|
"""Test friendly name constant."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
assert provider.FRIENDLY_NAME == "X.AI"
|
||||||
|
|
||||||
|
capabilities = provider.get_capabilities("grok-3")
|
||||||
|
assert capabilities.friendly_name == "X.AI"
|
||||||
|
|
||||||
|
def test_supported_models_structure(self):
|
||||||
|
"""Test that SUPPORTED_MODELS has the correct structure."""
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Check that all expected models are present
|
||||||
|
assert "grok-3" in provider.SUPPORTED_MODELS
|
||||||
|
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
||||||
|
assert "grok" in provider.SUPPORTED_MODELS
|
||||||
|
assert "grok3" in provider.SUPPORTED_MODELS
|
||||||
|
assert "grokfast" in provider.SUPPORTED_MODELS
|
||||||
|
assert "grok3fast" in provider.SUPPORTED_MODELS
|
||||||
|
|
||||||
|
# Check model configs have required fields
|
||||||
|
grok3_config = provider.SUPPORTED_MODELS["grok-3"]
|
||||||
|
assert isinstance(grok3_config, dict)
|
||||||
|
assert "context_window" in grok3_config
|
||||||
|
assert "supports_extended_thinking" in grok3_config
|
||||||
|
assert grok3_config["context_window"] == 131_072
|
||||||
|
assert grok3_config["supports_extended_thinking"] is False
|
||||||
|
|
||||||
|
# Check shortcuts point to full names
|
||||||
|
assert provider.SUPPORTED_MODELS["grok"] == "grok-3"
|
||||||
|
assert provider.SUPPORTED_MODELS["grokfast"] == "grok-3-fast"
|
||||||
|
|
||||||
|
@patch("providers.openai_compatible.OpenAI")
|
||||||
|
def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
|
||||||
|
"""Test that generate_content resolves aliases before making API calls.
|
||||||
|
|
||||||
|
This is the CRITICAL test that ensures aliases like 'grok' get resolved
|
||||||
|
to 'grok-3' before being sent to X.AI API.
|
||||||
|
"""
|
||||||
|
# Set up mock OpenAI client
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_openai_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Mock the completion response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Test response"
|
||||||
|
mock_response.choices[0].finish_reason = "stop"
|
||||||
|
mock_response.model = "grok-3" # API returns the resolved model name
|
||||||
|
mock_response.id = "test-id"
|
||||||
|
mock_response.created = 1234567890
|
||||||
|
mock_response.usage = MagicMock()
|
||||||
|
mock_response.usage.prompt_tokens = 10
|
||||||
|
mock_response.usage.completion_tokens = 5
|
||||||
|
mock_response.usage.total_tokens = 15
|
||||||
|
|
||||||
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Call generate_content with alias 'grok'
|
||||||
|
result = provider.generate_content(
|
||||||
|
prompt="Test prompt", model_name="grok", temperature=0.7 # This should be resolved to "grok-3"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the API was called with the RESOLVED model name
|
||||||
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
|
|
||||||
|
# CRITICAL ASSERTION: The API should receive "grok-3", not "grok"
|
||||||
|
assert call_kwargs["model"] == "grok-3", f"Expected 'grok-3' but API received '{call_kwargs['model']}'"
|
||||||
|
|
||||||
|
# Verify other parameters
|
||||||
|
assert call_kwargs["temperature"] == 0.7
|
||||||
|
assert len(call_kwargs["messages"]) == 1
|
||||||
|
assert call_kwargs["messages"][0]["role"] == "user"
|
||||||
|
assert call_kwargs["messages"][0]["content"] == "Test prompt"
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert result.content == "Test response"
|
||||||
|
assert result.model_name == "grok-3" # Should be the resolved name
|
||||||
|
|
||||||
|
@patch("providers.openai_compatible.OpenAI")
|
||||||
|
def test_generate_content_other_aliases(self, mock_openai_class):
|
||||||
|
"""Test other alias resolutions in generate_content."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
# Set up mock
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_openai_class.return_value = mock_client
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Test response"
|
||||||
|
mock_response.choices[0].finish_reason = "stop"
|
||||||
|
mock_response.usage = MagicMock()
|
||||||
|
mock_response.usage.prompt_tokens = 10
|
||||||
|
mock_response.usage.completion_tokens = 5
|
||||||
|
mock_response.usage.total_tokens = 15
|
||||||
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
provider = XAIModelProvider("test-key")
|
||||||
|
|
||||||
|
# Test grok3 -> grok-3
|
||||||
|
mock_response.model = "grok-3"
|
||||||
|
provider.generate_content(prompt="Test", model_name="grok3", temperature=0.7)
|
||||||
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
|
assert call_kwargs["model"] == "grok-3"
|
||||||
|
|
||||||
|
# Test grokfast -> grok-3-fast
|
||||||
|
mock_response.model = "grok-3-fast"
|
||||||
|
provider.generate_content(prompt="Test", model_name="grokfast", temperature=0.7)
|
||||||
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
|
assert call_kwargs["model"] == "grok-3-fast"
|
||||||
|
|
||||||
|
# Test grok3fast -> grok-3-fast
|
||||||
|
provider.generate_content(prompt="Test", model_name="grok3fast", temperature=0.7)
|
||||||
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||||
|
assert call_kwargs["model"] == "grok-3-fast"
|
||||||
@@ -9,11 +9,13 @@ standardization purposes.
|
|||||||
Environment Variables:
|
Environment Variables:
|
||||||
- OPENAI_ALLOWED_MODELS: Comma-separated list of allowed OpenAI models
|
- OPENAI_ALLOWED_MODELS: Comma-separated list of allowed OpenAI models
|
||||||
- GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models
|
- GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models
|
||||||
|
- XAI_ALLOWED_MODELS: Comma-separated list of allowed X.AI GROK models
|
||||||
- OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models
|
- OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
OPENAI_ALLOWED_MODELS=o3-mini,o4-mini
|
OPENAI_ALLOWED_MODELS=o3-mini,o4-mini
|
||||||
GOOGLE_ALLOWED_MODELS=flash
|
GOOGLE_ALLOWED_MODELS=flash
|
||||||
|
XAI_ALLOWED_MODELS=grok-3,grok-3-fast
|
||||||
OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral
|
OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -40,6 +42,7 @@ class ModelRestrictionService:
|
|||||||
ENV_VARS = {
|
ENV_VARS = {
|
||||||
ProviderType.OPENAI: "OPENAI_ALLOWED_MODELS",
|
ProviderType.OPENAI: "OPENAI_ALLOWED_MODELS",
|
||||||
ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS",
|
ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS",
|
||||||
|
ProviderType.XAI: "XAI_ALLOWED_MODELS",
|
||||||
ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS",
|
ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user