refactor: moved registries into a separate module and code cleanup

fix: refactored dial provider to follow the same pattern
This commit is contained in:
Fahad
2025-10-07 12:59:09 +04:00
parent c27e81d6d2
commit 7c36b9255a
54 changed files with 325 additions and 282 deletions

View File

@@ -0,0 +1,19 @@
"""Registry implementations for provider capability manifests."""
from .azure import AzureModelRegistry
from .custom import CustomEndpointModelRegistry
from .dial import DialModelRegistry
from .gemini import GeminiModelRegistry
from .openai import OpenAIModelRegistry
from .openrouter import OpenRouterModelRegistry
from .xai import XAIModelRegistry
__all__ = [
"AzureModelRegistry",
"CustomEndpointModelRegistry",
"DialModelRegistry",
"GeminiModelRegistry",
"OpenAIModelRegistry",
"OpenRouterModelRegistry",
"XAIModelRegistry",
]

View File

@@ -0,0 +1,45 @@
"""Registry loader for Azure OpenAI model configurations."""
from __future__ import annotations
import logging
from ..shared import ModelCapabilities, ProviderType, TemperatureConstraint
from .base import CAPABILITY_FIELD_NAMES, CustomModelRegistryBase
logger = logging.getLogger(__name__)
class AzureModelRegistry(CustomModelRegistryBase):
"""Load Azure-specific model metadata from configuration files."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(
env_var_name="AZURE_MODELS_CONFIG_PATH",
default_filename="azure_models.json",
config_path=config_path,
)
self.reload()
def _extra_keys(self) -> set[str]:
return {"deployment", "deployment_name"}
def _provider_default(self) -> ProviderType:
return ProviderType.AZURE
def _default_friendly_name(self, model_name: str) -> str:
return f"Azure OpenAI ({model_name})"
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
deployment = entry.pop("deployment", None) or entry.pop("deployment_name", None)
if not deployment:
raise ValueError(f"Azure model '{entry.get('model_name')}' is missing required 'deployment' field")
temp_hint = entry.get("temperature_constraint")
if isinstance(temp_hint, str):
entry["temperature_constraint"] = TemperatureConstraint.create(temp_hint)
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
filtered.setdefault("provider", ProviderType.AZURE)
capability = ModelCapabilities(**filtered)
return capability, {"deployment": deployment}

View File

@@ -0,0 +1,246 @@
"""Shared infrastructure for JSON-backed model registries."""
from __future__ import annotations
import importlib.resources
import json
import logging
from collections.abc import Iterable
from dataclasses import fields
from pathlib import Path
from utils.env import get_env
from utils.file_utils import read_json_file
from ..shared import ModelCapabilities, ProviderType, TemperatureConstraint
logger = logging.getLogger(__name__)
CAPABILITY_FIELD_NAMES = {field.name for field in fields(ModelCapabilities)}
class CustomModelRegistryBase:
"""Load and expose capability metadata from a JSON manifest."""
def __init__(
self,
*,
env_var_name: str,
default_filename: str,
config_path: str | None = None,
) -> None:
self._env_var_name = env_var_name
self._default_filename = default_filename
self._use_resources = False
self._resource_package = "conf"
self._default_path = Path(__file__).resolve().parents[3] / "conf" / default_filename
if config_path:
self.config_path = Path(config_path)
else:
env_path = get_env(env_var_name)
if env_path:
self.config_path = Path(env_path)
else:
try:
resource = importlib.resources.files(self._resource_package).joinpath(default_filename)
if hasattr(resource, "read_text"):
self._use_resources = True
self.config_path = None
else:
raise AttributeError("resource accessor not available")
except Exception:
self.config_path = Path(__file__).resolve().parents[3] / "conf" / default_filename
self.alias_map: dict[str, str] = {}
self.model_map: dict[str, ModelCapabilities] = {}
self._extras: dict[str, dict] = {}
def reload(self) -> None:
data = self._load_config_data()
configs = [config for config in self._parse_models(data) if config is not None]
self._build_maps(configs)
def list_models(self) -> list[str]:
return list(self.model_map.keys())
def list_aliases(self) -> list[str]:
return list(self.alias_map.keys())
def resolve(self, name_or_alias: str) -> ModelCapabilities | None:
key = name_or_alias.lower()
canonical = self.alias_map.get(key)
if canonical:
return self.model_map.get(canonical)
for model_name in self.model_map:
if model_name.lower() == key:
return self.model_map[model_name]
return None
def get_capabilities(self, name_or_alias: str) -> ModelCapabilities | None:
return self.resolve(name_or_alias)
def get_entry(self, model_name: str) -> dict | None:
return self._extras.get(model_name)
def get_model_config(self, model_name: str) -> ModelCapabilities | None:
"""Backwards-compatible accessor for registries expecting this helper."""
return self.model_map.get(model_name) or self.resolve(model_name)
def iter_entries(self) -> Iterable[tuple[str, ModelCapabilities, dict]]:
for model_name, capability in self.model_map.items():
yield model_name, capability, self._extras.get(model_name, {})
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _load_config_data(self) -> dict:
if self._use_resources:
try:
resource = importlib.resources.files(self._resource_package).joinpath(self._default_filename)
if hasattr(resource, "read_text"):
config_text = resource.read_text(encoding="utf-8")
else: # pragma: no cover - legacy Python fallback
with resource.open("r", encoding="utf-8") as handle:
config_text = handle.read()
data = json.loads(config_text)
except FileNotFoundError:
logger.debug("Packaged %s not found", self._default_filename)
return {"models": []}
except Exception as exc:
logger.warning("Failed to read packaged %s: %s", self._default_filename, exc)
return {"models": []}
return data or {"models": []}
if not self.config_path:
raise FileNotFoundError("Registry configuration path is not set")
if not self.config_path.exists():
logger.debug("Model registry config not found at %s", self.config_path)
if self.config_path == self._default_path:
fallback = Path.cwd() / "conf" / self._default_filename
if fallback != self.config_path and fallback.exists():
logger.debug("Falling back to %s", fallback)
self.config_path = fallback
else:
return {"models": []}
else:
return {"models": []}
data = read_json_file(str(self.config_path))
return data or {"models": []}
@property
def use_resources(self) -> bool:
return self._use_resources
def _parse_models(self, data: dict) -> Iterable[ModelCapabilities | None]:
for raw in data.get("models", []):
if not isinstance(raw, dict):
continue
yield self._convert_entry(raw)
def _convert_entry(self, raw: dict) -> ModelCapabilities | None:
entry = dict(raw)
model_name = entry.get("model_name")
if not model_name:
return None
aliases = entry.get("aliases")
if isinstance(aliases, str):
entry["aliases"] = [alias.strip() for alias in aliases.split(",") if alias.strip()]
entry.setdefault("friendly_name", self._default_friendly_name(model_name))
temperature_hint = entry.get("temperature_constraint")
if isinstance(temperature_hint, str):
entry["temperature_constraint"] = TemperatureConstraint.create(temperature_hint)
elif temperature_hint is None:
entry["temperature_constraint"] = TemperatureConstraint.create("range")
if "max_tokens" in entry:
raise ValueError(
"`max_tokens` is no longer supported. Use `max_output_tokens` in your model configuration."
)
unknown_keys = set(entry.keys()) - CAPABILITY_FIELD_NAMES - self._extra_keys()
if unknown_keys:
raise ValueError("Unsupported fields in model configuration: " + ", ".join(sorted(unknown_keys)))
capability, extras = self._finalise_entry(entry)
capability.provider = self._provider_default()
self._extras[capability.model_name] = extras or {}
return capability
def _default_friendly_name(self, model_name: str) -> str:
return model_name
def _extra_keys(self) -> set[str]:
return set()
def _provider_default(self) -> ProviderType:
return ProviderType.OPENROUTER
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
return ModelCapabilities(**{k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}), {}
def _build_maps(self, configs: Iterable[ModelCapabilities]) -> None:
alias_map: dict[str, str] = {}
model_map: dict[str, ModelCapabilities] = {}
for config in configs:
if not config:
continue
model_map[config.model_name] = config
model_name_lower = config.model_name.lower()
if model_name_lower not in alias_map:
alias_map[model_name_lower] = config.model_name
for alias in config.aliases:
alias_lower = alias.lower()
if alias_lower in alias_map and alias_map[alias_lower] != config.model_name:
raise ValueError(
f"Duplicate alias '{alias}' found for models '{alias_map[alias_lower]}' and '{config.model_name}'"
)
alias_map[alias_lower] = config.model_name
self.alias_map = alias_map
self.model_map = model_map
class CapabilityModelRegistry(CustomModelRegistryBase):
"""Registry that returns :class:`ModelCapabilities` objects with alias support."""
def __init__(
self,
*,
env_var_name: str,
default_filename: str,
provider: ProviderType,
friendly_prefix: str,
config_path: str | None = None,
) -> None:
self._provider = provider
self._friendly_prefix = friendly_prefix
super().__init__(
env_var_name=env_var_name,
default_filename=default_filename,
config_path=config_path,
)
self.reload()
def _provider_default(self) -> ProviderType:
return self._provider
def _default_friendly_name(self, model_name: str) -> str:
return self._friendly_prefix.format(model=model_name)
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
filtered.setdefault("provider", self._provider_default())
capability = ModelCapabilities(**filtered)
return capability, {}

View File

@@ -0,0 +1,25 @@
"""Registry loader for custom OpenAI-compatible endpoints."""
from __future__ import annotations
from ..shared import ModelCapabilities, ProviderType
from .base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
class CustomEndpointModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by ``conf/custom_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(
env_var_name="CUSTOM_MODELS_CONFIG_PATH",
default_filename="custom_models.json",
provider=ProviderType.CUSTOM,
friendly_prefix="Custom ({model})",
config_path=config_path,
)
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
filtered.setdefault("provider", ProviderType.CUSTOM)
capability = ModelCapabilities(**filtered)
return capability, {}

View File

@@ -0,0 +1,19 @@
"""Registry loader for DIAL provider capabilities."""
from __future__ import annotations
from ..shared import ProviderType
from .base import CapabilityModelRegistry
class DialModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by ``conf/dial_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(
env_var_name="DIAL_MODELS_CONFIG_PATH",
default_filename="dial_models.json",
provider=ProviderType.DIAL,
friendly_prefix="DIAL ({model})",
config_path=config_path,
)

View File

@@ -0,0 +1,19 @@
"""Registry loader for Gemini model capabilities."""
from __future__ import annotations
from ..shared import ProviderType
from .base import CapabilityModelRegistry
class GeminiModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by ``conf/gemini_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(
env_var_name="GEMINI_MODELS_CONFIG_PATH",
default_filename="gemini_models.json",
provider=ProviderType.GOOGLE,
friendly_prefix="Gemini ({model})",
config_path=config_path,
)

View File

@@ -0,0 +1,19 @@
"""Registry loader for OpenAI model capabilities."""
from __future__ import annotations
from ..shared import ProviderType
from .base import CapabilityModelRegistry
class OpenAIModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by ``conf/openai_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(
env_var_name="OPENAI_MODELS_CONFIG_PATH",
default_filename="openai_models.json",
provider=ProviderType.OPENAI,
friendly_prefix="OpenAI ({model})",
config_path=config_path,
)

View File

@@ -0,0 +1,38 @@
"""OpenRouter model registry for managing model configurations and aliases."""
from __future__ import annotations
from ..shared import ModelCapabilities, ProviderType
from .base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
class OpenRouterModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by ``conf/openrouter_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(
env_var_name="OPENROUTER_MODELS_CONFIG_PATH",
default_filename="openrouter_models.json",
provider=ProviderType.OPENROUTER,
friendly_prefix="OpenRouter ({model})",
config_path=config_path,
)
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
provider_override = entry.get("provider")
if isinstance(provider_override, str):
entry_provider = ProviderType(provider_override.lower())
elif isinstance(provider_override, ProviderType):
entry_provider = provider_override
else:
entry_provider = ProviderType.OPENROUTER
if entry_provider == ProviderType.CUSTOM:
entry.setdefault("friendly_name", f"Custom ({entry['model_name']})")
else:
entry.setdefault("friendly_name", f"OpenRouter ({entry['model_name']})")
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
filtered.setdefault("provider", entry_provider)
capability = ModelCapabilities(**filtered)
return capability, {}

View File

@@ -0,0 +1,19 @@
"""Registry loader for X.AI model capabilities."""
from __future__ import annotations
from ..shared import ProviderType
from .base import CapabilityModelRegistry
class XAIModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by ``conf/xai_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(
env_var_name="XAI_MODELS_CONFIG_PATH",
default_filename="xai_models.json",
provider=ProviderType.XAI,
friendly_prefix="X.AI ({model})",
config_path=config_path,
)