feat!: breaking change - OpenRouter models are now read from conf/openrouter_models.json while Custom / Self-hosted models are read from conf/custom_models.json
feat: Azure OpenAI / Azure AI Foundry support. Models should be defined in conf/azure_models.json (or a custom path). See .env.example for environment variables or see readme. https://github.com/BeehiveInnovations/zen-mcp-server/issues/265 feat: OpenRouter / Custom Models / Azure can separately also use custom config paths now (see .env.example ) refactor: Model registry class made abstract, OpenRouter / Custom Provider / Azure OpenAI now subclass these refactor: breaking change: `is_custom` property has been removed from model_capabilities.py (and thus custom_models.json) given each models are now read from separate configuration files
This commit is contained in:
241
providers/model_registry_base.py
Normal file
241
providers/model_registry_base.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Shared infrastructure for JSON-backed model registries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import fields
|
||||
from pathlib import Path
|
||||
|
||||
from utils.env import get_env
|
||||
from utils.file_utils import read_json_file
|
||||
|
||||
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CAPABILITY_FIELD_NAMES = {field.name for field in fields(ModelCapabilities)}
|
||||
|
||||
|
||||
class CustomModelRegistryBase:
|
||||
"""Load and expose capability metadata from a JSON manifest."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
env_var_name: str,
|
||||
default_filename: str,
|
||||
config_path: str | None = None,
|
||||
) -> None:
|
||||
self._env_var_name = env_var_name
|
||||
self._default_filename = default_filename
|
||||
self._use_resources = False
|
||||
self._resource_package = "conf"
|
||||
self._default_path = Path(__file__).parent.parent / "conf" / default_filename
|
||||
|
||||
if config_path:
|
||||
self.config_path = Path(config_path)
|
||||
else:
|
||||
env_path = get_env(env_var_name)
|
||||
if env_path:
|
||||
self.config_path = Path(env_path)
|
||||
else:
|
||||
try:
|
||||
resource = importlib.resources.files(self._resource_package).joinpath(default_filename)
|
||||
if hasattr(resource, "read_text"):
|
||||
self._use_resources = True
|
||||
self.config_path = None
|
||||
else:
|
||||
raise AttributeError("resource accessor not available")
|
||||
except Exception:
|
||||
self.config_path = Path(__file__).parent.parent / "conf" / default_filename
|
||||
|
||||
self.alias_map: dict[str, str] = {}
|
||||
self.model_map: dict[str, ModelCapabilities] = {}
|
||||
self._extras: dict[str, dict] = {}
|
||||
|
||||
def reload(self) -> None:
|
||||
data = self._load_config_data()
|
||||
configs = [config for config in self._parse_models(data) if config is not None]
|
||||
self._build_maps(configs)
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
return list(self.model_map.keys())
|
||||
|
||||
def list_aliases(self) -> list[str]:
|
||||
return list(self.alias_map.keys())
|
||||
|
||||
def resolve(self, name_or_alias: str) -> ModelCapabilities | None:
|
||||
key = name_or_alias.lower()
|
||||
canonical = self.alias_map.get(key)
|
||||
if canonical:
|
||||
return self.model_map.get(canonical)
|
||||
|
||||
for model_name in self.model_map:
|
||||
if model_name.lower() == key:
|
||||
return self.model_map[model_name]
|
||||
return None
|
||||
|
||||
def get_capabilities(self, name_or_alias: str) -> ModelCapabilities | None:
|
||||
return self.resolve(name_or_alias)
|
||||
|
||||
def get_entry(self, model_name: str) -> dict | None:
|
||||
return self._extras.get(model_name)
|
||||
|
||||
def iter_entries(self) -> Iterable[tuple[str, ModelCapabilities, dict]]:
|
||||
for model_name, capability in self.model_map.items():
|
||||
yield model_name, capability, self._extras.get(model_name, {})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _load_config_data(self) -> dict:
|
||||
if self._use_resources:
|
||||
try:
|
||||
resource = importlib.resources.files(self._resource_package).joinpath(self._default_filename)
|
||||
if hasattr(resource, "read_text"):
|
||||
config_text = resource.read_text(encoding="utf-8")
|
||||
else: # pragma: no cover - legacy Python fallback
|
||||
with resource.open("r", encoding="utf-8") as handle:
|
||||
config_text = handle.read()
|
||||
data = json.loads(config_text)
|
||||
except FileNotFoundError:
|
||||
logger.debug("Packaged %s not found", self._default_filename)
|
||||
return {"models": []}
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read packaged %s: %s", self._default_filename, exc)
|
||||
return {"models": []}
|
||||
return data or {"models": []}
|
||||
|
||||
if not self.config_path:
|
||||
raise FileNotFoundError("Registry configuration path is not set")
|
||||
|
||||
if not self.config_path.exists():
|
||||
logger.debug("Model registry config not found at %s", self.config_path)
|
||||
if self.config_path == self._default_path:
|
||||
fallback = Path.cwd() / "conf" / self._default_filename
|
||||
if fallback != self.config_path and fallback.exists():
|
||||
logger.debug("Falling back to %s", fallback)
|
||||
self.config_path = fallback
|
||||
else:
|
||||
return {"models": []}
|
||||
else:
|
||||
return {"models": []}
|
||||
|
||||
data = read_json_file(str(self.config_path))
|
||||
return data or {"models": []}
|
||||
|
||||
@property
|
||||
def use_resources(self) -> bool:
|
||||
return self._use_resources
|
||||
|
||||
def _parse_models(self, data: dict) -> Iterable[ModelCapabilities | None]:
|
||||
for raw in data.get("models", []):
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
yield self._convert_entry(raw)
|
||||
|
||||
def _convert_entry(self, raw: dict) -> ModelCapabilities | None:
|
||||
entry = dict(raw)
|
||||
model_name = entry.get("model_name")
|
||||
if not model_name:
|
||||
return None
|
||||
|
||||
aliases = entry.get("aliases")
|
||||
if isinstance(aliases, str):
|
||||
entry["aliases"] = [alias.strip() for alias in aliases.split(",") if alias.strip()]
|
||||
|
||||
entry.setdefault("friendly_name", self._default_friendly_name(model_name))
|
||||
|
||||
temperature_hint = entry.get("temperature_constraint")
|
||||
if isinstance(temperature_hint, str):
|
||||
entry["temperature_constraint"] = TemperatureConstraint.create(temperature_hint)
|
||||
elif temperature_hint is None:
|
||||
entry["temperature_constraint"] = TemperatureConstraint.create("range")
|
||||
|
||||
if "max_tokens" in entry:
|
||||
raise ValueError(
|
||||
"`max_tokens` is no longer supported. Use `max_output_tokens` in your model configuration."
|
||||
)
|
||||
|
||||
unknown_keys = set(entry.keys()) - CAPABILITY_FIELD_NAMES - self._extra_keys()
|
||||
if unknown_keys:
|
||||
raise ValueError("Unsupported fields in model configuration: " + ", ".join(sorted(unknown_keys)))
|
||||
|
||||
capability, extras = self._finalise_entry(entry)
|
||||
capability.provider = self._provider_default()
|
||||
self._extras[capability.model_name] = extras or {}
|
||||
return capability
|
||||
|
||||
def _default_friendly_name(self, model_name: str) -> str:
|
||||
return model_name
|
||||
|
||||
def _extra_keys(self) -> set[str]:
|
||||
return set()
|
||||
|
||||
def _provider_default(self) -> ProviderType:
|
||||
return ProviderType.OPENROUTER
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
return ModelCapabilities(**{k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}), {}
|
||||
|
||||
def _build_maps(self, configs: Iterable[ModelCapabilities]) -> None:
|
||||
alias_map: dict[str, str] = {}
|
||||
model_map: dict[str, ModelCapabilities] = {}
|
||||
|
||||
for config in configs:
|
||||
if not config:
|
||||
continue
|
||||
model_map[config.model_name] = config
|
||||
|
||||
model_name_lower = config.model_name.lower()
|
||||
if model_name_lower not in alias_map:
|
||||
alias_map[model_name_lower] = config.model_name
|
||||
|
||||
for alias in config.aliases:
|
||||
alias_lower = alias.lower()
|
||||
if alias_lower in alias_map and alias_map[alias_lower] != config.model_name:
|
||||
raise ValueError(
|
||||
f"Duplicate alias '{alias}' found for models '{alias_map[alias_lower]}' and '{config.model_name}'"
|
||||
)
|
||||
alias_map[alias_lower] = config.model_name
|
||||
|
||||
self.alias_map = alias_map
|
||||
self.model_map = model_map
|
||||
|
||||
|
||||
class CapabilityModelRegistry(CustomModelRegistryBase):
|
||||
"""Registry that returns `ModelCapabilities` objects with alias support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
env_var_name: str,
|
||||
default_filename: str,
|
||||
provider: ProviderType,
|
||||
friendly_prefix: str,
|
||||
config_path: str | None = None,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._friendly_prefix = friendly_prefix
|
||||
super().__init__(
|
||||
env_var_name=env_var_name,
|
||||
default_filename=default_filename,
|
||||
config_path=config_path,
|
||||
)
|
||||
self.reload()
|
||||
|
||||
def _provider_default(self) -> ProviderType:
|
||||
return self._provider
|
||||
|
||||
def _default_friendly_name(self, model_name: str) -> str:
|
||||
return self._friendly_prefix.format(model=model_name)
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
||||
filtered.setdefault("provider", self._provider_default())
|
||||
capability = ModelCapabilities(**filtered)
|
||||
return capability, {}
|
||||
Reference in New Issue
Block a user