feat!: breaking change - OpenRouter models are now read from conf/openrouter_models.json while Custom / Self-hosted models are read from conf/custom_models.json
feat: Azure OpenAI / Azure AI Foundry support. Models should be defined in conf/azure_models.json (or a custom path). See .env.example for environment variables or see readme. https://github.com/BeehiveInnovations/zen-mcp-server/issues/265 feat: OpenRouter / Custom Models / Azure can separately also use custom config paths now (see .env.example ) refactor: Model registry class made abstract, OpenRouter / Custom Provider / Azure OpenAI now subclass these refactor: breaking change: `is_custom` property has been removed from model_capabilities.py (and thus custom_models.json) given each models are now read from separate configuration files
This commit is contained in:
342
providers/azure_openai.py
Normal file
342
providers/azure_openai.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""Azure OpenAI provider built on the OpenAI-compatible implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import asdict, replace
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
from openai import AzureOpenAI
|
||||
except ImportError: # pragma: no cover
|
||||
AzureOpenAI = None # type: ignore[assignment]
|
||||
|
||||
from utils.env import get_env, suppress_env_vars
|
||||
|
||||
from .azure_registry import AzureModelRegistry
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openai_provider import OpenAIModelProvider
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAIProvider(OpenAICompatibleProvider):
|
||||
"""Thin Azure wrapper that reuses the OpenAI-compatible request pipeline."""
|
||||
|
||||
FRIENDLY_NAME = "Azure OpenAI"
|
||||
DEFAULT_API_VERSION = "2024-02-15-preview"
|
||||
|
||||
# The OpenAI-compatible base expects subclasses to expose capabilities via
|
||||
# ``get_all_model_capabilities``. Azure deployments are user-defined, so we
|
||||
# build the catalogue dynamically from environment configuration instead of
|
||||
# relying on a static ``MODEL_CAPABILITIES`` map.
|
||||
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
*,
|
||||
azure_endpoint: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployments: dict[str, object] | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# Let the OpenAI-compatible base handle shared configuration such as
|
||||
# timeouts, restriction-aware allowlists, and logging. ``base_url`` maps
|
||||
# directly onto Azure's endpoint URL.
|
||||
super().__init__(api_key, base_url=azure_endpoint, **kwargs)
|
||||
|
||||
if not azure_endpoint:
|
||||
azure_endpoint = get_env("AZURE_OPENAI_ENDPOINT")
|
||||
if not azure_endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required via parameter or AZURE_OPENAI_ENDPOINT")
|
||||
|
||||
self.azure_endpoint = azure_endpoint.rstrip("/")
|
||||
self.api_version = api_version or get_env("AZURE_OPENAI_API_VERSION", self.DEFAULT_API_VERSION)
|
||||
|
||||
registry_specs = self._load_registry_entries()
|
||||
override_specs = self._normalise_deployments(deployments or {}) if deployments else {}
|
||||
|
||||
self._model_specs = self._merge_specs(registry_specs, override_specs)
|
||||
if not self._model_specs:
|
||||
raise ValueError(
|
||||
"Azure OpenAI provider requires at least one configured deployment. "
|
||||
"Populate conf/azure_models.json or set AZURE_MODELS_CONFIG_PATH."
|
||||
)
|
||||
|
||||
self._capabilities = self._build_capabilities_map()
|
||||
self._deployment_map = {name: spec["deployment"] for name, spec in self._model_specs.items()}
|
||||
self._deployment_alias_lookup = {
|
||||
deployment.lower(): canonical for canonical, deployment in self._deployment_map.items()
|
||||
}
|
||||
self._canonical_lookup = {name.lower(): name for name in self._model_specs.keys()}
|
||||
self._invalidate_capability_cache()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Capability helpers
|
||||
# ------------------------------------------------------------------
|
||||
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||
return dict(self._capabilities)
|
||||
|
||||
def get_provider_type(self) -> ProviderType:
|
||||
return ProviderType.AZURE
|
||||
|
||||
def get_capabilities(self, model_name: str) -> ModelCapabilities: # type: ignore[override]
|
||||
lowered = model_name.lower()
|
||||
if lowered in self._deployment_alias_lookup:
|
||||
canonical = self._deployment_alias_lookup[lowered]
|
||||
return super().get_capabilities(canonical)
|
||||
canonical = self._canonical_lookup.get(lowered)
|
||||
if canonical:
|
||||
return super().get_capabilities(canonical)
|
||||
return super().get_capabilities(model_name)
|
||||
|
||||
def validate_model_name(self, model_name: str) -> bool: # type: ignore[override]
|
||||
lowered = model_name.lower()
|
||||
if lowered in self._deployment_alias_lookup or lowered in self._canonical_lookup:
|
||||
return True
|
||||
return super().validate_model_name(model_name)
|
||||
|
||||
def _build_capabilities_map(self) -> dict[str, ModelCapabilities]:
|
||||
capabilities: dict[str, ModelCapabilities] = {}
|
||||
|
||||
for canonical_name, spec in self._model_specs.items():
|
||||
template_capability: ModelCapabilities | None = spec.get("capability")
|
||||
overrides = spec.get("overrides", {})
|
||||
|
||||
if template_capability:
|
||||
cloned = replace(template_capability)
|
||||
else:
|
||||
template = OpenAIModelProvider.MODEL_CAPABILITIES.get(canonical_name)
|
||||
|
||||
if template:
|
||||
friendly = template.friendly_name.replace("OpenAI", "Azure OpenAI", 1)
|
||||
cloned = replace(
|
||||
template,
|
||||
provider=ProviderType.AZURE,
|
||||
friendly_name=friendly,
|
||||
aliases=list(template.aliases),
|
||||
)
|
||||
else:
|
||||
deployment_name = spec.get("deployment", "")
|
||||
cloned = ModelCapabilities(
|
||||
provider=ProviderType.AZURE,
|
||||
model_name=canonical_name,
|
||||
friendly_name=f"Azure OpenAI ({canonical_name})",
|
||||
description=f"Azure deployment '{deployment_name}' for {canonical_name}",
|
||||
aliases=[],
|
||||
)
|
||||
|
||||
if overrides:
|
||||
overrides = dict(overrides)
|
||||
temp_override = overrides.get("temperature_constraint")
|
||||
if isinstance(temp_override, str):
|
||||
overrides["temperature_constraint"] = TemperatureConstraint.create(temp_override)
|
||||
|
||||
aliases_override = overrides.get("aliases")
|
||||
if isinstance(aliases_override, str):
|
||||
overrides["aliases"] = [alias.strip() for alias in aliases_override.split(",") if alias.strip()]
|
||||
provider_override = overrides.get("provider")
|
||||
if provider_override:
|
||||
overrides.pop("provider", None)
|
||||
|
||||
try:
|
||||
cloned = replace(cloned, **overrides)
|
||||
except TypeError:
|
||||
base_data = asdict(cloned)
|
||||
base_data.update(overrides)
|
||||
base_data["provider"] = ProviderType.AZURE
|
||||
temp_value = base_data.get("temperature_constraint")
|
||||
if isinstance(temp_value, str):
|
||||
base_data["temperature_constraint"] = TemperatureConstraint.create(temp_value)
|
||||
cloned = ModelCapabilities(**base_data)
|
||||
|
||||
if cloned.provider != ProviderType.AZURE:
|
||||
cloned.provider = ProviderType.AZURE
|
||||
|
||||
capabilities[canonical_name] = cloned
|
||||
|
||||
return capabilities
|
||||
|
||||
def _load_registry_entries(self) -> dict[str, dict]:
|
||||
try:
|
||||
registry = AzureModelRegistry()
|
||||
except Exception as exc: # pragma: no cover - registry failure should not crash provider
|
||||
logger.warning("Unable to load Azure model registry: %s", exc)
|
||||
return {}
|
||||
|
||||
entries: dict[str, dict] = {}
|
||||
for model_name, capability, extra in registry.iter_entries():
|
||||
deployment = extra.get("deployment")
|
||||
if not deployment:
|
||||
logger.warning("Azure model '%s' missing deployment in registry", model_name)
|
||||
continue
|
||||
entries[model_name] = {"deployment": deployment, "capability": capability}
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def _merge_specs(
|
||||
registry_specs: dict[str, dict],
|
||||
override_specs: dict[str, dict],
|
||||
) -> dict[str, dict]:
|
||||
specs: dict[str, dict] = {}
|
||||
|
||||
for canonical, entry in registry_specs.items():
|
||||
specs[canonical] = {
|
||||
"deployment": entry.get("deployment"),
|
||||
"capability": entry.get("capability"),
|
||||
"overrides": {},
|
||||
}
|
||||
|
||||
for canonical, entry in override_specs.items():
|
||||
spec = specs.get(canonical, {"deployment": None, "capability": None, "overrides": {}})
|
||||
deployment = entry.get("deployment")
|
||||
if deployment:
|
||||
spec["deployment"] = deployment
|
||||
overrides = {k: v for k, v in entry.items() if k not in {"deployment"}}
|
||||
overrides.pop("capability", None)
|
||||
if overrides:
|
||||
spec["overrides"].update(overrides)
|
||||
specs[canonical] = spec
|
||||
|
||||
return {k: v for k, v in specs.items() if v.get("deployment")}
|
||||
|
||||
@staticmethod
|
||||
def _normalise_deployments(mapping: dict[str, object]) -> dict[str, dict]:
|
||||
normalised: dict[str, dict] = {}
|
||||
for canonical, spec in mapping.items():
|
||||
canonical_name = (canonical or "").strip()
|
||||
if not canonical_name:
|
||||
continue
|
||||
|
||||
deployment_name: str | None = None
|
||||
overrides: dict[str, object] = {}
|
||||
|
||||
if isinstance(spec, str):
|
||||
deployment_name = spec.strip()
|
||||
elif isinstance(spec, dict):
|
||||
deployment_name = spec.get("deployment") or spec.get("deployment_name")
|
||||
overrides = {k: v for k, v in spec.items() if k not in {"deployment", "deployment_name"}}
|
||||
|
||||
if not deployment_name:
|
||||
continue
|
||||
|
||||
normalised[canonical_name] = {"deployment": deployment_name.strip(), **overrides}
|
||||
|
||||
return normalised
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Azure-specific configuration
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def client(self): # type: ignore[override]
|
||||
"""Instantiate the Azure OpenAI client on first use."""
|
||||
|
||||
if self._client is None:
|
||||
if AzureOpenAI is None:
|
||||
raise ImportError(
|
||||
"Azure OpenAI support requires the 'openai' package. Install it with `pip install openai`."
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]
|
||||
|
||||
with suppress_env_vars(*proxy_env_vars):
|
||||
try:
|
||||
timeout_config = self.timeout_config
|
||||
|
||||
http_client = httpx.Client(timeout=timeout_config, follow_redirects=True)
|
||||
|
||||
client_kwargs = {
|
||||
"api_key": self.api_key,
|
||||
"azure_endpoint": self.azure_endpoint,
|
||||
"api_version": self.api_version,
|
||||
"http_client": http_client,
|
||||
}
|
||||
|
||||
if self.DEFAULT_HEADERS:
|
||||
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||||
|
||||
logger.debug(
|
||||
"Initializing Azure OpenAI client endpoint=%s api_version=%s timeouts=%s",
|
||||
self.azure_endpoint,
|
||||
self.api_version,
|
||||
timeout_config,
|
||||
)
|
||||
|
||||
self._client = AzureOpenAI(**client_kwargs)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Failed to create Azure OpenAI client: %s", exc)
|
||||
raise
|
||||
|
||||
return self._client
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request delegation
|
||||
# ------------------------------------------------------------------
|
||||
def generate_content(
|
||||
self,
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: str | None = None,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: int | None = None,
|
||||
images: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
canonical_name, deployment_name = self._resolve_canonical_and_deployment(model_name)
|
||||
|
||||
# Delegate to the shared OpenAI-compatible implementation using the
|
||||
# deployment name – Azure requires the deployment identifier in the
|
||||
# ``model`` field. The returned ``ModelResponse`` is normalised so
|
||||
# downstream consumers continue to see the canonical model name.
|
||||
raw_response = super().generate_content(
|
||||
prompt=prompt,
|
||||
model_name=deployment_name,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_output_tokens,
|
||||
images=images,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
capabilities = self._capabilities.get(canonical_name)
|
||||
friendly_name = capabilities.friendly_name if capabilities else self.FRIENDLY_NAME
|
||||
|
||||
return ModelResponse(
|
||||
content=raw_response.content,
|
||||
usage=raw_response.usage,
|
||||
model_name=canonical_name,
|
||||
friendly_name=friendly_name,
|
||||
provider=ProviderType.AZURE,
|
||||
metadata={**raw_response.metadata, "deployment": deployment_name},
|
||||
)
|
||||
|
||||
def _resolve_canonical_and_deployment(self, model_name: str) -> tuple[str, str]:
|
||||
resolved_canonical = self._resolve_model_name(model_name)
|
||||
|
||||
if resolved_canonical not in self._deployment_map:
|
||||
# The base resolver may hand back the deployment alias. Try to map it
|
||||
# back to a canonical entry.
|
||||
for canonical, deployment in self._deployment_map.items():
|
||||
if deployment.lower() == resolved_canonical.lower():
|
||||
return canonical, deployment
|
||||
raise ValueError(f"Model '{model_name}' is not configured for Azure OpenAI")
|
||||
|
||||
return resolved_canonical, self._deployment_map[resolved_canonical]
|
||||
|
||||
def _parse_allowed_models(self) -> set[str] | None: # type: ignore[override]
|
||||
# Support both AZURE_ALLOWED_MODELS (inherited behaviour) and the
|
||||
# clearer AZURE_OPENAI_ALLOWED_MODELS alias.
|
||||
explicit = get_env("AZURE_OPENAI_ALLOWED_MODELS")
|
||||
if explicit:
|
||||
models = {m.strip().lower() for m in explicit.split(",") if m.strip()}
|
||||
if models:
|
||||
logger.info("Configured allowed models for Azure OpenAI: %s", sorted(models))
|
||||
self._allowed_alias_cache = {}
|
||||
return models
|
||||
|
||||
return super()._parse_allowed_models()
|
||||
Reference in New Issue
Block a user