fixed all remaining issues with the session manager
This commit is contained in:
302
session-manager/async_docker_client.py
Normal file
302
session-manager/async_docker_client.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
Async Docker Operations Wrapper
|
||||
|
||||
Provides async wrappers for Docker operations to eliminate blocking calls
|
||||
in FastAPI async contexts and improve concurrency and scalability.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional, List, Any
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
|
||||
from aiodeocker import Docker
|
||||
from aiodeocker.containers import DockerContainer
|
||||
from aiodeocker.exceptions import DockerError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncDockerClient:
|
||||
"""Async wrapper for Docker operations using aiodeocker."""
|
||||
|
||||
def __init__(self):
|
||||
self._docker: Optional[Docker] = None
|
||||
self._connected = False
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.disconnect()
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to Docker daemon."""
|
||||
if self._connected:
|
||||
return
|
||||
|
||||
try:
|
||||
# Configure TLS if available
|
||||
tls_config = None
|
||||
if os.getenv("DOCKER_TLS_VERIFY") == "1":
|
||||
from aiodeocker.utils import create_tls_config
|
||||
|
||||
tls_config = create_tls_config(
|
||||
ca_cert=os.getenv("DOCKER_CA_CERT", "/etc/docker/certs/ca.pem"),
|
||||
client_cert=(
|
||||
os.getenv(
|
||||
"DOCKER_CLIENT_CERT", "/etc/docker/certs/client-cert.pem"
|
||||
),
|
||||
os.getenv(
|
||||
"DOCKER_CLIENT_KEY", "/etc/docker/certs/client-key.pem"
|
||||
),
|
||||
),
|
||||
verify=True,
|
||||
)
|
||||
|
||||
docker_host = os.getenv("DOCKER_HOST", "tcp://host.docker.internal:2376")
|
||||
self._docker = Docker(docker_host, tls=tls_config)
|
||||
|
||||
# Test connection
|
||||
await self._docker.ping()
|
||||
self._connected = True
|
||||
logger.info("Async Docker client connected successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Docker: {e}")
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from Docker daemon."""
|
||||
if self._docker and self._connected:
|
||||
await self._docker.close()
|
||||
self._connected = False
|
||||
logger.info("Async Docker client disconnected")
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Test Docker connectivity."""
|
||||
if not self._docker:
|
||||
return False
|
||||
try:
|
||||
await self._docker.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def create_container(
|
||||
self,
|
||||
image: str,
|
||||
name: str,
|
||||
volumes: Optional[Dict[str, Dict[str, str]]] = None,
|
||||
ports: Optional[Dict[str, int]] = None,
|
||||
environment: Optional[Dict[str, str]] = None,
|
||||
network_mode: str = "bridge",
|
||||
mem_limit: Optional[str] = None,
|
||||
cpu_quota: Optional[int] = None,
|
||||
cpu_period: Optional[int] = None,
|
||||
tmpfs: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> DockerContainer:
|
||||
"""
|
||||
Create a Docker container asynchronously.
|
||||
|
||||
Args:
|
||||
image: Container image name
|
||||
name: Container name
|
||||
volumes: Volume mounts
|
||||
ports: Port mappings
|
||||
environment: Environment variables
|
||||
network_mode: Network mode
|
||||
mem_limit: Memory limit (e.g., "4g")
|
||||
cpu_quota: CPU quota
|
||||
cpu_period: CPU period
|
||||
tmpfs: tmpfs mounts
|
||||
**kwargs: Additional container configuration
|
||||
|
||||
Returns:
|
||||
DockerContainer: The created container
|
||||
"""
|
||||
if not self._docker:
|
||||
raise RuntimeError("Docker client not connected")
|
||||
|
||||
config = {
|
||||
"Image": image,
|
||||
"name": name,
|
||||
"Volumes": volumes or {},
|
||||
"ExposedPorts": {f"{port}/tcp": {} for port in ports.values()}
|
||||
if ports
|
||||
else {},
|
||||
"Env": [f"{k}={v}" for k, v in (environment or {}).items()],
|
||||
"NetworkMode": network_mode,
|
||||
"HostConfig": {
|
||||
"Binds": [
|
||||
f"{host}:{container['bind']}:{container.get('mode', 'rw')}"
|
||||
for host, container in (volumes or {}).items()
|
||||
],
|
||||
"PortBindings": {
|
||||
f"{container_port}/tcp": [{"HostPort": str(host_port)}]
|
||||
for container_port, host_port in (ports or {}).items()
|
||||
},
|
||||
"Tmpfs": tmpfs or {},
|
||||
},
|
||||
}
|
||||
|
||||
# Add resource limits
|
||||
host_config = config["HostConfig"]
|
||||
if mem_limit:
|
||||
host_config["Memory"] = self._parse_memory_limit(mem_limit)
|
||||
if cpu_quota is not None:
|
||||
host_config["CpuQuota"] = cpu_quota
|
||||
if cpu_period is not None:
|
||||
host_config["CpuPeriod"] = cpu_period
|
||||
|
||||
# Add any additional host config
|
||||
host_config.update(kwargs.get("host_config", {}))
|
||||
|
||||
try:
|
||||
container = await self._docker.containers.create(config)
|
||||
logger.info(f"Container {name} created successfully")
|
||||
return container
|
||||
except DockerError as e:
|
||||
logger.error(f"Failed to create container {name}: {e}")
|
||||
raise
|
||||
|
||||
async def start_container(self, container: DockerContainer) -> None:
|
||||
"""Start a Docker container."""
|
||||
try:
|
||||
await container.start()
|
||||
logger.info(f"Container {container.id} started successfully")
|
||||
except DockerError as e:
|
||||
logger.error(f"Failed to start container {container.id}: {e}")
|
||||
raise
|
||||
|
||||
async def stop_container(
|
||||
self, container: DockerContainer, timeout: int = 10
|
||||
) -> None:
|
||||
"""Stop a Docker container."""
|
||||
try:
|
||||
await container.stop(timeout=timeout)
|
||||
logger.info(f"Container {container.id} stopped successfully")
|
||||
except DockerError as e:
|
||||
logger.error(f"Failed to stop container {container.id}: {e}")
|
||||
raise
|
||||
|
||||
async def remove_container(
|
||||
self, container: DockerContainer, force: bool = False
|
||||
) -> None:
|
||||
"""Remove a Docker container."""
|
||||
try:
|
||||
await container.delete(force=force)
|
||||
logger.info(f"Container {container.id} removed successfully")
|
||||
except DockerError as e:
|
||||
logger.error(f"Failed to remove container {container.id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_container(self, container_id: str) -> Optional[DockerContainer]:
|
||||
"""Get a container by ID or name."""
|
||||
try:
|
||||
return await self._docker.containers.get(container_id)
|
||||
except DockerError:
|
||||
return None
|
||||
|
||||
async def list_containers(
|
||||
self, all: bool = False, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[DockerContainer]:
|
||||
"""List Docker containers."""
|
||||
try:
|
||||
return await self._docker.containers.list(all=all, filters=filters)
|
||||
except DockerError as e:
|
||||
logger.error(f"Failed to list containers: {e}")
|
||||
return []
|
||||
|
||||
async def get_container_stats(
|
||||
self, container: DockerContainer
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get container statistics."""
|
||||
try:
|
||||
stats = await container.stats(stream=False)
|
||||
return stats
|
||||
except DockerError as e:
|
||||
logger.error(f"Failed to get stats for container {container.id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_system_info(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get Docker system information."""
|
||||
if not self._docker:
|
||||
return None
|
||||
try:
|
||||
return await self._docker.system.info()
|
||||
except DockerError as e:
|
||||
logger.error(f"Failed to get system info: {e}")
|
||||
return None
|
||||
|
||||
def _parse_memory_limit(self, memory_str: str) -> int:
|
||||
"""Parse memory limit string to bytes."""
|
||||
memory_str = memory_str.lower().strip()
|
||||
if memory_str.endswith("g"):
|
||||
return int(memory_str[:-1]) * (1024**3)
|
||||
elif memory_str.endswith("m"):
|
||||
return int(memory_str[:-1]) * (1024**2)
|
||||
elif memory_str.endswith("k"):
|
||||
return int(memory_str[:-1]) * 1024
|
||||
else:
|
||||
return int(memory_str)
|
||||
|
||||
|
||||
# Global async Docker client instance
|
||||
_async_docker_client = AsyncDockerClient()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_docker_client():
|
||||
"""Context manager for async Docker client."""
|
||||
async with _async_docker_client as client:
|
||||
yield client
|
||||
|
||||
|
||||
async def async_docker_ping() -> bool:
|
||||
"""Async ping Docker daemon."""
|
||||
async with get_async_docker_client() as client:
|
||||
return await client.ping()
|
||||
|
||||
|
||||
async def async_create_container(**kwargs) -> DockerContainer:
|
||||
"""Async container creation wrapper."""
|
||||
async with get_async_docker_client() as client:
|
||||
return await client.create_container(**kwargs)
|
||||
|
||||
|
||||
async def async_start_container(container: DockerContainer) -> None:
|
||||
"""Async container start wrapper."""
|
||||
async with get_async_docker_client() as client:
|
||||
await client.start_container(container)
|
||||
|
||||
|
||||
async def async_stop_container(container: DockerContainer, timeout: int = 10) -> None:
|
||||
"""Async container stop wrapper."""
|
||||
async with get_async_docker_client() as client:
|
||||
await client.stop_container(container, timeout)
|
||||
|
||||
|
||||
async def async_remove_container(
|
||||
container: DockerContainer, force: bool = False
|
||||
) -> None:
|
||||
"""Async container removal wrapper."""
|
||||
async with get_async_docker_client() as client:
|
||||
await client.remove_container(container, force)
|
||||
|
||||
|
||||
async def async_list_containers(
|
||||
all: bool = False, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[DockerContainer]:
|
||||
"""Async container listing wrapper."""
|
||||
async with get_async_docker_client() as client:
|
||||
return await client.list_containers(all=all, filters=filters)
|
||||
|
||||
|
||||
async def async_get_container(container_id: str) -> Optional[DockerContainer]:
|
||||
"""Async container retrieval wrapper."""
|
||||
async with get_async_docker_client() as client:
|
||||
return await client.get_container(container_id)
|
||||
574
session-manager/container_health.py
Normal file
574
session-manager/container_health.py
Normal file
@@ -0,0 +1,574 @@
|
||||
"""
|
||||
Container Health Monitoring System
|
||||
|
||||
Provides active monitoring of Docker containers with automatic failure detection,
|
||||
recovery mechanisms, and integration with session management and alerting systems.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
|
||||
from logging_config import get_logger, log_performance, log_security_event
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ContainerStatus(Enum):
|
||||
"""Container health status enumeration."""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
UNHEALTHY = "unhealthy"
|
||||
RESTARTING = "restarting"
|
||||
FAILED = "failed"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class HealthCheckResult:
|
||||
"""Result of a container health check."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
container_id: str,
|
||||
status: ContainerStatus,
|
||||
response_time: Optional[float] = None,
|
||||
error_message: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.container_id = container_id
|
||||
self.status = status
|
||||
self.response_time = response_time
|
||||
self.error_message = error_message
|
||||
self.metadata = metadata or {}
|
||||
self.timestamp = datetime.utcnow()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for logging/serialization."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"container_id": self.container_id,
|
||||
"status": self.status.value,
|
||||
"response_time": self.response_time,
|
||||
"error_message": self.error_message,
|
||||
"metadata": self.metadata,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class ContainerHealthMonitor:
|
||||
"""Monitors Docker container health and handles automatic recovery."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
check_interval: int = 30, # seconds
|
||||
health_timeout: float = 10.0, # seconds
|
||||
max_restart_attempts: int = 3,
|
||||
restart_delay: int = 5, # seconds
|
||||
failure_threshold: int = 3, # consecutive failures before restart
|
||||
):
|
||||
self.check_interval = check_interval
|
||||
self.health_timeout = health_timeout
|
||||
self.max_restart_attempts = max_restart_attempts
|
||||
self.restart_delay = restart_delay
|
||||
self.failure_threshold = failure_threshold
|
||||
|
||||
# Monitoring state
|
||||
self._monitoring = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._health_history: Dict[str, List[HealthCheckResult]] = {}
|
||||
self._restart_counts: Dict[str, int] = {}
|
||||
|
||||
# Dependencies (injected)
|
||||
self.session_manager = None
|
||||
self.docker_client = None
|
||||
|
||||
logger.info(
|
||||
"Container health monitor initialized",
|
||||
extra={
|
||||
"check_interval": check_interval,
|
||||
"health_timeout": health_timeout,
|
||||
"max_restart_attempts": max_restart_attempts,
|
||||
},
|
||||
)
|
||||
|
||||
def set_dependencies(self, session_manager, docker_client):
|
||||
"""Set dependencies for health monitoring."""
|
||||
self.session_manager = session_manager
|
||||
self.docker_client = docker_client
|
||||
|
||||
async def start_monitoring(self):
|
||||
"""Start the health monitoring loop."""
|
||||
if self._monitoring:
|
||||
logger.warning("Health monitoring already running")
|
||||
return
|
||||
|
||||
self._monitoring = True
|
||||
self._task = asyncio.create_task(self._monitoring_loop())
|
||||
logger.info("Container health monitoring started")
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""Stop the health monitoring loop."""
|
||||
if not self._monitoring:
|
||||
return
|
||||
|
||||
self._monitoring = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Container health monitoring stopped")
|
||||
|
||||
async def _monitoring_loop(self):
|
||||
"""Main monitoring loop."""
|
||||
while self._monitoring:
|
||||
try:
|
||||
await self._perform_health_checks()
|
||||
await self._cleanup_old_history()
|
||||
except Exception as e:
|
||||
logger.error("Error in health monitoring loop", extra={"error": str(e)})
|
||||
|
||||
await asyncio.sleep(self.check_interval)
|
||||
|
||||
async def _perform_health_checks(self):
|
||||
"""Perform health checks on all running containers."""
|
||||
if not self.session_manager:
|
||||
return
|
||||
|
||||
# Get all running sessions
|
||||
running_sessions = [
|
||||
session
|
||||
for session in self.session_manager.sessions.values()
|
||||
if session.status == "running"
|
||||
]
|
||||
|
||||
if not running_sessions:
|
||||
return
|
||||
|
||||
logger.debug(f"Checking health of {len(running_sessions)} running containers")
|
||||
|
||||
# Perform health checks concurrently
|
||||
tasks = [self._check_container_health(session) for session in running_sessions]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Process results
|
||||
for i, result in enumerate(results):
|
||||
session = running_sessions[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(
|
||||
"Health check failed",
|
||||
extra={
|
||||
"session_id": session.session_id,
|
||||
"container_id": session.container_id,
|
||||
"error": str(result),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
await self._process_health_result(result)
|
||||
|
||||
async def _check_container_health(self, session) -> HealthCheckResult:
|
||||
"""Check the health of a single container."""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
# Check if container exists and is running
|
||||
if not session.container_id:
|
||||
return HealthCheckResult(
|
||||
session.session_id,
|
||||
session.container_id or "unknown",
|
||||
ContainerStatus.UNKNOWN,
|
||||
error_message="No container ID",
|
||||
)
|
||||
|
||||
# Get container status
|
||||
container_info = await self._get_container_info(session.container_id)
|
||||
if not container_info:
|
||||
return HealthCheckResult(
|
||||
session.session_id,
|
||||
session.container_id,
|
||||
ContainerStatus.FAILED,
|
||||
error_message="Container not found",
|
||||
)
|
||||
|
||||
# Check container state
|
||||
state = container_info.get("State", {})
|
||||
status = state.get("Status", "unknown")
|
||||
|
||||
if status != "running":
|
||||
return HealthCheckResult(
|
||||
session.session_id,
|
||||
session.container_id,
|
||||
ContainerStatus.FAILED,
|
||||
error_message=f"Container status: {status}",
|
||||
)
|
||||
|
||||
# Check health status if available
|
||||
health = state.get("Health", {})
|
||||
if health:
|
||||
health_status = health.get("Status", "unknown")
|
||||
if health_status == "healthy":
|
||||
response_time = (
|
||||
asyncio.get_event_loop().time() - start_time
|
||||
) * 1000
|
||||
return HealthCheckResult(
|
||||
session.session_id,
|
||||
session.container_id,
|
||||
ContainerStatus.HEALTHY,
|
||||
response_time=response_time,
|
||||
metadata={
|
||||
"docker_status": status,
|
||||
"health_status": health_status,
|
||||
},
|
||||
)
|
||||
elif health_status in ["unhealthy", "starting"]:
|
||||
return HealthCheckResult(
|
||||
session.session_id,
|
||||
session.container_id,
|
||||
ContainerStatus.UNHEALTHY,
|
||||
error_message=f"Health check: {health_status}",
|
||||
metadata={
|
||||
"docker_status": status,
|
||||
"health_status": health_status,
|
||||
},
|
||||
)
|
||||
|
||||
# If no health check configured, consider running containers healthy
|
||||
response_time = (asyncio.get_event_loop().time() - start_time) * 1000
|
||||
return HealthCheckResult(
|
||||
session.session_id,
|
||||
session.container_id,
|
||||
ContainerStatus.HEALTHY,
|
||||
response_time=response_time,
|
||||
metadata={"docker_status": status},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
response_time = (asyncio.get_event_loop().time() - start_time) * 1000
|
||||
return HealthCheckResult(
|
||||
session.session_id,
|
||||
session.container_id or "unknown",
|
||||
ContainerStatus.UNKNOWN,
|
||||
response_time=response_time,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
async def _get_container_info(self, container_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get container information from Docker."""
|
||||
try:
|
||||
if self.docker_client:
|
||||
# Try async Docker client first
|
||||
container = await self.docker_client.get_container(container_id)
|
||||
if hasattr(container, "_container"):
|
||||
return await container._container.show()
|
||||
elif hasattr(container, "show"):
|
||||
return await container.show()
|
||||
else:
|
||||
# Fallback to sync client if available
|
||||
if (
|
||||
hasattr(self.session_manager, "docker_client")
|
||||
and self.session_manager.docker_client
|
||||
):
|
||||
container = self.session_manager.docker_client.containers.get(
|
||||
container_id
|
||||
)
|
||||
return container.attrs
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to get container info for {container_id}",
|
||||
extra={"error": str(e)},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _process_health_result(self, result: HealthCheckResult):
|
||||
"""Process a health check result and take appropriate action."""
|
||||
# Store result in history
|
||||
if result.session_id not in self._health_history:
|
||||
self._health_history[result.session_id] = []
|
||||
|
||||
self._health_history[result.session_id].append(result)
|
||||
|
||||
# Keep only recent history (last 10 checks)
|
||||
if len(self._health_history[result.session_id]) > 10:
|
||||
self._health_history[result.session_id] = self._health_history[
|
||||
result.session_id
|
||||
][-10:]
|
||||
|
||||
# Log result
|
||||
log_extra = result.to_dict()
|
||||
if result.status == ContainerStatus.HEALTHY:
|
||||
logger.debug("Container health check passed", extra=log_extra)
|
||||
elif result.status == ContainerStatus.UNHEALTHY:
|
||||
logger.warning("Container health check failed", extra=log_extra)
|
||||
elif result.status in [ContainerStatus.FAILED, ContainerStatus.UNKNOWN]:
|
||||
logger.error("Container health check critical", extra=log_extra)
|
||||
|
||||
# Check if restart is needed
|
||||
await self._check_restart_needed(result)
|
||||
|
||||
async def _check_restart_needed(self, result: HealthCheckResult):
|
||||
"""Check if a container needs to be restarted based on health history."""
|
||||
if result.status == ContainerStatus.HEALTHY:
|
||||
# Reset restart count on successful health check
|
||||
if result.session_id in self._restart_counts:
|
||||
self._restart_counts[result.session_id] = 0
|
||||
return
|
||||
|
||||
# Count recent failures
|
||||
recent_results = self._health_history.get(result.session_id, [])
|
||||
recent_failures = sum(
|
||||
1
|
||||
for r in recent_results[-self.failure_threshold :]
|
||||
if r.status
|
||||
in [
|
||||
ContainerStatus.UNHEALTHY,
|
||||
ContainerStatus.FAILED,
|
||||
ContainerStatus.UNKNOWN,
|
||||
]
|
||||
)
|
||||
|
||||
if recent_failures >= self.failure_threshold:
|
||||
await self._restart_container(result.session_id, result.container_id)
|
||||
|
||||
async def _restart_container(self, session_id: str, container_id: str):
|
||||
"""Restart a failed container."""
|
||||
# Check restart limit
|
||||
restart_count = self._restart_counts.get(session_id, 0)
|
||||
if restart_count >= self.max_restart_attempts:
|
||||
logger.error(
|
||||
"Container restart limit exceeded",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"container_id": container_id,
|
||||
"restart_attempts": restart_count,
|
||||
},
|
||||
)
|
||||
# Mark session as failed
|
||||
await self._mark_session_failed(
|
||||
session_id, f"Restart limit exceeded ({restart_count} attempts)"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Attempting container restart",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"container_id": container_id,
|
||||
"restart_attempt": restart_count + 1,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Stop the container
|
||||
await self._stop_container(container_id)
|
||||
|
||||
# Wait before restart
|
||||
await asyncio.sleep(self.restart_delay)
|
||||
|
||||
# Start new container for the session
|
||||
session = await self.session_manager.get_session(session_id)
|
||||
if session:
|
||||
# Update restart count
|
||||
self._restart_counts[session_id] = restart_count + 1
|
||||
|
||||
# Mark as restarting
|
||||
await self._update_session_status(session_id, "restarting")
|
||||
|
||||
# Trigger container restart through session manager
|
||||
if self.session_manager:
|
||||
# Create new container for the session
|
||||
await self.session_manager.create_session()
|
||||
logger.info(
|
||||
"Container restart initiated",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"restart_attempt": restart_count + 1,
|
||||
},
|
||||
)
|
||||
|
||||
# Log security event
|
||||
log_security_event(
|
||||
"container_restart",
|
||||
"warning",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"container_id": container_id,
|
||||
"reason": "health_check_failure",
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Container restart failed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"container_id": container_id,
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
async def _stop_container(self, container_id: str):
|
||||
"""Stop a container."""
|
||||
try:
|
||||
if self.docker_client:
|
||||
container = await self.docker_client.get_container(container_id)
|
||||
await self.docker_client.stop_container(container, timeout=10)
|
||||
elif (
|
||||
hasattr(self.session_manager, "docker_client")
|
||||
and self.session_manager.docker_client
|
||||
):
|
||||
container = self.session_manager.docker_client.containers.get(
|
||||
container_id
|
||||
)
|
||||
container.stop(timeout=10)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to stop container during restart",
|
||||
extra={"container_id": container_id, "error": str(e)},
|
||||
)
|
||||
|
||||
async def _update_session_status(self, session_id: str, status: str):
|
||||
"""Update session status."""
|
||||
if self.session_manager:
|
||||
session = self.session_manager.sessions.get(session_id)
|
||||
if session:
|
||||
session.status = status
|
||||
# Update in database if using database storage
|
||||
if (
|
||||
hasattr(self.session_manager, "USE_DATABASE_STORAGE")
|
||||
and self.session_manager.USE_DATABASE_STORAGE
|
||||
):
|
||||
try:
|
||||
from database import SessionModel
|
||||
|
||||
await SessionModel.update_session(
|
||||
session_id, {"status": status}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to update session status in database",
|
||||
extra={"session_id": session_id, "error": str(e)},
|
||||
)
|
||||
|
||||
async def _mark_session_failed(self, session_id: str, reason: str):
|
||||
"""Mark a session as permanently failed."""
|
||||
await self._update_session_status(session_id, "failed")
|
||||
|
||||
logger.error(
|
||||
"Session marked as failed",
|
||||
extra={"session_id": session_id, "reason": reason},
|
||||
)
|
||||
|
||||
# Log security event
|
||||
log_security_event(
|
||||
"session_failure", "error", {"session_id": session_id, "reason": reason}
|
||||
)
|
||||
|
||||
async def _cleanup_old_history(self):
|
||||
"""Clean up old health check history."""
|
||||
cutoff_time = datetime.utcnow() - timedelta(hours=1) # Keep last hour
|
||||
|
||||
for session_id in list(self._health_history.keys()):
|
||||
# Remove old results
|
||||
self._health_history[session_id] = [
|
||||
result
|
||||
for result in self._health_history[session_id]
|
||||
if result.timestamp > cutoff_time
|
||||
]
|
||||
|
||||
# Remove empty histories
|
||||
if not self._health_history[session_id]:
|
||||
del self._health_history[session_id]
|
||||
|
||||
def get_health_stats(self, session_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get health monitoring statistics."""
|
||||
stats = {
|
||||
"monitoring_active": self._monitoring,
|
||||
"check_interval": self.check_interval,
|
||||
"total_sessions_monitored": len(self._health_history),
|
||||
"sessions_with_failures": len(
|
||||
[
|
||||
sid
|
||||
for sid, history in self._health_history.items()
|
||||
if any(
|
||||
r.status != ContainerStatus.HEALTHY for r in history[-5:]
|
||||
) # Last 5 checks
|
||||
]
|
||||
),
|
||||
"restart_counts": dict(self._restart_counts),
|
||||
}
|
||||
|
||||
if session_id and session_id in self._health_history:
|
||||
recent_results = self._health_history[session_id][-10:] # Last 10 checks
|
||||
stats[f"session_{session_id}"] = {
|
||||
"total_checks": len(recent_results),
|
||||
"healthy_checks": sum(
|
||||
1 for r in recent_results if r.status == ContainerStatus.HEALTHY
|
||||
),
|
||||
"failed_checks": sum(
|
||||
1 for r in recent_results if r.status != ContainerStatus.HEALTHY
|
||||
),
|
||||
"average_response_time": sum(
|
||||
r.response_time or 0 for r in recent_results if r.response_time
|
||||
)
|
||||
/ max(1, sum(1 for r in recent_results if r.response_time)),
|
||||
"last_check": recent_results[-1].to_dict() if recent_results else None,
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def get_health_history(
|
||||
self, session_id: str, limit: int = 50
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get health check history for a session."""
|
||||
if session_id not in self._health_history:
|
||||
return []
|
||||
|
||||
return [
|
||||
result.to_dict() for result in self._health_history[session_id][-limit:]
|
||||
]
|
||||
|
||||
|
||||
# Global health monitor instance
|
||||
_container_health_monitor = ContainerHealthMonitor()
|
||||
|
||||
|
||||
def get_container_health_monitor() -> ContainerHealthMonitor:
|
||||
"""Get the global container health monitor instance."""
|
||||
return _container_health_monitor
|
||||
|
||||
|
||||
async def start_container_health_monitoring(session_manager=None, docker_client=None):
|
||||
"""Start container health monitoring."""
|
||||
monitor = get_container_health_monitor()
|
||||
if session_manager:
|
||||
monitor.set_dependencies(session_manager, docker_client)
|
||||
await monitor.start_monitoring()
|
||||
|
||||
|
||||
async def stop_container_health_monitoring():
|
||||
"""Stop container health monitoring."""
|
||||
monitor = get_container_health_monitor()
|
||||
await monitor.stop_monitoring()
|
||||
|
||||
|
||||
def get_container_health_stats(session_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get container health statistics."""
|
||||
monitor = get_container_health_monitor()
|
||||
return monitor.get_health_stats(session_id)
|
||||
|
||||
|
||||
def get_container_health_history(
|
||||
session_id: str, limit: int = 50
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get container health check history."""
|
||||
monitor = get_container_health_monitor()
|
||||
return monitor.get_health_history(session_id, limit)
|
||||
406
session-manager/database.py
Normal file
406
session-manager/database.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
Database Models and Connection Management for Session Persistence
|
||||
|
||||
Provides PostgreSQL-backed session storage with connection pooling,
|
||||
migrations, and health monitoring for production reliability.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncpg
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
from logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseConnection:
|
||||
"""PostgreSQL connection management with pooling."""
|
||||
|
||||
def __init__(self):
|
||||
self._pool: Optional[asyncpg.Pool] = None
|
||||
self._config = {
|
||||
"host": os.getenv("DB_HOST", "localhost"),
|
||||
"port": int(os.getenv("DB_PORT", "5432")),
|
||||
"user": os.getenv("DB_USER", "lovdata"),
|
||||
"password": os.getenv("DB_PASSWORD", "password"),
|
||||
"database": os.getenv("DB_NAME", "lovdata_chat"),
|
||||
"min_size": int(os.getenv("DB_MIN_CONNECTIONS", "5")),
|
||||
"max_size": int(os.getenv("DB_MAX_CONNECTIONS", "20")),
|
||||
"max_queries": int(os.getenv("DB_MAX_QUERIES", "50000")),
|
||||
"max_inactive_connection_lifetime": float(
|
||||
os.getenv("DB_MAX_INACTIVE_LIFETIME", "300.0")
|
||||
),
|
||||
}
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish database connection pool."""
|
||||
if self._pool:
|
||||
return
|
||||
|
||||
try:
|
||||
self._pool = await asyncpg.create_pool(**self._config)
|
||||
logger.info(
|
||||
"Database connection pool established",
|
||||
extra={
|
||||
"host": self._config["host"],
|
||||
"port": self._config["port"],
|
||||
"database": self._config["database"],
|
||||
"min_connections": self._config["min_size"],
|
||||
"max_connections": self._config["max_size"],
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to establish database connection", extra={"error": str(e)}
|
||||
)
|
||||
raise
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close database connection pool."""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
logger.info("Database connection pool closed")
|
||||
|
||||
async def get_connection(self) -> asyncpg.Connection:
|
||||
"""Get a database connection from the pool."""
|
||||
if not self._pool:
|
||||
await self.connect()
|
||||
return await self._pool.acquire()
|
||||
|
||||
async def release_connection(self, conn: asyncpg.Connection) -> None:
|
||||
"""Release a database connection back to the pool."""
|
||||
if self._pool:
|
||||
await self._pool.release(conn)
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform database health check."""
|
||||
try:
|
||||
conn = await self.get_connection()
|
||||
result = await conn.fetchval("SELECT 1")
|
||||
await self.release_connection(conn)
|
||||
|
||||
if result == 1:
|
||||
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
|
||||
else:
|
||||
return {"status": "unhealthy", "error": "Health check query failed"}
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", extra={"error": str(e)})
|
||||
return {"status": "unhealthy", "error": str(e)}
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""Context manager for database transactions."""
|
||||
conn = await self.get_connection()
|
||||
try:
|
||||
async with conn.transaction():
|
||||
yield conn
|
||||
finally:
|
||||
await self.release_connection(conn)
|
||||
|
||||
|
||||
# Global database connection instance
|
||||
_db_connection = DatabaseConnection()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_connection():
|
||||
"""Context manager for database connections."""
|
||||
conn = await _db_connection.get_connection()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await _db_connection.release_connection(conn)
|
||||
|
||||
|
||||
async def init_database() -> None:
|
||||
"""Initialize database and run migrations."""
|
||||
logger.info("Initializing database")
|
||||
|
||||
async with get_db_connection() as conn:
|
||||
# Create sessions table
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_id VARCHAR(32) PRIMARY KEY,
|
||||
container_name VARCHAR(255) NOT NULL,
|
||||
container_id VARCHAR(255),
|
||||
host_dir VARCHAR(1024) NOT NULL,
|
||||
port INTEGER,
|
||||
auth_token VARCHAR(255),
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
last_accessed TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
status VARCHAR(50) NOT NULL DEFAULT 'creating',
|
||||
metadata JSONB DEFAULT '{}'
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for performance
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_container_name ON sessions(container_name);
|
||||
""")
|
||||
|
||||
# Create cleanup function for expired sessions
|
||||
await conn.execute("""
|
||||
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
|
||||
RETURNS INTEGER AS $$
|
||||
DECLARE
|
||||
deleted_count INTEGER;
|
||||
BEGIN
|
||||
DELETE FROM sessions
|
||||
WHERE last_accessed < NOW() - INTERVAL '1 hour';
|
||||
|
||||
GET DIAGNOSTICS deleted_count = ROW_COUNT;
|
||||
RETURN deleted_count;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
""")
|
||||
|
||||
logger.info("Database initialized and migrations applied")
|
||||
|
||||
|
||||
async def shutdown_database() -> None:
|
||||
"""Shutdown database connections."""
|
||||
await _db_connection.disconnect()
|
||||
|
||||
|
||||
class SessionModel:
|
||||
"""Database model for sessions."""
|
||||
|
||||
@staticmethod
|
||||
async def create_session(session_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create a new session in the database."""
|
||||
async with get_db_connection() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO sessions (
|
||||
session_id, container_name, container_id, host_dir, port,
|
||||
auth_token, status, metadata
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING session_id, container_name, container_id, host_dir, port,
|
||||
auth_token, created_at, last_accessed, status, metadata
|
||||
""",
|
||||
session_data["session_id"],
|
||||
session_data["container_name"],
|
||||
session_data.get("container_id"),
|
||||
session_data["host_dir"],
|
||||
session_data.get("port"),
|
||||
session_data.get("auth_token"),
|
||||
session_data.get("status", "creating"),
|
||||
json.dumps(session_data.get("metadata", {})),
|
||||
)
|
||||
|
||||
if row:
|
||||
return dict(row)
|
||||
raise ValueError("Failed to create session")
|
||||
|
||||
@staticmethod
|
||||
async def get_session(session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a session by ID."""
|
||||
async with get_db_connection() as conn:
|
||||
# Update last_accessed timestamp
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET last_accessed = NOW()
|
||||
WHERE session_id = $1
|
||||
RETURNING session_id, container_name, container_id, host_dir, port,
|
||||
auth_token, created_at, last_accessed, status, metadata
|
||||
""",
|
||||
session_id,
|
||||
)
|
||||
|
||||
if row:
|
||||
result = dict(row)
|
||||
result["metadata"] = json.loads(result["metadata"] or "{}")
|
||||
return result
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def update_session(session_id: str, updates: Dict[str, Any]) -> bool:
|
||||
"""Update session fields."""
|
||||
if not updates:
|
||||
return True
|
||||
|
||||
async with get_db_connection() as conn:
|
||||
# Build dynamic update query
|
||||
set_parts = []
|
||||
values = [session_id]
|
||||
param_index = 2
|
||||
|
||||
for key, value in updates.items():
|
||||
if key in ["session_id", "created_at"]: # Don't update these
|
||||
continue
|
||||
|
||||
if key == "metadata":
|
||||
set_parts.append(f"metadata = ${param_index}")
|
||||
values.append(json.dumps(value))
|
||||
elif key == "last_accessed":
|
||||
set_parts.append(f"last_accessed = NOW()")
|
||||
else:
|
||||
set_parts.append(f"{key} = ${param_index}")
|
||||
values.append(value)
|
||||
param_index += 1
|
||||
|
||||
if not set_parts:
|
||||
return True
|
||||
|
||||
query = f"""
|
||||
UPDATE sessions
|
||||
SET {", ".join(set_parts)}
|
||||
WHERE session_id = $1
|
||||
"""
|
||||
|
||||
result = await conn.execute(query, *values)
|
||||
return result == "UPDATE 1"
|
||||
|
||||
@staticmethod
|
||||
async def delete_session(session_id: str) -> bool:
|
||||
"""Delete a session."""
|
||||
async with get_db_connection() as conn:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM sessions WHERE session_id = $1", session_id
|
||||
)
|
||||
return result == "DELETE 1"
|
||||
|
||||
@staticmethod
|
||||
async def list_sessions(limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
|
||||
"""List sessions with pagination."""
|
||||
async with get_db_connection() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT session_id, container_name, container_id, host_dir, port,
|
||||
auth_token, created_at, last_accessed, status, metadata
|
||||
FROM sessions
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
""",
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
|
||||
result = []
|
||||
for row in rows:
|
||||
session = dict(row)
|
||||
session["metadata"] = json.loads(session["metadata"] or "{}")
|
||||
result.append(session)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def count_sessions() -> int:
|
||||
"""Count total sessions."""
|
||||
async with get_db_connection() as conn:
|
||||
return await conn.fetchval("SELECT COUNT(*) FROM sessions")
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_sessions() -> int:
|
||||
"""Clean up expired sessions using database function."""
|
||||
async with get_db_connection() as conn:
|
||||
return await conn.fetchval("SELECT cleanup_expired_sessions()")
|
||||
|
||||
@staticmethod
|
||||
async def get_active_sessions_count() -> int:
|
||||
"""Get count of active (running) sessions."""
|
||||
async with get_db_connection() as conn:
|
||||
return await conn.fetchval("""
|
||||
SELECT COUNT(*) FROM sessions
|
||||
WHERE status = 'running'
|
||||
""")
|
||||
|
||||
@staticmethod
|
||||
async def get_sessions_by_status(status: str) -> List[Dict[str, Any]]:
|
||||
"""Get sessions by status."""
|
||||
async with get_db_connection() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT session_id, container_name, container_id, host_dir, port,
|
||||
auth_token, created_at, last_accessed, status, metadata
|
||||
FROM sessions
|
||||
WHERE status = $1
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
status,
|
||||
)
|
||||
|
||||
result = []
|
||||
for row in rows:
|
||||
session = dict(row)
|
||||
session["metadata"] = json.loads(session["metadata"] or "{}")
|
||||
result.append(session)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Migration utilities
|
||||
async def run_migrations():
|
||||
"""Run database migrations."""
|
||||
logger.info("Running database migrations")
|
||||
|
||||
async with get_db_connection() as conn:
|
||||
# Migration 1: Add metadata column if it doesn't exist
|
||||
try:
|
||||
await conn.execute("""
|
||||
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS metadata JSONB DEFAULT '{}'
|
||||
""")
|
||||
logger.info("Migration: Added metadata column")
|
||||
except Exception as e:
|
||||
logger.warning(f"Migration metadata column may already exist: {e}")
|
||||
|
||||
# Migration 2: Add indexes if they don't exist
|
||||
try:
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
|
||||
""")
|
||||
logger.info("Migration: Added performance indexes")
|
||||
except Exception as e:
|
||||
logger.warning(f"Migration indexes may already exist: {e}")
|
||||
|
||||
logger.info("Database migrations completed")
|
||||
|
||||
|
||||
# Health monitoring
|
||||
async def get_database_stats() -> Dict[str, Any]:
|
||||
"""Get database statistics and health information."""
|
||||
try:
|
||||
async with get_db_connection() as conn:
|
||||
# Get basic stats
|
||||
session_count = await conn.fetchval("SELECT COUNT(*) FROM sessions")
|
||||
active_sessions = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM sessions WHERE status = 'running'"
|
||||
)
|
||||
oldest_session = await conn.fetchval("SELECT MIN(created_at) FROM sessions")
|
||||
newest_session = await conn.fetchval("SELECT MAX(created_at) FROM sessions")
|
||||
|
||||
# Get database size information
|
||||
db_size = await conn.fetchval("""
|
||||
SELECT pg_size_pretty(pg_database_size(current_database()))
|
||||
""")
|
||||
|
||||
return {
|
||||
"total_sessions": session_count,
|
||||
"active_sessions": active_sessions,
|
||||
"oldest_session": oldest_session.isoformat()
|
||||
if oldest_session
|
||||
else None,
|
||||
"newest_session": newest_session.isoformat()
|
||||
if newest_session
|
||||
else None,
|
||||
"database_size": db_size,
|
||||
"status": "healthy",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Failed to get database stats", extra={"error": str(e)})
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
}
|
||||
635
session-manager/docker_service.py
Normal file
635
session-manager/docker_service.py
Normal file
@@ -0,0 +1,635 @@
|
||||
"""
|
||||
Docker Service Layer
|
||||
|
||||
Provides a clean abstraction for Docker operations, separating container management
|
||||
from business logic. Enables easy testing, mocking, and future Docker client changes.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ContainerInfo:
|
||||
"""Container information data structure."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
container_id: str,
|
||||
name: str,
|
||||
image: str,
|
||||
status: str,
|
||||
ports: Optional[Dict[str, int]] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
health_status: Optional[str] = None,
|
||||
):
|
||||
self.container_id = container_id
|
||||
self.name = name
|
||||
self.image = image
|
||||
self.status = status
|
||||
self.ports = ports or {}
|
||||
self.created_at = created_at or datetime.utcnow()
|
||||
self.health_status = health_status
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"container_id": self.container_id,
|
||||
"name": self.name,
|
||||
"image": self.image,
|
||||
"status": self.status,
|
||||
"ports": self.ports,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"health_status": self.health_status,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ContainerInfo":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
container_id=data["container_id"],
|
||||
name=data["name"],
|
||||
image=data["image"],
|
||||
status=data["status"],
|
||||
ports=data.get("ports", {}),
|
||||
created_at=datetime.fromisoformat(data["created_at"])
|
||||
if data.get("created_at")
|
||||
else None,
|
||||
health_status=data.get("health_status"),
|
||||
)
|
||||
|
||||
|
||||
class DockerOperationError(Exception):
|
||||
"""Docker operation error."""
|
||||
|
||||
def __init__(
|
||||
self, operation: str, container_id: Optional[str] = None, message: str = ""
|
||||
):
|
||||
self.operation = operation
|
||||
self.container_id = container_id
|
||||
self.message = message
|
||||
super().__init__(f"Docker {operation} failed: {message}")
|
||||
|
||||
|
||||
class DockerService:
|
||||
"""
|
||||
Docker service abstraction layer.
|
||||
|
||||
Provides a clean interface for container operations,
|
||||
enabling easy testing and future Docker client changes.
|
||||
"""
|
||||
|
||||
def __init__(self, use_async: bool = True):
|
||||
"""
|
||||
Initialize Docker service.
|
||||
|
||||
Args:
|
||||
use_async: Whether to use async Docker operations
|
||||
"""
|
||||
self.use_async = use_async
|
||||
self._docker_client = None
|
||||
self._initialized = False
|
||||
|
||||
logger.info("Docker service initialized", extra={"async_mode": use_async})
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the Docker client connection."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.use_async:
|
||||
# Initialize async Docker client
|
||||
from async_docker_client import AsyncDockerClient
|
||||
|
||||
self._docker_client = AsyncDockerClient()
|
||||
await self._docker_client.connect()
|
||||
else:
|
||||
# Initialize sync Docker client
|
||||
import docker
|
||||
|
||||
tls_config = docker.tls.TLSConfig(
|
||||
ca_cert=os.getenv("DOCKER_CA_CERT", "/etc/docker/certs/ca.pem"),
|
||||
client_cert=(
|
||||
os.getenv(
|
||||
"DOCKER_CLIENT_CERT", "/etc/docker/certs/client-cert.pem"
|
||||
),
|
||||
os.getenv(
|
||||
"DOCKER_CLIENT_KEY", "/etc/docker/certs/client-key.pem"
|
||||
),
|
||||
),
|
||||
verify=True,
|
||||
)
|
||||
docker_host = os.getenv(
|
||||
"DOCKER_HOST", "tcp://host.docker.internal:2376"
|
||||
)
|
||||
self._docker_client = docker.from_env()
|
||||
self._docker_client.api = docker.APIClient(
|
||||
base_url=docker_host, tls=tls_config, version="auto"
|
||||
)
|
||||
# Test connection
|
||||
self._docker_client.ping()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("Docker service connection established")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Docker service", extra={"error": str(e)})
|
||||
raise DockerOperationError("initialization", message=str(e))
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the Docker client connection."""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.use_async and self._docker_client:
|
||||
await self._docker_client.disconnect()
|
||||
# Sync client doesn't need explicit shutdown
|
||||
|
||||
self._initialized = False
|
||||
logger.info("Docker service connection closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error during Docker service shutdown", extra={"error": str(e)}
|
||||
)
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Test Docker daemon connectivity."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
if self.use_async:
|
||||
return await self._docker_client.ping()
|
||||
else:
|
||||
self._docker_client.ping()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Docker ping failed", extra={"error": str(e)})
|
||||
return False
|
||||
|
||||
async def create_container(
|
||||
self,
|
||||
name: str,
|
||||
image: str,
|
||||
volumes: Optional[Dict[str, Dict[str, str]]] = None,
|
||||
ports: Optional[Dict[str, int]] = None,
|
||||
environment: Optional[Dict[str, str]] = None,
|
||||
network_mode: str = "bridge",
|
||||
mem_limit: Optional[str] = None,
|
||||
cpu_quota: Optional[int] = None,
|
||||
cpu_period: Optional[int] = None,
|
||||
tmpfs: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> ContainerInfo:
|
||||
"""
|
||||
Create a Docker container.
|
||||
|
||||
Args:
|
||||
name: Container name
|
||||
image: Container image
|
||||
volumes: Volume mounts
|
||||
ports: Port mappings
|
||||
environment: Environment variables
|
||||
network_mode: Network mode
|
||||
mem_limit: Memory limit
|
||||
cpu_quota: CPU quota
|
||||
cpu_period: CPU period
|
||||
tmpfs: tmpfs mounts
|
||||
**kwargs: Additional options
|
||||
|
||||
Returns:
|
||||
ContainerInfo: Information about created container
|
||||
|
||||
Raises:
|
||||
DockerOperationError: If container creation fails
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Creating container",
|
||||
extra={
|
||||
"container_name": name,
|
||||
"image": image,
|
||||
"memory_limit": mem_limit,
|
||||
"cpu_quota": cpu_quota,
|
||||
},
|
||||
)
|
||||
|
||||
if self.use_async:
|
||||
container = await self._docker_client.create_container(
|
||||
image=image,
|
||||
name=name,
|
||||
volumes=volumes,
|
||||
ports=ports,
|
||||
environment=environment,
|
||||
network_mode=network_mode,
|
||||
mem_limit=mem_limit,
|
||||
cpu_quota=cpu_quota,
|
||||
cpu_period=cpu_period,
|
||||
tmpfs=tmpfs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return ContainerInfo(
|
||||
container_id=container.id,
|
||||
name=name,
|
||||
image=image,
|
||||
status="created",
|
||||
ports=ports,
|
||||
)
|
||||
else:
|
||||
container = self._docker_client.containers.run(
|
||||
image,
|
||||
name=name,
|
||||
volumes=volumes,
|
||||
ports={f"{port}/tcp": port for port in ports.values()}
|
||||
if ports
|
||||
else None,
|
||||
environment=environment,
|
||||
network_mode=network_mode,
|
||||
mem_limit=mem_limit,
|
||||
cpu_quota=cpu_quota,
|
||||
cpu_period=cpu_period,
|
||||
tmpfs=tmpfs,
|
||||
detach=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return ContainerInfo(
|
||||
container_id=container.id,
|
||||
name=name,
|
||||
image=image,
|
||||
status="running",
|
||||
ports=ports,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Container creation failed",
|
||||
extra={"container_name": name, "image": image, "error": str(e)},
|
||||
)
|
||||
raise DockerOperationError("create_container", name, str(e))
|
||||
|
||||
async def start_container(self, container_id: str) -> None:
|
||||
"""
|
||||
Start a Docker container.
|
||||
|
||||
Args:
|
||||
container_id: Container ID
|
||||
|
||||
Raises:
|
||||
DockerOperationError: If container start fails
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
logger.info("Starting container", extra={"container_id": container_id})
|
||||
|
||||
if self.use_async:
|
||||
container = await self._docker_client.get_container(container_id)
|
||||
await self._docker_client.start_container(container)
|
||||
else:
|
||||
container = self._docker_client.containers.get(container_id)
|
||||
container.start()
|
||||
|
||||
logger.info(
|
||||
"Container started successfully", extra={"container_id": container_id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Container start failed",
|
||||
extra={"container_id": container_id, "error": str(e)},
|
||||
)
|
||||
raise DockerOperationError("start_container", container_id, str(e))
|
||||
|
||||
async def stop_container(self, container_id: str, timeout: int = 10) -> None:
|
||||
"""
|
||||
Stop a Docker container.
|
||||
|
||||
Args:
|
||||
container_id: Container ID
|
||||
timeout: Stop timeout in seconds
|
||||
|
||||
Raises:
|
||||
DockerOperationError: If container stop fails
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Stopping container",
|
||||
extra={"container_id": container_id, "timeout": timeout},
|
||||
)
|
||||
|
||||
if self.use_async:
|
||||
container = await self._docker_client.get_container(container_id)
|
||||
await self._docker_client.stop_container(container, timeout)
|
||||
else:
|
||||
container = self._docker_client.containers.get(container_id)
|
||||
container.stop(timeout=timeout)
|
||||
|
||||
logger.info(
|
||||
"Container stopped successfully", extra={"container_id": container_id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Container stop failed",
|
||||
extra={"container_id": container_id, "error": str(e)},
|
||||
)
|
||||
raise DockerOperationError("stop_container", container_id, str(e))
|
||||
|
||||
async def remove_container(self, container_id: str, force: bool = False) -> None:
|
||||
"""
|
||||
Remove a Docker container.
|
||||
|
||||
Args:
|
||||
container_id: Container ID
|
||||
force: Force removal if running
|
||||
|
||||
Raises:
|
||||
DockerOperationError: If container removal fails
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Removing container",
|
||||
extra={"container_id": container_id, "force": force},
|
||||
)
|
||||
|
||||
if self.use_async:
|
||||
container = await self._docker_client.get_container(container_id)
|
||||
await self._docker_client.remove_container(container, force)
|
||||
else:
|
||||
container = self._docker_client.containers.get(container_id)
|
||||
container.remove(force=force)
|
||||
|
||||
logger.info(
|
||||
"Container removed successfully", extra={"container_id": container_id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Container removal failed",
|
||||
extra={"container_id": container_id, "error": str(e)},
|
||||
)
|
||||
raise DockerOperationError("remove_container", container_id, str(e))
|
||||
|
||||
async def get_container_info(self, container_id: str) -> Optional[ContainerInfo]:
|
||||
"""
|
||||
Get information about a container.
|
||||
|
||||
Args:
|
||||
container_id: Container ID
|
||||
|
||||
Returns:
|
||||
ContainerInfo or None: Container information
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
if self.use_async:
|
||||
container_info = await self._docker_client._get_container_info(
|
||||
container_id
|
||||
)
|
||||
if container_info:
|
||||
state = container_info.get("State", {})
|
||||
config = container_info.get("Config", {})
|
||||
return ContainerInfo(
|
||||
container_id=container_id,
|
||||
name=config.get("Name", "").lstrip("/"),
|
||||
image=config.get("Image", ""),
|
||||
status=state.get("Status", "unknown"),
|
||||
health_status=state.get("Health", {}).get("Status"),
|
||||
)
|
||||
else:
|
||||
container = self._docker_client.containers.get(container_id)
|
||||
return ContainerInfo(
|
||||
container_id=container.id,
|
||||
name=container.name,
|
||||
image=container.image.tags[0]
|
||||
if container.image.tags
|
||||
else container.image.id,
|
||||
status=container.status,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Container info retrieval failed",
|
||||
extra={"container_id": container_id, "error": str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
async def list_containers(
|
||||
self, all: bool = False, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[ContainerInfo]:
|
||||
"""
|
||||
List Docker containers.
|
||||
|
||||
Args:
|
||||
all: Include stopped containers
|
||||
filters: Container filters
|
||||
|
||||
Returns:
|
||||
List[ContainerInfo]: List of container information
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
if self.use_async:
|
||||
containers = await self._docker_client.list_containers(
|
||||
all=all, filters=filters
|
||||
)
|
||||
result = []
|
||||
for container in containers:
|
||||
container_info = await self._docker_client._get_container_info(
|
||||
container.id
|
||||
)
|
||||
if container_info:
|
||||
state = container_info.get("State", {})
|
||||
config = container_info.get("Config", {})
|
||||
result.append(
|
||||
ContainerInfo(
|
||||
container_id=container.id,
|
||||
name=config.get("Name", "").lstrip("/"),
|
||||
image=config.get("Image", ""),
|
||||
status=state.get("Status", "unknown"),
|
||||
health_status=state.get("Health", {}).get("Status"),
|
||||
)
|
||||
)
|
||||
return result
|
||||
else:
|
||||
containers = self._docker_client.containers.list(
|
||||
all=all, filters=filters
|
||||
)
|
||||
result = []
|
||||
for container in containers:
|
||||
result.append(
|
||||
ContainerInfo(
|
||||
container_id=container.id,
|
||||
name=container.name,
|
||||
image=container.image.tags[0]
|
||||
if container.image.tags
|
||||
else container.image.id,
|
||||
status=container.status,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Container listing failed", extra={"error": str(e)})
|
||||
return []
|
||||
|
||||
async def get_container_logs(self, container_id: str, tail: int = 100) -> str:
|
||||
"""
|
||||
Get container logs.
|
||||
|
||||
Args:
|
||||
container_id: Container ID
|
||||
tail: Number of log lines to retrieve
|
||||
|
||||
Returns:
|
||||
str: Container logs
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
if self.use_async:
|
||||
container = await self._docker_client.get_container(container_id)
|
||||
logs = await container.log(stdout=True, stderr=True, tail=tail)
|
||||
return "\n".join(logs)
|
||||
else:
|
||||
container = self._docker_client.containers.get(container_id)
|
||||
logs = container.logs(tail=tail).decode("utf-8")
|
||||
return logs
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Container log retrieval failed",
|
||||
extra={"container_id": container_id, "error": str(e)},
|
||||
)
|
||||
return ""
|
||||
|
||||
async def get_system_info(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get Docker system information.
|
||||
|
||||
Returns:
|
||||
Dict or None: System information
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
if self.use_async:
|
||||
return await self._docker_client.get_system_info()
|
||||
else:
|
||||
return self._docker_client.info()
|
||||
except Exception as e:
|
||||
logger.warning("System info retrieval failed", extra={"error": str(e)})
|
||||
return None
|
||||
|
||||
# Context manager support
|
||||
async def __aenter__(self):
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.shutdown()
|
||||
|
||||
|
||||
class MockDockerService(DockerService):
|
||||
"""
|
||||
Mock Docker service for testing without actual Docker.
|
||||
|
||||
Provides the same interface but with in-memory operations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(use_async=False)
|
||||
self._containers: Dict[str, ContainerInfo] = {}
|
||||
self._next_id = 1
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Mock initialization - always succeeds."""
|
||||
self._initialized = True
|
||||
logger.info("Mock Docker service initialized")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Mock shutdown."""
|
||||
self._containers.clear()
|
||||
self._initialized = False
|
||||
logger.info("Mock Docker service shutdown")
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Mock ping - always succeeds."""
|
||||
return True
|
||||
|
||||
async def create_container(self, name: str, image: str, **kwargs) -> ContainerInfo:
|
||||
"""Mock container creation."""
|
||||
container_id = f"mock-{self._next_id}"
|
||||
self._next_id += 1
|
||||
|
||||
container = ContainerInfo(
|
||||
container_id=container_id, name=name, image=image, status="created"
|
||||
)
|
||||
|
||||
self._containers[container_id] = container
|
||||
logger.info(
|
||||
"Mock container created",
|
||||
extra={
|
||||
"container_id": container_id,
|
||||
"container_name": name,
|
||||
"image": image,
|
||||
},
|
||||
)
|
||||
|
||||
return container
|
||||
|
||||
async def start_container(self, container_id: str) -> None:
|
||||
"""Mock container start."""
|
||||
if container_id in self._containers:
|
||||
self._containers[container_id].status = "running"
|
||||
logger.info("Mock container started", extra={"container_id": container_id})
|
||||
|
||||
async def stop_container(self, container_id: str, timeout: int = 10) -> None:
|
||||
"""Mock container stop."""
|
||||
if container_id in self._containers:
|
||||
self._containers[container_id].status = "exited"
|
||||
logger.info("Mock container stopped", extra={"container_id": container_id})
|
||||
|
||||
async def remove_container(self, container_id: str, force: bool = False) -> None:
|
||||
"""Mock container removal."""
|
||||
if container_id in self._containers:
|
||||
del self._containers[container_id]
|
||||
logger.info("Mock container removed", extra={"container_id": container_id})
|
||||
|
||||
async def get_container_info(self, container_id: str) -> Optional[ContainerInfo]:
|
||||
"""Mock container info retrieval."""
|
||||
return self._containers.get(container_id)
|
||||
|
||||
async def list_containers(
|
||||
self, all: bool = False, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[ContainerInfo]:
|
||||
"""Mock container listing."""
|
||||
return list(self._containers.values())
|
||||
252
session-manager/host_ip_detector.py
Normal file
252
session-manager/host_ip_detector.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
Host IP Detection Utilities
|
||||
|
||||
Provides robust methods to detect the Docker host IP from within a container,
|
||||
supporting multiple Docker environments and network configurations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import socket
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from functools import lru_cache
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HostIPDetector:
|
||||
"""Detects the Docker host IP address from container perspective."""
|
||||
|
||||
# Common Docker gateway IPs to try as fallbacks
|
||||
COMMON_GATEWAYS = [
|
||||
"172.17.0.1", # Default Docker bridge
|
||||
"172.18.0.1", # Docker networks
|
||||
"192.168.65.1", # Docker Desktop
|
||||
"192.168.66.1", # Alternative Docker Desktop
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._detected_ip: Optional[str] = None
|
||||
self._last_detection: float = 0
|
||||
self._cache_timeout: float = 300 # 5 minutes cache
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def detect_host_ip(self) -> str:
|
||||
"""
|
||||
Detect the Docker host IP using multiple methods with fallbacks.
|
||||
|
||||
Returns:
|
||||
str: The detected host IP address
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no host IP can be detected
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Use cached result if recent
|
||||
if (
|
||||
self._detected_ip
|
||||
and (current_time - self._last_detection) < self._cache_timeout
|
||||
):
|
||||
logger.debug(f"Using cached host IP: {self._detected_ip}")
|
||||
return self._detected_ip
|
||||
|
||||
logger.info("Detecting Docker host IP...")
|
||||
|
||||
detection_methods = [
|
||||
self._detect_via_docker_internal,
|
||||
self._detect_via_gateway_env,
|
||||
self._detect_via_route_table,
|
||||
self._detect_via_network_connect,
|
||||
self._detect_via_common_gateways,
|
||||
]
|
||||
|
||||
for method in detection_methods:
|
||||
try:
|
||||
ip = method()
|
||||
if ip and self._validate_ip(ip):
|
||||
logger.info(
|
||||
f"Successfully detected host IP using {method.__name__}: {ip}"
|
||||
)
|
||||
self._detected_ip = ip
|
||||
self._last_detection = current_time
|
||||
return ip
|
||||
else:
|
||||
logger.debug(f"Method {method.__name__} returned invalid IP: {ip}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Method {method.__name__} failed: {e}")
|
||||
|
||||
# If all methods fail, raise an error
|
||||
raise RuntimeError(
|
||||
"Could not detect Docker host IP. Tried all detection methods. "
|
||||
"Please check your Docker network configuration or set HOST_IP environment variable."
|
||||
)
|
||||
|
||||
def _detect_via_docker_internal(self) -> Optional[str]:
|
||||
"""Detect via host.docker.internal (Docker Desktop, Docker for Mac/Windows)."""
|
||||
try:
|
||||
# Try to resolve host.docker.internal
|
||||
ip = socket.gethostbyname("host.docker.internal")
|
||||
if ip != "127.0.0.1": # Make sure it's not localhost
|
||||
return ip
|
||||
except socket.gaierror:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _detect_via_gateway_env(self) -> Optional[str]:
|
||||
"""Detect via Docker gateway environment variables."""
|
||||
# Check common Docker gateway environment variables
|
||||
gateway_vars = [
|
||||
"DOCKER_HOST_GATEWAY",
|
||||
"GATEWAY",
|
||||
"HOST_IP",
|
||||
]
|
||||
|
||||
for var in gateway_vars:
|
||||
ip = os.getenv(var)
|
||||
if ip:
|
||||
logger.debug(f"Found host IP in environment variable {var}: {ip}")
|
||||
return ip
|
||||
|
||||
return None
|
||||
|
||||
def _detect_via_route_table(self) -> Optional[str]:
|
||||
"""Detect via Linux route table (/proc/net/route)."""
|
||||
try:
|
||||
with open("/proc/net/route", "r") as f:
|
||||
for line in f:
|
||||
fields = line.strip().split()
|
||||
if (
|
||||
len(fields) >= 8
|
||||
and fields[0] != "Iface"
|
||||
and fields[7] == "00000000"
|
||||
):
|
||||
# Found default route, convert hex gateway to IP
|
||||
gateway_hex = fields[2]
|
||||
if len(gateway_hex) == 8:
|
||||
# Convert from hex to IP (little endian)
|
||||
ip_parts = []
|
||||
for i in range(0, 8, 2):
|
||||
ip_parts.append(str(int(gateway_hex[i : i + 2], 16)))
|
||||
ip = ".".join(reversed(ip_parts))
|
||||
if ip != "0.0.0.0":
|
||||
return ip
|
||||
except (IOError, ValueError, IndexError) as e:
|
||||
logger.debug(f"Failed to read route table: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _detect_via_network_connect(self) -> Optional[str]:
|
||||
"""Detect by attempting to connect to a known external service."""
|
||||
try:
|
||||
# Try to connect to a reliable external service to determine local IP
|
||||
# We'll use the Docker daemon itself as a reference
|
||||
docker_host = os.getenv("DOCKER_HOST", "tcp://host.docker.internal:2376")
|
||||
|
||||
if docker_host.startswith("tcp://"):
|
||||
host_part = docker_host[6:].split(":")[0]
|
||||
if host_part not in ["localhost", "127.0.0.1"]:
|
||||
# Try to resolve the host
|
||||
try:
|
||||
ip = socket.gethostbyname(host_part)
|
||||
if ip != "127.0.0.1":
|
||||
return ip
|
||||
except socket.gaierror:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Network connect detection failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _detect_via_common_gateways(self) -> Optional[str]:
|
||||
"""Try common Docker gateway IPs."""
|
||||
for gateway in self.COMMON_GATEWAYS:
|
||||
if self._test_ip_connectivity(gateway):
|
||||
logger.debug(f"Found working gateway: {gateway}")
|
||||
return gateway
|
||||
|
||||
return None
|
||||
|
||||
def _test_ip_connectivity(self, ip: str) -> bool:
|
||||
"""Test if an IP address is reachable."""
|
||||
try:
|
||||
# Try to connect to a common port (Docker API or SSH)
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1.0)
|
||||
result = sock.connect_ex((ip, 22)) # SSH port, commonly available
|
||||
sock.close()
|
||||
return result == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _validate_ip(self, ip: str) -> bool:
|
||||
"""Validate that the IP address is reasonable."""
|
||||
try:
|
||||
socket.inet_aton(ip)
|
||||
# Basic validation - should not be localhost or invalid ranges
|
||||
if ip.startswith("127."):
|
||||
return False
|
||||
if ip == "0.0.0.0":
|
||||
return False
|
||||
# Should be a private IP range
|
||||
parts = ip.split(".")
|
||||
if len(parts) != 4:
|
||||
return False
|
||||
first_octet = int(parts[0])
|
||||
# Common Docker gateway ranges
|
||||
return first_octet in [10, 172, 192]
|
||||
except socket.error:
|
||||
return False
|
||||
|
||||
async def async_detect_host_ip(self) -> str:
|
||||
"""Async version of detect_host_ip for testing."""
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
return await loop.run_in_executor(executor, self.detect_host_ip)
|
||||
|
||||
|
||||
# Global instance for caching
|
||||
_host_detector = HostIPDetector()
|
||||
|
||||
|
||||
def get_host_ip() -> str:
|
||||
"""
|
||||
Get the Docker host IP address from container perspective.
|
||||
|
||||
This function caches the result for performance and tries multiple
|
||||
detection methods with fallbacks for different Docker environments.
|
||||
|
||||
Returns:
|
||||
str: The detected host IP address
|
||||
|
||||
Raises:
|
||||
RuntimeError: If host IP detection fails
|
||||
"""
|
||||
return _host_detector.detect_host_ip()
|
||||
|
||||
|
||||
async def async_get_host_ip() -> str:
|
||||
"""
|
||||
Async version of get_host_ip for use in async contexts.
|
||||
|
||||
Since the actual detection is not async, this just wraps the sync version.
|
||||
"""
|
||||
# Run in thread pool to avoid blocking async context
|
||||
import concurrent.futures
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
return await loop.run_in_executor(executor, get_host_ip)
|
||||
|
||||
|
||||
def reset_host_ip_cache():
|
||||
"""Reset the cached host IP detection result."""
|
||||
global _host_detector
|
||||
_host_detector = HostIPDetector()
|
||||
182
session-manager/http_pool.py
Normal file
182
session-manager/http_pool.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
HTTP Connection Pool Manager
|
||||
|
||||
Provides a global httpx.AsyncClient instance with connection pooling
|
||||
to eliminate the overhead of creating new HTTP clients for each proxy request.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTTPConnectionPool:
|
||||
"""Global HTTP connection pool manager for proxy operations."""
|
||||
|
||||
def __init__(self):
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
self._last_health_check: float = 0
|
||||
self._health_check_interval: float = 60 # Check health every 60 seconds
|
||||
self._is_healthy: bool = True
|
||||
self._reconnect_lock = asyncio.Lock()
|
||||
|
||||
# Connection pool configuration
|
||||
self._config = {
|
||||
"limits": httpx.Limits(
|
||||
max_keepalive_connections=20, # Keep connections alive
|
||||
max_connections=100, # Max total connections
|
||||
keepalive_expiry=300.0, # Keep connections alive for 5 minutes
|
||||
),
|
||||
"timeout": httpx.Timeout(
|
||||
connect=10.0, # Connection timeout
|
||||
read=30.0, # Read timeout
|
||||
write=10.0, # Write timeout
|
||||
pool=5.0, # Pool timeout
|
||||
),
|
||||
"follow_redirects": False,
|
||||
"http2": False, # Disable HTTP/2 for simplicity
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.ensure_client()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# Keep client alive - don't close it
|
||||
pass
|
||||
|
||||
async def ensure_client(self) -> None:
|
||||
"""Ensure the HTTP client is initialized and healthy."""
|
||||
if self._client is None:
|
||||
await self._create_client()
|
||||
|
||||
# Periodic health check
|
||||
current_time = time.time()
|
||||
if current_time - self._last_health_check > self._health_check_interval:
|
||||
if not await self._check_client_health():
|
||||
logger.warning("HTTP client health check failed, recreating client")
|
||||
await self._recreate_client()
|
||||
self._last_health_check = current_time
|
||||
|
||||
async def _create_client(self) -> None:
|
||||
"""Create a new HTTP client with connection pooling."""
|
||||
async with self._reconnect_lock:
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
|
||||
self._client = httpx.AsyncClient(**self._config)
|
||||
self._is_healthy = True
|
||||
logger.info("HTTP connection pool client created")
|
||||
|
||||
async def _recreate_client(self) -> None:
|
||||
"""Recreate the HTTP client (used when health check fails)."""
|
||||
logger.info("Recreating HTTP connection pool client")
|
||||
await self._create_client()
|
||||
|
||||
async def _check_client_health(self) -> bool:
|
||||
"""Check if the HTTP client is still healthy."""
|
||||
if not self._client:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Simple health check - we could ping a reliable endpoint
|
||||
# For now, just check if client is still responsive
|
||||
# In a real implementation, you might ping a health endpoint
|
||||
return self._is_healthy
|
||||
except Exception as e:
|
||||
logger.warning(f"HTTP client health check error: {e}")
|
||||
return False
|
||||
|
||||
async def request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||
"""Make an HTTP request using the connection pool."""
|
||||
await self.ensure_client()
|
||||
|
||||
if not self._client:
|
||||
raise RuntimeError("HTTP client not available")
|
||||
|
||||
try:
|
||||
response = await self._client.request(method, url, **kwargs)
|
||||
return response
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout, httpx.PoolTimeout) as e:
|
||||
# Connection-related errors - client might be unhealthy
|
||||
logger.warning(f"Connection error, marking client as unhealthy: {e}")
|
||||
self._is_healthy = False
|
||||
raise
|
||||
except Exception as e:
|
||||
# Other errors - re-raise as-is
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client and cleanup resources."""
|
||||
async with self._reconnect_lock:
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
self._is_healthy = False
|
||||
logger.info("HTTP connection pool client closed")
|
||||
|
||||
async def get_pool_stats(self) -> Dict[str, Any]:
|
||||
"""Get connection pool statistics."""
|
||||
if not self._client:
|
||||
return {"status": "not_initialized"}
|
||||
|
||||
# httpx doesn't expose detailed pool stats, but we can provide basic info
|
||||
return {
|
||||
"status": "healthy" if self._is_healthy else "unhealthy",
|
||||
"last_health_check": self._last_health_check,
|
||||
"config": {
|
||||
"max_keepalive_connections": self._config[
|
||||
"limits"
|
||||
].max_keepalive_connections,
|
||||
"max_connections": self._config["limits"].max_connections,
|
||||
"keepalive_expiry": self._config["limits"].keepalive_expiry,
|
||||
"connect_timeout": self._config["timeout"].connect,
|
||||
"read_timeout": self._config["timeout"].read,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Global HTTP connection pool instance
|
||||
_http_pool = HTTPConnectionPool()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_http_client():
|
||||
"""Context manager for getting the global HTTP client."""
|
||||
async with _http_pool:
|
||||
yield _http_pool
|
||||
|
||||
|
||||
async def make_http_request(method: str, url: str, **kwargs) -> httpx.Response:
|
||||
"""Make an HTTP request using the global connection pool."""
|
||||
async with get_http_client() as client:
|
||||
return await client.request(method, url, **kwargs)
|
||||
|
||||
|
||||
async def get_connection_pool_stats() -> Dict[str, Any]:
|
||||
"""Get connection pool statistics."""
|
||||
return await _http_pool.get_pool_stats()
|
||||
|
||||
|
||||
async def close_connection_pool() -> None:
|
||||
"""Close the global connection pool (for cleanup)."""
|
||||
await _http_pool.close()
|
||||
|
||||
|
||||
# Lifecycle management for FastAPI
|
||||
async def init_http_pool() -> None:
|
||||
"""Initialize the HTTP connection pool on startup."""
|
||||
logger.info("Initializing HTTP connection pool")
|
||||
await _http_pool.ensure_client()
|
||||
|
||||
|
||||
async def shutdown_http_pool() -> None:
|
||||
"""Shutdown the HTTP connection pool on shutdown."""
|
||||
logger.info("Shutting down HTTP connection pool")
|
||||
await _http_pool.close()
|
||||
317
session-manager/logging_config.py
Normal file
317
session-manager/logging_config.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Structured Logging Configuration
|
||||
|
||||
Provides comprehensive logging infrastructure with structured logging,
|
||||
request tracking, log formatting, and aggregation capabilities.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import logging.handlers
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
|
||||
class StructuredFormatter(logging.Formatter):
|
||||
"""Structured JSON formatter for production logging."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
# Create structured log entry
|
||||
log_entry = {
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
"module": record.module,
|
||||
"function": record.funcName,
|
||||
"line": record.lineno,
|
||||
}
|
||||
|
||||
# Add exception info if present
|
||||
if record.exc_info:
|
||||
log_entry["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
# Add extra fields from record
|
||||
if hasattr(record, "request_id"):
|
||||
log_entry["request_id"] = record.request_id
|
||||
if hasattr(record, "session_id"):
|
||||
log_entry["session_id"] = record.session_id
|
||||
if hasattr(record, "user_id"):
|
||||
log_entry["user_id"] = record.user_id
|
||||
if hasattr(record, "operation"):
|
||||
log_entry["operation"] = record.operation
|
||||
if hasattr(record, "duration_ms"):
|
||||
log_entry["duration_ms"] = record.duration_ms
|
||||
if hasattr(record, "status_code"):
|
||||
log_entry["status_code"] = record.status_code
|
||||
|
||||
# Add any additional structured data
|
||||
if hasattr(record, "__dict__"):
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in [
|
||||
"name",
|
||||
"msg",
|
||||
"args",
|
||||
"levelname",
|
||||
"levelno",
|
||||
"pathname",
|
||||
"filename",
|
||||
"module",
|
||||
"exc_info",
|
||||
"exc_text",
|
||||
"stack_info",
|
||||
"lineno",
|
||||
"funcName",
|
||||
"created",
|
||||
"msecs",
|
||||
"relativeCreated",
|
||||
"thread",
|
||||
"threadName",
|
||||
"processName",
|
||||
"process",
|
||||
"message",
|
||||
]:
|
||||
log_entry[key] = value
|
||||
|
||||
return json.dumps(log_entry, default=str)
|
||||
|
||||
|
||||
class HumanReadableFormatter(logging.Formatter):
|
||||
"""Human-readable formatter for development."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
fmt="%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
# Add request ID to human readable format
|
||||
if hasattr(record, "request_id"):
|
||||
self._fmt = "%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d [%(request_id)s] - %(message)s"
|
||||
else:
|
||||
self._fmt = "%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d - %(message)s"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class RequestContext:
|
||||
"""Context manager for request-scoped logging."""
|
||||
|
||||
_local = threading.local()
|
||||
|
||||
def __init__(self, request_id: Optional[str] = None):
|
||||
self.request_id = request_id or str(uuid.uuid4())[:8]
|
||||
self._old_request_id = getattr(self._local, "request_id", None)
|
||||
|
||||
def __enter__(self):
|
||||
self._local.request_id = self.request_id
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._old_request_id is not None:
|
||||
self._local.request_id = self._old_request_id
|
||||
else:
|
||||
delattr(self._local, "request_id")
|
||||
|
||||
@classmethod
|
||||
def get_current_request_id(cls) -> Optional[str]:
|
||||
"""Get the current request ID from thread local storage."""
|
||||
return getattr(cls._local, "request_id", None)
|
||||
|
||||
|
||||
class RequestAdapter(logging.LoggerAdapter):
|
||||
"""Logger adapter that automatically adds request context."""
|
||||
|
||||
def __init__(self, logger: logging.Logger):
|
||||
super().__init__(logger, {})
|
||||
|
||||
def process(self, msg: str, kwargs: Any) -> tuple:
|
||||
"""Add request context to log records."""
|
||||
request_id = RequestContext.get_current_request_id()
|
||||
if request_id:
|
||||
kwargs.setdefault("extra", {})["request_id"] = request_id
|
||||
return msg, kwargs
|
||||
|
||||
|
||||
def setup_logging(
|
||||
level: str = "INFO",
|
||||
format_type: str = "auto", # "json", "human", or "auto"
|
||||
log_file: Optional[str] = None,
|
||||
max_file_size: int = 10 * 1024 * 1024, # 10MB
|
||||
backup_count: int = 5,
|
||||
enable_console: bool = True,
|
||||
enable_file: bool = True,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Set up comprehensive logging configuration.
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
format_type: Log format - "json", "human", or "auto" (detects from environment)
|
||||
log_file: Path to log file (optional)
|
||||
max_file_size: Maximum log file size in bytes
|
||||
backup_count: Number of backup files to keep
|
||||
enable_console: Enable console logging
|
||||
enable_file: Enable file logging
|
||||
|
||||
Returns:
|
||||
Configured root logger
|
||||
"""
|
||||
# Determine format type
|
||||
if format_type == "auto":
|
||||
# Use JSON for production, human-readable for development
|
||||
format_type = "json" if os.getenv("ENVIRONMENT") == "production" else "human"
|
||||
|
||||
# Clear existing handlers
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Set log level
|
||||
numeric_level = getattr(logging, level.upper(), logging.INFO)
|
||||
root_logger.setLevel(numeric_level)
|
||||
|
||||
# Create formatters
|
||||
if format_type == "json":
|
||||
formatter = StructuredFormatter()
|
||||
else:
|
||||
formatter = HumanReadableFormatter()
|
||||
|
||||
# Console handler
|
||||
if enable_console:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(numeric_level)
|
||||
console_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler with rotation
|
||||
if enable_file and log_file:
|
||||
# Ensure log directory exists
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
log_file, maxBytes=max_file_size, backupCount=backup_count, encoding="utf-8"
|
||||
)
|
||||
file_handler.setLevel(numeric_level)
|
||||
file_handler.setFormatter(StructuredFormatter()) # Always use JSON for files
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# Create request adapter for the root logger
|
||||
adapter = RequestAdapter(root_logger)
|
||||
|
||||
# Configure third-party loggers
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("docker").setLevel(logging.WARNING)
|
||||
logging.getLogger("aiodeocker").setLevel(logging.WARNING)
|
||||
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
def get_logger(name: str) -> RequestAdapter:
|
||||
"""Get a configured logger with request context support."""
|
||||
logger = logging.getLogger(name)
|
||||
return RequestAdapter(logger)
|
||||
|
||||
|
||||
def log_performance(operation: str, duration_ms: float, **kwargs) -> None:
|
||||
"""Log performance metrics."""
|
||||
logger = get_logger(__name__)
|
||||
extra = {"operation": operation, "duration_ms": duration_ms, **kwargs}
|
||||
logger.info(
|
||||
f"Performance: {operation} completed in {duration_ms:.2f}ms", extra=extra
|
||||
)
|
||||
|
||||
|
||||
def log_request(
|
||||
method: str, path: str, status_code: int, duration_ms: float, **kwargs
|
||||
) -> None:
|
||||
"""Log HTTP request metrics."""
|
||||
logger = get_logger(__name__)
|
||||
extra = {
|
||||
"operation": "http_request",
|
||||
"method": method,
|
||||
"path": path,
|
||||
"status_code": status_code,
|
||||
"duration_ms": duration_ms,
|
||||
**kwargs,
|
||||
}
|
||||
level = logging.INFO if status_code < 400 else logging.WARNING
|
||||
logger.log(
|
||||
level,
|
||||
f"HTTP {method} {path} -> {status_code} ({duration_ms:.2f}ms)",
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
def log_session_operation(session_id: str, operation: str, **kwargs) -> None:
|
||||
"""Log session-related operations."""
|
||||
logger = get_logger(__name__)
|
||||
extra = {"session_id": session_id, "operation": operation, **kwargs}
|
||||
logger.info(f"Session {operation}: {session_id}", extra=extra)
|
||||
|
||||
|
||||
def log_security_event(event: str, severity: str = "info", **kwargs) -> None:
|
||||
"""Log security-related events."""
|
||||
logger = get_logger(__name__)
|
||||
extra = {"security_event": event, "severity": severity, **kwargs}
|
||||
level = getattr(logging, severity.upper(), logging.INFO)
|
||||
logger.log(level, f"Security: {event}", extra=extra)
|
||||
|
||||
|
||||
# Global logger instance
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Initialize logging on import
|
||||
_setup_complete = False
|
||||
|
||||
|
||||
def init_logging():
|
||||
"""Initialize logging system."""
|
||||
global _setup_complete
|
||||
if _setup_complete:
|
||||
return
|
||||
|
||||
# Configuration from environment
|
||||
level = os.getenv("LOG_LEVEL", "INFO")
|
||||
format_type = os.getenv("LOG_FORMAT", "auto") # json, human, auto
|
||||
log_file = os.getenv("LOG_FILE")
|
||||
max_file_size = int(os.getenv("LOG_MAX_SIZE_MB", "10")) * 1024 * 1024
|
||||
backup_count = int(os.getenv("LOG_BACKUP_COUNT", "5"))
|
||||
enable_console = os.getenv("LOG_CONSOLE", "true").lower() == "true"
|
||||
enable_file = (
|
||||
os.getenv("LOG_FILE_ENABLED", "true").lower() == "true" and log_file is not None
|
||||
)
|
||||
|
||||
setup_logging(
|
||||
level=level,
|
||||
format_type=format_type,
|
||||
log_file=log_file,
|
||||
max_file_size=max_file_size,
|
||||
backup_count=backup_count,
|
||||
enable_console=enable_console,
|
||||
enable_file=enable_file,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Structured logging initialized",
|
||||
extra={
|
||||
"level": level,
|
||||
"format": format_type,
|
||||
"log_file": log_file,
|
||||
"max_file_size_mb": max_file_size // (1024 * 1024),
|
||||
"backup_count": backup_count,
|
||||
},
|
||||
)
|
||||
|
||||
_setup_complete = True
|
||||
|
||||
|
||||
# Initialize on import
|
||||
init_logging()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,9 @@
|
||||
fastapi==0.104.1
|
||||
uvicorn==0.24.0
|
||||
docker>=7.1.0
|
||||
aiodeocker>=0.21.0
|
||||
asyncpg>=0.29.0
|
||||
pydantic==2.5.0
|
||||
python-multipart==0.0.6
|
||||
httpx==0.25.2
|
||||
psutil>=5.9.0
|
||||
248
session-manager/resource_manager.py
Normal file
248
session-manager/resource_manager.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Resource Management and Monitoring Utilities
|
||||
|
||||
Provides validation, enforcement, and monitoring of container resource limits
|
||||
to prevent resource exhaustion attacks and ensure fair resource allocation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import psutil
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourceLimits:
|
||||
"""Container resource limits configuration."""
|
||||
|
||||
memory_limit: str # e.g., "4g", "512m"
|
||||
cpu_quota: int # CPU quota in microseconds
|
||||
cpu_period: int # CPU period in microseconds
|
||||
|
||||
def validate(self) -> Tuple[bool, str]:
|
||||
"""Validate resource limits configuration."""
|
||||
# Validate memory limit format
|
||||
memory_limit = self.memory_limit.lower()
|
||||
if not (memory_limit.endswith(("g", "m", "k")) or memory_limit.isdigit()):
|
||||
return (
|
||||
False,
|
||||
f"Invalid memory limit format: {self.memory_limit}. Use format like '4g', '512m', '256k'",
|
||||
)
|
||||
|
||||
# Validate CPU quota and period
|
||||
if self.cpu_quota <= 0:
|
||||
return False, f"CPU quota must be positive, got {self.cpu_quota}"
|
||||
if self.cpu_period <= 0:
|
||||
return False, f"CPU period must be positive, got {self.cpu_period}"
|
||||
if self.cpu_quota > self.cpu_period:
|
||||
return (
|
||||
False,
|
||||
f"CPU quota ({self.cpu_quota}) cannot exceed CPU period ({self.cpu_period})",
|
||||
)
|
||||
|
||||
return True, "Valid"
|
||||
|
||||
def to_docker_limits(self) -> Dict[str, any]:
|
||||
"""Convert to Docker container limits format."""
|
||||
return {
|
||||
"mem_limit": self.memory_limit,
|
||||
"cpu_quota": self.cpu_quota,
|
||||
"cpu_period": self.cpu_period,
|
||||
}
|
||||
|
||||
|
||||
class ResourceMonitor:
|
||||
"""Monitor system and container resource usage."""
|
||||
|
||||
def __init__(self):
|
||||
self._last_check = datetime.now()
|
||||
self._alerts_sent = set() # Track alerts to prevent spam
|
||||
|
||||
def get_system_resources(self) -> Dict[str, any]:
|
||||
"""Get current system resource usage."""
|
||||
try:
|
||||
memory = psutil.virtual_memory()
|
||||
cpu = psutil.cpu_percent(interval=1)
|
||||
|
||||
return {
|
||||
"memory_percent": memory.percent / 100.0,
|
||||
"memory_used_gb": memory.used / (1024**3),
|
||||
"memory_total_gb": memory.total / (1024**3),
|
||||
"cpu_percent": cpu / 100.0,
|
||||
"cpu_count": psutil.cpu_count(),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get system resources: {e}")
|
||||
return {}
|
||||
|
||||
def check_resource_limits(
|
||||
self, limits: ResourceLimits, warning_thresholds: Dict[str, float]
|
||||
) -> Dict[str, any]:
|
||||
"""Check if system resources are approaching limits."""
|
||||
system_resources = self.get_system_resources()
|
||||
alerts = []
|
||||
|
||||
# Check memory usage
|
||||
memory_usage = system_resources.get("memory_percent", 0)
|
||||
memory_threshold = warning_thresholds.get("memory", 0.8)
|
||||
|
||||
if memory_usage >= memory_threshold:
|
||||
alerts.append(
|
||||
{
|
||||
"type": "memory",
|
||||
"level": "warning" if memory_usage < 0.95 else "critical",
|
||||
"message": f"System memory usage at {memory_usage:.1%}",
|
||||
"current": memory_usage,
|
||||
"threshold": memory_threshold,
|
||||
}
|
||||
)
|
||||
|
||||
# Check CPU usage
|
||||
cpu_usage = system_resources.get("cpu_percent", 0)
|
||||
cpu_threshold = warning_thresholds.get("cpu", 0.9)
|
||||
|
||||
if cpu_usage >= cpu_threshold:
|
||||
alerts.append(
|
||||
{
|
||||
"type": "cpu",
|
||||
"level": "warning" if cpu_usage < 0.95 else "critical",
|
||||
"message": f"System CPU usage at {cpu_usage:.1%}",
|
||||
"current": cpu_usage,
|
||||
"threshold": cpu_threshold,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"system_resources": system_resources,
|
||||
"alerts": alerts,
|
||||
"timestamp": datetime.now(),
|
||||
}
|
||||
|
||||
def should_throttle_sessions(self, resource_check: Dict) -> Tuple[bool, str]:
|
||||
"""Determine if new sessions should be throttled based on resource usage."""
|
||||
alerts = resource_check.get("alerts", [])
|
||||
|
||||
# Critical alerts always throttle
|
||||
critical_alerts = [a for a in alerts if a["level"] == "critical"]
|
||||
if critical_alerts:
|
||||
return (
|
||||
True,
|
||||
f"Critical resource usage: {[a['message'] for a in critical_alerts]}",
|
||||
)
|
||||
|
||||
# Multiple warnings also throttle
|
||||
warning_alerts = [a for a in alerts if a["level"] == "warning"]
|
||||
if len(warning_alerts) >= 2:
|
||||
return (
|
||||
True,
|
||||
f"Multiple resource warnings: {[a['message'] for a in warning_alerts]}",
|
||||
)
|
||||
|
||||
return False, "Resources OK"
|
||||
|
||||
|
||||
class ResourceValidator:
|
||||
"""Validate and parse resource limit configurations."""
|
||||
|
||||
@staticmethod
|
||||
def parse_memory_limit(memory_str: str) -> Tuple[int, str]:
|
||||
"""Parse memory limit string and return bytes."""
|
||||
if not memory_str:
|
||||
raise ValueError("Memory limit cannot be empty")
|
||||
|
||||
memory_str = memory_str.lower().strip()
|
||||
|
||||
# Handle different units
|
||||
if memory_str.endswith("g"):
|
||||
bytes_val = int(memory_str[:-1]) * (1024**3)
|
||||
unit = "GB"
|
||||
elif memory_str.endswith("m"):
|
||||
bytes_val = int(memory_str[:-1]) * (1024**2)
|
||||
unit = "MB"
|
||||
elif memory_str.endswith("k"):
|
||||
bytes_val = int(memory_str[:-1]) * 1024
|
||||
unit = "KB"
|
||||
else:
|
||||
# Assume bytes if no unit
|
||||
bytes_val = int(memory_str)
|
||||
unit = "bytes"
|
||||
|
||||
if bytes_val <= 0:
|
||||
raise ValueError(f"Memory limit must be positive, got {bytes_val}")
|
||||
|
||||
# Reasonable limits check
|
||||
if bytes_val > 32 * (1024**3): # 32GB
|
||||
logger.warning(f"Very high memory limit: {bytes_val} bytes")
|
||||
|
||||
return bytes_val, unit
|
||||
|
||||
@staticmethod
|
||||
def validate_resource_config(
|
||||
config: Dict[str, any],
|
||||
) -> Tuple[bool, str, Optional[ResourceLimits]]:
|
||||
"""Validate complete resource configuration."""
|
||||
try:
|
||||
limits = ResourceLimits(
|
||||
memory_limit=config.get("memory_limit", "4g"),
|
||||
cpu_quota=config.get("cpu_quota", 100000),
|
||||
cpu_period=config.get("cpu_period", 100000),
|
||||
)
|
||||
|
||||
valid, message = limits.validate()
|
||||
if not valid:
|
||||
return False, message, None
|
||||
|
||||
# Additional validation
|
||||
memory_bytes, _ = ResourceValidator.parse_memory_limit(limits.memory_limit)
|
||||
|
||||
# Warn about potentially problematic configurations
|
||||
if memory_bytes < 128 * (1024**2): # Less than 128MB
|
||||
logger.warning("Very low memory limit may cause container instability")
|
||||
|
||||
return True, "Configuration valid", limits
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
return False, f"Invalid configuration: {e}", None
|
||||
|
||||
|
||||
# Global instances
|
||||
resource_monitor = ResourceMonitor()
|
||||
|
||||
|
||||
def get_resource_limits() -> ResourceLimits:
|
||||
"""Get validated resource limits from environment."""
|
||||
config = {
|
||||
"memory_limit": os.getenv("CONTAINER_MEMORY_LIMIT", "4g"),
|
||||
"cpu_quota": int(os.getenv("CONTAINER_CPU_QUOTA", "100000")),
|
||||
"cpu_period": int(os.getenv("CONTAINER_CPU_PERIOD", "100000")),
|
||||
}
|
||||
|
||||
valid, message, limits = ResourceValidator.validate_resource_config(config)
|
||||
if not valid or limits is None:
|
||||
raise ValueError(f"Resource configuration error: {message}")
|
||||
|
||||
logger.info(
|
||||
f"Using resource limits: memory={limits.memory_limit}, cpu_quota={limits.cpu_quota}"
|
||||
)
|
||||
return limits
|
||||
|
||||
|
||||
def check_system_resources() -> Dict[str, any]:
|
||||
"""Check current system resource status."""
|
||||
limits = get_resource_limits()
|
||||
warning_thresholds = {
|
||||
"memory": float(os.getenv("MEMORY_WARNING_THRESHOLD", "0.8")),
|
||||
"cpu": float(os.getenv("CPU_WARNING_THRESHOLD", "0.9")),
|
||||
}
|
||||
|
||||
return resource_monitor.check_resource_limits(limits, warning_thresholds)
|
||||
|
||||
|
||||
def should_throttle_sessions() -> Tuple[bool, str]:
|
||||
"""Check if new sessions should be throttled due to resource constraints."""
|
||||
resource_check = check_system_resources()
|
||||
return resource_monitor.should_throttle_sessions(resource_check)
|
||||
235
session-manager/session_auth.py
Normal file
235
session-manager/session_auth.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Token-Based Authentication for OpenCode Sessions
|
||||
|
||||
Provides secure token generation, validation, and management for individual
|
||||
user sessions to prevent unauthorized access to OpenCode servers.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import secrets
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionTokenManager:
|
||||
"""Manages authentication tokens for OpenCode user sessions."""
|
||||
|
||||
def __init__(self):
|
||||
# Token storage - in production, this should be in Redis/database
|
||||
self._session_tokens: Dict[str, Dict] = {}
|
||||
|
||||
# Token configuration
|
||||
self._token_length = int(os.getenv("SESSION_TOKEN_LENGTH", "32"))
|
||||
self._token_expiry_hours = int(os.getenv("SESSION_TOKEN_EXPIRY_HOURS", "24"))
|
||||
self._token_secret = os.getenv("SESSION_TOKEN_SECRET", self._generate_secret())
|
||||
|
||||
# Cleanup configuration
|
||||
self._cleanup_interval_minutes = int(
|
||||
os.getenv("TOKEN_CLEANUP_INTERVAL_MINUTES", "60")
|
||||
)
|
||||
|
||||
def _generate_secret(self) -> str:
|
||||
"""Generate a secure secret for token signing."""
|
||||
return secrets.token_hex(32)
|
||||
|
||||
def generate_session_token(self, session_id: str) -> str:
|
||||
"""
|
||||
Generate a unique authentication token for a session.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
|
||||
Returns:
|
||||
str: The authentication token
|
||||
"""
|
||||
# Generate cryptographically secure random token
|
||||
token = secrets.token_urlsafe(self._token_length)
|
||||
|
||||
# Create token data with expiry
|
||||
expiry = datetime.now() + timedelta(hours=self._token_expiry_hours)
|
||||
|
||||
# Store token information
|
||||
self._session_tokens[session_id] = {
|
||||
"token": token,
|
||||
"session_id": session_id,
|
||||
"created_at": datetime.now(),
|
||||
"expires_at": expiry,
|
||||
"last_used": datetime.now(),
|
||||
}
|
||||
|
||||
logger.info(f"Generated authentication token for session {session_id}")
|
||||
return token
|
||||
|
||||
def validate_session_token(self, session_id: str, token: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate a session token.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
token: The token to validate
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (is_valid, reason)
|
||||
"""
|
||||
# Check if session exists
|
||||
if session_id not in self._session_tokens:
|
||||
return False, "Session not found"
|
||||
|
||||
session_data = self._session_tokens[session_id]
|
||||
|
||||
# Check if token matches
|
||||
if session_data["token"] != token:
|
||||
return False, "Invalid token"
|
||||
|
||||
# Check if token has expired
|
||||
if datetime.now() > session_data["expires_at"]:
|
||||
# Clean up expired token
|
||||
del self._session_tokens[session_id]
|
||||
return False, "Token expired"
|
||||
|
||||
# Update last used time
|
||||
session_data["last_used"] = datetime.now()
|
||||
|
||||
return True, "Valid"
|
||||
|
||||
def revoke_session_token(self, session_id: str) -> bool:
|
||||
"""
|
||||
Revoke a session token.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
|
||||
Returns:
|
||||
bool: True if token was revoked, False if not found
|
||||
"""
|
||||
if session_id in self._session_tokens:
|
||||
del self._session_tokens[session_id]
|
||||
logger.info(f"Revoked authentication token for session {session_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def rotate_session_token(self, session_id: str) -> Optional[str]:
|
||||
"""
|
||||
Rotate (regenerate) a session token.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
|
||||
Returns:
|
||||
Optional[str]: New token if session exists, None otherwise
|
||||
"""
|
||||
if session_id not in self._session_tokens:
|
||||
return None
|
||||
|
||||
# Generate new token
|
||||
new_token = self.generate_session_token(session_id)
|
||||
|
||||
logger.info(f"Rotated authentication token for session {session_id}")
|
||||
return new_token
|
||||
|
||||
def cleanup_expired_tokens(self) -> int:
|
||||
"""
|
||||
Clean up expired tokens.
|
||||
|
||||
Returns:
|
||||
int: Number of tokens cleaned up
|
||||
"""
|
||||
now = datetime.now()
|
||||
expired_sessions = []
|
||||
|
||||
for session_id, session_data in self._session_tokens.items():
|
||||
if now > session_data["expires_at"]:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
# Remove expired tokens
|
||||
for session_id in expired_sessions:
|
||||
del self._session_tokens[session_id]
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(
|
||||
f"Cleaned up {len(expired_sessions)} expired authentication tokens"
|
||||
)
|
||||
|
||||
return len(expired_sessions)
|
||||
|
||||
def get_session_token_info(self, session_id: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get information about a session token.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: Token information or None if not found
|
||||
"""
|
||||
if session_id not in self._session_tokens:
|
||||
return None
|
||||
|
||||
session_data = self._session_tokens[session_id].copy()
|
||||
# Remove sensitive token value
|
||||
session_data.pop("token", None)
|
||||
return session_data
|
||||
|
||||
def get_active_sessions_count(self) -> int:
|
||||
"""Get the number of active sessions with tokens."""
|
||||
return len(self._session_tokens)
|
||||
|
||||
def list_active_sessions(self) -> Dict[str, Dict]:
|
||||
"""List all active sessions with token information (without token values)."""
|
||||
result = {}
|
||||
for session_id, session_data in self._session_tokens.items():
|
||||
# Create copy without sensitive token
|
||||
info = session_data.copy()
|
||||
info.pop("token", None)
|
||||
result[session_id] = info
|
||||
return result
|
||||
|
||||
|
||||
# Global token manager instance
|
||||
_session_token_manager = SessionTokenManager()
|
||||
|
||||
|
||||
def generate_session_auth_token(session_id: str) -> str:
|
||||
"""Generate an authentication token for a session."""
|
||||
return _session_token_manager.generate_session_token(session_id)
|
||||
|
||||
|
||||
def validate_session_auth_token(session_id: str, token: str) -> Tuple[bool, str]:
|
||||
"""Validate a session authentication token."""
|
||||
return _session_token_manager.validate_session_token(session_id, token)
|
||||
|
||||
|
||||
def revoke_session_auth_token(session_id: str) -> bool:
|
||||
"""Revoke a session authentication token."""
|
||||
return _session_token_manager.revoke_session_token(session_id)
|
||||
|
||||
|
||||
def rotate_session_auth_token(session_id: str) -> Optional[str]:
|
||||
"""Rotate a session authentication token."""
|
||||
return _session_token_manager.rotate_session_auth_token(session_id)
|
||||
|
||||
|
||||
def cleanup_expired_auth_tokens() -> int:
|
||||
"""Clean up expired authentication tokens."""
|
||||
return _session_token_manager.cleanup_expired_tokens()
|
||||
|
||||
|
||||
def get_session_auth_info(session_id: str) -> Optional[Dict]:
|
||||
"""Get authentication information for a session."""
|
||||
return _session_token_manager.get_session_token_info(session_id)
|
||||
|
||||
|
||||
def get_active_auth_sessions_count() -> int:
|
||||
"""Get the number of active authenticated sessions."""
|
||||
return _session_token_manager.get_active_sessions_count()
|
||||
|
||||
|
||||
def list_active_auth_sessions() -> Dict[str, Dict]:
|
||||
"""List all active authenticated sessions."""
|
||||
return _session_token_manager.list_active_sessions()
|
||||
Reference in New Issue
Block a user