Use the new flash model

Updated tests
This commit is contained in:
Fahad
2025-06-12 13:44:09 +04:00
parent 8b8d966d33
commit 79af2654b9
20 changed files with 297 additions and 63 deletions

View File

@@ -26,7 +26,7 @@ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "auto")
# Validate DEFAULT_MODEL and set to "auto" if invalid
# Only include actually supported models from providers
VALID_MODELS = ["auto", "flash", "pro", "o3", "o3-mini", "gemini-2.0-flash-exp", "gemini-2.5-pro-preview-06-05"]
VALID_MODELS = ["auto", "flash", "pro", "o3", "o3-mini", "gemini-2.0-flash", "gemini-2.5-pro-preview-06-05"]
if DEFAULT_MODEL not in VALID_MODELS:
import logging
@@ -47,7 +47,7 @@ MODEL_CAPABILITIES_DESC = {
"o3": "Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
"o3-mini": "Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
# Full model names also supported
"gemini-2.0-flash-exp": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
"gemini-2.0-flash": "Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
"gemini-2.5-pro-preview-06-05": "Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
}

View File

@@ -13,7 +13,7 @@ class GeminiModelProvider(ModelProvider):
# Model configurations
SUPPORTED_MODELS = {
"gemini-2.0-flash-exp": {
"gemini-2.0-flash": {
"max_tokens": 1_048_576, # 1M tokens
"supports_extended_thinking": False,
},
@@ -22,7 +22,7 @@ class GeminiModelProvider(ModelProvider):
"supports_extended_thinking": True,
},
# Shorthands
"flash": "gemini-2.0-flash-exp",
"flash": "gemini-2.0-flash",
"pro": "gemini-2.5-pro-preview-06-05",
}

View File

