fixed all remaining issues with the session manager

This commit is contained in:
2026-01-18 23:28:49 +01:00
parent 0243cfc250
commit 2f5464e1d2
11 changed files with 4040 additions and 101 deletions

View File

@@ -0,0 +1,302 @@
"""
Async Docker Operations Wrapper
Provides async wrappers for Docker operations to eliminate blocking calls
in FastAPI async contexts and improve concurrency and scalability.
"""
import asyncio
import logging
from typing import Dict, Optional, List, Any
from contextlib import asynccontextmanager
import os
from aiodeocker import Docker
from aiodeocker.containers import DockerContainer
from aiodeocker.exceptions import DockerError
logger = logging.getLogger(__name__)
class AsyncDockerClient:
"""Async wrapper for Docker operations using aiodeocker."""
def __init__(self):
self._docker: Optional[Docker] = None
self._connected = False
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.disconnect()
async def connect(self):
"""Connect to Docker daemon."""
if self._connected:
return
try:
# Configure TLS if available
tls_config = None
if os.getenv("DOCKER_TLS_VERIFY") == "1":
from aiodeocker.utils import create_tls_config
tls_config = create_tls_config(
ca_cert=os.getenv("DOCKER_CA_CERT", "/etc/docker/certs/ca.pem"),
client_cert=(
os.getenv(
"DOCKER_CLIENT_CERT", "/etc/docker/certs/client-cert.pem"
),
os.getenv(
"DOCKER_CLIENT_KEY", "/etc/docker/certs/client-key.pem"
),
),
verify=True,
)
docker_host = os.getenv("DOCKER_HOST", "tcp://host.docker.internal:2376")
self._docker = Docker(docker_host, tls=tls_config)
# Test connection
await self._docker.ping()
self._connected = True
logger.info("Async Docker client connected successfully")
except Exception as e:
logger.error(f"Failed to connect to Docker: {e}")
raise
async def disconnect(self):
"""Disconnect from Docker daemon."""
if self._docker and self._connected:
await self._docker.close()
self._connected = False
logger.info("Async Docker client disconnected")
async def ping(self) -> bool:
"""Test Docker connectivity."""
if not self._docker:
return False
try:
await self._docker.ping()
return True
except Exception:
return False
async def create_container(
self,
image: str,
name: str,
volumes: Optional[Dict[str, Dict[str, str]]] = None,
ports: Optional[Dict[str, int]] = None,
environment: Optional[Dict[str, str]] = None,
network_mode: str = "bridge",
mem_limit: Optional[str] = None,
cpu_quota: Optional[int] = None,
cpu_period: Optional[int] = None,
tmpfs: Optional[Dict[str, str]] = None,
**kwargs,
) -> DockerContainer:
"""
Create a Docker container asynchronously.
Args:
image: Container image name
name: Container name
volumes: Volume mounts
ports: Port mappings
environment: Environment variables
network_mode: Network mode
mem_limit: Memory limit (e.g., "4g")
cpu_quota: CPU quota
cpu_period: CPU period
tmpfs: tmpfs mounts
**kwargs: Additional container configuration
Returns:
DockerContainer: The created container
"""
if not self._docker:
raise RuntimeError("Docker client not connected")
config = {
"Image": image,
"name": name,
"Volumes": volumes or {},
"ExposedPorts": {f"{port}/tcp": {} for port in ports.values()}
if ports
else {},
"Env": [f"{k}={v}" for k, v in (environment or {}).items()],
"NetworkMode": network_mode,
"HostConfig": {
"Binds": [
f"{host}:{container['bind']}:{container.get('mode', 'rw')}"
for host, container in (volumes or {}).items()
],
"PortBindings": {
f"{container_port}/tcp": [{"HostPort": str(host_port)}]
for container_port, host_port in (ports or {}).items()
},
"Tmpfs": tmpfs or {},
},
}
# Add resource limits
host_config = config["HostConfig"]
if mem_limit:
host_config["Memory"] = self._parse_memory_limit(mem_limit)
if cpu_quota is not None:
host_config["CpuQuota"] = cpu_quota
if cpu_period is not None:
host_config["CpuPeriod"] = cpu_period
# Add any additional host config
host_config.update(kwargs.get("host_config", {}))
try:
container = await self._docker.containers.create(config)
logger.info(f"Container {name} created successfully")
return container
except DockerError as e:
logger.error(f"Failed to create container {name}: {e}")
raise
async def start_container(self, container: DockerContainer) -> None:
"""Start a Docker container."""
try:
await container.start()
logger.info(f"Container {container.id} started successfully")
except DockerError as e:
logger.error(f"Failed to start container {container.id}: {e}")
raise
async def stop_container(
self, container: DockerContainer, timeout: int = 10
) -> None:
"""Stop a Docker container."""
try:
await container.stop(timeout=timeout)
logger.info(f"Container {container.id} stopped successfully")
except DockerError as e:
logger.error(f"Failed to stop container {container.id}: {e}")
raise
async def remove_container(
self, container: DockerContainer, force: bool = False
) -> None:
"""Remove a Docker container."""
try:
await container.delete(force=force)
logger.info(f"Container {container.id} removed successfully")
except DockerError as e:
logger.error(f"Failed to remove container {container.id}: {e}")
raise
async def get_container(self, container_id: str) -> Optional[DockerContainer]:
"""Get a container by ID or name."""
try:
return await self._docker.containers.get(container_id)
except DockerError:
return None
async def list_containers(
self, all: bool = False, filters: Optional[Dict[str, Any]] = None
) -> List[DockerContainer]:
"""List Docker containers."""
try:
return await self._docker.containers.list(all=all, filters=filters)
except DockerError as e:
logger.error(f"Failed to list containers: {e}")
return []
async def get_container_stats(
self, container: DockerContainer
) -> Optional[Dict[str, Any]]:
"""Get container statistics."""
try:
stats = await container.stats(stream=False)
return stats
except DockerError as e:
logger.error(f"Failed to get stats for container {container.id}: {e}")
return None
async def get_system_info(self) -> Optional[Dict[str, Any]]:
"""Get Docker system information."""
if not self._docker:
return None
try:
return await self._docker.system.info()
except DockerError as e:
logger.error(f"Failed to get system info: {e}")
return None
def _parse_memory_limit(self, memory_str: str) -> int:
"""Parse memory limit string to bytes."""
memory_str = memory_str.lower().strip()
if memory_str.endswith("g"):
return int(memory_str[:-1]) * (1024**3)
elif memory_str.endswith("m"):
return int(memory_str[:-1]) * (1024**2)
elif memory_str.endswith("k"):
return int(memory_str[:-1]) * 1024
else:
return int(memory_str)
# Global async Docker client instance
_async_docker_client = AsyncDockerClient()
@asynccontextmanager
async def get_async_docker_client():
"""Context manager for async Docker client."""
async with _async_docker_client as client:
yield client
async def async_docker_ping() -> bool:
"""Async ping Docker daemon."""
async with get_async_docker_client() as client:
return await client.ping()
async def async_create_container(**kwargs) -> DockerContainer:
"""Async container creation wrapper."""
async with get_async_docker_client() as client:
return await client.create_container(**kwargs)
async def async_start_container(container: DockerContainer) -> None:
"""Async container start wrapper."""
async with get_async_docker_client() as client:
await client.start_container(container)
async def async_stop_container(container: DockerContainer, timeout: int = 10) -> None:
"""Async container stop wrapper."""
async with get_async_docker_client() as client:
await client.stop_container(container, timeout)
async def async_remove_container(
container: DockerContainer, force: bool = False
) -> None:
"""Async container removal wrapper."""
async with get_async_docker_client() as client:
await client.remove_container(container, force)
async def async_list_containers(
all: bool = False, filters: Optional[Dict[str, Any]] = None
) -> List[DockerContainer]:
"""Async container listing wrapper."""
async with get_async_docker_client() as client:
return await client.list_containers(all=all, filters=filters)
async def async_get_container(container_id: str) -> Optional[DockerContainer]:
"""Async container retrieval wrapper."""
async with get_async_docker_client() as client:
return await client.get_container(container_id)

View File

