feat: centralized environment handling, ensures ZEN_MCP_FORCE_ENV_OVERRIDE is honored correctly
fix: updated tests to override env variables they need instead of relying on the current values from .env
This commit is contained in:
@@ -112,24 +112,28 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration constants
|
||||
# Get max conversation turns from environment, default to 20 turns (10 exchanges)
|
||||
try:
|
||||
MAX_CONVERSATION_TURNS = int(os.getenv("MAX_CONVERSATION_TURNS", "20"))
|
||||
max_turns_raw = (get_env("MAX_CONVERSATION_TURNS", "50") or "50").strip()
|
||||
MAX_CONVERSATION_TURNS = int(max_turns_raw)
|
||||
if MAX_CONVERSATION_TURNS <= 0:
|
||||
logger.warning(f"Invalid MAX_CONVERSATION_TURNS value ({MAX_CONVERSATION_TURNS}), using default of 20 turns")
|
||||
MAX_CONVERSATION_TURNS = 20
|
||||
logger.warning(f"Invalid MAX_CONVERSATION_TURNS value ({MAX_CONVERSATION_TURNS}), using default of 50 turns")
|
||||
MAX_CONVERSATION_TURNS = 50
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid MAX_CONVERSATION_TURNS value ('{os.getenv('MAX_CONVERSATION_TURNS')}'), using default of 20 turns"
|
||||
f"Invalid MAX_CONVERSATION_TURNS value ('{get_env('MAX_CONVERSATION_TURNS')}'), using default of 50 turns"
|
||||
)
|
||||
MAX_CONVERSATION_TURNS = 20
|
||||
MAX_CONVERSATION_TURNS = 50
|
||||
|
||||
# Get conversation timeout from environment (in hours), default to 3 hours
|
||||
try:
|
||||
CONVERSATION_TIMEOUT_HOURS = int(os.getenv("CONVERSATION_TIMEOUT_HOURS", "3"))
|
||||
timeout_raw = (get_env("CONVERSATION_TIMEOUT_HOURS", "3") or "3").strip()
|
||||
CONVERSATION_TIMEOUT_HOURS = int(timeout_raw)
|
||||
if CONVERSATION_TIMEOUT_HOURS <= 0:
|
||||
logger.warning(
|
||||
f"Invalid CONVERSATION_TIMEOUT_HOURS value ({CONVERSATION_TIMEOUT_HOURS}), using default of 3 hours"
|
||||
@@ -137,7 +141,7 @@ try:
|
||||
CONVERSATION_TIMEOUT_HOURS = 3
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid CONVERSATION_TIMEOUT_HOURS value ('{os.getenv('CONVERSATION_TIMEOUT_HOURS')}'), using default of 3 hours"
|
||||
f"Invalid CONVERSATION_TIMEOUT_HOURS value ('{get_env('CONVERSATION_TIMEOUT_HOURS')}'), using default of 3 hours"
|
||||
)
|
||||
CONVERSATION_TIMEOUT_HOURS = 3
|
||||
|
||||
|
||||
88
utils/env.py
Normal file
88
utils/env.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Centralized environment variable access for Zen MCP Server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from dotenv import dotenv_values, load_dotenv
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
dotenv_values = None # type: ignore[assignment]
|
||||
load_dotenv = None # type: ignore[assignment]
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
_ENV_PATH = _PROJECT_ROOT / ".env"
|
||||
|
||||
_DOTENV_VALUES: dict[str, str | None] = {}
|
||||
_FORCE_ENV_OVERRIDE = False
|
||||
|
||||
|
||||
def _read_dotenv_values() -> dict[str, str | None]:
|
||||
if dotenv_values is not None and _ENV_PATH.exists():
|
||||
loaded = dotenv_values(_ENV_PATH)
|
||||
return dict(loaded)
|
||||
return {}
|
||||
|
||||
|
||||
def _compute_force_override(values: Mapping[str, str | None]) -> bool:
|
||||
raw = (values.get("ZEN_MCP_FORCE_ENV_OVERRIDE") or "false").strip().lower()
|
||||
return raw == "true"
|
||||
|
||||
|
||||
def reload_env(dotenv_mapping: Mapping[str, str | None] | None = None) -> None:
|
||||
"""Reload .env values and recompute override semantics.
|
||||
|
||||
Args:
|
||||
dotenv_mapping: Optional mapping used instead of reading the .env file.
|
||||
Intended for tests; when provided, load_dotenv is not invoked.
|
||||
"""
|
||||
|
||||
global _DOTENV_VALUES, _FORCE_ENV_OVERRIDE
|
||||
|
||||
if dotenv_mapping is not None:
|
||||
_DOTENV_VALUES = dict(dotenv_mapping)
|
||||
_FORCE_ENV_OVERRIDE = _compute_force_override(_DOTENV_VALUES)
|
||||
return
|
||||
|
||||
_DOTENV_VALUES = _read_dotenv_values()
|
||||
_FORCE_ENV_OVERRIDE = _compute_force_override(_DOTENV_VALUES)
|
||||
|
||||
if load_dotenv is not None and _ENV_PATH.exists():
|
||||
load_dotenv(dotenv_path=_ENV_PATH, override=_FORCE_ENV_OVERRIDE)
|
||||
|
||||
|
||||
reload_env()
|
||||
|
||||
|
||||
def env_override_enabled() -> bool:
|
||||
"""Return True when ZEN_MCP_FORCE_ENV_OVERRIDE is enabled via the .env file."""
|
||||
|
||||
return _FORCE_ENV_OVERRIDE
|
||||
|
||||
|
||||
def get_env(key: str, default: str | None = None) -> str | None:
|
||||
"""Retrieve environment variables respecting ZEN_MCP_FORCE_ENV_OVERRIDE."""
|
||||
|
||||
if env_override_enabled():
|
||||
if key in _DOTENV_VALUES:
|
||||
value = _DOTENV_VALUES[key]
|
||||
return value if value is not None else default
|
||||
return default
|
||||
|
||||
return os.getenv(key, default)
|
||||
|
||||
|
||||
def get_env_bool(key: str, default: bool = False) -> bool:
|
||||
"""Boolean helper that respects override semantics."""
|
||||
|
||||
raw_default = "true" if default else "false"
|
||||
raw_value = get_env(key, raw_default)
|
||||
return (raw_value or raw_default).strip().lower() == "true"
|
||||
|
||||
|
||||
def get_all_env() -> dict[str, str | None]:
|
||||
"""Expose the loaded .env mapping for diagnostics/logging."""
|
||||
|
||||
return dict(_DOTENV_VALUES)
|
||||
@@ -21,11 +21,11 @@ Example:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from providers.shared import ProviderType
|
||||
from utils.env import get_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -65,7 +65,7 @@ class ModelRestrictionService:
|
||||
def _load_from_env(self) -> None:
|
||||
"""Load restrictions from environment variables."""
|
||||
for provider_type, env_var in self.ENV_VARS.items():
|
||||
env_value = os.getenv(env_var)
|
||||
env_value = get_env(env_var)
|
||||
|
||||
if env_value is None or env_value == "":
|
||||
# Not set or empty - no restrictions (allow all models)
|
||||
|
||||
@@ -19,11 +19,12 @@ Key Features:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +36,7 @@ class InMemoryStorage:
|
||||
self._lock = threading.Lock()
|
||||
# Match Redis behavior: cleanup interval based on conversation timeout
|
||||
# Run cleanup at 1/10th of timeout interval (e.g., 18 mins for 3 hour timeout)
|
||||
timeout_hours = int(os.getenv("CONVERSATION_TIMEOUT_HOURS", "3"))
|
||||
timeout_hours = int(get_env("CONVERSATION_TIMEOUT_HOURS", "3") or "3")
|
||||
self._cleanup_interval = (timeout_hours * 3600) // 10
|
||||
self._cleanup_interval = max(300, self._cleanup_interval) # Minimum 5 minutes
|
||||
self._shutdown = False
|
||||
|
||||
Reference in New Issue
Block a user