refactor: moved registries into a separate module and code cleanup

fix: refactored dial provider to follow the same pattern
This commit is contained in:
Fahad
2025-10-07 12:59:09 +04:00
parent c27e81d6d2
commit 7c36b9255a
54 changed files with 325 additions and 282 deletions

View File

@@ -3,8 +3,8 @@
from .azure_openai import AzureOpenAIProvider
from .base import ModelProvider
from .gemini import GeminiModelProvider
from .openai import OpenAIModelProvider
from .openai_compatible import OpenAICompatibleProvider
from .openai_provider import OpenAIModelProvider
from .openrouter import OpenRouterProvider
from .registry import ModelProviderRegistry
from .shared import ModelCapabilities, ModelResponse

View File

@@ -12,9 +12,9 @@ except ImportError: # pragma: no cover
from utils.env import get_env, suppress_env_vars
from .azure_registry import AzureModelRegistry
from .openai import OpenAIModelProvider
from .openai_compatible import OpenAICompatibleProvider
from .openai_provider import OpenAIModelProvider
from .registries.azure import AzureModelRegistry
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
logger = logging.getLogger(__name__)

View File

@@ -4,11 +4,12 @@ import logging
from utils.env import get_env
from .custom_registry import CustomEndpointModelRegistry
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .registries.custom import CustomEndpointModelRegistry
from .registries.openrouter import OpenRouterModelRegistry
from .shared import ModelCapabilities, ProviderType
class CustomProvider(OpenAICompatibleProvider):
"""Adapter for self-hosted or local OpenAI-compatible endpoints.

View File

@@ -2,17 +2,19 @@
import logging
import threading
from typing import Optional
from typing import ClassVar, Optional
from utils.env import get_env
from .openai_compatible import OpenAICompatibleProvider
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
from .registries.dial import DialModelRegistry
from .registry_provider_mixin import RegistryBackedProviderMixin
from .shared import ModelCapabilities, ModelResponse, ProviderType
logger = logging.getLogger(__name__)
class DIALModelProvider(OpenAICompatibleProvider):
class DIALModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
"""Client for the DIAL (Data & AI Layer) aggregation service.
DIAL exposes several third-party models behind a single OpenAI-compatible
@@ -23,185 +25,13 @@ class DIALModelProvider(OpenAICompatibleProvider):
FRIENDLY_NAME = "DIAL"
REGISTRY_CLASS = DialModelRegistry
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
# Retry configuration for API calls
MAX_RETRIES = 4
RETRY_DELAYS = [1, 3, 5, 8] # seconds
# Model configurations using ModelCapabilities objects
MODEL_CAPABILITIES = {
"o3-2025-04-16": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="o3-2025-04-16",
friendly_name="DIAL (O3)",
intelligence_score=14,
context_window=200_000,
max_output_tokens=100_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=False, # O3 models don't accept temperature
temperature_constraint=TemperatureConstraint.create("fixed"),
description="OpenAI O3 via DIAL - Strong reasoning model",
aliases=["o3"],
),
"o4-mini-2025-04-16": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="o4-mini-2025-04-16",
friendly_name="DIAL (O4-mini)",
intelligence_score=11,
context_window=200_000,
max_output_tokens=100_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=False, # O4 models don't accept temperature
temperature_constraint=TemperatureConstraint.create("fixed"),
description="OpenAI O4-mini via DIAL - Fast reasoning model",
aliases=["o4-mini"],
),
"anthropic.claude-sonnet-4.1-20250805-v1:0": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0",
friendly_name="DIAL (Sonnet 4.1)",
intelligence_score=10,
context_window=200_000,
max_output_tokens=64_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
supports_json_mode=True,
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Claude Sonnet 4.1 via DIAL - Balanced performance",
aliases=["sonnet-4.1", "sonnet-4"],
),
"anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking",
friendly_name="DIAL (Sonnet 4.1 Thinking)",
intelligence_score=11,
context_window=200_000,
max_output_tokens=64_000,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
supports_json_mode=True,
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Claude Sonnet 4.1 with thinking mode via DIAL",
aliases=["sonnet-4.1-thinking", "sonnet-4-thinking"],
),
"anthropic.claude-opus-4.1-20250805-v1:0": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-opus-4.1-20250805-v1:0",
friendly_name="DIAL (Opus 4.1)",
intelligence_score=14,
context_window=200_000,
max_output_tokens=64_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
supports_json_mode=True,
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Claude Opus 4.1 via DIAL - Most capable Claude model",
aliases=["opus-4.1", "opus-4"],
),
"anthropic.claude-opus-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-opus-4.1-20250805-v1:0-with-thinking",
friendly_name="DIAL (Opus 4.1 Thinking)",
intelligence_score=15,
context_window=200_000,
max_output_tokens=64_000,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
supports_json_mode=True,
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Claude Opus 4.1 with thinking mode via DIAL",
aliases=["opus-4.1-thinking", "opus-4-thinking"],
),
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-pro-preview-03-25-google-search",
friendly_name="DIAL (Gemini 2.5 Pro Search)",
intelligence_score=17,
context_window=1_000_000,
max_output_tokens=65_536,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Gemini 2.5 Pro with Google Search via DIAL",
aliases=["gemini-2.5-pro-search"],
),
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-pro-preview-05-06",
friendly_name="DIAL (Gemini 2.5 Pro)",
intelligence_score=18,
context_window=1_000_000,
max_output_tokens=65_536,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
aliases=["gemini-2.5-pro"],
),
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-flash-preview-05-20",
friendly_name="DIAL (Gemini Flash 2.5)",
intelligence_score=10,
context_window=1_000_000,
max_output_tokens=65_536,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
aliases=["gemini-2.5-flash"],
),
}
def __init__(self, api_key: str, **kwargs):
"""Initialize DIAL provider with API key and host.
@@ -209,6 +39,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
api_key: DIAL API key for authentication
**kwargs: Additional configuration options
"""
self._ensure_registry()
# Get DIAL API host from environment or kwargs
dial_host = kwargs.get("base_url") or get_env("DIAL_API_HOST") or "https://core.dialx.ai"