@@ -0,0 +1,574 @@
"""
Container Health Monitoring System
Provides active monitoring of Docker containers with automatic failure detection,
recovery mechanisms, and integration with session management and alerting systems.
"""
import asyncio
import logging
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime, timedelta
from enum import Enum
from logging_config import get_logger, log_performance, log_security_event
logger = get_logger(__name__)
class ContainerStatus(Enum):
"""Container health status enumeration."""
HEALTHY = "healthy"
UNHEALTHY = "unhealthy"
RESTARTING = "restarting"
FAILED = "failed"
UNKNOWN = "unknown"
class HealthCheckResult:
"""Result of a container health check."""
def __init__(
self,
session_id: str,
container_id: str,
status: ContainerStatus,
response_time: Optional[float] = None,
error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
):
self.session_id = session_id
self.container_id = container_id
self.status = status
self.response_time = response_time
self.error_message = error_message
self.metadata = metadata or {}
self.timestamp = datetime.utcnow()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for logging/serialization."""
return {
"session_id": self.session_id,
"container_id": self.container_id,
"status": self.status.value,
"response_time": self.response_time,
"error_message": self.error_message,
"metadata": self.metadata,
"timestamp": self.timestamp.isoformat(),
}
class ContainerHealthMonitor:
"""Monitors Docker container health and handles automatic recovery."""
def __init__(
self,
check_interval: int = 30, # seconds
health_timeout: float = 10.0, # seconds
max_restart_attempts: int = 3,
restart_delay: int = 5, # seconds
failure_threshold: int = 3, # consecutive failures before restart
):
self.check_interval = check_interval
self.health_timeout = health_timeout
self.max_restart_attempts = max_restart_attempts
self.restart_delay = restart_delay
self.failure_threshold = failure_threshold
# Monitoring state
self._monitoring = False
self._task: Optional[asyncio.Task] = None
self._health_history: Dict[str, List[HealthCheckResult]] = {}
self._restart_counts: Dict[str, int] = {}
# Dependencies (injected)
self.session_manager = None
self.docker_client = None
logger.info(
"Container health monitor initialized",
extra={
"check_interval": check_interval,
"health_timeout": health_timeout,
"max_restart_attempts": max_restart_attempts,
},
)
def set_dependencies(self, session_manager, docker_client):
"""Set dependencies for health monitoring."""
self.session_manager = session_manager
self.docker_client = docker_client
async def start_monitoring(self):
"""Start the health monitoring loop."""
if self._monitoring:
logger.warning("Health monitoring already running")
return
self._monitoring = True
self._task = asyncio.create_task(self._monitoring_loop())
logger.info("Container health monitoring started")
async def stop_monitoring(self):
"""Stop the health monitoring loop."""
if not self._monitoring:
return
self._monitoring = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
logger.info("Container health monitoring stopped")
async def _monitoring_loop(self):
"""Main monitoring loop."""
while self._monitoring:
try:
await self._perform_health_checks()
await self._cleanup_old_history()
except Exception as e:
logger.error("Error in health monitoring loop", extra={"error": str(e)})
await asyncio.sleep(self.check_interval)
async def _perform_health_checks(self):
"""Perform health checks on all running containers."""
if not self.session_manager:
return
# Get all running sessions
running_sessions = [
session
for session in self.session_manager.sessions.values()
if session.status == "running"
]
if not running_sessions:
return
logger.debug(f"Checking health of {len(running_sessions)} running containers")
# Perform health checks concurrently
tasks = [self._check_container_health(session) for session in running_sessions]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
for i, result in enumerate(results):
session = running_sessions[i]
if isinstance(result, Exception):
logger.error(
"Health check failed",
extra={
"session_id": session.session_id,
"container_id": session.container_id,
"error": str(result),
},
)
continue
await self._process_health_result(result)
async def _check_container_health(self, session) -> HealthCheckResult:
"""Check the health of a single container."""
start_time = asyncio.get_event_loop().time()
try:
# Check if container exists and is running
if not session.container_id:
return HealthCheckResult(
session.session_id,
session.container_id or "unknown",
ContainerStatus.UNKNOWN,
error_message="No container ID",
)
# Get container status
container_info = await self._get_container_info(session.container_id)
if not container_info:
return HealthCheckResult(
session.session_id,
session.container_id,
ContainerStatus.FAILED,
error_message="Container not found",
)
# Check container state
state = container_info.get("State", {})
status = state.get("Status", "unknown")
if status != "running":
return HealthCheckResult(
session.session_id,
session.container_id,
ContainerStatus.FAILED,
error_message=f"Container status: {status}",
)
# Check health status if available
health = state.get("Health", {})
if health:
health_status = health.get("Status", "unknown")
if health_status == "healthy":
response_time = (
asyncio.get_event_loop().time() - start_time
) * 1000
return HealthCheckResult(
session.session_id,
session.container_id,
ContainerStatus.HEALTHY,
response_time=response_time,
metadata={
"docker_status": status,
"health_status": health_status,
},
)
elif health_status in ["unhealthy", "starting"]:
return HealthCheckResult(
session.session_id,
session.container_id,
ContainerStatus.UNHEALTHY,
error_message=f"Health check: {health_status}",
metadata={
"docker_status": status,
"health_status": health_status,
},
)
# If no health check configured, consider running containers healthy
response_time = (asyncio.get_event_loop().time() - start_time) * 1000
return HealthCheckResult(
session.session_id,
session.container_id,
ContainerStatus.HEALTHY,
response_time=response_time,
metadata={"docker_status": status},
)
except Exception as e:
response_time = (asyncio.get_event_loop().time() - start_time) * 1000
return HealthCheckResult(
session.session_id,
session.container_id or "unknown",
ContainerStatus.UNKNOWN,
response_time=response_time,
error_message=str(e),
)
async def _get_container_info(self, container_id: str) -> Optional[Dict[str, Any]]:
"""Get container information from Docker."""
try:
if self.docker_client:
# Try async Docker client first
container = await self.docker_client.get_container(container_id)
if hasattr(container, "_container"):
return await container._container.show()
elif hasattr(container, "show"):
return await container.show()
else:
# Fallback to sync client if available
if (
hasattr(self.session_manager, "docker_client")
and self.session_manager.docker_client
):
container = self.session_manager.docker_client.containers.get(
container_id
)
return container.attrs
except Exception as e:
logger.debug(
f"Failed to get container info for {container_id}",
extra={"error": str(e)},
)
return None
async def _process_health_result(self, result: HealthCheckResult):
"""Process a health check result and take appropriate action."""
# Store result in history
if result.session_id not in self._health_history:
self._health_history[result.session_id] = []
self._health_history[result.session_id].append(result)
# Keep only recent history (last 10 checks)
if len(self._health_history[result.session_id]) > 10:
self._health_history[result.session_id] = self._health_history[
result.session_id
][-10:]
# Log result
log_extra = result.to_dict()
if result.status == ContainerStatus.HEALTHY:
logger.debug("Container health check passed", extra=log_extra)
elif result.status == ContainerStatus.UNHEALTHY:
logger.warning("Container health check failed", extra=log_extra)
elif result.status in [ContainerStatus.FAILED, ContainerStatus.UNKNOWN]:
logger.error("Container health check critical", extra=log_extra)
# Check if restart is needed
await self._check_restart_needed(result)
async def _check_restart_needed(self, result: HealthCheckResult):
"""Check if a container needs to be restarted based on health history."""
if result.status == ContainerStatus.HEALTHY:
# Reset restart count on successful health check
if result.session_id in self._restart_counts:
self._restart_counts[result.session_id] = 0
return
# Count recent failures
recent_results = self._health_history.get(result.session_id, [])
recent_failures = sum(
1
for r in recent_results[-self.failure_threshold :]
if r.status
in [
ContainerStatus.UNHEALTHY,
ContainerStatus.FAILED,
ContainerStatus.UNKNOWN,
]
)
if recent_failures >= self.failure_threshold:
await self._restart_container(result.session_id, result.container_id)
async def _restart_container(self, session_id: str, container_id: str):
"""Restart a failed container."""
# Check restart limit
restart_count = self._restart_counts.get(session_id, 0)
if restart_count >= self.max_restart_attempts:
logger.error(
"Container restart limit exceeded",
extra={
"session_id": session_id,
"container_id": container_id,
"restart_attempts": restart_count,
},
)
# Mark session as failed
await self._mark_session_failed(
session_id, f"Restart limit exceeded ({restart_count} attempts)"
)
return
logger.info(
"Attempting container restart",
extra={
"session_id": session_id,
"container_id": container_id,
"restart_attempt": restart_count + 1,
},
)
try:
# Stop the container
await self._stop_container(container_id)
# Wait before restart
await asyncio.sleep(self.restart_delay)
# Start new container for the session
session = await self.session_manager.get_session(session_id)
if session:
# Update restart count
self._restart_counts[session_id] = restart_count + 1
# Mark as restarting
await self._update_session_status(session_id, "restarting")
# Trigger container restart through session manager
if self.session_manager:
# Create new container for the session
await self.session_manager.create_session()
logger.info(
"Container restart initiated",
extra={
"session_id": session_id,
"restart_attempt": restart_count + 1,
},
)
# Log security event
log_security_event(
"container_restart",
"warning",
{
"session_id": session_id,
"container_id": container_id,
"reason": "health_check_failure",
},
)
except Exception as e:
logger.error(
"Container restart failed",
extra={
"session_id": session_id,
"container_id": container_id,
"error": str(e),
},
)
async def _stop_container(self, container_id: str):
"""Stop a container."""
try:
if self.docker_client:
container = await self.docker_client.get_container(container_id)
await self.docker_client.stop_container(container, timeout=10)
elif (
hasattr(self.session_manager, "docker_client")
and self.session_manager.docker_client
):
container = self.session_manager.docker_client.containers.get(
container_id
)
container.stop(timeout=10)
except Exception as e:
logger.warning(
"Failed to stop container during restart",
extra={"container_id": container_id, "error": str(e)},
)
async def _update_session_status(self, session_id: str, status: str):
"""Update session status."""
if self.session_manager:
session = self.session_manager.sessions.get(session_id)
if session:
session.status = status
# Update in database if using database storage
if (
hasattr(self.session_manager, "USE_DATABASE_STORAGE")
and self.session_manager.USE_DATABASE_STORAGE
):
try:
from database import SessionModel
await SessionModel.update_session(
session_id, {"status": status}
)
except Exception as e:
logger.warning(
"Failed to update session status in database",
extra={"session_id": session_id, "error": str(e)},
)
async def _mark_session_failed(self, session_id: str, reason: str):
"""Mark a session as permanently failed."""
await self._update_session_status(session_id, "failed")
logger.error(
"Session marked as failed",
extra={"session_id": session_id, "reason": reason},
)
# Log security event
log_security_event(
"session_failure", "error", {"session_id": session_id, "reason": reason}
)
async def _cleanup_old_history(self):
"""Clean up old health check history."""
cutoff_time = datetime.utcnow() - timedelta(hours=1) # Keep last hour
for session_id in list(self._health_history.keys()):
# Remove old results
self._health_history[session_id] = [
result
for result in self._health_history[session_id]
if result.timestamp > cutoff_time
]
# Remove empty histories
if not self._health_history[session_id]:
del self._health_history[session_id]
def get_health_stats(self, session_id: Optional[str] = None) -> Dict[str, Any]:
"""Get health monitoring statistics."""
stats = {
"monitoring_active": self._monitoring,
"check_interval": self.check_interval,
"total_sessions_monitored": len(self._health_history),
"sessions_with_failures": len(
[
sid
for sid, history in self._health_history.items()
if any(
r.status != ContainerStatus.HEALTHY for r in history[-5:]
) # Last 5 checks
]
),
"restart_counts": dict(self._restart_counts),
}
if session_id and session_id in self._health_history:
recent_results = self._health_history[session_id][-10:] # Last 10 checks
stats[f"session_{session_id}"] = {
"total_checks": len(recent_results),
"healthy_checks": sum(
1 for r in recent_results if r.status == ContainerStatus.HEALTHY
),
"failed_checks": sum(
1 for r in recent_results if r.status != ContainerStatus.HEALTHY
),
"average_response_time": sum(
r.response_time or 0 for r in recent_results if r.response_time
)
/ max(1, sum(1 for r in recent_results if r.response_time)),
"last_check": recent_results[-1].to_dict() if recent_results else None,
}
return stats
def get_health_history(
self, session_id: str, limit: int = 50
) -> List[Dict[str, Any]]:
"""Get health check history for a session."""
if session_id not in self._health_history:
return []
return [
result.to_dict() for result in self._health_history[session_id][-limit:]
]
# Global health monitor instance
_container_health_monitor = ContainerHealthMonitor()
def get_container_health_monitor() -> ContainerHealthMonitor:
"""Get the global container health monitor instance."""
return _container_health_monitor
async def start_container_health_monitoring(session_manager=None, docker_client=None):
"""Start container health monitoring."""
monitor = get_container_health_monitor()
if session_manager:
monitor.set_dependencies(session_manager, docker_client)
await monitor.start_monitoring()
async def stop_container_health_monitoring():
"""Stop container health monitoring."""
monitor = get_container_health_monitor()
await monitor.stop_monitoring()
def get_container_health_stats(session_id: Optional[str] = None) -> Dict[str, Any]:
"""Get container health statistics."""
monitor = get_container_health_monitor()
return monitor.get_health_stats(session_id)
def get_container_health_history(
session_id: str, limit: int = 50
) -> List[Dict[str, Any]]:
"""Get container health check history."""
monitor = get_container_health_monitor()
return monitor.get_health_history(session_id, limit)

406
session-manager/database.py Normal file
View File

@@ -0,0 +1,406 @@
"""
Database Models and Connection Management for Session Persistence
Provides PostgreSQL-backed session storage with connection pooling,
migrations, and health monitoring for production reliability.
"""
import os
import asyncpg
import json
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
from contextlib import asynccontextmanager
import logging
from logging_config import get_logger
logger = get_logger(__name__)
class DatabaseConnection:
"""PostgreSQL connection management with pooling."""
def __init__(self):
self._pool: Optional[asyncpg.Pool] = None
self._config = {
"host": os.getenv("DB_HOST", "localhost"),
"port": int(os.getenv("DB_PORT", "5432")),
"user": os.getenv("DB_USER", "lovdata"),
"password": os.getenv("DB_PASSWORD", "password"),
"database": os.getenv("DB_NAME", "lovdata_chat"),
"min_size": int(os.getenv("DB_MIN_CONNECTIONS", "5")),
"max_size": int(os.getenv("DB_MAX_CONNECTIONS", "20")),
"max_queries": int(os.getenv("DB_MAX_QUERIES", "50000")),
"max_inactive_connection_lifetime": float(
os.getenv("DB_MAX_INACTIVE_LIFETIME", "300.0")
),
}
async def connect(self) -> None:
"""Establish database connection pool."""
if self._pool:
return
try:
self._pool = await asyncpg.create_pool(**self._config)
logger.info(
"Database connection pool established",
extra={
"host": self._config["host"],
"port": self._config["port"],
"database": self._config["database"],
"min_connections": self._config["min_size"],
"max_connections": self._config["max_size"],
},
)
except Exception as e:
logger.error(
"Failed to establish database connection", extra={"error": str(e)}
)
raise
async def disconnect(self) -> None:
"""Close database connection pool."""
if self._pool:
await self._pool.close()
self._pool = None
logger.info("Database connection pool closed")
async def get_connection(self) -> asyncpg.Connection:
"""Get a database connection from the pool."""
if not self._pool:
await self.connect()
return await self._pool.acquire()
async def release_connection(self, conn: asyncpg.Connection) -> None:
"""Release a database connection back to the pool."""
if self._pool:
await self._pool.release(conn)
async def health_check(self) -> Dict[str, Any]:
"""Perform database health check."""
try:
conn = await self.get_connection()
result = await conn.fetchval("SELECT 1")
await self.release_connection(conn)
if result == 1:
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
else:
return {"status": "unhealthy", "error": "Health check query failed"}
except Exception as e:
logger.error("Database health check failed", extra={"error": str(e)})
return {"status": "unhealthy", "error": str(e)}
@asynccontextmanager
async def transaction(self):
"""Context manager for database transactions."""
conn = await self.get_connection()
try:
async with conn.transaction():
yield conn
finally:
await self.release_connection(conn)
# Global database connection instance
_db_connection = DatabaseConnection()
@asynccontextmanager
async def get_db_connection():
"""Context manager for database connections."""
conn = await _db_connection.get_connection()
try:
yield conn
finally:
await _db_connection.release_connection(conn)
async def init_database() -> None:
"""Initialize database and run migrations."""
logger.info("Initializing database")
async with get_db_connection() as conn:
# Create sessions table
await conn.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id VARCHAR(32) PRIMARY KEY,
container_name VARCHAR(255) NOT NULL,
container_id VARCHAR(255),
host_dir VARCHAR(1024) NOT NULL,
port INTEGER,
auth_token VARCHAR(255),
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
last_accessed TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
status VARCHAR(50) NOT NULL DEFAULT 'creating',
metadata JSONB DEFAULT '{}'
)
""")
# Create indexes for performance
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
CREATE INDEX IF NOT EXISTS idx_sessions_container_name ON sessions(container_name);
""")
# Create cleanup function for expired sessions
await conn.execute("""
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
RETURNS INTEGER AS $$
DECLARE
deleted_count INTEGER;
BEGIN
DELETE FROM sessions
WHERE last_accessed < NOW() - INTERVAL '1 hour';
GET DIAGNOSTICS deleted_count = ROW_COUNT;
RETURN deleted_count;
END;
$$ LANGUAGE plpgsql;
""")
logger.info("Database initialized and migrations applied")
async def shutdown_database() -> None:
"""Shutdown database connections."""
await _db_connection.disconnect()
class SessionModel:
"""Database model for sessions."""
@staticmethod
async def create_session(session_data: Dict[str, Any]) -> Dict[str, Any]:
"""Create a new session in the database."""
async with get_db_connection() as conn:
row = await conn.fetchrow(
"""
INSERT INTO sessions (
session_id, container_name, container_id, host_dir, port,
auth_token, status, metadata
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING session_id, container_name, container_id, host_dir, port,
auth_token, created_at, last_accessed, status, metadata
""",
session_data["session_id"],
session_data["container_name"],
session_data.get("container_id"),
session_data["host_dir"],
session_data.get("port"),
session_data.get("auth_token"),
session_data.get("status", "creating"),
json.dumps(session_data.get("metadata", {})),
)
if row:
return dict(row)
raise ValueError("Failed to create session")
@staticmethod
async def get_session(session_id: str) -> Optional[Dict[str, Any]]:
"""Get a session by ID."""
async with get_db_connection() as conn:
# Update last_accessed timestamp
row = await conn.fetchrow(
"""
UPDATE sessions
SET last_accessed = NOW()
WHERE session_id = $1
RETURNING session_id, container_name, container_id, host_dir, port,
auth_token, created_at, last_accessed, status, metadata
""",
session_id,
)
if row:
result = dict(row)
result["metadata"] = json.loads(result["metadata"] or "{}")
return result
return None
@staticmethod
async def update_session(session_id: str, updates: Dict[str, Any]) -> bool:
"""Update session fields."""
if not updates:
return True
async with get_db_connection() as conn:
# Build dynamic update query
set_parts = []
values = [session_id]
param_index = 2
for key, value in updates.items():
if key in ["session_id", "created_at"]: # Don't update these
continue
if key == "metadata":
set_parts.append(f"metadata = ${param_index}")
values.append(json.dumps(value))
elif key == "last_accessed":
set_parts.append(f"last_accessed = NOW()")
else:
set_parts.append(f"{key} = ${param_index}")
values.append(value)
param_index += 1
if not set_parts:
return True
query = f"""
UPDATE sessions
SET {", ".join(set_parts)}
WHERE session_id = $1
"""
result = await conn.execute(query, *values)
return result == "UPDATE 1"
@staticmethod
async def delete_session(session_id: str) -> bool:
"""Delete a session."""
async with get_db_connection() as conn:
result = await conn.execute(
"DELETE FROM sessions WHERE session_id = $1", session_id
)
return result == "DELETE 1"
@staticmethod
async def list_sessions(limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
"""List sessions with pagination."""
async with get_db_connection() as conn:
rows = await conn.fetch(
"""
SELECT session_id, container_name, container_id, host_dir, port,
auth_token, created_at, last_accessed, status, metadata
FROM sessions
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
""",
limit,
offset,
)
result = []
for row in rows:
session = dict(row)
session["metadata"] = json.loads(session["metadata"] or "{}")
result.append(session)
return result
@staticmethod
async def count_sessions() -> int:
"""Count total sessions."""
async with get_db_connection() as conn:
return await conn.fetchval("SELECT COUNT(*) FROM sessions")
@staticmethod
async def cleanup_expired_sessions() -> int:
"""Clean up expired sessions using database function."""
async with get_db_connection() as conn:
return await conn.fetchval("SELECT cleanup_expired_sessions()")
@staticmethod
async def get_active_sessions_count() -> int:
"""Get count of active (running) sessions."""
async with get_db_connection() as conn:
return await conn.fetchval("""
SELECT COUNT(*) FROM sessions
WHERE status = 'running'
""")
@staticmethod
async def get_sessions_by_status(status: str) -> List[Dict[str, Any]]:
"""Get sessions by status."""
async with get_db_connection() as conn:
rows = await conn.fetch(
"""
SELECT session_id, container_name, container_id, host_dir, port,
auth_token, created_at, last_accessed, status, metadata
FROM sessions
WHERE status = $1
ORDER BY created_at DESC
""",
status,
)
result = []
for row in rows:
session = dict(row)
session["metadata"] = json.loads(session["metadata"] or "{}")
result.append(session)
return result
# Migration utilities
async def run_migrations():
"""Run database migrations."""
logger.info("Running database migrations")
async with get_db_connection() as conn:
# Migration 1: Add metadata column if it doesn't exist
try:
await conn.execute("""
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS metadata JSONB DEFAULT '{}'
""")
logger.info("Migration: Added metadata column")
except Exception as e:
logger.warning(f"Migration metadata column may already exist: {e}")
# Migration 2: Add indexes if they don't exist
try:
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
""")
logger.info("Migration: Added performance indexes")
except Exception as e:
logger.warning(f"Migration indexes may already exist: {e}")
logger.info("Database migrations completed")
# Health monitoring
async def get_database_stats() -> Dict[str, Any]:
"""Get database statistics and health information."""
try:
async with get_db_connection() as conn:
# Get basic stats
session_count = await conn.fetchval("SELECT COUNT(*) FROM sessions")
active_sessions = await conn.fetchval(
"SELECT COUNT(*) FROM sessions WHERE status = 'running'"
)
oldest_session = await conn.fetchval("SELECT MIN(created_at) FROM sessions")
newest_session = await conn.fetchval("SELECT MAX(created_at) FROM sessions")
# Get database size information
db_size = await conn.fetchval("""
SELECT pg_size_pretty(pg_database_size(current_database()))
""")
return {
"total_sessions": session_count,
"active_sessions": active_sessions,
"oldest_session": oldest_session.isoformat()
if oldest_session
else None,
"newest_session": newest_session.isoformat()
if newest_session
else None,
"database_size": db_size,
"status": "healthy",
}
except Exception as e:
logger.error("Failed to get database stats", extra={"error": str(e)})
return {
"status": "unhealthy",
"error": str(e),
}

