Use the new flash model
Updated tests
This commit is contained in:
@@ -26,7 +26,7 @@ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "auto")
|
||||
|
||||
# Validate DEFAULT_MODEL and set to "auto" if invalid
|
||||
# Only include actually supported models from providers
|
||||
VALID_MODELS = ["auto", "flash", "pro", "o3", "o3-mini", "gemini-2.0-flash-exp", "gemini-2.5-pro-preview-06-05"]
|
||||
VALID_MODELS = ["auto", "flash", "pro", "o3", "o3-mini", "gemini-2.0-flash", "gemini-2.5-pro-preview-06-05"]
|
||||
if DEFAULT_MODEL not in VALID_MODELS:
|
||||
import logging
|
||||
|
||||
@@ -47,7 +47,7 @@ MODEL_CAPABILITIES_DESC = {
|
||||
"o3": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
|
||||
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||
# Full model names also supported
|
||||
"gemini-2.0-flash-exp": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||
"gemini-2.0-flash": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||
"gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
|
||||
# Model configurations
|
||||
SUPPORTED_MODELS = {
|
||||
"gemini-2.0-flash-exp": {
|
||||
"gemini-2.0-flash": {
|
||||
"max_tokens": 1_048_576, # 1M tokens
|
||||
"supports_extended_thinking": False,
|
||||
},
|
||||
@@ -22,7 +22,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
"supports_extended_thinking": True,
|
||||
},
|
||||
# Shorthands
|
||||
"flash": "gemini-2.0-flash-exp",
|
||||
"flash": "gemini-2.0-flash",
|
||||
"pro": "gemini-2.5-pro-preview-06-05",
|
||||
}
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ class ModelProviderRegistry:
|
||||
"""Get provider instance for a specific model name.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (e.g., "gemini-2.0-flash-exp", "o3-mini")
|
||||
model_name: Name of the model (e.g., "gemini-2.0-flash", "o3-mini")
|
||||
|
||||
Returns:
|
||||
ModelProvider instance that supports this model
|
||||
@@ -125,6 +125,50 @@ class ModelProviderRegistry:
|
||||
|
||||
return os.getenv(env_var)
|
||||
|
||||
@classmethod
|
||||
def get_preferred_fallback_model(cls) -> str:
|
||||
"""Get the preferred fallback model based on available API keys.
|
||||
|
||||
This method checks which providers have valid API keys and returns
|
||||
a sensible default model for auto mode fallback situations.
|
||||
|
||||
Priority order:
|
||||
1. OpenAI o3-mini (balanced performance/cost) if OpenAI API key available
|
||||
2. Gemini 2.0 Flash (fast and efficient) if Gemini API key available
|
||||
3. OpenAI o3 (high performance) if OpenAI API key available
|
||||
4. Gemini 2.5 Pro (deep reasoning) if Gemini API key available
|
||||
5. Fallback to gemini-2.0-flash (most common case)
|
||||
|
||||
Returns:
|
||||
Model name string for fallback use
|
||||
"""
|
||||
# Check provider availability by trying to get instances
|
||||
openai_available = cls.get_provider(ProviderType.OPENAI) is not None
|
||||
gemini_available = cls.get_provider(ProviderType.GOOGLE) is not None
|
||||
|
||||
# Priority order: prefer balanced models first, then high-performance
|
||||
if openai_available:
|
||||
return "o3-mini" # Balanced performance/cost
|
||||
elif gemini_available:
|
||||
return "gemini-2.0-flash" # Fast and efficient
|
||||
else:
|
||||
# No API keys available - return a reasonable default
|
||||
# This maintains backward compatibility for tests
|
||||
return "gemini-2.0-flash"
|
||||
|
||||
@classmethod
|
||||
def get_available_providers_with_keys(cls) -> list[ProviderType]:
|
||||
"""Get list of provider types that have valid API keys.
|
||||
|
||||
Returns:
|
||||
List of ProviderType values for providers with valid API keys
|
||||
"""
|
||||
available = []
|
||||
for provider_type in cls._providers:
|
||||
if cls.get_provider(provider_type) is not None:
|
||||
available.append(provider_type)
|
||||
return available
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear cached provider instances."""
|
||||
|
||||
@@ -55,7 +55,7 @@ class TestModelThinkingConfig(BaseSimulatorTest):
|
||||
"chat",
|
||||
{
|
||||
"prompt": "What is 3 + 3? Give a quick answer.",
|
||||
"model": "flash", # Should resolve to gemini-2.0-flash-exp
|
||||
"model": "flash", # Should resolve to gemini-2.0-flash
|
||||
"thinking_mode": "high", # Should be ignored for Flash model
|
||||
},
|
||||
)
|
||||
@@ -80,7 +80,7 @@ class TestModelThinkingConfig(BaseSimulatorTest):
|
||||
("pro", "should work with Pro model"),
|
||||
("flash", "should work with Flash model"),
|
||||
("gemini-2.5-pro-preview-06-05", "should work with full Pro model name"),
|
||||
("gemini-2.0-flash-exp", "should work with full Flash model name"),
|
||||
("gemini-2.0-flash", "should work with full Flash model name"),
|
||||
]
|
||||
|
||||
success_count = 0
|
||||
|
||||
@@ -23,7 +23,7 @@ if "OPENAI_API_KEY" not in os.environ:
|
||||
|
||||
# Set default model to a specific value for tests to avoid auto mode
|
||||
# This prevents all tests from failing due to missing model parameter
|
||||
os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash-exp"
|
||||
os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash"
|
||||
|
||||
# Force reload of config module to pick up the env var
|
||||
import importlib
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import Mock
|
||||
from providers.base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||
|
||||
|
||||
def create_mock_provider(model_name="gemini-2.0-flash-exp", max_tokens=1_048_576):
|
||||
def create_mock_provider(model_name="gemini-2.0-flash", max_tokens=1_048_576):
|
||||
"""Create a properly configured mock provider."""
|
||||
mock_provider = Mock()
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ class TestClaudeContinuationOffers:
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Analysis complete.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -129,7 +129,7 @@ class TestClaudeContinuationOffers:
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Continued analysis.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -162,7 +162,7 @@ class TestClaudeContinuationOffers:
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Analysis complete. The code looks good.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -208,7 +208,7 @@ I'd be happy to examine the error handling patterns in more detail if that would
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=content_with_followup,
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -253,7 +253,7 @@ I'd be happy to examine the error handling patterns in more detail if that would
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Continued analysis complete.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -309,7 +309,7 @@ I'd be happy to examine the error handling patterns in more detail if that would
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Final response.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -358,7 +358,7 @@ class TestContinuationIntegration:
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Analysis result",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -411,7 +411,7 @@ class TestContinuationIntegration:
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Structure analysis done.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -448,7 +448,7 @@ class TestContinuationIntegration:
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Performance analysis done.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class TestDynamicContextRequests:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -82,7 +82,7 @@ class TestDynamicContextRequests:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=normal_response, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content=normal_response, usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -106,7 +106,7 @@ class TestDynamicContextRequests:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=malformed_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content=malformed_json, usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -146,7 +146,7 @@ class TestDynamicContextRequests:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -233,7 +233,7 @@ class TestCollaborationWorkflow:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -272,7 +272,7 @@ class TestCollaborationWorkflow:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -299,7 +299,7 @@ class TestCollaborationWorkflow:
|
||||
"""
|
||||
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=final_response, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content=final_response, usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
|
||||
result2 = await tool.execute(
|
||||
|
||||
@@ -32,7 +32,7 @@ class TestConfig:
|
||||
def test_model_config(self):
|
||||
"""Test model configuration"""
|
||||
# DEFAULT_MODEL is set in conftest.py for tests
|
||||
assert DEFAULT_MODEL == "gemini-2.0-flash-exp"
|
||||
assert DEFAULT_MODEL == "gemini-2.0-flash"
|
||||
assert MAX_CONTEXT_TOKENS == 1_000_000
|
||||
|
||||
def test_temperature_defaults(self):
|
||||
|
||||
@@ -74,7 +74,7 @@ async def test_conversation_history_field_mapping():
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_capabilities.return_value = ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
friendly_name="Gemini",
|
||||
max_tokens=200000,
|
||||
supports_extended_thinking=True,
|
||||
|
||||
@@ -115,7 +115,7 @@ class TestConversationHistoryBugFix:
|
||||
return Mock(
|
||||
content="Response with conversation context",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
|
||||
@@ -175,7 +175,7 @@ class TestConversationHistoryBugFix:
|
||||
return Mock(
|
||||
content="Response without history",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
|
||||
@@ -213,7 +213,7 @@ class TestConversationHistoryBugFix:
|
||||
return Mock(
|
||||
content="New conversation response",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
|
||||
@@ -297,7 +297,7 @@ class TestConversationHistoryBugFix:
|
||||
return Mock(
|
||||
content="Analysis of new files complete",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ I'd be happy to review these security findings in detail if that would be helpfu
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=content,
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -158,7 +158,7 @@ I'd be happy to review these security findings in detail if that would be helpfu
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks.",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -279,7 +279,7 @@ I'd be happy to review these security findings in detail if that would be helpfu
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Security review of auth.py shows vulnerabilities",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
181
tests/test_intelligent_fallback.py
Normal file
181
tests/test_intelligent_fallback.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Test suite for intelligent auto mode fallback logic
|
||||
|
||||
Tests the new dynamic model selection based on available API keys
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
|
||||
class TestIntelligentFallback:
|
||||
"""Test intelligent model fallback logic"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup for each test - clear registry cache"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Cleanup after each test"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False)
|
||||
def test_prefers_openai_o3_mini_when_available(self):
|
||||
"""Test that o3-mini is preferred when OpenAI API key is available"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
assert fallback_model == "o3-mini"
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
|
||||
def test_prefers_gemini_flash_when_openai_unavailable(self):
|
||||
"""Test that gemini-2.0-flash is used when only Gemini API key is available"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
assert fallback_model == "gemini-2.0-flash"
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
|
||||
def test_prefers_openai_when_both_available(self):
|
||||
"""Test that OpenAI is preferred when both API keys are available"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
assert fallback_model == "o3-mini" # OpenAI has priority
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False)
|
||||
def test_fallback_when_no_keys_available(self):
|
||||
"""Test fallback behavior when no API keys are available"""
|
||||
ModelProviderRegistry.clear_cache()
|
||||
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
assert fallback_model == "gemini-2.0-flash" # Default fallback
|
||||
|
||||
def test_available_providers_with_keys(self):
|
||||
"""Test the get_available_providers_with_keys method"""
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False):
|
||||
ModelProviderRegistry.clear_cache()
|
||||
available = ModelProviderRegistry.get_available_providers_with_keys()
|
||||
assert ProviderType.OPENAI in available
|
||||
assert ProviderType.GOOGLE not in available
|
||||
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False):
|
||||
ModelProviderRegistry.clear_cache()
|
||||
available = ModelProviderRegistry.get_available_providers_with_keys()
|
||||
assert ProviderType.GOOGLE in available
|
||||
assert ProviderType.OPENAI not in available
|
||||
|
||||
def test_auto_mode_conversation_memory_integration(self):
|
||||
"""Test that conversation memory uses intelligent fallback in auto mode"""
|
||||
from utils.conversation_memory import ThreadContext, build_conversation_history
|
||||
|
||||
# Mock auto mode - patch the config module where these values are defined
|
||||
with (
|
||||
patch("config.IS_AUTO_MODE", True),
|
||||
patch("config.DEFAULT_MODEL", "auto"),
|
||||
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False),
|
||||
):
|
||||
|
||||
ModelProviderRegistry.clear_cache()
|
||||
|
||||
# Create a context with at least one turn so it doesn't exit early
|
||||
from utils.conversation_memory import ConversationTurn
|
||||
|
||||
context = ThreadContext(
|
||||
thread_id="test-123",
|
||||
created_at="2023-01-01T00:00:00Z",
|
||||
last_updated_at="2023-01-01T00:00:00Z",
|
||||
tool_name="chat",
|
||||
turns=[ConversationTurn(role="user", content="Test message", timestamp="2023-01-01T00:00:30Z")],
|
||||
initial_context={},
|
||||
)
|
||||
|
||||
# This should use o3-mini for token calculations since OpenAI is available
|
||||
with patch("utils.model_context.ModelContext") as mock_context_class:
|
||||
mock_context_instance = Mock()
|
||||
mock_context_class.return_value = mock_context_instance
|
||||
mock_context_instance.calculate_token_allocation.return_value = Mock(
|
||||
file_tokens=10000, history_tokens=5000
|
||||
)
|
||||
# Mock estimate_tokens to return integers for proper summing
|
||||
mock_context_instance.estimate_tokens.return_value = 100
|
||||
|
||||
history, tokens = build_conversation_history(context, model_context=None)
|
||||
|
||||
# Verify that ModelContext was called with o3-mini (the intelligent fallback)
|
||||
mock_context_class.assert_called_once_with("o3-mini")
|
||||
|
||||
def test_auto_mode_with_gemini_only(self):
|
||||
"""Test auto mode behavior when only Gemini API key is available"""
|
||||
from utils.conversation_memory import ThreadContext, build_conversation_history
|
||||
|
||||
with (
|
||||
patch("config.IS_AUTO_MODE", True),
|
||||
patch("config.DEFAULT_MODEL", "auto"),
|
||||
patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False),
|
||||
):
|
||||
|
||||
ModelProviderRegistry.clear_cache()
|
||||
|
||||
from utils.conversation_memory import ConversationTurn
|
||||
|
||||
context = ThreadContext(
|
||||
thread_id="test-456",
|
||||
created_at="2023-01-01T00:00:00Z",
|
||||
last_updated_at="2023-01-01T00:00:00Z",
|
||||
tool_name="analyze",
|
||||
turns=[ConversationTurn(role="assistant", content="Test response", timestamp="2023-01-01T00:00:30Z")],
|
||||
initial_context={},
|
||||
)
|
||||
|
||||
with patch("utils.model_context.ModelContext") as mock_context_class:
|
||||
mock_context_instance = Mock()
|
||||
mock_context_class.return_value = mock_context_instance
|
||||
mock_context_instance.calculate_token_allocation.return_value = Mock(
|
||||
file_tokens=10000, history_tokens=5000
|
||||
)
|
||||
# Mock estimate_tokens to return integers for proper summing
|
||||
mock_context_instance.estimate_tokens.return_value = 100
|
||||
|
||||
history, tokens = build_conversation_history(context, model_context=None)
|
||||
|
||||
# Should use gemini-2.0-flash when only Gemini is available
|
||||
mock_context_class.assert_called_once_with("gemini-2.0-flash")
|
||||
|
||||
def test_non_auto_mode_unchanged(self):
|
||||
"""Test that non-auto mode behavior is unchanged"""
|
||||
from utils.conversation_memory import ThreadContext, build_conversation_history
|
||||
|
||||
with patch("config.IS_AUTO_MODE", False), patch("config.DEFAULT_MODEL", "gemini-2.5-pro-preview-06-05"):
|
||||
|
||||
from utils.conversation_memory import ConversationTurn
|
||||
|
||||
context = ThreadContext(
|
||||
thread_id="test-789",
|
||||
created_at="2023-01-01T00:00:00Z",
|
||||
last_updated_at="2023-01-01T00:00:00Z",
|
||||
tool_name="thinkdeep",
|
||||
turns=[
|
||||
ConversationTurn(role="user", content="Test in non-auto mode", timestamp="2023-01-01T00:00:30Z")
|
||||
],
|
||||
initial_context={},
|
||||
)
|
||||
|
||||
with patch("utils.model_context.ModelContext") as mock_context_class:
|
||||
mock_context_instance = Mock()
|
||||
mock_context_class.return_value = mock_context_instance
|
||||
mock_context_instance.calculate_token_allocation.return_value = Mock(
|
||||
file_tokens=10000, history_tokens=5000
|
||||
)
|
||||
# Mock estimate_tokens to return integers for proper summing
|
||||
mock_context_instance.estimate_tokens.return_value = 100
|
||||
|
||||
history, tokens = build_conversation_history(context, model_context=None)
|
||||
|
||||
# Should use the configured DEFAULT_MODEL, not the intelligent fallback
|
||||
mock_context_class.assert_called_once_with("gemini-2.5-pro-preview-06-05")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -75,7 +75,7 @@ class TestLargePromptHandling:
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="This is a test response",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -100,7 +100,7 @@ class TestLargePromptHandling:
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="Processed large prompt",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -212,7 +212,7 @@ class TestLargePromptHandling:
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="Success",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -245,7 +245,7 @@ class TestLargePromptHandling:
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="Success",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -276,7 +276,7 @@ class TestLargePromptHandling:
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="Success",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
@@ -298,7 +298,7 @@ class TestLargePromptHandling:
|
||||
mock_provider.generate_content.return_value = MagicMock(
|
||||
content="Success",
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestPromptRegression:
|
||||
return Mock(
|
||||
content=text,
|
||||
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
model_name="gemini-2.0-flash",
|
||||
metadata={"finish_reason": "STOP"},
|
||||
)
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class TestModelProviderRegistry:
|
||||
"""Test getting provider for a specific model"""
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
|
||||
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash-exp")
|
||||
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash")
|
||||
|
||||
assert provider is not None
|
||||
assert isinstance(provider, GeminiModelProvider)
|
||||
@@ -80,10 +80,10 @@ class TestGeminiProvider:
|
||||
"""Test getting model capabilities"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
capabilities = provider.get_capabilities("gemini-2.0-flash-exp")
|
||||
capabilities = provider.get_capabilities("gemini-2.0-flash")
|
||||
|
||||
assert capabilities.provider == ProviderType.GOOGLE
|
||||
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
||||
assert capabilities.model_name == "gemini-2.0-flash"
|
||||
assert capabilities.max_tokens == 1_048_576
|
||||
assert not capabilities.supports_extended_thinking
|
||||
|
||||
@@ -103,13 +103,13 @@ class TestGeminiProvider:
|
||||
assert provider.validate_model_name("pro")
|
||||
|
||||
capabilities = provider.get_capabilities("flash")
|
||||
assert capabilities.model_name == "gemini-2.0-flash-exp"
|
||||
assert capabilities.model_name == "gemini-2.0-flash"
|
||||
|
||||
def test_supports_thinking_mode(self):
|
||||
"""Test thinking mode support detection"""
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
assert not provider.supports_thinking_mode("gemini-2.0-flash-exp")
|
||||
assert not provider.supports_thinking_mode("gemini-2.0-flash")
|
||||
assert provider.supports_thinking_mode("gemini-2.5-pro-preview-06-05")
|
||||
|
||||
@patch("google.genai.Client")
|
||||
@@ -133,11 +133,11 @@ class TestGeminiProvider:
|
||||
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
response = provider.generate_content(prompt="Test prompt", model_name="gemini-2.0-flash-exp", temperature=0.7)
|
||||
response = provider.generate_content(prompt="Test prompt", model_name="gemini-2.0-flash", temperature=0.7)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.content == "Generated content"
|
||||
assert response.model_name == "gemini-2.0-flash-exp"
|
||||
assert response.model_name == "gemini-2.0-flash"
|
||||
assert response.provider == ProviderType.GOOGLE
|
||||
assert response.usage["input_tokens"] == 10
|
||||
assert response.usage["output_tokens"] == 20
|
||||
|
||||
@@ -56,7 +56,7 @@ class TestServerTools:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Chat response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content="Chat response", usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestThinkingModes:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Minimal thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content="Minimal thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -82,7 +82,7 @@ class TestThinkingModes:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Low thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content="Low thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -114,7 +114,7 @@ class TestThinkingModes:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Medium thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content="Medium thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -145,7 +145,7 @@ class TestThinkingModes:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="High thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content="High thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -175,7 +175,7 @@ class TestThinkingModes:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Max thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content="Max thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class TestThinkDeepTool:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = True
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Extended analysis", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content="Extended analysis", usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -88,7 +88,7 @@ class TestCodeReviewTool:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Security issues found", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content="Security issues found", usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -133,7 +133,7 @@ class TestDebugIssueTool:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Root cause: race condition", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content="Root cause: race condition", usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -181,7 +181,7 @@ class TestAnalyzeTool:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Architecture analysis", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content="Architecture analysis", usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
@@ -295,7 +295,7 @@ class TestAbsolutePathValidation:
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.supports_thinking_mode.return_value = False
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content="Analysis complete", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
|
||||
content="Analysis complete", usage={}, model_name="gemini-2.0-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class ConversationTurn(BaseModel):
|
||||
files: List of file paths referenced in this specific turn
|
||||
tool_name: Which tool generated this turn (for cross-tool tracking)
|
||||
model_provider: Provider used (e.g., "google", "openai")
|
||||
model_name: Specific model used (e.g., "gemini-2.0-flash-exp", "o3-mini")
|
||||
model_name: Specific model used (e.g., "gemini-2.0-flash", "o3-mini")
|
||||
model_metadata: Additional model-specific metadata (e.g., thinking mode, token usage)
|
||||
"""
|
||||
|
||||
@@ -249,7 +249,7 @@ def add_turn(
|
||||
files: Optional list of files referenced in this turn
|
||||
tool_name: Name of the tool adding this turn (for attribution)
|
||||
model_provider: Provider used (e.g., "google", "openai")
|
||||
model_name: Specific model used (e.g., "gemini-2.0-flash-exp", "o3-mini")
|
||||
model_name: Specific model used (e.g., "gemini-2.0-flash", "o3-mini")
|
||||
model_metadata: Additional model info (e.g., thinking mode, token usage)
|
||||
|
||||
Returns:
|
||||
@@ -454,10 +454,19 @@ def build_conversation_history(context: ThreadContext, model_context=None, read_
|
||||
|
||||
# Get model-specific token allocation early (needed for both files and turns)
|
||||
if model_context is None:
|
||||
from config import DEFAULT_MODEL
|
||||
from config import DEFAULT_MODEL, IS_AUTO_MODE
|
||||
from utils.model_context import ModelContext
|
||||
|
||||
model_context = ModelContext(DEFAULT_MODEL)
|
||||
# In auto mode, use an intelligent fallback model for token calculations
|
||||
# since "auto" is not a real model with a provider
|
||||
model_name = DEFAULT_MODEL
|
||||
if IS_AUTO_MODE and model_name.lower() == "auto":
|
||||
# Use intelligent fallback based on available API keys
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
model_name = ModelProviderRegistry.get_preferred_fallback_model()
|
||||
|
||||
model_context = ModelContext(model_name)
|
||||
|
||||
token_allocation = model_context.calculate_token_allocation()
|
||||
max_file_tokens = token_allocation.file_tokens
|
||||
|
||||
Reference in New Issue
Block a user