From 79af2654b944adf45f9b124599b366b26449547c Mon Sep 17 00:00:00 2001 From: Fahad Date: Thu, 12 Jun 2025 13:44:09 +0400 Subject: [PATCH] Use the new flash model Updated tests --- config.py | 4 +- providers/gemini.py | 4 +- providers/registry.py | 46 ++++- simulator_tests/test_model_thinking_config.py | 4 +- tests/conftest.py | 2 +- tests/mock_helpers.py | 2 +- tests/test_claude_continuation.py | 18 +- tests/test_collaboration.py | 14 +- tests/test_config.py | 2 +- tests/test_conversation_field_mapping.py | 2 +- tests/test_conversation_history_bug.py | 8 +- tests/test_cross_tool_continuation.py | 6 +- tests/test_intelligent_fallback.py | 181 ++++++++++++++++++ tests/test_large_prompt_handling.py | 12 +- tests/test_prompt_regression.py | 2 +- tests/test_providers.py | 14 +- tests/test_server.py | 2 +- tests/test_thinking_modes.py | 10 +- tests/test_tools.py | 10 +- utils/conversation_memory.py | 17 +- 20 files changed, 297 insertions(+), 63 deletions(-) create mode 100644 tests/test_intelligent_fallback.py diff --git a/config.py b/config.py index aa7ebc8..9e213f9 100644 --- a/config.py +++ b/config.py @@ -26,7 +26,7 @@ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "auto") # Validate DEFAULT_MODEL and set to "auto" if invalid # Only include actually supported models from providers -VALID_MODELS = ["auto", "flash", "pro", "o3", "o3-mini", "gemini-2.0-flash-exp", "gemini-2.5-pro-preview-06-05"] +VALID_MODELS = ["auto", "flash", "pro", "o3", "o3-mini", "gemini-2.0-flash", "gemini-2.5-pro-preview-06-05"] if DEFAULT_MODEL not in VALID_MODELS: import logging @@ -47,7 +47,7 @@ MODEL_CAPABILITIES_DESC = { "o3": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis", "o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity", # Full model names also supported - "gemini-2.0-flash-exp": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations", + "gemini-2.0-flash": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations", "gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis", } diff --git a/providers/gemini.py b/providers/gemini.py index 9b0c438..a80b4e4 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -13,7 +13,7 @@ class GeminiModelProvider(ModelProvider): # Model configurations SUPPORTED_MODELS = { - "gemini-2.0-flash-exp": { + "gemini-2.0-flash": { "max_tokens": 1_048_576, # 1M tokens "supports_extended_thinking": False, }, @@ -22,7 +22,7 @@ class GeminiModelProvider(ModelProvider): "supports_extended_thinking": True, }, # Shorthands - "flash": "gemini-2.0-flash-exp", + "flash": "gemini-2.0-flash", "pro": "gemini-2.5-pro-preview-06-05", } diff --git a/providers/registry.py b/providers/registry.py index 5dab34c..057821c 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -67,7 +67,7 @@ class ModelProviderRegistry: """Get provider instance for a specific model name. Args: - model_name: Name of the model (e.g., "gemini-2.0-flash-exp", "o3-mini") + model_name: Name of the model (e.g., "gemini-2.0-flash", "o3-mini") Returns: ModelProvider instance that supports this model @@ -125,6 +125,50 @@ class ModelProviderRegistry: return os.getenv(env_var) + @classmethod + def get_preferred_fallback_model(cls) -> str: + """Get the preferred fallback model based on available API keys. + + This method checks which providers have valid API keys and returns + a sensible default model for auto mode fallback situations. + + Priority order: + 1. OpenAI o3-mini (balanced performance/cost) if OpenAI API key available + 2. Gemini 2.0 Flash (fast and efficient) if Gemini API key available + 3. OpenAI o3 (high performance) if OpenAI API key available + 4. Gemini 2.5 Pro (deep reasoning) if Gemini API key available + 5. Fallback to gemini-2.0-flash (most common case) + + Returns: + Model name string for fallback use + """ + # Check provider availability by trying to get instances + openai_available = cls.get_provider(ProviderType.OPENAI) is not None + gemini_available = cls.get_provider(ProviderType.GOOGLE) is not None + + # Priority order: prefer balanced models first, then high-performance + if openai_available: + return "o3-mini" # Balanced performance/cost + elif gemini_available: + return "gemini-2.0-flash" # Fast and efficient + else: + # No API keys available - return a reasonable default + # This maintains backward compatibility for tests + return "gemini-2.0-flash" + + @classmethod + def get_available_providers_with_keys(cls) -> list[ProviderType]: + """Get list of provider types that have valid API keys. + + Returns: + List of ProviderType values for providers with valid API keys + """ + available = [] + for provider_type in cls._providers: + if cls.get_provider(provider_type) is not None: + available.append(provider_type) + return available + @classmethod def clear_cache(cls) -> None: """Clear cached provider instances.""" diff --git a/simulator_tests/test_model_thinking_config.py b/simulator_tests/test_model_thinking_config.py index 1a54bfe..b1b096f 100644 --- a/simulator_tests/test_model_thinking_config.py +++ b/simulator_tests/test_model_thinking_config.py @@ -55,7 +55,7 @@ class TestModelThinkingConfig(BaseSimulatorTest): "chat", { "prompt": "What is 3 + 3? Give a quick answer.", - "model": "flash", # Should resolve to gemini-2.0-flash-exp + "model": "flash", # Should resolve to gemini-2.0-flash "thinking_mode": "high", # Should be ignored for Flash model }, ) @@ -80,7 +80,7 @@ class TestModelThinkingConfig(BaseSimulatorTest): ("pro", "should work with Pro model"), ("flash", "should work with Flash model"), ("gemini-2.5-pro-preview-06-05", "should work with full Pro model name"), - ("gemini-2.0-flash-exp", "should work with full Flash model name"), + ("gemini-2.0-flash", "should work with full Flash model name"), ] success_count = 0 diff --git a/tests/conftest.py b/tests/conftest.py index 1f51d48..7948ce5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ if "OPENAI_API_KEY" not in os.environ: # Set default model to a specific value for tests to avoid auto mode # This prevents all tests from failing due to missing model parameter -os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash-exp" +os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash" # Force reload of config module to pick up the env var import importlib diff --git a/tests/mock_helpers.py b/tests/mock_helpers.py index c86ada1..0aa4c5c 100644 --- a/tests/mock_helpers.py +++ b/tests/mock_helpers.py @@ -5,7 +5,7 @@ from unittest.mock import Mock from providers.base import ModelCapabilities, ProviderType, RangeTemperatureConstraint -def create_mock_provider(model_name="gemini-2.0-flash-exp", max_tokens=1_048_576): +def create_mock_provider(model_name="gemini-2.0-flash", max_tokens=1_048_576): """Create a properly configured mock provider.""" mock_provider = Mock() diff --git a/tests/test_claude_continuation.py b/tests/test_claude_continuation.py index 96f48f4..bed5408 100644 --- a/tests/test_claude_continuation.py +++ b/tests/test_claude_continuation.py @@ -72,7 +72,7 @@ class TestClaudeContinuationOffers: mock_provider.generate_content.return_value = Mock( content="Analysis complete.", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -129,7 +129,7 @@ class TestClaudeContinuationOffers: mock_provider.generate_content.return_value = Mock( content="Continued analysis.", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -162,7 +162,7 @@ class TestClaudeContinuationOffers: mock_provider.generate_content.return_value = Mock( content="Analysis complete. The code looks good.", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -208,7 +208,7 @@ I'd be happy to examine the error handling patterns in more detail if that would mock_provider.generate_content.return_value = Mock( content=content_with_followup, usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -253,7 +253,7 @@ I'd be happy to examine the error handling patterns in more detail if that would mock_provider.generate_content.return_value = Mock( content="Continued analysis complete.", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -309,7 +309,7 @@ I'd be happy to examine the error handling patterns in more detail if that would mock_provider.generate_content.return_value = Mock( content="Final response.", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -358,7 +358,7 @@ class TestContinuationIntegration: mock_provider.generate_content.return_value = Mock( content="Analysis result", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -411,7 +411,7 @@ class TestContinuationIntegration: mock_provider.generate_content.return_value = Mock( content="Structure analysis done.", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -448,7 +448,7 @@ class TestContinuationIntegration: mock_provider.generate_content.return_value = Mock( content="Performance analysis done.", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) diff --git a/tests/test_collaboration.py b/tests/test_collaboration.py index 0a4901c..966cc39 100644 --- a/tests/test_collaboration.py +++ b/tests/test_collaboration.py @@ -41,7 +41,7 @@ class TestDynamicContextRequests: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = False mock_provider.generate_content.return_value = Mock( - content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -82,7 +82,7 @@ class TestDynamicContextRequests: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = False mock_provider.generate_content.return_value = Mock( - content=normal_response, usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content=normal_response, usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -106,7 +106,7 @@ class TestDynamicContextRequests: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = False mock_provider.generate_content.return_value = Mock( - content=malformed_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content=malformed_json, usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -146,7 +146,7 @@ class TestDynamicContextRequests: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = False mock_provider.generate_content.return_value = Mock( - content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -233,7 +233,7 @@ class TestCollaborationWorkflow: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = False mock_provider.generate_content.return_value = Mock( - content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -272,7 +272,7 @@ class TestCollaborationWorkflow: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = False mock_provider.generate_content.return_value = Mock( - content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -299,7 +299,7 @@ class TestCollaborationWorkflow: """ mock_provider.generate_content.return_value = Mock( - content=final_response, usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content=final_response, usage={}, model_name="gemini-2.0-flash", metadata={} ) result2 = await tool.execute( diff --git a/tests/test_config.py b/tests/test_config.py index e5aea20..0ac6368 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -32,7 +32,7 @@ class TestConfig: def test_model_config(self): """Test model configuration""" # DEFAULT_MODEL is set in conftest.py for tests - assert DEFAULT_MODEL == "gemini-2.0-flash-exp" + assert DEFAULT_MODEL == "gemini-2.0-flash" assert MAX_CONTEXT_TOKENS == 1_000_000 def test_temperature_defaults(self): diff --git a/tests/test_conversation_field_mapping.py b/tests/test_conversation_field_mapping.py index 1daef4f..42206a1 100644 --- a/tests/test_conversation_field_mapping.py +++ b/tests/test_conversation_field_mapping.py @@ -74,7 +74,7 @@ async def test_conversation_history_field_mapping(): mock_provider = MagicMock() mock_provider.get_capabilities.return_value = ModelCapabilities( provider=ProviderType.GOOGLE, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", friendly_name="Gemini", max_tokens=200000, supports_extended_thinking=True, diff --git a/tests/test_conversation_history_bug.py b/tests/test_conversation_history_bug.py index d2f1f18..ff76db8 100644 --- a/tests/test_conversation_history_bug.py +++ b/tests/test_conversation_history_bug.py @@ -115,7 +115,7 @@ class TestConversationHistoryBugFix: return Mock( content="Response with conversation context", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) @@ -175,7 +175,7 @@ class TestConversationHistoryBugFix: return Mock( content="Response without history", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) @@ -213,7 +213,7 @@ class TestConversationHistoryBugFix: return Mock( content="New conversation response", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) @@ -297,7 +297,7 @@ class TestConversationHistoryBugFix: return Mock( content="Analysis of new files complete", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) diff --git a/tests/test_cross_tool_continuation.py b/tests/test_cross_tool_continuation.py index 6ece479..7a124b0 100644 --- a/tests/test_cross_tool_continuation.py +++ b/tests/test_cross_tool_continuation.py @@ -111,7 +111,7 @@ I'd be happy to review these security findings in detail if that would be helpfu mock_provider.generate_content.return_value = Mock( content=content, usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -158,7 +158,7 @@ I'd be happy to review these security findings in detail if that would be helpfu mock_provider.generate_content.return_value = Mock( content="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks.", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -279,7 +279,7 @@ I'd be happy to review these security findings in detail if that would be helpfu mock_provider.generate_content.return_value = Mock( content="Security review of auth.py shows vulnerabilities", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider diff --git a/tests/test_intelligent_fallback.py b/tests/test_intelligent_fallback.py new file mode 100644 index 0000000..112f5bb --- /dev/null +++ b/tests/test_intelligent_fallback.py @@ -0,0 +1,181 @@ +""" +Test suite for intelligent auto mode fallback logic + +Tests the new dynamic model selection based on available API keys +""" + +import os +from unittest.mock import Mock, patch + +import pytest + +from providers.base import ProviderType +from providers.registry import ModelProviderRegistry + + +class TestIntelligentFallback: + """Test intelligent model fallback logic""" + + def setup_method(self): + """Setup for each test - clear registry cache""" + ModelProviderRegistry.clear_cache() + + def teardown_method(self): + """Cleanup after each test""" + ModelProviderRegistry.clear_cache() + + @patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False) + def test_prefers_openai_o3_mini_when_available(self): + """Test that o3-mini is preferred when OpenAI API key is available""" + ModelProviderRegistry.clear_cache() + fallback_model = ModelProviderRegistry.get_preferred_fallback_model() + assert fallback_model == "o3-mini" + + @patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False) + def test_prefers_gemini_flash_when_openai_unavailable(self): + """Test that gemini-2.0-flash is used when only Gemini API key is available""" + ModelProviderRegistry.clear_cache() + fallback_model = ModelProviderRegistry.get_preferred_fallback_model() + assert fallback_model == "gemini-2.0-flash" + + @patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": "test-gemini-key"}, clear=False) + def test_prefers_openai_when_both_available(self): + """Test that OpenAI is preferred when both API keys are available""" + ModelProviderRegistry.clear_cache() + fallback_model = ModelProviderRegistry.get_preferred_fallback_model() + assert fallback_model == "o3-mini" # OpenAI has priority + + @patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False) + def test_fallback_when_no_keys_available(self): + """Test fallback behavior when no API keys are available""" + ModelProviderRegistry.clear_cache() + fallback_model = ModelProviderRegistry.get_preferred_fallback_model() + assert fallback_model == "gemini-2.0-flash" # Default fallback + + def test_available_providers_with_keys(self): + """Test the get_available_providers_with_keys method""" + with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False): + ModelProviderRegistry.clear_cache() + available = ModelProviderRegistry.get_available_providers_with_keys() + assert ProviderType.OPENAI in available + assert ProviderType.GOOGLE not in available + + with patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False): + ModelProviderRegistry.clear_cache() + available = ModelProviderRegistry.get_available_providers_with_keys() + assert ProviderType.GOOGLE in available + assert ProviderType.OPENAI not in available + + def test_auto_mode_conversation_memory_integration(self): + """Test that conversation memory uses intelligent fallback in auto mode""" + from utils.conversation_memory import ThreadContext, build_conversation_history + + # Mock auto mode - patch the config module where these values are defined + with ( + patch("config.IS_AUTO_MODE", True), + patch("config.DEFAULT_MODEL", "auto"), + patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False), + ): + + ModelProviderRegistry.clear_cache() + + # Create a context with at least one turn so it doesn't exit early + from utils.conversation_memory import ConversationTurn + + context = ThreadContext( + thread_id="test-123", + created_at="2023-01-01T00:00:00Z", + last_updated_at="2023-01-01T00:00:00Z", + tool_name="chat", + turns=[ConversationTurn(role="user", content="Test message", timestamp="2023-01-01T00:00:30Z")], + initial_context={}, + ) + + # This should use o3-mini for token calculations since OpenAI is available + with patch("utils.model_context.ModelContext") as mock_context_class: + mock_context_instance = Mock() + mock_context_class.return_value = mock_context_instance + mock_context_instance.calculate_token_allocation.return_value = Mock( + file_tokens=10000, history_tokens=5000 + ) + # Mock estimate_tokens to return integers for proper summing + mock_context_instance.estimate_tokens.return_value = 100 + + history, tokens = build_conversation_history(context, model_context=None) + + # Verify that ModelContext was called with o3-mini (the intelligent fallback) + mock_context_class.assert_called_once_with("o3-mini") + + def test_auto_mode_with_gemini_only(self): + """Test auto mode behavior when only Gemini API key is available""" + from utils.conversation_memory import ThreadContext, build_conversation_history + + with ( + patch("config.IS_AUTO_MODE", True), + patch("config.DEFAULT_MODEL", "auto"), + patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False), + ): + + ModelProviderRegistry.clear_cache() + + from utils.conversation_memory import ConversationTurn + + context = ThreadContext( + thread_id="test-456", + created_at="2023-01-01T00:00:00Z", + last_updated_at="2023-01-01T00:00:00Z", + tool_name="analyze", + turns=[ConversationTurn(role="assistant", content="Test response", timestamp="2023-01-01T00:00:30Z")], + initial_context={}, + ) + + with patch("utils.model_context.ModelContext") as mock_context_class: + mock_context_instance = Mock() + mock_context_class.return_value = mock_context_instance + mock_context_instance.calculate_token_allocation.return_value = Mock( + file_tokens=10000, history_tokens=5000 + ) + # Mock estimate_tokens to return integers for proper summing + mock_context_instance.estimate_tokens.return_value = 100 + + history, tokens = build_conversation_history(context, model_context=None) + + # Should use gemini-2.0-flash when only Gemini is available + mock_context_class.assert_called_once_with("gemini-2.0-flash") + + def test_non_auto_mode_unchanged(self): + """Test that non-auto mode behavior is unchanged""" + from utils.conversation_memory import ThreadContext, build_conversation_history + + with patch("config.IS_AUTO_MODE", False), patch("config.DEFAULT_MODEL", "gemini-2.5-pro-preview-06-05"): + + from utils.conversation_memory import ConversationTurn + + context = ThreadContext( + thread_id="test-789", + created_at="2023-01-01T00:00:00Z", + last_updated_at="2023-01-01T00:00:00Z", + tool_name="thinkdeep", + turns=[ + ConversationTurn(role="user", content="Test in non-auto mode", timestamp="2023-01-01T00:00:30Z") + ], + initial_context={}, + ) + + with patch("utils.model_context.ModelContext") as mock_context_class: + mock_context_instance = Mock() + mock_context_class.return_value = mock_context_instance + mock_context_instance.calculate_token_allocation.return_value = Mock( + file_tokens=10000, history_tokens=5000 + ) + # Mock estimate_tokens to return integers for proper summing + mock_context_instance.estimate_tokens.return_value = 100 + + history, tokens = build_conversation_history(context, model_context=None) + + # Should use the configured DEFAULT_MODEL, not the intelligent fallback + mock_context_class.assert_called_once_with("gemini-2.5-pro-preview-06-05") + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_large_prompt_handling.py b/tests/test_large_prompt_handling.py index fd54bfc..33573aa 100644 --- a/tests/test_large_prompt_handling.py +++ b/tests/test_large_prompt_handling.py @@ -75,7 +75,7 @@ class TestLargePromptHandling: mock_provider.generate_content.return_value = MagicMock( content="This is a test response", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -100,7 +100,7 @@ class TestLargePromptHandling: mock_provider.generate_content.return_value = MagicMock( content="Processed large prompt", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -212,7 +212,7 @@ class TestLargePromptHandling: mock_provider.generate_content.return_value = MagicMock( content="Success", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -245,7 +245,7 @@ class TestLargePromptHandling: mock_provider.generate_content.return_value = MagicMock( content="Success", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -276,7 +276,7 @@ class TestLargePromptHandling: mock_provider.generate_content.return_value = MagicMock( content="Success", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider @@ -298,7 +298,7 @@ class TestLargePromptHandling: mock_provider.generate_content.return_value = MagicMock( content="Success", usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) mock_get_provider.return_value = mock_provider diff --git a/tests/test_prompt_regression.py b/tests/test_prompt_regression.py index 44651fd..cd5cedc 100644 --- a/tests/test_prompt_regression.py +++ b/tests/test_prompt_regression.py @@ -31,7 +31,7 @@ class TestPromptRegression: return Mock( content=text, usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, - model_name="gemini-2.0-flash-exp", + model_name="gemini-2.0-flash", metadata={"finish_reason": "STOP"}, ) diff --git a/tests/test_providers.py b/tests/test_providers.py index 519ee11..e7370de 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -49,7 +49,7 @@ class TestModelProviderRegistry: """Test getting provider for a specific model""" ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider) - provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash-exp") + provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash") assert provider is not None assert isinstance(provider, GeminiModelProvider) @@ -80,10 +80,10 @@ class TestGeminiProvider: """Test getting model capabilities""" provider = GeminiModelProvider(api_key="test-key") - capabilities = provider.get_capabilities("gemini-2.0-flash-exp") + capabilities = provider.get_capabilities("gemini-2.0-flash") assert capabilities.provider == ProviderType.GOOGLE - assert capabilities.model_name == "gemini-2.0-flash-exp" + assert capabilities.model_name == "gemini-2.0-flash" assert capabilities.max_tokens == 1_048_576 assert not capabilities.supports_extended_thinking @@ -103,13 +103,13 @@ class TestGeminiProvider: assert provider.validate_model_name("pro") capabilities = provider.get_capabilities("flash") - assert capabilities.model_name == "gemini-2.0-flash-exp" + assert capabilities.model_name == "gemini-2.0-flash" def test_supports_thinking_mode(self): """Test thinking mode support detection""" provider = GeminiModelProvider(api_key="test-key") - assert not provider.supports_thinking_mode("gemini-2.0-flash-exp") + assert not provider.supports_thinking_mode("gemini-2.0-flash") assert provider.supports_thinking_mode("gemini-2.5-pro-preview-06-05") @patch("google.genai.Client") @@ -133,11 +133,11 @@ class TestGeminiProvider: provider = GeminiModelProvider(api_key="test-key") - response = provider.generate_content(prompt="Test prompt", model_name="gemini-2.0-flash-exp", temperature=0.7) + response = provider.generate_content(prompt="Test prompt", model_name="gemini-2.0-flash", temperature=0.7) assert isinstance(response, ModelResponse) assert response.content == "Generated content" - assert response.model_name == "gemini-2.0-flash-exp" + assert response.model_name == "gemini-2.0-flash" assert response.provider == ProviderType.GOOGLE assert response.usage["input_tokens"] == 10 assert response.usage["output_tokens"] == 20 diff --git a/tests/test_server.py b/tests/test_server.py index 2d5cb99..4d81015 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -56,7 +56,7 @@ class TestServerTools: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = False mock_provider.generate_content.return_value = Mock( - content="Chat response", usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content="Chat response", usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider diff --git a/tests/test_thinking_modes.py b/tests/test_thinking_modes.py index 5215c55..8df8137 100644 --- a/tests/test_thinking_modes.py +++ b/tests/test_thinking_modes.py @@ -45,7 +45,7 @@ class TestThinkingModes: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = True mock_provider.generate_content.return_value = Mock( - content="Minimal thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content="Minimal thinking response", usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -82,7 +82,7 @@ class TestThinkingModes: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = True mock_provider.generate_content.return_value = Mock( - content="Low thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content="Low thinking response", usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -114,7 +114,7 @@ class TestThinkingModes: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = True mock_provider.generate_content.return_value = Mock( - content="Medium thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content="Medium thinking response", usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -145,7 +145,7 @@ class TestThinkingModes: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = True mock_provider.generate_content.return_value = Mock( - content="High thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content="High thinking response", usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -175,7 +175,7 @@ class TestThinkingModes: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = True mock_provider.generate_content.return_value = Mock( - content="Max thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content="Max thinking response", usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider diff --git a/tests/test_tools.py b/tests/test_tools.py index a811eab..73aba51 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -37,7 +37,7 @@ class TestThinkDeepTool: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = True mock_provider.generate_content.return_value = Mock( - content="Extended analysis", usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content="Extended analysis", usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -88,7 +88,7 @@ class TestCodeReviewTool: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = False mock_provider.generate_content.return_value = Mock( - content="Security issues found", usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content="Security issues found", usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -133,7 +133,7 @@ class TestDebugIssueTool: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = False mock_provider.generate_content.return_value = Mock( - content="Root cause: race condition", usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content="Root cause: race condition", usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -181,7 +181,7 @@ class TestAnalyzeTool: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = False mock_provider.generate_content.return_value = Mock( - content="Architecture analysis", usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content="Architecture analysis", usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider @@ -295,7 +295,7 @@ class TestAbsolutePathValidation: mock_provider.get_provider_type.return_value = Mock(value="google") mock_provider.supports_thinking_mode.return_value = False mock_provider.generate_content.return_value = Mock( - content="Analysis complete", usage={}, model_name="gemini-2.0-flash-exp", metadata={} + content="Analysis complete", usage={}, model_name="gemini-2.0-flash", metadata={} ) mock_get_provider.return_value = mock_provider diff --git a/utils/conversation_memory.py b/utils/conversation_memory.py index 2600a33..cdef754 100644 --- a/utils/conversation_memory.py +++ b/utils/conversation_memory.py @@ -74,7 +74,7 @@ class ConversationTurn(BaseModel): files: List of file paths referenced in this specific turn tool_name: Which tool generated this turn (for cross-tool tracking) model_provider: Provider used (e.g., "google", "openai") - model_name: Specific model used (e.g., "gemini-2.0-flash-exp", "o3-mini") + model_name: Specific model used (e.g., "gemini-2.0-flash", "o3-mini") model_metadata: Additional model-specific metadata (e.g., thinking mode, token usage) """ @@ -249,7 +249,7 @@ def add_turn( files: Optional list of files referenced in this turn tool_name: Name of the tool adding this turn (for attribution) model_provider: Provider used (e.g., "google", "openai") - model_name: Specific model used (e.g., "gemini-2.0-flash-exp", "o3-mini") + model_name: Specific model used (e.g., "gemini-2.0-flash", "o3-mini") model_metadata: Additional model info (e.g., thinking mode, token usage) Returns: @@ -454,10 +454,19 @@ def build_conversation_history(context: ThreadContext, model_context=None, read_ # Get model-specific token allocation early (needed for both files and turns) if model_context is None: - from config import DEFAULT_MODEL + from config import DEFAULT_MODEL, IS_AUTO_MODE from utils.model_context import ModelContext - model_context = ModelContext(DEFAULT_MODEL) + # In auto mode, use an intelligent fallback model for token calculations + # since "auto" is not a real model with a provider + model_name = DEFAULT_MODEL + if IS_AUTO_MODE and model_name.lower() == "auto": + # Use intelligent fallback based on available API keys + from providers.registry import ModelProviderRegistry + + model_name = ModelProviderRegistry.get_preferred_fallback_model() + + model_context = ModelContext(model_name) token_allocation = model_context.calculate_token_allocation() max_file_tokens = token_allocation.file_tokens