diff --git a/.env.example b/.env.example
index 1ce8f90..260c46e 100644
--- a/.env.example
+++ b/.env.example
@@ -18,6 +18,9 @@ GEMINI_API_KEY=your_gemini_api_key_here
# Get your OpenAI API key from: https://platform.openai.com/api-keys
OPENAI_API_KEY=your_openai_api_key_here
+# Get your X.AI API key from: https://console.x.ai/
+XAI_API_KEY=your_xai_api_key_here
+
# Option 2: Use OpenRouter for access to multiple models through one API
# Get your OpenRouter API key from: https://openrouter.ai/
# If using OpenRouter, comment out the native API keys above
@@ -68,15 +71,25 @@ DEFAULT_THINKING_MODE_THINKDEEP=high
# - flash (shorthand for gemini-2.5-flash-preview-05-20)
# - pro (shorthand for gemini-2.5-pro-preview-06-05)
#
+# Supported X.AI GROK models:
+# - grok-3 (131K context, advanced reasoning)
+# - grok-3-fast (131K context, higher performance but more expensive)
+# - grok (shorthand for grok-3)
+# - grok3 (shorthand for grok-3)
+# - grokfast (shorthand for grok-3-fast)
+#
# Examples:
# OPENAI_ALLOWED_MODELS=o3-mini,o4-mini,mini # Only allow mini models (cost control)
# GOOGLE_ALLOWED_MODELS=flash # Only allow Flash (fast responses)
+# XAI_ALLOWED_MODELS=grok-3 # Only allow standard GROK (not fast variant)
# OPENAI_ALLOWED_MODELS=o4-mini # Single model standardization
# GOOGLE_ALLOWED_MODELS=flash,pro # Allow both Gemini models
+# XAI_ALLOWED_MODELS=grok,grok-3-fast # Allow both GROK variants
#
# Note: These restrictions apply even in 'auto' mode - Claude will only pick from allowed models
# OPENAI_ALLOWED_MODELS=
# GOOGLE_ALLOWED_MODELS=
+# XAI_ALLOWED_MODELS=
# Optional: Custom model configuration file path
# Override the default location of custom_models.json
diff --git a/README.md b/README.md
index 949d99a..adfa015 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
https://github.com/user-attachments/assets/8097e18e-b926-4d8b-ba14-a979e4c58bda
- đ¤ Claude + [Gemini / O3 / OpenRouter / Ollama / Any Model] = Your Ultimate AI Development Team
+ đ¤ Claude + [Gemini / O3 / GROK / OpenRouter / Ollama / Any Model] = Your Ultimate AI Development Team
@@ -115,6 +115,7 @@ The final implementation resulted in a 26% improvement in JSON parsing performan
**Option B: Native APIs**
- **Gemini**: Visit [Google AI Studio](https://makersuite.google.com/app/apikey) and generate an API key. For best results with Gemini 2.5 Pro, use a paid API key as the free tier has limited access to the latest models.
- **OpenAI**: Visit [OpenAI Platform](https://platform.openai.com/api-keys) to get an API key for O3 model access.
+- **X.AI**: Visit [X.AI Console](https://console.x.ai/) to get an API key for GROK model access.
**Option C: Custom API Endpoints (Local models like Ollama, vLLM)**
[Please see the setup guide](docs/custom_models.md#option-2-custom-api-setup-ollama-vllm-etc). With a custom API you can use:
diff --git a/config.py b/config.py
index 1e30bfe..d3bcd8f 100644
--- a/config.py
+++ b/config.py
@@ -14,7 +14,7 @@ import os
# These values are used in server responses and for tracking releases
# IMPORTANT: This is the single source of truth for version and author info
# Semantic versioning: MAJOR.MINOR.PATCH
-__version__ = "4.5.1"
+__version__ = "4.6.0"
# Last update date in ISO format
__updated__ = "2025-06-15"
# Primary maintainer
@@ -53,6 +53,12 @@ MODEL_CAPABILITIES_DESC = {
"o3-pro": "Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
"o4-mini": "Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
"o4-mini-high": "Enhanced O4 mini (200K context) - Higher reasoning effort for complex tasks",
+ # X.AI GROK models - Available when XAI_API_KEY is configured
+ "grok": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
+ "grok-3": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
+ "grok-3-fast": "GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
+ "grok3": "GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
+ "grokfast": "GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
# Full model names also supported (for explicit specification)
"gemini-2.5-flash-preview-05-20": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
"gemini-2.5-pro-preview-06-05": (
diff --git a/docker-compose.yml b/docker-compose.yml
index 4b40b32..dac4ac3 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -31,6 +31,7 @@ services:
environment:
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
+ - XAI_API_KEY=${XAI_API_KEY:-}
# OpenRouter support
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
- CUSTOM_MODELS_CONFIG_PATH=${CUSTOM_MODELS_CONFIG_PATH:-}
@@ -45,6 +46,7 @@ services:
# Model usage restrictions
- OPENAI_ALLOWED_MODELS=${OPENAI_ALLOWED_MODELS:-}
- GOOGLE_ALLOWED_MODELS=${GOOGLE_ALLOWED_MODELS:-}
+ - XAI_ALLOWED_MODELS=${XAI_ALLOWED_MODELS:-}
- REDIS_URL=redis://redis:6379/0
# Use HOME not PWD: Claude needs access to any absolute file path, not just current project,
# and Claude Code could be running from multiple locations at the same time
diff --git a/docs/adding_providers.md b/docs/adding_providers.md
index 182230c..f700c86 100644
--- a/docs/adding_providers.md
+++ b/docs/adding_providers.md
@@ -23,9 +23,11 @@ Inherit from `ModelProvider` when:
### Option B: OpenAI-Compatible Provider (Simplified)
Inherit from `OpenAICompatibleProvider` when:
- Your API follows OpenAI's chat completion format
-- You want to reuse existing implementation for `generate_content` and `count_tokens`
+- You want to reuse existing implementation for most functionality
- You only need to define model capabilities and validation
+â ī¸ **CRITICAL**: If your provider has model aliases (shorthands), you **MUST** override `generate_content()` to resolve aliases before API calls. See implementation example below.
+
## Step-by-Step Guide
### 1. Add Provider Type to Enum
@@ -177,8 +179,11 @@ For providers with OpenAI-compatible APIs, the implementation is much simpler:
"""Example provider using OpenAI-compatible interface."""
import logging
+from typing import Optional
+
from .base import (
ModelCapabilities,
+ ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
@@ -268,7 +273,31 @@ class ExampleProvider(OpenAICompatibleProvider):
return shorthand_value
return model_name
- # Note: generate_content and count_tokens are inherited from OpenAICompatibleProvider
+ def generate_content(
+ self,
+ prompt: str,
+ model_name: str,
+ system_prompt: Optional[str] = None,
+ temperature: float = 0.7,
+ max_output_tokens: Optional[int] = None,
+ **kwargs,
+ ) -> ModelResponse:
+ """Generate content using API with proper model name resolution."""
+ # CRITICAL: Resolve model alias before making API call
+ # This ensures aliases like "large" get sent as "example-model-large" to the API
+ resolved_model_name = self._resolve_model_name(model_name)
+
+ # Call parent implementation with resolved model name
+ return super().generate_content(
+ prompt=prompt,
+ model_name=resolved_model_name,
+ system_prompt=system_prompt,
+ temperature=temperature,
+ max_output_tokens=max_output_tokens,
+ **kwargs,
+ )
+
+ # Note: count_tokens is inherited from OpenAICompatibleProvider
```
### 3. Update Registry Configuration
@@ -291,7 +320,32 @@ def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]
# ... rest of the method
```
-### 4. Register Provider in server.py
+### 4. Configure Docker Environment Variables
+
+**CRITICAL**: You must add your provider's environment variables to `docker-compose.yml` for them to be available in the Docker container.
+
+Add your API key and restriction variables to the `environment` section:
+
+```yaml
+services:
+ zen-mcp:
+ # ... other configuration ...
+ environment:
+ - GEMINI_API_KEY=${GEMINI_API_KEY:-}
+ - OPENAI_API_KEY=${OPENAI_API_KEY:-}
+ - EXAMPLE_API_KEY=${EXAMPLE_API_KEY:-} # Add this line
+ # OpenRouter support
+ - OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
+ # ... other variables ...
+ # Model usage restrictions
+ - OPENAI_ALLOWED_MODELS=${OPENAI_ALLOWED_MODELS:-}
+ - GOOGLE_ALLOWED_MODELS=${GOOGLE_ALLOWED_MODELS:-}
+ - EXAMPLE_ALLOWED_MODELS=${EXAMPLE_ALLOWED_MODELS:-} # Add this line
+```
+
+â ī¸ **Without this step**, the Docker container won't have access to your environment variables, and your provider won't be registered even if the API key is set in your `.env` file.
+
+### 5. Register Provider in server.py
The `configure_providers()` function in `server.py` handles provider registration. You need to:
@@ -355,7 +409,7 @@ def configure_providers():
)
```
-### 5. Add Model Capabilities for Auto Mode
+### 6. Add Model Capabilities for Auto Mode
Update `config.py` to add your models to `MODEL_CAPABILITIES_DESC`:
@@ -372,9 +426,9 @@ MODEL_CAPABILITIES_DESC = {
}
```
-### 6. Update Documentation
+### 7. Update Documentation
-#### 6.1. Update README.md
+#### 7.1. Update README.md
Add your provider to the quickstart section:
@@ -396,9 +450,9 @@ Also update the .env file example:
# EXAMPLE_API_KEY=your-example-api-key-here # Add this
```
-### 7. Write Tests
+### 8. Write Tests
-#### 7.1. Unit Tests
+#### 8.1. Unit Tests
Create `tests/test_example_provider.py`:
@@ -460,7 +514,7 @@ class TestExampleProvider:
assert capabilities.temperature_constraint.max_temp == 2.0
```
-#### 7.2. Simulator Tests (Real-World Validation)
+#### 8.2. Simulator Tests (Real-World Validation)
Create a simulator test to validate that your provider works correctly in real-world scenarios. Create `simulator_tests/test_example_models.py`:
@@ -696,6 +750,36 @@ SUPPORTED_MODELS = {
The `_resolve_model_name()` method handles this mapping automatically.
+## Critical Implementation Requirements
+
+### Alias Resolution for OpenAI-Compatible Providers
+
+If you inherit from `OpenAICompatibleProvider` and define model aliases, you **MUST** override `generate_content()` to resolve aliases before API calls. This is because:
+
+1. **The base `OpenAICompatibleProvider.generate_content()`** sends the original model name directly to the API
+2. **Your API expects the full model name**, not the alias
+3. **Without resolution**, requests like `model="large"` will fail with 404/400 errors
+
+**Examples of providers that need this:**
+- XAI provider: `"grok"` â `"grok-3"`
+- OpenAI provider: `"mini"` â `"o4-mini"`
+- Custom provider: `"fast"` â `"llama-3.1-8b-instruct"`
+
+**Example implementation pattern:**
+```python
+def generate_content(self, prompt: str, model_name: str, **kwargs) -> ModelResponse:
+ # CRITICAL: Resolve alias before API call
+ resolved_model_name = self._resolve_model_name(model_name)
+
+ # Pass resolved name to parent
+ return super().generate_content(prompt=prompt, model_name=resolved_model_name, **kwargs)
+```
+
+**Providers that DON'T need this:**
+- Gemini provider (has its own generate_content implementation)
+- OpenRouter provider (already implements this pattern)
+- Providers without aliases
+
## Best Practices
1. **Always validate model names** against supported models and restrictions
@@ -715,6 +799,7 @@ Before submitting your PR:
- [ ] Provider implementation complete with all required methods
- [ ] API key mapping added to `_get_api_key_for_provider()` in `providers/registry.py`
- [ ] Provider added to `PROVIDER_PRIORITY_ORDER` in `registry.py` (if native provider)
+- [ ] **Environment variables added to `docker-compose.yml`** (API key and restrictions)
- [ ] Provider imported and registered in `server.py`'s `configure_providers()`
- [ ] API key checking added to `configure_providers()` function
- [ ] Error message updated to include new provider
diff --git a/providers/base.py b/providers/base.py
index 5d2b20e..580f39f 100644
--- a/providers/base.py
+++ b/providers/base.py
@@ -11,6 +11,7 @@ class ProviderType(Enum):
GOOGLE = "google"
OPENAI = "openai"
+ XAI = "xai"
OPENROUTER = "openrouter"
CUSTOM = "custom"
diff --git a/providers/openai.py b/providers/openai.py
index 3d0b3b5..b920af8 100644
--- a/providers/openai.py
+++ b/providers/openai.py
@@ -1,10 +1,12 @@
"""OpenAI model provider implementation."""
import logging
+from typing import Optional
from .base import (
FixedTemperatureConstraint,
ModelCapabilities,
+ ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
@@ -111,6 +113,29 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
return True
+ def generate_content(
+ self,
+ prompt: str,
+ model_name: str,
+ system_prompt: Optional[str] = None,
+ temperature: float = 0.7,
+ max_output_tokens: Optional[int] = None,
+ **kwargs,
+ ) -> ModelResponse:
+ """Generate content using OpenAI API with proper model name resolution."""
+ # Resolve model alias before making API call
+ resolved_model_name = self._resolve_model_name(model_name)
+
+ # Call parent implementation with resolved model name
+ return super().generate_content(
+ prompt=prompt,
+ model_name=resolved_model_name,
+ system_prompt=system_prompt,
+ temperature=temperature,
+ max_output_tokens=max_output_tokens,
+ **kwargs,
+ )
+
def supports_thinking_mode(self, model_name: str) -> bool:
"""Check if the model supports extended thinking mode."""
# Currently no OpenAI models support extended thinking
diff --git a/providers/registry.py b/providers/registry.py
index 09166ad..6332466 100644
--- a/providers/registry.py
+++ b/providers/registry.py
@@ -117,6 +117,7 @@ class ModelProviderRegistry:
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
+ ProviderType.XAI, # Direct X.AI GROK access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
@@ -173,16 +174,22 @@ class ModelProviderRegistry:
# Get supported models based on provider type
if hasattr(provider, "SUPPORTED_MODELS"):
for model_name, config in provider.SUPPORTED_MODELS.items():
- # Skip aliases (string values)
+ # Handle both base models (dict configs) and aliases (string values)
if isinstance(config, str):
- continue
-
- # Check restrictions if enabled
- if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
- logging.debug(f"Model {model_name} filtered by restrictions")
- continue
-
- models[model_name] = provider_type
+ # This is an alias - check if the target model would be allowed
+ target_model = config
+ if restriction_service and not restriction_service.is_allowed(provider_type, target_model):
+ logging.debug(f"Alias {model_name} -> {target_model} filtered by restrictions")
+ continue
+ # Allow the alias
+ models[model_name] = provider_type
+ else:
+ # This is a base model with config dict
+ # Check restrictions if enabled
+ if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
+ logging.debug(f"Model {model_name} filtered by restrictions")
+ continue
+ models[model_name] = provider_type
elif provider_type == ProviderType.OPENROUTER:
# OpenRouter uses a registry system instead of SUPPORTED_MODELS
if hasattr(provider, "_registry") and provider._registry:
@@ -230,6 +237,7 @@ class ModelProviderRegistry:
key_mapping = {
ProviderType.GOOGLE: "GEMINI_API_KEY",
ProviderType.OPENAI: "OPENAI_API_KEY",
+ ProviderType.XAI: "XAI_API_KEY",
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
}
@@ -264,9 +272,13 @@ class ModelProviderRegistry:
# Group by provider
openai_models = [m for m, p in available_models.items() if p == ProviderType.OPENAI]
gemini_models = [m for m, p in available_models.items() if p == ProviderType.GOOGLE]
+ xai_models = [m for m, p in available_models.items() if p == ProviderType.XAI]
+ openrouter_models = [m for m, p in available_models.items() if p == ProviderType.OPENROUTER]
openai_available = bool(openai_models)
gemini_available = bool(gemini_models)
+ xai_available = bool(xai_models)
+ openrouter_available = bool(openrouter_models)
if tool_category == ToolModelCategory.EXTENDED_REASONING:
# Prefer thinking-capable models for deep reasoning tools
@@ -275,17 +287,25 @@ class ModelProviderRegistry:
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
+ elif xai_available and "grok-3" in xai_models:
+ return "grok-3" # GROK-3 for deep reasoning
+ elif xai_available and xai_models:
+ # Fall back to any available XAI model
+ return xai_models[0]
elif gemini_available and any("pro" in m for m in gemini_models):
# Find the pro model (handles full names)
return next(m for m in gemini_models if "pro" in m)
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
- else:
- # Try to find thinking-capable model from custom/openrouter
+ elif openrouter_available:
+ # Try to find thinking-capable model from openrouter
thinking_model = cls._find_extended_thinking_model()
if thinking_model:
return thinking_model
+ # Fallback to first available OpenRouter model
+ return openrouter_models[0]
+ else:
# Fallback to pro if nothing found
return "gemini-2.5-pro-preview-06-05"
@@ -298,12 +318,20 @@ class ModelProviderRegistry:
elif openai_available and openai_models:
# Fall back to any available OpenAI model
return openai_models[0]
+ elif xai_available and "grok-3-fast" in xai_models:
+ return "grok-3-fast" # GROK-3 Fast for speed
+ elif xai_available and xai_models:
+ # Fall back to any available XAI model
+ return xai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
# Find the flash model (handles full names)
return next(m for m in gemini_models if "flash" in m)
elif gemini_available and gemini_models:
# Fall back to any available Gemini model
return gemini_models[0]
+ elif openrouter_available:
+ # Fallback to first available OpenRouter model
+ return openrouter_models[0]
else:
# Default to flash
return "gemini-2.5-flash-preview-05-20"
@@ -315,10 +343,16 @@ class ModelProviderRegistry:
return "o3-mini" # Second choice
elif openai_available and openai_models:
return openai_models[0]
+ elif xai_available and "grok-3" in xai_models:
+ return "grok-3" # GROK-3 as balanced choice
+ elif xai_available and xai_models:
+ return xai_models[0]
elif gemini_available and any("flash" in m for m in gemini_models):
return next(m for m in gemini_models if "flash" in m)
elif gemini_available and gemini_models:
return gemini_models[0]
+ elif openrouter_available:
+ return openrouter_models[0]
else:
# No models available due to restrictions - check if any providers exist
if not available_models:
@@ -355,8 +389,9 @@ class ModelProviderRegistry:
preferred_models = [
"anthropic/claude-3.5-sonnet",
"anthropic/claude-3-opus-20240229",
- "meta-llama/llama-3.1-70b-instruct",
+ "google/gemini-2.5-pro-preview-06-05",
"google/gemini-pro-1.5",
+ "meta-llama/llama-3.1-70b-instruct",
"mistralai/mixtral-8x7b-instruct",
]
for model in preferred_models:
diff --git a/providers/xai.py b/providers/xai.py
new file mode 100644
index 0000000..533bea3
--- /dev/null
+++ b/providers/xai.py
@@ -0,0 +1,135 @@
+"""X.AI (GROK) model provider implementation."""
+
+import logging
+from typing import Optional
+
+from .base import (
+ ModelCapabilities,
+ ModelResponse,
+ ProviderType,
+ RangeTemperatureConstraint,
+)
+from .openai_compatible import OpenAICompatibleProvider
+
+logger = logging.getLogger(__name__)
+
+
+class XAIModelProvider(OpenAICompatibleProvider):
+ """X.AI GROK API provider (api.x.ai)."""
+
+ FRIENDLY_NAME = "X.AI"
+
+ # Model configurations
+ SUPPORTED_MODELS = {
+ "grok-3": {
+ "context_window": 131_072, # 131K tokens
+ "supports_extended_thinking": False,
+ },
+ "grok-3-fast": {
+ "context_window": 131_072, # 131K tokens
+ "supports_extended_thinking": False,
+ },
+ # Shorthands for convenience
+ "grok": "grok-3", # Default to grok-3
+ "grok3": "grok-3",
+ "grok3fast": "grok-3-fast",
+ "grokfast": "grok-3-fast",
+ }
+
+ def __init__(self, api_key: str, **kwargs):
+ """Initialize X.AI provider with API key."""
+ # Set X.AI base URL
+ kwargs.setdefault("base_url", "https://api.x.ai/v1")
+ super().__init__(api_key, **kwargs)
+
+ def get_capabilities(self, model_name: str) -> ModelCapabilities:
+ """Get capabilities for a specific X.AI model."""
+ # Resolve shorthand
+ resolved_name = self._resolve_model_name(model_name)
+
+ if resolved_name not in self.SUPPORTED_MODELS or isinstance(self.SUPPORTED_MODELS[resolved_name], str):
+ raise ValueError(f"Unsupported X.AI model: {model_name}")
+
+ # Check if model is allowed by restrictions
+ from utils.model_restrictions import get_restriction_service
+
+ restriction_service = get_restriction_service()
+ if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
+ raise ValueError(f"X.AI model '{model_name}' is not allowed by restriction policy.")
+
+ config = self.SUPPORTED_MODELS[resolved_name]
+
+ # Define temperature constraints for GROK models
+ # GROK supports the standard OpenAI temperature range
+ temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
+
+ return ModelCapabilities(
+ provider=ProviderType.XAI,
+ model_name=resolved_name,
+ friendly_name=self.FRIENDLY_NAME,
+ context_window=config["context_window"],
+ supports_extended_thinking=config["supports_extended_thinking"],
+ supports_system_prompts=True,
+ supports_streaming=True,
+ supports_function_calling=True,
+ temperature_constraint=temp_constraint,
+ )
+
+ def get_provider_type(self) -> ProviderType:
+ """Get the provider type."""
+ return ProviderType.XAI
+
+ def validate_model_name(self, model_name: str) -> bool:
+ """Validate if the model name is supported and allowed."""
+ resolved_name = self._resolve_model_name(model_name)
+
+ # First check if model is supported
+ if resolved_name not in self.SUPPORTED_MODELS or not isinstance(self.SUPPORTED_MODELS[resolved_name], dict):
+ return False
+
+ # Then check if model is allowed by restrictions
+ from utils.model_restrictions import get_restriction_service
+
+ restriction_service = get_restriction_service()
+ if not restriction_service.is_allowed(ProviderType.XAI, resolved_name, model_name):
+ logger.debug(f"X.AI model '{model_name}' -> '{resolved_name}' blocked by restrictions")
+ return False
+
+ return True
+
+ def generate_content(
+ self,
+ prompt: str,
+ model_name: str,
+ system_prompt: Optional[str] = None,
+ temperature: float = 0.7,
+ max_output_tokens: Optional[int] = None,
+ **kwargs,
+ ) -> ModelResponse:
+ """Generate content using X.AI API with proper model name resolution."""
+ # Resolve model alias before making API call
+ resolved_model_name = self._resolve_model_name(model_name)
+
+ # Call parent implementation with resolved model name
+ return super().generate_content(
+ prompt=prompt,
+ model_name=resolved_model_name,
+ system_prompt=system_prompt,
+ temperature=temperature,
+ max_output_tokens=max_output_tokens,
+ **kwargs,
+ )
+
+ def supports_thinking_mode(self, model_name: str) -> bool:
+ """Check if the model supports extended thinking mode."""
+ # Currently GROK models do not support extended thinking
+ # This may change with future GROK model releases
+ return False
+
+ def _resolve_model_name(self, model_name: str) -> str:
+ """Resolve model shorthand to full name."""
+ # Check if it's a shorthand
+ shorthand_value = self.SUPPORTED_MODELS.get(model_name)
+ if isinstance(shorthand_value, str):
+ return shorthand_value
+ return model_name
diff --git a/run-server.sh b/run-server.sh
index f120362..c6dd8f6 100755
--- a/run-server.sh
+++ b/run-server.sh
@@ -120,6 +120,16 @@ else
fi
fi
+ if [ -n "${XAI_API_KEY:-}" ]; then
+ # Replace the placeholder API key with the actual value
+ if command -v sed >/dev/null 2>&1; then
+ sed -i.bak "s/your_xai_api_key_here/$XAI_API_KEY/" .env && rm .env.bak
+ echo "â
Updated .env with existing XAI_API_KEY from environment"
+ else
+ echo "â ī¸ Found XAI_API_KEY in environment, but sed not available. Please update .env manually."
+ fi
+ fi
+
if [ -n "${OPENROUTER_API_KEY:-}" ]; then
# Replace the placeholder API key with the actual value
if command -v sed >/dev/null 2>&1; then
@@ -169,6 +179,7 @@ source .env 2>/dev/null || true
VALID_GEMINI_KEY=false
VALID_OPENAI_KEY=false
+VALID_XAI_KEY=false
VALID_OPENROUTER_KEY=false
VALID_CUSTOM_URL=false
@@ -184,6 +195,12 @@ if [ -n "${OPENAI_API_KEY:-}" ] && [ "$OPENAI_API_KEY" != "your_openai_api_key_h
echo "â
OPENAI_API_KEY found"
fi
+# Check if XAI_API_KEY is set and not the placeholder
+if [ -n "${XAI_API_KEY:-}" ] && [ "$XAI_API_KEY" != "your_xai_api_key_here" ]; then
+ VALID_XAI_KEY=true
+ echo "â
XAI_API_KEY found"
+fi
+
# Check if OPENROUTER_API_KEY is set and not the placeholder
if [ -n "${OPENROUTER_API_KEY:-}" ] && [ "$OPENROUTER_API_KEY" != "your_openrouter_api_key_here" ]; then
VALID_OPENROUTER_KEY=true
@@ -197,19 +214,21 @@ if [ -n "${CUSTOM_API_URL:-}" ]; then
fi
# Require at least one valid API key or custom URL
-if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ] && [ "$VALID_OPENROUTER_KEY" = false ] && [ "$VALID_CUSTOM_URL" = false ]; then
+if [ "$VALID_GEMINI_KEY" = false ] && [ "$VALID_OPENAI_KEY" = false ] && [ "$VALID_XAI_KEY" = false ] && [ "$VALID_OPENROUTER_KEY" = false ] && [ "$VALID_CUSTOM_URL" = false ]; then
echo ""
echo "â ERROR: At least one valid API key or custom URL is required!"
echo ""
echo "Please edit the .env file and set at least one of:"
echo " - GEMINI_API_KEY (get from https://makersuite.google.com/app/apikey)"
echo " - OPENAI_API_KEY (get from https://platform.openai.com/api-keys)"
+ echo " - XAI_API_KEY (get from https://console.x.ai/)"
echo " - OPENROUTER_API_KEY (get from https://openrouter.ai/)"
echo " - CUSTOM_API_URL (for local models like Ollama, vLLM, etc.)"
echo ""
echo "Example:"
echo " GEMINI_API_KEY=your-actual-api-key-here"
echo " OPENAI_API_KEY=sk-your-actual-openai-key-here"
+ echo " XAI_API_KEY=xai-your-actual-xai-key-here"
echo " OPENROUTER_API_KEY=sk-or-your-actual-openrouter-key-here"
echo " CUSTOM_API_URL=http://host.docker.internal:11434/v1 # Ollama (use host.docker.internal, NOT localhost!)"
echo ""
@@ -302,7 +321,7 @@ show_configuration_steps() {
echo ""
echo "đ Next steps:"
NEEDS_KEY_UPDATE=false
- if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null || grep -q "your_openrouter_api_key_here" .env 2>/dev/null; then
+ if grep -q "your_gemini_api_key_here" .env 2>/dev/null || grep -q "your_openai_api_key_here" .env 2>/dev/null || grep -q "your_xai_api_key_here" .env 2>/dev/null || grep -q "your_openrouter_api_key_here" .env 2>/dev/null; then
NEEDS_KEY_UPDATE=true
fi
@@ -310,6 +329,7 @@ show_configuration_steps() {
echo "1. Edit .env and replace placeholder API keys with actual ones"
echo " - GEMINI_API_KEY: your-gemini-api-key-here"
echo " - OPENAI_API_KEY: your-openai-api-key-here"
+ echo " - XAI_API_KEY: your-xai-api-key-here"
echo " - OPENROUTER_API_KEY: your-openrouter-api-key-here (optional)"
echo "2. Restart services: $COMPOSE_CMD restart"
echo "3. Copy the configuration below to your Claude Desktop config if required:"
diff --git a/server.py b/server.py
index ecae98b..53d13e2 100644
--- a/server.py
+++ b/server.py
@@ -169,6 +169,7 @@ def configure_providers():
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
from providers.openrouter import OpenRouterProvider
+ from providers.xai import XAIModelProvider
from utils.model_restrictions import get_restriction_service
valid_providers = []
@@ -190,6 +191,13 @@ def configure_providers():
has_native_apis = True
logger.info("OpenAI API key found - o3 model available")
+ # Check for X.AI API key
+ xai_key = os.getenv("XAI_API_KEY")
+ if xai_key and xai_key != "your_xai_api_key_here":
+ valid_providers.append("X.AI (GROK)")
+ has_native_apis = True
+ logger.info("X.AI API key found - GROK models available")
+
# Check for OpenRouter API key
openrouter_key = os.getenv("OPENROUTER_API_KEY")
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
@@ -221,6 +229,8 @@ def configure_providers():
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
if openai_key and openai_key != "your_openai_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ if xai_key and xai_key != "your_xai_api_key_here":
+ ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
# 2. Custom provider second (for local/private models)
if has_custom:
@@ -242,6 +252,7 @@ def configure_providers():
"At least one API configuration is required. Please set either:\n"
"- GEMINI_API_KEY for Gemini models\n"
"- OPENAI_API_KEY for OpenAI o3 model\n"
+ "- XAI_API_KEY for X.AI GROK models\n"
"- OPENROUTER_API_KEY for OpenRouter (multiple models)\n"
"- CUSTOM_API_URL for local models (Ollama, vLLM, etc.)"
)
diff --git a/simulator_tests/__init__.py b/simulator_tests/__init__.py
index 7e51b47..64ede47 100644
--- a/simulator_tests/__init__.py
+++ b/simulator_tests/__init__.py
@@ -24,6 +24,7 @@ from .test_redis_validation import RedisValidationTest
from .test_refactor_validation import RefactorValidationTest
from .test_testgen_validation import TestGenValidationTest
from .test_token_allocation_validation import TokenAllocationValidationTest
+from .test_xai_models import XAIModelsTest
# Test registry for dynamic loading
TEST_REGISTRY = {
@@ -44,6 +45,7 @@ TEST_REGISTRY = {
"testgen_validation": TestGenValidationTest,
"refactor_validation": RefactorValidationTest,
"conversation_chain_validation": ConversationChainValidationTest,
+ "xai_models": XAIModelsTest,
# "o3_pro_expensive": O3ProExpensiveTest, # COMMENTED OUT - too expensive to run by default
}
@@ -67,5 +69,6 @@ __all__ = [
"TestGenValidationTest",
"RefactorValidationTest",
"ConversationChainValidationTest",
+ "XAIModelsTest",
"TEST_REGISTRY",
]
diff --git a/simulator_tests/test_xai_models.py b/simulator_tests/test_xai_models.py
new file mode 100644
index 0000000..c71a996
--- /dev/null
+++ b/simulator_tests/test_xai_models.py
@@ -0,0 +1,280 @@
+#!/usr/bin/env python3
+"""
+X.AI GROK Model Tests
+
+Tests that verify X.AI GROK functionality including:
+- Model alias resolution (grok, grok3, grokfast map to actual GROK models)
+- GROK-3 and GROK-3-fast models work correctly
+- Conversation continuity works with GROK models
+- API integration and response validation
+"""
+
+import subprocess
+
+from .base_test import BaseSimulatorTest
+
+
+class XAIModelsTest(BaseSimulatorTest):
+ """Test X.AI GROK model functionality and integration"""
+
+ @property
+ def test_name(self) -> str:
+ return "xai_models"
+
+ @property
+ def test_description(self) -> str:
+ return "X.AI GROK model functionality and integration"
+
+ def get_recent_server_logs(self) -> str:
+ """Get recent server logs from the log file directly"""
+ try:
+ # Read logs directly from the log file
+ cmd = ["docker", "exec", self.container_name, "tail", "-n", "500", "/tmp/mcp_server.log"]
+ result = subprocess.run(cmd, capture_output=True, text=True)
+
+ if result.returncode == 0:
+ return result.stdout
+ else:
+ self.logger.warning(f"Failed to read server logs: {result.stderr}")
+ return ""
+ except Exception as e:
+ self.logger.error(f"Failed to get server logs: {e}")
+ return ""
+
+ def run_test(self) -> bool:
+ """Test X.AI GROK model functionality"""
+ try:
+ self.logger.info("Test: X.AI GROK model functionality and integration")
+
+ # Check if X.AI API key is configured and not empty
+ check_cmd = [
+ "docker",
+ "exec",
+ self.container_name,
+ "python",
+ "-c",
+ """
+import os
+xai_key = os.environ.get("XAI_API_KEY", "")
+is_valid = bool(xai_key and xai_key != "your_xai_api_key_here" and xai_key.strip())
+print(f"XAI_KEY_VALID:{is_valid}")
+ """.strip(),
+ ]
+ result = subprocess.run(check_cmd, capture_output=True, text=True)
+
+ if result.returncode == 0 and "XAI_KEY_VALID:False" in result.stdout:
+ self.logger.info(" â ī¸ X.AI API key not configured or empty - skipping test")
+ self.logger.info(" âšī¸ This test requires XAI_API_KEY to be set in .env with a valid key")
+ return True # Return True to indicate test is skipped, not failed
+
+ # Setup test files for later use
+ self.setup_test_files()
+
+ # Test 1: 'grok' alias (should map to grok-3)
+ self.logger.info(" 1: Testing 'grok' alias (should map to grok-3)")
+
+ response1, continuation_id = self.call_mcp_tool(
+ "chat",
+ {
+ "prompt": "Say 'Hello from GROK model!' and nothing else.",
+ "model": "grok",
+ "temperature": 0.1,
+ },
+ )
+
+ if not response1:
+ self.logger.error(" â GROK alias test failed")
+ return False
+
+ self.logger.info(" â
GROK alias call completed")
+ if continuation_id:
+ self.logger.info(f" â
Got continuation_id: {continuation_id}")
+
+ # Test 2: Direct grok-3 model name
+ self.logger.info(" 2: Testing direct model name (grok-3)")
+
+ response2, _ = self.call_mcp_tool(
+ "chat",
+ {
+ "prompt": "Say 'Hello from GROK-3!' and nothing else.",
+ "model": "grok-3",
+ "temperature": 0.1,
+ },
+ )
+
+ if not response2:
+ self.logger.error(" â Direct GROK-3 model test failed")
+ return False
+
+ self.logger.info(" â
Direct GROK-3 model call completed")
+
+ # Test 3: grok-3-fast model
+ self.logger.info(" 3: Testing GROK-3-fast model")
+
+ response3, _ = self.call_mcp_tool(
+ "chat",
+ {
+ "prompt": "Say 'Hello from GROK-3-fast!' and nothing else.",
+ "model": "grok-3-fast",
+ "temperature": 0.1,
+ },
+ )
+
+ if not response3:
+ self.logger.error(" â GROK-3-fast model test failed")
+ return False
+
+ self.logger.info(" â
GROK-3-fast model call completed")
+
+ # Test 4: Shorthand aliases
+ self.logger.info(" 4: Testing shorthand aliases (grok3, grokfast)")
+
+ response4, _ = self.call_mcp_tool(
+ "chat",
+ {
+ "prompt": "Say 'Hello from grok3 alias!' and nothing else.",
+ "model": "grok3",
+ "temperature": 0.1,
+ },
+ )
+
+ if not response4:
+ self.logger.error(" â grok3 alias test failed")
+ return False
+
+ response5, _ = self.call_mcp_tool(
+ "chat",
+ {
+ "prompt": "Say 'Hello from grokfast alias!' and nothing else.",
+ "model": "grokfast",
+ "temperature": 0.1,
+ },
+ )
+
+ if not response5:
+ self.logger.error(" â grokfast alias test failed")
+ return False
+
+ self.logger.info(" â
Shorthand aliases work correctly")
+
+ # Test 5: Conversation continuity with GROK models
+ self.logger.info(" 5: Testing conversation continuity with GROK")
+
+ response6, new_continuation_id = self.call_mcp_tool(
+ "chat",
+ {
+ "prompt": "Remember this number: 87. What number did I just tell you?",
+ "model": "grok",
+ "temperature": 0.1,
+ },
+ )
+
+ if not response6 or not new_continuation_id:
+ self.logger.error(" â Failed to start conversation with continuation_id")
+ return False
+
+ # Continue the conversation
+ response7, _ = self.call_mcp_tool(
+ "chat",
+ {
+ "prompt": "What was the number I told you earlier?",
+ "model": "grok",
+ "continuation_id": new_continuation_id,
+ "temperature": 0.1,
+ },
+ )
+
+ if not response7:
+ self.logger.error(" â Failed to continue conversation")
+ return False
+
+ # Check if the model remembered the number
+ if "87" in response7:
+ self.logger.info(" â
Conversation continuity working with GROK")
+ else:
+ self.logger.warning(" â ī¸ Model may not have remembered the number")
+
+ # Test 6: Validate X.AI API usage from logs
+ self.logger.info(" 6: Validating X.AI API usage in logs")
+ logs = self.get_recent_server_logs()
+
+ # Check for X.AI API calls
+ xai_logs = [line for line in logs.split("\n") if "x.ai" in line.lower()]
+ xai_api_logs = [line for line in logs.split("\n") if "api.x.ai" in line]
+ grok_logs = [line for line in logs.split("\n") if "grok" in line.lower()]
+
+ # Check for specific model resolution
+ grok_resolution_logs = [
+ line
+ for line in logs.split("\n")
+ if ("Resolved model" in line and "grok" in line.lower()) or ("grok" in line and "->" in line)
+ ]
+
+ # Check for X.AI provider usage
+ xai_provider_logs = [line for line in logs.split("\n") if "XAI" in line or "X.AI" in line]
+
+ # Log findings
+ self.logger.info(f" X.AI-related logs: {len(xai_logs)}")
+ self.logger.info(f" X.AI API logs: {len(xai_api_logs)}")
+ self.logger.info(f" GROK-related logs: {len(grok_logs)}")
+ self.logger.info(f" Model resolution logs: {len(grok_resolution_logs)}")
+ self.logger.info(f" X.AI provider logs: {len(xai_provider_logs)}")
+
+ # Sample log output for debugging
+ if self.verbose and xai_logs:
+ self.logger.debug(" đ Sample X.AI logs:")
+ for log in xai_logs[:3]:
+ self.logger.debug(f" {log}")
+
+ if self.verbose and grok_logs:
+ self.logger.debug(" đ Sample GROK logs:")
+ for log in grok_logs[:3]:
+ self.logger.debug(f" {log}")
+
+ # Success criteria
+ grok_mentioned = len(grok_logs) > 0
+ api_used = len(xai_api_logs) > 0 or len(xai_logs) > 0
+ provider_used = len(xai_provider_logs) > 0
+
+ success_criteria = [
+ ("GROK models mentioned in logs", grok_mentioned),
+ ("X.AI API calls made", api_used),
+ ("X.AI provider used", provider_used),
+ ("All model calls succeeded", True), # We already checked this above
+ ("Conversation continuity works", True), # We already tested this
+ ]
+
+ passed_criteria = sum(1 for _, passed in success_criteria if passed)
+ self.logger.info(f" Success criteria met: {passed_criteria}/{len(success_criteria)}")
+
+ for criterion, passed in success_criteria:
+ status = "â
" if passed else "â"
+ self.logger.info(f" {status} {criterion}")
+
+ if passed_criteria >= 3: # At least 3 out of 5 criteria
+ self.logger.info(" â
X.AI GROK model tests passed")
+ return True
+ else:
+ self.logger.error(" â X.AI GROK model tests failed")
+ return False
+
+ except Exception as e:
+ self.logger.error(f"X.AI GROK model test failed: {e}")
+ return False
+ finally:
+ self.cleanup_test_files()
+
+
+def main():
+ """Run the X.AI GROK model tests"""
+ import sys
+
+ verbose = "--verbose" in sys.argv or "-v" in sys.argv
+ test = XAIModelsTest(verbose=verbose)
+
+ success = test.run_test()
+ sys.exit(0 if success else 1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/conftest.py b/tests/conftest.py
index deabdae..c164a73 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -21,6 +21,8 @@ if "GEMINI_API_KEY" not in os.environ:
os.environ["GEMINI_API_KEY"] = "dummy-key-for-tests"
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = "dummy-key-for-tests"
+if "XAI_API_KEY" not in os.environ:
+ os.environ["XAI_API_KEY"] = "dummy-key-for-tests"
# Set default model to a specific value for tests to avoid auto mode
# This prevents all tests from failing due to missing model parameter
@@ -46,10 +48,12 @@ from providers import ModelProviderRegistry # noqa: E402
from providers.base import ProviderType # noqa: E402
from providers.gemini import GeminiModelProvider # noqa: E402
from providers.openai import OpenAIModelProvider # noqa: E402
+from providers.xai import XAIModelProvider # noqa: E402
# Register providers at test startup
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
@pytest.fixture
@@ -90,6 +94,18 @@ def mock_provider_availability(request, monkeypatch):
if marker:
return
+ # Ensure providers are registered (in case other tests cleared the registry)
+ from providers.base import ProviderType
+
+ registry = ModelProviderRegistry()
+
+ if ProviderType.GOOGLE not in registry._providers:
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+ if ProviderType.OPENAI not in registry._providers:
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ if ProviderType.XAI not in registry._providers:
+ ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
+
from unittest.mock import MagicMock
original_get_provider = ModelProviderRegistry.get_provider_for_model
@@ -119,3 +135,31 @@ def mock_provider_availability(request, monkeypatch):
return original_get_provider(model_name)
monkeypatch.setattr(ModelProviderRegistry, "get_provider_for_model", mock_get_provider_for_model)
+
+ # Also mock is_effective_auto_mode for all BaseTool instances to return False
+ # unless we're specifically testing auto mode behavior
+ from tools.base import BaseTool
+
+ def mock_is_effective_auto_mode(self):
+ # If this is an auto mode test file or specific auto mode test, use the real logic
+ test_file = request.node.fspath.basename if hasattr(request, "node") and hasattr(request.node, "fspath") else ""
+ test_name = request.node.name if hasattr(request, "node") else ""
+
+ # Allow auto mode for tests in auto mode files or with auto in the name
+ if (
+ "auto_mode" in test_file.lower()
+ or "auto" in test_name.lower()
+ or "intelligent_fallback" in test_file.lower()
+ or "per_tool_model_defaults" in test_file.lower()
+ ):
+ # Call original method logic
+ from config import DEFAULT_MODEL
+
+ if DEFAULT_MODEL.lower() == "auto":
+ return True
+ provider = ModelProviderRegistry.get_provider_for_model(DEFAULT_MODEL)
+ return provider is None
+ # For all other tests, return False to disable auto mode
+ return False
+
+ monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode)
diff --git a/tests/test_auto_mode_comprehensive.py b/tests/test_auto_mode_comprehensive.py
new file mode 100644
index 0000000..46fa668
--- /dev/null
+++ b/tests/test_auto_mode_comprehensive.py
@@ -0,0 +1,582 @@
+"""Comprehensive tests for auto mode functionality across all provider combinations"""
+
+import importlib
+import os
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from providers.base import ProviderType
+from providers.registry import ModelProviderRegistry
+from tools.analyze import AnalyzeTool
+from tools.chat import ChatTool
+from tools.debug import DebugIssueTool
+from tools.models import ToolModelCategory
+from tools.thinkdeep import ThinkDeepTool
+
+
+@pytest.mark.no_mock_provider
+class TestAutoModeComprehensive:
+ """Test auto mode model selection across all provider combinations"""
+
+ def setup_method(self):
+ """Set up clean state before each test."""
+ # Save original environment state for restoration
+ import os
+
+ self._original_default_model = os.environ.get("DEFAULT_MODEL", "")
+
+ # Clear restriction service cache
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ # Clear provider registry by resetting singleton instance
+ ModelProviderRegistry._instance = None
+
+ def teardown_method(self):
+ """Clean up after each test."""
+ # Restore original DEFAULT_MODEL
+ import os
+
+ if self._original_default_model:
+ os.environ["DEFAULT_MODEL"] = self._original_default_model
+ elif "DEFAULT_MODEL" in os.environ:
+ del os.environ["DEFAULT_MODEL"]
+
+ # Reload config to pick up the restored DEFAULT_MODEL
+ import importlib
+
+ import config
+
+ importlib.reload(config)
+
+ # Clear restriction service cache
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ # Clear provider registry by resetting singleton instance
+ ModelProviderRegistry._instance = None
+
+ # Re-register providers for subsequent tests (like conftest.py does)
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+ from providers.xai import XAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
+
+ @pytest.mark.parametrize(
+ "provider_config,expected_models",
+ [
+ # Only Gemini API available
+ (
+ {
+ "GEMINI_API_KEY": "real-key",
+ "OPENAI_API_KEY": None,
+ "XAI_API_KEY": None,
+ "OPENROUTER_API_KEY": None,
+ },
+ {
+ "EXTENDED_REASONING": "gemini-2.5-pro-preview-06-05", # Pro for deep thinking
+ "FAST_RESPONSE": "gemini-2.5-flash-preview-05-20", # Flash for speed
+ "BALANCED": "gemini-2.5-flash-preview-05-20", # Flash as balanced
+ },
+ ),
+ # Only OpenAI API available
+ (
+ {
+ "GEMINI_API_KEY": None,
+ "OPENAI_API_KEY": "real-key",
+ "XAI_API_KEY": None,
+ "OPENROUTER_API_KEY": None,
+ },
+ {
+ "EXTENDED_REASONING": "o3", # O3 for deep reasoning
+ "FAST_RESPONSE": "o4-mini", # O4-mini for speed
+ "BALANCED": "o4-mini", # O4-mini as balanced
+ },
+ ),
+ # Only X.AI API available
+ (
+ {
+ "GEMINI_API_KEY": None,
+ "OPENAI_API_KEY": None,
+ "XAI_API_KEY": "real-key",
+ "OPENROUTER_API_KEY": None,
+ },
+ {
+ "EXTENDED_REASONING": "grok-3", # GROK-3 for reasoning
+ "FAST_RESPONSE": "grok-3-fast", # GROK-3-fast for speed
+ "BALANCED": "grok-3", # GROK-3 as balanced
+ },
+ ),
+ # Both Gemini and OpenAI available - should prefer based on tool category
+ (
+ {
+ "GEMINI_API_KEY": "real-key",
+ "OPENAI_API_KEY": "real-key",
+ "XAI_API_KEY": None,
+ "OPENROUTER_API_KEY": None,
+ },
+ {
+ "EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
+ "FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
+ "BALANCED": "o4-mini", # Prefer OpenAI for balanced
+ },
+ ),
+ # All native APIs available - should prefer based on tool category
+ (
+ {
+ "GEMINI_API_KEY": "real-key",
+ "OPENAI_API_KEY": "real-key",
+ "XAI_API_KEY": "real-key",
+ "OPENROUTER_API_KEY": None,
+ },
+ {
+ "EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
+ "FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
+ "BALANCED": "o4-mini", # Prefer OpenAI for balanced
+ },
+ ),
+ # Only OpenRouter available - should fall back to proxy models
+ (
+ {
+ "GEMINI_API_KEY": None,
+ "OPENAI_API_KEY": None,
+ "XAI_API_KEY": None,
+ "OPENROUTER_API_KEY": "real-key",
+ },
+ {
+ "EXTENDED_REASONING": "anthropic/claude-3.5-sonnet", # First preferred thinking model from OpenRouter
+ "FAST_RESPONSE": "anthropic/claude-3-opus", # First available OpenRouter model
+ "BALANCED": "anthropic/claude-3-opus", # First available OpenRouter model
+ },
+ ),
+ ],
+ )
+ def test_auto_mode_model_selection_by_provider(self, provider_config, expected_models):
+ """Test that auto mode selects correct models based on available providers."""
+
+ # Set up environment with specific provider configuration
+ # Filter out None values and handle them separately
+ env_to_set = {k: v for k, v in provider_config.items() if v is not None}
+ env_to_clear = [k for k, v in provider_config.items() if v is None]
+
+ with patch.dict(os.environ, env_to_set, clear=False):
+ # Clear the None-valued environment variables
+ for key in env_to_clear:
+ if key in os.environ:
+ del os.environ[key]
+ # Reload config to pick up auto mode
+ os.environ["DEFAULT_MODEL"] = "auto"
+ import config
+
+ importlib.reload(config)
+
+ # Register providers based on configuration
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+ from providers.openrouter import OpenRouterProvider
+ from providers.xai import XAIModelProvider
+
+ if provider_config.get("GEMINI_API_KEY"):
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+ if provider_config.get("OPENAI_API_KEY"):
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ if provider_config.get("XAI_API_KEY"):
+ ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
+ if provider_config.get("OPENROUTER_API_KEY"):
+ ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
+
+ # Test each tool category
+ for category_name, expected_model in expected_models.items():
+ category = ToolModelCategory(category_name.lower())
+
+ # Get preferred fallback model for this category
+ fallback_model = ModelProviderRegistry.get_preferred_fallback_model(category)
+
+ assert fallback_model == expected_model, (
+ f"Provider config {provider_config}: "
+ f"Expected {expected_model} for {category_name}, got {fallback_model}"
+ )
+
+ @pytest.mark.parametrize(
+ "tool_class,expected_category",
+ [
+ (ChatTool, ToolModelCategory.FAST_RESPONSE),
+ (AnalyzeTool, ToolModelCategory.EXTENDED_REASONING), # AnalyzeTool uses EXTENDED_REASONING
+ (DebugIssueTool, ToolModelCategory.EXTENDED_REASONING),
+ (ThinkDeepTool, ToolModelCategory.EXTENDED_REASONING),
+ ],
+ )
+ def test_tool_model_categories(self, tool_class, expected_category):
+ """Test that tools have the correct model categories."""
+ tool = tool_class()
+ assert tool.get_model_category() == expected_category
+
+ @pytest.mark.asyncio
+ async def test_auto_mode_with_gemini_only_uses_correct_models(self):
+ """Test that auto mode with only Gemini uses flash for fast tools and pro for reasoning tools."""
+
+ provider_config = {
+ "GEMINI_API_KEY": "real-key",
+ "OPENAI_API_KEY": None,
+ "XAI_API_KEY": None,
+ "OPENROUTER_API_KEY": None,
+ "DEFAULT_MODEL": "auto",
+ }
+
+ # Filter out None values to avoid patch.dict errors
+ env_to_set = {k: v for k, v in provider_config.items() if v is not None}
+ env_to_clear = [k for k, v in provider_config.items() if v is None]
+
+ with patch.dict(os.environ, env_to_set, clear=False):
+ # Clear the None-valued environment variables
+ for key in env_to_clear:
+ if key in os.environ:
+ del os.environ[key]
+ import config
+
+ importlib.reload(config)
+
+ # Register only Gemini provider
+ from providers.gemini import GeminiModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+
+ # Mock provider to capture what model is requested
+ mock_provider = MagicMock()
+ mock_provider.generate_content.return_value = MagicMock(
+ content="test response", model_name="test-model", usage={"input_tokens": 10, "output_tokens": 5}
+ )
+
+ with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
+ # Test ChatTool (FAST_RESPONSE) - should prefer flash
+ chat_tool = ChatTool()
+ await chat_tool.execute({"prompt": "test", "model": "auto"}) # This should trigger auto selection
+
+ # In auto mode, the tool should get an error requiring model selection
+ # but the suggested model should be flash
+
+ # Reset mock for next test
+ ModelProviderRegistry.get_provider_for_model.reset_mock()
+
+ # Test DebugIssueTool (EXTENDED_REASONING) - should prefer pro
+ debug_tool = DebugIssueTool()
+ await debug_tool.execute({"prompt": "test error", "model": "auto"})
+
+ def test_auto_mode_schema_includes_all_available_models(self):
+ """Test that auto mode schema includes all available models for user convenience."""
+
+ # Test with only Gemini available
+ provider_config = {
+ "GEMINI_API_KEY": "real-key",
+ "OPENAI_API_KEY": None,
+ "XAI_API_KEY": None,
+ "OPENROUTER_API_KEY": None,
+ "DEFAULT_MODEL": "auto",
+ }
+
+ # Filter out None values to avoid patch.dict errors
+ env_to_set = {k: v for k, v in provider_config.items() if v is not None}
+ env_to_clear = [k for k, v in provider_config.items() if v is None]
+
+ with patch.dict(os.environ, env_to_set, clear=False):
+ # Clear the None-valued environment variables
+ for key in env_to_clear:
+ if key in os.environ:
+ del os.environ[key]
+ import config
+
+ importlib.reload(config)
+
+ # Register only Gemini provider
+ from providers.gemini import GeminiModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+
+ tool = AnalyzeTool()
+ schema = tool.get_input_schema()
+
+ # Should have model as required field
+ assert "model" in schema["required"]
+
+ # Should include all model options from global config
+ model_schema = schema["properties"]["model"]
+ assert "enum" in model_schema
+
+ available_models = model_schema["enum"]
+
+ # Should include Gemini models
+ assert "flash" in available_models
+ assert "pro" in available_models
+ assert "gemini-2.5-flash-preview-05-20" in available_models
+ assert "gemini-2.5-pro-preview-06-05" in available_models
+
+ # Should also include other models (users might have OpenRouter configured)
+ # The schema should show all options; validation happens at runtime
+ assert "o3" in available_models
+ assert "o4-mini" in available_models
+ assert "grok" in available_models
+ assert "grok-3" in available_models
+
+ def test_auto_mode_schema_with_all_providers(self):
+ """Test that auto mode schema includes models from all available providers."""
+
+ provider_config = {
+ "GEMINI_API_KEY": "real-key",
+ "OPENAI_API_KEY": "real-key",
+ "XAI_API_KEY": "real-key",
+ "OPENROUTER_API_KEY": None, # Don't include OpenRouter to avoid infinite models
+ "DEFAULT_MODEL": "auto",
+ }
+
+ # Filter out None values to avoid patch.dict errors
+ env_to_set = {k: v for k, v in provider_config.items() if v is not None}
+ env_to_clear = [k for k, v in provider_config.items() if v is None]
+
+ with patch.dict(os.environ, env_to_set, clear=False):
+ # Clear the None-valued environment variables
+ for key in env_to_clear:
+ if key in os.environ:
+ del os.environ[key]
+ import config
+
+ importlib.reload(config)
+
+ # Register all native providers
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+ from providers.xai import XAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
+
+ tool = AnalyzeTool()
+ schema = tool.get_input_schema()
+
+ model_schema = schema["properties"]["model"]
+ available_models = model_schema["enum"]
+
+ # Should include models from all providers
+ # Gemini models
+ assert "flash" in available_models
+ assert "pro" in available_models
+
+ # OpenAI models
+ assert "o3" in available_models
+ assert "o4-mini" in available_models
+
+ # XAI models
+ assert "grok" in available_models
+ assert "grok-3" in available_models
+
+ @pytest.mark.asyncio
+ async def test_auto_mode_model_parameter_required_error(self):
+ """Test that auto mode properly requires model parameter and suggests correct model."""
+
+ provider_config = {
+ "GEMINI_API_KEY": "real-key",
+ "OPENAI_API_KEY": None,
+ "XAI_API_KEY": None,
+ "OPENROUTER_API_KEY": None,
+ "DEFAULT_MODEL": "auto",
+ }
+
+ # Filter out None values to avoid patch.dict errors
+ env_to_set = {k: v for k, v in provider_config.items() if v is not None}
+ env_to_clear = [k for k, v in provider_config.items() if v is None]
+
+ with patch.dict(os.environ, env_to_set, clear=False):
+ # Clear the None-valued environment variables
+ for key in env_to_clear:
+ if key in os.environ:
+ del os.environ[key]
+ import config
+
+ importlib.reload(config)
+
+ # Register only Gemini provider
+ from providers.gemini import GeminiModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+
+ # Test with ChatTool (FAST_RESPONSE category)
+ chat_tool = ChatTool()
+ result = await chat_tool.execute(
+ {
+ "prompt": "test"
+ # Note: no model parameter provided in auto mode
+ }
+ )
+
+ # Should get error requiring model selection
+ assert len(result) == 1
+ response_text = result[0].text
+
+ # Parse JSON response to check error
+ import json
+
+ response_data = json.loads(response_text)
+
+ assert response_data["status"] == "error"
+ assert "Model parameter is required" in response_data["content"]
+ assert "flash" in response_data["content"] # Should suggest flash for FAST_RESPONSE
+ assert "category: fast_response" in response_data["content"]
+
+ def test_model_availability_with_restrictions(self):
+ """Test that auto mode respects model restrictions when selecting fallback models."""
+
+ provider_config = {
+ "GEMINI_API_KEY": "real-key",
+ "OPENAI_API_KEY": "real-key",
+ "XAI_API_KEY": None,
+ "OPENROUTER_API_KEY": None,
+ "DEFAULT_MODEL": "auto",
+ "OPENAI_ALLOWED_MODELS": "o4-mini", # Restrict OpenAI to only o4-mini
+ }
+
+ # Filter out None values to avoid patch.dict errors
+ env_to_set = {k: v for k, v in provider_config.items() if v is not None}
+ env_to_clear = [k for k, v in provider_config.items() if v is None]
+
+ with patch.dict(os.environ, env_to_set, clear=False):
+ # Clear the None-valued environment variables
+ for key in env_to_clear:
+ if key in os.environ:
+ del os.environ[key]
+ import config
+
+ importlib.reload(config)
+
+ # Clear restriction service to pick up new env vars
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ # Register providers
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+
+ # Get available models - should respect restrictions
+ available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
+
+ # Should include restricted OpenAI model
+ assert "o4-mini" in available_models
+
+ # Should NOT include non-restricted OpenAI models
+ assert "o3" not in available_models
+ assert "o3-mini" not in available_models
+
+ # Should still include all Gemini models (no restrictions)
+ assert "gemini-2.5-flash-preview-05-20" in available_models
+ assert "gemini-2.5-pro-preview-06-05" in available_models
+
+ def test_openrouter_fallback_when_no_native_apis(self):
+ """Test that OpenRouter provides fallback models when no native APIs are available."""
+
+ provider_config = {
+ "GEMINI_API_KEY": None,
+ "OPENAI_API_KEY": None,
+ "XAI_API_KEY": None,
+ "OPENROUTER_API_KEY": "real-key",
+ "DEFAULT_MODEL": "auto",
+ }
+
+ # Filter out None values to avoid patch.dict errors
+ env_to_set = {k: v for k, v in provider_config.items() if v is not None}
+ env_to_clear = [k for k, v in provider_config.items() if v is None]
+
+ with patch.dict(os.environ, env_to_set, clear=False):
+ # Clear the None-valued environment variables
+ for key in env_to_clear:
+ if key in os.environ:
+ del os.environ[key]
+ import config
+
+ importlib.reload(config)
+
+ # Register only OpenRouter provider
+ from providers.openrouter import OpenRouterProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
+
+ # Mock OpenRouter registry to return known models
+ mock_registry = MagicMock()
+ mock_registry.list_models.return_value = [
+ "google/gemini-2.5-flash-preview-05-20",
+ "google/gemini-2.5-pro-preview-06-05",
+ "openai/o3",
+ "openai/o4-mini",
+ "anthropic/claude-3-opus",
+ ]
+
+ with patch.object(OpenRouterProvider, "_registry", mock_registry):
+ # Get preferred models for different categories
+ extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
+ ToolModelCategory.EXTENDED_REASONING
+ )
+ fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
+
+ # Should fallback to known good models even via OpenRouter
+ # The exact model depends on _find_extended_thinking_model implementation
+ assert extended_reasoning is not None
+ assert fast_response is not None
+
+ @pytest.mark.asyncio
+ async def test_actual_model_name_resolution_in_auto_mode(self):
+ """Test that when a model is selected in auto mode, the tool executes successfully."""
+
+ provider_config = {
+ "GEMINI_API_KEY": "real-key",
+ "OPENAI_API_KEY": None,
+ "XAI_API_KEY": None,
+ "OPENROUTER_API_KEY": None,
+ "DEFAULT_MODEL": "auto",
+ }
+
+ # Filter out None values to avoid patch.dict errors
+ env_to_set = {k: v for k, v in provider_config.items() if v is not None}
+ env_to_clear = [k for k, v in provider_config.items() if v is None]
+
+ with patch.dict(os.environ, env_to_set, clear=False):
+ # Clear the None-valued environment variables
+ for key in env_to_clear:
+ if key in os.environ:
+ del os.environ[key]
+ import config
+
+ importlib.reload(config)
+
+ # Register Gemini provider
+ from providers.gemini import GeminiModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+
+ # Mock the actual provider to simulate successful execution
+ mock_provider = MagicMock()
+ mock_response = MagicMock()
+ mock_response.content = "test response"
+ mock_response.model_name = "gemini-2.5-flash-preview-05-20" # The resolved name
+ mock_response.usage = {"input_tokens": 10, "output_tokens": 5}
+ # Mock _resolve_model_name to simulate alias resolution
+ mock_provider._resolve_model_name = lambda alias: (
+ "gemini-2.5-flash-preview-05-20" if alias == "flash" else alias
+ )
+ mock_provider.generate_content.return_value = mock_response
+
+ with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
+ chat_tool = ChatTool()
+ result = await chat_tool.execute({"prompt": "test", "model": "flash"}) # Use alias in auto mode
+
+ # Should succeed with proper model resolution
+ assert len(result) == 1
+ # Just verify that the tool executed successfully and didn't return an error
+ assert "error" not in result[0].text.lower()
diff --git a/tests/test_auto_mode_provider_selection.py b/tests/test_auto_mode_provider_selection.py
new file mode 100644
index 0000000..a45c388
--- /dev/null
+++ b/tests/test_auto_mode_provider_selection.py
@@ -0,0 +1,344 @@
+"""Test auto mode provider selection logic specifically"""
+
+import os
+
+import pytest
+
+from providers.base import ProviderType
+from providers.registry import ModelProviderRegistry
+from tools.models import ToolModelCategory
+
+
+@pytest.mark.no_mock_provider
+class TestAutoModeProviderSelection:
+ """Test the core auto mode provider selection logic"""
+
+ def setup_method(self):
+ """Set up clean state before each test."""
+ # Clear restriction service cache
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ # Clear provider registry
+ registry = ModelProviderRegistry()
+ registry._providers.clear()
+ registry._initialized_providers.clear()
+
+ def teardown_method(self):
+ """Clean up after each test."""
+ # Clear restriction service cache
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ def test_gemini_only_fallback_selection(self):
+ """Test auto mode fallback when only Gemini is available."""
+
+ # Save original environment
+ original_env = {}
+ for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
+ original_env[key] = os.environ.get(key)
+
+ try:
+ # Set up environment - only Gemini available
+ os.environ["GEMINI_API_KEY"] = "test-key"
+ for key in ["OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
+ os.environ.pop(key, None)
+
+ # Register only Gemini provider
+ from providers.gemini import GeminiModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+
+ # Test fallback selection for different categories
+ extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
+ ToolModelCategory.EXTENDED_REASONING
+ )
+ fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
+ balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
+
+ # Should select appropriate Gemini models
+ assert extended_reasoning in ["gemini-2.5-pro-preview-06-05", "pro"]
+ assert fast_response in ["gemini-2.5-flash-preview-05-20", "flash"]
+ assert balanced in ["gemini-2.5-flash-preview-05-20", "flash"]
+
+ finally:
+ # Restore original environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ else:
+ os.environ.pop(key, None)
+
+ def test_openai_only_fallback_selection(self):
+ """Test auto mode fallback when only OpenAI is available."""
+
+ # Save original environment
+ original_env = {}
+ for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
+ original_env[key] = os.environ.get(key)
+
+ try:
+ # Set up environment - only OpenAI available
+ os.environ["OPENAI_API_KEY"] = "test-key"
+ for key in ["GEMINI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
+ os.environ.pop(key, None)
+
+ # Register only OpenAI provider
+ from providers.openai import OpenAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+
+ # Test fallback selection for different categories
+ extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
+ ToolModelCategory.EXTENDED_REASONING
+ )
+ fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
+ balanced = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
+
+ # Should select appropriate OpenAI models
+ assert extended_reasoning in ["o3", "o3-mini", "o4-mini"] # Any available OpenAI model for reasoning
+ assert fast_response in ["o4-mini", "o3-mini"] # Prefer faster models
+ assert balanced in ["o4-mini", "o3-mini"] # Balanced selection
+
+ finally:
+ # Restore original environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ else:
+ os.environ.pop(key, None)
+
+ def test_both_gemini_and_openai_priority(self):
+ """Test auto mode when both Gemini and OpenAI are available."""
+
+ # Save original environment
+ original_env = {}
+ for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
+ original_env[key] = os.environ.get(key)
+
+ try:
+ # Set up environment - both Gemini and OpenAI available
+ os.environ["GEMINI_API_KEY"] = "test-key"
+ os.environ["OPENAI_API_KEY"] = "test-key"
+ for key in ["XAI_API_KEY", "OPENROUTER_API_KEY"]:
+ os.environ.pop(key, None)
+
+ # Register both providers
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+
+ # Test fallback selection for different categories
+ extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
+ ToolModelCategory.EXTENDED_REASONING
+ )
+ fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
+
+ # Should prefer OpenAI for reasoning (based on fallback logic)
+ assert extended_reasoning == "o3" # Should prefer O3 for extended reasoning
+
+ # Should prefer OpenAI for fast response
+ assert fast_response == "o4-mini" # Should prefer O4-mini for fast response
+
+ finally:
+ # Restore original environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ else:
+ os.environ.pop(key, None)
+
+ def test_xai_only_fallback_selection(self):
+ """Test auto mode fallback when only XAI is available."""
+
+ # Save original environment
+ original_env = {}
+ for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY"]:
+ original_env[key] = os.environ.get(key)
+
+ try:
+ # Set up environment - only XAI available
+ os.environ["XAI_API_KEY"] = "test-key"
+ for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "OPENROUTER_API_KEY"]:
+ os.environ.pop(key, None)
+
+ # Register only XAI provider
+ from providers.xai import XAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
+
+ # Test fallback selection for different categories
+ extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
+ ToolModelCategory.EXTENDED_REASONING
+ )
+ fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
+
+ # Should fallback to available models or default fallbacks
+ # Since XAI models are not explicitly handled in fallback logic,
+ # it should fall back to the hardcoded defaults
+ assert extended_reasoning is not None
+ assert fast_response is not None
+
+ finally:
+ # Restore original environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ else:
+ os.environ.pop(key, None)
+
+ def test_available_models_respects_restrictions(self):
+ """Test that get_available_models respects model restrictions."""
+
+ # Save original environment
+ original_env = {}
+ for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "OPENAI_ALLOWED_MODELS"]:
+ original_env[key] = os.environ.get(key)
+
+ try:
+ # Set up environment with restrictions
+ os.environ["GEMINI_API_KEY"] = "test-key"
+ os.environ["OPENAI_API_KEY"] = "test-key"
+ os.environ["OPENAI_ALLOWED_MODELS"] = "o4-mini" # Only allow o4-mini
+
+ # Clear restriction service to pick up new restrictions
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ # Register both providers
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+
+ # Get available models with restrictions
+ available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
+
+ # Should include allowed OpenAI model
+ assert "o4-mini" in available_models
+ assert available_models["o4-mini"] == ProviderType.OPENAI
+
+ # Should NOT include restricted OpenAI models
+ assert "o3" not in available_models
+ assert "o3-mini" not in available_models
+
+ # Should include all Gemini models (no restrictions)
+ assert "gemini-2.5-flash-preview-05-20" in available_models
+ assert available_models["gemini-2.5-flash-preview-05-20"] == ProviderType.GOOGLE
+
+ finally:
+ # Restore original environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ else:
+ os.environ.pop(key, None)
+
+ def test_model_validation_across_providers(self):
+ """Test that model validation works correctly across different providers."""
+
+ # Save original environment
+ original_env = {}
+ for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"]:
+ original_env[key] = os.environ.get(key)
+
+ try:
+ # Set up all providers
+ os.environ["GEMINI_API_KEY"] = "test-key"
+ os.environ["OPENAI_API_KEY"] = "test-key"
+ os.environ["XAI_API_KEY"] = "test-key"
+
+ # Register all providers
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+ from providers.xai import XAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
+
+ # Test model validation - each provider should handle its own models
+ # Gemini models
+ gemini_provider = ModelProviderRegistry.get_provider_for_model("flash")
+ assert gemini_provider is not None
+ assert gemini_provider.get_provider_type() == ProviderType.GOOGLE
+
+ # OpenAI models
+ openai_provider = ModelProviderRegistry.get_provider_for_model("o3")
+ assert openai_provider is not None
+ assert openai_provider.get_provider_type() == ProviderType.OPENAI
+
+ # XAI models
+ xai_provider = ModelProviderRegistry.get_provider_for_model("grok")
+ assert xai_provider is not None
+ assert xai_provider.get_provider_type() == ProviderType.XAI
+
+ # Invalid model should return None
+ invalid_provider = ModelProviderRegistry.get_provider_for_model("invalid-model-name")
+ assert invalid_provider is None
+
+ finally:
+ # Restore original environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ else:
+ os.environ.pop(key, None)
+
+ def test_alias_resolution_before_api_calls(self):
+ """Test that model aliases are resolved before being passed to providers."""
+
+ # Save original environment
+ original_env = {}
+ for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"]:
+ original_env[key] = os.environ.get(key)
+
+ try:
+ # Set up all providers
+ os.environ["GEMINI_API_KEY"] = "test-key"
+ os.environ["OPENAI_API_KEY"] = "test-key"
+ os.environ["XAI_API_KEY"] = "test-key"
+
+ # Register all providers
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+ from providers.xai import XAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
+
+ # Test that providers resolve aliases correctly
+ test_cases = [
+ ("flash", ProviderType.GOOGLE, "gemini-2.5-flash-preview-05-20"),
+ ("pro", ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05"),
+ ("mini", ProviderType.OPENAI, "o4-mini"),
+ ("o3mini", ProviderType.OPENAI, "o3-mini"),
+ ("grok", ProviderType.XAI, "grok-3"),
+ ("grokfast", ProviderType.XAI, "grok-3-fast"),
+ ]
+
+ for alias, expected_provider_type, expected_resolved_name in test_cases:
+ provider = ModelProviderRegistry.get_provider_for_model(alias)
+ assert provider is not None, f"No provider found for alias '{alias}'"
+ assert provider.get_provider_type() == expected_provider_type, f"Wrong provider for '{alias}'"
+
+ # Test alias resolution
+ resolved_name = provider._resolve_model_name(alias)
+ assert (
+ resolved_name == expected_resolved_name
+ ), f"Alias '{alias}' should resolve to '{expected_resolved_name}', got '{resolved_name}'"
+
+ finally:
+ # Restore original environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ else:
+ os.environ.pop(key, None)
diff --git a/tests/test_claude_continuation.py b/tests/test_claude_continuation.py
index 7a699e8..e4fa6e0 100644
--- a/tests/test_claude_continuation.py
+++ b/tests/test_claude_continuation.py
@@ -55,6 +55,8 @@ class TestClaudeContinuationOffers:
"""Test Claude continuation offer functionality"""
def setup_method(self):
+ # Note: Tool creation and schema generation happens here
+ # If providers are not registered yet, tool might detect auto mode
self.tool = ClaudeContinuationTool()
# Set default model to avoid effective auto mode
self.tool.default_model = "gemini-2.5-flash-preview-05-20"
@@ -63,11 +65,15 @@ class TestClaudeContinuationOffers:
@patch.dict("os.environ", {"PYTEST_CURRENT_TEST": ""}, clear=False)
async def test_new_conversation_offers_continuation(self, mock_redis):
"""Test that new conversations offer Claude continuation opportunity"""
+ # Create tool AFTER providers are registered (in conftest.py fixture)
+ tool = ClaudeContinuationTool()
+ tool.default_model = "gemini-2.5-flash-preview-05-20"
+
mock_client = Mock()
mock_redis.return_value = mock_client
# Mock the model
- with patch.object(self.tool, "get_model_provider") as mock_get_provider:
+ with patch.object(tool, "get_model_provider") as mock_get_provider:
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
@@ -81,7 +87,7 @@ class TestClaudeContinuationOffers:
# Execute tool without continuation_id (new conversation)
arguments = {"prompt": "Analyze this code"}
- response = await self.tool.execute(arguments)
+ response = await tool.execute(arguments)
# Parse response
response_data = json.loads(response[0].text)
@@ -177,10 +183,6 @@ class TestClaudeContinuationOffers:
assert len(response) == 1
response_data = json.loads(response[0].text)
- # Debug output
- if response_data.get("status") == "error":
- print(f"Error content: {response_data.get('content')}")
-
assert response_data["status"] == "continuation_available"
assert response_data["content"] == "Analysis complete. The code looks good."
assert "continuation_offer" in response_data
diff --git a/tests/test_intelligent_fallback.py b/tests/test_intelligent_fallback.py
index 78c3cdb..f783dd2 100644
--- a/tests/test_intelligent_fallback.py
+++ b/tests/test_intelligent_fallback.py
@@ -17,51 +17,93 @@ class TestIntelligentFallback:
"""Test intelligent model fallback logic"""
def setup_method(self):
- """Setup for each test - clear registry cache"""
- ModelProviderRegistry.clear_cache()
+ """Setup for each test - clear registry and reset providers"""
+ # Store original providers for restoration
+ registry = ModelProviderRegistry()
+ self._original_providers = registry._providers.copy()
+ self._original_initialized = registry._initialized_providers.copy()
+
+ # Clear registry completely
+ ModelProviderRegistry._instance = None
def teardown_method(self):
- """Cleanup after each test"""
- ModelProviderRegistry.clear_cache()
+ """Cleanup after each test - restore original providers"""
+ # Restore original registry state
+ registry = ModelProviderRegistry()
+ registry._providers.clear()
+ registry._initialized_providers.clear()
+ registry._providers.update(self._original_providers)
+ registry._initialized_providers.update(self._original_initialized)
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False)
def test_prefers_openai_o3_mini_when_available(self):
"""Test that o4-mini is preferred when OpenAI API key is available"""
- ModelProviderRegistry.clear_cache()
+ # Register only OpenAI provider for this test
+ from providers.openai import OpenAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o4-mini"
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
def test_prefers_gemini_flash_when_openai_unavailable(self):
"""Test that gemini-2.5-flash-preview-05-20 is used when only Gemini API key is available"""
- ModelProviderRegistry.clear_cache()
+ # Register only Gemini provider for this test
+ from providers.gemini import GeminiModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "gemini-2.5-flash-preview-05-20"
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
def test_prefers_openai_when_both_available(self):
"""Test that OpenAI is preferred when both API keys are available"""
- ModelProviderRegistry.clear_cache()
+ # Register both OpenAI and Gemini providers
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o4-mini" # OpenAI has priority
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False)
def test_fallback_when_no_keys_available(self):
"""Test fallback behavior when no API keys are available"""
- ModelProviderRegistry.clear_cache()
+ # Register providers but with no API keys available
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "gemini-2.5-flash-preview-05-20" # Default fallback
def test_available_providers_with_keys(self):
"""Test the get_available_providers_with_keys method"""
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False):
- ModelProviderRegistry.clear_cache()
+ # Clear and register providers
+ ModelProviderRegistry._instance = None
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+
available = ModelProviderRegistry.get_available_providers_with_keys()
assert ProviderType.OPENAI in available
assert ProviderType.GOOGLE not in available
with patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False):
- ModelProviderRegistry.clear_cache()
+ # Clear and register providers
+ ModelProviderRegistry._instance = None
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+
available = ModelProviderRegistry.get_available_providers_with_keys()
assert ProviderType.GOOGLE in available
assert ProviderType.OPENAI not in available
@@ -76,7 +118,10 @@ class TestIntelligentFallback:
patch("config.DEFAULT_MODEL", "auto"),
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False),
):
- ModelProviderRegistry.clear_cache()
+ # Register only OpenAI provider for this test
+ from providers.openai import OpenAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
# Create a context with at least one turn so it doesn't exit early
from utils.conversation_memory import ConversationTurn
@@ -114,7 +159,10 @@ class TestIntelligentFallback:
patch("config.DEFAULT_MODEL", "auto"),
patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False),
):
- ModelProviderRegistry.clear_cache()
+ # Register only Gemini provider for this test
+ from providers.gemini import GeminiModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
from utils.conversation_memory import ConversationTurn
diff --git a/tests/test_large_prompt_handling.py b/tests/test_large_prompt_handling.py
index 137c1f6..f0c4482 100644
--- a/tests/test_large_prompt_handling.py
+++ b/tests/test_large_prompt_handling.py
@@ -243,10 +243,23 @@ class TestLargePromptHandling:
tool = ChatTool()
exact_prompt = "x" * MCP_PROMPT_SIZE_LIMIT
- # With the fix, this should now pass because we check at MCP transport boundary before adding internal content
- result = await tool.execute({"prompt": exact_prompt})
- output = json.loads(result[0].text)
- assert output["status"] == "success"
+ # Mock the model provider to avoid real API calls
+ with patch.object(tool, "get_model_provider") as mock_get_provider:
+ mock_provider = MagicMock()
+ mock_provider.get_provider_type.return_value = MagicMock(value="google")
+ mock_provider.supports_thinking_mode.return_value = False
+ mock_provider.generate_content.return_value = MagicMock(
+ content="Response to the large prompt",
+ usage={"input_tokens": 12000, "output_tokens": 10, "total_tokens": 12010},
+ model_name="gemini-2.5-flash-preview-05-20",
+ metadata={"finish_reason": "STOP"},
+ )
+ mock_get_provider.return_value = mock_provider
+
+ # With the fix, this should now pass because we check at MCP transport boundary before adding internal content
+ result = await tool.execute({"prompt": exact_prompt})
+ output = json.loads(result[0].text)
+ assert output["status"] == "success"
@pytest.mark.asyncio
async def test_boundary_case_just_over_limit(self):
diff --git a/tests/test_model_restrictions.py b/tests/test_model_restrictions.py
index 715de63..4472d51 100644
--- a/tests/test_model_restrictions.py
+++ b/tests/test_model_restrictions.py
@@ -535,18 +535,38 @@ class TestAutoModeWithRestrictions:
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GEMINI_API_KEY": "", "OPENAI_API_KEY": "test-key"})
def test_fallback_with_shorthand_restrictions(self):
"""Test fallback model selection with shorthand restrictions."""
- # Clear caches
+ # Clear caches and reset registry
import utils.model_restrictions
from providers.registry import ModelProviderRegistry
from tools.models import ToolModelCategory
utils.model_restrictions._restriction_service = None
- ModelProviderRegistry.clear_cache()
- # Even with "mini" restriction, fallback should work if provider handles it correctly
- # This tests the real-world scenario
- model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
+ # Store original providers for restoration
+ registry = ModelProviderRegistry()
+ original_providers = registry._providers.copy()
+ original_initialized = registry._initialized_providers.copy()
- # The fallback will depend on how get_available_models handles aliases
- # For now, we accept either behavior and document it
- assert model in ["o4-mini", "gemini-2.5-flash-preview-05-20"]
+ try:
+ # Clear registry and register only OpenAI and Gemini providers
+ ModelProviderRegistry._instance = None
+ from providers.gemini import GeminiModelProvider
+ from providers.openai import OpenAIModelProvider
+
+ ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
+ ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
+
+ # Even with "mini" restriction, fallback should work if provider handles it correctly
+ # This tests the real-world scenario
+ model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
+
+ # The fallback will depend on how get_available_models handles aliases
+ # For now, we accept either behavior and document it
+ assert model in ["o4-mini", "gemini-2.5-flash-preview-05-20"]
+ finally:
+ # Restore original registry state
+ registry = ModelProviderRegistry()
+ registry._providers.clear()
+ registry._initialized_providers.clear()
+ registry._providers.update(original_providers)
+ registry._initialized_providers.update(original_initialized)
diff --git a/tests/test_openai_provider.py b/tests/test_openai_provider.py
new file mode 100644
index 0000000..8f1a936
--- /dev/null
+++ b/tests/test_openai_provider.py
@@ -0,0 +1,221 @@
+"""Tests for OpenAI provider implementation."""
+
+import os
+from unittest.mock import MagicMock, patch
+
+from providers.base import ProviderType
+from providers.openai import OpenAIModelProvider
+
+
+class TestOpenAIProvider:
+ """Test OpenAI provider functionality."""
+
+ def setup_method(self):
+ """Set up clean state before each test."""
+ # Clear restriction service cache before each test
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ def teardown_method(self):
+ """Clean up after each test to avoid singleton issues."""
+ # Clear restriction service cache after each test
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ @patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"})
+ def test_initialization(self):
+ """Test provider initialization."""
+ provider = OpenAIModelProvider("test-key")
+ assert provider.api_key == "test-key"
+ assert provider.get_provider_type() == ProviderType.OPENAI
+ assert provider.base_url == "https://api.openai.com/v1"
+
+ def test_initialization_with_custom_url(self):
+ """Test provider initialization with custom base URL."""
+ provider = OpenAIModelProvider("test-key", base_url="https://custom.openai.com/v1")
+ assert provider.api_key == "test-key"
+ assert provider.base_url == "https://custom.openai.com/v1"
+
+ def test_model_validation(self):
+ """Test model name validation."""
+ provider = OpenAIModelProvider("test-key")
+
+ # Test valid models
+ assert provider.validate_model_name("o3") is True
+ assert provider.validate_model_name("o3-mini") is True
+ assert provider.validate_model_name("o3-pro") is True
+ assert provider.validate_model_name("o4-mini") is True
+ assert provider.validate_model_name("o4-mini-high") is True
+
+ # Test valid aliases
+ assert provider.validate_model_name("mini") is True
+ assert provider.validate_model_name("o3mini") is True
+ assert provider.validate_model_name("o4mini") is True
+ assert provider.validate_model_name("o4minihigh") is True
+ assert provider.validate_model_name("o4minihi") is True
+
+ # Test invalid model
+ assert provider.validate_model_name("invalid-model") is False
+ assert provider.validate_model_name("gpt-4") is False
+ assert provider.validate_model_name("gemini-pro") is False
+
+ def test_resolve_model_name(self):
+ """Test model name resolution."""
+ provider = OpenAIModelProvider("test-key")
+
+ # Test shorthand resolution
+ assert provider._resolve_model_name("mini") == "o4-mini"
+ assert provider._resolve_model_name("o3mini") == "o3-mini"
+ assert provider._resolve_model_name("o4mini") == "o4-mini"
+ assert provider._resolve_model_name("o4minihigh") == "o4-mini-high"
+ assert provider._resolve_model_name("o4minihi") == "o4-mini-high"
+
+ # Test full name passthrough
+ assert provider._resolve_model_name("o3") == "o3"
+ assert provider._resolve_model_name("o3-mini") == "o3-mini"
+ assert provider._resolve_model_name("o3-pro") == "o3-pro"
+ assert provider._resolve_model_name("o4-mini") == "o4-mini"
+ assert provider._resolve_model_name("o4-mini-high") == "o4-mini-high"
+
+ def test_get_capabilities_o3(self):
+ """Test getting model capabilities for O3."""
+ provider = OpenAIModelProvider("test-key")
+
+ capabilities = provider.get_capabilities("o3")
+ assert capabilities.model_name == "o3" # Should NOT be resolved in capabilities
+ assert capabilities.friendly_name == "OpenAI"
+ assert capabilities.context_window == 200_000
+ assert capabilities.provider == ProviderType.OPENAI
+ assert not capabilities.supports_extended_thinking
+ assert capabilities.supports_system_prompts is True
+ assert capabilities.supports_streaming is True
+ assert capabilities.supports_function_calling is True
+
+ # Test temperature constraint (O3 has fixed temperature)
+ assert capabilities.temperature_constraint.value == 1.0
+
+ def test_get_capabilities_with_alias(self):
+ """Test getting model capabilities with alias resolves correctly."""
+ provider = OpenAIModelProvider("test-key")
+
+ capabilities = provider.get_capabilities("mini")
+ assert capabilities.model_name == "mini" # Capabilities should show original request
+ assert capabilities.friendly_name == "OpenAI"
+ assert capabilities.context_window == 200_000
+ assert capabilities.provider == ProviderType.OPENAI
+
+ @patch("providers.openai_compatible.OpenAI")
+ def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
+ """Test that generate_content resolves aliases before making API calls.
+
+ This is the CRITICAL test that was missing - verifying that aliases
+ like 'mini' get resolved to 'o4-mini' before being sent to OpenAI API.
+ """
+ # Set up mock OpenAI client
+ mock_client = MagicMock()
+ mock_openai_class.return_value = mock_client
+
+ # Mock the completion response
+ mock_response = MagicMock()
+ mock_response.choices = [MagicMock()]
+ mock_response.choices[0].message.content = "Test response"
+ mock_response.choices[0].finish_reason = "stop"
+ mock_response.model = "o4-mini" # API returns the resolved model name
+ mock_response.id = "test-id"
+ mock_response.created = 1234567890
+ mock_response.usage = MagicMock()
+ mock_response.usage.prompt_tokens = 10
+ mock_response.usage.completion_tokens = 5
+ mock_response.usage.total_tokens = 15
+
+ mock_client.chat.completions.create.return_value = mock_response
+
+ provider = OpenAIModelProvider("test-key")
+
+ # Call generate_content with alias 'mini'
+ result = provider.generate_content(
+ prompt="Test prompt", model_name="mini", temperature=1.0 # This should be resolved to "o4-mini"
+ )
+
+ # Verify the API was called with the RESOLVED model name
+ mock_client.chat.completions.create.assert_called_once()
+ call_kwargs = mock_client.chat.completions.create.call_args[1]
+
+ # CRITICAL ASSERTION: The API should receive "o4-mini", not "mini"
+ assert call_kwargs["model"] == "o4-mini", f"Expected 'o4-mini' but API received '{call_kwargs['model']}'"
+
+ # Verify other parameters
+ assert call_kwargs["temperature"] == 1.0
+ assert len(call_kwargs["messages"]) == 1
+ assert call_kwargs["messages"][0]["role"] == "user"
+ assert call_kwargs["messages"][0]["content"] == "Test prompt"
+
+ # Verify response
+ assert result.content == "Test response"
+ assert result.model_name == "o4-mini" # Should be the resolved name
+
+ @patch("providers.openai_compatible.OpenAI")
+ def test_generate_content_other_aliases(self, mock_openai_class):
+ """Test other alias resolutions in generate_content."""
+ # Set up mock
+ mock_client = MagicMock()
+ mock_openai_class.return_value = mock_client
+ mock_response = MagicMock()
+ mock_response.choices = [MagicMock()]
+ mock_response.choices[0].message.content = "Test response"
+ mock_response.choices[0].finish_reason = "stop"
+ mock_response.usage = MagicMock()
+ mock_response.usage.prompt_tokens = 10
+ mock_response.usage.completion_tokens = 5
+ mock_response.usage.total_tokens = 15
+ mock_client.chat.completions.create.return_value = mock_response
+
+ provider = OpenAIModelProvider("test-key")
+
+ # Test o3mini -> o3-mini
+ mock_response.model = "o3-mini"
+ provider.generate_content(prompt="Test", model_name="o3mini", temperature=1.0)
+ call_kwargs = mock_client.chat.completions.create.call_args[1]
+ assert call_kwargs["model"] == "o3-mini"
+
+ # Test o4minihigh -> o4-mini-high
+ mock_response.model = "o4-mini-high"
+ provider.generate_content(prompt="Test", model_name="o4minihigh", temperature=1.0)
+ call_kwargs = mock_client.chat.completions.create.call_args[1]
+ assert call_kwargs["model"] == "o4-mini-high"
+
+ @patch("providers.openai_compatible.OpenAI")
+ def test_generate_content_no_alias_passthrough(self, mock_openai_class):
+ """Test that full model names pass through unchanged."""
+ # Set up mock
+ mock_client = MagicMock()
+ mock_openai_class.return_value = mock_client
+ mock_response = MagicMock()
+ mock_response.choices = [MagicMock()]
+ mock_response.choices[0].message.content = "Test response"
+ mock_response.choices[0].finish_reason = "stop"
+ mock_response.model = "o3-pro"
+ mock_response.usage = MagicMock()
+ mock_response.usage.prompt_tokens = 10
+ mock_response.usage.completion_tokens = 5
+ mock_response.usage.total_tokens = 15
+ mock_client.chat.completions.create.return_value = mock_response
+
+ provider = OpenAIModelProvider("test-key")
+
+ # Test full model name passes through unchanged
+ provider.generate_content(prompt="Test", model_name="o3-pro", temperature=1.0)
+ call_kwargs = mock_client.chat.completions.create.call_args[1]
+ assert call_kwargs["model"] == "o3-pro" # Should be unchanged
+
+ def test_supports_thinking_mode(self):
+ """Test thinking mode support (currently False for all OpenAI models)."""
+ provider = OpenAIModelProvider("test-key")
+
+ # All OpenAI models currently don't support thinking mode
+ assert provider.supports_thinking_mode("o3") is False
+ assert provider.supports_thinking_mode("o3-mini") is False
+ assert provider.supports_thinking_mode("o4-mini") is False
+ assert provider.supports_thinking_mode("mini") is False # Test with alias too
diff --git a/tests/test_per_tool_model_defaults.py b/tests/test_per_tool_model_defaults.py
index a91c7e3..896509d 100644
--- a/tests/test_per_tool_model_defaults.py
+++ b/tests/test_per_tool_model_defaults.py
@@ -202,9 +202,9 @@ class TestCustomProviderFallback:
@patch.object(ModelProviderRegistry, "_find_extended_thinking_model")
def test_extended_reasoning_custom_fallback(self, mock_find_thinking):
"""Test EXTENDED_REASONING falls back to custom thinking model."""
- with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
- # No native providers available
- mock_get_provider.return_value = None
+ with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
+ # No native models available, but OpenRouter is available
+ mock_get_available.return_value = {"openrouter-model": ProviderType.OPENROUTER}
mock_find_thinking.return_value = "custom/thinking-model"
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
diff --git a/tests/test_xai_provider.py b/tests/test_xai_provider.py
new file mode 100644
index 0000000..e002636
--- /dev/null
+++ b/tests/test_xai_provider.py
@@ -0,0 +1,326 @@
+"""Tests for X.AI provider implementation."""
+
+import os
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from providers.base import ProviderType
+from providers.xai import XAIModelProvider
+
+
+class TestXAIProvider:
+ """Test X.AI provider functionality."""
+
+ def setup_method(self):
+ """Set up clean state before each test."""
+ # Clear restriction service cache before each test
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ def teardown_method(self):
+ """Clean up after each test to avoid singleton issues."""
+ # Clear restriction service cache after each test
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ @patch.dict(os.environ, {"XAI_API_KEY": "test-key"})
+ def test_initialization(self):
+ """Test provider initialization."""
+ provider = XAIModelProvider("test-key")
+ assert provider.api_key == "test-key"
+ assert provider.get_provider_type() == ProviderType.XAI
+ assert provider.base_url == "https://api.x.ai/v1"
+
+ def test_initialization_with_custom_url(self):
+ """Test provider initialization with custom base URL."""
+ provider = XAIModelProvider("test-key", base_url="https://custom.x.ai/v1")
+ assert provider.api_key == "test-key"
+ assert provider.base_url == "https://custom.x.ai/v1"
+
+ def test_model_validation(self):
+ """Test model name validation."""
+ provider = XAIModelProvider("test-key")
+
+ # Test valid models
+ assert provider.validate_model_name("grok-3") is True
+ assert provider.validate_model_name("grok-3-fast") is True
+ assert provider.validate_model_name("grok") is True
+ assert provider.validate_model_name("grok3") is True
+ assert provider.validate_model_name("grokfast") is True
+ assert provider.validate_model_name("grok3fast") is True
+
+ # Test invalid model
+ assert provider.validate_model_name("invalid-model") is False
+ assert provider.validate_model_name("gpt-4") is False
+ assert provider.validate_model_name("gemini-pro") is False
+
+ def test_resolve_model_name(self):
+ """Test model name resolution."""
+ provider = XAIModelProvider("test-key")
+
+ # Test shorthand resolution
+ assert provider._resolve_model_name("grok") == "grok-3"
+ assert provider._resolve_model_name("grok3") == "grok-3"
+ assert provider._resolve_model_name("grokfast") == "grok-3-fast"
+ assert provider._resolve_model_name("grok3fast") == "grok-3-fast"
+
+ # Test full name passthrough
+ assert provider._resolve_model_name("grok-3") == "grok-3"
+ assert provider._resolve_model_name("grok-3-fast") == "grok-3-fast"
+
+ def test_get_capabilities_grok3(self):
+ """Test getting model capabilities for GROK-3."""
+ provider = XAIModelProvider("test-key")
+
+ capabilities = provider.get_capabilities("grok-3")
+ assert capabilities.model_name == "grok-3"
+ assert capabilities.friendly_name == "X.AI"
+ assert capabilities.context_window == 131_072
+ assert capabilities.provider == ProviderType.XAI
+ assert not capabilities.supports_extended_thinking
+ assert capabilities.supports_system_prompts is True
+ assert capabilities.supports_streaming is True
+ assert capabilities.supports_function_calling is True
+
+ # Test temperature range
+ assert capabilities.temperature_constraint.min_temp == 0.0
+ assert capabilities.temperature_constraint.max_temp == 2.0
+ assert capabilities.temperature_constraint.default_temp == 0.7
+
+ def test_get_capabilities_grok3_fast(self):
+ """Test getting model capabilities for GROK-3 Fast."""
+ provider = XAIModelProvider("test-key")
+
+ capabilities = provider.get_capabilities("grok-3-fast")
+ assert capabilities.model_name == "grok-3-fast"
+ assert capabilities.friendly_name == "X.AI"
+ assert capabilities.context_window == 131_072
+ assert capabilities.provider == ProviderType.XAI
+ assert not capabilities.supports_extended_thinking
+
+ def test_get_capabilities_with_shorthand(self):
+ """Test getting model capabilities with shorthand."""
+ provider = XAIModelProvider("test-key")
+
+ capabilities = provider.get_capabilities("grok")
+ assert capabilities.model_name == "grok-3" # Should resolve to full name
+ assert capabilities.context_window == 131_072
+
+ capabilities_fast = provider.get_capabilities("grokfast")
+ assert capabilities_fast.model_name == "grok-3-fast" # Should resolve to full name
+
+ def test_unsupported_model_capabilities(self):
+ """Test error handling for unsupported models."""
+ provider = XAIModelProvider("test-key")
+
+ with pytest.raises(ValueError, match="Unsupported X.AI model"):
+ provider.get_capabilities("invalid-model")
+
+ def test_no_thinking_mode_support(self):
+ """Test that X.AI models don't support thinking mode."""
+ provider = XAIModelProvider("test-key")
+
+ assert not provider.supports_thinking_mode("grok-3")
+ assert not provider.supports_thinking_mode("grok-3-fast")
+ assert not provider.supports_thinking_mode("grok")
+ assert not provider.supports_thinking_mode("grokfast")
+
+ def test_provider_type(self):
+ """Test provider type identification."""
+ provider = XAIModelProvider("test-key")
+ assert provider.get_provider_type() == ProviderType.XAI
+
+ @patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok-3"})
+ def test_model_restrictions(self):
+ """Test model restrictions functionality."""
+ # Clear cached restriction service
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ provider = XAIModelProvider("test-key")
+
+ # grok-3 should be allowed
+ assert provider.validate_model_name("grok-3") is True
+ assert provider.validate_model_name("grok") is True # Shorthand for grok-3
+
+ # grok-3-fast should be blocked by restrictions
+ assert provider.validate_model_name("grok-3-fast") is False
+ assert provider.validate_model_name("grokfast") is False
+
+ @patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3-fast"})
+ def test_multiple_model_restrictions(self):
+ """Test multiple models in restrictions."""
+ # Clear cached restriction service
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ provider = XAIModelProvider("test-key")
+
+ # Shorthand "grok" should be allowed (resolves to grok-3)
+ assert provider.validate_model_name("grok") is True
+
+ # Full name "grok-3" should NOT be allowed (only shorthand "grok" is in restriction list)
+ assert provider.validate_model_name("grok-3") is False
+
+ # "grok-3-fast" should be allowed (explicitly listed)
+ assert provider.validate_model_name("grok-3-fast") is True
+
+ # Shorthand "grokfast" should be allowed (resolves to grok-3-fast)
+ assert provider.validate_model_name("grokfast") is True
+
+ @patch.dict(os.environ, {"XAI_ALLOWED_MODELS": "grok,grok-3"})
+ def test_both_shorthand_and_full_name_allowed(self):
+ """Test that both shorthand and full name can be allowed."""
+ # Clear cached restriction service
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ provider = XAIModelProvider("test-key")
+
+ # Both shorthand and full name should be allowed
+ assert provider.validate_model_name("grok") is True
+ assert provider.validate_model_name("grok-3") is True
+
+ # Other models should not be allowed
+ assert provider.validate_model_name("grok-3-fast") is False
+ assert provider.validate_model_name("grokfast") is False
+
+ @patch.dict(os.environ, {"XAI_ALLOWED_MODELS": ""})
+ def test_empty_restrictions_allows_all(self):
+ """Test that empty restrictions allow all models."""
+ # Clear cached restriction service
+ import utils.model_restrictions
+
+ utils.model_restrictions._restriction_service = None
+
+ provider = XAIModelProvider("test-key")
+
+ assert provider.validate_model_name("grok-3") is True
+ assert provider.validate_model_name("grok-3-fast") is True
+ assert provider.validate_model_name("grok") is True
+ assert provider.validate_model_name("grokfast") is True
+
+ def test_friendly_name(self):
+ """Test friendly name constant."""
+ provider = XAIModelProvider("test-key")
+ assert provider.FRIENDLY_NAME == "X.AI"
+
+ capabilities = provider.get_capabilities("grok-3")
+ assert capabilities.friendly_name == "X.AI"
+
+ def test_supported_models_structure(self):
+ """Test that SUPPORTED_MODELS has the correct structure."""
+ provider = XAIModelProvider("test-key")
+
+ # Check that all expected models are present
+ assert "grok-3" in provider.SUPPORTED_MODELS
+ assert "grok-3-fast" in provider.SUPPORTED_MODELS
+ assert "grok" in provider.SUPPORTED_MODELS
+ assert "grok3" in provider.SUPPORTED_MODELS
+ assert "grokfast" in provider.SUPPORTED_MODELS
+ assert "grok3fast" in provider.SUPPORTED_MODELS
+
+ # Check model configs have required fields
+ grok3_config = provider.SUPPORTED_MODELS["grok-3"]
+ assert isinstance(grok3_config, dict)
+ assert "context_window" in grok3_config
+ assert "supports_extended_thinking" in grok3_config
+ assert grok3_config["context_window"] == 131_072
+ assert grok3_config["supports_extended_thinking"] is False
+
+ # Check shortcuts point to full names
+ assert provider.SUPPORTED_MODELS["grok"] == "grok-3"
+ assert provider.SUPPORTED_MODELS["grokfast"] == "grok-3-fast"
+
+ @patch("providers.openai_compatible.OpenAI")
+ def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class):
+ """Test that generate_content resolves aliases before making API calls.
+
+ This is the CRITICAL test that ensures aliases like 'grok' get resolved
+ to 'grok-3' before being sent to X.AI API.
+ """
+ # Set up mock OpenAI client
+ mock_client = MagicMock()
+ mock_openai_class.return_value = mock_client
+
+ # Mock the completion response
+ mock_response = MagicMock()
+ mock_response.choices = [MagicMock()]
+ mock_response.choices[0].message.content = "Test response"
+ mock_response.choices[0].finish_reason = "stop"
+ mock_response.model = "grok-3" # API returns the resolved model name
+ mock_response.id = "test-id"
+ mock_response.created = 1234567890
+ mock_response.usage = MagicMock()
+ mock_response.usage.prompt_tokens = 10
+ mock_response.usage.completion_tokens = 5
+ mock_response.usage.total_tokens = 15
+
+ mock_client.chat.completions.create.return_value = mock_response
+
+ provider = XAIModelProvider("test-key")
+
+ # Call generate_content with alias 'grok'
+ result = provider.generate_content(
+ prompt="Test prompt", model_name="grok", temperature=0.7 # This should be resolved to "grok-3"
+ )
+
+ # Verify the API was called with the RESOLVED model name
+ mock_client.chat.completions.create.assert_called_once()
+ call_kwargs = mock_client.chat.completions.create.call_args[1]
+
+ # CRITICAL ASSERTION: The API should receive "grok-3", not "grok"
+ assert call_kwargs["model"] == "grok-3", f"Expected 'grok-3' but API received '{call_kwargs['model']}'"
+
+ # Verify other parameters
+ assert call_kwargs["temperature"] == 0.7
+ assert len(call_kwargs["messages"]) == 1
+ assert call_kwargs["messages"][0]["role"] == "user"
+ assert call_kwargs["messages"][0]["content"] == "Test prompt"
+
+ # Verify response
+ assert result.content == "Test response"
+ assert result.model_name == "grok-3" # Should be the resolved name
+
+ @patch("providers.openai_compatible.OpenAI")
+ def test_generate_content_other_aliases(self, mock_openai_class):
+ """Test other alias resolutions in generate_content."""
+ from unittest.mock import MagicMock
+
+ # Set up mock
+ mock_client = MagicMock()
+ mock_openai_class.return_value = mock_client
+ mock_response = MagicMock()
+ mock_response.choices = [MagicMock()]
+ mock_response.choices[0].message.content = "Test response"
+ mock_response.choices[0].finish_reason = "stop"
+ mock_response.usage = MagicMock()
+ mock_response.usage.prompt_tokens = 10
+ mock_response.usage.completion_tokens = 5
+ mock_response.usage.total_tokens = 15
+ mock_client.chat.completions.create.return_value = mock_response
+
+ provider = XAIModelProvider("test-key")
+
+ # Test grok3 -> grok-3
+ mock_response.model = "grok-3"
+ provider.generate_content(prompt="Test", model_name="grok3", temperature=0.7)
+ call_kwargs = mock_client.chat.completions.create.call_args[1]
+ assert call_kwargs["model"] == "grok-3"
+
+ # Test grokfast -> grok-3-fast
+ mock_response.model = "grok-3-fast"
+ provider.generate_content(prompt="Test", model_name="grokfast", temperature=0.7)
+ call_kwargs = mock_client.chat.completions.create.call_args[1]
+ assert call_kwargs["model"] == "grok-3-fast"
+
+ # Test grok3fast -> grok-3-fast
+ provider.generate_content(prompt="Test", model_name="grok3fast", temperature=0.7)
+ call_kwargs = mock_client.chat.completions.create.call_args[1]
+ assert call_kwargs["model"] == "grok-3-fast"
diff --git a/utils/model_restrictions.py b/utils/model_restrictions.py
index 22e7d70..12906b0 100644
--- a/utils/model_restrictions.py
+++ b/utils/model_restrictions.py
@@ -9,11 +9,13 @@ standardization purposes.
Environment Variables:
- OPENAI_ALLOWED_MODELS: Comma-separated list of allowed OpenAI models
- GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models
+- XAI_ALLOWED_MODELS: Comma-separated list of allowed X.AI GROK models
- OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models
Example:
OPENAI_ALLOWED_MODELS=o3-mini,o4-mini
GOOGLE_ALLOWED_MODELS=flash
+ XAI_ALLOWED_MODELS=grok-3,grok-3-fast
OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral
"""
@@ -40,6 +42,7 @@ class ModelRestrictionService:
ENV_VARS = {
ProviderType.OPENAI: "OPENAI_ALLOWED_MODELS",
ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS",
+ ProviderType.XAI: "XAI_ALLOWED_MODELS",
ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS",
}