refactor: moved registries into a separate module and code cleanup
fix: refactored dial provider to follow the same pattern
This commit is contained in:
19
providers/registries/__init__.py
Normal file
19
providers/registries/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Registry implementations for provider capability manifests."""
|
||||
|
||||
from .azure import AzureModelRegistry
|
||||
from .custom import CustomEndpointModelRegistry
|
||||
from .dial import DialModelRegistry
|
||||
from .gemini import GeminiModelRegistry
|
||||
from .openai import OpenAIModelRegistry
|
||||
from .openrouter import OpenRouterModelRegistry
|
||||
from .xai import XAIModelRegistry
|
||||
|
||||
__all__ = [
|
||||
"AzureModelRegistry",
|
||||
"CustomEndpointModelRegistry",
|
||||
"DialModelRegistry",
|
||||
"GeminiModelRegistry",
|
||||
"OpenAIModelRegistry",
|
||||
"OpenRouterModelRegistry",
|
||||
"XAIModelRegistry",
|
||||
]
|
||||
45
providers/registries/azure.py
Normal file
45
providers/registries/azure.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Registry loader for Azure OpenAI model configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ..shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
from .base import CAPABILITY_FIELD_NAMES, CustomModelRegistryBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureModelRegistry(CustomModelRegistryBase):
|
||||
"""Load Azure-specific model metadata from configuration files."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="AZURE_MODELS_CONFIG_PATH",
|
||||
default_filename="azure_models.json",
|
||||
config_path=config_path,
|
||||
)
|
||||
self.reload()
|
||||
|
||||
def _extra_keys(self) -> set[str]:
|
||||
return {"deployment", "deployment_name"}
|
||||
|
||||
def _provider_default(self) -> ProviderType:
|
||||
return ProviderType.AZURE
|
||||
|
||||
def _default_friendly_name(self, model_name: str) -> str:
|
||||
return f"Azure OpenAI ({model_name})"
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
deployment = entry.pop("deployment", None) or entry.pop("deployment_name", None)
|
||||
if not deployment:
|
||||
raise ValueError(f"Azure model '{entry.get('model_name')}' is missing required 'deployment' field")
|
||||
|
||||
temp_hint = entry.get("temperature_constraint")
|
||||
if isinstance(temp_hint, str):
|
||||
entry["temperature_constraint"] = TemperatureConstraint.create(temp_hint)
|
||||
|
||||
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
||||
filtered.setdefault("provider", ProviderType.AZURE)
|
||||
capability = ModelCapabilities(**filtered)
|
||||
return capability, {"deployment": deployment}
|
||||
246
providers/registries/base.py
Normal file
246
providers/registries/base.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Shared infrastructure for JSON-backed model registries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import fields
|
||||
from pathlib import Path
|
||||
|
||||
from utils.env import get_env
|
||||
from utils.file_utils import read_json_file
|
||||
|
||||
from ..shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CAPABILITY_FIELD_NAMES = {field.name for field in fields(ModelCapabilities)}
|
||||
|
||||
|
||||
class CustomModelRegistryBase:
|
||||
"""Load and expose capability metadata from a JSON manifest."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
env_var_name: str,
|
||||
default_filename: str,
|
||||
config_path: str | None = None,
|
||||
) -> None:
|
||||
self._env_var_name = env_var_name
|
||||
self._default_filename = default_filename
|
||||
self._use_resources = False
|
||||
self._resource_package = "conf"
|
||||
self._default_path = Path(__file__).resolve().parents[3] / "conf" / default_filename
|
||||
|
||||
if config_path:
|
||||
self.config_path = Path(config_path)
|
||||
else:
|
||||
env_path = get_env(env_var_name)
|
||||
if env_path:
|
||||
self.config_path = Path(env_path)
|
||||
else:
|
||||
try:
|
||||
resource = importlib.resources.files(self._resource_package).joinpath(default_filename)
|
||||
if hasattr(resource, "read_text"):
|
||||
self._use_resources = True
|
||||
self.config_path = None
|
||||
else:
|
||||
raise AttributeError("resource accessor not available")
|
||||
except Exception:
|
||||
self.config_path = Path(__file__).resolve().parents[3] / "conf" / default_filename
|
||||
|
||||
self.alias_map: dict[str, str] = {}
|
||||
self.model_map: dict[str, ModelCapabilities] = {}
|
||||
self._extras: dict[str, dict] = {}
|
||||
|
||||
def reload(self) -> None:
|
||||
data = self._load_config_data()
|
||||
configs = [config for config in self._parse_models(data) if config is not None]
|
||||
self._build_maps(configs)
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
return list(self.model_map.keys())
|
||||
|
||||
def list_aliases(self) -> list[str]:
|
||||
return list(self.alias_map.keys())
|
||||
|
||||
def resolve(self, name_or_alias: str) -> ModelCapabilities | None:
|
||||
key = name_or_alias.lower()
|
||||
canonical = self.alias_map.get(key)
|
||||
if canonical:
|
||||
return self.model_map.get(canonical)
|
||||
|
||||
for model_name in self.model_map:
|
||||
if model_name.lower() == key:
|
||||
return self.model_map[model_name]
|
||||
return None
|
||||
|
||||
def get_capabilities(self, name_or_alias: str) -> ModelCapabilities | None:
|
||||
return self.resolve(name_or_alias)
|
||||
|
||||
def get_entry(self, model_name: str) -> dict | None:
|
||||
return self._extras.get(model_name)
|
||||
|
||||
def get_model_config(self, model_name: str) -> ModelCapabilities | None:
|
||||
"""Backwards-compatible accessor for registries expecting this helper."""
|
||||
|
||||
return self.model_map.get(model_name) or self.resolve(model_name)
|
||||
|
||||
def iter_entries(self) -> Iterable[tuple[str, ModelCapabilities, dict]]:
|
||||
for model_name, capability in self.model_map.items():
|
||||
yield model_name, capability, self._extras.get(model_name, {})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _load_config_data(self) -> dict:
|
||||
if self._use_resources:
|
||||
try:
|
||||
resource = importlib.resources.files(self._resource_package).joinpath(self._default_filename)
|
||||
if hasattr(resource, "read_text"):
|
||||
config_text = resource.read_text(encoding="utf-8")
|
||||
else: # pragma: no cover - legacy Python fallback
|
||||
with resource.open("r", encoding="utf-8") as handle:
|
||||
config_text = handle.read()
|
||||
data = json.loads(config_text)
|
||||
except FileNotFoundError:
|
||||
logger.debug("Packaged %s not found", self._default_filename)
|
||||
return {"models": []}
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read packaged %s: %s", self._default_filename, exc)
|
||||
return {"models": []}
|
||||
return data or {"models": []}
|
||||
|
||||
if not self.config_path:
|
||||
raise FileNotFoundError("Registry configuration path is not set")
|
||||
|
||||
if not self.config_path.exists():
|
||||
logger.debug("Model registry config not found at %s", self.config_path)
|
||||
if self.config_path == self._default_path:
|
||||
fallback = Path.cwd() / "conf" / self._default_filename
|
||||
if fallback != self.config_path and fallback.exists():
|
||||
logger.debug("Falling back to %s", fallback)
|
||||
self.config_path = fallback
|
||||
else:
|
||||
return {"models": []}
|
||||
else:
|
||||
return {"models": []}
|
||||
|
||||
data = read_json_file(str(self.config_path))
|
||||
return data or {"models": []}
|
||||
|
||||
@property
|
||||
def use_resources(self) -> bool:
|
||||
return self._use_resources
|
||||
|
||||
def _parse_models(self, data: dict) -> Iterable[ModelCapabilities | None]:
|
||||
for raw in data.get("models", []):
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
yield self._convert_entry(raw)
|
||||
|
||||
def _convert_entry(self, raw: dict) -> ModelCapabilities | None:
|
||||
entry = dict(raw)
|
||||
model_name = entry.get("model_name")
|
||||
if not model_name:
|
||||
return None
|
||||
|
||||
aliases = entry.get("aliases")
|
||||
if isinstance(aliases, str):
|
||||
entry["aliases"] = [alias.strip() for alias in aliases.split(",") if alias.strip()]
|
||||
|
||||
entry.setdefault("friendly_name", self._default_friendly_name(model_name))
|
||||
|
||||
temperature_hint = entry.get("temperature_constraint")
|
||||
if isinstance(temperature_hint, str):
|
||||
entry["temperature_constraint"] = TemperatureConstraint.create(temperature_hint)
|
||||
elif temperature_hint is None:
|
||||
entry["temperature_constraint"] = TemperatureConstraint.create("range")
|
||||
|
||||
if "max_tokens" in entry:
|
||||
raise ValueError(
|
||||
"`max_tokens` is no longer supported. Use `max_output_tokens` in your model configuration."
|
||||
)
|
||||
|
||||
unknown_keys = set(entry.keys()) - CAPABILITY_FIELD_NAMES - self._extra_keys()
|
||||
if unknown_keys:
|
||||
raise ValueError("Unsupported fields in model configuration: " + ", ".join(sorted(unknown_keys)))
|
||||
|
||||
capability, extras = self._finalise_entry(entry)
|
||||
capability.provider = self._provider_default()
|
||||
self._extras[capability.model_name] = extras or {}
|
||||
return capability
|
||||
|
||||
def _default_friendly_name(self, model_name: str) -> str:
|
||||
return model_name
|
||||
|
||||
def _extra_keys(self) -> set[str]:
|
||||
return set()
|
||||
|
||||
def _provider_default(self) -> ProviderType:
|
||||
return ProviderType.OPENROUTER
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
return ModelCapabilities(**{k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}), {}
|
||||
|
||||
def _build_maps(self, configs: Iterable[ModelCapabilities]) -> None:
|
||||
alias_map: dict[str, str] = {}
|
||||
model_map: dict[str, ModelCapabilities] = {}
|
||||
|
||||
for config in configs:
|
||||
if not config:
|
||||
continue
|
||||
model_map[config.model_name] = config
|
||||
|
||||
model_name_lower = config.model_name.lower()
|
||||
if model_name_lower not in alias_map:
|
||||
alias_map[model_name_lower] = config.model_name
|
||||
|
||||
for alias in config.aliases:
|
||||
alias_lower = alias.lower()
|
||||
if alias_lower in alias_map and alias_map[alias_lower] != config.model_name:
|
||||
raise ValueError(
|
||||
f"Duplicate alias '{alias}' found for models '{alias_map[alias_lower]}' and '{config.model_name}'"
|
||||
)
|
||||
alias_map[alias_lower] = config.model_name
|
||||
|
||||
self.alias_map = alias_map
|
||||
self.model_map = model_map
|
||||
|
||||
|
||||
class CapabilityModelRegistry(CustomModelRegistryBase):
|
||||
"""Registry that returns :class:`ModelCapabilities` objects with alias support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
env_var_name: str,
|
||||
default_filename: str,
|
||||
provider: ProviderType,
|
||||
friendly_prefix: str,
|
||||
config_path: str | None = None,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._friendly_prefix = friendly_prefix
|
||||
super().__init__(
|
||||
env_var_name=env_var_name,
|
||||
default_filename=default_filename,
|
||||
config_path=config_path,
|
||||
)
|
||||
self.reload()
|
||||
|
||||
def _provider_default(self) -> ProviderType:
|
||||
return self._provider
|
||||
|
||||
def _default_friendly_name(self, model_name: str) -> str:
|
||||
return self._friendly_prefix.format(model=model_name)
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
||||
filtered.setdefault("provider", self._provider_default())
|
||||
capability = ModelCapabilities(**filtered)
|
||||
return capability, {}
|
||||
25
providers/registries/custom.py
Normal file
25
providers/registries/custom.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Registry loader for custom OpenAI-compatible endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..shared import ModelCapabilities, ProviderType
|
||||
from .base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
||||
|
||||
|
||||
class CustomEndpointModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by ``conf/custom_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="CUSTOM_MODELS_CONFIG_PATH",
|
||||
default_filename="custom_models.json",
|
||||
provider=ProviderType.CUSTOM,
|
||||
friendly_prefix="Custom ({model})",
|
||||
config_path=config_path,
|
||||
)
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
||||
filtered.setdefault("provider", ProviderType.CUSTOM)
|
||||
capability = ModelCapabilities(**filtered)
|
||||
return capability, {}
|
||||
19
providers/registries/dial.py
Normal file
19
providers/registries/dial.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Registry loader for DIAL provider capabilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..shared import ProviderType
|
||||
from .base import CapabilityModelRegistry
|
||||
|
||||
|
||||
class DialModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by ``conf/dial_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="DIAL_MODELS_CONFIG_PATH",
|
||||
default_filename="dial_models.json",
|
||||
provider=ProviderType.DIAL,
|
||||
friendly_prefix="DIAL ({model})",
|
||||
config_path=config_path,
|
||||
)
|
||||
19
providers/registries/gemini.py
Normal file
19
providers/registries/gemini.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Registry loader for Gemini model capabilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..shared import ProviderType
|
||||
from .base import CapabilityModelRegistry
|
||||
|
||||
|
||||
class GeminiModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by ``conf/gemini_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="GEMINI_MODELS_CONFIG_PATH",
|
||||
default_filename="gemini_models.json",
|
||||
provider=ProviderType.GOOGLE,
|
||||
friendly_prefix="Gemini ({model})",
|
||||
config_path=config_path,
|
||||
)
|
||||
19
providers/registries/openai.py
Normal file
19
providers/registries/openai.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Registry loader for OpenAI model capabilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..shared import ProviderType
|
||||
from .base import CapabilityModelRegistry
|
||||
|
||||
|
||||
class OpenAIModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by ``conf/openai_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="OPENAI_MODELS_CONFIG_PATH",
|
||||
default_filename="openai_models.json",
|
||||
provider=ProviderType.OPENAI,
|
||||
friendly_prefix="OpenAI ({model})",
|
||||
config_path=config_path,
|
||||
)
|
||||
38
providers/registries/openrouter.py
Normal file
38
providers/registries/openrouter.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""OpenRouter model registry for managing model configurations and aliases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..shared import ModelCapabilities, ProviderType
|
||||
from .base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
||||
|
||||
|
||||
class OpenRouterModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by ``conf/openrouter_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="OPENROUTER_MODELS_CONFIG_PATH",
|
||||
default_filename="openrouter_models.json",
|
||||
provider=ProviderType.OPENROUTER,
|
||||
friendly_prefix="OpenRouter ({model})",
|
||||
config_path=config_path,
|
||||
)
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
provider_override = entry.get("provider")
|
||||
if isinstance(provider_override, str):
|
||||
entry_provider = ProviderType(provider_override.lower())
|
||||
elif isinstance(provider_override, ProviderType):
|
||||
entry_provider = provider_override
|
||||
else:
|
||||
entry_provider = ProviderType.OPENROUTER
|
||||
|
||||
if entry_provider == ProviderType.CUSTOM:
|
||||
entry.setdefault("friendly_name", f"Custom ({entry['model_name']})")
|
||||
else:
|
||||
entry.setdefault("friendly_name", f"OpenRouter ({entry['model_name']})")
|
||||
|
||||
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
||||
filtered.setdefault("provider", entry_provider)
|
||||
capability = ModelCapabilities(**filtered)
|
||||
return capability, {}
|
||||
19
providers/registries/xai.py
Normal file
19
providers/registries/xai.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Registry loader for X.AI model capabilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..shared import ProviderType
|
||||
from .base import CapabilityModelRegistry
|
||||
|
||||
|
||||
class XAIModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by ``conf/xai_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="XAI_MODELS_CONFIG_PATH",
|
||||
default_filename="xai_models.json",
|
||||
provider=ProviderType.XAI,
|
||||
friendly_prefix="X.AI ({model})",
|
||||
config_path=config_path,
|
||||
)
|
||||
Reference in New Issue
Block a user