Merge remote-tracking branch 'origin/main'
This commit is contained in:
@@ -183,6 +183,16 @@ class ModelProviderRegistry:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
models[model_name] = provider_type
|
models[model_name] = provider_type
|
||||||
|
elif provider_type == ProviderType.OPENROUTER:
|
||||||
|
# OpenRouter uses a registry system instead of SUPPORTED_MODELS
|
||||||
|
if hasattr(provider, "_registry") and provider._registry:
|
||||||
|
for model_name in provider._registry.list_models():
|
||||||
|
# Check restrictions if enabled
|
||||||
|
if restriction_service and not restriction_service.is_allowed(provider_type, model_name):
|
||||||
|
logging.debug(f"Model {model_name} filtered by restrictions")
|
||||||
|
continue
|
||||||
|
|
||||||
|
models[model_name] = provider_type
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""Tests for OpenRouter provider."""
|
"""Tests for OpenRouter provider."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
from providers.base import ProviderType
|
||||||
from providers.openrouter import OpenRouterProvider
|
from providers.openrouter import OpenRouterProvider
|
||||||
@@ -110,6 +112,138 @@ class TestOpenRouterProvider:
|
|||||||
assert isinstance(provider, OpenRouterProvider)
|
assert isinstance(provider, OpenRouterProvider)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenRouterAutoMode:
|
||||||
|
"""Test auto mode functionality when only OpenRouter is configured."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Store original state before each test."""
|
||||||
|
self.registry = ModelProviderRegistry()
|
||||||
|
self._original_providers = self.registry._providers.copy()
|
||||||
|
self._original_initialized = self.registry._initialized_providers.copy()
|
||||||
|
|
||||||
|
self.registry._providers.clear()
|
||||||
|
self.registry._initialized_providers.clear()
|
||||||
|
|
||||||
|
self._original_env = {}
|
||||||
|
for key in ["OPENROUTER_API_KEY", "GEMINI_API_KEY", "OPENAI_API_KEY", "DEFAULT_MODEL"]:
|
||||||
|
self._original_env[key] = os.environ.get(key)
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Restore original state after each test."""
|
||||||
|
self.registry._providers.clear()
|
||||||
|
self.registry._initialized_providers.clear()
|
||||||
|
self.registry._providers.update(self._original_providers)
|
||||||
|
self.registry._initialized_providers.update(self._original_initialized)
|
||||||
|
|
||||||
|
for key, value in self._original_env.items():
|
||||||
|
if value is None:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
else:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
@pytest.mark.no_mock_provider
|
||||||
|
def test_openrouter_only_auto_mode(self):
|
||||||
|
"""Test that auto mode works when only OpenRouter is configured."""
|
||||||
|
os.environ.pop("GEMINI_API_KEY", None)
|
||||||
|
os.environ.pop("OPENAI_API_KEY", None)
|
||||||
|
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||||
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
|
||||||
|
mock_registry = Mock()
|
||||||
|
mock_registry.list_models.return_value = [
|
||||||
|
"google/gemini-2.5-flash-preview-05-20",
|
||||||
|
"google/gemini-2.5-pro-preview-06-05",
|
||||||
|
"openai/o3",
|
||||||
|
"openai/o3-mini",
|
||||||
|
"anthropic/claude-3-opus",
|
||||||
|
"anthropic/claude-3-sonnet",
|
||||||
|
]
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||||
|
assert provider is not None, "OpenRouter provider should be available with API key"
|
||||||
|
provider._registry = mock_registry
|
||||||
|
|
||||||
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
|
||||||
|
assert len(available_models) > 0, "Should find OpenRouter models in auto mode"
|
||||||
|
assert all(provider_type == ProviderType.OPENROUTER for provider_type in available_models.values())
|
||||||
|
|
||||||
|
expected_models = mock_registry.list_models()
|
||||||
|
for model in expected_models:
|
||||||
|
assert model in available_models, f"Model {model} should be available"
|
||||||
|
|
||||||
|
@pytest.mark.no_mock_provider
|
||||||
|
def test_openrouter_with_restrictions(self):
|
||||||
|
"""Test that OpenRouter respects model restrictions."""
|
||||||
|
os.environ.pop("GEMINI_API_KEY", None)
|
||||||
|
os.environ.pop("OPENAI_API_KEY", None)
|
||||||
|
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||||
|
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||||
|
os.environ["OPENROUTER_ALLOWED_MODELS"] = "anthropic/claude-3-opus,google/gemini-2.5-flash-preview-05-20"
|
||||||
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
|
||||||
|
# Force reload to pick up new environment variable
|
||||||
|
import utils.model_restrictions
|
||||||
|
|
||||||
|
utils.model_restrictions._restriction_service = None
|
||||||
|
|
||||||
|
mock_registry = Mock()
|
||||||
|
mock_registry.list_models.return_value = [
|
||||||
|
"google/gemini-2.5-flash-preview-05-20",
|
||||||
|
"google/gemini-2.5-pro-preview-06-05",
|
||||||
|
"anthropic/claude-3-opus",
|
||||||
|
"anthropic/claude-3-sonnet",
|
||||||
|
]
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||||
|
provider._registry = mock_registry
|
||||||
|
|
||||||
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
|
||||||
|
assert len(available_models) > 0, "Should have some allowed models"
|
||||||
|
|
||||||
|
expected_allowed = {"google/gemini-2.5-flash-preview-05-20", "anthropic/claude-3-opus"}
|
||||||
|
|
||||||
|
assert (
|
||||||
|
set(available_models.keys()) == expected_allowed
|
||||||
|
), f"Expected {expected_allowed}, but got {set(available_models.keys())}"
|
||||||
|
|
||||||
|
@pytest.mark.no_mock_provider
|
||||||
|
def test_no_providers_fails_auto_mode(self):
|
||||||
|
"""Test that auto mode fails gracefully when no providers are available."""
|
||||||
|
os.environ.pop("GEMINI_API_KEY", None)
|
||||||
|
os.environ.pop("OPENAI_API_KEY", None)
|
||||||
|
os.environ.pop("OPENROUTER_API_KEY", None)
|
||||||
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
|
||||||
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
|
||||||
|
assert len(available_models) == 0, "Should have no models when no providers are configured"
|
||||||
|
|
||||||
|
@pytest.mark.no_mock_provider
|
||||||
|
def test_openrouter_without_registry(self):
|
||||||
|
"""Test that OpenRouter without _registry attribute doesn't crash."""
|
||||||
|
os.environ.pop("GEMINI_API_KEY", None)
|
||||||
|
os.environ.pop("OPENAI_API_KEY", None)
|
||||||
|
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||||
|
os.environ["DEFAULT_MODEL"] = "auto"
|
||||||
|
|
||||||
|
mock_provider_class = Mock()
|
||||||
|
mock_provider_instance = Mock(spec=["get_provider_type"])
|
||||||
|
mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class)
|
||||||
|
|
||||||
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
|
||||||
|
assert len(available_models) == 0, "Should have no models when OpenRouter has no registry"
|
||||||
|
|
||||||
|
|
||||||
class TestOpenRouterRegistry:
|
class TestOpenRouterRegistry:
|
||||||
"""Test cases for OpenRouter model registry."""
|
"""Test cases for OpenRouter model registry."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user