Files
lovdata-chat/session-manager/database.py

407 lines
14 KiB
Python

"""
Database Models and Connection Management for Session Persistence
Provides PostgreSQL-backed session storage with connection pooling,
migrations, and health monitoring for production reliability.
"""
import os
import asyncpg
import json
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
from contextlib import asynccontextmanager
import logging
from logging_config import get_logger
logger = get_logger(__name__)
class DatabaseConnection:
"""PostgreSQL connection management with pooling."""
def __init__(self):
self._pool: Optional[asyncpg.Pool] = None
self._config = {
"host": os.getenv("DB_HOST", "localhost"),
"port": int(os.getenv("DB_PORT", "5432")),
"user": os.getenv("DB_USER", "lovdata"),
"password": os.getenv("DB_PASSWORD", "password"),
"database": os.getenv("DB_NAME", "lovdata_chat"),
"min_size": int(os.getenv("DB_MIN_CONNECTIONS", "5")),
"max_size": int(os.getenv("DB_MAX_CONNECTIONS", "20")),
"max_queries": int(os.getenv("DB_MAX_QUERIES", "50000")),
"max_inactive_connection_lifetime": float(
os.getenv("DB_MAX_INACTIVE_LIFETIME", "300.0")
),
}
async def connect(self) -> None:
"""Establish database connection pool."""
if self._pool:
return
try:
self._pool = await asyncpg.create_pool(**self._config)
logger.info(
"Database connection pool established",
extra={
"host": self._config["host"],
"port": self._config["port"],
"database": self._config["database"],
"min_connections": self._config["min_size"],
"max_connections": self._config["max_size"],
},
)
except Exception as e:
logger.error(
"Failed to establish database connection", extra={"error": str(e)}
)
raise
async def disconnect(self) -> None:
"""Close database connection pool."""
if self._pool:
await self._pool.close()
self._pool = None
logger.info("Database connection pool closed")
async def get_connection(self) -> asyncpg.Connection:
"""Get a database connection from the pool."""
if not self._pool:
await self.connect()
return await self._pool.acquire()
async def release_connection(self, conn: asyncpg.Connection) -> None:
"""Release a database connection back to the pool."""
if self._pool:
await self._pool.release(conn)
async def health_check(self) -> Dict[str, Any]:
"""Perform database health check."""
try:
conn = await self.get_connection()
result = await conn.fetchval("SELECT 1")
await self.release_connection(conn)
if result == 1:
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
else:
return {"status": "unhealthy", "error": "Health check query failed"}
except Exception as e:
logger.error("Database health check failed", extra={"error": str(e)})
return {"status": "unhealthy", "error": str(e)}
@asynccontextmanager
async def transaction(self):
"""Context manager for database transactions."""
conn = await self.get_connection()
try:
async with conn.transaction():
yield conn
finally:
await self.release_connection(conn)
# Global database connection instance
_db_connection = DatabaseConnection()
@asynccontextmanager
async def get_db_connection():
"""Context manager for database connections."""
conn = await _db_connection.get_connection()
try:
yield conn
finally:
await _db_connection.release_connection(conn)
async def init_database() -> None:
"""Initialize database and run migrations."""
logger.info("Initializing database")
async with get_db_connection() as conn:
# Create sessions table
await conn.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id VARCHAR(32) PRIMARY KEY,
container_name VARCHAR(255) NOT NULL,
container_id VARCHAR(255),
host_dir VARCHAR(1024) NOT NULL,
port INTEGER,
auth_token VARCHAR(255),
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
last_accessed TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
status VARCHAR(50) NOT NULL DEFAULT 'creating',
metadata JSONB DEFAULT '{}'
)
""")
# Create indexes for performance
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
CREATE INDEX IF NOT EXISTS idx_sessions_container_name ON sessions(container_name);
""")
# Create cleanup function for expired sessions
await conn.execute("""
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
RETURNS INTEGER AS $$
DECLARE
deleted_count INTEGER;
BEGIN
DELETE FROM sessions
WHERE last_accessed < NOW() - INTERVAL '1 hour';
GET DIAGNOSTICS deleted_count = ROW_COUNT;
RETURN deleted_count;
END;
$$ LANGUAGE plpgsql;
""")
logger.info("Database initialized and migrations applied")
async def shutdown_database() -> None:
"""Shutdown database connections."""
await _db_connection.disconnect()
class SessionModel:
"""Database model for sessions."""
@staticmethod
async def create_session(session_data: Dict[str, Any]) -> Dict[str, Any]:
"""Create a new session in the database."""
async with get_db_connection() as conn:
row = await conn.fetchrow(
"""
INSERT INTO sessions (
session_id, container_name, container_id, host_dir, port,
auth_token, status, metadata
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING session_id, container_name, container_id, host_dir, port,
auth_token, created_at, last_accessed, status, metadata
""",
session_data["session_id"],
session_data["container_name"],
session_data.get("container_id"),
session_data["host_dir"],
session_data.get("port"),
session_data.get("auth_token"),
session_data.get("status", "creating"),
json.dumps(session_data.get("metadata", {})),
)
if row:
return dict(row)
raise ValueError("Failed to create session")
@staticmethod
async def get_session(session_id: str) -> Optional[Dict[str, Any]]:
"""Get a session by ID."""
async with get_db_connection() as conn:
# Update last_accessed timestamp
row = await conn.fetchrow(
"""
UPDATE sessions
SET last_accessed = NOW()
WHERE session_id = $1
RETURNING session_id, container_name, container_id, host_dir, port,
auth_token, created_at, last_accessed, status, metadata
""",
session_id,
)
if row:
result = dict(row)
result["metadata"] = json.loads(result["metadata"] or "{}")
return result
return None
@staticmethod
async def update_session(session_id: str, updates: Dict[str, Any]) -> bool:
"""Update session fields."""
if not updates:
return True
async with get_db_connection() as conn:
# Build dynamic update query
set_parts = []
values = [session_id]
param_index = 2
for key, value in updates.items():
if key in ["session_id", "created_at"]: # Don't update these
continue
if key == "metadata":
set_parts.append(f"metadata = ${param_index}")
values.append(json.dumps(value))
elif key == "last_accessed":
set_parts.append(f"last_accessed = NOW()")
else:
set_parts.append(f"{key} = ${param_index}")
values.append(value)
param_index += 1
if not set_parts:
return True
query = f"""
UPDATE sessions
SET {", ".join(set_parts)}
WHERE session_id = $1
"""
result = await conn.execute(query, *values)
return result == "UPDATE 1"
@staticmethod
async def delete_session(session_id: str) -> bool:
"""Delete a session."""
async with get_db_connection() as conn:
result = await conn.execute(
"DELETE FROM sessions WHERE session_id = $1", session_id
)
return result == "DELETE 1"
@staticmethod
async def list_sessions(limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
"""List sessions with pagination."""
async with get_db_connection() as conn:
rows = await conn.fetch(
"""
SELECT session_id, container_name, container_id, host_dir, port,
auth_token, created_at, last_accessed, status, metadata
FROM sessions
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
""",
limit,
offset,
)
result = []
for row in rows:
session = dict(row)
session["metadata"] = json.loads(session["metadata"] or "{}")
result.append(session)
return result
@staticmethod
async def count_sessions() -> int:
"""Count total sessions."""
async with get_db_connection() as conn:
return await conn.fetchval("SELECT COUNT(*) FROM sessions")
@staticmethod
async def cleanup_expired_sessions() -> int:
"""Clean up expired sessions using database function."""
async with get_db_connection() as conn:
return await conn.fetchval("SELECT cleanup_expired_sessions()")
@staticmethod
async def get_active_sessions_count() -> int:
"""Get count of active (running) sessions."""
async with get_db_connection() as conn:
return await conn.fetchval("""
SELECT COUNT(*) FROM sessions
WHERE status = 'running'
""")
@staticmethod
async def get_sessions_by_status(status: str) -> List[Dict[str, Any]]:
"""Get sessions by status."""
async with get_db_connection() as conn:
rows = await conn.fetch(
"""
SELECT session_id, container_name, container_id, host_dir, port,
auth_token, created_at, last_accessed, status, metadata
FROM sessions
WHERE status = $1
ORDER BY created_at DESC
""",
status,
)
result = []
for row in rows:
session = dict(row)
session["metadata"] = json.loads(session["metadata"] or "{}")
result.append(session)
return result
# Migration utilities
async def run_migrations():
"""Run database migrations."""
logger.info("Running database migrations")
async with get_db_connection() as conn:
# Migration 1: Add metadata column if it doesn't exist
try:
await conn.execute("""
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS metadata JSONB DEFAULT '{}'
""")
logger.info("Migration: Added metadata column")
except Exception as e:
logger.warning(f"Migration metadata column may already exist: {e}")
# Migration 2: Add indexes if they don't exist
try:
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
""")
logger.info("Migration: Added performance indexes")
except Exception as e:
logger.warning(f"Migration indexes may already exist: {e}")
logger.info("Database migrations completed")
# Health monitoring
async def get_database_stats() -> Dict[str, Any]:
"""Get database statistics and health information."""
try:
async with get_db_connection() as conn:
# Get basic stats
session_count = await conn.fetchval("SELECT COUNT(*) FROM sessions")
active_sessions = await conn.fetchval(
"SELECT COUNT(*) FROM sessions WHERE status = 'running'"
)
oldest_session = await conn.fetchval("SELECT MIN(created_at) FROM sessions")
newest_session = await conn.fetchval("SELECT MAX(created_at) FROM sessions")
# Get database size information
db_size = await conn.fetchval("""
SELECT pg_size_pretty(pg_database_size(current_database()))
""")
return {
"total_sessions": session_count,
"active_sessions": active_sessions,
"oldest_session": oldest_session.isoformat()
if oldest_session
else None,
"newest_session": newest_session.isoformat()
if newest_session
else None,
"database_size": db_size,
"status": "healthy",
}
except Exception as e:
logger.error("Failed to get database stats", extra={"error": str(e)})
return {
"status": "unhealthy",
"error": str(e),
}