View File

@@ -0,0 +1,635 @@
"""
Docker Service Layer
Provides a clean abstraction for Docker operations, separating container management
from business logic. Enables easy testing, mocking, and future Docker client changes.
"""
import os
import logging
from typing import Dict, List, Optional, Any, Tuple
from datetime import datetime
from logging_config import get_logger
logger = get_logger(__name__)
class ContainerInfo:
"""Container information data structure."""
def __init__(
self,
container_id: str,
name: str,
image: str,
status: str,
ports: Optional[Dict[str, int]] = None,
created_at: Optional[datetime] = None,
health_status: Optional[str] = None,
):
self.container_id = container_id
self.name = name
self.image = image
self.status = status
self.ports = ports or {}
self.created_at = created_at or datetime.utcnow()
self.health_status = health_status
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"container_id": self.container_id,
"name": self.name,
"image": self.image,
"status": self.status,
"ports": self.ports,
"created_at": self.created_at.isoformat() if self.created_at else None,
"health_status": self.health_status,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ContainerInfo":
"""Create from dictionary."""
return cls(
container_id=data["container_id"],
name=data["name"],
image=data["image"],
status=data["status"],
ports=data.get("ports", {}),
created_at=datetime.fromisoformat(data["created_at"])
if data.get("created_at")
else None,
health_status=data.get("health_status"),
)
class DockerOperationError(Exception):
"""Docker operation error."""
def __init__(
self, operation: str, container_id: Optional[str] = None, message: str = ""
):
self.operation = operation
self.container_id = container_id
self.message = message
super().__init__(f"Docker {operation} failed: {message}")
class DockerService:
"""
Docker service abstraction layer.
Provides a clean interface for container operations,
enabling easy testing and future Docker client changes.
"""
def __init__(self, use_async: bool = True):
"""
Initialize Docker service.
Args:
use_async: Whether to use async Docker operations
"""
self.use_async = use_async
self._docker_client = None
self._initialized = False
logger.info("Docker service initialized", extra={"async_mode": use_async})
async def initialize(self) -> None:
"""Initialize the Docker client connection."""
if self._initialized:
return
try:
if self.use_async:
# Initialize async Docker client
from async_docker_client import AsyncDockerClient
self._docker_client = AsyncDockerClient()
await self._docker_client.connect()
else:
# Initialize sync Docker client
import docker
tls_config = docker.tls.TLSConfig(
ca_cert=os.getenv("DOCKER_CA_CERT", "/etc/docker/certs/ca.pem"),
client_cert=(
os.getenv(
"DOCKER_CLIENT_CERT", "/etc/docker/certs/client-cert.pem"
),
os.getenv(
"DOCKER_CLIENT_KEY", "/etc/docker/certs/client-key.pem"
),
),
verify=True,
)
docker_host = os.getenv(
"DOCKER_HOST", "tcp://host.docker.internal:2376"
)
self._docker_client = docker.from_env()
self._docker_client.api = docker.APIClient(
base_url=docker_host, tls=tls_config, version="auto"
)
# Test connection
self._docker_client.ping()
self._initialized = True
logger.info("Docker service connection established")
except Exception as e:
logger.error("Failed to initialize Docker service", extra={"error": str(e)})
raise DockerOperationError("initialization", message=str(e))
async def shutdown(self) -> None:
"""Shutdown the Docker client connection."""
if not self._initialized:
return
try:
if self.use_async and self._docker_client:
await self._docker_client.disconnect()
# Sync client doesn't need explicit shutdown
self._initialized = False
logger.info("Docker service connection closed")
except Exception as e:
logger.warning(
"Error during Docker service shutdown", extra={"error": str(e)}
)
async def ping(self) -> bool:
"""Test Docker daemon connectivity."""
if not self._initialized:
await self.initialize()
try:
if self.use_async:
return await self._docker_client.ping()
else:
self._docker_client.ping()
return True
except Exception as e:
logger.warning("Docker ping failed", extra={"error": str(e)})
return False
async def create_container(
self,
name: str,
image: str,
volumes: Optional[Dict[str, Dict[str, str]]] = None,
ports: Optional[Dict[str, int]] = None,
environment: Optional[Dict[str, str]] = None,
network_mode: str = "bridge",
mem_limit: Optional[str] = None,
cpu_quota: Optional[int] = None,
cpu_period: Optional[int] = None,
tmpfs: Optional[Dict[str, str]] = None,
**kwargs,
) -> ContainerInfo:
"""
Create a Docker container.
Args:
name: Container name
image: Container image
volumes: Volume mounts
ports: Port mappings
environment: Environment variables
network_mode: Network mode
mem_limit: Memory limit
cpu_quota: CPU quota
cpu_period: CPU period
tmpfs: tmpfs mounts
**kwargs: Additional options
Returns:
ContainerInfo: Information about created container
Raises:
DockerOperationError: If container creation fails
"""
if not self._initialized:
await self.initialize()
try:
logger.info(
"Creating container",
extra={
"container_name": name,
"image": image,
"memory_limit": mem_limit,
"cpu_quota": cpu_quota,
},
)
if self.use_async:
container = await self._docker_client.create_container(
image=image,
name=name,
volumes=volumes,
ports=ports,
environment=environment,
network_mode=network_mode,
mem_limit=mem_limit,
cpu_quota=cpu_quota,
cpu_period=cpu_period,
tmpfs=tmpfs,
**kwargs,
)
return ContainerInfo(
container_id=container.id,
name=name,
image=image,
status="created",
ports=ports,
)
else:
container = self._docker_client.containers.run(
image,
name=name,
volumes=volumes,
ports={f"{port}/tcp": port for port in ports.values()}
if ports
else None,
environment=environment,
network_mode=network_mode,
mem_limit=mem_limit,
cpu_quota=cpu_quota,
cpu_period=cpu_period,
tmpfs=tmpfs,
detach=True,
**kwargs,
)
return ContainerInfo(
container_id=container.id,
name=name,
image=image,
status="running",
ports=ports,
)
except Exception as e:
logger.error(
"Container creation failed",
extra={"container_name": name, "image": image, "error": str(e)},
)
raise DockerOperationError("create_container", name, str(e))
async def start_container(self, container_id: str) -> None:
"""
Start a Docker container.
Args:
container_id: Container ID
Raises:
DockerOperationError: If container start fails
"""
if not self._initialized:
await self.initialize()
try:
logger.info("Starting container", extra={"container_id": container_id})
if self.use_async:
container = await self._docker_client.get_container(container_id)
await self._docker_client.start_container(container)
else:
container = self._docker_client.containers.get(container_id)
container.start()
logger.info(
"Container started successfully", extra={"container_id": container_id}
)
except Exception as e:
logger.error(
"Container start failed",
extra={"container_id": container_id, "error": str(e)},
)
raise DockerOperationError("start_container", container_id, str(e))
async def stop_container(self, container_id: str, timeout: int = 10) -> None:
"""
Stop a Docker container.
Args:
container_id: Container ID
timeout: Stop timeout in seconds
Raises:
DockerOperationError: If container stop fails
"""
if not self._initialized:
await self.initialize()
try:
logger.info(
"Stopping container",
extra={"container_id": container_id, "timeout": timeout},
)
if self.use_async:
container = await self._docker_client.get_container(container_id)
await self._docker_client.stop_container(container, timeout)
else:
container = self._docker_client.containers.get(container_id)
container.stop(timeout=timeout)
logger.info(
"Container stopped successfully", extra={"container_id": container_id}
)
except Exception as e:
logger.error(
"Container stop failed",
extra={"container_id": container_id, "error": str(e)},
)
raise DockerOperationError("stop_container", container_id, str(e))
async def remove_container(self, container_id: str, force: bool = False) -> None:
"""
Remove a Docker container.
Args:
container_id: Container ID
force: Force removal if running
Raises:
DockerOperationError: If container removal fails
"""
if not self._initialized:
await self.initialize()
try:
logger.info(
"Removing container",
extra={"container_id": container_id, "force": force},
)
if self.use_async:
container = await self._docker_client.get_container(container_id)
await self._docker_client.remove_container(container, force)
else:
container = self._docker_client.containers.get(container_id)
container.remove(force=force)
logger.info(
"Container removed successfully", extra={"container_id": container_id}
)
except Exception as e:
logger.error(
"Container removal failed",
extra={"container_id": container_id, "error": str(e)},
)
raise DockerOperationError("remove_container", container_id, str(e))
async def get_container_info(self, container_id: str) -> Optional[ContainerInfo]:
"""
Get information about a container.
Args:
container_id: Container ID
Returns:
ContainerInfo or None: Container information
"""
if not self._initialized:
await self.initialize()
try:
if self.use_async:
container_info = await self._docker_client._get_container_info(
container_id
)
if container_info:
state = container_info.get("State", {})
config = container_info.get("Config", {})
return ContainerInfo(
container_id=container_id,
name=config.get("Name", "").lstrip("/"),
image=config.get("Image", ""),
status=state.get("Status", "unknown"),
health_status=state.get("Health", {}).get("Status"),
)
else:
container = self._docker_client.containers.get(container_id)
return ContainerInfo(
container_id=container.id,
name=container.name,
image=container.image.tags[0]
if container.image.tags
else container.image.id,
status=container.status,
)
return None
except Exception as e:
logger.debug(
"Container info retrieval failed",
extra={"container_id": container_id, "error": str(e)},
)
return None
async def list_containers(
self, all: bool = False, filters: Optional[Dict[str, Any]] = None
) -> List[ContainerInfo]:
"""
List Docker containers.
Args:
all: Include stopped containers
filters: Container filters
Returns:
List[ContainerInfo]: List of container information
"""
if not self._initialized:
await self.initialize()
try:
if self.use_async:
containers = await self._docker_client.list_containers(
all=all, filters=filters
)
result = []
for container in containers:
container_info = await self._docker_client._get_container_info(
container.id
)
if container_info:
state = container_info.get("State", {})
config = container_info.get("Config", {})
result.append(
ContainerInfo(
container_id=container.id,
name=config.get("Name", "").lstrip("/"),
image=config.get("Image", ""),
status=state.get("Status", "unknown"),
health_status=state.get("Health", {}).get("Status"),
)
)
return result
else:
containers = self._docker_client.containers.list(
all=all, filters=filters
)
result = []
for container in containers:
result.append(
ContainerInfo(
container_id=container.id,
name=container.name,
image=container.image.tags[0]
if container.image.tags
else container.image.id,
status=container.status,
)
)
return result
except Exception as e:
logger.error("Container listing failed", extra={"error": str(e)})
return []
async def get_container_logs(self, container_id: str, tail: int = 100) -> str:
"""
Get container logs.
Args:
container_id: Container ID
tail: Number of log lines to retrieve
Returns:
str: Container logs
"""
if not self._initialized:
await self.initialize()
try:
if self.use_async:
container = await self._docker_client.get_container(container_id)
logs = await container.log(stdout=True, stderr=True, tail=tail)
return "\n".join(logs)
else:
container = self._docker_client.containers.get(container_id)
logs = container.logs(tail=tail).decode("utf-8")
return logs
except Exception as e:
logger.warning(
"Container log retrieval failed",
extra={"container_id": container_id, "error": str(e)},
)
return ""
async def get_system_info(self) -> Optional[Dict[str, Any]]:
"""
Get Docker system information.
Returns:
Dict or None: System information
"""
if not self._initialized:
await self.initialize()
try:
if self.use_async:
return await self._docker_client.get_system_info()
else:
return self._docker_client.info()
except Exception as e:
logger.warning("System info retrieval failed", extra={"error": str(e)})
return None
# Context manager support
async def __aenter__(self):
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.shutdown()
class MockDockerService(DockerService):
"""
Mock Docker service for testing without actual Docker.
Provides the same interface but with in-memory operations.
"""
def __init__(self):
super().__init__(use_async=False)
self._containers: Dict[str, ContainerInfo] = {}
self._next_id = 1
async def initialize(self) -> None:
"""Mock initialization - always succeeds."""
self._initialized = True
logger.info("Mock Docker service initialized")
async def shutdown(self) -> None:
"""Mock shutdown."""
self._containers.clear()
self._initialized = False
logger.info("Mock Docker service shutdown")
async def ping(self) -> bool:
"""Mock ping - always succeeds."""
return True
async def create_container(self, name: str, image: str, **kwargs) -> ContainerInfo:
"""Mock container creation."""
container_id = f"mock-{self._next_id}"
self._next_id += 1
container = ContainerInfo(
container_id=container_id, name=name, image=image, status="created"
)
self._containers[container_id] = container
logger.info(
"Mock container created",
extra={
"container_id": container_id,
"container_name": name,
"image": image,
},
)
return container
async def start_container(self, container_id: str) -> None:
"""Mock container start."""
if container_id in self._containers:
self._containers[container_id].status = "running"
logger.info("Mock container started", extra={"container_id": container_id})
async def stop_container(self, container_id: str, timeout: int = 10) -> None:
"""Mock container stop."""
if container_id in self._containers:
self._containers[container_id].status = "exited"
logger.info("Mock container stopped", extra={"container_id": container_id})
async def remove_container(self, container_id: str, force: bool = False) -> None:
"""Mock container removal."""
if container_id in self._containers:
del self._containers[container_id]
logger.info("Mock container removed", extra={"container_id": container_id})
async def get_container_info(self, container_id: str) -> Optional[ContainerInfo]:
"""Mock container info retrieval."""
return self._containers.get(container_id)
async def list_containers(
self, all: bool = False, filters: Optional[Dict[str, Any]] = None
) -> List[ContainerInfo]:
"""Mock container listing."""
return list(self._containers.values())

