fixed all remaining issues with the session manager
This commit is contained in:
302
session-manager/async_docker_client.py
Normal file
302
session-manager/async_docker_client.py
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
"""
|
||||||
|
Async Docker Operations Wrapper
|
||||||
|
|
||||||
|
Provides async wrappers for Docker operations to eliminate blocking calls
|
||||||
|
in FastAPI async contexts and improve concurrency and scalability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional, List, Any
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import os
|
||||||
|
|
||||||
|
from aiodeocker import Docker
|
||||||
|
from aiodeocker.containers import DockerContainer
|
||||||
|
from aiodeocker.exceptions import DockerError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncDockerClient:
|
||||||
|
"""Async wrapper for Docker operations using aiodeocker."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._docker: Optional[Docker] = None
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
await self.connect()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.disconnect()
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""Connect to Docker daemon."""
|
||||||
|
if self._connected:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Configure TLS if available
|
||||||
|
tls_config = None
|
||||||
|
if os.getenv("DOCKER_TLS_VERIFY") == "1":
|
||||||
|
from aiodeocker.utils import create_tls_config
|
||||||
|
|
||||||
|
tls_config = create_tls_config(
|
||||||
|
ca_cert=os.getenv("DOCKER_CA_CERT", "/etc/docker/certs/ca.pem"),
|
||||||
|
client_cert=(
|
||||||
|
os.getenv(
|
||||||
|
"DOCKER_CLIENT_CERT", "/etc/docker/certs/client-cert.pem"
|
||||||
|
),
|
||||||
|
os.getenv(
|
||||||
|
"DOCKER_CLIENT_KEY", "/etc/docker/certs/client-key.pem"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
verify=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
docker_host = os.getenv("DOCKER_HOST", "tcp://host.docker.internal:2376")
|
||||||
|
self._docker = Docker(docker_host, tls=tls_config)
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
await self._docker.ping()
|
||||||
|
self._connected = True
|
||||||
|
logger.info("Async Docker client connected successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to Docker: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
"""Disconnect from Docker daemon."""
|
||||||
|
if self._docker and self._connected:
|
||||||
|
await self._docker.close()
|
||||||
|
self._connected = False
|
||||||
|
logger.info("Async Docker client disconnected")
|
||||||
|
|
||||||
|
async def ping(self) -> bool:
|
||||||
|
"""Test Docker connectivity."""
|
||||||
|
if not self._docker:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
await self._docker.ping()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def create_container(
|
||||||
|
self,
|
||||||
|
image: str,
|
||||||
|
name: str,
|
||||||
|
volumes: Optional[Dict[str, Dict[str, str]]] = None,
|
||||||
|
ports: Optional[Dict[str, int]] = None,
|
||||||
|
environment: Optional[Dict[str, str]] = None,
|
||||||
|
network_mode: str = "bridge",
|
||||||
|
mem_limit: Optional[str] = None,
|
||||||
|
cpu_quota: Optional[int] = None,
|
||||||
|
cpu_period: Optional[int] = None,
|
||||||
|
tmpfs: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> DockerContainer:
|
||||||
|
"""
|
||||||
|
Create a Docker container asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Container image name
|
||||||
|
name: Container name
|
||||||
|
volumes: Volume mounts
|
||||||
|
ports: Port mappings
|
||||||
|
environment: Environment variables
|
||||||
|
network_mode: Network mode
|
||||||
|
mem_limit: Memory limit (e.g., "4g")
|
||||||
|
cpu_quota: CPU quota
|
||||||
|
cpu_period: CPU period
|
||||||
|
tmpfs: tmpfs mounts
|
||||||
|
**kwargs: Additional container configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DockerContainer: The created container
|
||||||
|
"""
|
||||||
|
if not self._docker:
|
||||||
|
raise RuntimeError("Docker client not connected")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"Image": image,
|
||||||
|
"name": name,
|
||||||
|
"Volumes": volumes or {},
|
||||||
|
"ExposedPorts": {f"{port}/tcp": {} for port in ports.values()}
|
||||||
|
if ports
|
||||||
|
else {},
|
||||||
|
"Env": [f"{k}={v}" for k, v in (environment or {}).items()],
|
||||||
|
"NetworkMode": network_mode,
|
||||||
|
"HostConfig": {
|
||||||
|
"Binds": [
|
||||||
|
f"{host}:{container['bind']}:{container.get('mode', 'rw')}"
|
||||||
|
for host, container in (volumes or {}).items()
|
||||||
|
],
|
||||||
|
"PortBindings": {
|
||||||
|
f"{container_port}/tcp": [{"HostPort": str(host_port)}]
|
||||||
|
for container_port, host_port in (ports or {}).items()
|
||||||
|
},
|
||||||
|
"Tmpfs": tmpfs or {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add resource limits
|
||||||
|
host_config = config["HostConfig"]
|
||||||
|
if mem_limit:
|
||||||
|
host_config["Memory"] = self._parse_memory_limit(mem_limit)
|
||||||
|
if cpu_quota is not None:
|
||||||
|
host_config["CpuQuota"] = cpu_quota
|
||||||
|
if cpu_period is not None:
|
||||||
|
host_config["CpuPeriod"] = cpu_period
|
||||||
|
|
||||||
|
# Add any additional host config
|
||||||
|
host_config.update(kwargs.get("host_config", {}))
|
||||||
|
|
||||||
|
try:
|
||||||
|
container = await self._docker.containers.create(config)
|
||||||
|
logger.info(f"Container {name} created successfully")
|
||||||
|
return container
|
||||||
|
except DockerError as e:
|
||||||
|
logger.error(f"Failed to create container {name}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def start_container(self, container: DockerContainer) -> None:
|
||||||
|
"""Start a Docker container."""
|
||||||
|
try:
|
||||||
|
await container.start()
|
||||||
|
logger.info(f"Container {container.id} started successfully")
|
||||||
|
except DockerError as e:
|
||||||
|
logger.error(f"Failed to start container {container.id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def stop_container(
|
||||||
|
self, container: DockerContainer, timeout: int = 10
|
||||||
|
) -> None:
|
||||||
|
"""Stop a Docker container."""
|
||||||
|
try:
|
||||||
|
await container.stop(timeout=timeout)
|
||||||
|
logger.info(f"Container {container.id} stopped successfully")
|
||||||
|
except DockerError as e:
|
||||||
|
logger.error(f"Failed to stop container {container.id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def remove_container(
|
||||||
|
self, container: DockerContainer, force: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Remove a Docker container."""
|
||||||
|
try:
|
||||||
|
await container.delete(force=force)
|
||||||
|
logger.info(f"Container {container.id} removed successfully")
|
||||||
|
except DockerError as e:
|
||||||
|
logger.error(f"Failed to remove container {container.id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_container(self, container_id: str) -> Optional[DockerContainer]:
|
||||||
|
"""Get a container by ID or name."""
|
||||||
|
try:
|
||||||
|
return await self._docker.containers.get(container_id)
|
||||||
|
except DockerError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def list_containers(
|
||||||
|
self, all: bool = False, filters: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[DockerContainer]:
|
||||||
|
"""List Docker containers."""
|
||||||
|
try:
|
||||||
|
return await self._docker.containers.list(all=all, filters=filters)
|
||||||
|
except DockerError as e:
|
||||||
|
logger.error(f"Failed to list containers: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_container_stats(
|
||||||
|
self, container: DockerContainer
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get container statistics."""
|
||||||
|
try:
|
||||||
|
stats = await container.stats(stream=False)
|
||||||
|
return stats
|
||||||
|
except DockerError as e:
|
||||||
|
logger.error(f"Failed to get stats for container {container.id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_system_info(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get Docker system information."""
|
||||||
|
if not self._docker:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return await self._docker.system.info()
|
||||||
|
except DockerError as e:
|
||||||
|
logger.error(f"Failed to get system info: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_memory_limit(self, memory_str: str) -> int:
|
||||||
|
"""Parse memory limit string to bytes."""
|
||||||
|
memory_str = memory_str.lower().strip()
|
||||||
|
if memory_str.endswith("g"):
|
||||||
|
return int(memory_str[:-1]) * (1024**3)
|
||||||
|
elif memory_str.endswith("m"):
|
||||||
|
return int(memory_str[:-1]) * (1024**2)
|
||||||
|
elif memory_str.endswith("k"):
|
||||||
|
return int(memory_str[:-1]) * 1024
|
||||||
|
else:
|
||||||
|
return int(memory_str)
|
||||||
|
|
||||||
|
|
||||||
|
# Global async Docker client instance
|
||||||
|
_async_docker_client = AsyncDockerClient()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_async_docker_client():
|
||||||
|
"""Context manager for async Docker client."""
|
||||||
|
async with _async_docker_client as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
async def async_docker_ping() -> bool:
|
||||||
|
"""Async ping Docker daemon."""
|
||||||
|
async with get_async_docker_client() as client:
|
||||||
|
return await client.ping()
|
||||||
|
|
||||||
|
|
||||||
|
async def async_create_container(**kwargs) -> DockerContainer:
|
||||||
|
"""Async container creation wrapper."""
|
||||||
|
async with get_async_docker_client() as client:
|
||||||
|
return await client.create_container(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_start_container(container: DockerContainer) -> None:
|
||||||
|
"""Async container start wrapper."""
|
||||||
|
async with get_async_docker_client() as client:
|
||||||
|
await client.start_container(container)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_stop_container(container: DockerContainer, timeout: int = 10) -> None:
|
||||||
|
"""Async container stop wrapper."""
|
||||||
|
async with get_async_docker_client() as client:
|
||||||
|
await client.stop_container(container, timeout)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_remove_container(
|
||||||
|
container: DockerContainer, force: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Async container removal wrapper."""
|
||||||
|
async with get_async_docker_client() as client:
|
||||||
|
await client.remove_container(container, force)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_list_containers(
|
||||||
|
all: bool = False, filters: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[DockerContainer]:
|
||||||
|
"""Async container listing wrapper."""
|
||||||
|
async with get_async_docker_client() as client:
|
||||||
|
return await client.list_containers(all=all, filters=filters)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_container(container_id: str) -> Optional[DockerContainer]:
|
||||||
|
"""Async container retrieval wrapper."""
|
||||||
|
async with get_async_docker_client() as client:
|
||||||
|
return await client.get_container(container_id)
|
||||||
574
session-manager/container_health.py
Normal file
574
session-manager/container_health.py
Normal file
@@ -0,0 +1,574 @@
|
|||||||
|
"""
|
||||||
|
Container Health Monitoring System
|
||||||
|
|
||||||
|
Provides active monitoring of Docker containers with automatic failure detection,
|
||||||
|
recovery mechanisms, and integration with session management and alerting systems.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from logging_config import get_logger, log_performance, log_security_event
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContainerStatus(Enum):
|
||||||
|
"""Container health status enumeration."""
|
||||||
|
|
||||||
|
HEALTHY = "healthy"
|
||||||
|
UNHEALTHY = "unhealthy"
|
||||||
|
RESTARTING = "restarting"
|
||||||
|
FAILED = "failed"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class HealthCheckResult:
|
||||||
|
"""Result of a container health check."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
container_id: str,
|
||||||
|
status: ContainerStatus,
|
||||||
|
response_time: Optional[float] = None,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
self.session_id = session_id
|
||||||
|
self.container_id = container_id
|
||||||
|
self.status = status
|
||||||
|
self.response_time = response_time
|
||||||
|
self.error_message = error_message
|
||||||
|
self.metadata = metadata or {}
|
||||||
|
self.timestamp = datetime.utcnow()
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for logging/serialization."""
|
||||||
|
return {
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"container_id": self.container_id,
|
||||||
|
"status": self.status.value,
|
||||||
|
"response_time": self.response_time,
|
||||||
|
"error_message": self.error_message,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"timestamp": self.timestamp.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ContainerHealthMonitor:
|
||||||
|
"""Monitors Docker container health and handles automatic recovery."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
check_interval: int = 30, # seconds
|
||||||
|
health_timeout: float = 10.0, # seconds
|
||||||
|
max_restart_attempts: int = 3,
|
||||||
|
restart_delay: int = 5, # seconds
|
||||||
|
failure_threshold: int = 3, # consecutive failures before restart
|
||||||
|
):
|
||||||
|
self.check_interval = check_interval
|
||||||
|
self.health_timeout = health_timeout
|
||||||
|
self.max_restart_attempts = max_restart_attempts
|
||||||
|
self.restart_delay = restart_delay
|
||||||
|
self.failure_threshold = failure_threshold
|
||||||
|
|
||||||
|
# Monitoring state
|
||||||
|
self._monitoring = False
|
||||||
|
self._task: Optional[asyncio.Task] = None
|
||||||
|
self._health_history: Dict[str, List[HealthCheckResult]] = {}
|
||||||
|
self._restart_counts: Dict[str, int] = {}
|
||||||
|
|
||||||
|
# Dependencies (injected)
|
||||||
|
self.session_manager = None
|
||||||
|
self.docker_client = None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Container health monitor initialized",
|
||||||
|
extra={
|
||||||
|
"check_interval": check_interval,
|
||||||
|
"health_timeout": health_timeout,
|
||||||
|
"max_restart_attempts": max_restart_attempts,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_dependencies(self, session_manager, docker_client):
|
||||||
|
"""Set dependencies for health monitoring."""
|
||||||
|
self.session_manager = session_manager
|
||||||
|
self.docker_client = docker_client
|
||||||
|
|
||||||
|
async def start_monitoring(self):
|
||||||
|
"""Start the health monitoring loop."""
|
||||||
|
if self._monitoring:
|
||||||
|
logger.warning("Health monitoring already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._monitoring = True
|
||||||
|
self._task = asyncio.create_task(self._monitoring_loop())
|
||||||
|
logger.info("Container health monitoring started")
|
||||||
|
|
||||||
|
async def stop_monitoring(self):
|
||||||
|
"""Stop the health monitoring loop."""
|
||||||
|
if not self._monitoring:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._monitoring = False
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("Container health monitoring stopped")
|
||||||
|
|
||||||
|
async def _monitoring_loop(self):
|
||||||
|
"""Main monitoring loop."""
|
||||||
|
while self._monitoring:
|
||||||
|
try:
|
||||||
|
await self._perform_health_checks()
|
||||||
|
await self._cleanup_old_history()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in health monitoring loop", extra={"error": str(e)})
|
||||||
|
|
||||||
|
await asyncio.sleep(self.check_interval)
|
||||||
|
|
||||||
|
async def _perform_health_checks(self):
|
||||||
|
"""Perform health checks on all running containers."""
|
||||||
|
if not self.session_manager:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get all running sessions
|
||||||
|
running_sessions = [
|
||||||
|
session
|
||||||
|
for session in self.session_manager.sessions.values()
|
||||||
|
if session.status == "running"
|
||||||
|
]
|
||||||
|
|
||||||
|
if not running_sessions:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"Checking health of {len(running_sessions)} running containers")
|
||||||
|
|
||||||
|
# Perform health checks concurrently
|
||||||
|
tasks = [self._check_container_health(session) for session in running_sessions]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Process results
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
session = running_sessions[i]
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(
|
||||||
|
"Health check failed",
|
||||||
|
extra={
|
||||||
|
"session_id": session.session_id,
|
||||||
|
"container_id": session.container_id,
|
||||||
|
"error": str(result),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self._process_health_result(result)
|
||||||
|
|
||||||
|
async def _check_container_health(self, session) -> HealthCheckResult:
|
||||||
|
"""Check the health of a single container."""
|
||||||
|
start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if container exists and is running
|
||||||
|
if not session.container_id:
|
||||||
|
return HealthCheckResult(
|
||||||
|
session.session_id,
|
||||||
|
session.container_id or "unknown",
|
||||||
|
ContainerStatus.UNKNOWN,
|
||||||
|
error_message="No container ID",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get container status
|
||||||
|
container_info = await self._get_container_info(session.container_id)
|
||||||
|
if not container_info:
|
||||||
|
return HealthCheckResult(
|
||||||
|
session.session_id,
|
||||||
|
session.container_id,
|
||||||
|
ContainerStatus.FAILED,
|
||||||
|
error_message="Container not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check container state
|
||||||
|
state = container_info.get("State", {})
|
||||||
|
status = state.get("Status", "unknown")
|
||||||
|
|
||||||
|
if status != "running":
|
||||||
|
return HealthCheckResult(
|
||||||
|
session.session_id,
|
||||||
|
session.container_id,
|
||||||
|
ContainerStatus.FAILED,
|
||||||
|
error_message=f"Container status: {status}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check health status if available
|
||||||
|
health = state.get("Health", {})
|
||||||
|
if health:
|
||||||
|
health_status = health.get("Status", "unknown")
|
||||||
|
if health_status == "healthy":
|
||||||
|
response_time = (
|
||||||
|
asyncio.get_event_loop().time() - start_time
|
||||||
|
) * 1000
|
||||||
|
return HealthCheckResult(
|
||||||
|
session.session_id,
|
||||||
|
session.container_id,
|
||||||
|
ContainerStatus.HEALTHY,
|
||||||
|
response_time=response_time,
|
||||||
|
metadata={
|
||||||
|
"docker_status": status,
|
||||||
|
"health_status": health_status,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif health_status in ["unhealthy", "starting"]:
|
||||||
|
return HealthCheckResult(
|
||||||
|
session.session_id,
|
||||||
|
session.container_id,
|
||||||
|
ContainerStatus.UNHEALTHY,
|
||||||
|
error_message=f"Health check: {health_status}",
|
||||||
|
metadata={
|
||||||
|
"docker_status": status,
|
||||||
|
"health_status": health_status,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no health check configured, consider running containers healthy
|
||||||
|
response_time = (asyncio.get_event_loop().time() - start_time) * 1000
|
||||||
|
return HealthCheckResult(
|
||||||
|
session.session_id,
|
||||||
|
session.container_id,
|
||||||
|
ContainerStatus.HEALTHY,
|
||||||
|
response_time=response_time,
|
||||||
|
metadata={"docker_status": status},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
response_time = (asyncio.get_event_loop().time() - start_time) * 1000
|
||||||
|
return HealthCheckResult(
|
||||||
|
session.session_id,
|
||||||
|
session.container_id or "unknown",
|
||||||
|
ContainerStatus.UNKNOWN,
|
||||||
|
response_time=response_time,
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_container_info(self, container_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get container information from Docker."""
|
||||||
|
try:
|
||||||
|
if self.docker_client:
|
||||||
|
# Try async Docker client first
|
||||||
|
container = await self.docker_client.get_container(container_id)
|
||||||
|
if hasattr(container, "_container"):
|
||||||
|
return await container._container.show()
|
||||||
|
elif hasattr(container, "show"):
|
||||||
|
return await container.show()
|
||||||
|
else:
|
||||||
|
# Fallback to sync client if available
|
||||||
|
if (
|
||||||
|
hasattr(self.session_manager, "docker_client")
|
||||||
|
and self.session_manager.docker_client
|
||||||
|
):
|
||||||
|
container = self.session_manager.docker_client.containers.get(
|
||||||
|
container_id
|
||||||
|
)
|
||||||
|
return container.attrs
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
f"Failed to get container info for {container_id}",
|
||||||
|
extra={"error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _process_health_result(self, result: HealthCheckResult):
|
||||||
|
"""Process a health check result and take appropriate action."""
|
||||||
|
# Store result in history
|
||||||
|
if result.session_id not in self._health_history:
|
||||||
|
self._health_history[result.session_id] = []
|
||||||
|
|
||||||
|
self._health_history[result.session_id].append(result)
|
||||||
|
|
||||||
|
# Keep only recent history (last 10 checks)
|
||||||
|
if len(self._health_history[result.session_id]) > 10:
|
||||||
|
self._health_history[result.session_id] = self._health_history[
|
||||||
|
result.session_id
|
||||||
|
][-10:]
|
||||||
|
|
||||||
|
# Log result
|
||||||
|
log_extra = result.to_dict()
|
||||||
|
if result.status == ContainerStatus.HEALTHY:
|
||||||
|
logger.debug("Container health check passed", extra=log_extra)
|
||||||
|
elif result.status == ContainerStatus.UNHEALTHY:
|
||||||
|
logger.warning("Container health check failed", extra=log_extra)
|
||||||
|
elif result.status in [ContainerStatus.FAILED, ContainerStatus.UNKNOWN]:
|
||||||
|
logger.error("Container health check critical", extra=log_extra)
|
||||||
|
|
||||||
|
# Check if restart is needed
|
||||||
|
await self._check_restart_needed(result)
|
||||||
|
|
||||||
|
async def _check_restart_needed(self, result: HealthCheckResult):
|
||||||
|
"""Check if a container needs to be restarted based on health history."""
|
||||||
|
if result.status == ContainerStatus.HEALTHY:
|
||||||
|
# Reset restart count on successful health check
|
||||||
|
if result.session_id in self._restart_counts:
|
||||||
|
self._restart_counts[result.session_id] = 0
|
||||||
|
return
|
||||||
|
|
||||||
|
# Count recent failures
|
||||||
|
recent_results = self._health_history.get(result.session_id, [])
|
||||||
|
recent_failures = sum(
|
||||||
|
1
|
||||||
|
for r in recent_results[-self.failure_threshold :]
|
||||||
|
if r.status
|
||||||
|
in [
|
||||||
|
ContainerStatus.UNHEALTHY,
|
||||||
|
ContainerStatus.FAILED,
|
||||||
|
ContainerStatus.UNKNOWN,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if recent_failures >= self.failure_threshold:
|
||||||
|
await self._restart_container(result.session_id, result.container_id)
|
||||||
|
|
||||||
|
async def _restart_container(self, session_id: str, container_id: str):
|
||||||
|
"""Restart a failed container."""
|
||||||
|
# Check restart limit
|
||||||
|
restart_count = self._restart_counts.get(session_id, 0)
|
||||||
|
if restart_count >= self.max_restart_attempts:
|
||||||
|
logger.error(
|
||||||
|
"Container restart limit exceeded",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"container_id": container_id,
|
||||||
|
"restart_attempts": restart_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Mark session as failed
|
||||||
|
await self._mark_session_failed(
|
||||||
|
session_id, f"Restart limit exceeded ({restart_count} attempts)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Attempting container restart",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"container_id": container_id,
|
||||||
|
"restart_attempt": restart_count + 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Stop the container
|
||||||
|
await self._stop_container(container_id)
|
||||||
|
|
||||||
|
# Wait before restart
|
||||||
|
await asyncio.sleep(self.restart_delay)
|
||||||
|
|
||||||
|
# Start new container for the session
|
||||||
|
session = await self.session_manager.get_session(session_id)
|
||||||
|
if session:
|
||||||
|
# Update restart count
|
||||||
|
self._restart_counts[session_id] = restart_count + 1
|
||||||
|
|
||||||
|
# Mark as restarting
|
||||||
|
await self._update_session_status(session_id, "restarting")
|
||||||
|
|
||||||
|
# Trigger container restart through session manager
|
||||||
|
if self.session_manager:
|
||||||
|
# Create new container for the session
|
||||||
|
await self.session_manager.create_session()
|
||||||
|
logger.info(
|
||||||
|
"Container restart initiated",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"restart_attempt": restart_count + 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log security event
|
||||||
|
log_security_event(
|
||||||
|
"container_restart",
|
||||||
|
"warning",
|
||||||
|
{
|
||||||
|
"session_id": session_id,
|
||||||
|
"container_id": container_id,
|
||||||
|
"reason": "health_check_failure",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Container restart failed",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"container_id": container_id,
|
||||||
|
"error": str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stop_container(self, container_id: str):
|
||||||
|
"""Stop a container."""
|
||||||
|
try:
|
||||||
|
if self.docker_client:
|
||||||
|
container = await self.docker_client.get_container(container_id)
|
||||||
|
await self.docker_client.stop_container(container, timeout=10)
|
||||||
|
elif (
|
||||||
|
hasattr(self.session_manager, "docker_client")
|
||||||
|
and self.session_manager.docker_client
|
||||||
|
):
|
||||||
|
container = self.session_manager.docker_client.containers.get(
|
||||||
|
container_id
|
||||||
|
)
|
||||||
|
container.stop(timeout=10)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to stop container during restart",
|
||||||
|
extra={"container_id": container_id, "error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _update_session_status(self, session_id: str, status: str):
|
||||||
|
"""Update session status."""
|
||||||
|
if self.session_manager:
|
||||||
|
session = self.session_manager.sessions.get(session_id)
|
||||||
|
if session:
|
||||||
|
session.status = status
|
||||||
|
# Update in database if using database storage
|
||||||
|
if (
|
||||||
|
hasattr(self.session_manager, "USE_DATABASE_STORAGE")
|
||||||
|
and self.session_manager.USE_DATABASE_STORAGE
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
from database import SessionModel
|
||||||
|
|
||||||
|
await SessionModel.update_session(
|
||||||
|
session_id, {"status": status}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to update session status in database",
|
||||||
|
extra={"session_id": session_id, "error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _mark_session_failed(self, session_id: str, reason: str):
|
||||||
|
"""Mark a session as permanently failed."""
|
||||||
|
await self._update_session_status(session_id, "failed")
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
"Session marked as failed",
|
||||||
|
extra={"session_id": session_id, "reason": reason},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log security event
|
||||||
|
log_security_event(
|
||||||
|
"session_failure", "error", {"session_id": session_id, "reason": reason}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _cleanup_old_history(self):
|
||||||
|
"""Clean up old health check history."""
|
||||||
|
cutoff_time = datetime.utcnow() - timedelta(hours=1) # Keep last hour
|
||||||
|
|
||||||
|
for session_id in list(self._health_history.keys()):
|
||||||
|
# Remove old results
|
||||||
|
self._health_history[session_id] = [
|
||||||
|
result
|
||||||
|
for result in self._health_history[session_id]
|
||||||
|
if result.timestamp > cutoff_time
|
||||||
|
]
|
||||||
|
|
||||||
|
# Remove empty histories
|
||||||
|
if not self._health_history[session_id]:
|
||||||
|
del self._health_history[session_id]
|
||||||
|
|
||||||
|
def get_health_stats(self, session_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""Get health monitoring statistics."""
|
||||||
|
stats = {
|
||||||
|
"monitoring_active": self._monitoring,
|
||||||
|
"check_interval": self.check_interval,
|
||||||
|
"total_sessions_monitored": len(self._health_history),
|
||||||
|
"sessions_with_failures": len(
|
||||||
|
[
|
||||||
|
sid
|
||||||
|
for sid, history in self._health_history.items()
|
||||||
|
if any(
|
||||||
|
r.status != ContainerStatus.HEALTHY for r in history[-5:]
|
||||||
|
) # Last 5 checks
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"restart_counts": dict(self._restart_counts),
|
||||||
|
}
|
||||||
|
|
||||||
|
if session_id and session_id in self._health_history:
|
||||||
|
recent_results = self._health_history[session_id][-10:] # Last 10 checks
|
||||||
|
stats[f"session_{session_id}"] = {
|
||||||
|
"total_checks": len(recent_results),
|
||||||
|
"healthy_checks": sum(
|
||||||
|
1 for r in recent_results if r.status == ContainerStatus.HEALTHY
|
||||||
|
),
|
||||||
|
"failed_checks": sum(
|
||||||
|
1 for r in recent_results if r.status != ContainerStatus.HEALTHY
|
||||||
|
),
|
||||||
|
"average_response_time": sum(
|
||||||
|
r.response_time or 0 for r in recent_results if r.response_time
|
||||||
|
)
|
||||||
|
/ max(1, sum(1 for r in recent_results if r.response_time)),
|
||||||
|
"last_check": recent_results[-1].to_dict() if recent_results else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def get_health_history(
|
||||||
|
self, session_id: str, limit: int = 50
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Get health check history for a session."""
|
||||||
|
if session_id not in self._health_history:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
result.to_dict() for result in self._health_history[session_id][-limit:]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Global health monitor instance
|
||||||
|
_container_health_monitor = ContainerHealthMonitor()
|
||||||
|
|
||||||
|
|
||||||
|
def get_container_health_monitor() -> ContainerHealthMonitor:
|
||||||
|
"""Get the global container health monitor instance."""
|
||||||
|
return _container_health_monitor
|
||||||
|
|
||||||
|
|
||||||
|
async def start_container_health_monitoring(session_manager=None, docker_client=None):
|
||||||
|
"""Start container health monitoring."""
|
||||||
|
monitor = get_container_health_monitor()
|
||||||
|
if session_manager:
|
||||||
|
monitor.set_dependencies(session_manager, docker_client)
|
||||||
|
await monitor.start_monitoring()
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_container_health_monitoring():
|
||||||
|
"""Stop container health monitoring."""
|
||||||
|
monitor = get_container_health_monitor()
|
||||||
|
await monitor.stop_monitoring()
|
||||||
|
|
||||||
|
|
||||||
|
def get_container_health_stats(session_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""Get container health statistics."""
|
||||||
|
monitor = get_container_health_monitor()
|
||||||
|
return monitor.get_health_stats(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_container_health_history(
|
||||||
|
session_id: str, limit: int = 50
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Get container health check history."""
|
||||||
|
monitor = get_container_health_monitor()
|
||||||
|
return monitor.get_health_history(session_id, limit)
|
||||||
406
session-manager/database.py
Normal file
406
session-manager/database.py
Normal file
@@ -0,0 +1,406 @@
|
|||||||
|
"""
|
||||||
|
Database Models and Connection Management for Session Persistence
|
||||||
|
|
||||||
|
Provides PostgreSQL-backed session storage with connection pooling,
|
||||||
|
migrations, and health monitoring for production reliability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncpg
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConnection:
|
||||||
|
"""PostgreSQL connection management with pooling."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._pool: Optional[asyncpg.Pool] = None
|
||||||
|
self._config = {
|
||||||
|
"host": os.getenv("DB_HOST", "localhost"),
|
||||||
|
"port": int(os.getenv("DB_PORT", "5432")),
|
||||||
|
"user": os.getenv("DB_USER", "lovdata"),
|
||||||
|
"password": os.getenv("DB_PASSWORD", "password"),
|
||||||
|
"database": os.getenv("DB_NAME", "lovdata_chat"),
|
||||||
|
"min_size": int(os.getenv("DB_MIN_CONNECTIONS", "5")),
|
||||||
|
"max_size": int(os.getenv("DB_MAX_CONNECTIONS", "20")),
|
||||||
|
"max_queries": int(os.getenv("DB_MAX_QUERIES", "50000")),
|
||||||
|
"max_inactive_connection_lifetime": float(
|
||||||
|
os.getenv("DB_MAX_INACTIVE_LIFETIME", "300.0")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Establish database connection pool."""
|
||||||
|
if self._pool:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._pool = await asyncpg.create_pool(**self._config)
|
||||||
|
logger.info(
|
||||||
|
"Database connection pool established",
|
||||||
|
extra={
|
||||||
|
"host": self._config["host"],
|
||||||
|
"port": self._config["port"],
|
||||||
|
"database": self._config["database"],
|
||||||
|
"min_connections": self._config["min_size"],
|
||||||
|
"max_connections": self._config["max_size"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to establish database connection", extra={"error": str(e)}
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close database connection pool."""
|
||||||
|
if self._pool:
|
||||||
|
await self._pool.close()
|
||||||
|
self._pool = None
|
||||||
|
logger.info("Database connection pool closed")
|
||||||
|
|
||||||
|
async def get_connection(self) -> asyncpg.Connection:
|
||||||
|
"""Get a database connection from the pool."""
|
||||||
|
if not self._pool:
|
||||||
|
await self.connect()
|
||||||
|
return await self._pool.acquire()
|
||||||
|
|
||||||
|
async def release_connection(self, conn: asyncpg.Connection) -> None:
|
||||||
|
"""Release a database connection back to the pool."""
|
||||||
|
if self._pool:
|
||||||
|
await self._pool.release(conn)
|
||||||
|
|
||||||
|
async def health_check(self) -> Dict[str, Any]:
|
||||||
|
"""Perform database health check."""
|
||||||
|
try:
|
||||||
|
conn = await self.get_connection()
|
||||||
|
result = await conn.fetchval("SELECT 1")
|
||||||
|
await self.release_connection(conn)
|
||||||
|
|
||||||
|
if result == 1:
|
||||||
|
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
|
||||||
|
else:
|
||||||
|
return {"status": "unhealthy", "error": "Health check query failed"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Database health check failed", extra={"error": str(e)})
|
||||||
|
return {"status": "unhealthy", "error": str(e)}
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def transaction(self):
|
||||||
|
"""Context manager for database transactions."""
|
||||||
|
conn = await self.get_connection()
|
||||||
|
try:
|
||||||
|
async with conn.transaction():
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
await self.release_connection(conn)
|
||||||
|
|
||||||
|
|
||||||
|
# Global database connection instance
|
||||||
|
_db_connection = DatabaseConnection()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_db_connection():
|
||||||
|
"""Context manager for database connections."""
|
||||||
|
conn = await _db_connection.get_connection()
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
await _db_connection.release_connection(conn)
|
||||||
|
|
||||||
|
|
||||||
|
async def init_database() -> None:
|
||||||
|
"""Initialize database and run migrations."""
|
||||||
|
logger.info("Initializing database")
|
||||||
|
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
# Create sessions table
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
session_id VARCHAR(32) PRIMARY KEY,
|
||||||
|
container_name VARCHAR(255) NOT NULL,
|
||||||
|
container_id VARCHAR(255),
|
||||||
|
host_dir VARCHAR(1024) NOT NULL,
|
||||||
|
port INTEGER,
|
||||||
|
auth_token VARCHAR(255),
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
last_accessed TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
status VARCHAR(50) NOT NULL DEFAULT 'creating',
|
||||||
|
metadata JSONB DEFAULT '{}'
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create indexes for performance
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sessions_container_name ON sessions(container_name);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create cleanup function for expired sessions
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
|
||||||
|
RETURNS INTEGER AS $$
|
||||||
|
DECLARE
|
||||||
|
deleted_count INTEGER;
|
||||||
|
BEGIN
|
||||||
|
DELETE FROM sessions
|
||||||
|
WHERE last_accessed < NOW() - INTERVAL '1 hour';
|
||||||
|
|
||||||
|
GET DIAGNOSTICS deleted_count = ROW_COUNT;
|
||||||
|
RETURN deleted_count;
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
""")
|
||||||
|
|
||||||
|
logger.info("Database initialized and migrations applied")
|
||||||
|
|
||||||
|
|
||||||
|
async def shutdown_database() -> None:
|
||||||
|
"""Shutdown database connections."""
|
||||||
|
await _db_connection.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
class SessionModel:
|
||||||
|
"""Database model for sessions."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_session(session_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Create a new session in the database."""
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"""
|
||||||
|
INSERT INTO sessions (
|
||||||
|
session_id, container_name, container_id, host_dir, port,
|
||||||
|
auth_token, status, metadata
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
RETURNING session_id, container_name, container_id, host_dir, port,
|
||||||
|
auth_token, created_at, last_accessed, status, metadata
|
||||||
|
""",
|
||||||
|
session_data["session_id"],
|
||||||
|
session_data["container_name"],
|
||||||
|
session_data.get("container_id"),
|
||||||
|
session_data["host_dir"],
|
||||||
|
session_data.get("port"),
|
||||||
|
session_data.get("auth_token"),
|
||||||
|
session_data.get("status", "creating"),
|
||||||
|
json.dumps(session_data.get("metadata", {})),
|
||||||
|
)
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return dict(row)
|
||||||
|
raise ValueError("Failed to create session")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_session(session_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get a session by ID."""
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
# Update last_accessed timestamp
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"""
|
||||||
|
UPDATE sessions
|
||||||
|
SET last_accessed = NOW()
|
||||||
|
WHERE session_id = $1
|
||||||
|
RETURNING session_id, container_name, container_id, host_dir, port,
|
||||||
|
auth_token, created_at, last_accessed, status, metadata
|
||||||
|
""",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if row:
|
||||||
|
result = dict(row)
|
||||||
|
result["metadata"] = json.loads(result["metadata"] or "{}")
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_session(session_id: str, updates: Dict[str, Any]) -> bool:
|
||||||
|
"""Update session fields."""
|
||||||
|
if not updates:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
# Build dynamic update query
|
||||||
|
set_parts = []
|
||||||
|
values = [session_id]
|
||||||
|
param_index = 2
|
||||||
|
|
||||||
|
for key, value in updates.items():
|
||||||
|
if key in ["session_id", "created_at"]: # Don't update these
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key == "metadata":
|
||||||
|
set_parts.append(f"metadata = ${param_index}")
|
||||||
|
values.append(json.dumps(value))
|
||||||
|
elif key == "last_accessed":
|
||||||
|
set_parts.append(f"last_accessed = NOW()")
|
||||||
|
else:
|
||||||
|
set_parts.append(f"{key} = ${param_index}")
|
||||||
|
values.append(value)
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
|
if not set_parts:
|
||||||
|
return True
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
UPDATE sessions
|
||||||
|
SET {", ".join(set_parts)}
|
||||||
|
WHERE session_id = $1
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await conn.execute(query, *values)
|
||||||
|
return result == "UPDATE 1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def delete_session(session_id: str) -> bool:
|
||||||
|
"""Delete a session."""
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
"DELETE FROM sessions WHERE session_id = $1", session_id
|
||||||
|
)
|
||||||
|
return result == "DELETE 1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def list_sessions(limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
|
||||||
|
"""List sessions with pagination."""
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT session_id, container_name, container_id, host_dir, port,
|
||||||
|
auth_token, created_at, last_accessed, status, metadata
|
||||||
|
FROM sessions
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $1 OFFSET $2
|
||||||
|
""",
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for row in rows:
|
||||||
|
session = dict(row)
|
||||||
|
session["metadata"] = json.loads(session["metadata"] or "{}")
|
||||||
|
result.append(session)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def count_sessions() -> int:
|
||||||
|
"""Count total sessions."""
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
return await conn.fetchval("SELECT COUNT(*) FROM sessions")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def cleanup_expired_sessions() -> int:
|
||||||
|
"""Clean up expired sessions using database function."""
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
return await conn.fetchval("SELECT cleanup_expired_sessions()")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_active_sessions_count() -> int:
|
||||||
|
"""Get count of active (running) sessions."""
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
return await conn.fetchval("""
|
||||||
|
SELECT COUNT(*) FROM sessions
|
||||||
|
WHERE status = 'running'
|
||||||
|
""")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_sessions_by_status(status: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get sessions by status."""
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT session_id, container_name, container_id, host_dir, port,
|
||||||
|
auth_token, created_at, last_accessed, status, metadata
|
||||||
|
FROM sessions
|
||||||
|
WHERE status = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
""",
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for row in rows:
|
||||||
|
session = dict(row)
|
||||||
|
session["metadata"] = json.loads(session["metadata"] or "{}")
|
||||||
|
result.append(session)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Migration utilities
|
||||||
|
async def run_migrations():
|
||||||
|
"""Run database migrations."""
|
||||||
|
logger.info("Running database migrations")
|
||||||
|
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
# Migration 1: Add metadata column if it doesn't exist
|
||||||
|
try:
|
||||||
|
await conn.execute("""
|
||||||
|
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS metadata JSONB DEFAULT '{}'
|
||||||
|
""")
|
||||||
|
logger.info("Migration: Added metadata column")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Migration metadata column may already exist: {e}")
|
||||||
|
|
||||||
|
# Migration 2: Add indexes if they don't exist
|
||||||
|
try:
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
|
||||||
|
""")
|
||||||
|
logger.info("Migration: Added performance indexes")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Migration indexes may already exist: {e}")
|
||||||
|
|
||||||
|
logger.info("Database migrations completed")
|
||||||
|
|
||||||
|
|
||||||
|
# Health monitoring
|
||||||
|
async def get_database_stats() -> Dict[str, Any]:
|
||||||
|
"""Get database statistics and health information."""
|
||||||
|
try:
|
||||||
|
async with get_db_connection() as conn:
|
||||||
|
# Get basic stats
|
||||||
|
session_count = await conn.fetchval("SELECT COUNT(*) FROM sessions")
|
||||||
|
active_sessions = await conn.fetchval(
|
||||||
|
"SELECT COUNT(*) FROM sessions WHERE status = 'running'"
|
||||||
|
)
|
||||||
|
oldest_session = await conn.fetchval("SELECT MIN(created_at) FROM sessions")
|
||||||
|
newest_session = await conn.fetchval("SELECT MAX(created_at) FROM sessions")
|
||||||
|
|
||||||
|
# Get database size information
|
||||||
|
db_size = await conn.fetchval("""
|
||||||
|
SELECT pg_size_pretty(pg_database_size(current_database()))
|
||||||
|
""")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_sessions": session_count,
|
||||||
|
"active_sessions": active_sessions,
|
||||||
|
"oldest_session": oldest_session.isoformat()
|
||||||
|
if oldest_session
|
||||||
|
else None,
|
||||||
|
"newest_session": newest_session.isoformat()
|
||||||
|
if newest_session
|
||||||
|
else None,
|
||||||
|
"database_size": db_size,
|
||||||
|
"status": "healthy",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to get database stats", extra={"error": str(e)})
|
||||||
|
return {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
635
session-manager/docker_service.py
Normal file
635
session-manager/docker_service.py
Normal file
@@ -0,0 +1,635 @@
|
|||||||
|
"""
|
||||||
|
Docker Service Layer
|
||||||
|
|
||||||
|
Provides a clean abstraction for Docker operations, separating container management
|
||||||
|
from business logic. Enables easy testing, mocking, and future Docker client changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Any, Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContainerInfo:
|
||||||
|
"""Container information data structure."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
container_id: str,
|
||||||
|
name: str,
|
||||||
|
image: str,
|
||||||
|
status: str,
|
||||||
|
ports: Optional[Dict[str, int]] = None,
|
||||||
|
created_at: Optional[datetime] = None,
|
||||||
|
health_status: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.container_id = container_id
|
||||||
|
self.name = name
|
||||||
|
self.image = image
|
||||||
|
self.status = status
|
||||||
|
self.ports = ports or {}
|
||||||
|
self.created_at = created_at or datetime.utcnow()
|
||||||
|
self.health_status = health_status
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"container_id": self.container_id,
|
||||||
|
"name": self.name,
|
||||||
|
"image": self.image,
|
||||||
|
"status": self.status,
|
||||||
|
"ports": self.ports,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"health_status": self.health_status,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "ContainerInfo":
|
||||||
|
"""Create from dictionary."""
|
||||||
|
return cls(
|
||||||
|
container_id=data["container_id"],
|
||||||
|
name=data["name"],
|
||||||
|
image=data["image"],
|
||||||
|
status=data["status"],
|
||||||
|
ports=data.get("ports", {}),
|
||||||
|
created_at=datetime.fromisoformat(data["created_at"])
|
||||||
|
if data.get("created_at")
|
||||||
|
else None,
|
||||||
|
health_status=data.get("health_status"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DockerOperationError(Exception):
|
||||||
|
"""Docker operation error."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, operation: str, container_id: Optional[str] = None, message: str = ""
|
||||||
|
):
|
||||||
|
self.operation = operation
|
||||||
|
self.container_id = container_id
|
||||||
|
self.message = message
|
||||||
|
super().__init__(f"Docker {operation} failed: {message}")
|
||||||
|
|
||||||
|
|
||||||
|
class DockerService:
|
||||||
|
"""
|
||||||
|
Docker service abstraction layer.
|
||||||
|
|
||||||
|
Provides a clean interface for container operations,
|
||||||
|
enabling easy testing and future Docker client changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, use_async: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize Docker service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_async: Whether to use async Docker operations
|
||||||
|
"""
|
||||||
|
self.use_async = use_async
|
||||||
|
self._docker_client = None
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
logger.info("Docker service initialized", extra={"async_mode": use_async})
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the Docker client connection."""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.use_async:
|
||||||
|
# Initialize async Docker client
|
||||||
|
from async_docker_client import AsyncDockerClient
|
||||||
|
|
||||||
|
self._docker_client = AsyncDockerClient()
|
||||||
|
await self._docker_client.connect()
|
||||||
|
else:
|
||||||
|
# Initialize sync Docker client
|
||||||
|
import docker
|
||||||
|
|
||||||
|
tls_config = docker.tls.TLSConfig(
|
||||||
|
ca_cert=os.getenv("DOCKER_CA_CERT", "/etc/docker/certs/ca.pem"),
|
||||||
|
client_cert=(
|
||||||
|
os.getenv(
|
||||||
|
"DOCKER_CLIENT_CERT", "/etc/docker/certs/client-cert.pem"
|
||||||
|
),
|
||||||
|
os.getenv(
|
||||||
|
"DOCKER_CLIENT_KEY", "/etc/docker/certs/client-key.pem"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
verify=True,
|
||||||
|
)
|
||||||
|
docker_host = os.getenv(
|
||||||
|
"DOCKER_HOST", "tcp://host.docker.internal:2376"
|
||||||
|
)
|
||||||
|
self._docker_client = docker.from_env()
|
||||||
|
self._docker_client.api = docker.APIClient(
|
||||||
|
base_url=docker_host, tls=tls_config, version="auto"
|
||||||
|
)
|
||||||
|
# Test connection
|
||||||
|
self._docker_client.ping()
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("Docker service connection established")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to initialize Docker service", extra={"error": str(e)})
|
||||||
|
raise DockerOperationError("initialization", message=str(e))
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Shutdown the Docker client connection."""
|
||||||
|
if not self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.use_async and self._docker_client:
|
||||||
|
await self._docker_client.disconnect()
|
||||||
|
# Sync client doesn't need explicit shutdown
|
||||||
|
|
||||||
|
self._initialized = False
|
||||||
|
logger.info("Docker service connection closed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Error during Docker service shutdown", extra={"error": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ping(self) -> bool:
|
||||||
|
"""Test Docker daemon connectivity."""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.use_async:
|
||||||
|
return await self._docker_client.ping()
|
||||||
|
else:
|
||||||
|
self._docker_client.ping()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Docker ping failed", extra={"error": str(e)})
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def create_container(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
image: str,
|
||||||
|
volumes: Optional[Dict[str, Dict[str, str]]] = None,
|
||||||
|
ports: Optional[Dict[str, int]] = None,
|
||||||
|
environment: Optional[Dict[str, str]] = None,
|
||||||
|
network_mode: str = "bridge",
|
||||||
|
mem_limit: Optional[str] = None,
|
||||||
|
cpu_quota: Optional[int] = None,
|
||||||
|
cpu_period: Optional[int] = None,
|
||||||
|
tmpfs: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ContainerInfo:
|
||||||
|
"""
|
||||||
|
Create a Docker container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Container name
|
||||||
|
image: Container image
|
||||||
|
volumes: Volume mounts
|
||||||
|
ports: Port mappings
|
||||||
|
environment: Environment variables
|
||||||
|
network_mode: Network mode
|
||||||
|
mem_limit: Memory limit
|
||||||
|
cpu_quota: CPU quota
|
||||||
|
cpu_period: CPU period
|
||||||
|
tmpfs: tmpfs mounts
|
||||||
|
**kwargs: Additional options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContainerInfo: Information about created container
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DockerOperationError: If container creation fails
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
"Creating container",
|
||||||
|
extra={
|
||||||
|
"container_name": name,
|
||||||
|
"image": image,
|
||||||
|
"memory_limit": mem_limit,
|
||||||
|
"cpu_quota": cpu_quota,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_async:
|
||||||
|
container = await self._docker_client.create_container(
|
||||||
|
image=image,
|
||||||
|
name=name,
|
||||||
|
volumes=volumes,
|
||||||
|
ports=ports,
|
||||||
|
environment=environment,
|
||||||
|
network_mode=network_mode,
|
||||||
|
mem_limit=mem_limit,
|
||||||
|
cpu_quota=cpu_quota,
|
||||||
|
cpu_period=cpu_period,
|
||||||
|
tmpfs=tmpfs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ContainerInfo(
|
||||||
|
container_id=container.id,
|
||||||
|
name=name,
|
||||||
|
image=image,
|
||||||
|
status="created",
|
||||||
|
ports=ports,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
container = self._docker_client.containers.run(
|
||||||
|
image,
|
||||||
|
name=name,
|
||||||
|
volumes=volumes,
|
||||||
|
ports={f"{port}/tcp": port for port in ports.values()}
|
||||||
|
if ports
|
||||||
|
else None,
|
||||||
|
environment=environment,
|
||||||
|
network_mode=network_mode,
|
||||||
|
mem_limit=mem_limit,
|
||||||
|
cpu_quota=cpu_quota,
|
||||||
|
cpu_period=cpu_period,
|
||||||
|
tmpfs=tmpfs,
|
||||||
|
detach=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ContainerInfo(
|
||||||
|
container_id=container.id,
|
||||||
|
name=name,
|
||||||
|
image=image,
|
||||||
|
status="running",
|
||||||
|
ports=ports,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Container creation failed",
|
||||||
|
extra={"container_name": name, "image": image, "error": str(e)},
|
||||||
|
)
|
||||||
|
raise DockerOperationError("create_container", name, str(e))
|
||||||
|
|
||||||
|
async def start_container(self, container_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Start a Docker container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container_id: Container ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DockerOperationError: If container start fails
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Starting container", extra={"container_id": container_id})
|
||||||
|
|
||||||
|
if self.use_async:
|
||||||
|
container = await self._docker_client.get_container(container_id)
|
||||||
|
await self._docker_client.start_container(container)
|
||||||
|
else:
|
||||||
|
container = self._docker_client.containers.get(container_id)
|
||||||
|
container.start()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Container started successfully", extra={"container_id": container_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Container start failed",
|
||||||
|
extra={"container_id": container_id, "error": str(e)},
|
||||||
|
)
|
||||||
|
raise DockerOperationError("start_container", container_id, str(e))
|
||||||
|
|
||||||
|
async def stop_container(self, container_id: str, timeout: int = 10) -> None:
|
||||||
|
"""
|
||||||
|
Stop a Docker container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container_id: Container ID
|
||||||
|
timeout: Stop timeout in seconds
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DockerOperationError: If container stop fails
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
"Stopping container",
|
||||||
|
extra={"container_id": container_id, "timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_async:
|
||||||
|
container = await self._docker_client.get_container(container_id)
|
||||||
|
await self._docker_client.stop_container(container, timeout)
|
||||||
|
else:
|
||||||
|
container = self._docker_client.containers.get(container_id)
|
||||||
|
container.stop(timeout=timeout)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Container stopped successfully", extra={"container_id": container_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Container stop failed",
|
||||||
|
extra={"container_id": container_id, "error": str(e)},
|
||||||
|
)
|
||||||
|
raise DockerOperationError("stop_container", container_id, str(e))
|
||||||
|
|
||||||
|
async def remove_container(self, container_id: str, force: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Remove a Docker container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container_id: Container ID
|
||||||
|
force: Force removal if running
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DockerOperationError: If container removal fails
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
"Removing container",
|
||||||
|
extra={"container_id": container_id, "force": force},
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_async:
|
||||||
|
container = await self._docker_client.get_container(container_id)
|
||||||
|
await self._docker_client.remove_container(container, force)
|
||||||
|
else:
|
||||||
|
container = self._docker_client.containers.get(container_id)
|
||||||
|
container.remove(force=force)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Container removed successfully", extra={"container_id": container_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Container removal failed",
|
||||||
|
extra={"container_id": container_id, "error": str(e)},
|
||||||
|
)
|
||||||
|
raise DockerOperationError("remove_container", container_id, str(e))
|
||||||
|
|
||||||
|
async def get_container_info(self, container_id: str) -> Optional[ContainerInfo]:
|
||||||
|
"""
|
||||||
|
Get information about a container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container_id: Container ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContainerInfo or None: Container information
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.use_async:
|
||||||
|
container_info = await self._docker_client._get_container_info(
|
||||||
|
container_id
|
||||||
|
)
|
||||||
|
if container_info:
|
||||||
|
state = container_info.get("State", {})
|
||||||
|
config = container_info.get("Config", {})
|
||||||
|
return ContainerInfo(
|
||||||
|
container_id=container_id,
|
||||||
|
name=config.get("Name", "").lstrip("/"),
|
||||||
|
image=config.get("Image", ""),
|
||||||
|
status=state.get("Status", "unknown"),
|
||||||
|
health_status=state.get("Health", {}).get("Status"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
container = self._docker_client.containers.get(container_id)
|
||||||
|
return ContainerInfo(
|
||||||
|
container_id=container.id,
|
||||||
|
name=container.name,
|
||||||
|
image=container.image.tags[0]
|
||||||
|
if container.image.tags
|
||||||
|
else container.image.id,
|
||||||
|
status=container.status,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
"Container info retrieval failed",
|
||||||
|
extra={"container_id": container_id, "error": str(e)},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def list_containers(
|
||||||
|
self, all: bool = False, filters: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[ContainerInfo]:
|
||||||
|
"""
|
||||||
|
List Docker containers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all: Include stopped containers
|
||||||
|
filters: Container filters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ContainerInfo]: List of container information
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.use_async:
|
||||||
|
containers = await self._docker_client.list_containers(
|
||||||
|
all=all, filters=filters
|
||||||
|
)
|
||||||
|
result = []
|
||||||
|
for container in containers:
|
||||||
|
container_info = await self._docker_client._get_container_info(
|
||||||
|
container.id
|
||||||
|
)
|
||||||
|
if container_info:
|
||||||
|
state = container_info.get("State", {})
|
||||||
|
config = container_info.get("Config", {})
|
||||||
|
result.append(
|
||||||
|
ContainerInfo(
|
||||||
|
container_id=container.id,
|
||||||
|
name=config.get("Name", "").lstrip("/"),
|
||||||
|
image=config.get("Image", ""),
|
||||||
|
status=state.get("Status", "unknown"),
|
||||||
|
health_status=state.get("Health", {}).get("Status"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
containers = self._docker_client.containers.list(
|
||||||
|
all=all, filters=filters
|
||||||
|
)
|
||||||
|
result = []
|
||||||
|
for container in containers:
|
||||||
|
result.append(
|
||||||
|
ContainerInfo(
|
||||||
|
container_id=container.id,
|
||||||
|
name=container.name,
|
||||||
|
image=container.image.tags[0]
|
||||||
|
if container.image.tags
|
||||||
|
else container.image.id,
|
||||||
|
status=container.status,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Container listing failed", extra={"error": str(e)})
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_container_logs(self, container_id: str, tail: int = 100) -> str:
|
||||||
|
"""
|
||||||
|
Get container logs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container_id: Container ID
|
||||||
|
tail: Number of log lines to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Container logs
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.use_async:
|
||||||
|
container = await self._docker_client.get_container(container_id)
|
||||||
|
logs = await container.log(stdout=True, stderr=True, tail=tail)
|
||||||
|
return "\n".join(logs)
|
||||||
|
else:
|
||||||
|
container = self._docker_client.containers.get(container_id)
|
||||||
|
logs = container.logs(tail=tail).decode("utf-8")
|
||||||
|
return logs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Container log retrieval failed",
|
||||||
|
extra={"container_id": container_id, "error": str(e)},
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def get_system_info(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get Docker system information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict or None: System information
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.use_async:
|
||||||
|
return await self._docker_client.get_system_info()
|
||||||
|
else:
|
||||||
|
return self._docker_client.info()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("System info retrieval failed", extra={"error": str(e)})
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Context manager support
|
||||||
|
async def __aenter__(self):
|
||||||
|
await self.initialize()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
class MockDockerService(DockerService):
|
||||||
|
"""
|
||||||
|
Mock Docker service for testing without actual Docker.
|
||||||
|
|
||||||
|
Provides the same interface but with in-memory operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(use_async=False)
|
||||||
|
self._containers: Dict[str, ContainerInfo] = {}
|
||||||
|
self._next_id = 1
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Mock initialization - always succeeds."""
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("Mock Docker service initialized")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Mock shutdown."""
|
||||||
|
self._containers.clear()
|
||||||
|
self._initialized = False
|
||||||
|
logger.info("Mock Docker service shutdown")
|
||||||
|
|
||||||
|
async def ping(self) -> bool:
|
||||||
|
"""Mock ping - always succeeds."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def create_container(self, name: str, image: str, **kwargs) -> ContainerInfo:
|
||||||
|
"""Mock container creation."""
|
||||||
|
container_id = f"mock-{self._next_id}"
|
||||||
|
self._next_id += 1
|
||||||
|
|
||||||
|
container = ContainerInfo(
|
||||||
|
container_id=container_id, name=name, image=image, status="created"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._containers[container_id] = container
|
||||||
|
logger.info(
|
||||||
|
"Mock container created",
|
||||||
|
extra={
|
||||||
|
"container_id": container_id,
|
||||||
|
"container_name": name,
|
||||||
|
"image": image,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return container
|
||||||
|
|
||||||
|
async def start_container(self, container_id: str) -> None:
|
||||||
|
"""Mock container start."""
|
||||||
|
if container_id in self._containers:
|
||||||
|
self._containers[container_id].status = "running"
|
||||||
|
logger.info("Mock container started", extra={"container_id": container_id})
|
||||||
|
|
||||||
|
async def stop_container(self, container_id: str, timeout: int = 10) -> None:
|
||||||
|
"""Mock container stop."""
|
||||||
|
if container_id in self._containers:
|
||||||
|
self._containers[container_id].status = "exited"
|
||||||
|
logger.info("Mock container stopped", extra={"container_id": container_id})
|
||||||
|
|
||||||
|
async def remove_container(self, container_id: str, force: bool = False) -> None:
|
||||||
|
"""Mock container removal."""
|
||||||
|
if container_id in self._containers:
|
||||||
|
del self._containers[container_id]
|
||||||
|
logger.info("Mock container removed", extra={"container_id": container_id})
|
||||||
|
|
||||||
|
async def get_container_info(self, container_id: str) -> Optional[ContainerInfo]:
|
||||||
|
"""Mock container info retrieval."""
|
||||||
|
return self._containers.get(container_id)
|
||||||
|
|
||||||
|
async def list_containers(
|
||||||
|
self, all: bool = False, filters: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[ContainerInfo]:
|
||||||
|
"""Mock container listing."""
|
||||||
|
return list(self._containers.values())
|
||||||
252
session-manager/host_ip_detector.py
Normal file
252
session-manager/host_ip_detector.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""
|
||||||
|
Host IP Detection Utilities
|
||||||
|
|
||||||
|
Provides robust methods to detect the Docker host IP from within a container,
|
||||||
|
supporting multiple Docker environments and network configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Optional, List
|
||||||
|
from functools import lru_cache
|
||||||
|
import time
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HostIPDetector:
|
||||||
|
"""Detects the Docker host IP address from container perspective."""
|
||||||
|
|
||||||
|
# Common Docker gateway IPs to try as fallbacks
|
||||||
|
COMMON_GATEWAYS = [
|
||||||
|
"172.17.0.1", # Default Docker bridge
|
||||||
|
"172.18.0.1", # Docker networks
|
||||||
|
"192.168.65.1", # Docker Desktop
|
||||||
|
"192.168.66.1", # Alternative Docker Desktop
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._detected_ip: Optional[str] = None
|
||||||
|
self._last_detection: float = 0
|
||||||
|
self._cache_timeout: float = 300 # 5 minutes cache
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def detect_host_ip(self) -> str:
|
||||||
|
"""
|
||||||
|
Detect the Docker host IP using multiple methods with fallbacks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The detected host IP address
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no host IP can be detected
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Use cached result if recent
|
||||||
|
if (
|
||||||
|
self._detected_ip
|
||||||
|
and (current_time - self._last_detection) < self._cache_timeout
|
||||||
|
):
|
||||||
|
logger.debug(f"Using cached host IP: {self._detected_ip}")
|
||||||
|
return self._detected_ip
|
||||||
|
|
||||||
|
logger.info("Detecting Docker host IP...")
|
||||||
|
|
||||||
|
detection_methods = [
|
||||||
|
self._detect_via_docker_internal,
|
||||||
|
self._detect_via_gateway_env,
|
||||||
|
self._detect_via_route_table,
|
||||||
|
self._detect_via_network_connect,
|
||||||
|
self._detect_via_common_gateways,
|
||||||
|
]
|
||||||
|
|
||||||
|
for method in detection_methods:
|
||||||
|
try:
|
||||||
|
ip = method()
|
||||||
|
if ip and self._validate_ip(ip):
|
||||||
|
logger.info(
|
||||||
|
f"Successfully detected host IP using {method.__name__}: {ip}"
|
||||||
|
)
|
||||||
|
self._detected_ip = ip
|
||||||
|
self._last_detection = current_time
|
||||||
|
return ip
|
||||||
|
else:
|
||||||
|
logger.debug(f"Method {method.__name__} returned invalid IP: {ip}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Method {method.__name__} failed: {e}")
|
||||||
|
|
||||||
|
# If all methods fail, raise an error
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not detect Docker host IP. Tried all detection methods. "
|
||||||
|
"Please check your Docker network configuration or set HOST_IP environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _detect_via_docker_internal(self) -> Optional[str]:
|
||||||
|
"""Detect via host.docker.internal (Docker Desktop, Docker for Mac/Windows)."""
|
||||||
|
try:
|
||||||
|
# Try to resolve host.docker.internal
|
||||||
|
ip = socket.gethostbyname("host.docker.internal")
|
||||||
|
if ip != "127.0.0.1": # Make sure it's not localhost
|
||||||
|
return ip
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _detect_via_gateway_env(self) -> Optional[str]:
|
||||||
|
"""Detect via Docker gateway environment variables."""
|
||||||
|
# Check common Docker gateway environment variables
|
||||||
|
gateway_vars = [
|
||||||
|
"DOCKER_HOST_GATEWAY",
|
||||||
|
"GATEWAY",
|
||||||
|
"HOST_IP",
|
||||||
|
]
|
||||||
|
|
||||||
|
for var in gateway_vars:
|
||||||
|
ip = os.getenv(var)
|
||||||
|
if ip:
|
||||||
|
logger.debug(f"Found host IP in environment variable {var}: {ip}")
|
||||||
|
return ip
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _detect_via_route_table(self) -> Optional[str]:
|
||||||
|
"""Detect via Linux route table (/proc/net/route)."""
|
||||||
|
try:
|
||||||
|
with open("/proc/net/route", "r") as f:
|
||||||
|
for line in f:
|
||||||
|
fields = line.strip().split()
|
||||||
|
if (
|
||||||
|
len(fields) >= 8
|
||||||
|
and fields[0] != "Iface"
|
||||||
|
and fields[7] == "00000000"
|
||||||
|
):
|
||||||
|
# Found default route, convert hex gateway to IP
|
||||||
|
gateway_hex = fields[2]
|
||||||
|
if len(gateway_hex) == 8:
|
||||||
|
# Convert from hex to IP (little endian)
|
||||||
|
ip_parts = []
|
||||||
|
for i in range(0, 8, 2):
|
||||||
|
ip_parts.append(str(int(gateway_hex[i : i + 2], 16)))
|
||||||
|
ip = ".".join(reversed(ip_parts))
|
||||||
|
if ip != "0.0.0.0":
|
||||||
|
return ip
|
||||||
|
except (IOError, ValueError, IndexError) as e:
|
||||||
|
logger.debug(f"Failed to read route table: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _detect_via_network_connect(self) -> Optional[str]:
|
||||||
|
"""Detect by attempting to connect to a known external service."""
|
||||||
|
try:
|
||||||
|
# Try to connect to a reliable external service to determine local IP
|
||||||
|
# We'll use the Docker daemon itself as a reference
|
||||||
|
docker_host = os.getenv("DOCKER_HOST", "tcp://host.docker.internal:2376")
|
||||||
|
|
||||||
|
if docker_host.startswith("tcp://"):
|
||||||
|
host_part = docker_host[6:].split(":")[0]
|
||||||
|
if host_part not in ["localhost", "127.0.0.1"]:
|
||||||
|
# Try to resolve the host
|
||||||
|
try:
|
||||||
|
ip = socket.gethostbyname(host_part)
|
||||||
|
if ip != "127.0.0.1":
|
||||||
|
return ip
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Network connect detection failed: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _detect_via_common_gateways(self) -> Optional[str]:
|
||||||
|
"""Try common Docker gateway IPs."""
|
||||||
|
for gateway in self.COMMON_GATEWAYS:
|
||||||
|
if self._test_ip_connectivity(gateway):
|
||||||
|
logger.debug(f"Found working gateway: {gateway}")
|
||||||
|
return gateway
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _test_ip_connectivity(self, ip: str) -> bool:
|
||||||
|
"""Test if an IP address is reachable."""
|
||||||
|
try:
|
||||||
|
# Try to connect to a common port (Docker API or SSH)
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.settimeout(1.0)
|
||||||
|
result = sock.connect_ex((ip, 22)) # SSH port, commonly available
|
||||||
|
sock.close()
|
||||||
|
return result == 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _validate_ip(self, ip: str) -> bool:
|
||||||
|
"""Validate that the IP address is reasonable."""
|
||||||
|
try:
|
||||||
|
socket.inet_aton(ip)
|
||||||
|
# Basic validation - should not be localhost or invalid ranges
|
||||||
|
if ip.startswith("127."):
|
||||||
|
return False
|
||||||
|
if ip == "0.0.0.0":
|
||||||
|
return False
|
||||||
|
# Should be a private IP range
|
||||||
|
parts = ip.split(".")
|
||||||
|
if len(parts) != 4:
|
||||||
|
return False
|
||||||
|
first_octet = int(parts[0])
|
||||||
|
# Common Docker gateway ranges
|
||||||
|
return first_octet in [10, 172, 192]
|
||||||
|
except socket.error:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def async_detect_host_ip(self) -> str:
|
||||||
|
"""Async version of detect_host_ip for testing."""
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
return await loop.run_in_executor(executor, self.detect_host_ip)
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance for caching
|
||||||
|
_host_detector = HostIPDetector()
|
||||||
|
|
||||||
|
|
||||||
|
def get_host_ip() -> str:
|
||||||
|
"""
|
||||||
|
Get the Docker host IP address from container perspective.
|
||||||
|
|
||||||
|
This function caches the result for performance and tries multiple
|
||||||
|
detection methods with fallbacks for different Docker environments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The detected host IP address
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If host IP detection fails
|
||||||
|
"""
|
||||||
|
return _host_detector.detect_host_ip()
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_host_ip() -> str:
|
||||||
|
"""
|
||||||
|
Async version of get_host_ip for use in async contexts.
|
||||||
|
|
||||||
|
Since the actual detection is not async, this just wraps the sync version.
|
||||||
|
"""
|
||||||
|
# Run in thread pool to avoid blocking async context
|
||||||
|
import concurrent.futures
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
return await loop.run_in_executor(executor, get_host_ip)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_host_ip_cache():
|
||||||
|
"""Reset the cached host IP detection result."""
|
||||||
|
global _host_detector
|
||||||
|
_host_detector = HostIPDetector()
|
||||||
182
session-manager/http_pool.py
Normal file
182
session-manager/http_pool.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
"""
|
||||||
|
HTTP Connection Pool Manager
|
||||||
|
|
||||||
|
Provides a global httpx.AsyncClient instance with connection pooling
|
||||||
|
to eliminate the overhead of creating new HTTP clients for each proxy request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPConnectionPool:
|
||||||
|
"""Global HTTP connection pool manager for proxy operations."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
self._last_health_check: float = 0
|
||||||
|
self._health_check_interval: float = 60 # Check health every 60 seconds
|
||||||
|
self._is_healthy: bool = True
|
||||||
|
self._reconnect_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Connection pool configuration
|
||||||
|
self._config = {
|
||||||
|
"limits": httpx.Limits(
|
||||||
|
max_keepalive_connections=20, # Keep connections alive
|
||||||
|
max_connections=100, # Max total connections
|
||||||
|
keepalive_expiry=300.0, # Keep connections alive for 5 minutes
|
||||||
|
),
|
||||||
|
"timeout": httpx.Timeout(
|
||||||
|
connect=10.0, # Connection timeout
|
||||||
|
read=30.0, # Read timeout
|
||||||
|
write=10.0, # Write timeout
|
||||||
|
pool=5.0, # Pool timeout
|
||||||
|
),
|
||||||
|
"follow_redirects": False,
|
||||||
|
"http2": False, # Disable HTTP/2 for simplicity
|
||||||
|
}
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
await self.ensure_client()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
# Keep client alive - don't close it
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def ensure_client(self) -> None:
|
||||||
|
"""Ensure the HTTP client is initialized and healthy."""
|
||||||
|
if self._client is None:
|
||||||
|
await self._create_client()
|
||||||
|
|
||||||
|
# Periodic health check
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - self._last_health_check > self._health_check_interval:
|
||||||
|
if not await self._check_client_health():
|
||||||
|
logger.warning("HTTP client health check failed, recreating client")
|
||||||
|
await self._recreate_client()
|
||||||
|
self._last_health_check = current_time
|
||||||
|
|
||||||
|
async def _create_client(self) -> None:
|
||||||
|
"""Create a new HTTP client with connection pooling."""
|
||||||
|
async with self._reconnect_lock:
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
|
||||||
|
self._client = httpx.AsyncClient(**self._config)
|
||||||
|
self._is_healthy = True
|
||||||
|
logger.info("HTTP connection pool client created")
|
||||||
|
|
||||||
|
async def _recreate_client(self) -> None:
|
||||||
|
"""Recreate the HTTP client (used when health check fails)."""
|
||||||
|
logger.info("Recreating HTTP connection pool client")
|
||||||
|
await self._create_client()
|
||||||
|
|
||||||
|
async def _check_client_health(self) -> bool:
|
||||||
|
"""Check if the HTTP client is still healthy."""
|
||||||
|
if not self._client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Simple health check - we could ping a reliable endpoint
|
||||||
|
# For now, just check if client is still responsive
|
||||||
|
# In a real implementation, you might ping a health endpoint
|
||||||
|
return self._is_healthy
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"HTTP client health check error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||||
|
"""Make an HTTP request using the connection pool."""
|
||||||
|
await self.ensure_client()
|
||||||
|
|
||||||
|
if not self._client:
|
||||||
|
raise RuntimeError("HTTP client not available")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._client.request(method, url, **kwargs)
|
||||||
|
return response
|
||||||
|
except (httpx.ConnectError, httpx.ConnectTimeout, httpx.PoolTimeout) as e:
|
||||||
|
# Connection-related errors - client might be unhealthy
|
||||||
|
logger.warning(f"Connection error, marking client as unhealthy: {e}")
|
||||||
|
self._is_healthy = False
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# Other errors - re-raise as-is
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the HTTP client and cleanup resources."""
|
||||||
|
async with self._reconnect_lock:
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
self._is_healthy = False
|
||||||
|
logger.info("HTTP connection pool client closed")
|
||||||
|
|
||||||
|
async def get_pool_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get connection pool statistics."""
|
||||||
|
if not self._client:
|
||||||
|
return {"status": "not_initialized"}
|
||||||
|
|
||||||
|
# httpx doesn't expose detailed pool stats, but we can provide basic info
|
||||||
|
return {
|
||||||
|
"status": "healthy" if self._is_healthy else "unhealthy",
|
||||||
|
"last_health_check": self._last_health_check,
|
||||||
|
"config": {
|
||||||
|
"max_keepalive_connections": self._config[
|
||||||
|
"limits"
|
||||||
|
].max_keepalive_connections,
|
||||||
|
"max_connections": self._config["limits"].max_connections,
|
||||||
|
"keepalive_expiry": self._config["limits"].keepalive_expiry,
|
||||||
|
"connect_timeout": self._config["timeout"].connect,
|
||||||
|
"read_timeout": self._config["timeout"].read,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global HTTP connection pool instance
|
||||||
|
_http_pool = HTTPConnectionPool()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_http_client():
|
||||||
|
"""Context manager for getting the global HTTP client."""
|
||||||
|
async with _http_pool:
|
||||||
|
yield _http_pool
|
||||||
|
|
||||||
|
|
||||||
|
async def make_http_request(method: str, url: str, **kwargs) -> httpx.Response:
|
||||||
|
"""Make an HTTP request using the global connection pool."""
|
||||||
|
async with get_http_client() as client:
|
||||||
|
return await client.request(method, url, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_connection_pool_stats() -> Dict[str, Any]:
|
||||||
|
"""Get connection pool statistics."""
|
||||||
|
return await _http_pool.get_pool_stats()
|
||||||
|
|
||||||
|
|
||||||
|
async def close_connection_pool() -> None:
|
||||||
|
"""Close the global connection pool (for cleanup)."""
|
||||||
|
await _http_pool.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Lifecycle management for FastAPI
|
||||||
|
async def init_http_pool() -> None:
|
||||||
|
"""Initialize the HTTP connection pool on startup."""
|
||||||
|
logger.info("Initializing HTTP connection pool")
|
||||||
|
await _http_pool.ensure_client()
|
||||||
|
|
||||||
|
|
||||||
|
async def shutdown_http_pool() -> None:
|
||||||
|
"""Shutdown the HTTP connection pool on shutdown."""
|
||||||
|
logger.info("Shutting down HTTP connection pool")
|
||||||
|
await _http_pool.close()
|
||||||
317
session-manager/logging_config.py
Normal file
317
session-manager/logging_config.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
"""
|
||||||
|
Structured Logging Configuration
|
||||||
|
|
||||||
|
Provides comprehensive logging infrastructure with structured logging,
|
||||||
|
request tracking, log formatting, and aggregation capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredFormatter(logging.Formatter):
|
||||||
|
"""Structured JSON formatter for production logging."""
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
# Create structured log entry
|
||||||
|
log_entry = {
|
||||||
|
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||||
|
"level": record.levelname,
|
||||||
|
"logger": record.name,
|
||||||
|
"message": record.getMessage(),
|
||||||
|
"module": record.module,
|
||||||
|
"function": record.funcName,
|
||||||
|
"line": record.lineno,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add exception info if present
|
||||||
|
if record.exc_info:
|
||||||
|
log_entry["exception"] = self.formatException(record.exc_info)
|
||||||
|
|
||||||
|
# Add extra fields from record
|
||||||
|
if hasattr(record, "request_id"):
|
||||||
|
log_entry["request_id"] = record.request_id
|
||||||
|
if hasattr(record, "session_id"):
|
||||||
|
log_entry["session_id"] = record.session_id
|
||||||
|
if hasattr(record, "user_id"):
|
||||||
|
log_entry["user_id"] = record.user_id
|
||||||
|
if hasattr(record, "operation"):
|
||||||
|
log_entry["operation"] = record.operation
|
||||||
|
if hasattr(record, "duration_ms"):
|
||||||
|
log_entry["duration_ms"] = record.duration_ms
|
||||||
|
if hasattr(record, "status_code"):
|
||||||
|
log_entry["status_code"] = record.status_code
|
||||||
|
|
||||||
|
# Add any additional structured data
|
||||||
|
if hasattr(record, "__dict__"):
|
||||||
|
for key, value in record.__dict__.items():
|
||||||
|
if key not in [
|
||||||
|
"name",
|
||||||
|
"msg",
|
||||||
|
"args",
|
||||||
|
"levelname",
|
||||||
|
"levelno",
|
||||||
|
"pathname",
|
||||||
|
"filename",
|
||||||
|
"module",
|
||||||
|
"exc_info",
|
||||||
|
"exc_text",
|
||||||
|
"stack_info",
|
||||||
|
"lineno",
|
||||||
|
"funcName",
|
||||||
|
"created",
|
||||||
|
"msecs",
|
||||||
|
"relativeCreated",
|
||||||
|
"thread",
|
||||||
|
"threadName",
|
||||||
|
"processName",
|
||||||
|
"process",
|
||||||
|
"message",
|
||||||
|
]:
|
||||||
|
log_entry[key] = value
|
||||||
|
|
||||||
|
return json.dumps(log_entry, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
class HumanReadableFormatter(logging.Formatter):
|
||||||
|
"""Human-readable formatter for development."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
fmt="%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
# Add request ID to human readable format
|
||||||
|
if hasattr(record, "request_id"):
|
||||||
|
self._fmt = "%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d [%(request_id)s] - %(message)s"
|
||||||
|
else:
|
||||||
|
self._fmt = "%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d - %(message)s"
|
||||||
|
|
||||||
|
return super().format(record)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestContext:
|
||||||
|
"""Context manager for request-scoped logging."""
|
||||||
|
|
||||||
|
_local = threading.local()
|
||||||
|
|
||||||
|
def __init__(self, request_id: Optional[str] = None):
|
||||||
|
self.request_id = request_id or str(uuid.uuid4())[:8]
|
||||||
|
self._old_request_id = getattr(self._local, "request_id", None)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._local.request_id = self.request_id
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self._old_request_id is not None:
|
||||||
|
self._local.request_id = self._old_request_id
|
||||||
|
else:
|
||||||
|
delattr(self._local, "request_id")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_current_request_id(cls) -> Optional[str]:
|
||||||
|
"""Get the current request ID from thread local storage."""
|
||||||
|
return getattr(cls._local, "request_id", None)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestAdapter(logging.LoggerAdapter):
|
||||||
|
"""Logger adapter that automatically adds request context."""
|
||||||
|
|
||||||
|
def __init__(self, logger: logging.Logger):
|
||||||
|
super().__init__(logger, {})
|
||||||
|
|
||||||
|
def process(self, msg: str, kwargs: Any) -> tuple:
|
||||||
|
"""Add request context to log records."""
|
||||||
|
request_id = RequestContext.get_current_request_id()
|
||||||
|
if request_id:
|
||||||
|
kwargs.setdefault("extra", {})["request_id"] = request_id
|
||||||
|
return msg, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(
|
||||||
|
level: str = "INFO",
|
||||||
|
format_type: str = "auto", # "json", "human", or "auto"
|
||||||
|
log_file: Optional[str] = None,
|
||||||
|
max_file_size: int = 10 * 1024 * 1024, # 10MB
|
||||||
|
backup_count: int = 5,
|
||||||
|
enable_console: bool = True,
|
||||||
|
enable_file: bool = True,
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Set up comprehensive logging configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||||
|
format_type: Log format - "json", "human", or "auto" (detects from environment)
|
||||||
|
log_file: Path to log file (optional)
|
||||||
|
max_file_size: Maximum log file size in bytes
|
||||||
|
backup_count: Number of backup files to keep
|
||||||
|
enable_console: Enable console logging
|
||||||
|
enable_file: Enable file logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured root logger
|
||||||
|
"""
|
||||||
|
# Determine format type
|
||||||
|
if format_type == "auto":
|
||||||
|
# Use JSON for production, human-readable for development
|
||||||
|
format_type = "json" if os.getenv("ENVIRONMENT") == "production" else "human"
|
||||||
|
|
||||||
|
# Clear existing handlers
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.handlers.clear()
|
||||||
|
|
||||||
|
# Set log level
|
||||||
|
numeric_level = getattr(logging, level.upper(), logging.INFO)
|
||||||
|
root_logger.setLevel(numeric_level)
|
||||||
|
|
||||||
|
# Create formatters
|
||||||
|
if format_type == "json":
|
||||||
|
formatter = StructuredFormatter()
|
||||||
|
else:
|
||||||
|
formatter = HumanReadableFormatter()
|
||||||
|
|
||||||
|
# Console handler
|
||||||
|
if enable_console:
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setLevel(numeric_level)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
root_logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# File handler with rotation
|
||||||
|
if enable_file and log_file:
|
||||||
|
# Ensure log directory exists
|
||||||
|
log_path = Path(log_file)
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
file_handler = logging.handlers.RotatingFileHandler(
|
||||||
|
log_file, maxBytes=max_file_size, backupCount=backup_count, encoding="utf-8"
|
||||||
|
)
|
||||||
|
file_handler.setLevel(numeric_level)
|
||||||
|
file_handler.setFormatter(StructuredFormatter()) # Always use JSON for files
|
||||||
|
root_logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# Create request adapter for the root logger
|
||||||
|
adapter = RequestAdapter(root_logger)
|
||||||
|
|
||||||
|
# Configure third-party loggers
|
||||||
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("docker").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("aiodeocker").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str) -> RequestAdapter:
|
||||||
|
"""Get a configured logger with request context support."""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
return RequestAdapter(logger)
|
||||||
|
|
||||||
|
|
||||||
|
def log_performance(operation: str, duration_ms: float, **kwargs) -> None:
|
||||||
|
"""Log performance metrics."""
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
extra = {"operation": operation, "duration_ms": duration_ms, **kwargs}
|
||||||
|
logger.info(
|
||||||
|
f"Performance: {operation} completed in {duration_ms:.2f}ms", extra=extra
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_request(
|
||||||
|
method: str, path: str, status_code: int, duration_ms: float, **kwargs
|
||||||
|
) -> None:
|
||||||
|
"""Log HTTP request metrics."""
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
extra = {
|
||||||
|
"operation": "http_request",
|
||||||
|
"method": method,
|
||||||
|
"path": path,
|
||||||
|
"status_code": status_code,
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
level = logging.INFO if status_code < 400 else logging.WARNING
|
||||||
|
logger.log(
|
||||||
|
level,
|
||||||
|
f"HTTP {method} {path} -> {status_code} ({duration_ms:.2f}ms)",
|
||||||
|
extra=extra,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_session_operation(session_id: str, operation: str, **kwargs) -> None:
|
||||||
|
"""Log session-related operations."""
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
extra = {"session_id": session_id, "operation": operation, **kwargs}
|
||||||
|
logger.info(f"Session {operation}: {session_id}", extra=extra)
|
||||||
|
|
||||||
|
|
||||||
|
def log_security_event(event: str, severity: str = "info", **kwargs) -> None:
|
||||||
|
"""Log security-related events."""
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
extra = {"security_event": event, "severity": severity, **kwargs}
|
||||||
|
level = getattr(logging, severity.upper(), logging.INFO)
|
||||||
|
logger.log(level, f"Security: {event}", extra=extra)
|
||||||
|
|
||||||
|
|
||||||
|
# Global logger instance
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Initialize logging on import
|
||||||
|
_setup_complete = False
|
||||||
|
|
||||||
|
|
||||||
|
def init_logging():
|
||||||
|
"""Initialize logging system."""
|
||||||
|
global _setup_complete
|
||||||
|
if _setup_complete:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Configuration from environment
|
||||||
|
level = os.getenv("LOG_LEVEL", "INFO")
|
||||||
|
format_type = os.getenv("LOG_FORMAT", "auto") # json, human, auto
|
||||||
|
log_file = os.getenv("LOG_FILE")
|
||||||
|
max_file_size = int(os.getenv("LOG_MAX_SIZE_MB", "10")) * 1024 * 1024
|
||||||
|
backup_count = int(os.getenv("LOG_BACKUP_COUNT", "5"))
|
||||||
|
enable_console = os.getenv("LOG_CONSOLE", "true").lower() == "true"
|
||||||
|
enable_file = (
|
||||||
|
os.getenv("LOG_FILE_ENABLED", "true").lower() == "true" and log_file is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
setup_logging(
|
||||||
|
level=level,
|
||||||
|
format_type=format_type,
|
||||||
|
log_file=log_file,
|
||||||
|
max_file_size=max_file_size,
|
||||||
|
backup_count=backup_count,
|
||||||
|
enable_console=enable_console,
|
||||||
|
enable_file=enable_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Structured logging initialized",
|
||||||
|
extra={
|
||||||
|
"level": level,
|
||||||
|
"format": format_type,
|
||||||
|
"log_file": log_file,
|
||||||
|
"max_file_size_mb": max_file_size // (1024 * 1024),
|
||||||
|
"backup_count": backup_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
_setup_complete = True
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize on import
|
||||||
|
init_logging()
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,9 @@
|
|||||||
fastapi==0.104.1
|
fastapi==0.104.1
|
||||||
uvicorn==0.24.0
|
uvicorn==0.24.0
|
||||||
docker>=7.1.0
|
docker>=7.1.0
|
||||||
|
aiodeocker>=0.21.0
|
||||||
|
asyncpg>=0.29.0
|
||||||
pydantic==2.5.0
|
pydantic==2.5.0
|
||||||
python-multipart==0.0.6
|
python-multipart==0.0.6
|
||||||
httpx==0.25.2
|
httpx==0.25.2
|
||||||
|
psutil>=5.9.0
|
||||||
248
session-manager/resource_manager.py
Normal file
248
session-manager/resource_manager.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
"""
|
||||||
|
Resource Management and Monitoring Utilities
|
||||||
|
|
||||||
|
Provides validation, enforcement, and monitoring of container resource limits
|
||||||
|
to prevent resource exhaustion attacks and ensure fair resource allocation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import psutil
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResourceLimits:
|
||||||
|
"""Container resource limits configuration."""
|
||||||
|
|
||||||
|
memory_limit: str # e.g., "4g", "512m"
|
||||||
|
cpu_quota: int # CPU quota in microseconds
|
||||||
|
cpu_period: int # CPU period in microseconds
|
||||||
|
|
||||||
|
def validate(self) -> Tuple[bool, str]:
|
||||||
|
"""Validate resource limits configuration."""
|
||||||
|
# Validate memory limit format
|
||||||
|
memory_limit = self.memory_limit.lower()
|
||||||
|
if not (memory_limit.endswith(("g", "m", "k")) or memory_limit.isdigit()):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Invalid memory limit format: {self.memory_limit}. Use format like '4g', '512m', '256k'",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate CPU quota and period
|
||||||
|
if self.cpu_quota <= 0:
|
||||||
|
return False, f"CPU quota must be positive, got {self.cpu_quota}"
|
||||||
|
if self.cpu_period <= 0:
|
||||||
|
return False, f"CPU period must be positive, got {self.cpu_period}"
|
||||||
|
if self.cpu_quota > self.cpu_period:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"CPU quota ({self.cpu_quota}) cannot exceed CPU period ({self.cpu_period})",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, "Valid"
|
||||||
|
|
||||||
|
def to_docker_limits(self) -> Dict[str, any]:
|
||||||
|
"""Convert to Docker container limits format."""
|
||||||
|
return {
|
||||||
|
"mem_limit": self.memory_limit,
|
||||||
|
"cpu_quota": self.cpu_quota,
|
||||||
|
"cpu_period": self.cpu_period,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceMonitor:
|
||||||
|
"""Monitor system and container resource usage."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._last_check = datetime.now()
|
||||||
|
self._alerts_sent = set() # Track alerts to prevent spam
|
||||||
|
|
||||||
|
def get_system_resources(self) -> Dict[str, any]:
|
||||||
|
"""Get current system resource usage."""
|
||||||
|
try:
|
||||||
|
memory = psutil.virtual_memory()
|
||||||
|
cpu = psutil.cpu_percent(interval=1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"memory_percent": memory.percent / 100.0,
|
||||||
|
"memory_used_gb": memory.used / (1024**3),
|
||||||
|
"memory_total_gb": memory.total / (1024**3),
|
||||||
|
"cpu_percent": cpu / 100.0,
|
||||||
|
"cpu_count": psutil.cpu_count(),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get system resources: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def check_resource_limits(
|
||||||
|
self, limits: ResourceLimits, warning_thresholds: Dict[str, float]
|
||||||
|
) -> Dict[str, any]:
|
||||||
|
"""Check if system resources are approaching limits."""
|
||||||
|
system_resources = self.get_system_resources()
|
||||||
|
alerts = []
|
||||||
|
|
||||||
|
# Check memory usage
|
||||||
|
memory_usage = system_resources.get("memory_percent", 0)
|
||||||
|
memory_threshold = warning_thresholds.get("memory", 0.8)
|
||||||
|
|
||||||
|
if memory_usage >= memory_threshold:
|
||||||
|
alerts.append(
|
||||||
|
{
|
||||||
|
"type": "memory",
|
||||||
|
"level": "warning" if memory_usage < 0.95 else "critical",
|
||||||
|
"message": f"System memory usage at {memory_usage:.1%}",
|
||||||
|
"current": memory_usage,
|
||||||
|
"threshold": memory_threshold,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check CPU usage
|
||||||
|
cpu_usage = system_resources.get("cpu_percent", 0)
|
||||||
|
cpu_threshold = warning_thresholds.get("cpu", 0.9)
|
||||||
|
|
||||||
|
if cpu_usage >= cpu_threshold:
|
||||||
|
alerts.append(
|
||||||
|
{
|
||||||
|
"type": "cpu",
|
||||||
|
"level": "warning" if cpu_usage < 0.95 else "critical",
|
||||||
|
"message": f"System CPU usage at {cpu_usage:.1%}",
|
||||||
|
"current": cpu_usage,
|
||||||
|
"threshold": cpu_threshold,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"system_resources": system_resources,
|
||||||
|
"alerts": alerts,
|
||||||
|
"timestamp": datetime.now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def should_throttle_sessions(self, resource_check: Dict) -> Tuple[bool, str]:
|
||||||
|
"""Determine if new sessions should be throttled based on resource usage."""
|
||||||
|
alerts = resource_check.get("alerts", [])
|
||||||
|
|
||||||
|
# Critical alerts always throttle
|
||||||
|
critical_alerts = [a for a in alerts if a["level"] == "critical"]
|
||||||
|
if critical_alerts:
|
||||||
|
return (
|
||||||
|
True,
|
||||||
|
f"Critical resource usage: {[a['message'] for a in critical_alerts]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Multiple warnings also throttle
|
||||||
|
warning_alerts = [a for a in alerts if a["level"] == "warning"]
|
||||||
|
if len(warning_alerts) >= 2:
|
||||||
|
return (
|
||||||
|
True,
|
||||||
|
f"Multiple resource warnings: {[a['message'] for a in warning_alerts]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return False, "Resources OK"
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceValidator:
|
||||||
|
"""Validate and parse resource limit configurations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_memory_limit(memory_str: str) -> Tuple[int, str]:
|
||||||
|
"""Parse memory limit string and return bytes."""
|
||||||
|
if not memory_str:
|
||||||
|
raise ValueError("Memory limit cannot be empty")
|
||||||
|
|
||||||
|
memory_str = memory_str.lower().strip()
|
||||||
|
|
||||||
|
# Handle different units
|
||||||
|
if memory_str.endswith("g"):
|
||||||
|
bytes_val = int(memory_str[:-1]) * (1024**3)
|
||||||
|
unit = "GB"
|
||||||
|
elif memory_str.endswith("m"):
|
||||||
|
bytes_val = int(memory_str[:-1]) * (1024**2)
|
||||||
|
unit = "MB"
|
||||||
|
elif memory_str.endswith("k"):
|
||||||
|
bytes_val = int(memory_str[:-1]) * 1024
|
||||||
|
unit = "KB"
|
||||||
|
else:
|
||||||
|
# Assume bytes if no unit
|
||||||
|
bytes_val = int(memory_str)
|
||||||
|
unit = "bytes"
|
||||||
|
|
||||||
|
if bytes_val <= 0:
|
||||||
|
raise ValueError(f"Memory limit must be positive, got {bytes_val}")
|
||||||
|
|
||||||
|
# Reasonable limits check
|
||||||
|
if bytes_val > 32 * (1024**3): # 32GB
|
||||||
|
logger.warning(f"Very high memory limit: {bytes_val} bytes")
|
||||||
|
|
||||||
|
return bytes_val, unit
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_resource_config(
|
||||||
|
config: Dict[str, any],
|
||||||
|
) -> Tuple[bool, str, Optional[ResourceLimits]]:
|
||||||
|
"""Validate complete resource configuration."""
|
||||||
|
try:
|
||||||
|
limits = ResourceLimits(
|
||||||
|
memory_limit=config.get("memory_limit", "4g"),
|
||||||
|
cpu_quota=config.get("cpu_quota", 100000),
|
||||||
|
cpu_period=config.get("cpu_period", 100000),
|
||||||
|
)
|
||||||
|
|
||||||
|
valid, message = limits.validate()
|
||||||
|
if not valid:
|
||||||
|
return False, message, None
|
||||||
|
|
||||||
|
# Additional validation
|
||||||
|
memory_bytes, _ = ResourceValidator.parse_memory_limit(limits.memory_limit)
|
||||||
|
|
||||||
|
# Warn about potentially problematic configurations
|
||||||
|
if memory_bytes < 128 * (1024**2): # Less than 128MB
|
||||||
|
logger.warning("Very low memory limit may cause container instability")
|
||||||
|
|
||||||
|
return True, "Configuration valid", limits
|
||||||
|
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
return False, f"Invalid configuration: {e}", None
|
||||||
|
|
||||||
|
|
||||||
|
# Global instances
|
||||||
|
resource_monitor = ResourceMonitor()
|
||||||
|
|
||||||
|
|
||||||
|
def get_resource_limits() -> ResourceLimits:
|
||||||
|
"""Get validated resource limits from environment."""
|
||||||
|
config = {
|
||||||
|
"memory_limit": os.getenv("CONTAINER_MEMORY_LIMIT", "4g"),
|
||||||
|
"cpu_quota": int(os.getenv("CONTAINER_CPU_QUOTA", "100000")),
|
||||||
|
"cpu_period": int(os.getenv("CONTAINER_CPU_PERIOD", "100000")),
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, message, limits = ResourceValidator.validate_resource_config(config)
|
||||||
|
if not valid or limits is None:
|
||||||
|
raise ValueError(f"Resource configuration error: {message}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Using resource limits: memory={limits.memory_limit}, cpu_quota={limits.cpu_quota}"
|
||||||
|
)
|
||||||
|
return limits
|
||||||
|
|
||||||
|
|
||||||
|
def check_system_resources() -> Dict[str, any]:
|
||||||
|
"""Check current system resource status."""
|
||||||
|
limits = get_resource_limits()
|
||||||
|
warning_thresholds = {
|
||||||
|
"memory": float(os.getenv("MEMORY_WARNING_THRESHOLD", "0.8")),
|
||||||
|
"cpu": float(os.getenv("CPU_WARNING_THRESHOLD", "0.9")),
|
||||||
|
}
|
||||||
|
|
||||||
|
return resource_monitor.check_resource_limits(limits, warning_thresholds)
|
||||||
|
|
||||||
|
|
||||||
|
def should_throttle_sessions() -> Tuple[bool, str]:
|
||||||
|
"""Check if new sessions should be throttled due to resource constraints."""
|
||||||
|
resource_check = check_system_resources()
|
||||||
|
return resource_monitor.should_throttle_sessions(resource_check)
|
||||||
235
session-manager/session_auth.py
Normal file
235
session-manager/session_auth.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""
|
||||||
|
Token-Based Authentication for OpenCode Sessions
|
||||||
|
|
||||||
|
Provides secure token generation, validation, and management for individual
|
||||||
|
user sessions to prevent unauthorized access to OpenCode servers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import secrets
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionTokenManager:
|
||||||
|
"""Manages authentication tokens for OpenCode user sessions."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Token storage - in production, this should be in Redis/database
|
||||||
|
self._session_tokens: Dict[str, Dict] = {}
|
||||||
|
|
||||||
|
# Token configuration
|
||||||
|
self._token_length = int(os.getenv("SESSION_TOKEN_LENGTH", "32"))
|
||||||
|
self._token_expiry_hours = int(os.getenv("SESSION_TOKEN_EXPIRY_HOURS", "24"))
|
||||||
|
self._token_secret = os.getenv("SESSION_TOKEN_SECRET", self._generate_secret())
|
||||||
|
|
||||||
|
# Cleanup configuration
|
||||||
|
self._cleanup_interval_minutes = int(
|
||||||
|
os.getenv("TOKEN_CLEANUP_INTERVAL_MINUTES", "60")
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_secret(self) -> str:
|
||||||
|
"""Generate a secure secret for token signing."""
|
||||||
|
return secrets.token_hex(32)
|
||||||
|
|
||||||
|
def generate_session_token(self, session_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate a unique authentication token for a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The authentication token
|
||||||
|
"""
|
||||||
|
# Generate cryptographically secure random token
|
||||||
|
token = secrets.token_urlsafe(self._token_length)
|
||||||
|
|
||||||
|
# Create token data with expiry
|
||||||
|
expiry = datetime.now() + timedelta(hours=self._token_expiry_hours)
|
||||||
|
|
||||||
|
# Store token information
|
||||||
|
self._session_tokens[session_id] = {
|
||||||
|
"token": token,
|
||||||
|
"session_id": session_id,
|
||||||
|
"created_at": datetime.now(),
|
||||||
|
"expires_at": expiry,
|
||||||
|
"last_used": datetime.now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Generated authentication token for session {session_id}")
|
||||||
|
return token
|
||||||
|
|
||||||
|
def validate_session_token(self, session_id: str, token: str) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Validate a session token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session identifier
|
||||||
|
token: The token to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: (is_valid, reason)
|
||||||
|
"""
|
||||||
|
# Check if session exists
|
||||||
|
if session_id not in self._session_tokens:
|
||||||
|
return False, "Session not found"
|
||||||
|
|
||||||
|
session_data = self._session_tokens[session_id]
|
||||||
|
|
||||||
|
# Check if token matches
|
||||||
|
if session_data["token"] != token:
|
||||||
|
return False, "Invalid token"
|
||||||
|
|
||||||
|
# Check if token has expired
|
||||||
|
if datetime.now() > session_data["expires_at"]:
|
||||||
|
# Clean up expired token
|
||||||
|
del self._session_tokens[session_id]
|
||||||
|
return False, "Token expired"
|
||||||
|
|
||||||
|
# Update last used time
|
||||||
|
session_data["last_used"] = datetime.now()
|
||||||
|
|
||||||
|
return True, "Valid"
|
||||||
|
|
||||||
|
def revoke_session_token(self, session_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Revoke a session token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if token was revoked, False if not found
|
||||||
|
"""
|
||||||
|
if session_id in self._session_tokens:
|
||||||
|
del self._session_tokens[session_id]
|
||||||
|
logger.info(f"Revoked authentication token for session {session_id}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def rotate_session_token(self, session_id: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Rotate (regenerate) a session token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: New token if session exists, None otherwise
|
||||||
|
"""
|
||||||
|
if session_id not in self._session_tokens:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Generate new token
|
||||||
|
new_token = self.generate_session_token(session_id)
|
||||||
|
|
||||||
|
logger.info(f"Rotated authentication token for session {session_id}")
|
||||||
|
return new_token
|
||||||
|
|
||||||
|
def cleanup_expired_tokens(self) -> int:
|
||||||
|
"""
|
||||||
|
Clean up expired tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of tokens cleaned up
|
||||||
|
"""
|
||||||
|
now = datetime.now()
|
||||||
|
expired_sessions = []
|
||||||
|
|
||||||
|
for session_id, session_data in self._session_tokens.items():
|
||||||
|
if now > session_data["expires_at"]:
|
||||||
|
expired_sessions.append(session_id)
|
||||||
|
|
||||||
|
# Remove expired tokens
|
||||||
|
for session_id in expired_sessions:
|
||||||
|
del self._session_tokens[session_id]
|
||||||
|
|
||||||
|
if expired_sessions:
|
||||||
|
logger.info(
|
||||||
|
f"Cleaned up {len(expired_sessions)} expired authentication tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(expired_sessions)
|
||||||
|
|
||||||
|
def get_session_token_info(self, session_id: str) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
Get information about a session token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Dict]: Token information or None if not found
|
||||||
|
"""
|
||||||
|
if session_id not in self._session_tokens:
|
||||||
|
return None
|
||||||
|
|
||||||
|
session_data = self._session_tokens[session_id].copy()
|
||||||
|
# Remove sensitive token value
|
||||||
|
session_data.pop("token", None)
|
||||||
|
return session_data
|
||||||
|
|
||||||
|
def get_active_sessions_count(self) -> int:
|
||||||
|
"""Get the number of active sessions with tokens."""
|
||||||
|
return len(self._session_tokens)
|
||||||
|
|
||||||
|
def list_active_sessions(self) -> Dict[str, Dict]:
|
||||||
|
"""List all active sessions with token information (without token values)."""
|
||||||
|
result = {}
|
||||||
|
for session_id, session_data in self._session_tokens.items():
|
||||||
|
# Create copy without sensitive token
|
||||||
|
info = session_data.copy()
|
||||||
|
info.pop("token", None)
|
||||||
|
result[session_id] = info
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Global token manager instance
|
||||||
|
_session_token_manager = SessionTokenManager()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_session_auth_token(session_id: str) -> str:
|
||||||
|
"""Generate an authentication token for a session."""
|
||||||
|
return _session_token_manager.generate_session_token(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_session_auth_token(session_id: str, token: str) -> Tuple[bool, str]:
|
||||||
|
"""Validate a session authentication token."""
|
||||||
|
return _session_token_manager.validate_session_token(session_id, token)
|
||||||
|
|
||||||
|
|
||||||
|
def revoke_session_auth_token(session_id: str) -> bool:
|
||||||
|
"""Revoke a session authentication token."""
|
||||||
|
return _session_token_manager.revoke_session_token(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_session_auth_token(session_id: str) -> Optional[str]:
|
||||||
|
"""Rotate a session authentication token."""
|
||||||
|
return _session_token_manager.rotate_session_auth_token(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_expired_auth_tokens() -> int:
|
||||||
|
"""Clean up expired authentication tokens."""
|
||||||
|
return _session_token_manager.cleanup_expired_tokens()
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_auth_info(session_id: str) -> Optional[Dict]:
|
||||||
|
"""Get authentication information for a session."""
|
||||||
|
return _session_token_manager.get_session_token_info(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_auth_sessions_count() -> int:
|
||||||
|
"""Get the number of active authenticated sessions."""
|
||||||
|
return _session_token_manager.get_active_sessions_count()
|
||||||
|
|
||||||
|
|
||||||
|
def list_active_auth_sessions() -> Dict[str, Dict]:
|
||||||
|
"""List all active authenticated sessions."""
|
||||||
|
return _session_token_manager.list_active_sessions()
|
||||||
Reference in New Issue
Block a user