added opencode zen as provider

This commit is contained in:
2025-12-22 23:13:29 +01:00
parent 7afc7c1cc9
commit c71a535f16
14 changed files with 956 additions and 13 deletions

View File

@@ -8,6 +8,7 @@ from .openai_compatible import OpenAICompatibleProvider
from .openrouter import OpenRouterProvider
from .registry import ModelProviderRegistry
from .shared import ModelCapabilities, ModelResponse
from .zen import ZenProvider
__all__ = [
"ModelProvider",
@@ -19,4 +20,5 @@ __all__ = [
"OpenAIModelProvider",
"OpenAICompatibleProvider",
"OpenRouterProvider",
"ZenProvider",
]

View File

@@ -0,0 +1,35 @@
"""OpenCode Zen model registry for managing model configurations and aliases."""
from __future__ import annotations
from ..shared import ModelCapabilities, ProviderType
from .base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
class ZenModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by ``conf/zen_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(
env_var_name="ZEN_MODELS_CONFIG_PATH",
default_filename="zen_models.json",
provider=ProviderType.ZEN,
friendly_prefix="OpenCode Zen ({model})",
config_path=config_path,
)
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
provider_override = entry.get("provider")
if isinstance(provider_override, str):
entry_provider = ProviderType(provider_override.lower())
elif isinstance(provider_override, ProviderType):
entry_provider = provider_override
else:
entry_provider = ProviderType.ZEN
entry.setdefault("friendly_name", f"OpenCode Zen ({entry['model_name']})")
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
filtered.setdefault("provider", entry_provider)
capability = ModelCapabilities(**filtered)
return capability, {}

View File

@@ -40,6 +40,7 @@ class ModelProviderRegistry:
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.AZURE, # Azure-hosted OpenAI deployments
ProviderType.XAI, # Direct X.AI GROK access
ProviderType.ZEN, # OpenCode Zen curated models
ProviderType.DIAL, # DIAL unified API access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
@@ -336,6 +337,7 @@ class ModelProviderRegistry:
ProviderType.OPENAI: "OPENAI_API_KEY",
ProviderType.AZURE: "AZURE_OPENAI_API_KEY",
ProviderType.XAI: "XAI_API_KEY",
ProviderType.ZEN: "ZEN_API_KEY",
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
ProviderType.DIAL: "DIAL_API_KEY",

View File

@@ -15,3 +15,4 @@ class ProviderType(Enum):
OPENROUTER = "openrouter"
CUSTOM = "custom"
DIAL = "dial"
ZEN = "zen"

141
providers/zen.py Normal file
View File

@@ -0,0 +1,141 @@
"""OpenCode Zen provider implementation."""
import logging
from .openai_compatible import OpenAICompatibleProvider
from .registries.zen import ZenModelRegistry
from .shared import (
ModelCapabilities,
ProviderType,
)
class ZenProvider(OpenAICompatibleProvider):
"""Client for OpenCode Zen's curated model service.
Role
Surface OpenCode Zen's tested and verified models through the same interface as
native providers so tools can reference Zen models without special cases.
Characteristics
* Pulls model definitions from :class:`ZenModelRegistry`
(capabilities, metadata, pricing information)
* Reuses :class:`OpenAICompatibleProvider` infrastructure for request
execution so Zen endpoints behave like standard OpenAI-style APIs.
* Supports OpenCode Zen's curated list of coding-focused models.
"""
FRIENDLY_NAME = "OpenCode Zen"
# Model registry for managing configurations
_registry: ZenModelRegistry | None = None
def __init__(self, api_key: str, **kwargs):
"""Initialize OpenCode Zen provider.
Args:
api_key: OpenCode Zen API key
**kwargs: Additional configuration
"""
base_url = "https://opencode.ai/zen/v1"
super().__init__(api_key, base_url=base_url, **kwargs)
# Initialize model registry
if ZenProvider._registry is None:
ZenProvider._registry = ZenModelRegistry()
# Log loaded models only on first load
models = self._registry.list_models()
logging.info(f"OpenCode Zen loaded {len(models)} models")
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: str | None = None,
) -> ModelCapabilities | None:
"""Fetch Zen capabilities from the registry."""
capabilities = self._registry.get_capabilities(canonical_name)
if capabilities:
return capabilities
# For unknown models, return None to let base class handle error
logging.debug("Model '%s' not found in Zen registry", canonical_name)
return None
# ------------------------------------------------------------------
# Provider identity
# ------------------------------------------------------------------
def get_provider_type(self) -> ProviderType:
"""Identify this provider for restrictions and logging."""
return ProviderType.ZEN
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
def list_models(
self,
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Return formatted Zen model names, respecting restrictions."""
if not self._registry:
return []
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
allowed_configs: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if not config:
continue
if restriction_service:
if not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
allowed_configs[model_name] = config
if not allowed_configs:
return []
return ModelCapabilities.collect_model_names(
allowed_configs,
include_aliases=include_aliases,
lowercase=lowercase,
unique=unique,
)
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve aliases defined in the Zen registry."""
config = self._registry.resolve(model_name)
if config and config.model_name != model_name:
logging.debug("Resolved Zen model alias '%s' to '%s'", model_name, config.model_name)
return config.model_name
return model_name
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Expose registry-backed Zen capabilities."""
if not self._registry:
return {}
capabilities: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if config:
capabilities[model_name] = config
return capabilities