diff --git a/providers/registry.py b/providers/registry.py index 1e795e5..09166ad 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -183,6 +183,16 @@ class ModelProviderRegistry: continue models[model_name] = provider_type + elif provider_type == ProviderType.OPENROUTER: + # OpenRouter uses a registry system instead of SUPPORTED_MODELS + if hasattr(provider, "_registry") and provider._registry: + for model_name in provider._registry.list_models(): + # Check restrictions if enabled + if restriction_service and not restriction_service.is_allowed(provider_type, model_name): + logging.debug(f"Model {model_name} filtered by restrictions") + continue + + models[model_name] = provider_type return models diff --git a/tests/test_openrouter_provider.py b/tests/test_openrouter_provider.py index 81cdeb2..49a5fed 100644 --- a/tests/test_openrouter_provider.py +++ b/tests/test_openrouter_provider.py @@ -1,7 +1,9 @@ """Tests for OpenRouter provider.""" import os -from unittest.mock import patch +from unittest.mock import Mock, patch + +import pytest from providers.base import ProviderType from providers.openrouter import OpenRouterProvider @@ -110,6 +112,139 @@ class TestOpenRouterProvider: assert isinstance(provider, OpenRouterProvider) +class TestOpenRouterAutoMode: + """Test auto mode functionality when only OpenRouter is configured.""" + + def setup_method(self): + """Store original state before each test.""" + self.registry = ModelProviderRegistry() + self._original_providers = self.registry._providers.copy() + self._original_initialized = self.registry._initialized_providers.copy() + + self.registry._providers.clear() + self.registry._initialized_providers.clear() + + self._original_env = {} + for key in ['OPENROUTER_API_KEY', 'GEMINI_API_KEY', 'OPENAI_API_KEY', 'DEFAULT_MODEL']: + self._original_env[key] = os.environ.get(key) + + def teardown_method(self): + """Restore original state after each test.""" + self.registry._providers.clear() + self.registry._initialized_providers.clear() + self.registry._providers.update(self._original_providers) + self.registry._initialized_providers.update(self._original_initialized) + + for key, value in self._original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + @pytest.mark.no_mock_provider + def test_openrouter_only_auto_mode(self): + """Test that auto mode works when only OpenRouter is configured.""" + os.environ.pop('GEMINI_API_KEY', None) + os.environ.pop('OPENAI_API_KEY', None) + os.environ['OPENROUTER_API_KEY'] = 'test-openrouter-key' + os.environ['DEFAULT_MODEL'] = 'auto' + + mock_registry = Mock() + mock_registry.list_models.return_value = [ + 'google/gemini-2.5-flash-preview-05-20', + 'google/gemini-2.5-pro-preview-06-05', + 'openai/o3', + 'openai/o3-mini', + 'anthropic/claude-3-opus', + 'anthropic/claude-3-sonnet' + ] + + ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) + + provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER) + assert provider is not None, "OpenRouter provider should be available with API key" + provider._registry = mock_registry + + available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True) + + assert len(available_models) > 0, "Should find OpenRouter models in auto mode" + assert all(provider_type == ProviderType.OPENROUTER for provider_type in available_models.values()) + + expected_models = mock_registry.list_models() + for model in expected_models: + assert model in available_models, f"Model {model} should be available" + + @pytest.mark.no_mock_provider + def test_openrouter_with_restrictions(self): + """Test that OpenRouter respects model restrictions.""" + os.environ.pop('GEMINI_API_KEY', None) + os.environ.pop('OPENAI_API_KEY', None) + os.environ['OPENROUTER_API_KEY'] = 'test-openrouter-key' + os.environ.pop('OPENROUTER_ALLOWED_MODELS', None) + os.environ['OPENROUTER_ALLOWED_MODELS'] = 'anthropic/claude-3-opus,google/gemini-2.5-flash-preview-05-20' + os.environ['DEFAULT_MODEL'] = 'auto' + + # Force reload to pick up new environment variable + import utils.model_restrictions + utils.model_restrictions._restriction_service = None + + mock_registry = Mock() + mock_registry.list_models.return_value = [ + 'google/gemini-2.5-flash-preview-05-20', + 'google/gemini-2.5-pro-preview-06-05', + 'anthropic/claude-3-opus', + 'anthropic/claude-3-sonnet' + ] + + ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) + + provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER) + provider._registry = mock_registry + + available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True) + + assert len(available_models) > 0, "Should have some allowed models" + + expected_allowed = { + 'google/gemini-2.5-flash-preview-05-20', + 'anthropic/claude-3-opus' + } + + assert set(available_models.keys()) == expected_allowed, \ + f"Expected {expected_allowed}, but got {set(available_models.keys())}" + + @pytest.mark.no_mock_provider + def test_no_providers_fails_auto_mode(self): + """Test that auto mode fails gracefully when no providers are available.""" + os.environ.pop('GEMINI_API_KEY', None) + os.environ.pop('OPENAI_API_KEY', None) + os.environ.pop('OPENROUTER_API_KEY', None) + os.environ['DEFAULT_MODEL'] = 'auto' + + available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True) + + assert len(available_models) == 0, "Should have no models when no providers are configured" + + @pytest.mark.no_mock_provider + def test_openrouter_without_registry(self): + """Test that OpenRouter without _registry attribute doesn't crash.""" + os.environ.pop('GEMINI_API_KEY', None) + os.environ.pop('OPENAI_API_KEY', None) + os.environ['OPENROUTER_API_KEY'] = 'test-openrouter-key' + os.environ['DEFAULT_MODEL'] = 'auto' + + mock_provider_class = Mock() + mock_provider_instance = Mock(spec=['get_provider_type']) + mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER + mock_provider_class.return_value = mock_provider_instance + + ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class) + + available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True) + + assert len(available_models) == 0, "Should have no models when OpenRouter has no registry" + + class TestOpenRouterRegistry: """Test cases for OpenRouter model registry."""