@@ -67,7 +67,7 @@ class ModelProviderRegistry:
"""Get provider instance for a specific model name.
Args:
model_name: Name of the model (e.g., "gemini-2.0-flash-exp", "o3-mini")
model_name: Name of the model (e.g., "gemini-2.0-flash", "o3-mini")
Returns:
ModelProvider instance that supports this model
@@ -125,6 +125,50 @@ class ModelProviderRegistry:
return os.getenv(env_var)
@classmethod
def get_preferred_fallback_model(cls) -> str:
"""Get the preferred fallback model based on available API keys.
This method checks which providers have valid API keys and returns
a sensible default model for auto mode fallback situations.
Priority order:
1. OpenAI o3-mini (balanced performance/cost) if OpenAI API key available
2. Gemini 2.0 Flash (fast and efficient) if Gemini API key available
3. OpenAI o3 (high performance) if OpenAI API key available
4. Gemini 2.5 Pro (deep reasoning) if Gemini API key available
5. Fallback to gemini-2.0-flash (most common case)
Returns:
Model name string for fallback use
"""
# Check provider availability by trying to get instances
openai_available = cls.get_provider(ProviderType.OPENAI) is not None
gemini_available = cls.get_provider(ProviderType.GOOGLE) is not None
# Priority order: prefer balanced models first, then high-performance
if openai_available:
return "o3-mini" # Balanced performance/cost
elif gemini_available:
return "gemini-2.0-flash" # Fast and efficient
else:
# No API keys available - return a reasonable default
# This maintains backward compatibility for tests
return "gemini-2.0-flash"
@classmethod
def get_available_providers_with_keys(cls) -> list[ProviderType]:
"""Get list of provider types that have valid API keys.
Returns:
List of ProviderType values for providers with valid API keys
"""
available = []
for provider_type in cls._providers:
if cls.get_provider(provider_type) is not None:
available.append(provider_type)
return available
@classmethod
def clear_cache(cls) -> None:
"""Clear cached provider instances."""

View File

@@ -55,7 +55,7 @@ class TestModelThinkingConfig(BaseSimulatorTest):
"chat",
{
"prompt": "What is 3 + 3? Give a quick answer.",
"model": "flash", # Should resolve to gemini-2.0-flash-exp
"model": "flash", # Should resolve to gemini-2.0-flash
"thinking_mode": "high", # Should be ignored for Flash model
},
)
@@ -80,7 +80,7 @@ class TestModelThinkingConfig(BaseSimulatorTest):
("pro", "should work with Pro model"),
("flash", "should work with Flash model"),
("gemini-2.5-pro-preview-06-05", "should work with full Pro model name"),
("gemini-2.0-flash-exp", "should work with full Flash model name"),
("gemini-2.0-flash", "should work with full Flash model name"),
]
success_count = 0

View File

@@ -23,7 +23,7 @@ if "OPENAI_API_KEY" not in os.environ:
# Set default model to a specific value for tests to avoid auto mode
# This prevents all tests from failing due to missing model parameter
os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash-exp"
os.environ["DEFAULT_MODEL"] = "gemini-2.0-flash"
# Force reload of config module to pick up the env var
import importlib

View File

@@ -5,7 +5,7 @@ from unittest.mock import Mock
from providers.base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
def create_mock_provider(model_name="gemini-2.0-flash-exp", max_tokens=1_048_576):
def create_mock_provider(model_name="gemini-2.0-flash", max_tokens=1_048_576):
"""Create a properly configured mock provider."""
mock_provider = Mock()

View File

@@ -72,7 +72,7 @@ class TestClaudeContinuationOffers:
mock_provider.generate_content.return_value = Mock(
content="Analysis complete.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -129,7 +129,7 @@ class TestClaudeContinuationOffers:
mock_provider.generate_content.return_value = Mock(
content="Continued analysis.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -162,7 +162,7 @@ class TestClaudeContinuationOffers:
mock_provider.generate_content.return_value = Mock(
content="Analysis complete. The code looks good.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -208,7 +208,7 @@ I'd be happy to examine the error handling patterns in more detail if that would
mock_provider.generate_content.return_value = Mock(
content=content_with_followup,
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -253,7 +253,7 @@ I'd be happy to examine the error handling patterns in more detail if that would
mock_provider.generate_content.return_value = Mock(
content="Continued analysis complete.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -309,7 +309,7 @@ I'd be happy to examine the error handling patterns in more detail if that would
mock_provider.generate_content.return_value = Mock(
content="Final response.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -358,7 +358,7 @@ class TestContinuationIntegration:
mock_provider.generate_content.return_value = Mock(
content="Analysis result",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -411,7 +411,7 @@ class TestContinuationIntegration:
mock_provider.generate_content.return_value = Mock(
content="Structure analysis done.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -448,7 +448,7 @@ class TestContinuationIntegration:
mock_provider.generate_content.return_value = Mock(
content="Performance analysis done.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)

View File

@@ -41,7 +41,7 @@ class TestDynamicContextRequests:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -82,7 +82,7 @@ class TestDynamicContextRequests:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=normal_response, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content=normal_response, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -106,7 +106,7 @@ class TestDynamicContextRequests:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=malformed_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content=malformed_json, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -146,7 +146,7 @@ class TestDynamicContextRequests:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -233,7 +233,7 @@ class TestCollaborationWorkflow:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -272,7 +272,7 @@ class TestCollaborationWorkflow:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content=clarification_json, usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -299,7 +299,7 @@ class TestCollaborationWorkflow:
"""
mock_provider.generate_content.return_value = Mock(
content=final_response, usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content=final_response, usage={}, model_name="gemini-2.0-flash", metadata={}
)
result2 = await tool.execute(

View File

@@ -32,7 +32,7 @@ class TestConfig:
def test_model_config(self):
"""Test model configuration"""
# DEFAULT_MODEL is set in conftest.py for tests
assert DEFAULT_MODEL == "gemini-2.0-flash-exp"
assert DEFAULT_MODEL == "gemini-2.0-flash"
assert MAX_CONTEXT_TOKENS == 1_000_000
def test_temperature_defaults(self):

View File

@@ -74,7 +74,7 @@ async def test_conversation_history_field_mapping():
mock_provider = MagicMock()
mock_provider.get_capabilities.return_value = ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
friendly_name="Gemini",
max_tokens=200000,
supports_extended_thinking=True,

View File

@@ -115,7 +115,7 @@ class TestConversationHistoryBugFix:
return Mock(
content="Response with conversation context",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
@@ -175,7 +175,7 @@ class TestConversationHistoryBugFix:
return Mock(
content="Response without history",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
@@ -213,7 +213,7 @@ class TestConversationHistoryBugFix:
return Mock(
content="New conversation response",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
@@ -297,7 +297,7 @@ class TestConversationHistoryBugFix:
return Mock(
content="Analysis of new files complete",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)

View File

@@ -111,7 +111,7 @@ I'd be happy to review these security findings in detail if that would be helpfu
mock_provider.generate_content.return_value = Mock(
content=content,
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -158,7 +158,7 @@ I'd be happy to review these security findings in detail if that would be helpfu
mock_provider.generate_content.return_value = Mock(
content="Critical security vulnerability confirmed. The authentication function always returns true, bypassing all security checks.",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -279,7 +279,7 @@ I'd be happy to review these security findings in detail if that would be helpfu
mock_provider.generate_content.return_value = Mock(
content="Security review of auth.py shows vulnerabilities",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider

View File

@@ -0,0 +1,181 @@
"""
Test suite for intelligent auto mode fallback logic
Tests the new dynamic model selection based on available API keys
"""
import os
from unittest.mock import Mock, patch
import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
class TestIntelligentFallback:
"""Test intelligent model fallback logic"""
def setup_method(self):
"""Setup for each test - clear registry cache"""
ModelProviderRegistry.clear_cache()
def teardown_method(self):
"""Cleanup after each test"""
ModelProviderRegistry.clear_cache()
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False)
def test_prefers_openai_o3_mini_when_available(self):
"""Test that o3-mini is preferred when OpenAI API key is available"""
ModelProviderRegistry.clear_cache()
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o3-mini"
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
def test_prefers_gemini_flash_when_openai_unavailable(self):
"""Test that gemini-2.0-flash is used when only Gemini API key is available"""
ModelProviderRegistry.clear_cache()
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "gemini-2.0-flash"
@patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": "test-gemini-key"}, clear=False)
def test_prefers_openai_when_both_available(self):
"""Test that OpenAI is preferred when both API keys are available"""
ModelProviderRegistry.clear_cache()
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "o3-mini" # OpenAI has priority
@patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": ""}, clear=False)
def test_fallback_when_no_keys_available(self):
"""Test fallback behavior when no API keys are available"""
ModelProviderRegistry.clear_cache()
fallback_model = ModelProviderRegistry.get_preferred_fallback_model()
assert fallback_model == "gemini-2.0-flash" # Default fallback
def test_available_providers_with_keys(self):
"""Test the get_available_providers_with_keys method"""
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False):
ModelProviderRegistry.clear_cache()
available = ModelProviderRegistry.get_available_providers_with_keys()
assert ProviderType.OPENAI in available
assert ProviderType.GOOGLE not in available
with patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False):
ModelProviderRegistry.clear_cache()
available = ModelProviderRegistry.get_available_providers_with_keys()
assert ProviderType.GOOGLE in available
assert ProviderType.OPENAI not in available
def test_auto_mode_conversation_memory_integration(self):
"""Test that conversation memory uses intelligent fallback in auto mode"""
from utils.conversation_memory import ThreadContext, build_conversation_history
# Mock auto mode - patch the config module where these values are defined
with (
patch("config.IS_AUTO_MODE", True),
patch("config.DEFAULT_MODEL", "auto"),
patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "GEMINI_API_KEY": ""}, clear=False),
):
ModelProviderRegistry.clear_cache()
# Create a context with at least one turn so it doesn't exit early
from utils.conversation_memory import ConversationTurn
context = ThreadContext(
thread_id="test-123",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:00:00Z",
tool_name="chat",
turns=[ConversationTurn(role="user", content="Test message", timestamp="2023-01-01T00:00:30Z")],
initial_context={},
)
# This should use o3-mini for token calculations since OpenAI is available
with patch("utils.model_context.ModelContext") as mock_context_class:
mock_context_instance = Mock()
mock_context_class.return_value = mock_context_instance
mock_context_instance.calculate_token_allocation.return_value = Mock(
file_tokens=10000, history_tokens=5000
)
# Mock estimate_tokens to return integers for proper summing
mock_context_instance.estimate_tokens.return_value = 100
history, tokens = build_conversation_history(context, model_context=None)
# Verify that ModelContext was called with o3-mini (the intelligent fallback)
mock_context_class.assert_called_once_with("o3-mini")
def test_auto_mode_with_gemini_only(self):
"""Test auto mode behavior when only Gemini API key is available"""
from utils.conversation_memory import ThreadContext, build_conversation_history
with (
patch("config.IS_AUTO_MODE", True),
patch("config.DEFAULT_MODEL", "auto"),
patch.dict(os.environ, {"OPENAI_API_KEY": "", "GEMINI_API_KEY": "test-key"}, clear=False),
):
ModelProviderRegistry.clear_cache()
from utils.conversation_memory import ConversationTurn
context = ThreadContext(
thread_id="test-456",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:00:00Z",
tool_name="analyze",
turns=[ConversationTurn(role="assistant", content="Test response", timestamp="2023-01-01T00:00:30Z")],
initial_context={},
)
with patch("utils.model_context.ModelContext") as mock_context_class:
mock_context_instance = Mock()
mock_context_class.return_value = mock_context_instance
mock_context_instance.calculate_token_allocation.return_value = Mock(
file_tokens=10000, history_tokens=5000
)
# Mock estimate_tokens to return integers for proper summing
mock_context_instance.estimate_tokens.return_value = 100
history, tokens = build_conversation_history(context, model_context=None)
# Should use gemini-2.0-flash when only Gemini is available
mock_context_class.assert_called_once_with("gemini-2.0-flash")
def test_non_auto_mode_unchanged(self):
"""Test that non-auto mode behavior is unchanged"""
from utils.conversation_memory import ThreadContext, build_conversation_history
with patch("config.IS_AUTO_MODE", False), patch("config.DEFAULT_MODEL", "gemini-2.5-pro-preview-06-05"):
from utils.conversation_memory import ConversationTurn
context = ThreadContext(
thread_id="test-789",
created_at="2023-01-01T00:00:00Z",
last_updated_at="2023-01-01T00:00:00Z",
tool_name="thinkdeep",
turns=[
ConversationTurn(role="user", content="Test in non-auto mode", timestamp="2023-01-01T00:00:30Z")
],
initial_context={},
)
with patch("utils.model_context.ModelContext") as mock_context_class:
mock_context_instance = Mock()
mock_context_class.return_value = mock_context_instance
mock_context_instance.calculate_token_allocation.return_value = Mock(
file_tokens=10000, history_tokens=5000
)
# Mock estimate_tokens to return integers for proper summing
mock_context_instance.estimate_tokens.return_value = 100
history, tokens = build_conversation_history(context, model_context=None)
# Should use the configured DEFAULT_MODEL, not the intelligent fallback
mock_context_class.assert_called_once_with("gemini-2.5-pro-preview-06-05")
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -75,7 +75,7 @@ class TestLargePromptHandling:
mock_provider.generate_content.return_value = MagicMock(
content="This is a test response",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -100,7 +100,7 @@ class TestLargePromptHandling:
mock_provider.generate_content.return_value = MagicMock(
content="Processed large prompt",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -212,7 +212,7 @@ class TestLargePromptHandling:
mock_provider.generate_content.return_value = MagicMock(
content="Success",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -245,7 +245,7 @@ class TestLargePromptHandling:
mock_provider.generate_content.return_value = MagicMock(
content="Success",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -276,7 +276,7 @@ class TestLargePromptHandling:
mock_provider.generate_content.return_value = MagicMock(
content="Success",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider
@@ -298,7 +298,7 @@ class TestLargePromptHandling:
mock_provider.generate_content.return_value = MagicMock(
content="Success",
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)
mock_get_provider.return_value = mock_provider

View File

@@ -31,7 +31,7 @@ class TestPromptRegression:
return Mock(
content=text,
usage={"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
model_name="gemini-2.0-flash-exp",
model_name="gemini-2.0-flash",
metadata={"finish_reason": "STOP"},
)

View File

@@ -49,7 +49,7 @@ class TestModelProviderRegistry:
"""Test getting provider for a specific model"""
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash-exp")
provider = ModelProviderRegistry.get_provider_for_model("gemini-2.0-flash")
assert provider is not None
assert isinstance(provider, GeminiModelProvider)
@@ -80,10 +80,10 @@ class TestGeminiProvider:
"""Test getting model capabilities"""
provider = GeminiModelProvider(api_key="test-key")
capabilities = provider.get_capabilities("gemini-2.0-flash-exp")
capabilities = provider.get_capabilities("gemini-2.0-flash")
assert capabilities.provider == ProviderType.GOOGLE
assert capabilities.model_name == "gemini-2.0-flash-exp"
assert capabilities.model_name == "gemini-2.0-flash"
assert capabilities.max_tokens == 1_048_576
assert not capabilities.supports_extended_thinking
@@ -103,13 +103,13 @@ class TestGeminiProvider:
assert provider.validate_model_name("pro")
capabilities = provider.get_capabilities("flash")
assert capabilities.model_name == "gemini-2.0-flash-exp"
assert capabilities.model_name == "gemini-2.0-flash"
def test_supports_thinking_mode(self):
"""Test thinking mode support detection"""
provider = GeminiModelProvider(api_key="test-key")
assert not provider.supports_thinking_mode("gemini-2.0-flash-exp")
assert not provider.supports_thinking_mode("gemini-2.0-flash")
assert provider.supports_thinking_mode("gemini-2.5-pro-preview-06-05")
@patch("google.genai.Client")
@@ -133,11 +133,11 @@ class TestGeminiProvider:
provider = GeminiModelProvider(api_key="test-key")
response = provider.generate_content(prompt="Test prompt", model_name="gemini-2.0-flash-exp", temperature=0.7)
response = provider.generate_content(prompt="Test prompt", model_name="gemini-2.0-flash", temperature=0.7)
assert isinstance(response, ModelResponse)
assert response.content == "Generated content"
assert response.model_name == "gemini-2.0-flash-exp"
assert response.model_name == "gemini-2.0-flash"
assert response.provider == ProviderType.GOOGLE
assert response.usage["input_tokens"] == 10
assert response.usage["output_tokens"] == 20

View File

@@ -56,7 +56,7 @@ class TestServerTools:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Chat response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content="Chat response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider

View File

@@ -45,7 +45,7 @@ class TestThinkingModes:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="Minimal thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content="Minimal thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -82,7 +82,7 @@ class TestThinkingModes:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="Low thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content="Low thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -114,7 +114,7 @@ class TestThinkingModes:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="Medium thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content="Medium thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -145,7 +145,7 @@ class TestThinkingModes:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="High thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content="High thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -175,7 +175,7 @@ class TestThinkingModes:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="Max thinking response", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content="Max thinking response", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider

View File

@@ -37,7 +37,7 @@ class TestThinkDeepTool:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = True
mock_provider.generate_content.return_value = Mock(
content="Extended analysis", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content="Extended analysis", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -88,7 +88,7 @@ class TestCodeReviewTool:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Security issues found", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content="Security issues found", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -133,7 +133,7 @@ class TestDebugIssueTool:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Root cause: race condition", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content="Root cause: race condition", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -181,7 +181,7 @@ class TestAnalyzeTool:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Architecture analysis", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content="Architecture analysis", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
@@ -295,7 +295,7 @@ class TestAbsolutePathValidation:
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.supports_thinking_mode.return_value = False
mock_provider.generate_content.return_value = Mock(
content="Analysis complete", usage={}, model_name="gemini-2.0-flash-exp", metadata={}
content="Analysis complete", usage={}, model_name="gemini-2.0-flash", metadata={}
)
mock_get_provider.return_value = mock_provider

View File

@@ -74,7 +74,7 @@ class ConversationTurn(BaseModel):
files: List of file paths referenced in this specific turn
tool_name: Which tool generated this turn (for cross-tool tracking)
model_provider: Provider used (e.g., "google", "openai")
model_name: Specific model used (e.g., "gemini-2.0-flash-exp", "o3-mini")
model_name: Specific model used (e.g., "gemini-2.0-flash", "o3-mini")
model_metadata: Additional model-specific metadata (e.g., thinking mode, token usage)
"""
@@ -249,7 +249,7 @@ def add_turn(
files: Optional list of files referenced in this turn
tool_name: Name of the tool adding this turn (for attribution)
model_provider: Provider used (e.g., "google", "openai")
model_name: Specific model used (e.g., "gemini-2.0-flash-exp", "o3-mini")
model_name: Specific model used (e.g., "gemini-2.0-flash", "o3-mini")
model_metadata: Additional model info (e.g., thinking mode, token usage)
Returns:
@@ -454,10 +454,19 @@ def build_conversation_history(context: ThreadContext, model_context=None, read_
# Get model-specific token allocation early (needed for both files and turns)
if model_context is None:
from config import DEFAULT_MODEL
from config import DEFAULT_MODEL, IS_AUTO_MODE
from utils.model_context import ModelContext
model_context = ModelContext(DEFAULT_MODEL)
# In auto mode, use an intelligent fallback model for token calculations
# since "auto" is not a real model with a provider
model_name = DEFAULT_MODEL
if IS_AUTO_MODE and model_name.lower() == "auto":
# Use intelligent fallback based on available API keys
from providers.registry import ModelProviderRegistry
model_name = ModelProviderRegistry.get_preferred_fallback_model()
model_context = ModelContext(model_name)
token_allocation = model_context.calculate_token_allocation()
max_file_tokens = token_allocation.file_tokens