362 lines
14 KiB
Python
362 lines
14 KiB
Python
"""Tests for OpenCode Zen provider."""
|
|
|
|
import os
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
|
|
from providers.registry import ModelProviderRegistry
|
|
from providers.shared import ProviderType
|
|
from providers.zen import ZenProvider
|
|
|
|
|
|
class TestZenProvider:
|
|
"""Test cases for OpenCode Zen provider."""
|
|
|
|
def test_provider_initialization(self):
|
|
"""Test Zen provider initialization."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
assert provider.api_key == "test-key"
|
|
assert provider.base_url == "https://opencode.ai/zen/v1"
|
|
assert provider.FRIENDLY_NAME == "OpenCode Zen"
|
|
|
|
def test_get_provider_type(self):
|
|
"""Test provider type identification."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
assert provider.get_provider_type() == ProviderType.ZEN
|
|
|
|
def test_model_validation(self):
|
|
"""Test model validation."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
|
|
# Zen accepts models that are in the registry
|
|
assert provider.validate_model_name("claude-sonnet-4-5") is True
|
|
assert provider.validate_model_name("gpt-5.1-codex") is True
|
|
|
|
# Unknown models are rejected
|
|
assert provider.validate_model_name("unknown-model") is False
|
|
|
|
def test_get_capabilities(self):
|
|
"""Test capability generation."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
|
|
# Test with a model in the registry
|
|
caps = provider.get_capabilities("claude-sonnet-4-5")
|
|
assert caps.provider == ProviderType.ZEN
|
|
assert caps.model_name == "claude-sonnet-4-5"
|
|
assert caps.friendly_name == "OpenCode Zen (claude-sonnet-4-5)"
|
|
|
|
# Test with a model not in registry - should raise error
|
|
with pytest.raises(ValueError, match="Unsupported model 'unknown-model' for provider zen"):
|
|
provider.get_capabilities("unknown-model")
|
|
|
|
def test_model_alias_resolution(self):
|
|
"""Test model alias resolution."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
|
|
# Test alias resolution
|
|
assert provider._resolve_model_name("zen-sonnet") == "claude-sonnet-4-5"
|
|
assert provider._resolve_model_name("zen-sonnet4.5") == "claude-sonnet-4-5"
|
|
assert provider._resolve_model_name("zen-codex") == "gpt-5.1-codex"
|
|
assert provider._resolve_model_name("zen-gpt-codex") == "gpt-5.1-codex"
|
|
|
|
# Test case-insensitive
|
|
assert provider._resolve_model_name("ZEN-SONNET") == "claude-sonnet-4-5"
|
|
assert provider._resolve_model_name("Zen-Codex") == "gpt-5.1-codex"
|
|
|
|
# Test direct model names (should pass through unchanged)
|
|
assert provider._resolve_model_name("claude-sonnet-4-5") == "claude-sonnet-4-5"
|
|
assert provider._resolve_model_name("gpt-5.1-codex") == "gpt-5.1-codex"
|
|
|
|
# Test unknown models pass through
|
|
assert provider._resolve_model_name("unknown-model") == "unknown-model"
|
|
|
|
def test_list_models(self):
|
|
"""Test model listing with various options."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
|
|
# Test basic model listing
|
|
models = provider.list_models()
|
|
assert isinstance(models, list)
|
|
assert len(models) > 0
|
|
|
|
# Should include our configured models
|
|
assert "claude-sonnet-4-5" in models
|
|
assert "gpt-5.1-codex" in models
|
|
|
|
# Should include aliases
|
|
assert "zen-sonnet" in models
|
|
assert "zen-codex" in models
|
|
|
|
def test_list_models_with_options(self):
|
|
"""Test model listing with different options."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
|
|
# Test without aliases
|
|
models_no_aliases = provider.list_models(include_aliases=False)
|
|
assert "zen-sonnet" not in models_no_aliases
|
|
assert "claude-sonnet-4-5" in models_no_aliases
|
|
|
|
# Test lowercase
|
|
models_lower = provider.list_models(lowercase=True)
|
|
assert all(model == model.lower() for model in models_lower)
|
|
|
|
def test_registry_capabilities(self):
|
|
"""Test that registry capabilities are properly loaded."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
|
|
# Test that we have a registry
|
|
assert provider._registry is not None
|
|
|
|
# Test getting all capabilities
|
|
capabilities = provider.get_all_model_capabilities()
|
|
assert isinstance(capabilities, dict)
|
|
assert len(capabilities) > 0
|
|
|
|
# Should include our configured models
|
|
assert "claude-sonnet-4-5" in capabilities
|
|
assert "gpt-5.1-codex" in capabilities
|
|
|
|
# Check capability structure
|
|
caps = capabilities["claude-sonnet-4-5"]
|
|
assert caps.provider == ProviderType.ZEN
|
|
assert caps.context_window == 200000
|
|
assert caps.intelligence_score == 17
|
|
|
|
def test_model_capabilities_lookup(self):
|
|
"""Test capability lookup for known and unknown models."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
|
|
# Test known model
|
|
caps = provider._lookup_capabilities("claude-sonnet-4-5")
|
|
assert caps is not None
|
|
assert caps.provider == ProviderType.ZEN
|
|
|
|
# Test unknown model returns None (base class handles error)
|
|
caps = provider._lookup_capabilities("unknown-zen-model")
|
|
assert caps is None
|
|
|
|
def test_zen_registration(self):
|
|
"""Test Zen can be registered and retrieved."""
|
|
with patch.dict(os.environ, {"ZEN_API_KEY": "test-key"}):
|
|
# Clean up any existing registration
|
|
ModelProviderRegistry.unregister_provider(ProviderType.ZEN)
|
|
|
|
# Register the provider
|
|
ModelProviderRegistry.register_provider(ProviderType.ZEN, ZenProvider)
|
|
|
|
# Retrieve and verify
|
|
provider = ModelProviderRegistry.get_provider(ProviderType.ZEN)
|
|
assert provider is not None
|
|
assert isinstance(provider, ZenProvider)
|
|
|
|
|
|
class TestZenAutoMode:
|
|
"""Test auto mode functionality when only Zen is configured."""
|
|
|
|
def setup_method(self):
|
|
"""Store original state before each test."""
|
|
self.registry = ModelProviderRegistry()
|
|
self._original_providers = self.registry._providers.copy()
|
|
self._original_initialized = self.registry._initialized_providers.copy()
|
|
|
|
# Clear the registry state for this test
|
|
self.registry._providers.clear()
|
|
self.registry._initialized_providers.clear()
|
|
|
|
self._original_env = {}
|
|
for key in ["ZEN_API_KEY", "GEMINI_API_KEY", "OPENAI_API_KEY", "DEFAULT_MODEL"]:
|
|
self._original_env[key] = os.environ.get(key)
|
|
|
|
def teardown_method(self):
|
|
"""Restore original state after each test."""
|
|
self.registry._providers.clear()
|
|
self.registry._initialized_providers.clear()
|
|
self.registry._providers.update(self._original_providers)
|
|
self.registry._initialized_providers.update(self._original_initialized)
|
|
|
|
for key, value in self._original_env.items():
|
|
if value is None:
|
|
os.environ.pop(key, None)
|
|
else:
|
|
os.environ[key] = value
|
|
|
|
@pytest.mark.no_mock_provider
|
|
def test_zen_only_auto_mode(self):
|
|
"""Test that auto mode works when only Zen is configured."""
|
|
os.environ.pop("GEMINI_API_KEY", None)
|
|
os.environ.pop("OPENAI_API_KEY", None)
|
|
os.environ["ZEN_API_KEY"] = "test-zen-key"
|
|
os.environ["DEFAULT_MODEL"] = "auto"
|
|
|
|
mock_registry = Mock()
|
|
model_names = [
|
|
"claude-sonnet-4-5",
|
|
"claude-haiku-4-5",
|
|
"gpt-5.1-codex",
|
|
"gemini-3-pro",
|
|
"glm-4.6",
|
|
]
|
|
mock_registry.list_models.return_value = model_names
|
|
|
|
# Mock resolve to return a ModelCapabilities-like object for each model
|
|
def mock_resolve(model_name):
|
|
if model_name in model_names:
|
|
mock_config = Mock()
|
|
mock_config.provider = ProviderType.ZEN
|
|
mock_config.aliases = [] # Empty list of aliases
|
|
mock_config.get_effective_capability_rank = Mock(return_value=50) # Add ranking method
|
|
return mock_config
|
|
return None
|
|
|
|
mock_registry.resolve.side_effect = mock_resolve
|
|
|
|
ModelProviderRegistry.register_provider(ProviderType.ZEN, ZenProvider)
|
|
|
|
provider = ModelProviderRegistry.get_provider(ProviderType.ZEN)
|
|
assert provider is not None, "Zen provider should be available with API key"
|
|
provider._registry = mock_registry
|
|
|
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
|
|
|
assert len(available_models) > 0, "Should find Zen models in auto mode"
|
|
assert all(provider_type == ProviderType.ZEN for provider_type in available_models.values())
|
|
|
|
for model in model_names:
|
|
assert model in available_models, f"Model {model} should be available"
|
|
|
|
|
|
class TestZenIntegration:
|
|
"""Integration tests for Zen provider with server components."""
|
|
|
|
def test_zen_provider_in_server_init(self):
|
|
"""Test that Zen provider is properly handled during server initialization."""
|
|
# This test verifies that the server can handle Zen provider configuration
|
|
# without actual server startup
|
|
with patch.dict(os.environ, {"ZEN_API_KEY": "test-integration-key"}):
|
|
# Import server module to trigger provider setup
|
|
from providers.registry import ModelProviderRegistry
|
|
|
|
# Verify Zen provider can be registered
|
|
ModelProviderRegistry.register_provider(ProviderType.ZEN, ZenProvider)
|
|
provider = ModelProviderRegistry.get_provider(ProviderType.ZEN)
|
|
assert provider is not None
|
|
assert isinstance(provider, ZenProvider)
|
|
|
|
def test_zen_config_loading(self):
|
|
"""Test that Zen configuration loads properly in integration context."""
|
|
with patch.dict(os.environ, {"ZEN_API_KEY": "test-config-key"}):
|
|
from providers.registries.zen import ZenModelRegistry
|
|
|
|
# Test registry loads configuration
|
|
registry = ZenModelRegistry()
|
|
models = registry.list_models()
|
|
aliases = registry.list_aliases()
|
|
|
|
assert len(models) > 0, "Should load models from zen_models.json"
|
|
assert len(aliases) > 0, "Should load aliases from zen_models.json"
|
|
|
|
# Verify specific models are loaded
|
|
assert "claude-sonnet-4-5" in models
|
|
assert "zen-sonnet" in aliases
|
|
|
|
def test_zen_provider_priority(self):
|
|
"""Test that Zen provider follows correct priority order."""
|
|
# Zen should be prioritized after native APIs but before OpenRouter
|
|
from providers.registry import ModelProviderRegistry
|
|
|
|
priority_order = ModelProviderRegistry.PROVIDER_PRIORITY_ORDER
|
|
zen_index = priority_order.index(ProviderType.ZEN)
|
|
openrouter_index = priority_order.index(ProviderType.OPENROUTER)
|
|
|
|
# Zen should come before OpenRouter in priority
|
|
assert zen_index < openrouter_index, "Zen should have higher priority than OpenRouter"
|
|
|
|
|
|
class TestZenAPIMocking:
|
|
"""Test API interactions with mocked OpenAI SDK."""
|
|
|
|
def test_chat_completion_mock(self):
|
|
"""Test chat completion with mocked API response."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
|
|
# Mock the OpenAI client and response
|
|
mock_response = Mock()
|
|
mock_response.choices = [Mock()]
|
|
mock_response.choices[0].message.content = "Mocked response from Zen"
|
|
mock_response.usage = Mock()
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 20
|
|
|
|
with patch.object(provider.client.chat.completions, "create", return_value=mock_response):
|
|
# Test the completion method - this will initialize the client
|
|
response = provider.complete(
|
|
model="claude-sonnet-4-5", messages=[{"role": "user", "content": "Hello"}], temperature=0.7
|
|
)
|
|
|
|
assert response.content == "Mocked response from Zen"
|
|
|
|
def test_streaming_completion_mock(self):
|
|
"""Test streaming completion with mocked API."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
|
|
# Mock streaming response
|
|
mock_chunk1 = Mock()
|
|
mock_chunk1.choices = [Mock()]
|
|
mock_chunk1.choices[0].delta.content = "Hello"
|
|
mock_chunk1.choices[0].finish_reason = None
|
|
|
|
mock_chunk2 = Mock()
|
|
mock_chunk2.choices = [Mock()]
|
|
mock_chunk2.choices[0].delta.content = " world!"
|
|
mock_chunk2.choices[0].finish_reason = "stop"
|
|
|
|
mock_stream = [mock_chunk1, mock_chunk2]
|
|
|
|
# Access client to initialize it first
|
|
_ = provider.client
|
|
with patch.object(provider.client.chat.completions, "create", return_value=mock_stream):
|
|
# Test streaming completion
|
|
stream = provider.complete_stream(
|
|
model="gpt-5.1-codex",
|
|
messages=[{"role": "user", "content": "Say hello"}],
|
|
)
|
|
|
|
chunks = list(stream)
|
|
assert len(chunks) == 2
|
|
assert chunks[0].content == "Hello"
|
|
assert chunks[1].content == " world!"
|
|
|
|
def test_api_error_handling(self):
|
|
"""Test error handling for API failures."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
|
|
# Mock API error
|
|
from openai import APIError
|
|
|
|
api_error = APIError("Mock API error", request=Mock(), body="error details")
|
|
|
|
with patch.object(provider._client.chat.completions, "create", side_effect=api_error):
|
|
with pytest.raises(APIError):
|
|
provider.complete(model="claude-sonnet-4-5", messages=[{"role": "user", "content": "Test"}])
|
|
|
|
def test_invalid_model_error(self):
|
|
"""Test error handling for invalid models."""
|
|
provider = ZenProvider(api_key="test-key")
|
|
|
|
with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider zen"):
|
|
provider.get_capabilities("invalid-model")
|
|
|
|
def test_authentication_error(self):
|
|
"""Test handling of authentication errors."""
|
|
provider = ZenProvider(api_key="invalid-key")
|
|
|
|
# Mock authentication error
|
|
from openai import AuthenticationError
|
|
|
|
auth_error = AuthenticationError("Invalid API key", request=Mock(), body="auth failed")
|
|
|
|
with patch.object(provider._client.chat.completions, "create", side_effect=auth_error):
|
|
with pytest.raises(AuthenticationError):
|
|
provider.complete(model="claude-sonnet-4-5", messages=[{"role": "user", "content": "Test"}])
|