View File

@@ -0,0 +1,252 @@
"""
Host IP Detection Utilities
Provides robust methods to detect the Docker host IP from within a container,
supporting multiple Docker environments and network configurations.
"""
import os
import socket
import asyncio
import logging
from typing import Optional, List
from functools import lru_cache
import time
logger = logging.getLogger(__name__)
class HostIPDetector:
"""Detects the Docker host IP address from container perspective."""
# Common Docker gateway IPs to try as fallbacks
COMMON_GATEWAYS = [
"172.17.0.1", # Default Docker bridge
"172.18.0.1", # Docker networks
"192.168.65.1", # Docker Desktop
"192.168.66.1", # Alternative Docker Desktop
]
def __init__(self):
self._detected_ip: Optional[str] = None
self._last_detection: float = 0
self._cache_timeout: float = 300 # 5 minutes cache
@lru_cache(maxsize=1)
def detect_host_ip(self) -> str:
"""
Detect the Docker host IP using multiple methods with fallbacks.
Returns:
str: The detected host IP address
Raises:
RuntimeError: If no host IP can be detected
"""
current_time = time.time()
# Use cached result if recent
if (
self._detected_ip
and (current_time - self._last_detection) < self._cache_timeout
):
logger.debug(f"Using cached host IP: {self._detected_ip}")
return self._detected_ip
logger.info("Detecting Docker host IP...")
detection_methods = [
self._detect_via_docker_internal,
self._detect_via_gateway_env,
self._detect_via_route_table,
self._detect_via_network_connect,
self._detect_via_common_gateways,
]
for method in detection_methods:
try:
ip = method()
if ip and self._validate_ip(ip):
logger.info(
f"Successfully detected host IP using {method.__name__}: {ip}"
)
self._detected_ip = ip
self._last_detection = current_time
return ip
else:
logger.debug(f"Method {method.__name__} returned invalid IP: {ip}")
except Exception as e:
logger.debug(f"Method {method.__name__} failed: {e}")
# If all methods fail, raise an error
raise RuntimeError(
"Could not detect Docker host IP. Tried all detection methods. "
"Please check your Docker network configuration or set HOST_IP environment variable."
)
def _detect_via_docker_internal(self) -> Optional[str]:
"""Detect via host.docker.internal (Docker Desktop, Docker for Mac/Windows)."""
try:
# Try to resolve host.docker.internal
ip = socket.gethostbyname("host.docker.internal")
if ip != "127.0.0.1": # Make sure it's not localhost
return ip
except socket.gaierror:
pass
return None
def _detect_via_gateway_env(self) -> Optional[str]:
"""Detect via Docker gateway environment variables."""
# Check common Docker gateway environment variables
gateway_vars = [
"DOCKER_HOST_GATEWAY",
"GATEWAY",
"HOST_IP",
]
for var in gateway_vars:
ip = os.getenv(var)
if ip:
logger.debug(f"Found host IP in environment variable {var}: {ip}")
return ip
return None
def _detect_via_route_table(self) -> Optional[str]:
"""Detect via Linux route table (/proc/net/route)."""
try:
with open("/proc/net/route", "r") as f:
for line in f:
fields = line.strip().split()
if (
len(fields) >= 8
and fields[0] != "Iface"
and fields[7] == "00000000"
):
# Found default route, convert hex gateway to IP
gateway_hex = fields[2]
if len(gateway_hex) == 8:
# Convert from hex to IP (little endian)
ip_parts = []
for i in range(0, 8, 2):
ip_parts.append(str(int(gateway_hex[i : i + 2], 16)))
ip = ".".join(reversed(ip_parts))
if ip != "0.0.0.0":
return ip
except (IOError, ValueError, IndexError) as e:
logger.debug(f"Failed to read route table: {e}")
return None
def _detect_via_network_connect(self) -> Optional[str]:
"""Detect by attempting to connect to a known external service."""
try:
# Try to connect to a reliable external service to determine local IP
# We'll use the Docker daemon itself as a reference
docker_host = os.getenv("DOCKER_HOST", "tcp://host.docker.internal:2376")
if docker_host.startswith("tcp://"):
host_part = docker_host[6:].split(":")[0]
if host_part not in ["localhost", "127.0.0.1"]:
# Try to resolve the host
try:
ip = socket.gethostbyname(host_part)
if ip != "127.0.0.1":
return ip
except socket.gaierror:
pass
except Exception as e:
logger.debug(f"Network connect detection failed: {e}")
return None
def _detect_via_common_gateways(self) -> Optional[str]:
"""Try common Docker gateway IPs."""
for gateway in self.COMMON_GATEWAYS:
if self._test_ip_connectivity(gateway):
logger.debug(f"Found working gateway: {gateway}")
return gateway
return None
def _test_ip_connectivity(self, ip: str) -> bool:
"""Test if an IP address is reachable."""
try:
# Try to connect to a common port (Docker API or SSH)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1.0)
result = sock.connect_ex((ip, 22)) # SSH port, commonly available
sock.close()
return result == 0
except Exception:
return False
def _validate_ip(self, ip: str) -> bool:
"""Validate that the IP address is reasonable."""
try:
socket.inet_aton(ip)
# Basic validation - should not be localhost or invalid ranges
if ip.startswith("127."):
return False
if ip == "0.0.0.0":
return False
# Should be a private IP range
parts = ip.split(".")
if len(parts) != 4:
return False
first_octet = int(parts[0])
# Common Docker gateway ranges
return first_octet in [10, 172, 192]
except socket.error:
return False
async def async_detect_host_ip(self) -> str:
"""Async version of detect_host_ip for testing."""
import asyncio
import concurrent.futures
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
return await loop.run_in_executor(executor, self.detect_host_ip)
# Global instance for caching
_host_detector = HostIPDetector()
def get_host_ip() -> str:
"""
Get the Docker host IP address from container perspective.
This function caches the result for performance and tries multiple
detection methods with fallbacks for different Docker environments.
Returns:
str: The detected host IP address
Raises:
RuntimeError: If host IP detection fails
"""
return _host_detector.detect_host_ip()
async def async_get_host_ip() -> str:
"""
Async version of get_host_ip for use in async contexts.
Since the actual detection is not async, this just wraps the sync version.
"""
# Run in thread pool to avoid blocking async context
import concurrent.futures
import asyncio
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
return await loop.run_in_executor(executor, get_host_ip)
def reset_host_ip_cache():
"""Reset the cached host IP detection result."""
global _host_detector
_host_detector = HostIPDetector()

