feat: centralized environment handling, ensures ZEN_MCP_FORCE_ENV_OVERRIDE is honored correctly
fix: updated tests to override env variables they need instead of relying on the current values from .env
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
"""Custom API provider implementation."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .shared import ModelCapabilities, ProviderType
|
||||
@@ -56,9 +57,9 @@ class CustomProvider(OpenAICompatibleProvider):
|
||||
"""
|
||||
# Fall back to environment variables only if not provided
|
||||
if not base_url:
|
||||
base_url = os.getenv("CUSTOM_API_URL", "")
|
||||
base_url = get_env("CUSTOM_API_URL", "") or ""
|
||||
if not api_key:
|
||||
api_key = os.getenv("CUSTOM_API_KEY", "")
|
||||
api_key = get_env("CUSTOM_API_KEY", "") or ""
|
||||
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""DIAL (Data & AI Layer) model provider implementation."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||
|
||||
@@ -209,7 +210,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
**kwargs: Additional configuration options
|
||||
"""
|
||||
# Get DIAL API host from environment or kwargs
|
||||
dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai"
|
||||
dial_host = kwargs.get("base_url") or get_env("DIAL_API_HOST") or "https://core.dialx.ai"
|
||||
|
||||
# DIAL uses /openai endpoint for OpenAI-compatible API
|
||||
if not dial_host.endswith("/openai"):
|
||||
@@ -218,7 +219,7 @@ class DIALModelProvider(OpenAICompatibleProvider):
|
||||
kwargs["base_url"] = dial_host
|
||||
|
||||
# Get API version from environment or use default
|
||||
self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview")
|
||||
self.api_version = get_env("DIAL_API_VERSION", "2024-12-01-preview") or "2024-12-01-preview"
|
||||
|
||||
# Add DIAL-specific headers
|
||||
# DIAL uses Api-Key header instead of Authorization: Bearer
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
import copy
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from utils.env import get_env
|
||||
from utils.image_utils import validate_image
|
||||
|
||||
from .base import ModelProvider
|
||||
@@ -112,7 +112,7 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
# Get provider-specific allowed models
|
||||
provider_type = self.get_provider_type().value.upper()
|
||||
env_var = f"{provider_type}_ALLOWED_MODELS"
|
||||
models_str = os.getenv(env_var, "")
|
||||
models_str = get_env(env_var, "") or ""
|
||||
|
||||
if models_str:
|
||||
# Parse and normalize to lowercase for case-insensitive comparison
|
||||
@@ -165,10 +165,25 @@ class OpenAICompatibleProvider(ModelProvider):
|
||||
logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}")
|
||||
|
||||
# Allow override via kwargs or environment variables in future, for now...
|
||||
connect_timeout = kwargs.get("connect_timeout", float(os.getenv("CUSTOM_CONNECT_TIMEOUT", default_connect)))
|
||||
read_timeout = kwargs.get("read_timeout", float(os.getenv("CUSTOM_READ_TIMEOUT", default_read)))
|
||||
write_timeout = kwargs.get("write_timeout", float(os.getenv("CUSTOM_WRITE_TIMEOUT", default_write)))
|
||||
pool_timeout = kwargs.get("pool_timeout", float(os.getenv("CUSTOM_POOL_TIMEOUT", default_pool)))
|
||||
connect_timeout = kwargs.get("connect_timeout")
|
||||
if connect_timeout is None:
|
||||
connect_timeout_raw = get_env("CUSTOM_CONNECT_TIMEOUT")
|
||||
connect_timeout = float(connect_timeout_raw) if connect_timeout_raw is not None else float(default_connect)
|
||||
|
||||
read_timeout = kwargs.get("read_timeout")
|
||||
if read_timeout is None:
|
||||
read_timeout_raw = get_env("CUSTOM_READ_TIMEOUT")
|
||||
read_timeout = float(read_timeout_raw) if read_timeout_raw is not None else float(default_read)
|
||||
|
||||
write_timeout = kwargs.get("write_timeout")
|
||||
if write_timeout is None:
|
||||
write_timeout_raw = get_env("CUSTOM_WRITE_TIMEOUT")
|
||||
write_timeout = float(write_timeout_raw) if write_timeout_raw is not None else float(default_write)
|
||||
|
||||
pool_timeout = kwargs.get("pool_timeout")
|
||||
if pool_timeout is None:
|
||||
pool_timeout_raw = get_env("CUSTOM_POOL_TIMEOUT")
|
||||
pool_timeout = float(pool_timeout_raw) if pool_timeout_raw is not None else float(default_pool)
|
||||
|
||||
timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""OpenRouter provider implementation."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .shared import (
|
||||
@@ -35,8 +36,9 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
|
||||
# Custom headers required by OpenRouter
|
||||
DEFAULT_HEADERS = {
|
||||
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
|
||||
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
|
||||
"HTTP-Referer": get_env("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server")
|
||||
or "https://github.com/BeehiveInnovations/zen-mcp-server",
|
||||
"X-Title": get_env("OPENROUTER_TITLE", "Zen MCP Server") or "Zen MCP Server",
|
||||
}
|
||||
|
||||
# Model registry for managing configurations and aliases
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
|
||||
import importlib.resources
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
# Import handled via importlib.resources.files() calls directly
|
||||
from utils.file_utils import read_json_file
|
||||
|
||||
@@ -50,7 +51,7 @@ class OpenRouterModelRegistry:
|
||||
self.config_path = Path(config_path)
|
||||
else:
|
||||
# Check environment variable first
|
||||
env_path = os.getenv("CUSTOM_MODELS_CONFIG_PATH")
|
||||
env_path = get_env("CUSTOM_MODELS_CONFIG_PATH")
|
||||
if env_path:
|
||||
# Environment variable path
|
||||
self.config_path = Path(env_path)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Model provider registry for managing available providers."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from utils.env import get_env
|
||||
|
||||
from .base import ModelProvider
|
||||
from .shared import ProviderType
|
||||
|
||||
@@ -102,7 +103,7 @@ class ModelProviderRegistry:
|
||||
provider = provider_class(api_key=api_key)
|
||||
else:
|
||||
# Regular class - need to handle URL requirement
|
||||
custom_url = os.getenv("CUSTOM_API_URL", "")
|
||||
custom_url = get_env("CUSTOM_API_URL", "") or ""
|
||||
if not custom_url:
|
||||
if api_key: # Key is set but URL is missing
|
||||
logging.warning("CUSTOM_API_KEY set but CUSTOM_API_URL missing – skipping Custom provider")
|
||||
@@ -116,7 +117,7 @@ class ModelProviderRegistry:
|
||||
# For Gemini, check if custom base URL is configured
|
||||
if not api_key:
|
||||
return None
|
||||
gemini_base_url = os.getenv("GEMINI_BASE_URL")
|
||||
gemini_base_url = get_env("GEMINI_BASE_URL")
|
||||
provider_kwargs = {"api_key": api_key}
|
||||
if gemini_base_url:
|
||||
provider_kwargs["base_url"] = gemini_base_url
|
||||
@@ -327,7 +328,7 @@ class ModelProviderRegistry:
|
||||
if not env_var:
|
||||
return None
|
||||
|
||||
return os.getenv(env_var)
|
||||
return get_env(env_var)
|
||||
|
||||
@classmethod
|
||||
def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]:
|
||||
|
||||
Reference in New Issue
Block a user