Files

142 lines
4.7 KiB
Python

"""OpenCode Zen provider implementation."""
import logging
from .openai_compatible import OpenAICompatibleProvider
from .registries.zen import ZenModelRegistry
from .shared import (
ModelCapabilities,
ProviderType,
)
class ZenProvider(OpenAICompatibleProvider):
"""Client for OpenCode Zen's curated model service.
Role
Surface OpenCode Zen's tested and verified models through the same interface as
native providers so tools can reference Zen models without special cases.
Characteristics
* Pulls model definitions from :class:`ZenModelRegistry`
(capabilities, metadata, pricing information)
* Reuses :class:`OpenAICompatibleProvider` infrastructure for request
execution so Zen endpoints behave like standard OpenAI-style APIs.
* Supports OpenCode Zen's curated list of coding-focused models.
"""
FRIENDLY_NAME = "OpenCode Zen"
# Model registry for managing configurations
_registry: ZenModelRegistry | None = None
def __init__(self, api_key: str, **kwargs):
"""Initialize OpenCode Zen provider.
Args:
api_key: OpenCode Zen API key
**kwargs: Additional configuration
"""
base_url = "https://opencode.ai/zen/v1"
super().__init__(api_key, base_url=base_url, **kwargs)
# Initialize model registry
if ZenProvider._registry is None:
ZenProvider._registry = ZenModelRegistry()
# Log loaded models only on first load
models = self._registry.list_models()
logging.info(f"OpenCode Zen loaded {len(models)} models")
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: str | None = None,
) -> ModelCapabilities | None:
"""Fetch Zen capabilities from the registry."""
capabilities = self._registry.get_capabilities(canonical_name)
if capabilities:
return capabilities
# For unknown models, return None to let base class handle error
logging.debug("Model '%s' not found in Zen registry", canonical_name)
return None
# ------------------------------------------------------------------
# Provider identity
# ------------------------------------------------------------------
def get_provider_type(self) -> ProviderType:
"""Identify this provider for restrictions and logging."""
return ProviderType.ZEN
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
def list_models(
self,
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Return formatted Zen model names, respecting restrictions."""
if not self._registry:
return []
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
allowed_configs: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if not config:
continue
if restriction_service:
if not restriction_service.is_allowed(self.get_provider_type(), model_name):
continue
allowed_configs[model_name] = config
if not allowed_configs:
return []
return ModelCapabilities.collect_model_names(
allowed_configs,
include_aliases=include_aliases,
lowercase=lowercase,
unique=unique,
)
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve aliases defined in the Zen registry."""
config = self._registry.resolve(model_name)
if config and config.model_name != model_name:
logging.debug("Resolved Zen model alias '%s' to '%s'", model_name, config.model_name)
return config.model_name
return model_name
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Expose registry-backed Zen capabilities."""
if not self._registry:
return {}
capabilities: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if config:
capabilities[model_name] = config
return capabilities