added opencode zen as provider
This commit is contained in:
@@ -82,9 +82,20 @@ def project_path(tmp_path):
|
||||
return test_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zen_provider():
|
||||
"""
|
||||
Provides a Zen provider instance for testing.
|
||||
Uses dummy API key for isolated testing.
|
||||
"""
|
||||
from providers.zen import ZenProvider
|
||||
|
||||
return ZenProvider(api_key="test-zen-key")
|
||||
|
||||
|
||||
def _set_dummy_keys_if_missing():
|
||||
"""Set dummy API keys only when they are completely absent."""
|
||||
for var in ("GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"):
|
||||
for var in ("GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "ZEN_API_KEY"):
|
||||
if not os.environ.get(var):
|
||||
os.environ[var] = "dummy-key-for-tests"
|
||||
|
||||
|
||||
166
tests/test_zen_model_registry.py
Normal file
166
tests/test_zen_model_registry.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Tests for OpenCode Zen model registry functionality."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from providers.registries.zen import ZenModelRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestZenModelRegistry:
|
||||
"""Test cases for Zen model registry."""
|
||||
|
||||
def test_registry_initialization(self):
|
||||
"""Test registry initializes with default config."""
|
||||
registry = ZenModelRegistry()
|
||||
|
||||
# Should load models from default location
|
||||
assert len(registry.list_models()) > 0
|
||||
assert len(registry.list_aliases()) > 0
|
||||
|
||||
# Should include our configured models
|
||||
assert "claude-sonnet-4-5" in registry.list_models()
|
||||
assert "gpt-5.1-codex" in registry.list_models()
|
||||
|
||||
def test_custom_config_path(self):
|
||||
"""Test registry with custom config path."""
|
||||
# Create temporary config
|
||||
config_data = {
|
||||
"models": [
|
||||
{
|
||||
"model_name": "test/zen-model-1",
|
||||
"aliases": ["zen-test1", "zt1"],
|
||||
"context_window": 4096,
|
||||
"max_output_tokens": 2048,
|
||||
"intelligence_score": 15,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
registry = ZenModelRegistry(config_path=temp_path)
|
||||
assert len(registry.list_models()) == 1
|
||||
assert "test/zen-model-1" in registry.list_models()
|
||||
assert "zen-test1" in registry.list_aliases()
|
||||
assert "zt1" in registry.list_aliases()
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_get_capabilities(self):
|
||||
"""Test capability retrieval."""
|
||||
registry = ZenModelRegistry()
|
||||
|
||||
# Test getting capabilities for a known model
|
||||
caps = registry.get_capabilities("claude-sonnet-4-5")
|
||||
assert caps is not None
|
||||
assert caps.provider == ProviderType.ZEN
|
||||
assert caps.model_name == "claude-sonnet-4-5"
|
||||
assert caps.friendly_name == "OpenCode Zen (claude-sonnet-4-5)"
|
||||
assert caps.context_window == 200000
|
||||
assert caps.intelligence_score == 17
|
||||
|
||||
# Test getting capabilities for unknown model
|
||||
caps = registry.get_capabilities("unknown-model")
|
||||
assert caps is None
|
||||
|
||||
def test_resolve_model(self):
|
||||
"""Test model resolution with aliases."""
|
||||
registry = ZenModelRegistry()
|
||||
|
||||
# Test resolving a direct model name
|
||||
config = registry.resolve("claude-sonnet-4-5")
|
||||
assert config is not None
|
||||
assert config.model_name == "claude-sonnet-4-5"
|
||||
|
||||
# Test resolving an alias
|
||||
config = registry.resolve("zen-sonnet")
|
||||
assert config is not None
|
||||
assert config.model_name == "claude-sonnet-4-5"
|
||||
|
||||
# Test resolving unknown model
|
||||
config = registry.resolve("unknown-model")
|
||||
assert config is None
|
||||
|
||||
def test_list_aliases(self):
|
||||
"""Test alias listing."""
|
||||
registry = ZenModelRegistry()
|
||||
|
||||
aliases = registry.list_aliases()
|
||||
assert isinstance(aliases, list)
|
||||
assert len(aliases) > 0
|
||||
|
||||
# Should include our configured aliases
|
||||
assert "zen-sonnet" in aliases
|
||||
assert "zen-codex" in aliases
|
||||
assert "zen-gemini" in aliases
|
||||
|
||||
def test_environment_config_path(self):
|
||||
"""Test registry respects environment variable for config path."""
|
||||
config_data = {
|
||||
"models": [
|
||||
{
|
||||
"model_name": "env/test-model",
|
||||
"aliases": ["env-test"],
|
||||
"context_window": 8192,
|
||||
"max_output_tokens": 4096,
|
||||
"intelligence_score": 10,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
with patch.dict(os.environ, {"ZEN_MODELS_CONFIG_PATH": temp_path}):
|
||||
registry = ZenModelRegistry()
|
||||
assert "env/test-model" in registry.list_models()
|
||||
assert "env-test" in registry.list_aliases()
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_malformed_config(self):
|
||||
"""Test registry handles malformed config gracefully."""
|
||||
malformed_config = {
|
||||
"models": [
|
||||
{
|
||||
"model_name": "test/bad-model",
|
||||
# Missing required fields
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(malformed_config, f)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
registry = ZenModelRegistry(config_path=temp_path)
|
||||
# Should still initialize but model may not load properly
|
||||
# This tests error handling in config loading
|
||||
registry.list_models() # Test that this doesn't crash
|
||||
# May or may not include the malformed model depending on validation
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_empty_config(self):
|
||||
"""Test registry with empty config."""
|
||||
empty_config = {"models": []}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(empty_config, f)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
registry = ZenModelRegistry(config_path=temp_path)
|
||||
assert len(registry.list_models()) == 0
|
||||
assert len(registry.list_aliases()) == 0
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
361
tests/test_zen_provider.py
Normal file
361
tests/test_zen_provider.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""Tests for OpenCode Zen provider."""
|
||||
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from providers.zen import ZenProvider
|
||||
|
||||
|
||||
class TestZenProvider:
|
||||
"""Test cases for OpenCode Zen provider."""
|
||||
|
||||
def test_provider_initialization(self):
|
||||
"""Test Zen provider initialization."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.base_url == "https://opencode.ai/zen/v1"
|
||||
assert provider.FRIENDLY_NAME == "OpenCode Zen"
|
||||
|
||||
def test_get_provider_type(self):
|
||||
"""Test provider type identification."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
assert provider.get_provider_type() == ProviderType.ZEN
|
||||
|
||||
def test_model_validation(self):
|
||||
"""Test model validation."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
|
||||
# Zen accepts models that are in the registry
|
||||
assert provider.validate_model_name("claude-sonnet-4-5") is True
|
||||
assert provider.validate_model_name("gpt-5.1-codex") is True
|
||||
|
||||
# Unknown models are rejected
|
||||
assert provider.validate_model_name("unknown-model") is False
|
||||
|
||||
def test_get_capabilities(self):
|
||||
"""Test capability generation."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
|
||||
# Test with a model in the registry
|
||||
caps = provider.get_capabilities("claude-sonnet-4-5")
|
||||
assert caps.provider == ProviderType.ZEN
|
||||
assert caps.model_name == "claude-sonnet-4-5"
|
||||
assert caps.friendly_name == "OpenCode Zen (claude-sonnet-4-5)"
|
||||
|
||||
# Test with a model not in registry - should raise error
|
||||
with pytest.raises(ValueError, match="Unsupported model 'unknown-model' for provider zen"):
|
||||
provider.get_capabilities("unknown-model")
|
||||
|
||||
def test_model_alias_resolution(self):
|
||||
"""Test model alias resolution."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
|
||||
# Test alias resolution
|
||||
assert provider._resolve_model_name("zen-sonnet") == "claude-sonnet-4-5"
|
||||
assert provider._resolve_model_name("zen-sonnet4.5") == "claude-sonnet-4-5"
|
||||
assert provider._resolve_model_name("zen-codex") == "gpt-5.1-codex"
|
||||
assert provider._resolve_model_name("zen-gpt-codex") == "gpt-5.1-codex"
|
||||
|
||||
# Test case-insensitive
|
||||
assert provider._resolve_model_name("ZEN-SONNET") == "claude-sonnet-4-5"
|
||||
assert provider._resolve_model_name("Zen-Codex") == "gpt-5.1-codex"
|
||||
|
||||
# Test direct model names (should pass through unchanged)
|
||||
assert provider._resolve_model_name("claude-sonnet-4-5") == "claude-sonnet-4-5"
|
||||
assert provider._resolve_model_name("gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
|
||||
# Test unknown models pass through
|
||||
assert provider._resolve_model_name("unknown-model") == "unknown-model"
|
||||
|
||||
def test_list_models(self):
|
||||
"""Test model listing with various options."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
|
||||
# Test basic model listing
|
||||
models = provider.list_models()
|
||||
assert isinstance(models, list)
|
||||
assert len(models) > 0
|
||||
|
||||
# Should include our configured models
|
||||
assert "claude-sonnet-4-5" in models
|
||||
assert "gpt-5.1-codex" in models
|
||||
|
||||
# Should include aliases
|
||||
assert "zen-sonnet" in models
|
||||
assert "zen-codex" in models
|
||||
|
||||
def test_list_models_with_options(self):
|
||||
"""Test model listing with different options."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
|
||||
# Test without aliases
|
||||
models_no_aliases = provider.list_models(include_aliases=False)
|
||||
assert "zen-sonnet" not in models_no_aliases
|
||||
assert "claude-sonnet-4-5" in models_no_aliases
|
||||
|
||||
# Test lowercase
|
||||
models_lower = provider.list_models(lowercase=True)
|
||||
assert all(model == model.lower() for model in models_lower)
|
||||
|
||||
def test_registry_capabilities(self):
|
||||
"""Test that registry capabilities are properly loaded."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
|
||||
# Test that we have a registry
|
||||
assert provider._registry is not None
|
||||
|
||||
# Test getting all capabilities
|
||||
capabilities = provider.get_all_model_capabilities()
|
||||
assert isinstance(capabilities, dict)
|
||||
assert len(capabilities) > 0
|
||||
|
||||
# Should include our configured models
|
||||
assert "claude-sonnet-4-5" in capabilities
|
||||
assert "gpt-5.1-codex" in capabilities
|
||||
|
||||
# Check capability structure
|
||||
caps = capabilities["claude-sonnet-4-5"]
|
||||
assert caps.provider == ProviderType.ZEN
|
||||
assert caps.context_window == 200000
|
||||
assert caps.intelligence_score == 17
|
||||
|
||||
def test_model_capabilities_lookup(self):
|
||||
"""Test capability lookup for known and unknown models."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
|
||||
# Test known model
|
||||
caps = provider._lookup_capabilities("claude-sonnet-4-5")
|
||||
assert caps is not None
|
||||
assert caps.provider == ProviderType.ZEN
|
||||
|
||||
# Test unknown model returns None (base class handles error)
|
||||
caps = provider._lookup_capabilities("unknown-zen-model")
|
||||
assert caps is None
|
||||
|
||||
def test_zen_registration(self):
|
||||
"""Test Zen can be registered and retrieved."""
|
||||
with patch.dict(os.environ, {"ZEN_API_KEY": "test-key"}):
|
||||
# Clean up any existing registration
|
||||
ModelProviderRegistry.unregister_provider(ProviderType.ZEN)
|
||||
|
||||
# Register the provider
|
||||
ModelProviderRegistry.register_provider(ProviderType.ZEN, ZenProvider)
|
||||
|
||||
# Retrieve and verify
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.ZEN)
|
||||
assert provider is not None
|
||||
assert isinstance(provider, ZenProvider)
|
||||
|
||||
|
||||
class TestZenAutoMode:
|
||||
"""Test auto mode functionality when only Zen is configured."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Store original state before each test."""
|
||||
self.registry = ModelProviderRegistry()
|
||||
self._original_providers = self.registry._providers.copy()
|
||||
self._original_initialized = self.registry._initialized_providers.copy()
|
||||
|
||||
# Clear the registry state for this test
|
||||
self.registry._providers.clear()
|
||||
self.registry._initialized_providers.clear()
|
||||
|
||||
self._original_env = {}
|
||||
for key in ["ZEN_API_KEY", "GEMINI_API_KEY", "OPENAI_API_KEY", "DEFAULT_MODEL"]:
|
||||
self._original_env[key] = os.environ.get(key)
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore original state after each test."""
|
||||
self.registry._providers.clear()
|
||||
self.registry._initialized_providers.clear()
|
||||
self.registry._providers.update(self._original_providers)
|
||||
self.registry._initialized_providers.update(self._original_initialized)
|
||||
|
||||
for key, value in self._original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
def test_zen_only_auto_mode(self):
|
||||
"""Test that auto mode works when only Zen is configured."""
|
||||
os.environ.pop("GEMINI_API_KEY", None)
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ["ZEN_API_KEY"] = "test-zen-key"
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
|
||||
mock_registry = Mock()
|
||||
model_names = [
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
"gpt-5.1-codex",
|
||||
"gemini-3-pro",
|
||||
"glm-4.6",
|
||||
]
|
||||
mock_registry.list_models.return_value = model_names
|
||||
|
||||
# Mock resolve to return a ModelCapabilities-like object for each model
|
||||
def mock_resolve(model_name):
|
||||
if model_name in model_names:
|
||||
mock_config = Mock()
|
||||
mock_config.provider = ProviderType.ZEN
|
||||
mock_config.aliases = [] # Empty list of aliases
|
||||
mock_config.get_effective_capability_rank = Mock(return_value=50) # Add ranking method
|
||||
return mock_config
|
||||
return None
|
||||
|
||||
mock_registry.resolve.side_effect = mock_resolve
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.ZEN, ZenProvider)
|
||||
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.ZEN)
|
||||
assert provider is not None, "Zen provider should be available with API key"
|
||||
provider._registry = mock_registry
|
||||
|
||||
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
|
||||
assert len(available_models) > 0, "Should find Zen models in auto mode"
|
||||
assert all(provider_type == ProviderType.ZEN for provider_type in available_models.values())
|
||||
|
||||
for model in model_names:
|
||||
assert model in available_models, f"Model {model} should be available"
|
||||
|
||||
|
||||
class TestZenIntegration:
|
||||
"""Integration tests for Zen provider with server components."""
|
||||
|
||||
def test_zen_provider_in_server_init(self):
|
||||
"""Test that Zen provider is properly handled during server initialization."""
|
||||
# This test verifies that the server can handle Zen provider configuration
|
||||
# without actual server startup
|
||||
with patch.dict(os.environ, {"ZEN_API_KEY": "test-integration-key"}):
|
||||
# Import server module to trigger provider setup
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Verify Zen provider can be registered
|
||||
ModelProviderRegistry.register_provider(ProviderType.ZEN, ZenProvider)
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.ZEN)
|
||||
assert provider is not None
|
||||
assert isinstance(provider, ZenProvider)
|
||||
|
||||
def test_zen_config_loading(self):
|
||||
"""Test that Zen configuration loads properly in integration context."""
|
||||
with patch.dict(os.environ, {"ZEN_API_KEY": "test-config-key"}):
|
||||
from providers.registries.zen import ZenModelRegistry
|
||||
|
||||
# Test registry loads configuration
|
||||
registry = ZenModelRegistry()
|
||||
models = registry.list_models()
|
||||
aliases = registry.list_aliases()
|
||||
|
||||
assert len(models) > 0, "Should load models from zen_models.json"
|
||||
assert len(aliases) > 0, "Should load aliases from zen_models.json"
|
||||
|
||||
# Verify specific models are loaded
|
||||
assert "claude-sonnet-4-5" in models
|
||||
assert "zen-sonnet" in aliases
|
||||
|
||||
def test_zen_provider_priority(self):
|
||||
"""Test that Zen provider follows correct priority order."""
|
||||
# Zen should be prioritized after native APIs but before OpenRouter
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
priority_order = ModelProviderRegistry.PROVIDER_PRIORITY_ORDER
|
||||
zen_index = priority_order.index(ProviderType.ZEN)
|
||||
openrouter_index = priority_order.index(ProviderType.OPENROUTER)
|
||||
|
||||
# Zen should come before OpenRouter in priority
|
||||
assert zen_index < openrouter_index, "Zen should have higher priority than OpenRouter"
|
||||
|
||||
|
||||
class TestZenAPIMocking:
|
||||
"""Test API interactions with mocked OpenAI SDK."""
|
||||
|
||||
def test_chat_completion_mock(self):
|
||||
"""Test chat completion with mocked API response."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
|
||||
# Mock the OpenAI client and response
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock()]
|
||||
mock_response.choices[0].message.content = "Mocked response from Zen"
|
||||
mock_response.usage = Mock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 20
|
||||
|
||||
with patch.object(provider.client.chat.completions, "create", return_value=mock_response):
|
||||
# Test the completion method - this will initialize the client
|
||||
response = provider.complete(
|
||||
model="claude-sonnet-4-5", messages=[{"role": "user", "content": "Hello"}], temperature=0.7
|
||||
)
|
||||
|
||||
assert response.content == "Mocked response from Zen"
|
||||
|
||||
def test_streaming_completion_mock(self):
|
||||
"""Test streaming completion with mocked API."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
|
||||
# Mock streaming response
|
||||
mock_chunk1 = Mock()
|
||||
mock_chunk1.choices = [Mock()]
|
||||
mock_chunk1.choices[0].delta.content = "Hello"
|
||||
mock_chunk1.choices[0].finish_reason = None
|
||||
|
||||
mock_chunk2 = Mock()
|
||||
mock_chunk2.choices = [Mock()]
|
||||
mock_chunk2.choices[0].delta.content = " world!"
|
||||
mock_chunk2.choices[0].finish_reason = "stop"
|
||||
|
||||
mock_stream = [mock_chunk1, mock_chunk2]
|
||||
|
||||
# Access client to initialize it first
|
||||
_ = provider.client
|
||||
with patch.object(provider.client.chat.completions, "create", return_value=mock_stream):
|
||||
# Test streaming completion
|
||||
stream = provider.complete_stream(
|
||||
model="gpt-5.1-codex",
|
||||
messages=[{"role": "user", "content": "Say hello"}],
|
||||
)
|
||||
|
||||
chunks = list(stream)
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].content == "Hello"
|
||||
assert chunks[1].content == " world!"
|
||||
|
||||
def test_api_error_handling(self):
|
||||
"""Test error handling for API failures."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
|
||||
# Mock API error
|
||||
from openai import APIError
|
||||
|
||||
api_error = APIError("Mock API error", request=Mock(), body="error details")
|
||||
|
||||
with patch.object(provider._client.chat.completions, "create", side_effect=api_error):
|
||||
with pytest.raises(APIError):
|
||||
provider.complete(model="claude-sonnet-4-5", messages=[{"role": "user", "content": "Test"}])
|
||||
|
||||
def test_invalid_model_error(self):
|
||||
"""Test error handling for invalid models."""
|
||||
provider = ZenProvider(api_key="test-key")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider zen"):
|
||||
provider.get_capabilities("invalid-model")
|
||||
|
||||
def test_authentication_error(self):
|
||||
"""Test handling of authentication errors."""
|
||||
provider = ZenProvider(api_key="invalid-key")
|
||||
|
||||
# Mock authentication error
|
||||
from openai import AuthenticationError
|
||||
|
||||
auth_error = AuthenticationError("Invalid API key", request=Mock(), body="auth failed")
|
||||
|
||||
with patch.object(provider._client.chat.completions, "create", side_effect=auth_error):
|
||||
with pytest.raises(AuthenticationError):
|
||||
provider.complete(model="claude-sonnet-4-5", messages=[{"role": "user", "content": "Test"}])
|
||||
Reference in New Issue
Block a user