feat: centralized environment handling, ensures ZEN_MCP_FORCE_ENV_OVERRIDE is honored correctly

fix: updated tests to override env variables they need instead of relying on the current values from .env
This commit is contained in:
Fahad
2025-10-04 14:28:56 +04:00
parent 4015e917ed
commit 2c534ac06e
24 changed files with 300 additions and 179 deletions

View File

@@ -1,9 +1,10 @@
"""Custom API provider implementation."""
import logging
import os
from typing import Optional
from utils.env import get_env
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import ModelCapabilities, ProviderType
@@ -56,9 +57,9 @@ class CustomProvider(OpenAICompatibleProvider):
"""
# Fall back to environment variables only if not provided
if not base_url:
base_url = os.getenv("CUSTOM_API_URL", "")
base_url = get_env("CUSTOM_API_URL", "") or ""
if not api_key:
api_key = os.getenv("CUSTOM_API_KEY", "")
api_key = get_env("CUSTOM_API_KEY", "") or ""
if not base_url:
raise ValueError(

View File

@@ -1,10 +1,11 @@
"""DIAL (Data & AI Layer) model provider implementation."""
import logging
import os
import threading
from typing import Optional
from utils.env import get_env
from .openai_compatible import OpenAICompatibleProvider
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
@@ -209,7 +210,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
**kwargs: Additional configuration options
"""
# Get DIAL API host from environment or kwargs
dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai"
dial_host = kwargs.get("base_url") or get_env("DIAL_API_HOST") or "https://core.dialx.ai"
# DIAL uses /openai endpoint for OpenAI-compatible API
if not dial_host.endswith("/openai"):
@@ -218,7 +219,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
kwargs["base_url"] = dial_host
# Get API version from environment or use default
self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview")
self.api_version = get_env("DIAL_API_VERSION", "2024-12-01-preview") or "2024-12-01-preview"
# Add DIAL-specific headers
# DIAL uses Api-Key header instead of Authorization: Bearer

View File

@@ -3,12 +3,12 @@
import copy
import ipaddress
import logging
import os
from typing import Optional
from urllib.parse import urlparse
from openai import OpenAI
from utils.env import get_env
from utils.image_utils import validate_image
from .base import ModelProvider
@@ -112,7 +112,7 @@ class OpenAICompatibleProvider(ModelProvider):
# Get provider-specific allowed models
provider_type = self.get_provider_type().value.upper()
env_var = f"{provider_type}_ALLOWED_MODELS"
models_str = os.getenv(env_var, "")
models_str = get_env(env_var, "") or ""
if models_str:
# Parse and normalize to lowercase for case-insensitive comparison
@@ -165,10 +165,25 @@ class OpenAICompatibleProvider(ModelProvider):
logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}")
# Allow override via kwargs or environment variables in future, for now...
connect_timeout = kwargs.get("connect_timeout", float(os.getenv("CUSTOM_CONNECT_TIMEOUT", default_connect)))
read_timeout = kwargs.get("read_timeout", float(os.getenv("CUSTOM_READ_TIMEOUT", default_read)))
write_timeout = kwargs.get("write_timeout", float(os.getenv("CUSTOM_WRITE_TIMEOUT", default_write)))
pool_timeout = kwargs.get("pool_timeout", float(os.getenv("CUSTOM_POOL_TIMEOUT", default_pool)))
connect_timeout = kwargs.get("connect_timeout")
if connect_timeout is None:
connect_timeout_raw = get_env("CUSTOM_CONNECT_TIMEOUT")
connect_timeout = float(connect_timeout_raw) if connect_timeout_raw is not None else float(default_connect)
read_timeout = kwargs.get("read_timeout")
if read_timeout is None:
read_timeout_raw = get_env("CUSTOM_READ_TIMEOUT")
read_timeout = float(read_timeout_raw) if read_timeout_raw is not None else float(default_read)
write_timeout = kwargs.get("write_timeout")
if write_timeout is None:
write_timeout_raw = get_env("CUSTOM_WRITE_TIMEOUT")
write_timeout = float(write_timeout_raw) if write_timeout_raw is not None else float(default_write)
pool_timeout = kwargs.get("pool_timeout")
if pool_timeout is None:
pool_timeout_raw = get_env("CUSTOM_POOL_TIMEOUT")
pool_timeout = float(pool_timeout_raw) if pool_timeout_raw is not None else float(default_pool)
timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout)

View File

@@ -1,9 +1,10 @@
"""OpenRouter provider implementation."""
import logging
import os
from typing import Optional
from utils.env import get_env
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import (
@@ -35,8 +36,9 @@ class OpenRouterProvider(OpenAICompatibleProvider):
# Custom headers required by OpenRouter
DEFAULT_HEADERS = {
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
"HTTP-Referer": get_env("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server")
or "https://github.com/BeehiveInnovations/zen-mcp-server",
"X-Title": get_env("OPENROUTER_TITLE", "Zen MCP Server") or "Zen MCP Server",
}
# Model registry for managing configurations and aliases

View File

@@ -2,10 +2,11 @@
import importlib.resources
import logging
import os
from pathlib import Path
from typing import Optional
from utils.env import get_env
# Import handled via importlib.resources.files() calls directly
from utils.file_utils import read_json_file
@@ -50,7 +51,7 @@ class OpenRouterModelRegistry:
self.config_path = Path(config_path)
else:
# Check environment variable first
env_path = os.getenv("CUSTOM_MODELS_CONFIG_PATH")
env_path = get_env("CUSTOM_MODELS_CONFIG_PATH")
if env_path:
# Environment variable path
self.config_path = Path(env_path)

View File

@@ -1,9 +1,10 @@
"""Model provider registry for managing available providers."""
import logging
import os
from typing import TYPE_CHECKING, Optional
from utils.env import get_env
from .base import ModelProvider
from .shared import ProviderType
@@ -102,7 +103,7 @@ class ModelProviderRegistry:
provider = provider_class(api_key=api_key)
else:
# Regular class - need to handle URL requirement
custom_url = os.getenv("CUSTOM_API_URL", "")
custom_url = get_env("CUSTOM_API_URL", "") or ""
if not custom_url:
if api_key: # Key is set but URL is missing
logging.warning("CUSTOM_API_KEY set but CUSTOM_API_URL missing skipping Custom provider")
@@ -116,7 +117,7 @@ class ModelProviderRegistry:
# For Gemini, check if custom base URL is configured
if not api_key:
return None
gemini_base_url = os.getenv("GEMINI_BASE_URL")
gemini_base_url = get_env("GEMINI_BASE_URL")
provider_kwargs = {"api_key": api_key}
if gemini_base_url:
provider_kwargs["base_url"] = gemini_base_url
@@ -327,7 +328,7 @@ class ModelProviderRegistry:
if not env_var:
return None
return os.getenv(env_var)
return get_env(env_var)
@classmethod
def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]: