fix: listmodels to always honor restricted models
fix: restrictions should resolve canonical names for openrouter fix: tools now correctly return restricted list by presenting model names in schema fix: tests updated to ensure these manage their expected env vars properly perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
@@ -49,17 +49,32 @@ class TestModelRestrictionService:
|
||||
def test_load_multiple_models_restriction(self):
|
||||
"""Test loading multiple allowed models."""
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
||||
service = ModelRestrictionService()
|
||||
# Instantiate providers so alias resolution for allow-lists is available
|
||||
openai_provider = OpenAIModelProvider(api_key="test-key")
|
||||
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
# Check OpenAI models
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o3")
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Check Google models
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "flash")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "pro")
|
||||
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||
def fake_get_provider(provider_type, force_new=False):
|
||||
mapping = {
|
||||
ProviderType.OPENAI: openai_provider,
|
||||
ProviderType.GOOGLE: gemini_provider,
|
||||
}
|
||||
return mapping.get(provider_type)
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider):
|
||||
|
||||
service = ModelRestrictionService()
|
||||
|
||||
# Check OpenAI models
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o3")
|
||||
|
||||
# Check Google models
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "flash")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "pro")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||
|
||||
def test_case_insensitive_and_whitespace_handling(self):
|
||||
"""Test that model names are case-insensitive and whitespace is trimmed."""
|
||||
@@ -111,13 +126,17 @@ class TestModelRestrictionService:
|
||||
|
||||
def test_shorthand_names_in_restrictions(self):
|
||||
"""Test that shorthand names work in restrictions."""
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,o3-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
||||
# Instantiate providers so the registry can resolve aliases
|
||||
OpenAIModelProvider(api_key="test-key")
|
||||
GeminiModelProvider(api_key="test-key")
|
||||
|
||||
service = ModelRestrictionService()
|
||||
|
||||
# When providers check models, they pass both resolved and original names
|
||||
# OpenAI: 'mini' shorthand allows o4-mini
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "mini") # How providers actually call it
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o4-mini") # Direct check without original (for testing)
|
||||
# OpenAI: 'o4mini' shorthand allows o4-mini
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini") # How providers actually call it
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini") # Canonical should also be allowed
|
||||
|
||||
# OpenAI: o3-mini allowed directly
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
@@ -280,19 +299,25 @@ class TestProviderIntegration:
|
||||
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
# Test case: Only alias "flash" is allowed, not the full name
|
||||
# If parameters are in wrong order, this test will catch it
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Should allow "flash" alias
|
||||
assert provider.validate_model_name("flash")
|
||||
with patch.object(ModelProviderRegistry, "get_provider", return_value=provider):
|
||||
|
||||
# Should allow getting capabilities for "flash"
|
||||
capabilities = provider.get_capabilities("flash")
|
||||
assert capabilities.model_name == "gemini-2.5-flash"
|
||||
# Test case: Only alias "flash" is allowed, not the full name
|
||||
# If parameters are in wrong order, this test will catch it
|
||||
|
||||
# Test the edge case: Try to use full model name when only alias is allowed
|
||||
# This should NOT be allowed - only the alias "flash" is in the restriction list
|
||||
assert not provider.validate_model_name("gemini-2.5-flash")
|
||||
# Should allow "flash" alias
|
||||
assert provider.validate_model_name("flash")
|
||||
|
||||
# Should allow getting capabilities for "flash"
|
||||
capabilities = provider.get_capabilities("flash")
|
||||
assert capabilities.model_name == "gemini-2.5-flash"
|
||||
|
||||
# Canonical form should also be allowed now that alias is on the allowlist
|
||||
assert provider.validate_model_name("gemini-2.5-flash")
|
||||
# Unrelated models remain blocked
|
||||
assert not provider.validate_model_name("pro")
|
||||
assert not provider.validate_model_name("gemini-2.5-pro")
|
||||
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
|
||||
def test_gemini_parameter_order_edge_case_full_name_only(self):
|
||||
@@ -570,17 +595,27 @@ class TestShorthandRestrictions:
|
||||
|
||||
# Test OpenAI provider
|
||||
openai_provider = OpenAIModelProvider(api_key="test-key")
|
||||
assert openai_provider.validate_model_name("mini") # Should work with shorthand
|
||||
# When restricting to "mini", you can't use "o4-mini" directly - this is correct behavior
|
||||
assert not openai_provider.validate_model_name("o4-mini") # Not allowed - only shorthand is allowed
|
||||
assert not openai_provider.validate_model_name("o3-mini") # Not allowed
|
||||
|
||||
# Test Gemini provider
|
||||
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
|
||||
# Same for Gemini - if you restrict to "flash", you can't use the full name
|
||||
assert not gemini_provider.validate_model_name("gemini-2.5-flash") # Not allowed
|
||||
assert not gemini_provider.validate_model_name("pro") # Not allowed
|
||||
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
def registry_side_effect(provider_type, force_new=False):
|
||||
mapping = {
|
||||
ProviderType.OPENAI: openai_provider,
|
||||
ProviderType.GOOGLE: gemini_provider,
|
||||
}
|
||||
return mapping.get(provider_type)
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect):
|
||||
assert openai_provider.validate_model_name("mini") # Should work with shorthand
|
||||
assert openai_provider.validate_model_name("gpt-5-mini") # Canonical resolved from shorthand
|
||||
assert not openai_provider.validate_model_name("o4-mini") # Unrelated model still blocked
|
||||
assert not openai_provider.validate_model_name("o3-mini")
|
||||
|
||||
# Test Gemini provider
|
||||
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
|
||||
assert gemini_provider.validate_model_name("gemini-2.5-flash") # Canonical allowed
|
||||
assert not gemini_provider.validate_model_name("pro") # Not allowed
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
|
||||
def test_multiple_shorthands_for_same_model(self):
|
||||
@@ -596,9 +631,9 @@ class TestShorthandRestrictions:
|
||||
assert openai_provider.validate_model_name("mini") # mini -> o4-mini
|
||||
assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini
|
||||
|
||||
# Resolved names work only if explicitly allowed
|
||||
# Resolved names should be allowed when their shorthands are present
|
||||
assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed
|
||||
assert not openai_provider.validate_model_name("o3-mini") # Not explicitly allowed, only shorthand
|
||||
assert openai_provider.validate_model_name("o3-mini") # Allowed via shorthand
|
||||
|
||||
# Other models should not work
|
||||
assert not openai_provider.validate_model_name("o3")
|
||||
|
||||
Reference in New Issue
Block a user