""" Database Models and Connection Management for Session Persistence Provides PostgreSQL-backed session storage with connection pooling, migrations, and health monitoring for production reliability. """ import os import asyncpg import json from typing import Dict, List, Optional, Any from datetime import datetime, timedelta from contextlib import asynccontextmanager import logging from logging_config import get_logger logger = get_logger(__name__) class DatabaseConnection: """PostgreSQL connection management with pooling.""" def __init__(self): self._pool: Optional[asyncpg.Pool] = None self._config = { "host": os.getenv("DB_HOST", "localhost"), "port": int(os.getenv("DB_PORT", "5432")), "user": os.getenv("DB_USER", "lovdata"), "password": os.getenv("DB_PASSWORD", "password"), "database": os.getenv("DB_NAME", "lovdata_chat"), "min_size": int(os.getenv("DB_MIN_CONNECTIONS", "5")), "max_size": int(os.getenv("DB_MAX_CONNECTIONS", "20")), "max_queries": int(os.getenv("DB_MAX_QUERIES", "50000")), "max_inactive_connection_lifetime": float( os.getenv("DB_MAX_INACTIVE_LIFETIME", "300.0") ), } async def connect(self) -> None: """Establish database connection pool.""" if self._pool: return try: self._pool = await asyncpg.create_pool(**self._config) logger.info( "Database connection pool established", extra={ "host": self._config["host"], "port": self._config["port"], "database": self._config["database"], "min_connections": self._config["min_size"], "max_connections": self._config["max_size"], }, ) except Exception as e: logger.error( "Failed to establish database connection", extra={"error": str(e)} ) raise async def disconnect(self) -> None: """Close database connection pool.""" if self._pool: await self._pool.close() self._pool = None logger.info("Database connection pool closed") async def get_connection(self) -> asyncpg.Connection: """Get a database connection from the pool.""" if not self._pool: await self.connect() return await self._pool.acquire() async def release_connection(self, conn: asyncpg.Connection) -> None: """Release a database connection back to the pool.""" if self._pool: await self._pool.release(conn) async def health_check(self) -> Dict[str, Any]: """Perform database health check.""" try: conn = await self.get_connection() result = await conn.fetchval("SELECT 1") await self.release_connection(conn) if result == 1: return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()} else: return {"status": "unhealthy", "error": "Health check query failed"} except Exception as e: logger.error("Database health check failed", extra={"error": str(e)}) return {"status": "unhealthy", "error": str(e)} @asynccontextmanager async def transaction(self): """Context manager for database transactions.""" conn = await self.get_connection() try: async with conn.transaction(): yield conn finally: await self.release_connection(conn) # Global database connection instance _db_connection = DatabaseConnection() @asynccontextmanager async def get_db_connection(): """Context manager for database connections.""" conn = await _db_connection.get_connection() try: yield conn finally: await _db_connection.release_connection(conn) async def init_database() -> None: """Initialize database and run migrations.""" logger.info("Initializing database") async with get_db_connection() as conn: # Create sessions table await conn.execute(""" CREATE TABLE IF NOT EXISTS sessions ( session_id VARCHAR(32) PRIMARY KEY, container_name VARCHAR(255) NOT NULL, container_id VARCHAR(255), host_dir VARCHAR(1024) NOT NULL, port INTEGER, auth_token VARCHAR(255), created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), last_accessed TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), status VARCHAR(50) NOT NULL DEFAULT 'creating', metadata JSONB DEFAULT '{}' ) """) # Create indexes for performance await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status); CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed); CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at); CREATE INDEX IF NOT EXISTS idx_sessions_container_name ON sessions(container_name); """) # Create cleanup function for expired sessions await conn.execute(""" CREATE OR REPLACE FUNCTION cleanup_expired_sessions() RETURNS INTEGER AS $$ DECLARE deleted_count INTEGER; BEGIN DELETE FROM sessions WHERE last_accessed < NOW() - INTERVAL '1 hour'; GET DIAGNOSTICS deleted_count = ROW_COUNT; RETURN deleted_count; END; $$ LANGUAGE plpgsql; """) logger.info("Database initialized and migrations applied") async def shutdown_database() -> None: """Shutdown database connections.""" await _db_connection.disconnect() class SessionModel: """Database model for sessions.""" @staticmethod async def create_session(session_data: Dict[str, Any]) -> Dict[str, Any]: """Create a new session in the database.""" async with get_db_connection() as conn: row = await conn.fetchrow( """ INSERT INTO sessions ( session_id, container_name, container_id, host_dir, port, auth_token, status, metadata ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING session_id, container_name, container_id, host_dir, port, auth_token, created_at, last_accessed, status, metadata """, session_data["session_id"], session_data["container_name"], session_data.get("container_id"), session_data["host_dir"], session_data.get("port"), session_data.get("auth_token"), session_data.get("status", "creating"), json.dumps(session_data.get("metadata", {})), ) if row: return dict(row) raise ValueError("Failed to create session") @staticmethod async def get_session(session_id: str) -> Optional[Dict[str, Any]]: """Get a session by ID.""" async with get_db_connection() as conn: # Update last_accessed timestamp row = await conn.fetchrow( """ UPDATE sessions SET last_accessed = NOW() WHERE session_id = $1 RETURNING session_id, container_name, container_id, host_dir, port, auth_token, created_at, last_accessed, status, metadata """, session_id, ) if row: result = dict(row) result["metadata"] = json.loads(result["metadata"] or "{}") return result return None @staticmethod async def update_session(session_id: str, updates: Dict[str, Any]) -> bool: """Update session fields.""" if not updates: return True async with get_db_connection() as conn: # Build dynamic update query set_parts = [] values = [session_id] param_index = 2 for key, value in updates.items(): if key in ["session_id", "created_at"]: # Don't update these continue if key == "metadata": set_parts.append(f"metadata = ${param_index}") values.append(json.dumps(value)) elif key == "last_accessed": set_parts.append(f"last_accessed = NOW()") else: set_parts.append(f"{key} = ${param_index}") values.append(value) param_index += 1 if not set_parts: return True query = f""" UPDATE sessions SET {", ".join(set_parts)} WHERE session_id = $1 """ result = await conn.execute(query, *values) return result == "UPDATE 1" @staticmethod async def delete_session(session_id: str) -> bool: """Delete a session.""" async with get_db_connection() as conn: result = await conn.execute( "DELETE FROM sessions WHERE session_id = $1", session_id ) return result == "DELETE 1" @staticmethod async def list_sessions(limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]: """List sessions with pagination.""" async with get_db_connection() as conn: rows = await conn.fetch( """ SELECT session_id, container_name, container_id, host_dir, port, auth_token, created_at, last_accessed, status, metadata FROM sessions ORDER BY created_at DESC LIMIT $1 OFFSET $2 """, limit, offset, ) result = [] for row in rows: session = dict(row) session["metadata"] = json.loads(session["metadata"] or "{}") result.append(session) return result @staticmethod async def count_sessions() -> int: """Count total sessions.""" async with get_db_connection() as conn: return await conn.fetchval("SELECT COUNT(*) FROM sessions") @staticmethod async def cleanup_expired_sessions() -> int: """Clean up expired sessions using database function.""" async with get_db_connection() as conn: return await conn.fetchval("SELECT cleanup_expired_sessions()") @staticmethod async def get_active_sessions_count() -> int: """Get count of active (running) sessions.""" async with get_db_connection() as conn: return await conn.fetchval(""" SELECT COUNT(*) FROM sessions WHERE status = 'running' """) @staticmethod async def get_sessions_by_status(status: str) -> List[Dict[str, Any]]: """Get sessions by status.""" async with get_db_connection() as conn: rows = await conn.fetch( """ SELECT session_id, container_name, container_id, host_dir, port, auth_token, created_at, last_accessed, status, metadata FROM sessions WHERE status = $1 ORDER BY created_at DESC """, status, ) result = [] for row in rows: session = dict(row) session["metadata"] = json.loads(session["metadata"] or "{}") result.append(session) return result # Migration utilities async def run_migrations(): """Run database migrations.""" logger.info("Running database migrations") async with get_db_connection() as conn: # Migration 1: Add metadata column if it doesn't exist try: await conn.execute(""" ALTER TABLE sessions ADD COLUMN IF NOT EXISTS metadata JSONB DEFAULT '{}' """) logger.info("Migration: Added metadata column") except Exception as e: logger.warning(f"Migration metadata column may already exist: {e}") # Migration 2: Add indexes if they don't exist try: await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status); CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed); CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at); """) logger.info("Migration: Added performance indexes") except Exception as e: logger.warning(f"Migration indexes may already exist: {e}") logger.info("Database migrations completed") # Health monitoring async def get_database_stats() -> Dict[str, Any]: """Get database statistics and health information.""" try: async with get_db_connection() as conn: # Get basic stats session_count = await conn.fetchval("SELECT COUNT(*) FROM sessions") active_sessions = await conn.fetchval( "SELECT COUNT(*) FROM sessions WHERE status = 'running'" ) oldest_session = await conn.fetchval("SELECT MIN(created_at) FROM sessions") newest_session = await conn.fetchval("SELECT MAX(created_at) FROM sessions") # Get database size information db_size = await conn.fetchval(""" SELECT pg_size_pretty(pg_database_size(current_database())) """) return { "total_sessions": session_count, "active_sessions": active_sessions, "oldest_session": oldest_session.isoformat() if oldest_session else None, "newest_session": newest_session.isoformat() if newest_session else None, "database_size": db_size, "status": "healthy", } except Exception as e: logger.error("Failed to get database stats", extra={"error": str(e)}) return { "status": "unhealthy", "error": str(e), }