feat: centralized environment handling, ensures ZEN_MCP_FORCE_ENV_OVERRIDE is honored correctly

fix: updated tests to override env variables they need instead of relying on the current values from .env
This commit is contained in:
Fahad
2025-10-04 14:28:56 +04:00
parent 4015e917ed
commit 2c534ac06e
24 changed files with 300 additions and 179 deletions

View File

@@ -112,24 +112,28 @@ from typing import Any, Optional
from pydantic import BaseModel
from utils.env import get_env
logger = logging.getLogger(__name__)
# Configuration constants
# Get max conversation turns from environment, default to 20 turns (10 exchanges)
try:
MAX_CONVERSATION_TURNS = int(os.getenv("MAX_CONVERSATION_TURNS", "20"))
max_turns_raw = (get_env("MAX_CONVERSATION_TURNS", "50") or "50").strip()
MAX_CONVERSATION_TURNS = int(max_turns_raw)
if MAX_CONVERSATION_TURNS <= 0:
logger.warning(f"Invalid MAX_CONVERSATION_TURNS value ({MAX_CONVERSATION_TURNS}), using default of 20 turns")
MAX_CONVERSATION_TURNS = 20
logger.warning(f"Invalid MAX_CONVERSATION_TURNS value ({MAX_CONVERSATION_TURNS}), using default of 50 turns")
MAX_CONVERSATION_TURNS = 50
except ValueError:
logger.warning(
f"Invalid MAX_CONVERSATION_TURNS value ('{os.getenv('MAX_CONVERSATION_TURNS')}'), using default of 20 turns"
f"Invalid MAX_CONVERSATION_TURNS value ('{get_env('MAX_CONVERSATION_TURNS')}'), using default of 50 turns"
)
MAX_CONVERSATION_TURNS = 20
MAX_CONVERSATION_TURNS = 50
# Get conversation timeout from environment (in hours), default to 3 hours
try:
CONVERSATION_TIMEOUT_HOURS = int(os.getenv("CONVERSATION_TIMEOUT_HOURS", "3"))
timeout_raw = (get_env("CONVERSATION_TIMEOUT_HOURS", "3") or "3").strip()
CONVERSATION_TIMEOUT_HOURS = int(timeout_raw)
if CONVERSATION_TIMEOUT_HOURS <= 0:
logger.warning(
f"Invalid CONVERSATION_TIMEOUT_HOURS value ({CONVERSATION_TIMEOUT_HOURS}), using default of 3 hours"
@@ -137,7 +141,7 @@ try:
CONVERSATION_TIMEOUT_HOURS = 3
except ValueError:
logger.warning(
f"Invalid CONVERSATION_TIMEOUT_HOURS value ('{os.getenv('CONVERSATION_TIMEOUT_HOURS')}'), using default of 3 hours"
f"Invalid CONVERSATION_TIMEOUT_HOURS value ('{get_env('CONVERSATION_TIMEOUT_HOURS')}'), using default of 3 hours"
)
CONVERSATION_TIMEOUT_HOURS = 3

88
utils/env.py Normal file
View File

@@ -0,0 +1,88 @@
"""Centralized environment variable access for Zen MCP Server."""
from __future__ import annotations
import os
from collections.abc import Mapping
from pathlib import Path
try:
from dotenv import dotenv_values, load_dotenv
except ImportError: # pragma: no cover - optional dependency
dotenv_values = None # type: ignore[assignment]
load_dotenv = None # type: ignore[assignment]
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
_ENV_PATH = _PROJECT_ROOT / ".env"
_DOTENV_VALUES: dict[str, str | None] = {}
_FORCE_ENV_OVERRIDE = False
def _read_dotenv_values() -> dict[str, str | None]:
if dotenv_values is not None and _ENV_PATH.exists():
loaded = dotenv_values(_ENV_PATH)
return dict(loaded)
return {}
def _compute_force_override(values: Mapping[str, str | None]) -> bool:
raw = (values.get("ZEN_MCP_FORCE_ENV_OVERRIDE") or "false").strip().lower()
return raw == "true"
def reload_env(dotenv_mapping: Mapping[str, str | None] | None = None) -> None:
"""Reload .env values and recompute override semantics.
Args:
dotenv_mapping: Optional mapping used instead of reading the .env file.
Intended for tests; when provided, load_dotenv is not invoked.
"""
global _DOTENV_VALUES, _FORCE_ENV_OVERRIDE
if dotenv_mapping is not None:
_DOTENV_VALUES = dict(dotenv_mapping)
_FORCE_ENV_OVERRIDE = _compute_force_override(_DOTENV_VALUES)
return
_DOTENV_VALUES = _read_dotenv_values()
_FORCE_ENV_OVERRIDE = _compute_force_override(_DOTENV_VALUES)
if load_dotenv is not None and _ENV_PATH.exists():
load_dotenv(dotenv_path=_ENV_PATH, override=_FORCE_ENV_OVERRIDE)
reload_env()
def env_override_enabled() -> bool:
"""Return True when ZEN_MCP_FORCE_ENV_OVERRIDE is enabled via the .env file."""
return _FORCE_ENV_OVERRIDE
def get_env(key: str, default: str | None = None) -> str | None:
"""Retrieve environment variables respecting ZEN_MCP_FORCE_ENV_OVERRIDE."""
if env_override_enabled():
if key in _DOTENV_VALUES:
value = _DOTENV_VALUES[key]
return value if value is not None else default
return default
return os.getenv(key, default)
def get_env_bool(key: str, default: bool = False) -> bool:
"""Boolean helper that respects override semantics."""
raw_default = "true" if default else "false"
raw_value = get_env(key, raw_default)
return (raw_value or raw_default).strip().lower() == "true"
def get_all_env() -> dict[str, str | None]:
"""Expose the loaded .env mapping for diagnostics/logging."""
return dict(_DOTENV_VALUES)

View File

@@ -21,11 +21,11 @@ Example:
"""
import logging
import os
from collections import defaultdict
from typing import Optional
from providers.shared import ProviderType
from utils.env import get_env
logger = logging.getLogger(__name__)
@@ -65,7 +65,7 @@ class ModelRestrictionService:
def _load_from_env(self) -> None:
"""Load restrictions from environment variables."""
for provider_type, env_var in self.ENV_VARS.items():
env_value = os.getenv(env_var)
env_value = get_env(env_var)
if env_value is None or env_value == "":
# Not set or empty - no restrictions (allow all models)

View File

@@ -19,11 +19,12 @@ Key Features:
"""
import logging
import os
import threading
import time
from typing import Optional
from utils.env import get_env
logger = logging.getLogger(__name__)
@@ -35,7 +36,7 @@ class InMemoryStorage:
self._lock = threading.Lock()
# Match Redis behavior: cleanup interval based on conversation timeout
# Run cleanup at 1/10th of timeout interval (e.g., 18 mins for 3 hour timeout)
timeout_hours = int(os.getenv("CONVERSATION_TIMEOUT_HOURS", "3"))
timeout_hours = int(get_env("CONVERSATION_TIMEOUT_HOURS", "3") or "3")
self._cleanup_interval = (timeout_hours * 3600) // 10
self._cleanup_interval = max(300, self._cleanup_interval) # Minimum 5 minutes
self._shutdown = False