View File

@@ -2,7 +2,7 @@
import base64
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, ClassVar, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
@@ -14,7 +14,7 @@ from utils.env import get_env
from utils.image_utils import validate_image
from .base import ModelProvider
from .gemini_registry import GeminiModelRegistry
from .registries.gemini import GeminiModelRegistry
from .registry_provider_mixin import RegistryBackedProviderMixin
from .shared import ModelCapabilities, ModelResponse, ProviderType
@@ -30,7 +30,7 @@ class GeminiModelProvider(RegistryBackedProviderMixin, ModelProvider):
"""
REGISTRY_CLASS = GeminiModelRegistry
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
# Thinking mode configurations - percentages of model's max_thinking_tokens
# These percentages work across all models that support thinking

View File

@@ -1,13 +1,13 @@
"""OpenAI model provider implementation."""
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, ClassVar, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .openai_compatible import OpenAICompatibleProvider
from .openai_registry import OpenAIModelRegistry
from .registries.openai import OpenAIModelRegistry
from .registry_provider_mixin import RegistryBackedProviderMixin
from .shared import ModelCapabilities, ProviderType
@@ -23,7 +23,7 @@ class OpenAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider)
"""
REGISTRY_CLASS = OpenAIModelRegistry
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
def __init__(self, api_key: str, **kwargs):
"""Initialize OpenAI provider with API key."""
@@ -50,7 +50,7 @@ class OpenAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider)
return builtin
try:
from .openrouter_registry import OpenRouterModelRegistry
from .registries.openrouter import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
config = registry.get_model_config(canonical_name)

View File

