Native support for xAI Grok3
Model shorthand mapping related fixes Comprehensive auto-mode related tests
This commit is contained in:
582
tests/test_auto_mode_comprehensive.py
Normal file
582
tests/test_auto_mode_comprehensive.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""Comprehensive tests for auto mode functionality across all provider combinations"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from tools.analyze import AnalyzeTool
|
||||
from tools.chat import ChatTool
|
||||
from tools.debug import DebugIssueTool
|
||||
from tools.models import ToolModelCategory
|
||||
from tools.thinkdeep import ThinkDeepTool
|
||||
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
class TestAutoModeComprehensive:
|
||||
"""Test auto mode model selection across all provider combinations"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up clean state before each test."""
|
||||
# Save original environment state for restoration
|
||||
import os
|
||||
|
||||
self._original_default_model = os.environ.get("DEFAULT_MODEL", "")
|
||||
|
||||
# Clear restriction service cache
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Clear provider registry by resetting singleton instance
|
||||
ModelProviderRegistry._instance = None
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test."""
|
||||
# Restore original DEFAULT_MODEL
|
||||
import os
|
||||
|
||||
if self._original_default_model:
|
||||
os.environ["DEFAULT_MODEL"] = self._original_default_model
|
||||
elif "DEFAULT_MODEL" in os.environ:
|
||||
del os.environ["DEFAULT_MODEL"]
|
||||
|
||||
# Reload config to pick up the restored DEFAULT_MODEL
|
||||
import importlib
|
||||
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Clear restriction service cache
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Clear provider registry by resetting singleton instance
|
||||
ModelProviderRegistry._instance = None
|
||||
|
||||
# Re-register providers for subsequent tests (like conftest.py does)
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_config,expected_models",
|
||||
[
|
||||
# Only Gemini API available
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "gemini-2.5-pro-preview-06-05", # Pro for deep thinking
|
||||
"FAST_RESPONSE": "gemini-2.5-flash-preview-05-20", # Flash for speed
|
||||
"BALANCED": "gemini-2.5-flash-preview-05-20", # Flash as balanced
|
||||
},
|
||||
),
|
||||
# Only OpenAI API available
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": None,
|
||||
"OPENAI_API_KEY": "real-key",
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "o3", # O3 for deep reasoning
|
||||
"FAST_RESPONSE": "o4-mini", # O4-mini for speed
|
||||
"BALANCED": "o4-mini", # O4-mini as balanced
|
||||
},
|
||||
),
|
||||
# Only X.AI API available
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": None,
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": "real-key",
|
||||
"OPENROUTER_API_KEY": None,
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "grok-3", # GROK-3 for reasoning
|
||||
"FAST_RESPONSE": "grok-3-fast", # GROK-3-fast for speed
|
||||
"BALANCED": "grok-3", # GROK-3 as balanced
|
||||
},
|
||||
),
|
||||
# Both Gemini and OpenAI available - should prefer based on tool category
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": "real-key",
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
|
||||
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
|
||||
"BALANCED": "o4-mini", # Prefer OpenAI for balanced
|
||||
},
|
||||
),
|
||||
# All native APIs available - should prefer based on tool category
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": "real-key",
|
||||
"XAI_API_KEY": "real-key",
|
||||
"OPENROUTER_API_KEY": None,
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "o3", # Prefer O3 for deep reasoning
|
||||
"FAST_RESPONSE": "o4-mini", # Prefer O4-mini for speed
|
||||
"BALANCED": "o4-mini", # Prefer OpenAI for balanced
|
||||
},
|
||||
),
|
||||
# Only OpenRouter available - should fall back to proxy models
|
||||
(
|
||||
{
|
||||
"GEMINI_API_KEY": None,
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": "real-key",
|
||||
},
|
||||
{
|
||||
"EXTENDED_REASONING": "anthropic/claude-3.5-sonnet", # First preferred thinking model from OpenRouter
|
||||
"FAST_RESPONSE": "anthropic/claude-3-opus", # First available OpenRouter model
|
||||
"BALANCED": "anthropic/claude-3-opus", # First available OpenRouter model
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_auto_mode_model_selection_by_provider(self, provider_config, expected_models):
|
||||
"""Test that auto mode selects correct models based on available providers."""
|
||||
|
||||
# Set up environment with specific provider configuration
|
||||
# Filter out None values and handle them separately
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
# Reload config to pick up auto mode
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register providers based on configuration
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
if provider_config.get("GEMINI_API_KEY"):
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
if provider_config.get("OPENAI_API_KEY"):
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
if provider_config.get("XAI_API_KEY"):
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
if provider_config.get("OPENROUTER_API_KEY"):
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
|
||||
# Test each tool category
|
||||
for category_name, expected_model in expected_models.items():
|
||||
category = ToolModelCategory(category_name.lower())
|
||||
|
||||
# Get preferred fallback model for this category
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model(category)
|
||||
|
||||
assert fallback_model == expected_model, (
|
||||
f"Provider config {provider_config}: "
|
||||
f"Expected {expected_model} for {category_name}, got {fallback_model}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_class,expected_category",
|
||||
[
|
||||
(ChatTool, ToolModelCategory.FAST_RESPONSE),
|
||||
(AnalyzeTool, ToolModelCategory.EXTENDED_REASONING), # AnalyzeTool uses EXTENDED_REASONING
|
||||
(DebugIssueTool, ToolModelCategory.EXTENDED_REASONING),
|
||||
(ThinkDeepTool, ToolModelCategory.EXTENDED_REASONING),
|
||||
],
|
||||
)
|
||||
def test_tool_model_categories(self, tool_class, expected_category):
|
||||
"""Test that tools have the correct model categories."""
|
||||
tool = tool_class()
|
||||
assert tool.get_model_category() == expected_category
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_gemini_only_uses_correct_models(self):
|
||||
"""Test that auto mode with only Gemini uses flash for fast tools and pro for reasoning tools."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register only Gemini provider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
# Mock provider to capture what model is requested
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="test response", model_name="test-model", usage={"input_tokens": 10, "output_tokens": 5}
|
||||
)
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
|
||||
# Test ChatTool (FAST_RESPONSE) - should prefer flash
|
||||
chat_tool = ChatTool()
|
||||
await chat_tool.execute({"prompt": "test", "model": "auto"}) # This should trigger auto selection
|
||||
|
||||
# In auto mode, the tool should get an error requiring model selection
|
||||
# but the suggested model should be flash
|
||||
|
||||
# Reset mock for next test
|
||||
ModelProviderRegistry.get_provider_for_model.reset_mock()
|
||||
|
||||
# Test DebugIssueTool (EXTENDED_REASONING) - should prefer pro
|
||||
debug_tool = DebugIssueTool()
|
||||
await debug_tool.execute({"prompt": "test error", "model": "auto"})
|
||||
|
||||
def test_auto_mode_schema_includes_all_available_models(self):
|
||||
"""Test that auto mode schema includes all available models for user convenience."""
|
||||
|
||||
# Test with only Gemini available
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register only Gemini provider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
tool = AnalyzeTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Should have model as required field
|
||||
assert "model" in schema["required"]
|
||||
|
||||
# Should include all model options from global config
|
||||
model_schema = schema["properties"]["model"]
|
||||
assert "enum" in model_schema
|
||||
|
||||
available_models = model_schema["enum"]
|
||||
|
||||
# Should include Gemini models
|
||||
assert "flash" in available_models
|
||||
assert "pro" in available_models
|
||||
assert "gemini-2.5-flash-preview-05-20" in available_models
|
||||
assert "gemini-2.5-pro-preview-06-05" in available_models
|
||||
|
||||
# Should also include other models (users might have OpenRouter configured)
|
||||
# The schema should show all options; validation happens at runtime
|
||||
assert "o3" in available_models
|
||||
assert "o4-mini" in available_models
|
||||
assert "grok" in available_models
|
||||
assert "grok-3" in available_models
|
||||
|
||||
def test_auto_mode_schema_with_all_providers(self):
|
||||
"""Test that auto mode schema includes models from all available providers."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": "real-key",
|
||||
"XAI_API_KEY": "real-key",
|
||||
"OPENROUTER_API_KEY": None, # Don't include OpenRouter to avoid infinite models
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register all native providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
tool = AnalyzeTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
model_schema = schema["properties"]["model"]
|
||||
available_models = model_schema["enum"]
|
||||
|
||||
# Should include models from all providers
|
||||
# Gemini models
|
||||
assert "flash" in available_models
|
||||
assert "pro" in available_models
|
||||
|
||||
# OpenAI models
|
||||
assert "o3" in available_models
|
||||
assert "o4-mini" in available_models
|
||||
|
||||
# XAI models
|
||||
assert "grok" in available_models
|
||||
assert "grok-3" in available_models
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_model_parameter_required_error(self):
|
||||
"""Test that auto mode properly requires model parameter and suggests correct model."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register only Gemini provider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
# Test with ChatTool (FAST_RESPONSE category)
|
||||
chat_tool = ChatTool()
|
||||
result = await chat_tool.execute(
|
||||
{
|
||||
"prompt": "test"
|
||||
# Note: no model parameter provided in auto mode
|
||||
}
|
||||
)
|
||||
|
||||
# Should get error requiring model selection
|
||||
assert len(result) == 1
|
||||
response_text = result[0].text
|
||||
|
||||
# Parse JSON response to check error
|
||||
import json
|
||||
|
||||
response_data = json.loads(response_text)
|
||||
|
||||
assert response_data["status"] == "error"
|
||||
assert "Model parameter is required" in response_data["content"]
|
||||
assert "flash" in response_data["content"] # Should suggest flash for FAST_RESPONSE
|
||||
assert "category: fast_response" in response_data["content"]
|
||||
|
||||
def test_model_availability_with_restrictions(self):
|
||||
"""Test that auto mode respects model restrictions when selecting fallback models."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": "real-key",
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
"DEFAULT_MODEL": "auto",
|
||||
"OPENAI_ALLOWED_MODELS": "o4-mini", # Restrict OpenAI to only o4-mini
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Clear restriction service to pick up new env vars
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Register providers
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai import OpenAIModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
# Get available models - should respect restrictions
|
||||
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
|
||||
# Should include restricted OpenAI model
|
||||
assert "o4-mini" in available_models
|
||||
|
||||
# Should NOT include non-restricted OpenAI models
|
||||
assert "o3" not in available_models
|
||||
assert "o3-mini" not in available_models
|
||||
|
||||
# Should still include all Gemini models (no restrictions)
|
||||
assert "gemini-2.5-flash-preview-05-20" in available_models
|
||||
assert "gemini-2.5-pro-preview-06-05" in available_models
|
||||
|
||||
def test_openrouter_fallback_when_no_native_apis(self):
|
||||
"""Test that OpenRouter provides fallback models when no native APIs are available."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": None,
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": "real-key",
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register only OpenRouter provider
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
|
||||
# Mock OpenRouter registry to return known models
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_models.return_value = [
|
||||
"google/gemini-2.5-flash-preview-05-20",
|
||||
"google/gemini-2.5-pro-preview-06-05",
|
||||
"openai/o3",
|
||||
"openai/o4-mini",
|
||||
"anthropic/claude-3-opus",
|
||||
]
|
||||
|
||||
with patch.object(OpenRouterProvider, "_registry", mock_registry):
|
||||
# Get preferred models for different categories
|
||||
extended_reasoning = ModelProviderRegistry.get_preferred_fallback_model(
|
||||
ToolModelCategory.EXTENDED_REASONING
|
||||
)
|
||||
fast_response = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
|
||||
|
||||
# Should fallback to known good models even via OpenRouter
|
||||
# The exact model depends on _find_extended_thinking_model implementation
|
||||
assert extended_reasoning is not None
|
||||
assert fast_response is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_actual_model_name_resolution_in_auto_mode(self):
|
||||
"""Test that when a model is selected in auto mode, the tool executes successfully."""
|
||||
|
||||
provider_config = {
|
||||
"GEMINI_API_KEY": "real-key",
|
||||
"OPENAI_API_KEY": None,
|
||||
"XAI_API_KEY": None,
|
||||
"OPENROUTER_API_KEY": None,
|
||||
"DEFAULT_MODEL": "auto",
|
||||
}
|
||||
|
||||
# Filter out None values to avoid patch.dict errors
|
||||
env_to_set = {k: v for k, v in provider_config.items() if v is not None}
|
||||
env_to_clear = [k for k, v in provider_config.items() if v is None]
|
||||
|
||||
with patch.dict(os.environ, env_to_set, clear=False):
|
||||
# Clear the None-valued environment variables
|
||||
for key in env_to_clear:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Register Gemini provider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
# Mock the actual provider to simulate successful execution
|
||||
mock_provider = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "test response"
|
||||
mock_response.model_name = "gemini-2.5-flash-preview-05-20" # The resolved name
|
||||
mock_response.usage = {"input_tokens": 10, "output_tokens": 5}
|
||||
# Mock _resolve_model_name to simulate alias resolution
|
||||
mock_provider._resolve_model_name = lambda alias: (
|
||||
"gemini-2.5-flash-preview-05-20" if alias == "flash" else alias
|
||||
)
|
||||
mock_provider.generate_content.return_value = mock_response
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider_for_model", return_value=mock_provider):
|
||||
chat_tool = ChatTool()
|
||||
result = await chat_tool.execute({"prompt": "test", "model": "flash"}) # Use alias in auto mode
|
||||
|
||||
# Should succeed with proper model resolution
|
||||
assert len(result) == 1
|
||||
# Just verify that the tool executed successfully and didn't return an error
|
||||
assert "error" not in result[0].text.lower()
|
||||
Reference in New Issue
Block a user