Fixed restriction checks for OpenRouter
This commit is contained in:
@@ -45,18 +45,32 @@ class TestCustomProvider:
|
||||
|
||||
def test_get_capabilities_from_registry(self):
|
||||
"""Test get_capabilities returns registry capabilities when available."""
|
||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||
# Save original environment
|
||||
original_env = os.environ.get("OPENROUTER_ALLOWED_MODELS")
|
||||
|
||||
# Test with a model that should be in the registry (OpenRouter model) and is allowed by restrictions
|
||||
capabilities = provider.get_capabilities("o3") # o3 is in OPENROUTER_ALLOWED_MODELS
|
||||
try:
|
||||
# Clear any restrictions
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||
|
||||
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
|
||||
assert capabilities.context_window > 0
|
||||
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
||||
|
||||
# Test with a custom model (is_custom=true)
|
||||
capabilities = provider.get_capabilities("local-llama")
|
||||
assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true
|
||||
assert capabilities.context_window > 0
|
||||
# Test with a model that should be in the registry (OpenRouter model)
|
||||
capabilities = provider.get_capabilities("o3") # o3 is an OpenRouter model
|
||||
|
||||
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
|
||||
assert capabilities.context_window > 0
|
||||
|
||||
# Test with a custom model (is_custom=true)
|
||||
capabilities = provider.get_capabilities("local-llama")
|
||||
assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true
|
||||
assert capabilities.context_window > 0
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
if original_env is None:
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||
else:
|
||||
os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env
|
||||
|
||||
def test_get_capabilities_generic_fallback(self):
|
||||
"""Test get_capabilities returns generic capabilities for unknown models."""
|
||||
|
||||
Reference in New Issue
Block a user