refactor: moved registries into a separate module and code cleanup
fix: refactored dial provider to follow the same pattern
This commit is contained in:
@@ -3,8 +3,8 @@
|
||||
from .azure_openai import AzureOpenAIProvider
|
||||
from .base import ModelProvider
|
||||
from .gemini import GeminiModelProvider
|
||||
from .openai import OpenAIModelProvider
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openai_provider import OpenAIModelProvider
|
||||
from .openrouter import OpenRouterProvider
|
||||
from .registry import ModelProviderRegistry
|
||||
from .shared import ModelCapabilities, ModelResponse
|
||||
|
||||
@@ -12,9 +12,9 @@ except ImportError: # pragma: no cover
|
||||
|
||||
from utils.env import get_env, suppress_env_vars
|
||||
|
||||
from .azure_registry import AzureModelRegistry
|
||||
from .openai import OpenAIModelProvider
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openai_provider import OpenAIModelProvider
|
||||
from .registries.azure import AzureModelRegistry
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -4,11 +4,12 @@ import logging
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
from .custom_registry import CustomEndpointModelRegistry
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .registries.custom import CustomEndpointModelRegistry
|
||||
from .registries.openrouter import OpenRouterModelRegistry
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
|
||||
|
||||
class CustomProvider(OpenAICompatibleProvider):
|
||||
"""Adapter for self-hosted or local OpenAI-compatible endpoints.
|
||||
|
||||
|
||||
@@ -2,17 +2,19 @@
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||
from .registries.dial import DialModelRegistry
|
||||
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DIALModelProvider(OpenAICompatibleProvider):
|
||||
class DIALModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
|
||||
"""Client for the DIAL (Data & AI Layer) aggregation service.
|
||||
|
||||
DIAL exposes several third-party models behind a single OpenAI-compatible
|
||||
@@ -23,185 +25,13 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
FRIENDLY_NAME = "DIAL"
|
||||
|
||||
REGISTRY_CLASS = DialModelRegistry
|
||||
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
|
||||
|
||||
# Retry configuration for API calls
|
||||
MAX_RETRIES = 4
|
||||
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
MODEL_CAPABILITIES = {
|
||||
"o3-2025-04-16": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="o3-2025-04-16",
|
||||
friendly_name="DIAL (O3)",
|
||||
intelligence_score=14,
|
||||
context_window=200_000,
|
||||
max_output_tokens=100_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=False, # O3 models don't accept temperature
|
||||
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||
description="OpenAI O3 via DIAL - Strong reasoning model",
|
||||
aliases=["o3"],
|
||||
),
|
||||
"o4-mini-2025-04-16": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="o4-mini-2025-04-16",
|
||||
friendly_name="DIAL (O4-mini)",
|
||||
intelligence_score=11,
|
||||
context_window=200_000,
|
||||
max_output_tokens=100_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False, # DIAL may not expose function calling
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=False, # O4 models don't accept temperature
|
||||
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||
description="OpenAI O4-mini via DIAL - Fast reasoning model",
|
||||
aliases=["o4-mini"],
|
||||
),
|
||||
"anthropic.claude-sonnet-4.1-20250805-v1:0": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0",
|
||||
friendly_name="DIAL (Sonnet 4.1)",
|
||||
intelligence_score=10,
|
||||
context_window=200_000,
|
||||
max_output_tokens=64_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="Claude Sonnet 4.1 via DIAL - Balanced performance",
|
||||
aliases=["sonnet-4.1", "sonnet-4"],
|
||||
),
|
||||
"anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking",
|
||||
friendly_name="DIAL (Sonnet 4.1 Thinking)",
|
||||
intelligence_score=11,
|
||||
context_window=200_000,
|
||||
max_output_tokens=64_000,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="Claude Sonnet 4.1 with thinking mode via DIAL",
|
||||
aliases=["sonnet-4.1-thinking", "sonnet-4-thinking"],
|
||||
),
|
||||
"anthropic.claude-opus-4.1-20250805-v1:0": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-opus-4.1-20250805-v1:0",
|
||||
friendly_name="DIAL (Opus 4.1)",
|
||||
intelligence_score=14,
|
||||
context_window=200_000,
|
||||
max_output_tokens=64_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="Claude Opus 4.1 via DIAL - Most capable Claude model",
|
||||
aliases=["opus-4.1", "opus-4"],
|
||||
),
|
||||
"anthropic.claude-opus-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="anthropic.claude-opus-4.1-20250805-v1:0-with-thinking",
|
||||
friendly_name="DIAL (Opus 4.1 Thinking)",
|
||||
intelligence_score=15,
|
||||
context_window=200_000,
|
||||
max_output_tokens=64_000,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=5.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="Claude Opus 4.1 with thinking mode via DIAL",
|
||||
aliases=["opus-4.1-thinking", "opus-4-thinking"],
|
||||
),
|
||||
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-pro-preview-03-25-google-search",
|
||||
friendly_name="DIAL (Gemini 2.5 Pro Search)",
|
||||
intelligence_score=17,
|
||||
context_window=1_000_000,
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="Gemini 2.5 Pro with Google Search via DIAL",
|
||||
aliases=["gemini-2.5-pro-search"],
|
||||
),
|
||||
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-pro-preview-05-06",
|
||||
friendly_name="DIAL (Gemini 2.5 Pro)",
|
||||
intelligence_score=18,
|
||||
context_window=1_000_000,
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
|
||||
aliases=["gemini-2.5-pro"],
|
||||
),
|
||||
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
|
||||
provider=ProviderType.DIAL,
|
||||
model_name="gemini-2.5-flash-preview-05-20",
|
||||
friendly_name="DIAL (Gemini Flash 2.5)",
|
||||
intelligence_score=10,
|
||||
context_window=1_000_000,
|
||||
max_output_tokens=65_536,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=False,
|
||||
supports_json_mode=True,
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
|
||||
aliases=["gemini-2.5-flash"],
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize DIAL provider with API key and host.
|
||||
|
||||
@@ -209,6 +39,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
api_key: DIAL API key for authentication
|
||||
**kwargs: Additional configuration options
|
||||
"""
|
||||
self._ensure_registry()
|
||||
# Get DIAL API host from environment or kwargs
|
||||
dial_host = kwargs.get("base_url") or get_env("DIAL_API_HOST") or "https://core.dialx.ai"
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
@@ -14,7 +14,7 @@ from utils.env import get_env
|
||||
from utils.image_utils import validate_image
|
||||
|
||||
from .base import ModelProvider
|
||||
from .gemini_registry import GeminiModelRegistry
|
||||
from .registries.gemini import GeminiModelRegistry
|
||||
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||
|
||||
@@ -30,7 +30,7 @@ class GeminiModelProvider(RegistryBackedProviderMixin, ModelProvider):
|
||||
"""
|
||||
|
||||
REGISTRY_CLASS = GeminiModelRegistry
|
||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
||||
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
|
||||
|
||||
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
||||
# These percentages work across all models that support thinking
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""OpenAI model provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openai_registry import OpenAIModelRegistry
|
||||
from .registries.openai import OpenAIModelRegistry
|
||||
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
|
||||
@@ -23,7 +23,7 @@ class OpenAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider)
|
||||
"""
|
||||
|
||||
REGISTRY_CLASS = OpenAIModelRegistry
|
||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
||||
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize OpenAI provider with API key."""
|
||||
@@ -50,7 +50,7 @@ class OpenAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider)
|
||||
return builtin
|
||||
|
||||
try:
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .registries.openrouter import OpenRouterModelRegistry
|
||||
|
||||
registry = OpenRouterModelRegistry()
|
||||
config = registry.get_model_config(canonical_name)
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
from utils.env import get_env
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .registries.openrouter import OpenRouterModelRegistry
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ProviderType,
|
||||
|
||||
19
providers/registries/__init__.py
Normal file
19
providers/registries/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Registry implementations for provider capability manifests."""
|
||||
|
||||
from .azure import AzureModelRegistry
|
||||
from .custom import CustomEndpointModelRegistry
|
||||
from .dial import DialModelRegistry
|
||||
from .gemini import GeminiModelRegistry
|
||||
from .openai import OpenAIModelRegistry
|
||||
from .openrouter import OpenRouterModelRegistry
|
||||
from .xai import XAIModelRegistry
|
||||
|
||||
__all__ = [
|
||||
"AzureModelRegistry",
|
||||
"CustomEndpointModelRegistry",
|
||||
"DialModelRegistry",
|
||||
"GeminiModelRegistry",
|
||||
"OpenAIModelRegistry",
|
||||
"OpenRouterModelRegistry",
|
||||
"XAIModelRegistry",
|
||||
]
|
||||
@@ -4,8 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from .model_registry_base import CAPABILITY_FIELD_NAMES, CustomModelRegistryBase
|
||||
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
from ..shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
from .base import CAPABILITY_FIELD_NAMES, CustomModelRegistryBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -12,7 +12,7 @@ from pathlib import Path
|
||||
from utils.env import get_env
|
||||
from utils.file_utils import read_json_file
|
||||
|
||||
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
from ..shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,7 +34,7 @@ class CustomModelRegistryBase:
|
||||
self._default_filename = default_filename
|
||||
self._use_resources = False
|
||||
self._resource_package = "conf"
|
||||
self._default_path = Path(__file__).parent.parent / "conf" / default_filename
|
||||
self._default_path = Path(__file__).resolve().parents[3] / "conf" / default_filename
|
||||
|
||||
if config_path:
|
||||
self.config_path = Path(config_path)
|
||||
@@ -51,7 +51,7 @@ class CustomModelRegistryBase:
|
||||
else:
|
||||
raise AttributeError("resource accessor not available")
|
||||
except Exception:
|
||||
self.config_path = Path(__file__).parent.parent / "conf" / default_filename
|
||||
self.config_path = Path(__file__).resolve().parents[3] / "conf" / default_filename
|
||||
|
||||
self.alias_map: dict[str, str] = {}
|
||||
self.model_map: dict[str, ModelCapabilities] = {}
|
||||
@@ -213,7 +213,7 @@ class CustomModelRegistryBase:
|
||||
|
||||
|
||||
class CapabilityModelRegistry(CustomModelRegistryBase):
|
||||
"""Registry that returns `ModelCapabilities` objects with alias support."""
|
||||
"""Registry that returns :class:`ModelCapabilities` objects with alias support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Registry for models exposed via custom (local) OpenAI-compatible endpoints."""
|
||||
"""Registry loader for custom OpenAI-compatible endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .model_registry_base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
from ..shared import ModelCapabilities, ProviderType
|
||||
from .base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
||||
|
||||
|
||||
class CustomEndpointModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by ``conf/custom_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="CUSTOM_MODELS_CONFIG_PATH",
|
||||
@@ -15,11 +17,8 @@ class CustomEndpointModelRegistry(CapabilityModelRegistry):
|
||||
friendly_prefix="Custom ({model})",
|
||||
config_path=config_path,
|
||||
)
|
||||
self.reload()
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
entry["provider"] = ProviderType.CUSTOM
|
||||
entry.setdefault("friendly_name", f"Custom ({entry['model_name']})")
|
||||
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
||||
filtered.setdefault("provider", ProviderType.CUSTOM)
|
||||
capability = ModelCapabilities(**filtered)
|
||||
19
providers/registries/dial.py
Normal file
19
providers/registries/dial.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Registry loader for DIAL provider capabilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..shared import ProviderType
|
||||
from .base import CapabilityModelRegistry
|
||||
|
||||
|
||||
class DialModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by ``conf/dial_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="DIAL_MODELS_CONFIG_PATH",
|
||||
default_filename="dial_models.json",
|
||||
provider=ProviderType.DIAL,
|
||||
friendly_prefix="DIAL ({model})",
|
||||
config_path=config_path,
|
||||
)
|
||||
@@ -2,12 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .model_registry_base import CapabilityModelRegistry
|
||||
from .shared import ProviderType
|
||||
from ..shared import ProviderType
|
||||
from .base import CapabilityModelRegistry
|
||||
|
||||
|
||||
class GeminiModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by `conf/gemini_models.json`."""
|
||||
"""Capability registry backed by ``conf/gemini_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
@@ -2,12 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .model_registry_base import CapabilityModelRegistry
|
||||
from .shared import ProviderType
|
||||
from ..shared import ProviderType
|
||||
from .base import CapabilityModelRegistry
|
||||
|
||||
|
||||
class OpenAIModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by `conf/openai_models.json`."""
|
||||
"""Capability registry backed by ``conf/openai_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
@@ -2,12 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .model_registry_base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
from ..shared import ModelCapabilities, ProviderType
|
||||
from .base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
||||
|
||||
|
||||
class OpenRouterModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by `conf/openrouter_models.json`."""
|
||||
"""Capability registry backed by ``conf/openrouter_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
@@ -1,13 +1,13 @@
|
||||
"""Registry loader for X.AI (GROK) model capabilities."""
|
||||
"""Registry loader for X.AI model capabilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .model_registry_base import CapabilityModelRegistry
|
||||
from .shared import ProviderType
|
||||
from ..shared import ProviderType
|
||||
from .base import CapabilityModelRegistry
|
||||
|
||||
|
||||
class XAIModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by `conf/xai_models.json`."""
|
||||
"""Capability registry backed by ``conf/xai_models.json``."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
@@ -22,7 +22,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import ClassVar
|
||||
|
||||
from .model_registry_base import CapabilityModelRegistry
|
||||
from .registries.base import CapabilityModelRegistry
|
||||
from .shared import ModelCapabilities
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""X.AI (GROK) model provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .registries.xai import XAIModelRegistry
|
||||
from .registry_provider_mixin import RegistryBackedProviderMixin
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
from .xai_registry import XAIModelRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +24,7 @@ class XAIModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider):
|
||||
FRIENDLY_NAME = "X.AI"
|
||||
|
||||
REGISTRY_CLASS = XAIModelRegistry
|
||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
||||
MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {}
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize X.AI provider with API key."""
|
||||
|
||||
Reference in New Issue
Block a user