407 lines
14 KiB
Python
407 lines
14 KiB
Python
"""
|
|
Database Models and Connection Management for Session Persistence
|
|
|
|
Provides PostgreSQL-backed session storage with connection pooling,
|
|
migrations, and health monitoring for production reliability.
|
|
"""
|
|
|
|
import os
|
|
import asyncpg
|
|
import json
|
|
from typing import Dict, List, Optional, Any
|
|
from datetime import datetime, timedelta
|
|
from contextlib import asynccontextmanager
|
|
import logging
|
|
|
|
from logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class DatabaseConnection:
|
|
"""PostgreSQL connection management with pooling."""
|
|
|
|
def __init__(self):
|
|
self._pool: Optional[asyncpg.Pool] = None
|
|
self._config = {
|
|
"host": os.getenv("DB_HOST", "localhost"),
|
|
"port": int(os.getenv("DB_PORT", "5432")),
|
|
"user": os.getenv("DB_USER", "lovdata"),
|
|
"password": os.getenv("DB_PASSWORD", "password"),
|
|
"database": os.getenv("DB_NAME", "lovdata_chat"),
|
|
"min_size": int(os.getenv("DB_MIN_CONNECTIONS", "5")),
|
|
"max_size": int(os.getenv("DB_MAX_CONNECTIONS", "20")),
|
|
"max_queries": int(os.getenv("DB_MAX_QUERIES", "50000")),
|
|
"max_inactive_connection_lifetime": float(
|
|
os.getenv("DB_MAX_INACTIVE_LIFETIME", "300.0")
|
|
),
|
|
}
|
|
|
|
async def connect(self) -> None:
|
|
"""Establish database connection pool."""
|
|
if self._pool:
|
|
return
|
|
|
|
try:
|
|
self._pool = await asyncpg.create_pool(**self._config)
|
|
logger.info(
|
|
"Database connection pool established",
|
|
extra={
|
|
"host": self._config["host"],
|
|
"port": self._config["port"],
|
|
"database": self._config["database"],
|
|
"min_connections": self._config["min_size"],
|
|
"max_connections": self._config["max_size"],
|
|
},
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to establish database connection", extra={"error": str(e)}
|
|
)
|
|
raise
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Close database connection pool."""
|
|
if self._pool:
|
|
await self._pool.close()
|
|
self._pool = None
|
|
logger.info("Database connection pool closed")
|
|
|
|
async def get_connection(self) -> asyncpg.Connection:
|
|
"""Get a database connection from the pool."""
|
|
if not self._pool:
|
|
await self.connect()
|
|
return await self._pool.acquire()
|
|
|
|
async def release_connection(self, conn: asyncpg.Connection) -> None:
|
|
"""Release a database connection back to the pool."""
|
|
if self._pool:
|
|
await self._pool.release(conn)
|
|
|
|
async def health_check(self) -> Dict[str, Any]:
|
|
"""Perform database health check."""
|
|
try:
|
|
conn = await self.get_connection()
|
|
result = await conn.fetchval("SELECT 1")
|
|
await self.release_connection(conn)
|
|
|
|
if result == 1:
|
|
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
|
|
else:
|
|
return {"status": "unhealthy", "error": "Health check query failed"}
|
|
except Exception as e:
|
|
logger.error("Database health check failed", extra={"error": str(e)})
|
|
return {"status": "unhealthy", "error": str(e)}
|
|
|
|
@asynccontextmanager
|
|
async def transaction(self):
|
|
"""Context manager for database transactions."""
|
|
conn = await self.get_connection()
|
|
try:
|
|
async with conn.transaction():
|
|
yield conn
|
|
finally:
|
|
await self.release_connection(conn)
|
|
|
|
|
|
# Global database connection instance
|
|
_db_connection = DatabaseConnection()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_db_connection():
|
|
"""Context manager for database connections."""
|
|
conn = await _db_connection.get_connection()
|
|
try:
|
|
yield conn
|
|
finally:
|
|
await _db_connection.release_connection(conn)
|
|
|
|
|
|
async def init_database() -> None:
|
|
"""Initialize database and run migrations."""
|
|
logger.info("Initializing database")
|
|
|
|
async with get_db_connection() as conn:
|
|
# Create sessions table
|
|
await conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
session_id VARCHAR(32) PRIMARY KEY,
|
|
container_name VARCHAR(255) NOT NULL,
|
|
container_id VARCHAR(255),
|
|
host_dir VARCHAR(1024) NOT NULL,
|
|
port INTEGER,
|
|
auth_token VARCHAR(255),
|
|
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
|
last_accessed TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
|
status VARCHAR(50) NOT NULL DEFAULT 'creating',
|
|
metadata JSONB DEFAULT '{}'
|
|
)
|
|
""")
|
|
|
|
# Create indexes for performance
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
|
|
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
|
|
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
|
|
CREATE INDEX IF NOT EXISTS idx_sessions_container_name ON sessions(container_name);
|
|
""")
|
|
|
|
# Create cleanup function for expired sessions
|
|
await conn.execute("""
|
|
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
|
|
RETURNS INTEGER AS $$
|
|
DECLARE
|
|
deleted_count INTEGER;
|
|
BEGIN
|
|
DELETE FROM sessions
|
|
WHERE last_accessed < NOW() - INTERVAL '1 hour';
|
|
|
|
GET DIAGNOSTICS deleted_count = ROW_COUNT;
|
|
RETURN deleted_count;
|
|
END;
|
|
$$ LANGUAGE plpgsql;
|
|
""")
|
|
|
|
logger.info("Database initialized and migrations applied")
|
|
|
|
|
|
async def shutdown_database() -> None:
|
|
"""Shutdown database connections."""
|
|
await _db_connection.disconnect()
|
|
|
|
|
|
class SessionModel:
|
|
"""Database model for sessions."""
|
|
|
|
@staticmethod
|
|
async def create_session(session_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Create a new session in the database."""
|
|
async with get_db_connection() as conn:
|
|
row = await conn.fetchrow(
|
|
"""
|
|
INSERT INTO sessions (
|
|
session_id, container_name, container_id, host_dir, port,
|
|
auth_token, status, metadata
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
RETURNING session_id, container_name, container_id, host_dir, port,
|
|
auth_token, created_at, last_accessed, status, metadata
|
|
""",
|
|
session_data["session_id"],
|
|
session_data["container_name"],
|
|
session_data.get("container_id"),
|
|
session_data["host_dir"],
|
|
session_data.get("port"),
|
|
session_data.get("auth_token"),
|
|
session_data.get("status", "creating"),
|
|
json.dumps(session_data.get("metadata", {})),
|
|
)
|
|
|
|
if row:
|
|
return dict(row)
|
|
raise ValueError("Failed to create session")
|
|
|
|
@staticmethod
|
|
async def get_session(session_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get a session by ID."""
|
|
async with get_db_connection() as conn:
|
|
# Update last_accessed timestamp
|
|
row = await conn.fetchrow(
|
|
"""
|
|
UPDATE sessions
|
|
SET last_accessed = NOW()
|
|
WHERE session_id = $1
|
|
RETURNING session_id, container_name, container_id, host_dir, port,
|
|
auth_token, created_at, last_accessed, status, metadata
|
|
""",
|
|
session_id,
|
|
)
|
|
|
|
if row:
|
|
result = dict(row)
|
|
result["metadata"] = json.loads(result["metadata"] or "{}")
|
|
return result
|
|
return None
|
|
|
|
@staticmethod
|
|
async def update_session(session_id: str, updates: Dict[str, Any]) -> bool:
|
|
"""Update session fields."""
|
|
if not updates:
|
|
return True
|
|
|
|
async with get_db_connection() as conn:
|
|
# Build dynamic update query
|
|
set_parts = []
|
|
values = [session_id]
|
|
param_index = 2
|
|
|
|
for key, value in updates.items():
|
|
if key in ["session_id", "created_at"]: # Don't update these
|
|
continue
|
|
|
|
if key == "metadata":
|
|
set_parts.append(f"metadata = ${param_index}")
|
|
values.append(json.dumps(value))
|
|
elif key == "last_accessed":
|
|
set_parts.append(f"last_accessed = NOW()")
|
|
else:
|
|
set_parts.append(f"{key} = ${param_index}")
|
|
values.append(value)
|
|
param_index += 1
|
|
|
|
if not set_parts:
|
|
return True
|
|
|
|
query = f"""
|
|
UPDATE sessions
|
|
SET {", ".join(set_parts)}
|
|
WHERE session_id = $1
|
|
"""
|
|
|
|
result = await conn.execute(query, *values)
|
|
return result == "UPDATE 1"
|
|
|
|
@staticmethod
|
|
async def delete_session(session_id: str) -> bool:
|
|
"""Delete a session."""
|
|
async with get_db_connection() as conn:
|
|
result = await conn.execute(
|
|
"DELETE FROM sessions WHERE session_id = $1", session_id
|
|
)
|
|
return result == "DELETE 1"
|
|
|
|
@staticmethod
|
|
async def list_sessions(limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
|
|
"""List sessions with pagination."""
|
|
async with get_db_connection() as conn:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT session_id, container_name, container_id, host_dir, port,
|
|
auth_token, created_at, last_accessed, status, metadata
|
|
FROM sessions
|
|
ORDER BY created_at DESC
|
|
LIMIT $1 OFFSET $2
|
|
""",
|
|
limit,
|
|
offset,
|
|
)
|
|
|
|
result = []
|
|
for row in rows:
|
|
session = dict(row)
|
|
session["metadata"] = json.loads(session["metadata"] or "{}")
|
|
result.append(session)
|
|
|
|
return result
|
|
|
|
@staticmethod
|
|
async def count_sessions() -> int:
|
|
"""Count total sessions."""
|
|
async with get_db_connection() as conn:
|
|
return await conn.fetchval("SELECT COUNT(*) FROM sessions")
|
|
|
|
@staticmethod
|
|
async def cleanup_expired_sessions() -> int:
|
|
"""Clean up expired sessions using database function."""
|
|
async with get_db_connection() as conn:
|
|
return await conn.fetchval("SELECT cleanup_expired_sessions()")
|
|
|
|
@staticmethod
|
|
async def get_active_sessions_count() -> int:
|
|
"""Get count of active (running) sessions."""
|
|
async with get_db_connection() as conn:
|
|
return await conn.fetchval("""
|
|
SELECT COUNT(*) FROM sessions
|
|
WHERE status = 'running'
|
|
""")
|
|
|
|
@staticmethod
|
|
async def get_sessions_by_status(status: str) -> List[Dict[str, Any]]:
|
|
"""Get sessions by status."""
|
|
async with get_db_connection() as conn:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT session_id, container_name, container_id, host_dir, port,
|
|
auth_token, created_at, last_accessed, status, metadata
|
|
FROM sessions
|
|
WHERE status = $1
|
|
ORDER BY created_at DESC
|
|
""",
|
|
status,
|
|
)
|
|
|
|
result = []
|
|
for row in rows:
|
|
session = dict(row)
|
|
session["metadata"] = json.loads(session["metadata"] or "{}")
|
|
result.append(session)
|
|
|
|
return result
|
|
|
|
|
|
# Migration utilities
|
|
async def run_migrations():
|
|
"""Run database migrations."""
|
|
logger.info("Running database migrations")
|
|
|
|
async with get_db_connection() as conn:
|
|
# Migration 1: Add metadata column if it doesn't exist
|
|
try:
|
|
await conn.execute("""
|
|
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS metadata JSONB DEFAULT '{}'
|
|
""")
|
|
logger.info("Migration: Added metadata column")
|
|
except Exception as e:
|
|
logger.warning(f"Migration metadata column may already exist: {e}")
|
|
|
|
# Migration 2: Add indexes if they don't exist
|
|
try:
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
|
|
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
|
|
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
|
|
""")
|
|
logger.info("Migration: Added performance indexes")
|
|
except Exception as e:
|
|
logger.warning(f"Migration indexes may already exist: {e}")
|
|
|
|
logger.info("Database migrations completed")
|
|
|
|
|
|
# Health monitoring
|
|
async def get_database_stats() -> Dict[str, Any]:
|
|
"""Get database statistics and health information."""
|
|
try:
|
|
async with get_db_connection() as conn:
|
|
# Get basic stats
|
|
session_count = await conn.fetchval("SELECT COUNT(*) FROM sessions")
|
|
active_sessions = await conn.fetchval(
|
|
"SELECT COUNT(*) FROM sessions WHERE status = 'running'"
|
|
)
|
|
oldest_session = await conn.fetchval("SELECT MIN(created_at) FROM sessions")
|
|
newest_session = await conn.fetchval("SELECT MAX(created_at) FROM sessions")
|
|
|
|
# Get database size information
|
|
db_size = await conn.fetchval("""
|
|
SELECT pg_size_pretty(pg_database_size(current_database()))
|
|
""")
|
|
|
|
return {
|
|
"total_sessions": session_count,
|
|
"active_sessions": active_sessions,
|
|
"oldest_session": oldest_session.isoformat()
|
|
if oldest_session
|
|
else None,
|
|
"newest_session": newest_session.isoformat()
|
|
if newest_session
|
|
else None,
|
|
"database_size": db_size,
|
|
"status": "healthy",
|
|
}
|
|
except Exception as e:
|
|
logger.error("Failed to get database stats", extra={"error": str(e)})
|
|
return {
|
|
"status": "unhealthy",
|
|
"error": str(e),
|
|
}
|