@@ -5,7 +5,7 @@ import logging
from utils.env import get_env
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .registries.openrouter import OpenRouterModelRegistry
from .shared import (
ModelCapabilities,
ProviderType,

View File

@@ -0,0 +1,19 @@
"""Registry implementations for provider capability manifests."""
from .azure import AzureModelRegistry
from .custom import CustomEndpointModelRegistry
from .dial import DialModelRegistry
from .gemini import GeminiModelRegistry
from .openai import OpenAIModelRegistry
from .openrouter import OpenRouterModelRegistry
from .xai import XAIModelRegistry
__all__ = [
"AzureModelRegistry",
"CustomEndpointModelRegistry",
"DialModelRegistry",
"GeminiModelRegistry",
"OpenAIModelRegistry",
"OpenRouterModelRegistry",
"XAIModelRegistry",
]

View File

@@ -4,8 +4,8 @@ from __future__ import annotations
import logging
from .model_registry_base import CAPABILITY_FIELD_NAMES, CustomModelRegistryBase
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
from ..shared import ModelCapabilities, ProviderType, TemperatureConstraint
from .base import CAPABILITY_FIELD_NAMES, CustomModelRegistryBase
logger = logging.getLogger(__name__)

View File

@@ -12,7 +12,7 @@ from pathlib import Path
from utils.env import get_env
from utils.file_utils import read_json_file
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
from ..shared import ModelCapabilities, ProviderType, TemperatureConstraint
logger = logging.getLogger(__name__)
@@ -34,7 +34,7 @@ class CustomModelRegistryBase:
self._default_filename = default_filename
self._use_resources = False
self._resource_package = "conf"
self._default_path = Path(__file__).parent.parent / "conf" / default_filename
self._default_path = Path(__file__).resolve().parents[3] / "conf" / default_filename
if config_path:
self.config_path = Path(config_path)
@@ -51,7 +51,7 @@ class CustomModelRegistryBase:
else:
raise AttributeError("resource accessor not available")
except Exception:
self.config_path = Path(__file__).parent.parent / "conf" / default_filename
self.config_path = Path(__file__).resolve().parents[3] / "conf" / default_filename
self.alias_map: dict[str, str] = {}
self.model_map: dict[str, ModelCapabilities] = {}
@@ -213,7 +213,7 @@ class CustomModelRegistryBase:
class CapabilityModelRegistry(CustomModelRegistryBase):
"""Registry that returns `ModelCapabilities` objects with alias support."""
"""Registry that returns :class:`ModelCapabilities` objects with alias support."""
def __init__(
self,

View File

@@ -1,12 +1,14 @@
"""Registry for models exposed via custom (local) OpenAI-compatible endpoints."""
"""Registry loader for custom OpenAI-compatible endpoints."""
from __future__ import annotations
from .model_registry_base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
from .shared import ModelCapabilities, ProviderType
from ..shared import ModelCapabilities, ProviderType
from .base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
class CustomEndpointModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by ``conf/custom_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(
env_var_name="CUSTOM_MODELS_CONFIG_PATH",
@@ -15,11 +17,8 @@ class CustomEndpointModelRegistry(CapabilityModelRegistry):
friendly_prefix="Custom ({model})",
config_path=config_path,
)
self.reload()
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
entry["provider"] = ProviderType.CUSTOM
entry.setdefault("friendly_name", f"Custom ({entry['model_name']})")
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
filtered.setdefault("provider", ProviderType.CUSTOM)
capability = ModelCapabilities(**filtered)

View File

@@ -0,0 +1,19 @@
"""Registry loader for DIAL provider capabilities."""
from __future__ import annotations
from ..shared import ProviderType
from .base import CapabilityModelRegistry
class DialModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by ``conf/dial_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(
env_var_name="DIAL_MODELS_CONFIG_PATH",
default_filename="dial_models.json",
provider=ProviderType.DIAL,
friendly_prefix="DIAL ({model})",
config_path=config_path,
)

View File

@@ -2,12 +2,12 @@
from __future__ import annotations
from .model_registry_base import CapabilityModelRegistry
from .shared import ProviderType
from ..shared import ProviderType
from .base import CapabilityModelRegistry
class GeminiModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by `conf/gemini_models.json`."""
"""Capability registry backed by ``conf/gemini_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(

View File

@@ -2,12 +2,12 @@
from __future__ import annotations
from .model_registry_base import CapabilityModelRegistry
from .shared import ProviderType
from ..shared import ProviderType
from .base import CapabilityModelRegistry
class OpenAIModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by `conf/openai_models.json`."""
"""Capability registry backed by ``conf/openai_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(

View File

@@ -2,12 +2,12 @@
from __future__ import annotations
from .model_registry_base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
from .shared import ModelCapabilities, ProviderType
from ..shared import ModelCapabilities, ProviderType
from .base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
class OpenRouterModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by `conf/openrouter_models.json`."""
"""Capability registry backed by ``conf/openrouter_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(

View File

@@ -1,13 +1,13 @@
"""Registry loader for X.AI (GROK) model capabilities."""
"""Registry loader for X.AI model capabilities."""
from __future__ import annotations
from .model_registry_base import CapabilityModelRegistry
from .shared import ProviderType
from ..shared import ProviderType
from .base import CapabilityModelRegistry
class XAIModelRegistry(CapabilityModelRegistry):
"""Capability registry backed by `conf/xai_models.json`."""
"""Capability registry backed by ``conf/xai_models.json``."""
def __init__(self, config_path: str | None = None) -> None:
super().__init__(

View File

@@ -22,7 +22,7 @@ from __future__ import annotations
import logging
from typing import ClassVar
from .model_registry_base import CapabilityModelRegistry
from .registries.base import CapabilityModelRegistry
from .shared import ModelCapabilities

View File

@@ -1,15 +1,15 @@
"""X.AI (GROK) model provider implementation."""
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, ClassVar, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .openai_compatible import OpenAICompatibleProvider
from .registries.xai import XAIModelRegistry
from .registry_provider_mixin import RegistryBackedProviderMixin
from .shared import ModelCapabilities, ProviderType
from .xai_registry import XAIModelRegistry
logger = logging.getLogger(__name__)
@@ -24,7 +24,7 @@ class XAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
FRIENDLY_NAME = "X.AI"
REGISTRY_CLASS = XAIModelRegistry
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
def __init__(self, api_key: str, **kwargs):
"""Initialize X.AI provider with API key."""