Files
my-pal-mcp-server/tests/test_custom_provider.py

322 lines
13 KiB
Python

"""Tests for CustomProvider functionality."""
import os
from unittest.mock import MagicMock, patch
import pytest
from providers import ModelProviderRegistry
from providers.custom import CustomProvider
from providers.shared import ProviderType
class TestCustomProvider:
"""Test CustomProvider class functionality."""
def test_provider_initialization_with_params(self):
"""Test CustomProvider initializes correctly with explicit parameters."""
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
assert provider.base_url == "http://localhost:11434/v1"
assert provider.api_key == "test-key"
assert provider.get_provider_type() == ProviderType.CUSTOM
def test_provider_initialization_with_env_vars(self):
"""Test CustomProvider initializes correctly with environment variables."""
with patch.dict(os.environ, {"CUSTOM_API_URL": "http://localhost:8000/v1", "CUSTOM_API_KEY": "env-key"}):
provider = CustomProvider()
assert provider.base_url == "http://localhost:8000/v1"
assert provider.api_key == "env-key"
def test_provider_initialization_missing_url(self):
"""Test CustomProvider raises error when URL is missing."""
with patch.dict(os.environ, {"CUSTOM_API_URL": ""}, clear=False):
with pytest.raises(ValueError, match="Custom API URL must be provided"):
CustomProvider(api_key="test-key")
def test_validate_model_names_always_true(self):
"""Test CustomProvider validates model names correctly."""
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# Known model should validate
assert provider.validate_model_name("llama3.2")
# For custom provider, unknown models return False when not in registry
# This is expected behavior - custom models need to be declared in custom_models.json
assert not provider.validate_model_name("unknown-model")
assert not provider.validate_model_name("anything")
def test_get_capabilities_from_registry(self):
"""Test get_capabilities returns registry capabilities when available."""
# Save original environment
original_env = os.environ.get("OPENROUTER_ALLOWED_MODELS")
try:
# Clear any restrictions
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# OpenRouter-backed models should be handled by the OpenRouter provider
with pytest.raises(ValueError):
provider.get_capabilities("o3")
# Test with a custom model (is_custom=true)
capabilities = provider.get_capabilities("local-llama")
assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true
assert capabilities.context_window > 0
finally:
# Restore original environment
if original_env is None:
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
else:
os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env
def test_get_capabilities_generic_fallback(self):
"""Test get_capabilities raises error for unknown models not in registry."""
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# Unknown models should raise ValueError when not in registry
with pytest.raises(ValueError, match="Unsupported model 'unknown-model-xyz' for provider custom"):
provider.get_capabilities("unknown-model-xyz")
def test_model_alias_resolution(self):
"""Test model alias resolution works correctly."""
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# Test that aliases resolve properly
# "llama" now resolves to "meta-llama/llama-3-70b" (the OpenRouter model)
resolved = provider._resolve_model_name("llama")
assert resolved == "meta-llama/llama-3-70b"
# Test local model alias
resolved_local = provider._resolve_model_name("local-llama")
assert resolved_local == "llama3.2"
def test_no_thinking_mode_support(self):
"""Custom provider generic capabilities default to no thinking mode."""
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# llama3.2 is a known model that should work
assert not provider.get_capabilities("llama3.2").supports_extended_thinking
# Unknown models should raise error
with pytest.raises(ValueError, match="Unsupported model 'any-model' for provider custom"):
provider.get_capabilities("any-model")
@patch("providers.custom.OpenAICompatibleProvider.generate_content")
def test_generate_content_with_alias_resolution(self, mock_generate):
"""Test generate_content resolves aliases before calling parent."""
mock_response = MagicMock()
mock_generate.return_value = mock_response
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# Call with an alias
result = provider.generate_content(
prompt="test prompt",
model_name="llama",
temperature=0.7, # This is an alias
)
# Verify parent method was called with resolved model name
mock_generate.assert_called_once()
call_args = mock_generate.call_args
# The model_name should be either resolved or passed through
assert "model_name" in call_args.kwargs
assert result == mock_response
class TestCustomProviderRegistration:
"""Test CustomProvider integration with ModelProviderRegistry."""
def setup_method(self):
"""Clear registry before each test."""
ModelProviderRegistry.clear_cache()
ModelProviderRegistry.unregister_provider(ProviderType.CUSTOM)
def teardown_method(self):
"""Clean up after each test."""
ModelProviderRegistry.clear_cache()
ModelProviderRegistry.unregister_provider(ProviderType.CUSTOM)
def test_custom_provider_factory_registration(self):
"""Test custom provider can be registered via factory function."""
def custom_provider_factory(api_key=None):
return CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
with patch.dict(os.environ, {"CUSTOM_API_PLACEHOLDER": "configured"}):
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
# Verify provider is available
available = ModelProviderRegistry.get_available_providers()
assert ProviderType.CUSTOM in available
# Verify provider can be retrieved
provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
assert provider is not None
assert isinstance(provider, CustomProvider)
def test_dual_provider_setup(self):
"""Test both OpenRouter and Custom providers can coexist."""
from providers.openrouter import OpenRouterProvider
# Create factory for custom provider
def custom_provider_factory(api_key=None):
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
with patch.dict(
os.environ,
{
"OPENROUTER_API_KEY": "test-openrouter-key",
"CUSTOM_API_PLACEHOLDER": "configured",
"OPENROUTER_ALLOWED_MODELS": "llama,anthropic/claude-opus-4.1",
},
clear=True,
):
# Register both providers
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
# Verify both are available
available = ModelProviderRegistry.get_available_providers()
assert ProviderType.OPENROUTER in available
assert ProviderType.CUSTOM in available
# Verify both can be retrieved
openrouter_provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
custom_provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
assert openrouter_provider is not None
assert custom_provider is not None
assert isinstance(custom_provider, CustomProvider)
def test_provider_priority_selection(self):
"""Test provider selection prioritizes correctly."""
from providers.openrouter import OpenRouterProvider
def custom_provider_factory(api_key=None):
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
with patch.dict(
os.environ,
{
"OPENROUTER_API_KEY": "test-openrouter-key",
"CUSTOM_API_PLACEHOLDER": "configured",
"OPENROUTER_ALLOWED_MODELS": "",
},
clear=True,
):
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
custom_provider = custom_provider_factory()
openrouter_provider = OpenRouterProvider(api_key="test-openrouter-key")
assert not custom_provider.validate_model_name("llama")
assert openrouter_provider.validate_model_name("llama")
class TestConfigureProvidersFunction:
"""Test the configure_providers function in server.py."""
def setup_method(self):
"""Clear environment and registry before each test."""
# Store the original providers to restore them later
registry = ModelProviderRegistry()
self._original_providers = registry._providers.copy()
ModelProviderRegistry.clear_cache()
for provider_type in ProviderType:
ModelProviderRegistry.unregister_provider(provider_type)
def teardown_method(self):
"""Clean up after each test."""
# Restore the original providers that were registered in conftest.py
registry = ModelProviderRegistry()
ModelProviderRegistry.clear_cache()
registry._providers.clear()
registry._providers.update(self._original_providers)
def test_configure_providers_custom_only(self):
"""Test configure_providers with only custom URL set."""
from server import configure_providers
with patch.dict(
os.environ,
{
"CUSTOM_API_URL": "http://localhost:11434/v1",
"CUSTOM_API_KEY": "",
# Clear other API keys
"GEMINI_API_KEY": "",
"OPENAI_API_KEY": "",
"OPENROUTER_API_KEY": "",
},
clear=True,
):
configure_providers()
# Verify only custom provider is available
available = ModelProviderRegistry.get_available_providers()
assert ProviderType.CUSTOM in available
assert ProviderType.OPENROUTER not in available
def test_configure_providers_openrouter_only(self):
"""Test configure_providers with only OpenRouter key set."""
from server import configure_providers
with patch.dict(
os.environ,
{
"OPENROUTER_API_KEY": "test-key",
# Clear other API keys
"GEMINI_API_KEY": "",
"OPENAI_API_KEY": "",
"CUSTOM_API_URL": "",
},
clear=True,
):
configure_providers()
# Verify only OpenRouter provider is available
available = ModelProviderRegistry.get_available_providers()
assert ProviderType.OPENROUTER in available
assert ProviderType.CUSTOM not in available
def test_configure_providers_dual_setup(self):
"""Test configure_providers with both OpenRouter and Custom configured."""
from server import configure_providers
with patch.dict(
os.environ,
{
"OPENROUTER_API_KEY": "test-openrouter-key",
"CUSTOM_API_URL": "http://localhost:11434/v1",
"CUSTOM_API_KEY": "",
# Clear other API keys
"GEMINI_API_KEY": "",
"OPENAI_API_KEY": "",
},
clear=True,
):
configure_providers()
# Verify both providers are available
available = ModelProviderRegistry.get_available_providers()
assert ProviderType.OPENROUTER in available
assert ProviderType.CUSTOM in available
def test_configure_providers_no_valid_keys(self):
"""Test configure_providers raises error when no valid API keys."""
from server import configure_providers
with patch.dict(
os.environ,
{"GEMINI_API_KEY": "", "OPENAI_API_KEY": "", "OPENROUTER_API_KEY": "", "CUSTOM_API_URL": ""},
clear=True,
):
with pytest.raises(ValueError, match="At least one API configuration is required"):
configure_providers()