View File

@@ -0,0 +1,182 @@
"""
HTTP Connection Pool Manager
Provides a global httpx.AsyncClient instance with connection pooling
to eliminate the overhead of creating new HTTP clients for each proxy request.
"""
import asyncio
import logging
import time
from typing import Optional, Dict, Any
from contextlib import asynccontextmanager
import httpx
logger = logging.getLogger(__name__)
class HTTPConnectionPool:
"""Global HTTP connection pool manager for proxy operations."""
def __init__(self):
self._client: Optional[httpx.AsyncClient] = None
self._last_health_check: float = 0
self._health_check_interval: float = 60 # Check health every 60 seconds
self._is_healthy: bool = True
self._reconnect_lock = asyncio.Lock()
# Connection pool configuration
self._config = {
"limits": httpx.Limits(
max_keepalive_connections=20, # Keep connections alive
max_connections=100, # Max total connections
keepalive_expiry=300.0, # Keep connections alive for 5 minutes
),
"timeout": httpx.Timeout(
connect=10.0, # Connection timeout
read=30.0, # Read timeout
write=10.0, # Write timeout
pool=5.0, # Pool timeout
),
"follow_redirects": False,
"http2": False, # Disable HTTP/2 for simplicity
}
async def __aenter__(self):
await self.ensure_client()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Keep client alive - don't close it
pass
async def ensure_client(self) -> None:
"""Ensure the HTTP client is initialized and healthy."""
if self._client is None:
await self._create_client()
# Periodic health check
current_time = time.time()
if current_time - self._last_health_check > self._health_check_interval:
if not await self._check_client_health():
logger.warning("HTTP client health check failed, recreating client")
await self._recreate_client()
self._last_health_check = current_time
async def _create_client(self) -> None:
"""Create a new HTTP client with connection pooling."""
async with self._reconnect_lock:
if self._client:
await self._client.aclose()
self._client = httpx.AsyncClient(**self._config)
self._is_healthy = True
logger.info("HTTP connection pool client created")
async def _recreate_client(self) -> None:
"""Recreate the HTTP client (used when health check fails)."""
logger.info("Recreating HTTP connection pool client")
await self._create_client()
async def _check_client_health(self) -> bool:
"""Check if the HTTP client is still healthy."""
if not self._client:
return False
try:
# Simple health check - we could ping a reliable endpoint
# For now, just check if client is still responsive
# In a real implementation, you might ping a health endpoint
return self._is_healthy
except Exception as e:
logger.warning(f"HTTP client health check error: {e}")
return False
async def request(self, method: str, url: str, **kwargs) -> httpx.Response:
"""Make an HTTP request using the connection pool."""
await self.ensure_client()
if not self._client:
raise RuntimeError("HTTP client not available")
try:
response = await self._client.request(method, url, **kwargs)
return response
except (httpx.ConnectError, httpx.ConnectTimeout, httpx.PoolTimeout) as e:
# Connection-related errors - client might be unhealthy
logger.warning(f"Connection error, marking client as unhealthy: {e}")
self._is_healthy = False
raise
except Exception as e:
# Other errors - re-raise as-is
raise
async def close(self) -> None:
"""Close the HTTP client and cleanup resources."""
async with self._reconnect_lock:
if self._client:
await self._client.aclose()
self._client = None
self._is_healthy = False
logger.info("HTTP connection pool client closed")
async def get_pool_stats(self) -> Dict[str, Any]:
"""Get connection pool statistics."""
if not self._client:
return {"status": "not_initialized"}
# httpx doesn't expose detailed pool stats, but we can provide basic info
return {
"status": "healthy" if self._is_healthy else "unhealthy",
"last_health_check": self._last_health_check,
"config": {
"max_keepalive_connections": self._config[
"limits"
].max_keepalive_connections,
"max_connections": self._config["limits"].max_connections,
"keepalive_expiry": self._config["limits"].keepalive_expiry,
"connect_timeout": self._config["timeout"].connect,
"read_timeout": self._config["timeout"].read,
},
}
# Global HTTP connection pool instance
_http_pool = HTTPConnectionPool()
@asynccontextmanager
async def get_http_client():
"""Context manager for getting the global HTTP client."""
async with _http_pool:
yield _http_pool
async def make_http_request(method: str, url: str, **kwargs) -> httpx.Response:
"""Make an HTTP request using the global connection pool."""
async with get_http_client() as client:
return await client.request(method, url, **kwargs)
async def get_connection_pool_stats() -> Dict[str, Any]:
"""Get connection pool statistics."""
return await _http_pool.get_pool_stats()
async def close_connection_pool() -> None:
"""Close the global connection pool (for cleanup)."""
await _http_pool.close()
# Lifecycle management for FastAPI
async def init_http_pool() -> None:
"""Initialize the HTTP connection pool on startup."""
logger.info("Initializing HTTP connection pool")
await _http_pool.ensure_client()
async def shutdown_http_pool() -> None:
"""Shutdown the HTTP connection pool on shutdown."""
logger.info("Shutting down HTTP connection pool")
await _http_pool.close()

