feat: Azure OpenAI / Azure AI Foundry support. Models should be defined in conf/azure_models.json (or a custom path). See .env.example for environment variables or see readme. https://github.com/BeehiveInnovations/zen-mcp-server/issues/265 feat: OpenRouter / Custom Models / Azure can separately also use custom config paths now (see .env.example ) refactor: Model registry class made abstract, OpenRouter / Custom Provider / Azure OpenAI now subclass these refactor: breaking change: `is_custom` property has been removed from model_capabilities.py (and thus custom_models.json) given each models are now read from separate configuration files
322 lines
13 KiB
Python
322 lines
13 KiB
Python
"""Tests for CustomProvider functionality."""
|
|
|
|
import os
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from providers import ModelProviderRegistry
|
|
from providers.custom import CustomProvider
|
|
from providers.shared import ProviderType
|
|
|
|
|
|
class TestCustomProvider:
|
|
"""Test CustomProvider class functionality."""
|
|
|
|
def test_provider_initialization_with_params(self):
|
|
"""Test CustomProvider initializes correctly with explicit parameters."""
|
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
|
|
|
assert provider.base_url == "http://localhost:11434/v1"
|
|
assert provider.api_key == "test-key"
|
|
assert provider.get_provider_type() == ProviderType.CUSTOM
|
|
|
|
def test_provider_initialization_with_env_vars(self):
|
|
"""Test CustomProvider initializes correctly with environment variables."""
|
|
with patch.dict(os.environ, {"CUSTOM_API_URL": "http://localhost:8000/v1", "CUSTOM_API_KEY": "env-key"}):
|
|
provider = CustomProvider()
|
|
|
|
assert provider.base_url == "http://localhost:8000/v1"
|
|
assert provider.api_key == "env-key"
|
|
|
|
def test_provider_initialization_missing_url(self):
|
|
"""Test CustomProvider raises error when URL is missing."""
|
|
with patch.dict(os.environ, {"CUSTOM_API_URL": ""}, clear=False):
|
|
with pytest.raises(ValueError, match="Custom API URL must be provided"):
|
|
CustomProvider(api_key="test-key")
|
|
|
|
def test_validate_model_names_always_true(self):
|
|
"""Test CustomProvider validates model names correctly."""
|
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
|
|
|
# Known model should validate
|
|
assert provider.validate_model_name("llama3.2")
|
|
|
|
# For custom provider, unknown models return False when not in registry
|
|
# This is expected behavior - custom models need to be declared in custom_models.json
|
|
assert not provider.validate_model_name("unknown-model")
|
|
assert not provider.validate_model_name("anything")
|
|
|
|
def test_get_capabilities_from_registry(self):
|
|
"""Test get_capabilities returns registry capabilities when available."""
|
|
# Save original environment
|
|
original_env = os.environ.get("OPENROUTER_ALLOWED_MODELS")
|
|
|
|
try:
|
|
# Clear any restrictions
|
|
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
|
|
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
|
|
|
# OpenRouter-backed models should be handled by the OpenRouter provider
|
|
with pytest.raises(ValueError):
|
|
provider.get_capabilities("o3")
|
|
|
|
# Test with a custom model from the local registry
|
|
capabilities = provider.get_capabilities("local-llama")
|
|
assert capabilities.provider == ProviderType.CUSTOM
|
|
assert capabilities.context_window > 0
|
|
|
|
finally:
|
|
# Restore original environment
|
|
if original_env is None:
|
|
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
|
else:
|
|
os.environ["OPENROUTER_ALLOWED_MODELS"] = original_env
|
|
|
|
def test_get_capabilities_generic_fallback(self):
|
|
"""Test get_capabilities raises error for unknown models not in registry."""
|
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
|
|
|
# Unknown models should raise ValueError when not in registry
|
|
with pytest.raises(ValueError, match="Unsupported model 'unknown-model-xyz' for provider custom"):
|
|
provider.get_capabilities("unknown-model-xyz")
|
|
|
|
def test_model_alias_resolution(self):
|
|
"""Test model alias resolution works correctly."""
|
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
|
|
|
# Test that aliases resolve properly
|
|
# "llama" now resolves to "meta-llama/llama-3-70b" (the OpenRouter model)
|
|
resolved = provider._resolve_model_name("llama")
|
|
assert resolved == "meta-llama/llama-3-70b"
|
|
|
|
# Test local model alias
|
|
resolved_local = provider._resolve_model_name("local-llama")
|
|
assert resolved_local == "llama3.2"
|
|
|
|
def test_no_thinking_mode_support(self):
|
|
"""Custom provider generic capabilities default to no thinking mode."""
|
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
|
|
|
# llama3.2 is a known model that should work
|
|
assert not provider.get_capabilities("llama3.2").supports_extended_thinking
|
|
|
|
# Unknown models should raise error
|
|
with pytest.raises(ValueError, match="Unsupported model 'any-model' for provider custom"):
|
|
provider.get_capabilities("any-model")
|
|
|
|
@patch("providers.custom.OpenAICompatibleProvider.generate_content")
|
|
def test_generate_content_with_alias_resolution(self, mock_generate):
|
|
"""Test generate_content resolves aliases before calling parent."""
|
|
mock_response = MagicMock()
|
|
mock_generate.return_value = mock_response
|
|
|
|
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
|
|
|
# Call with an alias
|
|
result = provider.generate_content(
|
|
prompt="test prompt",
|
|
model_name="llama",
|
|
temperature=0.7, # This is an alias
|
|
)
|
|
|
|
# Verify parent method was called with resolved model name
|
|
mock_generate.assert_called_once()
|
|
call_args = mock_generate.call_args
|
|
# The model_name should be either resolved or passed through
|
|
assert "model_name" in call_args.kwargs
|
|
assert result == mock_response
|
|
|
|
|
|
class TestCustomProviderRegistration:
|
|
"""Test CustomProvider integration with ModelProviderRegistry."""
|
|
|
|
def setup_method(self):
|
|
"""Clear registry before each test."""
|
|
ModelProviderRegistry.clear_cache()
|
|
ModelProviderRegistry.unregister_provider(ProviderType.CUSTOM)
|
|
|
|
def teardown_method(self):
|
|
"""Clean up after each test."""
|
|
ModelProviderRegistry.clear_cache()
|
|
ModelProviderRegistry.unregister_provider(ProviderType.CUSTOM)
|
|
|
|
def test_custom_provider_factory_registration(self):
|
|
"""Test custom provider can be registered via factory function."""
|
|
|
|
def custom_provider_factory(api_key=None):
|
|
return CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
|
|
|
|
with patch.dict(os.environ, {"CUSTOM_API_PLACEHOLDER": "configured"}):
|
|
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
|
|
|
|
# Verify provider is available
|
|
available = ModelProviderRegistry.get_available_providers()
|
|
assert ProviderType.CUSTOM in available
|
|
|
|
# Verify provider can be retrieved
|
|
provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
|
|
assert provider is not None
|
|
assert isinstance(provider, CustomProvider)
|
|
|
|
def test_dual_provider_setup(self):
|
|
"""Test both OpenRouter and Custom providers can coexist."""
|
|
from providers.openrouter import OpenRouterProvider
|
|
|
|
# Create factory for custom provider
|
|
def custom_provider_factory(api_key=None):
|
|
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
|
|
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"OPENROUTER_API_KEY": "test-openrouter-key",
|
|
"CUSTOM_API_PLACEHOLDER": "configured",
|
|
"OPENROUTER_ALLOWED_MODELS": "llama,anthropic/claude-opus-4.1",
|
|
},
|
|
clear=True,
|
|
):
|
|
# Register both providers
|
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
|
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
|
|
|
|
# Verify both are available
|
|
available = ModelProviderRegistry.get_available_providers()
|
|
assert ProviderType.OPENROUTER in available
|
|
assert ProviderType.CUSTOM in available
|
|
|
|
# Verify both can be retrieved
|
|
openrouter_provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
|
custom_provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
|
|
|
|
assert openrouter_provider is not None
|
|
assert custom_provider is not None
|
|
assert isinstance(custom_provider, CustomProvider)
|
|
|
|
def test_provider_priority_selection(self):
|
|
"""Test provider selection prioritizes correctly."""
|
|
from providers.openrouter import OpenRouterProvider
|
|
|
|
def custom_provider_factory(api_key=None):
|
|
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
|
|
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"OPENROUTER_API_KEY": "test-openrouter-key",
|
|
"CUSTOM_API_PLACEHOLDER": "configured",
|
|
"OPENROUTER_ALLOWED_MODELS": "",
|
|
},
|
|
clear=True,
|
|
):
|
|
import utils.model_restrictions
|
|
|
|
utils.model_restrictions._restriction_service = None
|
|
custom_provider = custom_provider_factory()
|
|
openrouter_provider = OpenRouterProvider(api_key="test-openrouter-key")
|
|
|
|
assert not custom_provider.validate_model_name("llama")
|
|
assert openrouter_provider.validate_model_name("llama")
|
|
|
|
|
|
class TestConfigureProvidersFunction:
|
|
"""Test the configure_providers function in server.py."""
|
|
|
|
def setup_method(self):
|
|
"""Clear environment and registry before each test."""
|
|
# Store the original providers to restore them later
|
|
registry = ModelProviderRegistry()
|
|
self._original_providers = registry._providers.copy()
|
|
ModelProviderRegistry.clear_cache()
|
|
for provider_type in ProviderType:
|
|
ModelProviderRegistry.unregister_provider(provider_type)
|
|
|
|
def teardown_method(self):
|
|
"""Clean up after each test."""
|
|
# Restore the original providers that were registered in conftest.py
|
|
registry = ModelProviderRegistry()
|
|
ModelProviderRegistry.clear_cache()
|
|
registry._providers.clear()
|
|
registry._providers.update(self._original_providers)
|
|
|
|
def test_configure_providers_custom_only(self):
|
|
"""Test configure_providers with only custom URL set."""
|
|
from server import configure_providers
|
|
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"CUSTOM_API_URL": "http://localhost:11434/v1",
|
|
"CUSTOM_API_KEY": "",
|
|
# Clear other API keys
|
|
"GEMINI_API_KEY": "",
|
|
"OPENAI_API_KEY": "",
|
|
"OPENROUTER_API_KEY": "",
|
|
},
|
|
clear=True,
|
|
):
|
|
configure_providers()
|
|
|
|
# Verify only custom provider is available
|
|
available = ModelProviderRegistry.get_available_providers()
|
|
assert ProviderType.CUSTOM in available
|
|
assert ProviderType.OPENROUTER not in available
|
|
|
|
def test_configure_providers_openrouter_only(self):
|
|
"""Test configure_providers with only OpenRouter key set."""
|
|
from server import configure_providers
|
|
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"OPENROUTER_API_KEY": "test-key",
|
|
# Clear other API keys
|
|
"GEMINI_API_KEY": "",
|
|
"OPENAI_API_KEY": "",
|
|
"CUSTOM_API_URL": "",
|
|
},
|
|
clear=True,
|
|
):
|
|
configure_providers()
|
|
|
|
# Verify only OpenRouter provider is available
|
|
available = ModelProviderRegistry.get_available_providers()
|
|
assert ProviderType.OPENROUTER in available
|
|
assert ProviderType.CUSTOM not in available
|
|
|
|
def test_configure_providers_dual_setup(self):
|
|
"""Test configure_providers with both OpenRouter and Custom configured."""
|
|
from server import configure_providers
|
|
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"OPENROUTER_API_KEY": "test-openrouter-key",
|
|
"CUSTOM_API_URL": "http://localhost:11434/v1",
|
|
"CUSTOM_API_KEY": "",
|
|
# Clear other API keys
|
|
"GEMINI_API_KEY": "",
|
|
"OPENAI_API_KEY": "",
|
|
},
|
|
clear=True,
|
|
):
|
|
configure_providers()
|
|
|
|
# Verify both providers are available
|
|
available = ModelProviderRegistry.get_available_providers()
|
|
assert ProviderType.OPENROUTER in available
|
|
assert ProviderType.CUSTOM in available
|
|
|
|
def test_configure_providers_no_valid_keys(self):
|
|
"""Test configure_providers raises error when no valid API keys."""
|
|
from server import configure_providers
|
|
|
|
with patch.dict(
|
|
os.environ,
|
|
{"GEMINI_API_KEY": "", "OPENAI_API_KEY": "", "OPENROUTER_API_KEY": "", "CUSTOM_API_URL": ""},
|
|
clear=True,
|
|
):
|
|
with pytest.raises(ValueError, match="At least one API configuration is required"):
|
|
configure_providers()
|