Gemini model rename
This commit is contained in:
@@ -22,8 +22,8 @@ class TestModelRestrictionService:
|
||||
# Should allow all models
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3")
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash-preview-05-20")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-opus")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "openai/o3")
|
||||
|
||||
@@ -43,7 +43,7 @@ class TestModelRestrictionService:
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o4-mini")
|
||||
|
||||
# Google and OpenRouter should have no restrictions
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-opus")
|
||||
|
||||
def test_load_multiple_models_restriction(self):
|
||||
@@ -59,7 +59,7 @@ class TestModelRestrictionService:
|
||||
# Check Google models
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "flash")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "pro")
|
||||
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05")
|
||||
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||
|
||||
def test_case_insensitive_and_whitespace_handling(self):
|
||||
"""Test that model names are case-insensitive and whitespace is trimmed."""
|
||||
@@ -84,9 +84,9 @@ class TestModelRestrictionService:
|
||||
|
||||
# Google should only allow flash (and its resolved name)
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "flash")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash-preview-05-20", "flash")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash", "flash")
|
||||
assert not service.is_allowed(ProviderType.GOOGLE, "pro")
|
||||
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05", "pro")
|
||||
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro", "pro")
|
||||
|
||||
def test_filter_models(self):
|
||||
"""Test filtering a list of models based on restrictions."""
|
||||
@@ -124,8 +124,8 @@ class TestModelRestrictionService:
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o3")
|
||||
|
||||
# Google should allow both models via shorthands
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash-preview-05-20", "flash")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05", "pro")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash", "flash")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro", "pro")
|
||||
|
||||
# Also test that full names work when specified in restrictions
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini", "o3mini") # Even with shorthand
|
||||
@@ -238,7 +238,7 @@ class TestProviderIntegration:
|
||||
provider.get_capabilities("o3")
|
||||
assert "not allowed by restriction policy" in str(exc_info.value)
|
||||
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash-preview-05-20,flash"})
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash,flash"})
|
||||
def test_gemini_provider_respects_restrictions(self):
|
||||
"""Test that Gemini provider respects restrictions."""
|
||||
# Clear any cached restriction service
|
||||
@@ -250,11 +250,11 @@ class TestProviderIntegration:
|
||||
|
||||
# Should validate allowed models (both shorthand and full name allowed)
|
||||
assert provider.validate_model_name("flash")
|
||||
assert provider.validate_model_name("gemini-2.5-flash-preview-05-20")
|
||||
assert provider.validate_model_name("gemini-2.5-flash")
|
||||
|
||||
# Should not validate disallowed model
|
||||
assert not provider.validate_model_name("pro")
|
||||
assert not provider.validate_model_name("gemini-2.5-pro-preview-06-05")
|
||||
assert not provider.validate_model_name("gemini-2.5-pro")
|
||||
|
||||
# get_capabilities should raise for disallowed model
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -288,13 +288,13 @@ class TestProviderIntegration:
|
||||
|
||||
# Should allow getting capabilities for "flash"
|
||||
capabilities = provider.get_capabilities("flash")
|
||||
assert capabilities.model_name == "gemini-2.5-flash-preview-05-20"
|
||||
assert capabilities.model_name == "gemini-2.5-flash"
|
||||
|
||||
# Test the edge case: Try to use full model name when only alias is allowed
|
||||
# This should NOT be allowed - only the alias "flash" is in the restriction list
|
||||
assert not provider.validate_model_name("gemini-2.5-flash-preview-05-20")
|
||||
assert not provider.validate_model_name("gemini-2.5-flash")
|
||||
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash-preview-05-20"})
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
|
||||
def test_gemini_parameter_order_edge_case_full_name_only(self):
|
||||
"""Test parameter order with only full name allowed, not alias.
|
||||
|
||||
@@ -310,7 +310,7 @@ class TestProviderIntegration:
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
# Should allow full name
|
||||
assert provider.validate_model_name("gemini-2.5-flash-preview-05-20")
|
||||
assert provider.validate_model_name("gemini-2.5-flash")
|
||||
|
||||
# Should also allow alias that resolves to allowed full name
|
||||
# This works because is_allowed checks both resolved_name and original_name
|
||||
@@ -318,7 +318,7 @@ class TestProviderIntegration:
|
||||
|
||||
# Should not allow "pro" alias
|
||||
assert not provider.validate_model_name("pro")
|
||||
assert not provider.validate_model_name("gemini-2.5-pro-preview-06-05")
|
||||
assert not provider.validate_model_name("gemini-2.5-pro")
|
||||
|
||||
|
||||
class TestCustomProviderOpenRouterRestrictions:
|
||||
@@ -469,8 +469,8 @@ class TestRegistryIntegration:
|
||||
|
||||
mock_gemini = MagicMock()
|
||||
mock_gemini.SUPPORTED_MODELS = {
|
||||
"gemini-2.5-pro-preview-06-05": {"context_window": 1048576},
|
||||
"gemini-2.5-flash-preview-05-20": {"context_window": 1048576},
|
||||
"gemini-2.5-pro": {"context_window": 1048576},
|
||||
"gemini-2.5-flash": {"context_window": 1048576},
|
||||
}
|
||||
mock_gemini.get_provider_type.return_value = ProviderType.GOOGLE
|
||||
|
||||
@@ -493,8 +493,8 @@ class TestRegistryIntegration:
|
||||
|
||||
mock_gemini.list_models = gemini_list_models
|
||||
mock_gemini.list_all_known_models.return_value = [
|
||||
"gemini-2.5-pro-preview-06-05",
|
||||
"gemini-2.5-flash-preview-05-20",
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
]
|
||||
|
||||
def get_provider_side_effect(provider_type):
|
||||
@@ -514,7 +514,7 @@ class TestRegistryIntegration:
|
||||
}
|
||||
|
||||
with patch.dict(
|
||||
os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini", "GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash-preview-05-20"}
|
||||
os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini", "GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}
|
||||
):
|
||||
# Clear cached restriction service
|
||||
import utils.model_restrictions
|
||||
@@ -526,8 +526,8 @@ class TestRegistryIntegration:
|
||||
# Should only include allowed models
|
||||
assert "o3-mini" in available
|
||||
assert "o3" not in available
|
||||
assert "gemini-2.5-flash-preview-05-20" in available
|
||||
assert "gemini-2.5-pro-preview-06-05" not in available
|
||||
assert "gemini-2.5-flash" in available
|
||||
assert "gemini-2.5-pro" not in available
|
||||
|
||||
|
||||
class TestShorthandRestrictions:
|
||||
@@ -552,7 +552,7 @@ class TestShorthandRestrictions:
|
||||
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
|
||||
# Same for Gemini - if you restrict to "flash", you can't use the full name
|
||||
assert not gemini_provider.validate_model_name("gemini-2.5-flash-preview-05-20") # Not allowed
|
||||
assert not gemini_provider.validate_model_name("gemini-2.5-flash") # Not allowed
|
||||
assert not gemini_provider.validate_model_name("pro") # Not allowed
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
|
||||
@@ -579,7 +579,7 @@ class TestShorthandRestrictions:
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"OPENAI_ALLOWED_MODELS": "mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,gemini-2.5-flash-preview-05-20"},
|
||||
{"OPENAI_ALLOWED_MODELS": "mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,gemini-2.5-flash"},
|
||||
)
|
||||
def test_both_shorthand_and_full_name_allowed(self):
|
||||
"""Test that we can allow both shorthand and full names."""
|
||||
@@ -596,7 +596,7 @@ class TestShorthandRestrictions:
|
||||
# Gemini - both flash and full name are allowed
|
||||
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||
assert gemini_provider.validate_model_name("flash")
|
||||
assert gemini_provider.validate_model_name("gemini-2.5-flash-preview-05-20")
|
||||
assert gemini_provider.validate_model_name("gemini-2.5-flash")
|
||||
|
||||
|
||||
class TestAutoModeWithRestrictions:
|
||||
@@ -688,7 +688,7 @@ class TestAutoModeWithRestrictions:
|
||||
|
||||
# The fallback will depend on how get_available_models handles aliases
|
||||
# For now, we accept either behavior and document it
|
||||
assert model in ["o4-mini", "gemini-2.5-flash-preview-05-20"]
|
||||
assert model in ["o4-mini", "gemini-2.5-flash"]
|
||||
finally:
|
||||
# Restore original registry state
|
||||
registry = ModelProviderRegistry()
|
||||
|
||||
Reference in New Issue
Block a user