View File

@@ -0,0 +1,317 @@
"""
Structured Logging Configuration
Provides comprehensive logging infrastructure with structured logging,
request tracking, log formatting, and aggregation capabilities.
"""
import os
import sys
import json
import logging
import logging.handlers
from typing import Dict, Any, Optional
from datetime import datetime
from pathlib import Path
import threading
import uuid
class StructuredFormatter(logging.Formatter):
"""Structured JSON formatter for production logging."""
def format(self, record: logging.LogRecord) -> str:
# Create structured log entry
log_entry = {
"timestamp": datetime.utcnow().isoformat() + "Z",
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
"module": record.module,
"function": record.funcName,
"line": record.lineno,
}
# Add exception info if present
if record.exc_info:
log_entry["exception"] = self.formatException(record.exc_info)
# Add extra fields from record
if hasattr(record, "request_id"):
log_entry["request_id"] = record.request_id
if hasattr(record, "session_id"):
log_entry["session_id"] = record.session_id
if hasattr(record, "user_id"):
log_entry["user_id"] = record.user_id
if hasattr(record, "operation"):
log_entry["operation"] = record.operation
if hasattr(record, "duration_ms"):
log_entry["duration_ms"] = record.duration_ms
if hasattr(record, "status_code"):
log_entry["status_code"] = record.status_code
# Add any additional structured data
if hasattr(record, "__dict__"):
for key, value in record.__dict__.items():
if key not in [
"name",
"msg",
"args",
"levelname",
"levelno",
"pathname",
"filename",
"module",
"exc_info",
"exc_text",
"stack_info",
"lineno",
"funcName",
"created",
"msecs",
"relativeCreated",
"thread",
"threadName",
"processName",
"process",
"message",
]:
log_entry[key] = value
return json.dumps(log_entry, default=str)
class HumanReadableFormatter(logging.Formatter):
"""Human-readable formatter for development."""
def __init__(self):
super().__init__(
fmt="%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def format(self, record: logging.LogRecord) -> str:
# Add request ID to human readable format
if hasattr(record, "request_id"):
self._fmt = "%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d [%(request_id)s] - %(message)s"
else:
self._fmt = "%(asctime)s [%(levelname)8s] %(name)s:%(funcName)s:%(lineno)d - %(message)s"
return super().format(record)
class RequestContext:
"""Context manager for request-scoped logging."""
_local = threading.local()
def __init__(self, request_id: Optional[str] = None):
self.request_id = request_id or str(uuid.uuid4())[:8]
self._old_request_id = getattr(self._local, "request_id", None)
def __enter__(self):
self._local.request_id = self.request_id
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self._old_request_id is not None:
self._local.request_id = self._old_request_id
else:
delattr(self._local, "request_id")
@classmethod
def get_current_request_id(cls) -> Optional[str]:
"""Get the current request ID from thread local storage."""
return getattr(cls._local, "request_id", None)
class RequestAdapter(logging.LoggerAdapter):
"""Logger adapter that automatically adds request context."""
def __init__(self, logger: logging.Logger):
super().__init__(logger, {})
def process(self, msg: str, kwargs: Any) -> tuple:
"""Add request context to log records."""
request_id = RequestContext.get_current_request_id()
if request_id:
kwargs.setdefault("extra", {})["request_id"] = request_id
return msg, kwargs
def setup_logging(
level: str = "INFO",
format_type: str = "auto", # "json", "human", or "auto"
log_file: Optional[str] = None,
max_file_size: int = 10 * 1024 * 1024, # 10MB
backup_count: int = 5,
enable_console: bool = True,
enable_file: bool = True,
) -> logging.Logger:
"""
Set up comprehensive logging configuration.
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
format_type: Log format - "json", "human", or "auto" (detects from environment)
log_file: Path to log file (optional)
max_file_size: Maximum log file size in bytes
backup_count: Number of backup files to keep
enable_console: Enable console logging
enable_file: Enable file logging
Returns:
Configured root logger
"""
# Determine format type
if format_type == "auto":
# Use JSON for production, human-readable for development
format_type = "json" if os.getenv("ENVIRONMENT") == "production" else "human"
# Clear existing handlers
root_logger = logging.getLogger()
root_logger.handlers.clear()
# Set log level
numeric_level = getattr(logging, level.upper(), logging.INFO)
root_logger.setLevel(numeric_level)
# Create formatters
if format_type == "json":
formatter = StructuredFormatter()
else:
formatter = HumanReadableFormatter()
# Console handler
if enable_console:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(numeric_level)
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
# File handler with rotation
if enable_file and log_file:
# Ensure log directory exists
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.handlers.RotatingFileHandler(
log_file, maxBytes=max_file_size, backupCount=backup_count, encoding="utf-8"
)
file_handler.setLevel(numeric_level)
file_handler.setFormatter(StructuredFormatter()) # Always use JSON for files
root_logger.addHandler(file_handler)
# Create request adapter for the root logger
adapter = RequestAdapter(root_logger)
# Configure third-party loggers
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("docker").setLevel(logging.WARNING)
logging.getLogger("aiodeocker").setLevel(logging.WARNING)
logging.getLogger("asyncio").setLevel(logging.WARNING)
return adapter
def get_logger(name: str) -> RequestAdapter:
"""Get a configured logger with request context support."""
logger = logging.getLogger(name)
return RequestAdapter(logger)
def log_performance(operation: str, duration_ms: float, **kwargs) -> None:
"""Log performance metrics."""
logger = get_logger(__name__)
extra = {"operation": operation, "duration_ms": duration_ms, **kwargs}
logger.info(
f"Performance: {operation} completed in {duration_ms:.2f}ms", extra=extra
)
def log_request(
method: str, path: str, status_code: int, duration_ms: float, **kwargs
) -> None:
"""Log HTTP request metrics."""
logger = get_logger(__name__)
extra = {
"operation": "http_request",
"method": method,
"path": path,
"status_code": status_code,
"duration_ms": duration_ms,
**kwargs,
}
level = logging.INFO if status_code < 400 else logging.WARNING
logger.log(
level,
f"HTTP {method} {path} -> {status_code} ({duration_ms:.2f}ms)",
extra=extra,
)
def log_session_operation(session_id: str, operation: str, **kwargs) -> None:
"""Log session-related operations."""
logger = get_logger(__name__)
extra = {"session_id": session_id, "operation": operation, **kwargs}
logger.info(f"Session {operation}: {session_id}", extra=extra)
def log_security_event(event: str, severity: str = "info", **kwargs) -> None:
"""Log security-related events."""
logger = get_logger(__name__)
extra = {"security_event": event, "severity": severity, **kwargs}
level = getattr(logging, severity.upper(), logging.INFO)
logger.log(level, f"Security: {event}", extra=extra)
# Global logger instance
logger = get_logger(__name__)
# Initialize logging on import
_setup_complete = False
def init_logging():
"""Initialize logging system."""
global _setup_complete
if _setup_complete:
return
# Configuration from environment
level = os.getenv("LOG_LEVEL", "INFO")
format_type = os.getenv("LOG_FORMAT", "auto") # json, human, auto
log_file = os.getenv("LOG_FILE")
max_file_size = int(os.getenv("LOG_MAX_SIZE_MB", "10")) * 1024 * 1024
backup_count = int(os.getenv("LOG_BACKUP_COUNT", "5"))
enable_console = os.getenv("LOG_CONSOLE", "true").lower() == "true"
enable_file = (
os.getenv("LOG_FILE_ENABLED", "true").lower() == "true" and log_file is not None
)
setup_logging(
level=level,
format_type=format_type,
log_file=log_file,
max_file_size=max_file_size,
backup_count=backup_count,
enable_console=enable_console,
enable_file=enable_file,
)
logger.info(
"Structured logging initialized",
extra={
"level": level,
"format": format_type,
"log_file": log_file,
"max_file_size_mb": max_file_size // (1024 * 1024),
"backup_count": backup_count,
},
)
_setup_complete = True
# Initialize on import
init_logging()

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,9 @@
fastapi==0.104.1
uvicorn==0.24.0
docker>=7.1.0
aiodeocker>=0.21.0
asyncpg>=0.29.0
pydantic==2.5.0
python-multipart==0.0.6
httpx==0.25.2
httpx==0.25.2
psutil>=5.9.0

View File

@@ -0,0 +1,248 @@
"""
Resource Management and Monitoring Utilities
Provides validation, enforcement, and monitoring of container resource limits
to prevent resource exhaustion attacks and ensure fair resource allocation.
"""
import os
import psutil
import logging
from typing import Dict, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
@dataclass
class ResourceLimits:
"""Container resource limits configuration."""
memory_limit: str # e.g., "4g", "512m"
cpu_quota: int # CPU quota in microseconds
cpu_period: int # CPU period in microseconds
def validate(self) -> Tuple[bool, str]:
"""Validate resource limits configuration."""
# Validate memory limit format
memory_limit = self.memory_limit.lower()
if not (memory_limit.endswith(("g", "m", "k")) or memory_limit.isdigit()):
return (
False,
f"Invalid memory limit format: {self.memory_limit}. Use format like '4g', '512m', '256k'",
)
# Validate CPU quota and period
if self.cpu_quota <= 0:
return False, f"CPU quota must be positive, got {self.cpu_quota}"
if self.cpu_period <= 0:
return False, f"CPU period must be positive, got {self.cpu_period}"
if self.cpu_quota > self.cpu_period:
return (
False,
f"CPU quota ({self.cpu_quota}) cannot exceed CPU period ({self.cpu_period})",
)
return True, "Valid"
def to_docker_limits(self) -> Dict[str, any]:
"""Convert to Docker container limits format."""
return {
"mem_limit": self.memory_limit,
"cpu_quota": self.cpu_quota,
"cpu_period": self.cpu_period,
}
class ResourceMonitor:
"""Monitor system and container resource usage."""
def __init__(self):
self._last_check = datetime.now()
self._alerts_sent = set() # Track alerts to prevent spam
def get_system_resources(self) -> Dict[str, any]:
"""Get current system resource usage."""
try:
memory = psutil.virtual_memory()
cpu = psutil.cpu_percent(interval=1)
return {
"memory_percent": memory.percent / 100.0,
"memory_used_gb": memory.used / (1024**3),
"memory_total_gb": memory.total / (1024**3),
"cpu_percent": cpu / 100.0,
"cpu_count": psutil.cpu_count(),
}
except Exception as e:
logger.warning(f"Failed to get system resources: {e}")
return {}
def check_resource_limits(
self, limits: ResourceLimits, warning_thresholds: Dict[str, float]
) -> Dict[str, any]:
"""Check if system resources are approaching limits."""
system_resources = self.get_system_resources()
alerts = []
# Check memory usage
memory_usage = system_resources.get("memory_percent", 0)
memory_threshold = warning_thresholds.get("memory", 0.8)
if memory_usage >= memory_threshold:
alerts.append(
{
"type": "memory",
"level": "warning" if memory_usage < 0.95 else "critical",
"message": f"System memory usage at {memory_usage:.1%}",
"current": memory_usage,
"threshold": memory_threshold,
}
)
# Check CPU usage
cpu_usage = system_resources.get("cpu_percent", 0)
cpu_threshold = warning_thresholds.get("cpu", 0.9)
if cpu_usage >= cpu_threshold:
alerts.append(
{
"type": "cpu",
"level": "warning" if cpu_usage < 0.95 else "critical",
"message": f"System CPU usage at {cpu_usage:.1%}",
"current": cpu_usage,
"threshold": cpu_threshold,
}
)
return {
"system_resources": system_resources,
"alerts": alerts,
"timestamp": datetime.now(),
}
def should_throttle_sessions(self, resource_check: Dict) -> Tuple[bool, str]:
"""Determine if new sessions should be throttled based on resource usage."""
alerts = resource_check.get("alerts", [])
# Critical alerts always throttle
critical_alerts = [a for a in alerts if a["level"] == "critical"]
if critical_alerts:
return (
True,
f"Critical resource usage: {[a['message'] for a in critical_alerts]}",
)
# Multiple warnings also throttle
warning_alerts = [a for a in alerts if a["level"] == "warning"]
if len(warning_alerts) >= 2:
return (
True,
f"Multiple resource warnings: {[a['message'] for a in warning_alerts]}",
)
return False, "Resources OK"
class ResourceValidator:
"""Validate and parse resource limit configurations."""
@staticmethod
def parse_memory_limit(memory_str: str) -> Tuple[int, str]:
"""Parse memory limit string and return bytes."""
if not memory_str:
raise ValueError("Memory limit cannot be empty")
memory_str = memory_str.lower().strip()
# Handle different units
if memory_str.endswith("g"):
bytes_val = int(memory_str[:-1]) * (1024**3)
unit = "GB"
elif memory_str.endswith("m"):
bytes_val = int(memory_str[:-1]) * (1024**2)
unit = "MB"
elif memory_str.endswith("k"):
bytes_val = int(memory_str[:-1]) * 1024
unit = "KB"
else:
# Assume bytes if no unit
bytes_val = int(memory_str)
unit = "bytes"
if bytes_val <= 0:
raise ValueError(f"Memory limit must be positive, got {bytes_val}")
# Reasonable limits check
if bytes_val > 32 * (1024**3): # 32GB
logger.warning(f"Very high memory limit: {bytes_val} bytes")
return bytes_val, unit
@staticmethod
def validate_resource_config(
config: Dict[str, any],
) -> Tuple[bool, str, Optional[ResourceLimits]]:
"""Validate complete resource configuration."""
try:
limits = ResourceLimits(
memory_limit=config.get("memory_limit", "4g"),
cpu_quota=config.get("cpu_quota", 100000),
cpu_period=config.get("cpu_period", 100000),
)
valid, message = limits.validate()
if not valid:
return False, message, None
# Additional validation
memory_bytes, _ = ResourceValidator.parse_memory_limit(limits.memory_limit)
# Warn about potentially problematic configurations
if memory_bytes < 128 * (1024**2): # Less than 128MB
logger.warning("Very low memory limit may cause container instability")
return True, "Configuration valid", limits
except (ValueError, TypeError) as e:
return False, f"Invalid configuration: {e}", None
# Global instances
resource_monitor = ResourceMonitor()
def get_resource_limits() -> ResourceLimits:
"""Get validated resource limits from environment."""
config = {
"memory_limit": os.getenv("CONTAINER_MEMORY_LIMIT", "4g"),
"cpu_quota": int(os.getenv("CONTAINER_CPU_QUOTA", "100000")),
"cpu_period": int(os.getenv("CONTAINER_CPU_PERIOD", "100000")),
}
valid, message, limits = ResourceValidator.validate_resource_config(config)
if not valid or limits is None:
raise ValueError(f"Resource configuration error: {message}")
logger.info(
f"Using resource limits: memory={limits.memory_limit}, cpu_quota={limits.cpu_quota}"
)
return limits
def check_system_resources() -> Dict[str, any]:
"""Check current system resource status."""
limits = get_resource_limits()
warning_thresholds = {
"memory": float(os.getenv("MEMORY_WARNING_THRESHOLD", "0.8")),
"cpu": float(os.getenv("CPU_WARNING_THRESHOLD", "0.9")),
}
return resource_monitor.check_resource_limits(limits, warning_thresholds)
def should_throttle_sessions() -> Tuple[bool, str]:
"""Check if new sessions should be throttled due to resource constraints."""
resource_check = check_system_resources()
return resource_monitor.should_throttle_sessions(resource_check)

View File

@@ -0,0 +1,235 @@
"""
Token-Based Authentication for OpenCode Sessions
Provides secure token generation, validation, and management for individual
user sessions to prevent unauthorized access to OpenCode servers.
"""
import os
import uuid
import secrets
import hashlib
import hmac
from typing import Dict, Optional, Tuple
from datetime import datetime, timedelta
import logging
logger = logging.getLogger(__name__)
class SessionTokenManager:
"""Manages authentication tokens for OpenCode user sessions."""
def __init__(self):
# Token storage - in production, this should be in Redis/database
self._session_tokens: Dict[str, Dict] = {}
# Token configuration
self._token_length = int(os.getenv("SESSION_TOKEN_LENGTH", "32"))
self._token_expiry_hours = int(os.getenv("SESSION_TOKEN_EXPIRY_HOURS", "24"))
self._token_secret = os.getenv("SESSION_TOKEN_SECRET", self._generate_secret())
# Cleanup configuration
self._cleanup_interval_minutes = int(
os.getenv("TOKEN_CLEANUP_INTERVAL_MINUTES", "60")
)
def _generate_secret(self) -> str:
"""Generate a secure secret for token signing."""
return secrets.token_hex(32)
def generate_session_token(self, session_id: str) -> str:
"""
Generate a unique authentication token for a session.
Args:
session_id: The session identifier
Returns:
str: The authentication token
"""
# Generate cryptographically secure random token
token = secrets.token_urlsafe(self._token_length)
# Create token data with expiry
expiry = datetime.now() + timedelta(hours=self._token_expiry_hours)
# Store token information
self._session_tokens[session_id] = {
"token": token,
"session_id": session_id,
"created_at": datetime.now(),
"expires_at": expiry,
"last_used": datetime.now(),
}
logger.info(f"Generated authentication token for session {session_id}")
return token
def validate_session_token(self, session_id: str, token: str) -> Tuple[bool, str]:
"""
Validate a session token.
Args:
session_id: The session identifier
token: The token to validate
Returns:
Tuple[bool, str]: (is_valid, reason)
"""
# Check if session exists
if session_id not in self._session_tokens:
return False, "Session not found"
session_data = self._session_tokens[session_id]
# Check if token matches
if session_data["token"] != token:
return False, "Invalid token"
# Check if token has expired
if datetime.now() > session_data["expires_at"]:
# Clean up expired token
del self._session_tokens[session_id]
return False, "Token expired"
# Update last used time
session_data["last_used"] = datetime.now()
return True, "Valid"
def revoke_session_token(self, session_id: str) -> bool:
"""
Revoke a session token.
Args:
session_id: The session identifier
Returns:
bool: True if token was revoked, False if not found
"""
if session_id in self._session_tokens:
del self._session_tokens[session_id]
logger.info(f"Revoked authentication token for session {session_id}")
return True
return False
def rotate_session_token(self, session_id: str) -> Optional[str]:
"""
Rotate (regenerate) a session token.
Args:
session_id: The session identifier
Returns:
Optional[str]: New token if session exists, None otherwise
"""
if session_id not in self._session_tokens:
return None
# Generate new token
new_token = self.generate_session_token(session_id)
logger.info(f"Rotated authentication token for session {session_id}")
return new_token
def cleanup_expired_tokens(self) -> int:
"""
Clean up expired tokens.
Returns:
int: Number of tokens cleaned up
"""
now = datetime.now()
expired_sessions = []
for session_id, session_data in self._session_tokens.items():
if now > session_data["expires_at"]:
expired_sessions.append(session_id)
# Remove expired tokens
for session_id in expired_sessions:
del self._session_tokens[session_id]
if expired_sessions:
logger.info(
f"Cleaned up {len(expired_sessions)} expired authentication tokens"
)
return len(expired_sessions)
def get_session_token_info(self, session_id: str) -> Optional[Dict]:
"""
Get information about a session token.
Args:
session_id: The session identifier
Returns:
Optional[Dict]: Token information or None if not found
"""
if session_id not in self._session_tokens:
return None
session_data = self._session_tokens[session_id].copy()
# Remove sensitive token value
session_data.pop("token", None)
return session_data
def get_active_sessions_count(self) -> int:
"""Get the number of active sessions with tokens."""
return len(self._session_tokens)
def list_active_sessions(self) -> Dict[str, Dict]:
"""List all active sessions with token information (without token values)."""
result = {}
for session_id, session_data in self._session_tokens.items():
# Create copy without sensitive token
info = session_data.copy()
info.pop("token", None)
result[session_id] = info
return result
# Global token manager instance
_session_token_manager = SessionTokenManager()
def generate_session_auth_token(session_id: str) -> str:
"""Generate an authentication token for a session."""
return _session_token_manager.generate_session_token(session_id)
def validate_session_auth_token(session_id: str, token: str) -> Tuple[bool, str]:
"""Validate a session authentication token."""
return _session_token_manager.validate_session_token(session_id, token)
def revoke_session_auth_token(session_id: str) -> bool:
"""Revoke a session authentication token."""
return _session_token_manager.revoke_session_token(session_id)
def rotate_session_auth_token(session_id: str) -> Optional[str]:
"""Rotate a session authentication token."""
return _session_token_manager.rotate_session_auth_token(session_id)
def cleanup_expired_auth_tokens() -> int:
"""Clean up expired authentication tokens."""
return _session_token_manager.cleanup_expired_tokens()
def get_session_auth_info(session_id: str) -> Optional[Dict]:
"""Get authentication information for a session."""
return _session_token_manager.get_session_token_info(session_id)
def get_active_auth_sessions_count() -> int:
"""Get the number of active authenticated sessions."""
return _session_token_manager.get_active_sessions_count()
def list_active_auth_sessions() -> Dict[str, Dict]:
"""List all active authenticated sessions."""
return _session_token_manager.list_active_sessions()