Add encouraging message about powerful models to schema in case it's not on Opus 4 or above
OPENROUTER_ALLOWED_MODELS environment variable support to further limit the models to allow from within Claude. This will put a limit on top of even the ones listed in custom_models.json
This commit is contained in:
@@ -24,10 +24,13 @@ class TestModelRestrictionService:
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash-preview-05-20")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-opus")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "openai/o3")
|
||||
|
||||
# Should have no restrictions
|
||||
assert not service.has_restrictions(ProviderType.OPENAI)
|
||||
assert not service.has_restrictions(ProviderType.GOOGLE)
|
||||
assert not service.has_restrictions(ProviderType.OPENROUTER)
|
||||
|
||||
def test_load_single_model_restriction(self):
|
||||
"""Test loading a single allowed model."""
|
||||
@@ -39,8 +42,9 @@ class TestModelRestrictionService:
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o3")
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o4-mini")
|
||||
|
||||
# Google should have no restrictions
|
||||
# Google and OpenRouter should have no restrictions
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro-preview-06-05")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-opus")
|
||||
|
||||
def test_load_multiple_models_restriction(self):
|
||||
"""Test loading multiple allowed models."""
|
||||
@@ -146,6 +150,68 @@ class TestModelRestrictionService:
|
||||
assert "o4-mimi" in caplog.text
|
||||
assert "not a recognized" in caplog.text
|
||||
|
||||
def test_openrouter_model_restrictions(self):
|
||||
"""Test OpenRouter model restrictions functionality."""
|
||||
with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet"}):
|
||||
service = ModelRestrictionService()
|
||||
|
||||
# Should only allow specified OpenRouter models
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "opus")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "sonnet")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-opus", "opus") # With original name
|
||||
assert not service.is_allowed(ProviderType.OPENROUTER, "haiku")
|
||||
assert not service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-haiku")
|
||||
assert not service.is_allowed(ProviderType.OPENROUTER, "mistral-large")
|
||||
|
||||
# Other providers should have no restrictions
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "pro")
|
||||
|
||||
# Should have restrictions for OpenRouter
|
||||
assert service.has_restrictions(ProviderType.OPENROUTER)
|
||||
assert not service.has_restrictions(ProviderType.OPENAI)
|
||||
assert not service.has_restrictions(ProviderType.GOOGLE)
|
||||
|
||||
def test_openrouter_filter_models(self):
|
||||
"""Test filtering OpenRouter models based on restrictions."""
|
||||
with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,mistral"}):
|
||||
service = ModelRestrictionService()
|
||||
|
||||
models = ["opus", "sonnet", "haiku", "mistral", "llama"]
|
||||
filtered = service.filter_models(ProviderType.OPENROUTER, models)
|
||||
|
||||
assert filtered == ["opus", "mistral"]
|
||||
|
||||
def test_combined_provider_restrictions(self):
|
||||
"""Test that restrictions work correctly when set for multiple providers."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"OPENAI_ALLOWED_MODELS": "o3-mini",
|
||||
"GOOGLE_ALLOWED_MODELS": "flash",
|
||||
"OPENROUTER_ALLOWED_MODELS": "opus,sonnet",
|
||||
},
|
||||
):
|
||||
service = ModelRestrictionService()
|
||||
|
||||
# OpenAI restrictions
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o3")
|
||||
|
||||
# Google restrictions
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "flash")
|
||||
assert not service.is_allowed(ProviderType.GOOGLE, "pro")
|
||||
|
||||
# OpenRouter restrictions
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "opus")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "sonnet")
|
||||
assert not service.is_allowed(ProviderType.OPENROUTER, "haiku")
|
||||
|
||||
# All providers should have restrictions
|
||||
assert service.has_restrictions(ProviderType.OPENAI)
|
||||
assert service.has_restrictions(ProviderType.GOOGLE)
|
||||
assert service.has_restrictions(ProviderType.OPENROUTER)
|
||||
|
||||
|
||||
class TestProviderIntegration:
|
||||
"""Test integration with actual providers."""
|
||||
@@ -195,6 +261,96 @@ class TestProviderIntegration:
|
||||
assert "not allowed by restriction policy" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestCustomProviderOpenRouterRestrictions:
|
||||
"""Test custom provider integration with OpenRouter restrictions."""
|
||||
|
||||
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet", "OPENROUTER_API_KEY": "test-key"})
|
||||
def test_custom_provider_respects_openrouter_restrictions(self):
|
||||
"""Test that custom provider respects OpenRouter restrictions for cloud models."""
|
||||
# Clear any cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
provider = CustomProvider(base_url="http://test.com/v1")
|
||||
|
||||
# Should validate allowed OpenRouter models (is_custom=false)
|
||||
assert provider.validate_model_name("opus")
|
||||
assert provider.validate_model_name("sonnet")
|
||||
|
||||
# Should not validate disallowed OpenRouter models
|
||||
assert not provider.validate_model_name("haiku")
|
||||
|
||||
# Should still validate custom models (is_custom=true) regardless of restrictions
|
||||
assert provider.validate_model_name("local-llama") # This has is_custom=true
|
||||
|
||||
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus", "OPENROUTER_API_KEY": "test-key"})
|
||||
def test_custom_provider_openrouter_capabilities_restrictions(self):
|
||||
"""Test that custom provider's get_capabilities respects OpenRouter restrictions."""
|
||||
# Clear any cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
provider = CustomProvider(base_url="http://test.com/v1")
|
||||
|
||||
# Should work for allowed OpenRouter model
|
||||
capabilities = provider.get_capabilities("opus")
|
||||
assert capabilities.provider == ProviderType.OPENROUTER
|
||||
|
||||
# Should raise for disallowed OpenRouter model
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
provider.get_capabilities("haiku")
|
||||
assert "not allowed by restriction policy" in str(exc_info.value)
|
||||
|
||||
# Should still work for custom models (is_custom=true)
|
||||
capabilities = provider.get_capabilities("local-llama")
|
||||
assert capabilities.provider == ProviderType.CUSTOM
|
||||
|
||||
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus"}, clear=False)
|
||||
def test_custom_provider_no_openrouter_key_ignores_restrictions(self):
|
||||
"""Test that when OpenRouter key is not set, cloud models are rejected regardless of restrictions."""
|
||||
# Make sure OPENROUTER_API_KEY is not set
|
||||
if "OPENROUTER_API_KEY" in os.environ:
|
||||
del os.environ["OPENROUTER_API_KEY"]
|
||||
# Clear any cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
provider = CustomProvider(base_url="http://test.com/v1")
|
||||
|
||||
# Should not validate OpenRouter models when key is not available
|
||||
assert not provider.validate_model_name("opus") # Even though it's in allowed list
|
||||
assert not provider.validate_model_name("haiku")
|
||||
|
||||
# Should still validate custom models
|
||||
assert provider.validate_model_name("local-llama")
|
||||
|
||||
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "", "OPENROUTER_API_KEY": "test-key"})
|
||||
def test_custom_provider_empty_restrictions_allows_all_openrouter(self):
|
||||
"""Test that empty OPENROUTER_ALLOWED_MODELS allows all OpenRouter models."""
|
||||
# Clear any cached restriction service
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
from providers.custom import CustomProvider
|
||||
|
||||
provider = CustomProvider(base_url="http://test.com/v1")
|
||||
|
||||
# Should validate all OpenRouter models when restrictions are empty
|
||||
assert provider.validate_model_name("opus")
|
||||
assert provider.validate_model_name("sonnet")
|
||||
assert provider.validate_model_name("haiku")
|
||||
|
||||
|
||||
class TestRegistryIntegration:
|
||||
"""Test integration with ModelProviderRegistry."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user