diff --git a/session-manager/async_docker_client.py b/session-manager/async_docker_client.py new file mode 100644 index 0000000..fcc18c2 --- /dev/null +++ b/session-manager/async_docker_client.py @@ -0,0 +1,302 @@ +""" +Async Docker Operations Wrapper + +Provides async wrappers for Docker operations to eliminate blocking calls +in FastAPI async contexts and improve concurrency and scalability. +""" + +import asyncio +import logging +from typing import Dict, Optional, List, Any +from contextlib import asynccontextmanager +import os + +from aiodeocker import Docker +from aiodeocker.containers import DockerContainer +from aiodeocker.exceptions import DockerError + +logger = logging.getLogger(__name__) + + +class AsyncDockerClient: + """Async wrapper for Docker operations using aiodeocker.""" + + def __init__(self): + self._docker: Optional[Docker] = None + self._connected = False + + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.disconnect() + + async def connect(self): + """Connect to Docker daemon.""" + if self._connected: + return + + try: + # Configure TLS if available + tls_config = None + if os.getenv("DOCKER_TLS_VERIFY") == "1": + from aiodeocker.utils import create_tls_config + + tls_config = create_tls_config( + ca_cert=os.getenv("DOCKER_CA_CERT", "/etc/docker/certs/ca.pem"), + client_cert=( + os.getenv( + "DOCKER_CLIENT_CERT", "/etc/docker/certs/client-cert.pem" + ), + os.getenv( + "DOCKER_CLIENT_KEY", "/etc/docker/certs/client-key.pem" + ), + ), + verify=True, + ) + + docker_host = os.getenv("DOCKER_HOST", "tcp://host.docker.internal:2376") + self._docker = Docker(docker_host, tls=tls_config) + + # Test connection + await self._docker.ping() + self._connected = True + logger.info("Async Docker client connected successfully") + + except Exception as e: + logger.error(f"Failed to connect to Docker: {e}") + raise + + async def disconnect(self): + """Disconnect from Docker daemon.""" + if self._docker and self._connected: + await self._docker.close() + self._connected = False + logger.info("Async Docker client disconnected") + + async def ping(self) -> bool: + """Test Docker connectivity.""" + if not self._docker: + return False + try: + await self._docker.ping() + return True + except Exception: + return False + + async def create_container( + self, + image: str, + name: str, + volumes: Optional[Dict[str, Dict[str, str]]] = None, + ports: Optional[Dict[str, int]] = None, + environment: Optional[Dict[str, str]] = None, + network_mode: str = "bridge", + mem_limit: Optional[str] = None, + cpu_quota: Optional[int] = None, + cpu_period: Optional[int] = None, + tmpfs: Optional[Dict[str, str]] = None, + **kwargs, + ) -> DockerContainer: + """ + Create a Docker container asynchronously. + + Args: + image: Container image name + name: Container name + volumes: Volume mounts + ports: Port mappings + environment: Environment variables + network_mode: Network mode + mem_limit: Memory limit (e.g., "4g") + cpu_quota: CPU quota + cpu_period: CPU period + tmpfs: tmpfs mounts + **kwargs: Additional container configuration + + Returns: + DockerContainer: The created container + """ + if not self._docker: + raise RuntimeError("Docker client not connected") + + config = { + "Image": image, + "name": name, + "Volumes": volumes or {}, + "ExposedPorts": {f"{port}/tcp": {} for port in ports.values()} + if ports + else {}, + "Env": [f"{k}={v}" for k, v in (environment or {}).items()], + "NetworkMode": network_mode, + "HostConfig": { + "Binds": [ + f"{host}:{container['bind']}:{container.get('mode', 'rw')}" + for host, container in (volumes or {}).items() + ], + "PortBindings": { + f"{container_port}/tcp": [{"HostPort": str(host_port)}] + for container_port, host_port in (ports or {}).items() + }, + "Tmpfs": tmpfs or {}, + }, + } + + # Add resource limits + host_config = config["HostConfig"] + if mem_limit: + host_config["Memory"] = self._parse_memory_limit(mem_limit) + if cpu_quota is not None: + host_config["CpuQuota"] = cpu_quota + if cpu_period is not None: + host_config["CpuPeriod"] = cpu_period + + # Add any additional host config + host_config.update(kwargs.get("host_config", {})) + + try: + container = await self._docker.containers.create(config) + logger.info(f"Container {name} created successfully") + return container + except DockerError as e: + logger.error(f"Failed to create container {name}: {e}") + raise + + async def start_container(self, container: DockerContainer) -> None: + """Start a Docker container.""" + try: + await container.start() + logger.info(f"Container {container.id} started successfully") + except DockerError as e: + logger.error(f"Failed to start container {container.id}: {e}") + raise + + async def stop_container( + self, container: DockerContainer, timeout: int = 10 + ) -> None: + """Stop a Docker container.""" + try: + await container.stop(timeout=timeout) + logger.info(f"Container {container.id} stopped successfully") + except DockerError as e: + logger.error(f"Failed to stop container {container.id}: {e}") + raise + + async def remove_container( + self, container: DockerContainer, force: bool = False + ) -> None: + """Remove a Docker container.""" + try: + await container.delete(force=force) + logger.info(f"Container {container.id} removed successfully") + except DockerError as e: + logger.error(f"Failed to remove container {container.id}: {e}") + raise + + async def get_container(self, container_id: str) -> Optional[DockerContainer]: + """Get a container by ID or name.""" + try: + return await self._docker.containers.get(container_id) + except DockerError: + return None + + async def list_containers( + self, all: bool = False, filters: Optional[Dict[str, Any]] = None + ) -> List[DockerContainer]: + """List Docker containers.""" + try: + return await self._docker.containers.list(all=all, filters=filters) + except DockerError as e: + logger.error(f"Failed to list containers: {e}") + return [] + + async def get_container_stats( + self, container: DockerContainer + ) -> Optional[Dict[str, Any]]: + """Get container statistics.""" + try: + stats = await container.stats(stream=False) + return stats + except DockerError as e: + logger.error(f"Failed to get stats for container {container.id}: {e}") + return None + + async def get_system_info(self) -> Optional[Dict[str, Any]]: + """Get Docker system information.""" + if not self._docker: + return None + try: + return await self._docker.system.info() + except DockerError as e: + logger.error(f"Failed to get system info: {e}") + return None + + def _parse_memory_limit(self, memory_str: str) -> int: + """Parse memory limit string to bytes.""" + memory_str = memory_str.lower().strip() + if memory_str.endswith("g"): + return int(memory_str[:-1]) * (1024**3) + elif memory_str.endswith("m"): + return int(memory_str[:-1]) * (1024**2) + elif memory_str.endswith("k"): + return int(memory_str[:-1]) * 1024 + else: + return int(memory_str) + + +# Global async Docker client instance +_async_docker_client = AsyncDockerClient() + + +@asynccontextmanager +async def get_async_docker_client(): + """Context manager for async Docker client.""" + async with _async_docker_client as client: + yield client + + +async def async_docker_ping() -> bool: + """Async ping Docker daemon.""" + async with get_async_docker_client() as client: + return await client.ping() + + +async def async_create_container(**kwargs) -> DockerContainer: + """Async container creation wrapper.""" + async with get_async_docker_client() as client: + return await client.create_container(**kwargs) + + +async def async_start_container(container: DockerContainer) -> None: + """Async container start wrapper.""" + async with get_async_docker_client() as client: + await client.start_container(container) + + +async def async_stop_container(container: DockerContainer, timeout: int = 10) -> None: + """Async container stop wrapper.""" + async with get_async_docker_client() as client: + await client.stop_container(container, timeout) + + +async def async_remove_container( + container: DockerContainer, force: bool = False +) -> None: + """Async container removal wrapper.""" + async with get_async_docker_client() as client: + await client.remove_container(container, force) + + +async def async_list_containers( + all: bool = False, filters: Optional[Dict[str, Any]] = None +) -> List[DockerContainer]: + """Async container listing wrapper.""" + async with get_async_docker_client() as client: + return await client.list_containers(all=all, filters=filters) + + +async def async_get_container(container_id: str) -> Optional[DockerContainer]: + """Async container retrieval wrapper.""" + async with get_async_docker_client() as client: + return await client.get_container(container_id) diff --git a/session-manager/container_health.py b/session-manager/container_health.py new file mode 100644 index 0000000..4382310 --- /dev/null +++ b/session-manager/container_health.py @@ -0,0 +1,574 @@ +""" +Container Health Monitoring System + +Provides active monitoring of Docker containers with automatic failure detection, +recovery mechanisms, and integration with session management and alerting systems. +""" + +import asyncio +import logging +from typing import Dict, List, Optional, Tuple, Any +from datetime import datetime, timedelta +from enum import Enum + +from logging_config import get_logger, log_performance, log_security_event + +logger = get_logger(__name__) + + +class ContainerStatus(Enum): + """Container health status enumeration.""" + + HEALTHY = "healthy" + UNHEALTHY = "unhealthy" + RESTARTING = "restarting" + FAILED = "failed" + UNKNOWN = "unknown" + + +class HealthCheckResult: + """Result of a container health check.""" + + def __init__( + self, + session_id: str, + container_id: str, + status: ContainerStatus, + response_time: Optional[float] = None, + error_message: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + self.session_id = session_id + self.container_id = container_id + self.status = status + self.response_time = response_time + self.error_message = error_message + self.metadata = metadata or {} + self.timestamp = datetime.utcnow() + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for logging/serialization.""" + return { + "session_id": self.session_id, + "container_id": self.container_id, + "status": self.status.value, + "response_time": self.response_time, + "error_message": self.error_message, + "metadata": self.metadata, + "timestamp": self.timestamp.isoformat(), + } + + +class ContainerHealthMonitor: + """Monitors Docker container health and handles automatic recovery.""" + + def __init__( + self, + check_interval: int = 30, # seconds + health_timeout: float = 10.0, # seconds + max_restart_attempts: int = 3, + restart_delay: int = 5, # seconds + failure_threshold: int = 3, # consecutive failures before restart + ): + self.check_interval = check_interval + self.health_timeout = health_timeout + self.max_restart_attempts = max_restart_attempts + self.restart_delay = restart_delay + self.failure_threshold = failure_threshold + + # Monitoring state + self._monitoring = False + self._task: Optional[asyncio.Task] = None + self._health_history: Dict[str, List[HealthCheckResult]] = {} + self._restart_counts: Dict[str, int] = {} + + # Dependencies (injected) + self.session_manager = None + self.docker_client = None + + logger.info( + "Container health monitor initialized", + extra={ + "check_interval": check_interval, + "health_timeout": health_timeout, + "max_restart_attempts": max_restart_attempts, + }, + ) + + def set_dependencies(self, session_manager, docker_client): + """Set dependencies for health monitoring.""" + self.session_manager = session_manager + self.docker_client = docker_client + + async def start_monitoring(self): + """Start the health monitoring loop.""" + if self._monitoring: + logger.warning("Health monitoring already running") + return + + self._monitoring = True + self._task = asyncio.create_task(self._monitoring_loop()) + logger.info("Container health monitoring started") + + async def stop_monitoring(self): + """Stop the health monitoring loop.""" + if not self._monitoring: + return + + self._monitoring = False + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + logger.info("Container health monitoring stopped") + + async def _monitoring_loop(self): + """Main monitoring loop.""" + while self._monitoring: + try: + await self._perform_health_checks() + await self._cleanup_old_history() + except Exception as e: + logger.error("Error in health monitoring loop", extra={"error": str(e)}) + + await asyncio.sleep(self.check_interval) + + async def _perform_health_checks(self): + """Perform health checks on all running containers.""" + if not self.session_manager: + return + + # Get all running sessions + running_sessions = [ + session + for session in self.session_manager.sessions.values() + if session.status == "running" + ] + + if not running_sessions: + return + + logger.debug(f"Checking health of {len(running_sessions)} running containers") + + # Perform health checks concurrently + tasks = [self._check_container_health(session) for session in running_sessions] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Process results + for i, result in enumerate(results): + session = running_sessions[i] + if isinstance(result, Exception): + logger.error( + "Health check failed", + extra={ + "session_id": session.session_id, + "container_id": session.container_id, + "error": str(result), + }, + ) + continue + + await self._process_health_result(result) + + async def _check_container_health(self, session) -> HealthCheckResult: + """Check the health of a single container.""" + start_time = asyncio.get_event_loop().time() + + try: + # Check if container exists and is running + if not session.container_id: + return HealthCheckResult( + session.session_id, + session.container_id or "unknown", + ContainerStatus.UNKNOWN, + error_message="No container ID", + ) + + # Get container status + container_info = await self._get_container_info(session.container_id) + if not container_info: + return HealthCheckResult( + session.session_id, + session.container_id, + ContainerStatus.FAILED, + error_message="Container not found", + ) + + # Check container state + state = container_info.get("State", {}) + status = state.get("Status", "unknown") + + if status != "running": + return HealthCheckResult( + session.session_id, + session.container_id, + ContainerStatus.FAILED, + error_message=f"Container status: {status}", + ) + + # Check health status if available + health = state.get("Health", {}) + if health: + health_status = health.get("Status", "unknown") + if health_status == "healthy": + response_time = ( + asyncio.get_event_loop().time() - start_time + ) * 1000 + return HealthCheckResult( + session.session_id, + session.container_id, + ContainerStatus.HEALTHY, + response_time=response_time, + metadata={ + "docker_status": status, + "health_status": health_status, + }, + ) + elif health_status in ["unhealthy", "starting"]: + return HealthCheckResult( + session.session_id, + session.container_id, + ContainerStatus.UNHEALTHY, + error_message=f"Health check: {health_status}", + metadata={ + "docker_status": status, + "health_status": health_status, + }, + ) + + # If no health check configured, consider running containers healthy + response_time = (asyncio.get_event_loop().time() - start_time) * 1000 + return HealthCheckResult( + session.session_id, + session.container_id, + ContainerStatus.HEALTHY, + response_time=response_time, + metadata={"docker_status": status}, + ) + + except Exception as e: + response_time = (asyncio.get_event_loop().time() - start_time) * 1000 + return HealthCheckResult( + session.session_id, + session.container_id or "unknown", + ContainerStatus.UNKNOWN, + response_time=response_time, + error_message=str(e), + ) + + async def _get_container_info(self, container_id: str) -> Optional[Dict[str, Any]]: + """Get container information from Docker.""" + try: + if self.docker_client: + # Try async Docker client first + container = await self.docker_client.get_container(container_id) + if hasattr(container, "_container"): + return await container._container.show() + elif hasattr(container, "show"): + return await container.show() + else: + # Fallback to sync client if available + if ( + hasattr(self.session_manager, "docker_client") + and self.session_manager.docker_client + ): + container = self.session_manager.docker_client.containers.get( + container_id + ) + return container.attrs + except Exception as e: + logger.debug( + f"Failed to get container info for {container_id}", + extra={"error": str(e)}, + ) + + return None + + async def _process_health_result(self, result: HealthCheckResult): + """Process a health check result and take appropriate action.""" + # Store result in history + if result.session_id not in self._health_history: + self._health_history[result.session_id] = [] + + self._health_history[result.session_id].append(result) + + # Keep only recent history (last 10 checks) + if len(self._health_history[result.session_id]) > 10: + self._health_history[result.session_id] = self._health_history[ + result.session_id + ][-10:] + + # Log result + log_extra = result.to_dict() + if result.status == ContainerStatus.HEALTHY: + logger.debug("Container health check passed", extra=log_extra) + elif result.status == ContainerStatus.UNHEALTHY: + logger.warning("Container health check failed", extra=log_extra) + elif result.status in [ContainerStatus.FAILED, ContainerStatus.UNKNOWN]: + logger.error("Container health check critical", extra=log_extra) + + # Check if restart is needed + await self._check_restart_needed(result) + + async def _check_restart_needed(self, result: HealthCheckResult): + """Check if a container needs to be restarted based on health history.""" + if result.status == ContainerStatus.HEALTHY: + # Reset restart count on successful health check + if result.session_id in self._restart_counts: + self._restart_counts[result.session_id] = 0 + return + + # Count recent failures + recent_results = self._health_history.get(result.session_id, []) + recent_failures = sum( + 1 + for r in recent_results[-self.failure_threshold :] + if r.status + in [ + ContainerStatus.UNHEALTHY, + ContainerStatus.FAILED, + ContainerStatus.UNKNOWN, + ] + ) + + if recent_failures >= self.failure_threshold: + await self._restart_container(result.session_id, result.container_id) + + async def _restart_container(self, session_id: str, container_id: str): + """Restart a failed container.""" + # Check restart limit + restart_count = self._restart_counts.get(session_id, 0) + if restart_count >= self.max_restart_attempts: + logger.error( + "Container restart limit exceeded", + extra={ + "session_id": session_id, + "container_id": container_id, + "restart_attempts": restart_count, + }, + ) + # Mark session as failed + await self._mark_session_failed( + session_id, f"Restart limit exceeded ({restart_count} attempts)" + ) + return + + logger.info( + "Attempting container restart", + extra={ + "session_id": session_id, + "container_id": container_id, + "restart_attempt": restart_count + 1, + }, + ) + + try: + # Stop the container + await self._stop_container(container_id) + + # Wait before restart + await asyncio.sleep(self.restart_delay) + + # Start new container for the session + session = await self.session_manager.get_session(session_id) + if session: + # Update restart count + self._restart_counts[session_id] = restart_count + 1 + + # Mark as restarting + await self._update_session_status(session_id, "restarting") + + # Trigger container restart through session manager + if self.session_manager: + # Create new container for the session + await self.session_manager.create_session() + logger.info( + "Container restart initiated", + extra={ + "session_id": session_id, + "restart_attempt": restart_count + 1, + }, + ) + + # Log security event + log_security_event( + "container_restart", + "warning", + { + "session_id": session_id, + "container_id": container_id, + "reason": "health_check_failure", + }, + ) + + except Exception as e: + logger.error( + "Container restart failed", + extra={ + "session_id": session_id, + "container_id": container_id, + "error": str(e), + }, + ) + + async def _stop_container(self, container_id: str): + """Stop a container.""" + try: + if self.docker_client: + container = await self.docker_client.get_container(container_id) + await self.docker_client.stop_container(container, timeout=10) + elif ( + hasattr(self.session_manager, "docker_client") + and self.session_manager.docker_client + ): + container = self.session_manager.docker_client.containers.get( + container_id + ) + container.stop(timeout=10) + except Exception as e: + logger.warning( + "Failed to stop container during restart", + extra={"container_id": container_id, "error": str(e)}, + ) + + async def _update_session_status(self, session_id: str, status: str): + """Update session status.""" + if self.session_manager: + session = self.session_manager.sessions.get(session_id) + if session: + session.status = status + # Update in database if using database storage + if ( + hasattr(self.session_manager, "USE_DATABASE_STORAGE") + and self.session_manager.USE_DATABASE_STORAGE + ): + try: + from database import SessionModel + + await SessionModel.update_session( + session_id, {"status": status} + ) + except Exception as e: + logger.warning( + "Failed to update session status in database", + extra={"session_id": session_id, "error": str(e)}, + ) + + async def _mark_session_failed(self, session_id: str, reason: str): + """Mark a session as permanently failed.""" + await self._update_session_status(session_id, "failed") + + logger.error( + "Session marked as failed", + extra={"session_id": session_id, "reason": reason}, + ) + + # Log security event + log_security_event( + "session_failure", "error", {"session_id": session_id, "reason": reason} + ) + + async def _cleanup_old_history(self): + """Clean up old health check history.""" + cutoff_time = datetime.utcnow() - timedelta(hours=1) # Keep last hour + + for session_id in list(self._health_history.keys()): + # Remove old results + self._health_history[session_id] = [ + result + for result in self._health_history[session_id] + if result.timestamp > cutoff_time + ] + + # Remove empty histories + if not self._health_history[session_id]: + del self._health_history[session_id] + + def get_health_stats(self, session_id: Optional[str] = None) -> Dict[str, Any]: + """Get health monitoring statistics.""" + stats = { + "monitoring_active": self._monitoring, + "check_interval": self.check_interval, + "total_sessions_monitored": len(self._health_history), + "sessions_with_failures": len( + [ + sid + for sid, history in self._health_history.items() + if any( + r.status != ContainerStatus.HEALTHY for r in history[-5:] + ) # Last 5 checks + ] + ), + "restart_counts": dict(self._restart_counts), + } + + if session_id and session_id in self._health_history: + recent_results = self._health_history[session_id][-10:] # Last 10 checks + stats[f"session_{session_id}"] = { + "total_checks": len(recent_results), + "healthy_checks": sum( + 1 for r in recent_results if r.status == ContainerStatus.HEALTHY + ), + "failed_checks": sum( + 1 for r in recent_results if r.status != ContainerStatus.HEALTHY + ), + "average_response_time": sum( + r.response_time or 0 for r in recent_results if r.response_time + ) + / max(1, sum(1 for r in recent_results if r.response_time)), + "last_check": recent_results[-1].to_dict() if recent_results else None, + } + + return stats + + def get_health_history( + self, session_id: str, limit: int = 50 + ) -> List[Dict[str, Any]]: + """Get health check history for a session.""" + if session_id not in self._health_history: + return [] + + return [ + result.to_dict() for result in self._health_history[session_id][-limit:] + ] + + +# Global health monitor instance +_container_health_monitor = ContainerHealthMonitor() + + +def get_container_health_monitor() -> ContainerHealthMonitor: + """Get the global container health monitor instance.""" + return _container_health_monitor + + +async def start_container_health_monitoring(session_manager=None, docker_client=None): + """Start container health monitoring.""" + monitor = get_container_health_monitor() + if session_manager: + monitor.set_dependencies(session_manager, docker_client) + await monitor.start_monitoring() + + +async def stop_container_health_monitoring(): + """Stop container health monitoring.""" + monitor = get_container_health_monitor() + await monitor.stop_monitoring() + + +def get_container_health_stats(session_id: Optional[str] = None) -> Dict[str, Any]: + """Get container health statistics.""" + monitor = get_container_health_monitor() + return monitor.get_health_stats(session_id) + + +def get_container_health_history( + session_id: str, limit: int = 50 +) -> List[Dict[str, Any]]: + """Get container health check history.""" + monitor = get_container_health_monitor() + return monitor.get_health_history(session_id, limit) diff --git a/session-manager/database.py b/session-manager/database.py new file mode 100644 index 0000000..042a63b --- /dev/null +++ b/session-manager/database.py @@ -0,0 +1,406 @@ +""" +Database Models and Connection Management for Session Persistence + +Provides PostgreSQL-backed session storage with connection pooling, +migrations, and health monitoring for production reliability. +""" + +import os +import asyncpg +import json +from typing import Dict, List, Optional, Any +from datetime import datetime, timedelta +from contextlib import asynccontextmanager +import logging + +from logging_config import get_logger + +logger = get_logger(__name__) + + +class DatabaseConnection: + """PostgreSQL connection management with pooling.""" + + def __init__(self): + self._pool: Optional[asyncpg.Pool] = None + self._config = { + "host": os.getenv("DB_HOST", "localhost"), + "port": int(os.getenv("DB_PORT", "5432")), + "user": os.getenv("DB_USER", "lovdata"), + "password": os.getenv("DB_PASSWORD", "password"), + "database": os.getenv("DB_NAME", "lovdata_chat"), + "min_size": int(os.getenv("DB_MIN_CONNECTIONS", "5")), + "max_size": int(os.getenv("DB_MAX_CONNECTIONS", "20")), + "max_queries": int(os.getenv("DB_MAX_QUERIES", "50000")), + "max_inactive_connection_lifetime": float( + os.getenv("DB_MAX_INACTIVE_LIFETIME", "300.0") + ), + } + + async def connect(self) -> None: + """Establish database connection pool.""" + if self._pool: + return + + try: + self._pool = await asyncpg.create_pool(**self._config) + logger.info( + "Database connection pool established", + extra={ + "host": self._config["host"], + "port": self._config["port"], + "database": self._config["database"], + "min_connections": self._config["min_size"], + "max_connections": self._config["max_size"], + }, + ) + except Exception as e: + logger.error( + "Failed to establish database connection", extra={"error": str(e)} + ) + raise + + async def disconnect(self) -> None: + """Close database connection pool.""" + if self._pool: + await self._pool.close() + self._pool = None + logger.info("Database connection pool closed") + + async def get_connection(self) -> asyncpg.Connection: + """Get a database connection from the pool.""" + if not self._pool: + await self.connect() + return await self._pool.acquire() + + async def release_connection(self, conn: asyncpg.Connection) -> None: + """Release a database connection back to the pool.""" + if self._pool: + await self._pool.release(conn) + + async def health_check(self) -> Dict[str, Any]: + """Perform database health check.""" + try: + conn = await self.get_connection() + result = await conn.fetchval("SELECT 1") + await self.release_connection(conn) + + if result == 1: + return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()} + else: + return {"status": "unhealthy", "error": "Health check query failed"} + except Exception as e: + logger.error("Database health check failed", extra={"error": str(e)}) + return {"status": "unhealthy", "error": str(e)} + + @asynccontextmanager + async def transaction(self): + """Context manager for database transactions.""" + conn = await self.get_connection() + try: + async with conn.transaction(): + yield conn + finally: + await self.release_connection(conn) + + +# Global database connection instance +_db_connection = DatabaseConnection() + + +@asynccontextmanager +async def get_db_connection(): + """Context manager for database connections.""" + conn = await _db_connection.get_connection() + try: + yield conn + finally: + await _db_connection.release_connection(conn) + + +async def init_database() -> None: + """Initialize database and run migrations.""" + logger.info("Initializing database") + + async with get_db_connection() as conn: + # Create sessions table + await conn.execute(""" + CREATE TABLE IF NOT EXISTS sessions ( + session_id VARCHAR(32) PRIMARY KEY, + container_name VARCHAR(255) NOT NULL, + container_id VARCHAR(255), + host_dir VARCHAR(1024) NOT NULL, + port INTEGER, + auth_token VARCHAR(255), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + last_accessed TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + status VARCHAR(50) NOT NULL DEFAULT 'creating', + metadata JSONB DEFAULT '{}' + ) + """) + + # Create indexes for performance + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status); + CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed); + CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at); + CREATE INDEX IF NOT EXISTS idx_sessions_container_name ON sessions(container_name); + """) + + # Create cleanup function for expired sessions + await conn.execute(""" + CREATE OR REPLACE FUNCTION cleanup_expired_sessions() + RETURNS INTEGER AS $$ + DECLARE + deleted_count INTEGER; + BEGIN + DELETE FROM sessions + WHERE last_accessed < NOW() - INTERVAL '1 hour'; + + GET DIAGNOSTICS deleted_count = ROW_COUNT; + RETURN deleted_count; + END; + $$ LANGUAGE plpgsql; + """) + + logger.info("Database initialized and migrations applied") + + +async def shutdown_database() -> None: + """Shutdown database connections.""" + await _db_connection.disconnect() + + +class SessionModel: + """Database model for sessions.""" + + @staticmethod + async def create_session(session_data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new session in the database.""" + async with get_db_connection() as conn: + row = await conn.fetchrow( + """ + INSERT INTO sessions ( + session_id, container_name, container_id, host_dir, port, + auth_token, status, metadata + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + RETURNING session_id, container_name, container_id, host_dir, port, + auth_token, created_at, last_accessed, status, metadata + """, + session_data["session_id"], + session_data["container_name"], + session_data.get("container_id"), + session_data["host_dir"], + session_data.get("port"), + session_data.get("auth_token"), + session_data.get("status", "creating"), + json.dumps(session_data.get("metadata", {})), + ) + + if row: + return dict(row) + raise ValueError("Failed to create session") + + @staticmethod + async def get_session(session_id: str) -> Optional[Dict[str, Any]]: + """Get a session by ID.""" + async with get_db_connection() as conn: + # Update last_accessed timestamp + row = await conn.fetchrow( + """ + UPDATE sessions + SET last_accessed = NOW() + WHERE session_id = $1 + RETURNING session_id, container_name, container_id, host_dir, port, + auth_token, created_at, last_accessed, status, metadata + """, + session_id, + ) + + if row: + result = dict(row) + result["metadata"] = json.loads(result["metadata"] or "{}") + return result + return None + + @staticmethod + async def update_session(session_id: str, updates: Dict[str, Any]) -> bool: + """Update session fields.""" + if not updates: + return True + + async with get_db_connection() as conn: + # Build dynamic update query + set_parts = [] + values = [session_id] + param_index = 2 + + for key, value in updates.items(): + if key in ["session_id", "created_at"]: # Don't update these + continue + + if key == "metadata": + set_parts.append(f"metadata = ${param_index}") + values.append(json.dumps(value)) + elif key == "last_accessed": + set_parts.append(f"last_accessed = NOW()") + else: + set_parts.append(f"{key} = ${param_index}") + values.append(value) + param_index += 1 + + if not set_parts: + return True + + query = f""" + UPDATE sessions + SET {", ".join(set_parts)} + WHERE session_id = $1 + """ + + result = await conn.execute(query, *values) + return result == "UPDATE 1" + + @staticmethod + async def delete_session(session_id: str) -> bool: + """Delete a session.""" + async with get_db_connection() as conn: + result = await conn.execute( + "DELETE FROM sessions WHERE session_id = $1", session_id + ) + return result == "DELETE 1" + + @staticmethod + async def list_sessions(limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]: + """List sessions with pagination.""" + async with get_db_connection() as conn: + rows = await conn.fetch( + """ + SELECT session_id, container_name, container_id, host_dir, port, + auth_token, created_at, last_accessed, status, metadata + FROM sessions + ORDER BY created_at DESC + LIMIT $1 OFFSET $2 + """, + limit, + offset, + ) + + result = [] + for row in rows: + session = dict(row) + session["metadata"] = json.loads(session["metadata"] or "{}") + result.append(session) + + return result + + @staticmethod + async def count_sessions() -> int: + """Count total sessions.""" + async with get_db_connection() as conn: + return await conn.fetchval("SELECT COUNT(*) FROM sessions") + + @staticmethod + async def cleanup_expired_sessions() -> int: + """Clean up expired sessions using database function.""" + async with get_db_connection() as conn: + return await conn.fetchval("SELECT cleanup_expired_sessions()") + + @staticmethod + async def get_active_sessions_count() -> int: + """Get count of active (running) sessions.""" + async with get_db_connection() as conn: + return await conn.fetchval(""" + SELECT COUNT(*) FROM sessions + WHERE status = 'running' + """) + + @staticmethod + async def get_sessions_by_status(status: str) -> List[Dict[str, Any]]: + """Get sessions by status.""" + async with get_db_connection() as conn: + rows = await conn.fetch( + """ + SELECT session_id, container_name, container_id, host_dir, port, + auth_token, created_at, last_accessed, status, metadata + FROM sessions + WHERE status = $1 + ORDER BY created_at DESC + """, + status, + ) + + result = [] + for row in rows: + session = dict(row) + session["metadata"] = json.loads(session["metadata"] or "{}") + result.append(session) + + return result + + +# Migration utilities +async def run_migrations(): + """Run database migrations.""" + logger.info("Running database migrations") + + async with get_db_connection() as conn: + # Migration 1: Add metadata column if it doesn't exist + try: + await conn.execute(""" + ALTER TABLE sessions ADD COLUMN IF NOT EXISTS metadata JSONB DEFAULT '{}' + """) + logger.info("Migration: Added metadata column") + except Exception as e: + logger.warning(f"Migration metadata column may already exist: {e}") + + # Migration 2: Add indexes if they don't exist + try: + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status); + CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed); + CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at); + """) + logger.info("Migration: Added performance indexes") + except Exception as e: + logger.warning(f"Migration indexes may already exist: {e}") + + logger.info("Database migrations completed") + + +# Health monitoring +async def get_database_stats() -> Dict[str, Any]: + """Get database statistics and health information.""" + try: + async with get_db_connection() as conn: + # Get basic stats + session_count = await conn.fetchval("SELECT COUNT(*) FROM sessions") + active_sessions = await conn.fetchval( + "SELECT COUNT(*) FROM sessions WHERE status = 'running'" + ) + oldest_session = await conn.fetchval("SELECT MIN(created_at) FROM sessions") + newest_session = await conn.fetchval("SELECT MAX(created_at) FROM sessions") + + # Get database size information + db_size = await conn.fetchval(""" + SELECT pg_size_pretty(pg_database_size(current_database())) + """) + + return { + "total_sessions": session_count, + "active_sessions": active_sessions, + "oldest_session": oldest_session.isoformat() + if oldest_session + else None, + "newest_session": newest_session.isoformat() + if newest_session + else None, + "database_size": db_size, + "status": "healthy", + } + except Exception as e: + logger.error("Failed to get database stats", extra={"error": str(e)}) + return { + "status": "unhealthy", + "error": str(e), + } diff --git a/session-manager/docker_service.py b/session-manager/docker_service.py new file mode 100644 index 0000000..149dd48 --- /dev/null +++ b/session-manager/docker_service.py @@ -0,0 +1,635 @@ +""" +Docker Service Layer + +Provides a clean abstraction for Docker operations, separating container management +from business logic. Enables easy testing, mocking, and future Docker client changes. +""" + +import os +import logging +from typing import Dict, List, Optional, Any, Tuple +from datetime import datetime + +from logging_config import get_logger + +logger = get_logger(__name__) + + +class ContainerInfo: + """Container information data structure.""" + + def __init__( + self, + container_id: str, + name: str, + image: str, + status: str, + ports: Optional[Dict[str, int]] = None, + created_at: Optional[datetime] = None, + health_status: Optional[str] = None, + ): + self.container_id = container_id + self.name = name + self.image = image + self.status = status + self.ports = ports or {} + self.created_at = created_at or datetime.utcnow() + self.health_status = health_status + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "container_id": self.container_id, + "name": self.name, + "image": self.image, + "status": self.status, + "ports": self.ports, + "created_at": self.created_at.isoformat() if self.created_at else None, + "health_status": self.health_status, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ContainerInfo": + """Create from dictionary.""" + return cls( + container_id=data["container_id"], + name=data["name"], + image=data["image"], + status=data["status"], + ports=data.get("ports", {}), + created_at=datetime.fromisoformat(data["created_at"]) + if data.get("created_at") + else None, + health_status=data.get("health_status"), + ) + + +class DockerOperationError(Exception): + """Docker operation error.""" + + def __init__( + self, operation: str, container_id: Optional[str] = None, message: str = "" + ): + self.operation = operation + self.container_id = container_id + self.message = message + super().__init__(f"Docker {operation} failed: {message}") + + +class DockerService: + """ + Docker service abstraction layer. + + Provides a clean interface for container operations, + enabling easy testing and future Docker client changes. + """ + + def __init__(self, use_async: bool = True): + """ + Initialize Docker service. + + Args: + use_async: Whether to use async Docker operations + """ + self.use_async = use_async + self._docker_client = None + self._initialized = False + + logger.info("Docker service initialized", extra={"async_mode": use_async}) + + async def initialize(self) -> None: + """Initialize the Docker client connection.""" + if self._initialized: + return + + try: + if self.use_async: + # Initialize async Docker client + from async_docker_client import AsyncDockerClient + + self._docker_client = AsyncDockerClient() + await self._docker_client.connect() + else: + # Initialize sync Docker client + import docker + + tls_config = docker.tls.TLSConfig( + ca_cert=os.getenv("DOCKER_CA_CERT", "/etc/docker/certs/ca.pem"), + client_cert=( + os.getenv( + "DOCKER_CLIENT_CERT", "/etc/docker/certs/client-cert.pem" + ), + os.getenv( + "DOCKER_CLIENT_KEY", "/etc/docker/certs/client-key.pem" + ), + ), + verify=True, + ) + docker_host = os.getenv( + "DOCKER_HOST", "tcp://host.docker.internal:2376" + ) + self._docker_client = docker.from_env() + self._docker_client.api = docker.APIClient( + base_url=docker_host, tls=tls_config, version="auto" + ) + # Test connection + self._docker_client.ping() + + self._initialized = True + logger.info("Docker service connection established") + + except Exception as e: + logger.error("Failed to initialize Docker service", extra={"error": str(e)}) + raise DockerOperationError("initialization", message=str(e)) + + async def shutdown(self) -> None: + """Shutdown the Docker client connection.""" + if not self._initialized: + return + + try: + if self.use_async and self._docker_client: + await self._docker_client.disconnect() + # Sync client doesn't need explicit shutdown + + self._initialized = False + logger.info("Docker service connection closed") + + except Exception as e: + logger.warning( + "Error during Docker service shutdown", extra={"error": str(e)} + ) + + async def ping(self) -> bool: + """Test Docker daemon connectivity.""" + if not self._initialized: + await self.initialize() + + try: + if self.use_async: + return await self._docker_client.ping() + else: + self._docker_client.ping() + return True + except Exception as e: + logger.warning("Docker ping failed", extra={"error": str(e)}) + return False + + async def create_container( + self, + name: str, + image: str, + volumes: Optional[Dict[str, Dict[str, str]]] = None, + ports: Optional[Dict[str, int]] = None, + environment: Optional[Dict[str, str]] = None, + network_mode: str = "bridge", + mem_limit: Optional[str] = None, + cpu_quota: Optional[int] = None, + cpu_period: Optional[int] = None, + tmpfs: Optional[Dict[str, str]] = None, + **kwargs, + ) -> ContainerInfo: + """ + Create a Docker container. + + Args: + name: Container name + image: Container image + volumes: Volume mounts + ports: Port mappings + environment: Environment variables + network_mode: Network mode + mem_limit: Memory limit + cpu_quota: CPU quota + cpu_period: CPU period + tmpfs: tmpfs mounts + **kwargs: Additional options + + Returns: + ContainerInfo: Information about created container + + Raises: + DockerOperationError: If container creation fails + """ + if not self._initialized: + await self.initialize() + + try: + logger.info( + "Creating container", + extra={ + "container_name": name, + "image": image, + "memory_limit": mem_limit, + "cpu_quota": cpu_quota, + }, + ) + + if self.use_async: + container = await self._docker_client.create_container( + image=image, + name=name, + volumes=volumes, + ports=ports, + environment=environment, + network_mode=network_mode, + mem_limit=mem_limit, + cpu_quota=cpu_quota, + cpu_period=cpu_period, + tmpfs=tmpfs, + **kwargs, + ) + + return ContainerInfo( + container_id=container.id, + name=name, + image=image, + status="created", + ports=ports, + ) + else: + container = self._docker_client.containers.run( + image, + name=name, + volumes=volumes, + ports={f"{port}/tcp": port for port in ports.values()} + if ports + else None, + environment=environment, + network_mode=network_mode, + mem_limit=mem_limit, + cpu_quota=cpu_quota, + cpu_period=cpu_period, + tmpfs=tmpfs, + detach=True, + **kwargs, + ) + + return ContainerInfo( + container_id=container.id, + name=name, + image=image, + status="running", + ports=ports, + ) + + except Exception as e: + logger.error( + "Container creation failed", + extra={"container_name": name, "image": image, "error": str(e)}, + ) + raise DockerOperationError("create_container", name, str(e)) + + async def start_container(self, container_id: str) -> None: + """ + Start a Docker container. + + Args: + container_id: Container ID + + Raises: + DockerOperationError: If container start fails + """ + if not self._initialized: + await self.initialize() + + try: + logger.info("Starting container", extra={"container_id": container_id}) + + if self.use_async: + container = await self._docker_client.get_container(container_id) + await self._docker_client.start_container(container) + else: + container = self._docker_client.containers.get(container_id) + container.start() + + logger.info( + "Container started successfully", extra={"container_id": container_id} + ) + + except Exception as e: + logger.error( + "Container start failed", + extra={"container_id": container_id, "error": str(e)}, + ) + raise DockerOperationError("start_container", container_id, str(e)) + + async def stop_container(self, container_id: str, timeout: int = 10) -> None: + """ + Stop a Docker container. + + Args: + container_id: Container ID + timeout: Stop timeout in seconds + + Raises: + DockerOperationError: If container stop fails + """ + if not self._initialized: + await self.initialize() + + try: + logger.info( + "Stopping container", + extra={"container_id": container_id, "timeout": timeout}, + ) + + if self.use_async: + container = await self._docker_client.get_container(container_id) + await self._docker_client.stop_container(container, timeout) + else: + container = self._docker_client.containers.get(container_id) + container.stop(timeout=timeout) + + logger.info( + "Container stopped successfully", extra={"container_id": container_id} + ) + + except Exception as e: + logger.error( + "Container stop failed", + extra={"container_id": container_id, "error": str(e)}, + ) + raise DockerOperationError("stop_container", container_id, str(e)) + + async def remove_container(self, container_id: str, force: bool = False) -> None: + """ + Remove a Docker container. + + Args: + container_id: Container ID + force: Force removal if running + + Raises: + DockerOperationError: If container removal fails + """ + if not self._initialized: + await self.initialize() + + try: + logger.info( + "Removing container", + extra={"container_id": container_id, "force": force}, + ) + + if self.use_async: + container = await self._docker_client.get_container(container_id) + await self._docker_client.remove_container(container, force) + else: + container = self._docker_client.containers.get(container_id) + container.remove(force=force) + + logger.info( + "Container removed successfully", extra={"container_id": container_id} + ) + + except Exception as e: + logger.error( + "Container removal failed", + extra={"container_id": container_id, "error": str(e)}, + ) + raise DockerOperationError("remove_container", container_id, str(e)) + + async def get_container_info(self, container_id: str) -> Optional[ContainerInfo]: + """ + Get information about a container. + + Args: + container_id: Container ID + + Returns: + ContainerInfo or None: Container information + """ + if not self._initialized: + await self.initialize() + + try: + if self.use_async: + container_info = await self._docker_client._get_container_info( + container_id + ) + if container_info: + state = container_info.get("State", {}) + config = container_info.get("Config", {}) + return ContainerInfo( + container_id=container_id, + name=config.get("Name", "").lstrip("/"), + image=config.get("Image", ""), + status=state.get("Status", "unknown"), + health_status=state.get("Health", {}).get("Status"), + ) + else: + container = self._docker_client.containers.get(container_id) + return ContainerInfo( + container_id=container.id, + name=container.name, + image=container.image.tags[0] + if container.image.tags + else container.image.id, + status=container.status, + ) + + return None + + except Exception as e: + logger.debug( + "Container info retrieval failed", + extra={"container_id": container_id, "error": str(e)}, + ) + return None + + async def list_containers( + self, all: bool = False, filters: Optional[Dict[str, Any]] = None + ) -> List[ContainerInfo]: + """ + List Docker containers. + + Args: + all: Include stopped containers + filters: Container filters + + Returns: + List[ContainerInfo]: List of container information + """ + if not self._initialized: + await self.initialize() + + try: + if self.use_async: + containers = await self._docker_client.list_containers( + all=all, filters=filters + ) + result = [] + for container in containers: + container_info = await self._docker_client._get_container_info( + container.id + ) + if container_info: + state = container_info.get("State", {}) + config = container_info.get("Config", {}) + result.append( + ContainerInfo( + container_id=container.id, + name=config.get("Name", "").lstrip("/"), + image=config.get("Image", ""), + status=state.get("Status", "unknown"), + health_status=state.get("Health", {}).get("Status"), + ) + ) + return result + else: + containers = self._docker_client.containers.list( + all=all, filters=filters + ) + result = [] + for container in containers: + result.append( + ContainerInfo( + container_id=container.id, + name=container.name, + image=container.image.tags[0] + if container.image.tags + else container.image.id, + status=container.status, + ) + ) + return result + + except Exception as e: + logger.error("Container listing failed", extra={"error": str(e)}) + return [] + + async def get_container_logs(self, container_id: str, tail: int = 100) -> str: + """ + Get container logs. + + Args: + container_id: Container ID + tail: Number of log lines to retrieve + + Returns: + str: Container logs + """ + if not self._initialized: + await self.initialize() + + try: + if self.use_async: + container = await self._docker_client.get_container(container_id) + logs = await container.log(stdout=True, stderr=True, tail=tail) + return "\n".join(logs) + else: + container = self._docker_client.containers.get(container_id) + logs = container.logs(tail=tail).decode("utf-8") + return logs + + except Exception as e: + logger.warning( + "Container log retrieval failed", + extra={"container_id": container_id, "error": str(e)}, + ) + return "" + + async def get_system_info(self) -> Optional[Dict[str, Any]]: + """ + Get Docker system information. + + Returns: + Dict or None: System information + """ + if not self._initialized: + await self.initialize() + + try: + if self.use_async: + return await self._docker_client.get_system_info() + else: + return self._docker_client.info() + except Exception as e: + logger.warning("System info retrieval failed", extra={"error": str(e)}) + return None + + # Context manager support + async def __aenter__(self): + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.shutdown() + + +class MockDockerService(DockerService): + """ + Mock Docker service for testing without actual Docker. + + Provides the same interface but with in-memory operations. + """ + + def __init__(self): + super().__init__(use_async=False) + self._containers: Dict[str, ContainerInfo] = {} + self._next_id = 1 + + async def initialize(self) -> None: + """Mock initialization - always succeeds.""" + self._initialized = True + logger.info("Mock Docker service initialized") + + async def shutdown(self) -> None: + """Mock shutdown.""" + self._containers.clear() + self._initialized = False + logger.info("Mock Docker service shutdown") + + async def ping(self) -> bool: + """Mock ping - always succeeds.""" + return True + + async def create_container(self, name: str, image: str, **kwargs) -> ContainerInfo: + """Mock container creation.""" + container_id = f"mock-{self._next_id}" + self._next_id += 1 + + container = ContainerInfo( + container_id=container_id, name=name, image=image, status="created" + ) + + self._containers[container_id] = container + logger.info( + "Mock container created", + extra={ + "container_id": container_id, + "container_name": name, + "image": image, + }, + ) + + return container + + async def start_container(self, container_id: str) -> None: + """Mock container start.""" + if container_id in self._containers: + self._containers[container_id].status = "running" + logger.info("Mock container started", extra={"container_id": container_id}) + + async def stop_container(self, container_id: str, timeout: int = 10) -> None: + """Mock container stop.""" + if container_id in self._containers: + self._containers[container_id].status = "exited" + logger.info("Mock container stopped", extra={"container_id": container_id}) + + async def remove_container(self, container_id: str, force: bool = False) -> None: + """Mock container removal.""" + if container_id in self._containers: + del self._containers[container_id] + logger.info("Mock container removed", extra={"container_id": container_id}) + + async def get_container_info(self, container_id: str) -> Optional[ContainerInfo]: + """Mock container info retrieval.""" + return self._containers.get(container_id) + + async def list_containers( + self, all: bool = False, filters: Optional[Dict[str, Any]] = None + ) -> List[ContainerInfo]: + """Mock container listing.""" + return list(self._containers.values()) diff --git a/session-manager/host_ip_detector.py b/session-manager/host_ip_detector.py new file mode 100644 index 0000000..6532299 --- /dev/null +++ b/session-manager/host_ip_detector.py @@ -0,0 +1,252 @@ +""" +Host IP Detection Utilities + +Provides robust methods to detect the Docker host IP from within a container, +supporting multiple Docker environments and network configurations. +""" + +import os +import socket +import asyncio +import logging +from typing import Optional, List +from functools import lru_cache +import time + +logger = logging.getLogger(__name__) + + +class HostIPDetector: + """Detects the Docker host IP address from container perspective.""" + + # Common Docker gateway IPs to try as fallbacks + COMMON_GATEWAYS = [ + "172.17.0.1", # Default Docker bridge + "172.18.0.1", # Docker networks + "192.168.65.1", # Docker Desktop + "192.168.66.1", # Alternative Docker Desktop + ] + + def __init__(self): + self._detected_ip: Optional[str] = None + self._last_detection: float = 0 + self._cache_timeout: float = 300 # 5 minutes cache + + @lru_cache(maxsize=1) + def detect_host_ip(self) -> str: + """ + Detect the Docker host IP using multiple methods with fallbacks. + + Returns: + str: The detected host IP address + + Raises: + RuntimeError: If no host IP can be detected + """ + current_time = time.time() + + # Use cached result if recent + if ( + self._detected_ip + and (current_time - self._last_detection) < self._cache_timeout + ): + logger.debug(f"Using cached host IP: {self._detected_ip}") + return self._detected_ip + + logger.info("Detecting Docker host IP...") + + detection_methods = [ + self._detect_via_docker_internal, + self._detect_via_gateway_env, + self._detect_via_route_table, + self._detect_via_network_connect, + self._detect_via_common_gateways, + ] + + for method in detection_methods: + try: + ip = method() + if ip and self._validate_ip(ip): + logger.info( + f"Successfully detected host IP using {method.__name__}: {ip}" + ) + self._detected_ip = ip + self._last_detection = current_time + return ip + else: + logger.debug(f"Method {method.__name__} returned invalid IP: {ip}") + except Exception as e: + logger.debug(f"Method {method.__name__} failed: {e}") + + # If all methods fail, raise an error + raise RuntimeError( + "Could not detect Docker host IP. Tried all detection methods. " + "Please check your Docker network configuration or set HOST_IP environment variable." + ) + + def _detect_via_docker_internal(self) -> Optional[str]: + """Detect via host.docker.internal (Docker Desktop, Docker for Mac/Windows).""" + try: + # Try to resolve host.docker.internal + ip = socket.gethostbyname("host.docker.internal") + if ip != "127.0.0.1": # Make sure it's not localhost + return ip + except socket.gaierror: + pass + + return None + + def _detect_via_gateway_env(self) -> Optional[str]: + """Detect via Docker gateway environment variables.""" + # Check common Docker gateway environment variables + gateway_vars = [ + "DOCKER_HOST_GATEWAY", + "GATEWAY", + "HOST_IP", + ] + + for var in gateway_vars: + ip = os.getenv(var) + if ip: + logger.debug(f"Found host IP in environment variable {var}: {ip}") + return ip + + return None + + def _detect_via_route_table(self) -> Optional[str]: + """Detect via Linux route table (/proc/net/route).""" + try: + with open("/proc/net/route", "r") as f: + for line in f: + fields = line.strip().split() + if ( + len(fields) >= 8 + and fields[0] != "Iface" + and fields[7] == "00000000" + ): + # Found default route, convert hex gateway to IP + gateway_hex = fields[2] + if len(gateway_hex) == 8: + # Convert from hex to IP (little endian) + ip_parts = [] + for i in range(0, 8, 2): + ip_parts.append(str(int(gateway_hex[i : i + 2], 16))) + ip = ".".join(reversed(ip_parts)) + if ip != "0.0.0.0": + return ip + except (IOError, ValueError, IndexError) as e: + logger.debug(f"Failed to read route table: {e}") + + return None + + def _detect_via_network_connect(self) -> Optional[str]: + """Detect by attempting to connect to a known external service.""" + try: + # Try to connect to a reliable external service to determine local IP + # We'll use the Docker daemon itself as a reference + docker_host = os.getenv("DOCKER_HOST", "tcp://host.docker.internal:2376") + + if docker_host.startswith("tcp://"): + host_part = docker_host[6:].split(":")[0] + if host_part not in ["localhost", "127.0.0.1"]: + # Try to resolve the host + try: + ip = socket.gethostbyname(host_part) + if ip != "127.0.0.1": + return ip + except socket.gaierror: + pass + except Exception as e: + logger.debug(f"Network connect detection failed: {e}") + + return None + + def _detect_via_common_gateways(self) -> Optional[str]: + """Try common Docker gateway IPs.""" + for gateway in self.COMMON_GATEWAYS: + if self._test_ip_connectivity(gateway): + logger.debug(f"Found working gateway: {gateway}") + return gateway + + return None + + def _test_ip_connectivity(self, ip: str) -> bool: + """Test if an IP address is reachable.""" + try: + # Try to connect to a common port (Docker API or SSH) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1.0) + result = sock.connect_ex((ip, 22)) # SSH port, commonly available + sock.close() + return result == 0 + except Exception: + return False + + def _validate_ip(self, ip: str) -> bool: + """Validate that the IP address is reasonable.""" + try: + socket.inet_aton(ip) + # Basic validation - should not be localhost or invalid ranges + if ip.startswith("127."): + return False + if ip == "0.0.0.0": + return False + # Should be a private IP range + parts = ip.split(".") + if len(parts) != 4: + return False + first_octet = int(parts[0]) + # Common Docker gateway ranges + return first_octet in [10, 172, 192] + except socket.error: + return False + + async def async_detect_host_ip(self) -> str: + """Async version of detect_host_ip for testing.""" + import asyncio + import concurrent.futures + + loop = asyncio.get_event_loop() + with concurrent.futures.ThreadPoolExecutor() as executor: + return await loop.run_in_executor(executor, self.detect_host_ip) + + +# Global instance for caching +_host_detector = HostIPDetector() + + +def get_host_ip() -> str: + """ + Get the Docker host IP address from container perspective. + + This function caches the result for performance and tries multiple + detection methods with fallbacks for different Docker environments. + + Returns: + str: The detected host IP address + + Raises: + RuntimeError: If host IP detection fails + """ + return _host_detector.detect_host_ip() + + +async def async_get_host_ip() -> str: + """ + Async version of get_host_ip for use in async contexts. + + Since the actual detection is not async, this just wraps the sync version. + """ + # Run in thread pool to avoid blocking async context + import concurrent.futures + import asyncio + + loop = asyncio.get_event_loop() + with concurrent.futures.ThreadPoolExecutor() as executor: + return await loop.run_in_executor(executor, get_host_ip) + + +def reset_host_ip_cache(): + """Reset the cached host IP detection result.""" + global _host_detector + _host_detector = HostIPDetector() diff --git a/session-manager/http_pool.py b/session-manager/http_pool.py new file mode 100644 index 0000000..80e9f25 --- /dev/null +++ b/session-manager/http_pool.py @@ -0,0 +1,182 @@ +""" +HTTP Connection Pool Manager + +Provides a global httpx.AsyncClient instance with connection pooling +to eliminate the overhead of creating new HTTP clients for each proxy request. +""" + +import asyncio +import logging +import time +from typing import Optional, Dict, Any +from contextlib import asynccontextmanager + +import httpx + +logger = logging.getLogger(__name__) + + +class HTTPConnectionPool: + """Global HTTP connection pool manager for proxy operations.""" + + def __init__(self): + self._client: Optional[httpx.AsyncClient] = None + self._last_health_check: float = 0 + self._health_check_interval: float = 60 # Check health every 60 seconds + self._is_healthy: bool = True + self._reconnect_lock = asyncio.Lock() + + # Connection pool configuration + self._config = { + "limits": httpx.Limits( + max_keepalive_connections=20, # Keep connections alive + max_connections=100, # Max total connections + keepalive_expiry=300.0, # Keep connections alive for 5 minutes + ), + "timeout": httpx.Timeout( + connect=10.0, # Connection timeout + read=30.0, # Read timeout + write=10.0, # Write timeout + pool=5.0, # Pool timeout + ), + "follow_redirects": False, + "http2": False, # Disable HTTP/2 for simplicity + } + + async def __aenter__(self): + await self.ensure_client() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Keep client alive - don't close it + pass + + async def ensure_client(self) -> None: + """Ensure the HTTP client is initialized and healthy.""" + if self._client is None: + await self._create_client() + + # Periodic health check + current_time = time.time() + if current_time - self._last_health_check > self._health_check_interval: + if not await self._check_client_health(): + logger.warning("HTTP client health check failed, recreating client") + await self._recreate_client() + self._last_health_check = current_time + + async def _create_client(self) -> None: + """Create a new HTTP client with connection pooling.""" + async with self._reconnect_lock: + if self._client: + await self._client.aclose() + + self._client = httpx.AsyncClient(**self._config) + self._is_healthy = True + logger.info("HTTP connection pool client created") + + async def _recreate_client(self) -> None: + """Recreate the HTTP client (used when health check fails).""" + logger.info("Recreating HTTP connection pool client") + await self._create_client() + + async def _check_client_health(self) -> bool: + """Check if the HTTP client is still healthy.""" + if not self._client: + return False + + try: + # Simple health check - we could ping a reliable endpoint + # For now, just check if client is still responsive + # In a real implementation, you might ping a health endpoint + return self._is_healthy + except Exception as e: + logger.warning(f"HTTP client health check error: {e}") + return False + + async def request(self, method: str, url: str, **kwargs) -> httpx.Response: + """Make an HTTP request using the connection pool.""" + await self.ensure_client() + + if not self._client: + raise RuntimeError("HTTP client not available") + + try: + response = await self._client.request(method, url, **kwargs) + return response + except (httpx.ConnectError, httpx.ConnectTimeout, httpx.PoolTimeout) as e: + # Connection-related errors - client might be unhealthy + logger.warning(f"Connection error, marking client as unhealthy: {e}") + self._is_healthy = False + raise + except Exception as e: + # Other errors - re-raise as-is + raise + + async def close(self) -> None: + """Close the HTTP client and cleanup resources.""" + async with self._reconnect_lock: + if self._client: + await self._client.aclose() + self._client = None + self._is_healthy = False + logger.info("HTTP connection pool client closed") + + async def get_pool_stats(self) -> Dict[str, Any]: + """Get connection pool statistics.""" + if not self._client: + return {"status": "not_initialized"} + + # httpx doesn't expose detailed pool stats, but we can provide basic info + return { + "status": "healthy" if self._is_healthy else "unhealthy", + "last_health_check": self._last_health_check, + "config": { + "max_keepalive_connections": self._config[ + "limits" + ].max_keepalive_connections, + "max_connections": self._config["limits"].max_connections, + "keepalive_expiry": self._config["limits"].keepalive_expiry, + "connect_timeout": self._config["timeout"].connect, + "read_timeout": self._config["timeout"].read, + }, + } + + +# Global HTTP connection pool instance +_http_pool = HTTPConnectionPool() + + +@asynccontextmanager +async def get_http_client(): + """Context manager for getting the global HTTP client.""" + async with _http_pool: + yield _http_pool + + +async def make_http_request(method: str, url: str, **kwargs) -> httpx.Response: + """Make an HTTP request using the global connection pool.""" + async with get_http_client() as client: + return await client.request(method, url, **kwargs) + + +async def get_connection_pool_stats() -> Dict[str, Any]: + """Get connection pool statistics.""" + return await _http_pool.get_pool_stats() + + +async def close_connection_pool() -> None: + """Close the global connection pool (for cleanup).""" + await _http_pool.close() + + +# Lifecycle management for FastAPI +async def init_http_pool() -> None: + """Initialize the HTTP connection pool on startup.""" + logger.info("Initializing HTTP connection pool") + await _http_pool.ensure_client() + + +async def shutdown_http_pool() -> None: + """Shutdown the HTTP connection pool on shutdown.""" + logger.info("Shutting down HTTP connection pool") + await _http_pool.close() diff --git a/session-manager/logging_config.py b/session-manager/logging_config.py new file mode 100644 index 0000000..ea9ee16 --- /dev/null +++ b/session-manager/logging_config.py @@ -0,0 +1,317 @@ +""" +Structured Logging Configuration + +Provides comprehensive logging infrastructure with structured logging, +request tracking, log formatting, and aggregation capabilities. +""" + +import os +import sys +import json +import logging +import logging.handlers +from typing import Dict, Any, Optional +from datetime import datetime +from pathlib import Path +import threading +import uuid + + +class StructuredFormatter(logging.Formatter): + """Structured JSON formatter for production logging.""" + + def format(self, record: logging.LogRecord) -> str: + # Create structured log entry + log_entry = { + "timestamp": datetime.utcnow().isoformat() + "Z", + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + "module": record.module, + "function": record.funcName, + "line": record.lineno, + } + + # Add exception info if present + if record.exc_info: + log_entry["exception"] = self.formatException(record.exc_info) + + # Add extra fields from record + if hasattr(record, "request_id"): + log_entry["request_id"] = record.request_id + if hasattr(record, "session_id"): + log_entry["session_id"] = record.session_id + if hasattr(record, "user_id"): + log_entry["user_id"] = record.user_id + if hasattr(record, "operation"): + log_entry["operation"] = record.operation + if hasattr(record, "duration_ms"): + log_entry["duration_ms"] = record.duration_ms + if hasattr(record, "status_code"): + log_entry["status_code"] = record.status_code + + # Add any additional structured data + if hasattr(record, "__dict__"): + for key, value in record.__dict__.items(): + if key not in [ + "name", + "msg", + "args", + "levelname", + "levelno", + "pathname", + "filename", + "module", + "exc_info", + "exc_text", + "stack_info", + "lineno", + "funcName", + "created", + "msecs", + "relativeCreated", + "thread", + "threadName", + "processName", + "process", + "message", + ]: + log_entry[key] = value + + return json.dumps(log_entry, default=str) + + +class HumanReadableFormatter(logging.Formatter): + """Human-readable formatter for development.""" + + def __init__(self): + super().__init__( + fmt="%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + def format(self, record: logging.LogRecord) -> str: + # Add request ID to human readable format + if hasattr(record, "request_id"): + self._fmt = "%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d [%(request_id)s] - %(message)s" + else: + self._fmt = "%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d - %(message)s" + + return super().format(record) + + +class RequestContext: + """Context manager for request-scoped logging.""" + + _local = threading.local() + + def __init__(self, request_id: Optional[str] = None): + self.request_id = request_id or str(uuid.uuid4())[:8] + self._old_request_id = getattr(self._local, "request_id", None) + + def __enter__(self): + self._local.request_id = self.request_id + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._old_request_id is not None: + self._local.request_id = self._old_request_id + else: + delattr(self._local, "request_id") + + @classmethod + def get_current_request_id(cls) -> Optional[str]: + """Get the current request ID from thread local storage.""" + return getattr(cls._local, "request_id", None) + + +class RequestAdapter(logging.LoggerAdapter): + """Logger adapter that automatically adds request context.""" + + def __init__(self, logger: logging.Logger): + super().__init__(logger, {}) + + def process(self, msg: str, kwargs: Any) -> tuple: + """Add request context to log records.""" + request_id = RequestContext.get_current_request_id() + if request_id: + kwargs.setdefault("extra", {})["request_id"] = request_id + return msg, kwargs + + +def setup_logging( + level: str = "INFO", + format_type: str = "auto", # "json", "human", or "auto" + log_file: Optional[str] = None, + max_file_size: int = 10 * 1024 * 1024, # 10MB + backup_count: int = 5, + enable_console: bool = True, + enable_file: bool = True, +) -> logging.Logger: + """ + Set up comprehensive logging configuration. + + Args: + level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + format_type: Log format - "json", "human", or "auto" (detects from environment) + log_file: Path to log file (optional) + max_file_size: Maximum log file size in bytes + backup_count: Number of backup files to keep + enable_console: Enable console logging + enable_file: Enable file logging + + Returns: + Configured root logger + """ + # Determine format type + if format_type == "auto": + # Use JSON for production, human-readable for development + format_type = "json" if os.getenv("ENVIRONMENT") == "production" else "human" + + # Clear existing handlers + root_logger = logging.getLogger() + root_logger.handlers.clear() + + # Set log level + numeric_level = getattr(logging, level.upper(), logging.INFO) + root_logger.setLevel(numeric_level) + + # Create formatters + if format_type == "json": + formatter = StructuredFormatter() + else: + formatter = HumanReadableFormatter() + + # Console handler + if enable_console: + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(numeric_level) + console_handler.setFormatter(formatter) + root_logger.addHandler(console_handler) + + # File handler with rotation + if enable_file and log_file: + # Ensure log directory exists + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + file_handler = logging.handlers.RotatingFileHandler( + log_file, maxBytes=max_file_size, backupCount=backup_count, encoding="utf-8" + ) + file_handler.setLevel(numeric_level) + file_handler.setFormatter(StructuredFormatter()) # Always use JSON for files + root_logger.addHandler(file_handler) + + # Create request adapter for the root logger + adapter = RequestAdapter(root_logger) + + # Configure third-party loggers + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("docker").setLevel(logging.WARNING) + logging.getLogger("aiodeocker").setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) + + return adapter + + +def get_logger(name: str) -> RequestAdapter: + """Get a configured logger with request context support.""" + logger = logging.getLogger(name) + return RequestAdapter(logger) + + +def log_performance(operation: str, duration_ms: float, **kwargs) -> None: + """Log performance metrics.""" + logger = get_logger(__name__) + extra = {"operation": operation, "duration_ms": duration_ms, **kwargs} + logger.info( + f"Performance: {operation} completed in {duration_ms:.2f}ms", extra=extra + ) + + +def log_request( + method: str, path: str, status_code: int, duration_ms: float, **kwargs +) -> None: + """Log HTTP request metrics.""" + logger = get_logger(__name__) + extra = { + "operation": "http_request", + "method": method, + "path": path, + "status_code": status_code, + "duration_ms": duration_ms, + **kwargs, + } + level = logging.INFO if status_code < 400 else logging.WARNING + logger.log( + level, + f"HTTP {method} {path} -> {status_code} ({duration_ms:.2f}ms)", + extra=extra, + ) + + +def log_session_operation(session_id: str, operation: str, **kwargs) -> None: + """Log session-related operations.""" + logger = get_logger(__name__) + extra = {"session_id": session_id, "operation": operation, **kwargs} + logger.info(f"Session {operation}: {session_id}", extra=extra) + + +def log_security_event(event: str, severity: str = "info", **kwargs) -> None: + """Log security-related events.""" + logger = get_logger(__name__) + extra = {"security_event": event, "severity": severity, **kwargs} + level = getattr(logging, severity.upper(), logging.INFO) + logger.log(level, f"Security: {event}", extra=extra) + + +# Global logger instance +logger = get_logger(__name__) + +# Initialize logging on import +_setup_complete = False + + +def init_logging(): + """Initialize logging system.""" + global _setup_complete + if _setup_complete: + return + + # Configuration from environment + level = os.getenv("LOG_LEVEL", "INFO") + format_type = os.getenv("LOG_FORMAT", "auto") # json, human, auto + log_file = os.getenv("LOG_FILE") + max_file_size = int(os.getenv("LOG_MAX_SIZE_MB", "10")) * 1024 * 1024 + backup_count = int(os.getenv("LOG_BACKUP_COUNT", "5")) + enable_console = os.getenv("LOG_CONSOLE", "true").lower() == "true" + enable_file = ( + os.getenv("LOG_FILE_ENABLED", "true").lower() == "true" and log_file is not None + ) + + setup_logging( + level=level, + format_type=format_type, + log_file=log_file, + max_file_size=max_file_size, + backup_count=backup_count, + enable_console=enable_console, + enable_file=enable_file, + ) + + logger.info( + "Structured logging initialized", + extra={ + "level": level, + "format": format_type, + "log_file": log_file, + "max_file_size_mb": max_file_size // (1024 * 1024), + "backup_count": backup_count, + }, + ) + + _setup_complete = True + + +# Initialize on import +init_logging() diff --git a/session-manager/main.py b/session-manager/main.py index b1a6419..c6e8426 100644 --- a/session-manager/main.py +++ b/session-manager/main.py @@ -9,6 +9,8 @@ import os import uuid import json import asyncio +import logging +import time from datetime import datetime, timedelta from pathlib import Path from typing import Dict, Optional, List @@ -22,15 +24,134 @@ from pydantic import BaseModel import uvicorn import httpx +# Import host IP detection utility +from host_ip_detector import async_get_host_ip, reset_host_ip_cache -# Configuration +# Import resource management utilities +from resource_manager import ( + get_resource_limits, + check_system_resources, + should_throttle_sessions, + ResourceLimits, + ResourceValidator, +) + +# Import structured logging +from logging_config import ( + get_logger, + RequestContext, + log_performance, + log_request, + log_session_operation, + log_security_event, + init_logging, +) + +# Import async Docker client +from async_docker_client import ( + get_async_docker_client, + async_create_container, + async_start_container, + async_stop_container, + async_remove_container, + async_get_container, + async_list_containers, + async_docker_ping, +) + +# Import HTTP connection pool +from http_pool import ( + make_http_request, + get_connection_pool_stats, + init_http_pool, + shutdown_http_pool, +) + +# Import session authentication +from session_auth import ( + generate_session_auth_token, + validate_session_auth_token, + revoke_session_auth_token, + rotate_session_auth_token, + cleanup_expired_auth_tokens, + get_session_auth_info, + get_active_auth_sessions_count, + list_active_auth_sessions, +) + +# Import database layer +from database import ( + init_database, + shutdown_database, + SessionModel, + get_database_stats, + run_migrations, +) + +# Import container health monitoring +from container_health import ( + start_container_health_monitoring, + stop_container_health_monitoring, + get_container_health_stats, + get_container_health_history, +) + +# Import Docker service abstraction +from docker_service import DockerService, DockerOperationError + +# Import database layer +from database import ( + init_database, + shutdown_database, + SessionModel, + get_database_stats, + run_migrations, +) + +# Initialize structured logging +init_logging() + +# Get configured logger +logger = get_logger(__name__) + +# Configuration for async operations +USE_ASYNC_DOCKER = os.getenv("USE_ASYNC_DOCKER", "true").lower() == "true" + +# Session storage configuration +USE_DATABASE_STORAGE = os.getenv("USE_DATABASE_STORAGE", "true").lower() == "true" + + +# Configuration - Resource limits are now configurable and enforced SESSIONS_DIR = Path("/app/sessions") SESSIONS_FILE = Path("/app/sessions/sessions.json") -CONTAINER_IMAGE = "lovdata-opencode:latest" -MAX_CONCURRENT_SESSIONS = 3 # Workstation limit -SESSION_TIMEOUT_MINUTES = 60 # Auto-cleanup after 1 hour -CONTAINER_MEMORY_LIMIT = "4g" -CONTAINER_CPU_QUOTA = 100000 # 1 CPU core +CONTAINER_IMAGE = os.getenv("CONTAINER_IMAGE", "lovdata-opencode:latest") + +# Resource limits - configurable via environment variables with defaults +CONTAINER_MEMORY_LIMIT = os.getenv( + "CONTAINER_MEMORY_LIMIT", "4g" +) # Memory limit per container +CONTAINER_CPU_QUOTA = int( + os.getenv("CONTAINER_CPU_QUOTA", "100000") +) # CPU quota (100000 = 1 core) +CONTAINER_CPU_PERIOD = int( + os.getenv("CONTAINER_CPU_PERIOD", "100000") +) # CPU period (microseconds) + +# Session management +MAX_CONCURRENT_SESSIONS = int( + os.getenv("MAX_CONCURRENT_SESSIONS", "3") +) # Max concurrent sessions +SESSION_TIMEOUT_MINUTES = int( + os.getenv("SESSION_TIMEOUT_MINUTES", "60") +) # Auto-cleanup timeout + +# Resource monitoring thresholds +MEMORY_WARNING_THRESHOLD = float( + os.getenv("MEMORY_WARNING_THRESHOLD", "0.8") +) # 80% memory usage +CPU_WARNING_THRESHOLD = float( + os.getenv("CPU_WARNING_THRESHOLD", "0.9") +) # 90% CPU usage class SessionData(BaseModel): @@ -39,26 +160,49 @@ class SessionData(BaseModel): container_id: Optional[str] = None host_dir: str port: Optional[int] = None + auth_token: Optional[str] = None # Authentication token for the session created_at: datetime last_accessed: datetime status: str = "creating" # creating, running, stopped, error class SessionManager: - def __init__(self): - # Use Docker library 7.1.0 with Unix socket support - import docker + def __init__(self, docker_service: Optional[DockerService] = None): + # Use injected Docker service or create default + if docker_service: + self.docker_service = docker_service + else: + self.docker_service = DockerService(use_async=USE_ASYNC_DOCKER) - self.docker_client = docker.from_env() - # Test the connection - self.docker_client.ping() - print("Docker library client initialized successfully") + # Initialize session storage + if USE_DATABASE_STORAGE: + # Use database backend + self.sessions: Dict[ + str, SessionData + ] = {} # Keep in-memory cache for performance + logger.info("Session storage initialized", extra={"backend": "database"}) + else: + # Use JSON file backend (legacy) + self.sessions: Dict[str, SessionData] = {} + self._load_sessions_from_file() + logger.info("Session storage initialized", extra={"backend": "json_file"}) - self.sessions: Dict[str, SessionData] = {} - self._load_sessions() + # Initialize container health monitoring + from container_health import get_container_health_monitor - def _load_sessions(self): - """Load session data from persistent storage""" + self.health_monitor = get_container_health_monitor() + # Dependencies will be set when health monitoring starts + + logger.info( + "SessionManager initialized", + extra={ + "docker_service_type": type(self.docker_service).__name__, + "storage_backend": "database" if USE_DATABASE_STORAGE else "json_file", + }, + ) + + def _load_sessions_from_file(self): + """Load session data from JSON file (legacy method)""" if SESSIONS_FILE.exists(): try: with open(SESSIONS_FILE, "r") as f: @@ -72,10 +216,36 @@ class SessionManager: session_dict["last_accessed"] ) self.sessions[session_id] = SessionData(**session_dict) + logger.info( + "Sessions loaded from JSON file", + extra={"count": len(self.sessions)}, + ) except (json.JSONDecodeError, KeyError) as e: - print(f"Warning: Could not load sessions file: {e}") + logger.warning("Could not load sessions file", extra={"error": str(e)}) self.sessions = {} + async def _load_sessions_from_database(self): + """Load active sessions from database into memory cache""" + try: + # Load only running/creating sessions to avoid loading too much data + db_sessions = await SessionModel.get_sessions_by_status("running") + db_sessions.extend(await SessionModel.get_sessions_by_status("creating")) + + self.sessions = {} + for session_dict in db_sessions: + # Convert to SessionData model + session_data = SessionData(**session_dict) + self.sessions[session_dict["session_id"]] = session_data + + logger.info( + "Sessions loaded from database", extra={"count": len(self.sessions)} + ) + except Exception as e: + logger.error( + "Failed to load sessions from database", extra={"error": str(e)} + ) + self.sessions = {} + def _save_sessions(self): """Save session data to persistent storage""" SESSIONS_DIR.mkdir(exist_ok=True) @@ -100,21 +270,37 @@ class SessionManager: def _check_container_limits(self) -> bool: """Check if we're within concurrent session limits""" - if not self.docker_client: - return False active_sessions = sum( 1 for s in self.sessions.values() if s.status in ["creating", "running"] ) return active_sessions < MAX_CONCURRENT_SESSIONS + async def _async_check_container_limits(self) -> bool: + """Async version of container limits check""" + return self._check_container_limits() + async def create_session(self) -> SessionData: """Create a new OpenCode session with dedicated container""" - if not self._check_container_limits(): + # Check concurrent session limits + if USE_ASYNC_DOCKER: + limits_ok = await self._async_check_container_limits() + else: + limits_ok = self._check_container_limits() + + if not limits_ok: raise HTTPException( status_code=429, detail=f"Maximum concurrent sessions ({MAX_CONCURRENT_SESSIONS}) reached", ) + # Check system resource limits + should_throttle, reason = should_throttle_sessions() + if should_throttle: + raise HTTPException( + status_code=503, + detail=f"System resource constraints prevent new sessions: {reason}", + ) + session_id = self._generate_session_id() container_name = f"opencode-{session_id}" host_dir = str(SESSIONS_DIR / session_id) @@ -123,65 +309,256 @@ class SessionManager: # Create host directory Path(host_dir).mkdir(parents=True, exist_ok=True) + # Generate authentication token for this session + auth_token = generate_session_auth_token(session_id) + session = SessionData( session_id=session_id, container_name=container_name, host_dir=host_dir, port=port, + auth_token=auth_token, created_at=datetime.now(), last_accessed=datetime.now(), status="creating", ) + # Store in memory cache self.sessions[session_id] = session - self._save_sessions() - # Start container in background - asyncio.create_task(self._start_container(session)) + # Persist to database if using database storage + if USE_DATABASE_STORAGE: + try: + await SessionModel.create_session( + { + "session_id": session_id, + "container_name": container_name, + "host_dir": host_dir, + "port": port, + "auth_token": auth_token, + "status": "creating", + } + ) + logger.info( + "Session created in database", extra={"session_id": session_id} + ) + except Exception as e: + logger.error( + "Failed to create session in database", + extra={"session_id": session_id, "error": str(e)}, + ) + # Continue with in-memory storage as fallback + + # Start container asynchronously + if USE_ASYNC_DOCKER: + asyncio.create_task(self._start_container_async(session)) + else: + asyncio.create_task(self._start_container_sync(session)) return session - async def _start_container(self, session: SessionData): - """Start the OpenCode container for a session""" + async def _start_container_async(self, session: SessionData): + """Start the OpenCode container asynchronously using aiodeocker""" try: - # Create and start the OpenCode container - container = self.docker_client.containers.run( - "lovdata-opencode:latest", # Will be built from the Dockerfile + # Get and validate resource limits + resource_limits = get_resource_limits() + + logger.info( + f"Starting container {session.container_name} with resource limits: memory={resource_limits.memory_limit}, cpu_quota={resource_limits.cpu_quota}" + ) + + # Create container using Docker service + container_info = await self.docker_service.create_container( name=session.container_name, + image=CONTAINER_IMAGE, volumes={session.host_dir: {"bind": "/app/somedir", "mode": "rw"}}, - ports={f"8080/tcp": session.port}, - detach=True, + ports={"8080": session.port}, environment={ "MCP_SERVER": os.getenv("MCP_SERVER", ""), "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY", ""), "ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", ""), "GOOGLE_API_KEY": os.getenv("GOOGLE_API_KEY", ""), + # Secure authentication for OpenCode server + "OPENCODE_SERVER_PASSWORD": session.auth_token or "", + "SESSION_AUTH_TOKEN": session.auth_token or "", + "SESSION_ID": session.session_id, }, network_mode="bridge", + # Apply resource limits to prevent resource exhaustion + mem_limit=resource_limits.memory_limit, + cpu_quota=resource_limits.cpu_quota, + cpu_period=resource_limits.cpu_period, + # Additional security and resource constraints + tmpfs={ + "/tmp": "rw,noexec,nosuid,size=100m", + "/var/tmp": "rw,noexec,nosuid,size=50m", + }, ) - session.container_id = container.id + # For async mode, containers are already started during creation + # For sync mode, we need to explicitly start them + if not self.docker_service.use_async: + await self.docker_service.start_container(container_info.container_id) + + session.container_id = container_info.container_id session.status = "running" - self._save_sessions() - print(f"Container {session.container_name} started on port {session.port}") + + # Update in-memory cache + self.sessions[session.session_id] = session + + # Update database if using database storage + if USE_DATABASE_STORAGE: + try: + await SessionModel.update_session( + session.session_id, + { + "container_id": container_info.container_id, + "status": "running", + }, + ) + except Exception as e: + logger.error( + "Failed to update session in database", + extra={"session_id": session.session_id, "error": str(e)}, + ) + + logger.info( + "Container started successfully", + extra={ + "session_id": session.session_id, + "container_name": session.container_name, + "container_id": container_info.container_id, + "port": session.port, + }, + ) except Exception as e: session.status = "error" self._save_sessions() - print(f"Failed to start container {session.container_name}: {e}") + logger.error(f"Failed to start container {session.container_name}: {e}") + + async def _start_container_sync(self, session: SessionData): + """Start the OpenCode container using Docker service (sync mode)""" + try: + # Get and validate resource limits + resource_limits = get_resource_limits() + + logger.info( + f"Starting container {session.container_name} with resource limits: memory={resource_limits.memory_limit}, cpu_quota={resource_limits.cpu_quota}" + ) + + # Create container using Docker service + container_info = await self.docker_service.create_container( + name=session.container_name, + image=CONTAINER_IMAGE, + volumes={session.host_dir: {"bind": "/app/somedir", "mode": "rw"}}, + ports={"8080": session.port}, + environment={ + "MCP_SERVER": os.getenv("MCP_SERVER", ""), + "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY", ""), + "ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", ""), + "GOOGLE_API_KEY": os.getenv("GOOGLE_API_KEY", ""), + # Secure authentication for OpenCode server + "OPENCODE_SERVER_PASSWORD": session.auth_token or "", + "SESSION_AUTH_TOKEN": session.auth_token or "", + "SESSION_ID": session.session_id, + }, + network_mode="bridge", + # Apply resource limits to prevent resource exhaustion + mem_limit=resource_limits.memory_limit, + cpu_quota=resource_limits.cpu_quota, + cpu_period=resource_limits.cpu_period, + # Additional security and resource constraints + tmpfs={ + "/tmp": "rw,noexec,nosuid,size=100m", + "/var/tmp": "rw,noexec,nosuid,size=50m", + }, + ) + + session.container_id = container_info.container_id + session.status = "running" + + # Update in-memory cache + self.sessions[session.session_id] = session + + # Update database if using database storage + if USE_DATABASE_STORAGE: + try: + await SessionModel.update_session( + session.session_id, + { + "container_id": container_info.container_id, + "status": "running", + }, + ) + except Exception as e: + logger.error( + "Failed to update session in database", + extra={"session_id": session.session_id, "error": str(e)}, + ) + + logger.info( + "Container started successfully", + extra={ + "session_id": session.session_id, + "container_name": session.container_name, + "container_id": container_info.container_id, + "port": session.port, + }, + ) + + except Exception as e: + session.status = "error" + self._save_sessions() + logger.error(f"Failed to start container {session.container_name}: {e}") async def get_session(self, session_id: str) -> Optional[SessionData]: """Get session information""" + # Check in-memory cache first session = self.sessions.get(session_id) if session: session.last_accessed = datetime.now() - self._save_sessions() - return session + # Update database if using database storage + if USE_DATABASE_STORAGE: + try: + await SessionModel.update_session( + session_id, {"last_accessed": datetime.now()} + ) + except Exception as e: + logger.warning( + "Failed to update session access time in database", + extra={"session_id": session_id, "error": str(e)}, + ) + return session + + # If not in cache and using database, try to load from database + if USE_DATABASE_STORAGE: + try: + db_session = await SessionModel.get_session(session_id) + if db_session: + # Convert to SessionData and cache it + session_data = SessionData(**db_session) + self.sessions[session_id] = session_data + logger.debug( + "Session loaded from database", extra={"session_id": session_id} + ) + return session_data + except Exception as e: + logger.error( + "Failed to load session from database", + extra={"session_id": session_id, "error": str(e)}, + ) + + return None async def list_sessions(self) -> List[SessionData]: """List all sessions""" return list(self.sessions.values()) + async def list_containers_async(self, all: bool = False) -> List: + """List containers asynchronously""" + return await self.docker_service.list_containers(all=all) + async def cleanup_expired_sessions(self): """Clean up expired sessions and their containers""" now = datetime.now() @@ -194,30 +571,38 @@ class SessionManager: # Stop and remove container try: - container = self.docker_client.containers.get( - session.container_name + await self.docker_service.stop_container( + session.container_name, timeout=10 ) - container.stop(timeout=10) - container.remove() - print(f"Cleaned up container {session.container_name}") + await self.docker_service.remove_container(session.container_name) + logger.info(f"Cleaned up container {session.container_name}") except Exception as e: - print(f"Error cleaning up container {session.container_name}: {e}") + logger.error( + f"Error cleaning up container {session.container_name}: {e}" + ) # Remove session directory try: import shutil shutil.rmtree(session.host_dir) - print(f"Removed session directory {session.host_dir}") + logger.info(f"Removed session directory {session.host_dir}") except OSError as e: - print(f"Error removing session directory {session.host_dir}: {e}") + logger.error( + f"Error removing session directory {session.host_dir}: {e}" + ) for session_id in expired_sessions: del self.sessions[session_id] if expired_sessions: self._save_sessions() - print(f"Cleaned up {len(expired_sessions)} expired sessions") + logger.info(f"Cleaned up {len(expired_sessions)} expired sessions") + + # Also cleanup expired authentication tokens + expired_tokens = cleanup_expired_auth_tokens() + if expired_tokens > 0: + logger.info(f"Cleaned up {expired_tokens} expired authentication tokens") # Global session manager instance @@ -227,8 +612,47 @@ session_manager = SessionManager() @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan manager""" + global USE_DATABASE_STORAGE # Declare global at function start + # Startup - print("Starting Session Management Service") + logger.info("Starting Session Management Service") + + # Initialize HTTP connection pool + await init_http_pool() + + # Initialize database if using database storage + if USE_DATABASE_STORAGE: + try: + await init_database() + await run_migrations() + # Load active sessions from database + await session_manager._load_sessions_from_database() + logger.info("Database initialized and sessions loaded") + except Exception as e: + logger.error("Database initialization failed", extra={"error": str(e)}) + if USE_DATABASE_STORAGE: + logger.warning("Falling back to JSON file storage") + USE_DATABASE_STORAGE = False + session_manager._load_sessions_from_file() + + # Start container health monitoring + try: + docker_client = None + if USE_ASYNC_DOCKER: + from async_docker_client import get_async_docker_client + + # Create a client instance for health monitoring + async with get_async_docker_client() as client: + docker_client = client._docker if hasattr(client, "_docker") else None + else: + docker_client = session_manager.docker_client + + await start_container_health_monitoring(session_manager, docker_client) + logger.info("Container health monitoring started") + except Exception as e: + logger.error( + "Failed to start container health monitoring", extra={"error": str(e)} + ) # Start cleanup task async def cleanup_task(): @@ -241,9 +665,25 @@ async def lifespan(app: FastAPI): yield # Shutdown - print("Shutting down Session Management Service") + logger.info("Shutting down Session Management Service") cleanup_coro.cancel() + # Shutdown HTTP connection pool + await shutdown_http_pool() + + # Shutdown container health monitoring + try: + await stop_container_health_monitoring() + logger.info("Container health monitoring stopped") + except Exception as e: + logger.error( + "Error stopping container health monitoring", extra={"error": str(e)} + ) + + # Shutdown database + if USE_DATABASE_STORAGE: + await shutdown_database() + app = FastAPI( title="Lovdata Chat Session Manager", @@ -254,26 +694,83 @@ app = FastAPI( @app.post("/sessions", response_model=SessionData) -async def create_session(): +async def create_session(request: Request): """Create a new session with dedicated container""" - try: - session = await session_manager.create_session() - return session - except HTTPException: - raise - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to create session: {str(e)}" - ) + start_time = time.time() + + with RequestContext(): + try: + log_request("POST", "/sessions", 200, 0, operation="create_session_start") + + session = await session_manager.create_session() + + duration_ms = (time.time() - start_time) * 1000 + log_session_operation( + session.session_id, "created", duration_ms=duration_ms + ) + log_performance( + "create_session", duration_ms, session_id=session.session_id + ) + + return session + except HTTPException as e: + duration_ms = (time.time() - start_time) * 1000 + log_request( + "POST", "/sessions", e.status_code, duration_ms, error=str(e.detail) + ) + raise + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + log_request("POST", "/sessions", 500, duration_ms, error=str(e)) + raise HTTPException( + status_code=500, detail=f"Failed to create session: {str(e)}" + ) @app.get("/sessions/{session_id}", response_model=SessionData) -async def get_session(session_id: str): +async def get_session(session_id: str, request: Request): """Get session information""" - session = await session_manager.get_session(session_id) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - return session + start_time = time.time() + + with RequestContext(): + try: + log_request( + "GET", f"/sessions/{session_id}", 200, 0, operation="get_session_start" + ) + + session = await session_manager.get_session(session_id) + if not session: + duration_ms = (time.time() - start_time) * 1000 + log_request( + "GET", + f"/sessions/{session_id}", + 404, + duration_ms, + session_id=session_id, + ) + raise HTTPException(status_code=404, detail="Session not found") + + duration_ms = (time.time() - start_time) * 1000 + log_request( + "GET", + f"/sessions/{session_id}", + 200, + duration_ms, + session_id=session_id, + ) + log_session_operation(session_id, "accessed", duration_ms=duration_ms) + + return session + except HTTPException: + raise + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + log_request( + "GET", f"/sessions/{session_id}", 500, duration_ms, error=str(e) + ) + raise HTTPException( + status_code=500, detail=f"Failed to get session: {str(e)}" + ) @app.get("/sessions", response_model=List[SessionData]) @@ -289,6 +786,9 @@ async def delete_session(session_id: str, background_tasks: BackgroundTasks): if not session: raise HTTPException(status_code=404, detail="Session not found") + # Revoke authentication token + revoke_session_auth_token(session_id) + # Schedule cleanup background_tasks.add_task(session_manager.cleanup_expired_sessions) @@ -306,19 +806,149 @@ async def trigger_cleanup(): return {"message": "Cleanup completed"} +@app.get("/sessions/{session_id}/auth") +async def get_session_auth_info(session_id: str): + """Get authentication information for a session""" + session = await session_manager.get_session(session_id) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + auth_info = get_session_auth_info(session_id) + if not auth_info: + raise HTTPException(status_code=404, detail="Authentication info not found") + + return { + "session_id": session_id, + "auth_info": auth_info, + "has_token": session.auth_token is not None, + } + + +@app.post("/sessions/{session_id}/auth/rotate") +async def rotate_session_token(session_id: str): + """Rotate the authentication token for a session""" + session = await session_manager.get_session(session_id) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + from session_auth import _session_token_manager + + new_token = _session_token_manager.rotate_session_token(session_id) + if not new_token: + raise HTTPException(status_code=404, detail="Failed to rotate token") + + # Update session with new token + session.auth_token = new_token + session_manager._save_sessions() + + return { + "session_id": session_id, + "new_token": new_token, + "message": "Token rotated successfully", + } + + +@app.get("/auth/sessions") +async def list_authenticated_sessions(): + """List all authenticated sessions""" + sessions = list_active_auth_sessions() + return { + "active_auth_sessions": len(sessions), + "sessions": sessions, + } + + +@app.get("/health/container") +async def get_container_health(): + """Get detailed container health statistics""" + stats = get_container_health_stats() + return stats + + +@app.get("/health/container/{session_id}") +async def get_session_container_health(session_id: str): + """Get container health information for a specific session""" + session = await session_manager.get_session(session_id) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + stats = get_container_health_stats(session_id) + history = get_container_health_history(session_id, limit=20) + + return { + "session_id": session_id, + "container_id": session.container_id, + "stats": stats.get(f"session_{session_id}", {}), + "recent_history": history, + } + + @app.api_route( "/session/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], ) async def proxy_to_session(request: Request, session_id: str, path: str): """Proxy requests to session containers based on session ID in URL""" - session = await session_manager.get_session(session_id) - if not session or session.status != "running": - raise HTTPException(status_code=404, detail="Session not found or not running") + start_time = time.time() + + with RequestContext(): + log_request( + request.method, + f"/session/{session_id}/{path}", + 200, + 0, + operation="proxy_start", + session_id=session_id, + ) + + session = await session_manager.get_session(session_id) + if not session or session.status != "running": + duration_ms = (time.time() - start_time) * 1000 + log_request( + request.method, + f"/session/{session_id}/{path}", + 404, + duration_ms, + session_id=session_id, + error="Session not found or not running", + ) + raise HTTPException( + status_code=404, detail="Session not found or not running" + ) + + # Dynamically detect the Docker host IP from container perspective + # This supports multiple Docker environments (Docker Desktop, Linux, cloud, etc.) + try: + host_ip = await async_get_host_ip() + logger.info(f"Using detected host IP for proxy: {host_ip}") + except RuntimeError as e: + # Fallback to environment variable or common defaults + host_ip = os.getenv("HOST_IP") + if not host_ip: + # Try common Docker gateway IPs as final fallback + common_gateways = ["172.17.0.1", "192.168.65.1", "host.docker.internal"] + for gateway in common_gateways: + try: + # Test connectivity to gateway + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1.0) + result = sock.connect_ex((gateway, 22)) + sock.close() + if result == 0: + host_ip = gateway + logger.warning(f"Using fallback gateway IP: {host_ip}") + break + except Exception: + continue + else: + logger.error(f"Host IP detection failed: {e}") + raise HTTPException( + status_code=500, + detail="Could not determine Docker host IP for proxy routing", + ) - # Proxy the request to the container (use host.docker.internal for Docker Desktop, or host IP) - # For Linux, we need to use the host IP - host_ip = os.getenv("HOST_IP", "172.17.0.1") # Default Docker bridge IP container_url = f"http://{host_ip}:{session.port}" # Prepare the request URL @@ -333,55 +963,210 @@ async def proxy_to_session(request: Request, session_id: str, path: str): headers = dict(request.headers) headers.pop("host", None) - # Make the proxy request - async with httpx.AsyncClient(timeout=30.0) as client: - try: - response = await client.request( - method=request.method, - url=url, - headers=headers, - content=body, - follow_redirects=False, # Let the client handle redirects - ) + # Add authentication headers for the OpenCode server + if session.auth_token: + headers["Authorization"] = f"Bearer {session.auth_token}" + headers["X-Session-Token"] = session.auth_token + headers["X-Session-ID"] = session.session_id - # Return the response - return Response( - content=response.content, - status_code=response.status_code, - headers=dict(response.headers), - ) + # Make the proxy request using the connection pool + try: + log_session_operation( + session_id, "proxy_request", method=request.method, path=path + ) - except httpx.TimeoutException: - raise HTTPException( - status_code=504, detail="Request to session container timed out" - ) - except httpx.RequestError as e: - raise HTTPException( - status_code=502, - detail=f"Failed to connect to session container: {str(e)}", - ) + response = await make_http_request( + method=request.method, + url=url, + headers=headers, + content=body, + ) + + duration_ms = (time.time() - start_time) * 1000 + log_request( + request.method, + f"/session/{session_id}/{path}", + response.status_code, + duration_ms, + session_id=session_id, + operation="proxy_complete", + ) + + # Log security event for proxy access + log_security_event( + "proxy_access", + "info", + session_id=session_id, + method=request.method, + path=path, + status_code=response.status_code, + ) + + # Return the response + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers), + ) + + except httpx.TimeoutException as e: + duration_ms = (time.time() - start_time) * 1000 + log_request( + request.method, + f"/session/{session_id}/{path}", + 504, + duration_ms, + session_id=session_id, + error="timeout", + ) + log_security_event( + "proxy_timeout", + "warning", + session_id=session_id, + method=request.method, + path=path, + error=str(e), + ) + raise HTTPException( + status_code=504, detail="Request to session container timed out" + ) + except httpx.RequestError as e: + duration_ms = (time.time() - start_time) * 1000 + log_request( + request.method, + f"/session/{session_id}/{path}", + 502, + duration_ms, + session_id=session_id, + error=str(e), + ) + log_security_event( + "proxy_connection_error", + "error", + session_id=session_id, + method=request.method, + path=path, + error=str(e), + ) + raise HTTPException( + status_code=502, + detail=f"Failed to connect to session container: {str(e)}", + ) @app.get("/health") async def health_check(): - """Health check endpoint""" + """Health check endpoint with comprehensive resource monitoring""" docker_ok = False + host_ip_ok = False + detected_host_ip = None + resource_status = {} + http_pool_stats = {} + try: # Check Docker connectivity - session_manager.docker_client.ping() - docker_ok = True - except: + docker_ok = await session_manager.docker_service.ping() + except Exception as e: + logger.warning(f"Docker health check failed: {e}") docker_ok = False - return { - "status": "healthy" if docker_ok else "unhealthy", + try: + # Check host IP detection + detected_host_ip = await async_get_host_ip() + host_ip_ok = True + except Exception as e: + logger.warning(f"Host IP detection failed: {e}") + host_ip_ok = False + + try: + # Check system resource status + resource_status = check_system_resources() + except Exception as e: + logger.warning("Resource monitoring failed", extra={"error": str(e)}) + resource_status = {"error": str(e)} + + try: + # Get HTTP connection pool statistics + http_pool_stats = await get_connection_pool_stats() + except Exception as e: + logger.warning("HTTP pool stats failed", extra={"error": str(e)}) + http_pool_stats = {"error": str(e)} + + # Check database status if using database storage + database_status = {} + if USE_DATABASE_STORAGE: + try: + database_status = await get_database_stats() + except Exception as e: + logger.warning("Database stats failed", extra={"error": str(e)}) + database_status = {"status": "error", "error": str(e)} + + # Get container health statistics + container_health_stats = {} + try: + container_health_stats = get_container_health_stats() + except Exception as e: + logger.warning("Container health stats failed", extra={"error": str(e)}) + container_health_stats = {"error": str(e)} + + # Determine overall health status + resource_alerts = ( + resource_status.get("alerts", []) if isinstance(resource_status, dict) else [] + ) + critical_alerts = [ + a + for a in resource_alerts + if isinstance(a, dict) and a.get("level") == "critical" + ] + + # Check HTTP pool health + http_healthy = ( + http_pool_stats.get("status") == "healthy" + if isinstance(http_pool_stats, dict) + else False + ) + + if critical_alerts or not (docker_ok and host_ip_ok and http_healthy): + status = "unhealthy" + elif resource_alerts: + status = "degraded" + else: + status = "healthy" + + health_data = { + "status": status, "docker": docker_ok, + "docker_mode": "async" if USE_ASYNC_DOCKER else "sync", + "host_ip_detection": host_ip_ok, + "detected_host_ip": detected_host_ip, + "http_connection_pool": http_pool_stats, + "storage_backend": "database" if USE_DATABASE_STORAGE else "json_file", "active_sessions": len( [s for s in session_manager.sessions.values() if s.status == "running"] ), + "resource_limits": { + "memory_limit": CONTAINER_MEMORY_LIMIT, + "cpu_quota": CONTAINER_CPU_QUOTA, + "cpu_period": CONTAINER_CPU_PERIOD, + "max_concurrent_sessions": MAX_CONCURRENT_SESSIONS, + }, + "system_resources": resource_status.get("system_resources", {}) + if isinstance(resource_status, dict) + else {}, + "resource_alerts": resource_alerts, "timestamp": datetime.now().isoformat(), } + # Add database information if using database storage + if USE_DATABASE_STORAGE and database_status: + health_data["database"] = database_status + + # Add container health information + if container_health_stats: + health_data["container_health"] = container_health_stats + + return health_data + if __name__ == "__main__": uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) diff --git a/session-manager/requirements.txt b/session-manager/requirements.txt index 1589966..acac2ea 100644 --- a/session-manager/requirements.txt +++ b/session-manager/requirements.txt @@ -1,6 +1,9 @@ fastapi==0.104.1 uvicorn==0.24.0 docker>=7.1.0 +aiodeocker>=0.21.0 +asyncpg>=0.29.0 pydantic==2.5.0 python-multipart==0.0.6 -httpx==0.25.2 \ No newline at end of file +httpx==0.25.2 +psutil>=5.9.0 \ No newline at end of file diff --git a/session-manager/resource_manager.py b/session-manager/resource_manager.py new file mode 100644 index 0000000..2f9b950 --- /dev/null +++ b/session-manager/resource_manager.py @@ -0,0 +1,248 @@ +""" +Resource Management and Monitoring Utilities + +Provides validation, enforcement, and monitoring of container resource limits +to prevent resource exhaustion attacks and ensure fair resource allocation. +""" + +import os +import psutil +import logging +from typing import Dict, Optional, Tuple +from dataclasses import dataclass +from datetime import datetime, timedelta + +logger = logging.getLogger(__name__) + + +@dataclass +class ResourceLimits: + """Container resource limits configuration.""" + + memory_limit: str # e.g., "4g", "512m" + cpu_quota: int # CPU quota in microseconds + cpu_period: int # CPU period in microseconds + + def validate(self) -> Tuple[bool, str]: + """Validate resource limits configuration.""" + # Validate memory limit format + memory_limit = self.memory_limit.lower() + if not (memory_limit.endswith(("g", "m", "k")) or memory_limit.isdigit()): + return ( + False, + f"Invalid memory limit format: {self.memory_limit}. Use format like '4g', '512m', '256k'", + ) + + # Validate CPU quota and period + if self.cpu_quota <= 0: + return False, f"CPU quota must be positive, got {self.cpu_quota}" + if self.cpu_period <= 0: + return False, f"CPU period must be positive, got {self.cpu_period}" + if self.cpu_quota > self.cpu_period: + return ( + False, + f"CPU quota ({self.cpu_quota}) cannot exceed CPU period ({self.cpu_period})", + ) + + return True, "Valid" + + def to_docker_limits(self) -> Dict[str, any]: + """Convert to Docker container limits format.""" + return { + "mem_limit": self.memory_limit, + "cpu_quota": self.cpu_quota, + "cpu_period": self.cpu_period, + } + + +class ResourceMonitor: + """Monitor system and container resource usage.""" + + def __init__(self): + self._last_check = datetime.now() + self._alerts_sent = set() # Track alerts to prevent spam + + def get_system_resources(self) -> Dict[str, any]: + """Get current system resource usage.""" + try: + memory = psutil.virtual_memory() + cpu = psutil.cpu_percent(interval=1) + + return { + "memory_percent": memory.percent / 100.0, + "memory_used_gb": memory.used / (1024**3), + "memory_total_gb": memory.total / (1024**3), + "cpu_percent": cpu / 100.0, + "cpu_count": psutil.cpu_count(), + } + except Exception as e: + logger.warning(f"Failed to get system resources: {e}") + return {} + + def check_resource_limits( + self, limits: ResourceLimits, warning_thresholds: Dict[str, float] + ) -> Dict[str, any]: + """Check if system resources are approaching limits.""" + system_resources = self.get_system_resources() + alerts = [] + + # Check memory usage + memory_usage = system_resources.get("memory_percent", 0) + memory_threshold = warning_thresholds.get("memory", 0.8) + + if memory_usage >= memory_threshold: + alerts.append( + { + "type": "memory", + "level": "warning" if memory_usage < 0.95 else "critical", + "message": f"System memory usage at {memory_usage:.1%}", + "current": memory_usage, + "threshold": memory_threshold, + } + ) + + # Check CPU usage + cpu_usage = system_resources.get("cpu_percent", 0) + cpu_threshold = warning_thresholds.get("cpu", 0.9) + + if cpu_usage >= cpu_threshold: + alerts.append( + { + "type": "cpu", + "level": "warning" if cpu_usage < 0.95 else "critical", + "message": f"System CPU usage at {cpu_usage:.1%}", + "current": cpu_usage, + "threshold": cpu_threshold, + } + ) + + return { + "system_resources": system_resources, + "alerts": alerts, + "timestamp": datetime.now(), + } + + def should_throttle_sessions(self, resource_check: Dict) -> Tuple[bool, str]: + """Determine if new sessions should be throttled based on resource usage.""" + alerts = resource_check.get("alerts", []) + + # Critical alerts always throttle + critical_alerts = [a for a in alerts if a["level"] == "critical"] + if critical_alerts: + return ( + True, + f"Critical resource usage: {[a['message'] for a in critical_alerts]}", + ) + + # Multiple warnings also throttle + warning_alerts = [a for a in alerts if a["level"] == "warning"] + if len(warning_alerts) >= 2: + return ( + True, + f"Multiple resource warnings: {[a['message'] for a in warning_alerts]}", + ) + + return False, "Resources OK" + + +class ResourceValidator: + """Validate and parse resource limit configurations.""" + + @staticmethod + def parse_memory_limit(memory_str: str) -> Tuple[int, str]: + """Parse memory limit string and return bytes.""" + if not memory_str: + raise ValueError("Memory limit cannot be empty") + + memory_str = memory_str.lower().strip() + + # Handle different units + if memory_str.endswith("g"): + bytes_val = int(memory_str[:-1]) * (1024**3) + unit = "GB" + elif memory_str.endswith("m"): + bytes_val = int(memory_str[:-1]) * (1024**2) + unit = "MB" + elif memory_str.endswith("k"): + bytes_val = int(memory_str[:-1]) * 1024 + unit = "KB" + else: + # Assume bytes if no unit + bytes_val = int(memory_str) + unit = "bytes" + + if bytes_val <= 0: + raise ValueError(f"Memory limit must be positive, got {bytes_val}") + + # Reasonable limits check + if bytes_val > 32 * (1024**3): # 32GB + logger.warning(f"Very high memory limit: {bytes_val} bytes") + + return bytes_val, unit + + @staticmethod + def validate_resource_config( + config: Dict[str, any], + ) -> Tuple[bool, str, Optional[ResourceLimits]]: + """Validate complete resource configuration.""" + try: + limits = ResourceLimits( + memory_limit=config.get("memory_limit", "4g"), + cpu_quota=config.get("cpu_quota", 100000), + cpu_period=config.get("cpu_period", 100000), + ) + + valid, message = limits.validate() + if not valid: + return False, message, None + + # Additional validation + memory_bytes, _ = ResourceValidator.parse_memory_limit(limits.memory_limit) + + # Warn about potentially problematic configurations + if memory_bytes < 128 * (1024**2): # Less than 128MB + logger.warning("Very low memory limit may cause container instability") + + return True, "Configuration valid", limits + + except (ValueError, TypeError) as e: + return False, f"Invalid configuration: {e}", None + + +# Global instances +resource_monitor = ResourceMonitor() + + +def get_resource_limits() -> ResourceLimits: + """Get validated resource limits from environment.""" + config = { + "memory_limit": os.getenv("CONTAINER_MEMORY_LIMIT", "4g"), + "cpu_quota": int(os.getenv("CONTAINER_CPU_QUOTA", "100000")), + "cpu_period": int(os.getenv("CONTAINER_CPU_PERIOD", "100000")), + } + + valid, message, limits = ResourceValidator.validate_resource_config(config) + if not valid or limits is None: + raise ValueError(f"Resource configuration error: {message}") + + logger.info( + f"Using resource limits: memory={limits.memory_limit}, cpu_quota={limits.cpu_quota}" + ) + return limits + + +def check_system_resources() -> Dict[str, any]: + """Check current system resource status.""" + limits = get_resource_limits() + warning_thresholds = { + "memory": float(os.getenv("MEMORY_WARNING_THRESHOLD", "0.8")), + "cpu": float(os.getenv("CPU_WARNING_THRESHOLD", "0.9")), + } + + return resource_monitor.check_resource_limits(limits, warning_thresholds) + + +def should_throttle_sessions() -> Tuple[bool, str]: + """Check if new sessions should be throttled due to resource constraints.""" + resource_check = check_system_resources() + return resource_monitor.should_throttle_sessions(resource_check) diff --git a/session-manager/session_auth.py b/session-manager/session_auth.py new file mode 100644 index 0000000..83f2af3 --- /dev/null +++ b/session-manager/session_auth.py @@ -0,0 +1,235 @@ +""" +Token-Based Authentication for OpenCode Sessions + +Provides secure token generation, validation, and management for individual +user sessions to prevent unauthorized access to OpenCode servers. +""" + +import os +import uuid +import secrets +import hashlib +import hmac +from typing import Dict, Optional, Tuple +from datetime import datetime, timedelta +import logging + +logger = logging.getLogger(__name__) + + +class SessionTokenManager: + """Manages authentication tokens for OpenCode user sessions.""" + + def __init__(self): + # Token storage - in production, this should be in Redis/database + self._session_tokens: Dict[str, Dict] = {} + + # Token configuration + self._token_length = int(os.getenv("SESSION_TOKEN_LENGTH", "32")) + self._token_expiry_hours = int(os.getenv("SESSION_TOKEN_EXPIRY_HOURS", "24")) + self._token_secret = os.getenv("SESSION_TOKEN_SECRET", self._generate_secret()) + + # Cleanup configuration + self._cleanup_interval_minutes = int( + os.getenv("TOKEN_CLEANUP_INTERVAL_MINUTES", "60") + ) + + def _generate_secret(self) -> str: + """Generate a secure secret for token signing.""" + return secrets.token_hex(32) + + def generate_session_token(self, session_id: str) -> str: + """ + Generate a unique authentication token for a session. + + Args: + session_id: The session identifier + + Returns: + str: The authentication token + """ + # Generate cryptographically secure random token + token = secrets.token_urlsafe(self._token_length) + + # Create token data with expiry + expiry = datetime.now() + timedelta(hours=self._token_expiry_hours) + + # Store token information + self._session_tokens[session_id] = { + "token": token, + "session_id": session_id, + "created_at": datetime.now(), + "expires_at": expiry, + "last_used": datetime.now(), + } + + logger.info(f"Generated authentication token for session {session_id}") + return token + + def validate_session_token(self, session_id: str, token: str) -> Tuple[bool, str]: + """ + Validate a session token. + + Args: + session_id: The session identifier + token: The token to validate + + Returns: + Tuple[bool, str]: (is_valid, reason) + """ + # Check if session exists + if session_id not in self._session_tokens: + return False, "Session not found" + + session_data = self._session_tokens[session_id] + + # Check if token matches + if session_data["token"] != token: + return False, "Invalid token" + + # Check if token has expired + if datetime.now() > session_data["expires_at"]: + # Clean up expired token + del self._session_tokens[session_id] + return False, "Token expired" + + # Update last used time + session_data["last_used"] = datetime.now() + + return True, "Valid" + + def revoke_session_token(self, session_id: str) -> bool: + """ + Revoke a session token. + + Args: + session_id: The session identifier + + Returns: + bool: True if token was revoked, False if not found + """ + if session_id in self._session_tokens: + del self._session_tokens[session_id] + logger.info(f"Revoked authentication token for session {session_id}") + return True + return False + + def rotate_session_token(self, session_id: str) -> Optional[str]: + """ + Rotate (regenerate) a session token. + + Args: + session_id: The session identifier + + Returns: + Optional[str]: New token if session exists, None otherwise + """ + if session_id not in self._session_tokens: + return None + + # Generate new token + new_token = self.generate_session_token(session_id) + + logger.info(f"Rotated authentication token for session {session_id}") + return new_token + + def cleanup_expired_tokens(self) -> int: + """ + Clean up expired tokens. + + Returns: + int: Number of tokens cleaned up + """ + now = datetime.now() + expired_sessions = [] + + for session_id, session_data in self._session_tokens.items(): + if now > session_data["expires_at"]: + expired_sessions.append(session_id) + + # Remove expired tokens + for session_id in expired_sessions: + del self._session_tokens[session_id] + + if expired_sessions: + logger.info( + f"Cleaned up {len(expired_sessions)} expired authentication tokens" + ) + + return len(expired_sessions) + + def get_session_token_info(self, session_id: str) -> Optional[Dict]: + """ + Get information about a session token. + + Args: + session_id: The session identifier + + Returns: + Optional[Dict]: Token information or None if not found + """ + if session_id not in self._session_tokens: + return None + + session_data = self._session_tokens[session_id].copy() + # Remove sensitive token value + session_data.pop("token", None) + return session_data + + def get_active_sessions_count(self) -> int: + """Get the number of active sessions with tokens.""" + return len(self._session_tokens) + + def list_active_sessions(self) -> Dict[str, Dict]: + """List all active sessions with token information (without token values).""" + result = {} + for session_id, session_data in self._session_tokens.items(): + # Create copy without sensitive token + info = session_data.copy() + info.pop("token", None) + result[session_id] = info + return result + + +# Global token manager instance +_session_token_manager = SessionTokenManager() + + +def generate_session_auth_token(session_id: str) -> str: + """Generate an authentication token for a session.""" + return _session_token_manager.generate_session_token(session_id) + + +def validate_session_auth_token(session_id: str, token: str) -> Tuple[bool, str]: + """Validate a session authentication token.""" + return _session_token_manager.validate_session_token(session_id, token) + + +def revoke_session_auth_token(session_id: str) -> bool: + """Revoke a session authentication token.""" + return _session_token_manager.revoke_session_token(session_id) + + +def rotate_session_auth_token(session_id: str) -> Optional[str]: + """Rotate a session authentication token.""" + return _session_token_manager.rotate_session_auth_token(session_id) + + +def cleanup_expired_auth_tokens() -> int: + """Clean up expired authentication tokens.""" + return _session_token_manager.cleanup_expired_tokens() + + +def get_session_auth_info(session_id: str) -> Optional[Dict]: + """Get authentication information for a session.""" + return _session_token_manager.get_session_token_info(session_id) + + +def get_active_auth_sessions_count() -> int: + """Get the number of active authenticated sessions.""" + return _session_token_manager.get_active_sessions_count() + + +def list_active_auth_sessions() -> Dict[str, Dict]: + """List all active authenticated sessions.""" + return _session_token_manager.list_active_sessions()