feat!: breaking change - OpenRouter models are now read from conf/openrouter_models.json while Custom / Self-hosted models are read from conf/custom_models.json
feat: Azure OpenAI / Azure AI Foundry support. Models should be defined in conf/azure_models.json (or a custom path). See .env.example for environment variables or see readme. https://github.com/BeehiveInnovations/zen-mcp-server/issues/265 feat: OpenRouter / Custom Models / Azure can separately also use custom config paths now (see .env.example ) refactor: Model registry class made abstract, OpenRouter / Custom Provider / Azure OpenAI now subclass these refactor: breaking change: `is_custom` property has been removed from model_capabilities.py (and thus custom_models.json) given each models are now read from separate configuration files
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Model provider abstractions for supporting multiple AI providers."""
|
||||
|
||||
from .azure_openai import AzureOpenAIProvider
|
||||
from .base import ModelProvider
|
||||
from .gemini import GeminiModelProvider
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
@@ -13,6 +14,7 @@ __all__ = [
|
||||
"ModelResponse",
|
||||
"ModelCapabilities",
|
||||
"ModelProviderRegistry",
|
||||
"AzureOpenAIProvider",
|
||||
"GeminiModelProvider",
|
||||
"OpenAIModelProvider",
|
||||
"OpenAICompatibleProvider",
|
||||
|
||||
342
providers/azure_openai.py
Normal file
342
providers/azure_openai.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""Azure OpenAI provider built on the OpenAI-compatible implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import asdict, replace
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
from openai import AzureOpenAI
|
||||
except ImportError: # pragma: no cover
|
||||
AzureOpenAI = None # type: ignore[assignment]
|
||||
|
||||
from utils.env import get_env, suppress_env_vars
|
||||
|
||||
from .azure_registry import AzureModelRegistry
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openai_provider import OpenAIModelProvider
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAIProvider(OpenAICompatibleProvider):
|
||||
"""Thin Azure wrapper that reuses the OpenAI-compatible request pipeline."""
|
||||
|
||||
FRIENDLY_NAME = "Azure OpenAI"
|
||||
DEFAULT_API_VERSION = "2024-02-15-preview"
|
||||
|
||||
# The OpenAI-compatible base expects subclasses to expose capabilities via
|
||||
# ``get_all_model_capabilities``. Azure deployments are user-defined, so we
|
||||
# build the catalogue dynamically from environment configuration instead of
|
||||
# relying on a static ``MODEL_CAPABILITIES`` map.
|
||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
*,
|
||||
azure_endpoint: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployments: dict[str, object] | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# Let the OpenAI-compatible base handle shared configuration such as
|
||||
# timeouts, restriction-aware allowlists, and logging. ``base_url`` maps
|
||||
# directly onto Azure's endpoint URL.
|
||||
super().__init__(api_key, base_url=azure_endpoint, **kwargs)
|
||||
|
||||
if not azure_endpoint:
|
||||
azure_endpoint = get_env("AZURE_OPENAI_ENDPOINT")
|
||||
if not azure_endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required via parameter or AZURE_OPENAI_ENDPOINT")
|
||||
|
||||
self.azure_endpoint = azure_endpoint.rstrip("/")
|
||||
self.api_version = api_version or get_env("AZURE_OPENAI_API_VERSION", self.DEFAULT_API_VERSION)
|
||||
|
||||
registry_specs = self._load_registry_entries()
|
||||
override_specs = self._normalise_deployments(deployments or {}) if deployments else {}
|
||||
|
||||
self._model_specs = self._merge_specs(registry_specs, override_specs)
|
||||
if not self._model_specs:
|
||||
raise ValueError(
|
||||
"Azure OpenAI provider requires at least one configured deployment. "
|
||||
"Populate conf/azure_models.json or set AZURE_MODELS_CONFIG_PATH."
|
||||
)
|
||||
|
||||
self._capabilities = self._build_capabilities_map()
|
||||
self._deployment_map = {name: spec["deployment"] for name, spec in self._model_specs.items()}
|
||||
self._deployment_alias_lookup = {
|
||||
deployment.lower(): canonical for canonical, deployment in self._deployment_map.items()
|
||||
}
|
||||
self._canonical_lookup = {name.lower(): name for name in self._model_specs.keys()}
|
||||
self._invalidate_capability_cache()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Capability helpers
|
||||
# ------------------------------------------------------------------
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
return dict(self._capabilities)
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
return ProviderType.AZURE
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities: # type: ignore[override]
|
||||
lowered = model_name.lower()
|
||||
if lowered in self._deployment_alias_lookup:
|
||||
canonical = self._deployment_alias_lookup[lowered]
|
||||
return super().get_capabilities(canonical)
|
||||
canonical = self._canonical_lookup.get(lowered)
|
||||
if canonical:
|
||||
return super().get_capabilities(canonical)
|
||||
return super().get_capabilities(model_name)
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool: # type: ignore[override]
|
||||
lowered = model_name.lower()
|
||||
if lowered in self._deployment_alias_lookup or lowered in self._canonical_lookup:
|
||||
return True
|
||||
return super().validate_model_name(model_name)
|
||||
|
||||
def _build_capabilities_map(self) -> dict[str, ModelCapabilities]:
|
||||
capabilities: dict[str, ModelCapabilities] = {}
|
||||
|
||||
for canonical_name, spec in self._model_specs.items():
|
||||
template_capability: ModelCapabilities | None = spec.get("capability")
|
||||
overrides = spec.get("overrides", {})
|
||||
|
||||
if template_capability:
|
||||
cloned = replace(template_capability)
|
||||
else:
|
||||
template = OpenAIModelProvider.MODEL_CAPABILITIES.get(canonical_name)
|
||||
|
||||
if template:
|
||||
friendly = template.friendly_name.replace("OpenAI", "Azure OpenAI", 1)
|
||||
cloned = replace(
|
||||
template,
|
||||
provider=ProviderType.AZURE,
|
||||
friendly_name=friendly,
|
||||
aliases=list(template.aliases),
|
||||
)
|
||||
else:
|
||||
deployment_name = spec.get("deployment", "")
|
||||
cloned = ModelCapabilities(
|
||||
provider=ProviderType.AZURE,
|
||||
model_name=canonical_name,
|
||||
friendly_name=f"Azure OpenAI ({canonical_name})",
|
||||
description=f"Azure deployment '{deployment_name}' for {canonical_name}",
|
||||
aliases=[],
|
||||
)
|
||||
|
||||
if overrides:
|
||||
overrides = dict(overrides)
|
||||
temp_override = overrides.get("temperature_constraint")
|
||||
if isinstance(temp_override, str):
|
||||
overrides["temperature_constraint"] = TemperatureConstraint.create(temp_override)
|
||||
|
||||
aliases_override = overrides.get("aliases")
|
||||
if isinstance(aliases_override, str):
|
||||
overrides["aliases"] = [alias.strip() for alias in aliases_override.split(",") if alias.strip()]
|
||||
provider_override = overrides.get("provider")
|
||||
if provider_override:
|
||||
overrides.pop("provider", None)
|
||||
|
||||
try:
|
||||
cloned = replace(cloned, **overrides)
|
||||
except TypeError:
|
||||
base_data = asdict(cloned)
|
||||
base_data.update(overrides)
|
||||
base_data["provider"] = ProviderType.AZURE
|
||||
temp_value = base_data.get("temperature_constraint")
|
||||
if isinstance(temp_value, str):
|
||||
base_data["temperature_constraint"] = TemperatureConstraint.create(temp_value)
|
||||
cloned = ModelCapabilities(**base_data)
|
||||
|
||||
if cloned.provider != ProviderType.AZURE:
|
||||
cloned.provider = ProviderType.AZURE
|
||||
|
||||
capabilities[canonical_name] = cloned
|
||||
|
||||
return capabilities
|
||||
|
||||
def _load_registry_entries(self) -> dict[str, dict]:
|
||||
try:
|
||||
registry = AzureModelRegistry()
|
||||
except Exception as exc: # pragma: no cover - registry failure should not crash provider
|
||||
logger.warning("Unable to load Azure model registry: %s", exc)
|
||||
return {}
|
||||
|
||||
entries: dict[str, dict] = {}
|
||||
for model_name, capability, extra in registry.iter_entries():
|
||||
deployment = extra.get("deployment")
|
||||
if not deployment:
|
||||
logger.warning("Azure model '%s' missing deployment in registry", model_name)
|
||||
continue
|
||||
entries[model_name] = {"deployment": deployment, "capability": capability}
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def _merge_specs(
|
||||
registry_specs: dict[str, dict],
|
||||
override_specs: dict[str, dict],
|
||||
) -> dict[str, dict]:
|
||||
specs: dict[str, dict] = {}
|
||||
|
||||
for canonical, entry in registry_specs.items():
|
||||
specs[canonical] = {
|
||||
"deployment": entry.get("deployment"),
|
||||
"capability": entry.get("capability"),
|
||||
"overrides": {},
|
||||
}
|
||||
|
||||
for canonical, entry in override_specs.items():
|
||||
spec = specs.get(canonical, {"deployment": None, "capability": None, "overrides": {}})
|
||||
deployment = entry.get("deployment")
|
||||
if deployment:
|
||||
spec["deployment"] = deployment
|
||||
overrides = {k: v for k, v in entry.items() if k not in {"deployment"}}
|
||||
overrides.pop("capability", None)
|
||||
if overrides:
|
||||
spec["overrides"].update(overrides)
|
||||
specs[canonical] = spec
|
||||
|
||||
return {k: v for k, v in specs.items() if v.get("deployment")}
|
||||
|
||||
@staticmethod
|
||||
def _normalise_deployments(mapping: dict[str, object]) -> dict[str, dict]:
|
||||
normalised: dict[str, dict] = {}
|
||||
for canonical, spec in mapping.items():
|
||||
canonical_name = (canonical or "").strip()
|
||||
if not canonical_name:
|
||||
continue
|
||||
|
||||
deployment_name: str | None = None
|
||||
overrides: dict[str, object] = {}
|
||||
|
||||
if isinstance(spec, str):
|
||||
deployment_name = spec.strip()
|
||||
elif isinstance(spec, dict):
|
||||
deployment_name = spec.get("deployment") or spec.get("deployment_name")
|
||||
overrides = {k: v for k, v in spec.items() if k not in {"deployment", "deployment_name"}}
|
||||
|
||||
if not deployment_name:
|
||||
continue
|
||||
|
||||
normalised[canonical_name] = {"deployment": deployment_name.strip(), **overrides}
|
||||
|
||||
return normalised
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Azure-specific configuration
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def client(self): # type: ignore[override]
|
||||
"""Instantiate the Azure OpenAI client on first use."""
|
||||
|
||||
if self._client is None:
|
||||
if AzureOpenAI is None:
|
||||
raise ImportError(
|
||||
"Azure OpenAI support requires the 'openai' package. Install it with `pip install openai`."
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]
|
||||
|
||||
with suppress_env_vars(*proxy_env_vars):
|
||||
try:
|
||||
timeout_config = self.timeout_config
|
||||
|
||||
http_client = httpx.Client(timeout=timeout_config, follow_redirects=True)
|
||||
|
||||
client_kwargs = {
|
||||
"api_key": self.api_key,
|
||||
"azure_endpoint": self.azure_endpoint,
|
||||
"api_version": self.api_version,
|
||||
"http_client": http_client,
|
||||
}
|
||||
|
||||
if self.DEFAULT_HEADERS:
|
||||
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||||
|
||||
logger.debug(
|
||||
"Initializing Azure OpenAI client endpoint=%s api_version=%s timeouts=%s",
|
||||
self.azure_endpoint,
|
||||
self.api_version,
|
||||
timeout_config,
|
||||
)
|
||||
|
||||
self._client = AzureOpenAI(**client_kwargs)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Failed to create Azure OpenAI client: %s", exc)
|
||||
raise
|
||||
|
||||
return self._client
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request delegation
|
||||
# ------------------------------------------------------------------
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: str | None = None,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: int | None = None,
|
||||
images: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
canonical_name, deployment_name = self._resolve_canonical_and_deployment(model_name)
|
||||
|
||||
# Delegate to the shared OpenAI-compatible implementation using the
|
||||
# deployment name – Azure requires the deployment identifier in the
|
||||
# ``model`` field. The returned ``ModelResponse`` is normalised so
|
||||
# downstream consumers continue to see the canonical model name.
|
||||
raw_response = super().generate_content(
|
||||
prompt=prompt,
|
||||
model_name=deployment_name,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
images=images,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
capabilities = self._capabilities.get(canonical_name)
|
||||
friendly_name = capabilities.friendly_name if capabilities else self.FRIENDLY_NAME
|
||||
|
||||
return ModelResponse(
|
||||
content=raw_response.content,
|
||||
usage=raw_response.usage,
|
||||
model_name=canonical_name,
|
||||
friendly_name=friendly_name,
|
||||
provider=ProviderType.AZURE,
|
||||
metadata={**raw_response.metadata, "deployment": deployment_name},
|
||||
)
|
||||
|
||||
def _resolve_canonical_and_deployment(self, model_name: str) -> tuple[str, str]:
|
||||
resolved_canonical = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_canonical not in self._deployment_map:
|
||||
# The base resolver may hand back the deployment alias. Try to map it
|
||||
# back to a canonical entry.
|
||||
for canonical, deployment in self._deployment_map.items():
|
||||
if deployment.lower() == resolved_canonical.lower():
|
||||
return canonical, deployment
|
||||
raise ValueError(f"Model '{model_name}' is not configured for Azure OpenAI")
|
||||
|
||||
return resolved_canonical, self._deployment_map[resolved_canonical]
|
||||
|
||||
def _parse_allowed_models(self) -> set[str] | None: # type: ignore[override]
|
||||
# Support both AZURE_ALLOWED_MODELS (inherited behaviour) and the
|
||||
# clearer AZURE_OPENAI_ALLOWED_MODELS alias.
|
||||
explicit = get_env("AZURE_OPENAI_ALLOWED_MODELS")
|
||||
if explicit:
|
||||
models = {m.strip().lower() for m in explicit.split(",") if m.strip()}
|
||||
if models:
|
||||
logger.info("Configured allowed models for Azure OpenAI: %s", sorted(models))
|
||||
self._allowed_alias_cache = {}
|
||||
return models
|
||||
|
||||
return super()._parse_allowed_models()
|
||||
45
providers/azure_registry.py
Normal file
45
providers/azure_registry.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Registry loader for Azure OpenAI model configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from .model_registry_base import CAPABILITY_FIELD_NAMES, CustomModelRegistryBase
|
||||
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureModelRegistry(CustomModelRegistryBase):
|
||||
"""Load Azure-specific model metadata from configuration files."""
|
||||
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="AZURE_MODELS_CONFIG_PATH",
|
||||
default_filename="azure_models.json",
|
||||
config_path=config_path,
|
||||
)
|
||||
self.reload()
|
||||
|
||||
def _extra_keys(self) -> set[str]:
|
||||
return {"deployment", "deployment_name"}
|
||||
|
||||
def _provider_default(self) -> ProviderType:
|
||||
return ProviderType.AZURE
|
||||
|
||||
def _default_friendly_name(self, model_name: str) -> str:
|
||||
return f"Azure OpenAI ({model_name})"
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
deployment = entry.pop("deployment", None) or entry.pop("deployment_name", None)
|
||||
if not deployment:
|
||||
raise ValueError(f"Azure model '{entry.get('model_name')}' is missing required 'deployment' field")
|
||||
|
||||
temp_hint = entry.get("temperature_constraint")
|
||||
if isinstance(temp_hint, str):
|
||||
entry["temperature_constraint"] = TemperatureConstraint.create(temp_hint)
|
||||
|
||||
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
||||
filtered.setdefault("provider", ProviderType.AZURE)
|
||||
capability = ModelCapabilities(**filtered)
|
||||
return capability, {"deployment": deployment}
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Custom API provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
from .custom_registry import CustomEndpointModelRegistry
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
@@ -31,8 +31,8 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
|
||||
FRIENDLY_NAME = "Custom API"
|
||||
|
||||
# Model registry for managing configurations and aliases (shared with OpenRouter)
|
||||
_registry: Optional[OpenRouterModelRegistry] = None
|
||||
# Model registry for managing configurations and aliases
|
||||
_registry: CustomEndpointModelRegistry | None = None
|
||||
|
||||
def __init__(self, api_key: str = "", base_url: str = "", **kwargs):
|
||||
"""Initialize Custom provider for local/self-hosted models.
|
||||
@@ -78,9 +78,9 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
|
||||
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||
|
||||
# Initialize model registry (shared with OpenRouter for consistent aliases)
|
||||
# Initialize model registry
|
||||
if CustomProvider._registry is None:
|
||||
CustomProvider._registry = OpenRouterModelRegistry()
|
||||
CustomProvider._registry = CustomEndpointModelRegistry()
|
||||
# Log loaded models and aliases only on first load
|
||||
models = self._registry.list_models()
|
||||
aliases = self._registry.list_aliases()
|
||||
@@ -92,8 +92,8 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
def _lookup_capabilities(
|
||||
self,
|
||||
canonical_name: str,
|
||||
requested_name: Optional[str] = None,
|
||||
) -> Optional[ModelCapabilities]:
|
||||
requested_name: str | None = None,
|
||||
) -> ModelCapabilities | None:
|
||||
"""Return capabilities for models explicitly marked as custom."""
|
||||
|
||||
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||
@@ -101,12 +101,12 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
return builtin
|
||||
|
||||
registry_entry = self._registry.resolve(canonical_name)
|
||||
if registry_entry and getattr(registry_entry, "is_custom", False):
|
||||
if registry_entry:
|
||||
registry_entry.provider = ProviderType.CUSTOM
|
||||
return registry_entry
|
||||
|
||||
logging.debug(
|
||||
"Custom provider cannot resolve model '%s'; ensure it is declared with 'is_custom': true in custom_models.json",
|
||||
"Custom provider cannot resolve model '%s'; ensure it is declared in custom_models.json",
|
||||
canonical_name,
|
||||
)
|
||||
return None
|
||||
@@ -151,6 +151,15 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
return base_model
|
||||
|
||||
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||
# Attempt to resolve via OpenRouter registry so aliases still map cleanly
|
||||
openrouter_registry = OpenRouterModelRegistry()
|
||||
openrouter_config = openrouter_registry.resolve(model_name)
|
||||
if openrouter_config:
|
||||
resolved = openrouter_config.model_name
|
||||
self._alias_cache[cache_key] = resolved
|
||||
self._alias_cache.setdefault(resolved.lower(), resolved)
|
||||
return resolved
|
||||
|
||||
self._alias_cache[cache_key] = model_name
|
||||
return model_name
|
||||
|
||||
@@ -160,9 +169,9 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
if not self._registry:
|
||||
return {}
|
||||
|
||||
capabilities: dict[str, ModelCapabilities] = {}
|
||||
for model_name in self._registry.list_models():
|
||||
config = self._registry.resolve(model_name)
|
||||
if config and getattr(config, "is_custom", False):
|
||||
capabilities[model_name] = config
|
||||
capabilities = {}
|
||||
for model in self._registry.list_models():
|
||||
config = self._registry.resolve(model)
|
||||
if config:
|
||||
capabilities[model] = config
|
||||
return capabilities
|
||||
|
||||
26
providers/custom_registry.py
Normal file
26
providers/custom_registry.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Registry for models exposed via custom (local) OpenAI-compatible endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .model_registry_base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
|
||||
|
||||
class CustomEndpointModelRegistry(CapabilityModelRegistry):
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="CUSTOM_MODELS_CONFIG_PATH",
|
||||
default_filename="custom_models.json",
|
||||
provider=ProviderType.CUSTOM,
|
||||
friendly_prefix="Custom ({model})",
|
||||
config_path=config_path,
|
||||
)
|
||||
self.reload()
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
entry["provider"] = ProviderType.CUSTOM
|
||||
entry.setdefault("friendly_name", f"Custom ({entry['model_name']})")
|
||||
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
||||
filtered.setdefault("provider", ProviderType.CUSTOM)
|
||||
capability = ModelCapabilities(**filtered)
|
||||
return capability, {}
|
||||
241
providers/model_registry_base.py
Normal file
241
providers/model_registry_base.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Shared infrastructure for JSON-backed model registries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import fields
|
||||
from pathlib import Path
|
||||
|
||||
from utils.env import get_env
|
||||
from utils.file_utils import read_json_file
|
||||
|
||||
from .shared import ModelCapabilities, ProviderType, TemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CAPABILITY_FIELD_NAMES = {field.name for field in fields(ModelCapabilities)}
|
||||
|
||||
|
||||
class CustomModelRegistryBase:
|
||||
"""Load and expose capability metadata from a JSON manifest."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
env_var_name: str,
|
||||
default_filename: str,
|
||||
config_path: str | None = None,
|
||||
) -> None:
|
||||
self._env_var_name = env_var_name
|
||||
self._default_filename = default_filename
|
||||
self._use_resources = False
|
||||
self._resource_package = "conf"
|
||||
self._default_path = Path(__file__).parent.parent / "conf" / default_filename
|
||||
|
||||
if config_path:
|
||||
self.config_path = Path(config_path)
|
||||
else:
|
||||
env_path = get_env(env_var_name)
|
||||
if env_path:
|
||||
self.config_path = Path(env_path)
|
||||
else:
|
||||
try:
|
||||
resource = importlib.resources.files(self._resource_package).joinpath(default_filename)
|
||||
if hasattr(resource, "read_text"):
|
||||
self._use_resources = True
|
||||
self.config_path = None
|
||||
else:
|
||||
raise AttributeError("resource accessor not available")
|
||||
except Exception:
|
||||
self.config_path = Path(__file__).parent.parent / "conf" / default_filename
|
||||
|
||||
self.alias_map: dict[str, str] = {}
|
||||
self.model_map: dict[str, ModelCapabilities] = {}
|
||||
self._extras: dict[str, dict] = {}
|
||||
|
||||
def reload(self) -> None:
|
||||
data = self._load_config_data()
|
||||
configs = [config for config in self._parse_models(data) if config is not None]
|
||||
self._build_maps(configs)
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
return list(self.model_map.keys())
|
||||
|
||||
def list_aliases(self) -> list[str]:
|
||||
return list(self.alias_map.keys())
|
||||
|
||||
def resolve(self, name_or_alias: str) -> ModelCapabilities | None:
|
||||
key = name_or_alias.lower()
|
||||
canonical = self.alias_map.get(key)
|
||||
if canonical:
|
||||
return self.model_map.get(canonical)
|
||||
|
||||
for model_name in self.model_map:
|
||||
if model_name.lower() == key:
|
||||
return self.model_map[model_name]
|
||||
return None
|
||||
|
||||
def get_capabilities(self, name_or_alias: str) -> ModelCapabilities | None:
|
||||
return self.resolve(name_or_alias)
|
||||
|
||||
def get_entry(self, model_name: str) -> dict | None:
|
||||
return self._extras.get(model_name)
|
||||
|
||||
def iter_entries(self) -> Iterable[tuple[str, ModelCapabilities, dict]]:
|
||||
for model_name, capability in self.model_map.items():
|
||||
yield model_name, capability, self._extras.get(model_name, {})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _load_config_data(self) -> dict:
|
||||
if self._use_resources:
|
||||
try:
|
||||
resource = importlib.resources.files(self._resource_package).joinpath(self._default_filename)
|
||||
if hasattr(resource, "read_text"):
|
||||
config_text = resource.read_text(encoding="utf-8")
|
||||
else: # pragma: no cover - legacy Python fallback
|
||||
with resource.open("r", encoding="utf-8") as handle:
|
||||
config_text = handle.read()
|
||||
data = json.loads(config_text)
|
||||
except FileNotFoundError:
|
||||
logger.debug("Packaged %s not found", self._default_filename)
|
||||
return {"models": []}
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read packaged %s: %s", self._default_filename, exc)
|
||||
return {"models": []}
|
||||
return data or {"models": []}
|
||||
|
||||
if not self.config_path:
|
||||
raise FileNotFoundError("Registry configuration path is not set")
|
||||
|
||||
if not self.config_path.exists():
|
||||
logger.debug("Model registry config not found at %s", self.config_path)
|
||||
if self.config_path == self._default_path:
|
||||
fallback = Path.cwd() / "conf" / self._default_filename
|
||||
if fallback != self.config_path and fallback.exists():
|
||||
logger.debug("Falling back to %s", fallback)
|
||||
self.config_path = fallback
|
||||
else:
|
||||
return {"models": []}
|
||||
else:
|
||||
return {"models": []}
|
||||
|
||||
data = read_json_file(str(self.config_path))
|
||||
return data or {"models": []}
|
||||
|
||||
@property
|
||||
def use_resources(self) -> bool:
|
||||
return self._use_resources
|
||||
|
||||
def _parse_models(self, data: dict) -> Iterable[ModelCapabilities | None]:
|
||||
for raw in data.get("models", []):
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
yield self._convert_entry(raw)
|
||||
|
||||
def _convert_entry(self, raw: dict) -> ModelCapabilities | None:
|
||||
entry = dict(raw)
|
||||
model_name = entry.get("model_name")
|
||||
if not model_name:
|
||||
return None
|
||||
|
||||
aliases = entry.get("aliases")
|
||||
if isinstance(aliases, str):
|
||||
entry["aliases"] = [alias.strip() for alias in aliases.split(",") if alias.strip()]
|
||||
|
||||
entry.setdefault("friendly_name", self._default_friendly_name(model_name))
|
||||
|
||||
temperature_hint = entry.get("temperature_constraint")
|
||||
if isinstance(temperature_hint, str):
|
||||
entry["temperature_constraint"] = TemperatureConstraint.create(temperature_hint)
|
||||
elif temperature_hint is None:
|
||||
entry["temperature_constraint"] = TemperatureConstraint.create("range")
|
||||
|
||||
if "max_tokens" in entry:
|
||||
raise ValueError(
|
||||
"`max_tokens` is no longer supported. Use `max_output_tokens` in your model configuration."
|
||||
)
|
||||
|
||||
unknown_keys = set(entry.keys()) - CAPABILITY_FIELD_NAMES - self._extra_keys()
|
||||
if unknown_keys:
|
||||
raise ValueError("Unsupported fields in model configuration: " + ", ".join(sorted(unknown_keys)))
|
||||
|
||||
capability, extras = self._finalise_entry(entry)
|
||||
capability.provider = self._provider_default()
|
||||
self._extras[capability.model_name] = extras or {}
|
||||
return capability
|
||||
|
||||
def _default_friendly_name(self, model_name: str) -> str:
|
||||
return model_name
|
||||
|
||||
def _extra_keys(self) -> set[str]:
|
||||
return set()
|
||||
|
||||
def _provider_default(self) -> ProviderType:
|
||||
return ProviderType.OPENROUTER
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
return ModelCapabilities(**{k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}), {}
|
||||
|
||||
def _build_maps(self, configs: Iterable[ModelCapabilities]) -> None:
|
||||
alias_map: dict[str, str] = {}
|
||||
model_map: dict[str, ModelCapabilities] = {}
|
||||
|
||||
for config in configs:
|
||||
if not config:
|
||||
continue
|
||||
model_map[config.model_name] = config
|
||||
|
||||
model_name_lower = config.model_name.lower()
|
||||
if model_name_lower not in alias_map:
|
||||
alias_map[model_name_lower] = config.model_name
|
||||
|
||||
for alias in config.aliases:
|
||||
alias_lower = alias.lower()
|
||||
if alias_lower in alias_map and alias_map[alias_lower] != config.model_name:
|
||||
raise ValueError(
|
||||
f"Duplicate alias '{alias}' found for models '{alias_map[alias_lower]}' and '{config.model_name}'"
|
||||
)
|
||||
alias_map[alias_lower] = config.model_name
|
||||
|
||||
self.alias_map = alias_map
|
||||
self.model_map = model_map
|
||||
|
||||
|
||||
class CapabilityModelRegistry(CustomModelRegistryBase):
|
||||
"""Registry that returns `ModelCapabilities` objects with alias support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
env_var_name: str,
|
||||
default_filename: str,
|
||||
provider: ProviderType,
|
||||
friendly_prefix: str,
|
||||
config_path: str | None = None,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._friendly_prefix = friendly_prefix
|
||||
super().__init__(
|
||||
env_var_name=env_var_name,
|
||||
default_filename=default_filename,
|
||||
config_path=config_path,
|
||||
)
|
||||
self.reload()
|
||||
|
||||
def _provider_default(self) -> ProviderType:
|
||||
return self._provider
|
||||
|
||||
def _default_friendly_name(self, model_name: str) -> str:
|
||||
return self._friendly_prefix.format(model=model_name)
|
||||
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
||||
filtered.setdefault("provider", self._provider_default())
|
||||
capability = ModelCapabilities(**filtered)
|
||||
return capability, {}
|
||||
@@ -8,7 +8,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from utils.env import get_env
|
||||
from utils.env import get_env, suppress_env_vars
|
||||
from utils.image_utils import validate_image
|
||||
|
||||
from .base import ModelProvider
|
||||
@@ -257,80 +257,74 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
def client(self):
|
||||
"""Lazy initialization of OpenAI client with security checks and timeout configuration."""
|
||||
if self._client is None:
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
# Temporarily disable proxy environment variables to prevent httpx from detecting them
|
||||
original_env = {}
|
||||
proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]
|
||||
|
||||
for var in proxy_env_vars:
|
||||
if var in os.environ:
|
||||
original_env[var] = os.environ[var]
|
||||
del os.environ[var]
|
||||
|
||||
try:
|
||||
# Create a custom httpx client that explicitly avoids proxy parameters
|
||||
timeout_config = (
|
||||
self.timeout_config
|
||||
if hasattr(self, "timeout_config") and self.timeout_config
|
||||
else httpx.Timeout(30.0)
|
||||
)
|
||||
|
||||
# Create httpx client with minimal config to avoid proxy conflicts
|
||||
# Note: proxies parameter was removed in httpx 0.28.0
|
||||
# Check for test transport injection
|
||||
if hasattr(self, "_test_transport"):
|
||||
# Use custom transport for testing (HTTP recording/replay)
|
||||
http_client = httpx.Client(
|
||||
transport=self._test_transport,
|
||||
timeout=timeout_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
else:
|
||||
# Normal production client
|
||||
http_client = httpx.Client(
|
||||
timeout=timeout_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
# Keep client initialization minimal to avoid proxy parameter conflicts
|
||||
client_kwargs = {
|
||||
"api_key": self.api_key,
|
||||
"http_client": http_client,
|
||||
}
|
||||
|
||||
if self.base_url:
|
||||
client_kwargs["base_url"] = self.base_url
|
||||
|
||||
if self.organization:
|
||||
client_kwargs["organization"] = self.organization
|
||||
|
||||
# Add default headers if any
|
||||
if self.DEFAULT_HEADERS:
|
||||
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||||
|
||||
logging.debug(f"OpenAI client initialized with custom httpx client and timeout: {timeout_config}")
|
||||
|
||||
# Create OpenAI client with custom httpx client
|
||||
self._client = OpenAI(**client_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
# If all else fails, try absolute minimal client without custom httpx
|
||||
logging.warning(f"Failed to create client with custom httpx, falling back to minimal config: {e}")
|
||||
with suppress_env_vars(*proxy_env_vars):
|
||||
try:
|
||||
minimal_kwargs = {"api_key": self.api_key}
|
||||
# Create a custom httpx client that explicitly avoids proxy parameters
|
||||
timeout_config = (
|
||||
self.timeout_config
|
||||
if hasattr(self, "timeout_config") and self.timeout_config
|
||||
else httpx.Timeout(30.0)
|
||||
)
|
||||
|
||||
# Create httpx client with minimal config to avoid proxy conflicts
|
||||
# Note: proxies parameter was removed in httpx 0.28.0
|
||||
# Check for test transport injection
|
||||
if hasattr(self, "_test_transport"):
|
||||
# Use custom transport for testing (HTTP recording/replay)
|
||||
http_client = httpx.Client(
|
||||
transport=self._test_transport,
|
||||
timeout=timeout_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
else:
|
||||
# Normal production client
|
||||
http_client = httpx.Client(
|
||||
timeout=timeout_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
# Keep client initialization minimal to avoid proxy parameter conflicts
|
||||
client_kwargs = {
|
||||
"api_key": self.api_key,
|
||||
"http_client": http_client,
|
||||
}
|
||||
|
||||
if self.base_url:
|
||||
minimal_kwargs["base_url"] = self.base_url
|
||||
self._client = OpenAI(**minimal_kwargs)
|
||||
except Exception as fallback_error:
|
||||
logging.error(f"Even minimal OpenAI client creation failed: {fallback_error}")
|
||||
raise
|
||||
finally:
|
||||
# Restore original proxy environment variables
|
||||
for var, value in original_env.items():
|
||||
os.environ[var] = value
|
||||
client_kwargs["base_url"] = self.base_url
|
||||
|
||||
if self.organization:
|
||||
client_kwargs["organization"] = self.organization
|
||||
|
||||
# Add default headers if any
|
||||
if self.DEFAULT_HEADERS:
|
||||
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||||
|
||||
logging.debug(
|
||||
"OpenAI client initialized with custom httpx client and timeout: %s",
|
||||
timeout_config,
|
||||
)
|
||||
|
||||
# Create OpenAI client with custom httpx client
|
||||
self._client = OpenAI(**client_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
# If all else fails, try absolute minimal client without custom httpx
|
||||
logging.warning(
|
||||
"Failed to create client with custom httpx, falling back to minimal config: %s",
|
||||
e,
|
||||
)
|
||||
try:
|
||||
minimal_kwargs = {"api_key": self.api_key}
|
||||
if self.base_url:
|
||||
minimal_kwargs["base_url"] = self.base_url
|
||||
self._client = OpenAI(**minimal_kwargs)
|
||||
except Exception as fallback_error:
|
||||
logging.error("Even minimal OpenAI client creation failed: %s", fallback_error)
|
||||
raise
|
||||
|
||||
return self._client
|
||||
|
||||
|
||||
@@ -103,16 +103,16 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
model_name="o3-mini",
|
||||
friendly_name="OpenAI (O3-mini)",
|
||||
intelligence_score=12,
|
||||
context_window=200_000, # 200K tokens
|
||||
max_output_tokens=65536, # 64K max output tokens
|
||||
context_window=200_000,
|
||||
max_output_tokens=65536,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O3 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=False,
|
||||
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||
description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||
aliases=["o3mini"],
|
||||
@@ -122,16 +122,16 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
model_name="o3-pro",
|
||||
friendly_name="OpenAI (O3-Pro)",
|
||||
intelligence_score=15,
|
||||
context_window=200_000, # 200K tokens
|
||||
max_output_tokens=65536, # 64K max output tokens
|
||||
context_window=200_000,
|
||||
max_output_tokens=65536,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O3 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=False,
|
||||
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||
description="Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
||||
aliases=["o3pro"],
|
||||
@@ -141,16 +141,15 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
model_name="o4-mini",
|
||||
friendly_name="OpenAI (O4-mini)",
|
||||
intelligence_score=11,
|
||||
context_window=200_000, # 200K tokens
|
||||
max_output_tokens=65536, # 64K max output tokens
|
||||
context_window=200_000,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # O4 models support vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=False, # O4 models don't accept temperature parameter
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=False,
|
||||
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
||||
aliases=["o4mini"],
|
||||
@@ -160,16 +159,16 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
model_name="gpt-4.1",
|
||||
friendly_name="OpenAI (GPT 4.1)",
|
||||
intelligence_score=13,
|
||||
context_window=1_000_000, # 1M tokens
|
||||
context_window=1_000_000,
|
||||
max_output_tokens=32_768,
|
||||
supports_extended_thinking=False,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # GPT-4.1 supports vision
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_temperature=True, # Regular models accept temperature parameter
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
||||
aliases=["gpt4.1"],
|
||||
@@ -178,19 +177,19 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name="gpt-5-codex",
|
||||
friendly_name="OpenAI (GPT-5 Codex)",
|
||||
intelligence_score=17, # Higher than GPT-5 for coding tasks
|
||||
context_window=400_000, # 400K tokens (same as GPT-5)
|
||||
max_output_tokens=128_000, # 128K output tokens
|
||||
supports_extended_thinking=True, # Responses API supports reasoning tokens
|
||||
intelligence_score=17,
|
||||
context_window=400_000,
|
||||
max_output_tokens=128_000,
|
||||
supports_extended_thinking=True,
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True, # Enhanced for agentic software engineering
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
supports_images=True, # Screenshots, wireframes, diagrams
|
||||
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||
supports_images=True,
|
||||
max_image_size_mb=20.0,
|
||||
supports_temperature=True,
|
||||
temperature_constraint=TemperatureConstraint.create("range"),
|
||||
description="GPT-5 Codex (400K context) - Uses Responses API for 40-80% cost savings. Specialized for coding, refactoring, and software architecture. 3% better performance on SWE-bench.",
|
||||
description="GPT-5 Codex (400K context) Specialized for coding, refactoring, and software architecture.",
|
||||
aliases=["gpt5-codex", "codex", "gpt-5-code", "gpt5-code"],
|
||||
),
|
||||
}
|
||||
@@ -282,7 +281,7 @@ class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
|
||||
if category == ToolModelCategory.EXTENDED_REASONING:
|
||||
# Prefer models with extended thinking support
|
||||
# GPT-5-Codex first for coding tasks (uses Responses API with 40-80% cost savings)
|
||||
# GPT-5-Codex first for coding tasks
|
||||
preferred = find_first(["gpt-5-codex", "o3", "o3-pro", "gpt-5"])
|
||||
return preferred if preferred else allowed_models[0]
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""OpenRouter provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
@@ -42,7 +41,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
}
|
||||
|
||||
# Model registry for managing configurations and aliases
|
||||
_registry: Optional[OpenRouterModelRegistry] = None
|
||||
_registry: OpenRouterModelRegistry | None = None
|
||||
|
||||
def __init__(self, api_key: str, **kwargs):
|
||||
"""Initialize OpenRouter provider.
|
||||
@@ -70,8 +69,8 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
def _lookup_capabilities(
|
||||
self,
|
||||
canonical_name: str,
|
||||
requested_name: Optional[str] = None,
|
||||
) -> Optional[ModelCapabilities]:
|
||||
requested_name: str | None = None,
|
||||
) -> ModelCapabilities | None:
|
||||
"""Fetch OpenRouter capabilities from the registry or build a generic fallback."""
|
||||
|
||||
capabilities = self._registry.get_capabilities(canonical_name)
|
||||
@@ -143,7 +142,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
# Custom models belong to CustomProvider; skip them here so the two
|
||||
# providers don't race over the same registrations (important for tests
|
||||
# that stub the registry with minimal objects lacking attrs).
|
||||
if hasattr(config, "is_custom") and config.is_custom is True:
|
||||
if config.provider == ProviderType.CUSTOM:
|
||||
continue
|
||||
|
||||
if restriction_service:
|
||||
@@ -211,7 +210,7 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
continue
|
||||
|
||||
# See note in list_models: respect the CustomProvider boundary.
|
||||
if hasattr(config, "is_custom") and config.is_custom is True:
|
||||
if config.provider == ProviderType.CUSTOM:
|
||||
continue
|
||||
|
||||
capabilities[model_name] = config
|
||||
|
||||
@@ -1,293 +1,38 @@
|
||||
"""OpenRouter model registry for managing model configurations and aliases."""
|
||||
|
||||
import importlib.resources
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
# Import handled via importlib.resources.files() calls directly
|
||||
from utils.file_utils import read_json_file
|
||||
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ProviderType,
|
||||
TemperatureConstraint,
|
||||
)
|
||||
from .model_registry_base import CAPABILITY_FIELD_NAMES, CapabilityModelRegistry
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
|
||||
|
||||
class OpenRouterModelRegistry:
|
||||
"""In-memory view of OpenRouter and custom model metadata.
|
||||
class OpenRouterModelRegistry(CapabilityModelRegistry):
|
||||
"""Capability registry backed by `conf/openrouter_models.json`."""
|
||||
|
||||
Role
|
||||
Parse the packaged ``conf/custom_models.json`` (or user-specified
|
||||
overrides), construct alias and capability maps, and serve those
|
||||
structures to providers that rely on OpenRouter semantics (both the
|
||||
OpenRouter provider itself and the Custom provider).
|
||||
def __init__(self, config_path: str | None = None) -> None:
|
||||
super().__init__(
|
||||
env_var_name="OPENROUTER_MODELS_CONFIG_PATH",
|
||||
default_filename="openrouter_models.json",
|
||||
provider=ProviderType.OPENROUTER,
|
||||
friendly_prefix="OpenRouter ({model})",
|
||||
config_path=config_path,
|
||||
)
|
||||
|
||||
Key duties
|
||||
* Load :class:`ModelCapabilities` definitions from configuration files
|
||||
* Maintain a case-insensitive alias → canonical name map for fast
|
||||
resolution
|
||||
* Provide helpers to list models, list aliases, and resolve an arbitrary
|
||||
name to its capability object without repeatedly touching the file
|
||||
system.
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""Initialize the registry.
|
||||
|
||||
Args:
|
||||
config_path: Path to config file. If None, uses default locations.
|
||||
"""
|
||||
self.alias_map: dict[str, str] = {} # alias -> model_name
|
||||
self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config
|
||||
|
||||
# Determine config path and loading strategy
|
||||
self.use_resources = False
|
||||
if config_path:
|
||||
# Direct config_path parameter
|
||||
self.config_path = Path(config_path)
|
||||
def _finalise_entry(self, entry: dict) -> tuple[ModelCapabilities, dict]:
|
||||
provider_override = entry.get("provider")
|
||||
if isinstance(provider_override, str):
|
||||
entry_provider = ProviderType(provider_override.lower())
|
||||
elif isinstance(provider_override, ProviderType):
|
||||
entry_provider = provider_override
|
||||
else:
|
||||
# Check environment variable first
|
||||
env_path = get_env("CUSTOM_MODELS_CONFIG_PATH")
|
||||
if env_path:
|
||||
# Environment variable path
|
||||
self.config_path = Path(env_path)
|
||||
else:
|
||||
# Try importlib.resources for robust packaging support
|
||||
self.config_path = None
|
||||
self.use_resources = False
|
||||
entry_provider = ProviderType.OPENROUTER
|
||||
|
||||
try:
|
||||
resource_traversable = importlib.resources.files("conf").joinpath("custom_models.json")
|
||||
if hasattr(resource_traversable, "read_text"):
|
||||
self.use_resources = True
|
||||
else:
|
||||
raise AttributeError("read_text not available")
|
||||
except Exception:
|
||||
pass
|
||||
if entry_provider == ProviderType.CUSTOM:
|
||||
entry.setdefault("friendly_name", f"Custom ({entry['model_name']})")
|
||||
else:
|
||||
entry.setdefault("friendly_name", f"OpenRouter ({entry['model_name']})")
|
||||
|
||||
if not self.use_resources:
|
||||
# Fallback to file system paths
|
||||
potential_paths = [
|
||||
Path(__file__).parent.parent / "conf" / "custom_models.json",
|
||||
Path.cwd() / "conf" / "custom_models.json",
|
||||
]
|
||||
|
||||
for path in potential_paths:
|
||||
if path.exists():
|
||||
self.config_path = path
|
||||
break
|
||||
|
||||
if self.config_path is None:
|
||||
self.config_path = potential_paths[0]
|
||||
|
||||
# Load configuration
|
||||
self.reload()
|
||||
|
||||
def reload(self) -> None:
|
||||
"""Reload configuration from disk."""
|
||||
try:
|
||||
configs = self._read_config()
|
||||
self._build_maps(configs)
|
||||
caller_info = ""
|
||||
try:
|
||||
import inspect
|
||||
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
if caller_frame:
|
||||
caller_name = caller_frame.f_code.co_name
|
||||
caller_file = (
|
||||
caller_frame.f_code.co_filename.split("/")[-1] if caller_frame.f_code.co_filename else "unknown"
|
||||
)
|
||||
# Look for tool context
|
||||
while caller_frame:
|
||||
frame_locals = caller_frame.f_locals
|
||||
if "self" in frame_locals and hasattr(frame_locals["self"], "get_name"):
|
||||
tool_name = frame_locals["self"].get_name()
|
||||
caller_info = f" (called from {tool_name} tool)"
|
||||
break
|
||||
caller_frame = caller_frame.f_back
|
||||
if not caller_info:
|
||||
caller_info = f" (called from {caller_name} in {caller_file})"
|
||||
except Exception:
|
||||
# If frame inspection fails, just continue without caller info
|
||||
pass
|
||||
|
||||
logging.debug(
|
||||
f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases{caller_info}"
|
||||
)
|
||||
except ValueError as e:
|
||||
# Re-raise ValueError only for duplicate aliases (critical config errors)
|
||||
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||
# Initialize with empty maps on failure
|
||||
self.alias_map = {}
|
||||
self.model_map = {}
|
||||
if "Duplicate alias" in str(e):
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||
# Initialize with empty maps on failure
|
||||
self.alias_map = {}
|
||||
self.model_map = {}
|
||||
|
||||
def _read_config(self) -> list[ModelCapabilities]:
|
||||
"""Read configuration from file or package resources.
|
||||
|
||||
Returns:
|
||||
List of model configurations
|
||||
"""
|
||||
try:
|
||||
if self.use_resources:
|
||||
# Use importlib.resources for packaged environments
|
||||
try:
|
||||
resource_path = importlib.resources.files("conf").joinpath("custom_models.json")
|
||||
if hasattr(resource_path, "read_text"):
|
||||
# Python 3.9+
|
||||
config_text = resource_path.read_text(encoding="utf-8")
|
||||
else:
|
||||
# Python 3.8 fallback
|
||||
with resource_path.open("r", encoding="utf-8") as f:
|
||||
config_text = f.read()
|
||||
|
||||
import json
|
||||
|
||||
data = json.loads(config_text)
|
||||
logging.debug("Loaded OpenRouter config from package resources")
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to load config from resources: {e}")
|
||||
return []
|
||||
else:
|
||||
# Use file path loading
|
||||
if not self.config_path.exists():
|
||||
logging.warning(f"OpenRouter model config not found at {self.config_path}")
|
||||
return []
|
||||
|
||||
# Use centralized JSON reading utility
|
||||
data = read_json_file(str(self.config_path))
|
||||
logging.debug(f"Loaded OpenRouter config from file: {self.config_path}")
|
||||
|
||||
if data is None:
|
||||
location = "resources" if self.use_resources else str(self.config_path)
|
||||
raise ValueError(f"Could not read or parse JSON from {location}")
|
||||
|
||||
# Parse models
|
||||
configs = []
|
||||
for model_data in data.get("models", []):
|
||||
# Create ModelCapabilities directly from JSON data
|
||||
# Handle temperature_constraint conversion
|
||||
temp_constraint_str = model_data.get("temperature_constraint")
|
||||
temp_constraint = TemperatureConstraint.create(temp_constraint_str or "range")
|
||||
|
||||
# Set provider-specific defaults based on is_custom flag
|
||||
is_custom = model_data.get("is_custom", False)
|
||||
if is_custom:
|
||||
model_data.setdefault("provider", ProviderType.CUSTOM)
|
||||
model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})")
|
||||
else:
|
||||
model_data.setdefault("provider", ProviderType.OPENROUTER)
|
||||
model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})")
|
||||
model_data["temperature_constraint"] = temp_constraint
|
||||
|
||||
# Remove the string version of temperature_constraint before creating ModelCapabilities
|
||||
if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str):
|
||||
del model_data["temperature_constraint"]
|
||||
model_data["temperature_constraint"] = temp_constraint
|
||||
|
||||
config = ModelCapabilities(**model_data)
|
||||
configs.append(config)
|
||||
|
||||
return configs
|
||||
except ValueError:
|
||||
# Re-raise ValueError for specific config errors
|
||||
raise
|
||||
except Exception as e:
|
||||
location = "resources" if self.use_resources else str(self.config_path)
|
||||
raise ValueError(f"Error reading config from {location}: {e}")
|
||||
|
||||
def _build_maps(self, configs: list[ModelCapabilities]) -> None:
|
||||
"""Build alias and model maps from configurations.
|
||||
|
||||
Args:
|
||||
configs: List of model configurations
|
||||
"""
|
||||
alias_map = {}
|
||||
model_map = {}
|
||||
|
||||
for config in configs:
|
||||
# Add to model map
|
||||
model_map[config.model_name] = config
|
||||
|
||||
# Add the model_name itself as an alias for case-insensitive lookup
|
||||
# But only if it's not already in the aliases list
|
||||
model_name_lower = config.model_name.lower()
|
||||
aliases_lower = [alias.lower() for alias in config.aliases]
|
||||
|
||||
if model_name_lower not in aliases_lower:
|
||||
if model_name_lower in alias_map:
|
||||
existing_model = alias_map[model_name_lower]
|
||||
if existing_model != config.model_name:
|
||||
raise ValueError(
|
||||
f"Duplicate model name '{config.model_name}' (case-insensitive) found for models "
|
||||
f"'{existing_model}' and '{config.model_name}'"
|
||||
)
|
||||
else:
|
||||
alias_map[model_name_lower] = config.model_name
|
||||
|
||||
# Add aliases
|
||||
for alias in config.aliases:
|
||||
alias_lower = alias.lower()
|
||||
if alias_lower in alias_map:
|
||||
existing_model = alias_map[alias_lower]
|
||||
raise ValueError(
|
||||
f"Duplicate alias '{alias}' found for models '{existing_model}' and '{config.model_name}'"
|
||||
)
|
||||
alias_map[alias_lower] = config.model_name
|
||||
|
||||
# Atomic update
|
||||
self.alias_map = alias_map
|
||||
self.model_map = model_map
|
||||
|
||||
def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||
"""Resolve a model name or alias to configuration.
|
||||
|
||||
Args:
|
||||
name_or_alias: Model name or alias to resolve
|
||||
|
||||
Returns:
|
||||
Model configuration if found, None otherwise
|
||||
"""
|
||||
# Try alias lookup (case-insensitive) - this now includes model names too
|
||||
alias_lower = name_or_alias.lower()
|
||||
if alias_lower in self.alias_map:
|
||||
model_name = self.alias_map[alias_lower]
|
||||
return self.model_map.get(model_name)
|
||||
|
||||
return None
|
||||
|
||||
def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||
"""Get model capabilities for a name or alias.
|
||||
|
||||
Args:
|
||||
name_or_alias: Model name or alias
|
||||
|
||||
Returns:
|
||||
ModelCapabilities if found, None otherwise
|
||||
"""
|
||||
# Registry now returns ModelCapabilities directly
|
||||
return self.resolve(name_or_alias)
|
||||
|
||||
def get_model_config(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||
"""Backward-compatible wrapper used by providers and older tests."""
|
||||
|
||||
return self.resolve(name_or_alias)
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
"""List all available model names."""
|
||||
return list(self.model_map.keys())
|
||||
|
||||
def list_aliases(self) -> list[str]:
|
||||
"""List all available aliases."""
|
||||
return list(self.alias_map.keys())
|
||||
filtered = {k: v for k, v in entry.items() if k in CAPABILITY_FIELD_NAMES}
|
||||
filtered.setdefault("provider", entry_provider)
|
||||
capability = ModelCapabilities(**filtered)
|
||||
return capability, {}
|
||||
|
||||
@@ -38,6 +38,7 @@ class ModelProviderRegistry:
|
||||
PROVIDER_PRIORITY_ORDER = [
|
||||
ProviderType.GOOGLE, # Direct Gemini access
|
||||
ProviderType.OPENAI, # Direct OpenAI access
|
||||
ProviderType.AZURE, # Azure-hosted OpenAI deployments
|
||||
ProviderType.XAI, # Direct X.AI GROK access
|
||||
ProviderType.DIAL, # DIAL unified API access
|
||||
ProviderType.CUSTOM, # Local/self-hosted models
|
||||
@@ -123,6 +124,21 @@ class ModelProviderRegistry:
|
||||
provider_kwargs["base_url"] = gemini_base_url
|
||||
logging.info(f"Initialized Gemini provider with custom endpoint: {gemini_base_url}")
|
||||
provider = provider_class(**provider_kwargs)
|
||||
elif provider_type == ProviderType.AZURE:
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
azure_endpoint = get_env("AZURE_OPENAI_ENDPOINT")
|
||||
if not azure_endpoint:
|
||||
logging.warning("AZURE_OPENAI_ENDPOINT missing – skipping Azure OpenAI provider")
|
||||
return None
|
||||
|
||||
azure_version = get_env("AZURE_OPENAI_API_VERSION")
|
||||
provider = provider_class(
|
||||
api_key=api_key,
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_version=azure_version,
|
||||
)
|
||||
else:
|
||||
if not api_key:
|
||||
return None
|
||||
@@ -318,6 +334,7 @@ class ModelProviderRegistry:
|
||||
key_mapping = {
|
||||
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||
ProviderType.AZURE: "AZURE_OPENAI_API_KEY",
|
||||
ProviderType.XAI: "XAI_API_KEY",
|
||||
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
||||
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
|
||||
|
||||
@@ -53,7 +53,6 @@ class ModelCapabilities:
|
||||
|
||||
# Additional attributes
|
||||
max_image_size_mb: float = 0.0
|
||||
is_custom: bool = False
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
)
|
||||
@@ -102,9 +101,6 @@ class ModelCapabilities:
|
||||
if self.supports_images:
|
||||
score += 1
|
||||
|
||||
if self.is_custom:
|
||||
score -= 1
|
||||
|
||||
return max(0, min(100, score))
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -10,6 +10,7 @@ class ProviderType(Enum):
|
||||
|
||||
GOOGLE = "google"
|
||||
OPENAI = "openai"
|
||||
AZURE = "azure"
|
||||
XAI = "xai"
|
||||
OPENROUTER = "openrouter"
|
||||
CUSTOM = "custom"
|
||||
|
||||
Reference in New Issue
Block a user