fixed all remaining issues with the session manager
This commit is contained in:
406
session-manager/database.py
Normal file
406
session-manager/database.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
Database Models and Connection Management for Session Persistence
|
||||
|
||||
Provides PostgreSQL-backed session storage with connection pooling,
|
||||
migrations, and health monitoring for production reliability.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncpg
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
from logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseConnection:
|
||||
"""PostgreSQL connection management with pooling."""
|
||||
|
||||
def __init__(self):
|
||||
self._pool: Optional[asyncpg.Pool] = None
|
||||
self._config = {
|
||||
"host": os.getenv("DB_HOST", "localhost"),
|
||||
"port": int(os.getenv("DB_PORT", "5432")),
|
||||
"user": os.getenv("DB_USER", "lovdata"),
|
||||
"password": os.getenv("DB_PASSWORD", "password"),
|
||||
"database": os.getenv("DB_NAME", "lovdata_chat"),
|
||||
"min_size": int(os.getenv("DB_MIN_CONNECTIONS", "5")),
|
||||
"max_size": int(os.getenv("DB_MAX_CONNECTIONS", "20")),
|
||||
"max_queries": int(os.getenv("DB_MAX_QUERIES", "50000")),
|
||||
"max_inactive_connection_lifetime": float(
|
||||
os.getenv("DB_MAX_INACTIVE_LIFETIME", "300.0")
|
||||
),
|
||||
}
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish database connection pool."""
|
||||
if self._pool:
|
||||
return
|
||||
|
||||
try:
|
||||
self._pool = await asyncpg.create_pool(**self._config)
|
||||
logger.info(
|
||||
"Database connection pool established",
|
||||
extra={
|
||||
"host": self._config["host"],
|
||||
"port": self._config["port"],
|
||||
"database": self._config["database"],
|
||||
"min_connections": self._config["min_size"],
|
||||
"max_connections": self._config["max_size"],
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to establish database connection", extra={"error": str(e)}
|
||||
)
|
||||
raise
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close database connection pool."""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
logger.info("Database connection pool closed")
|
||||
|
||||
async def get_connection(self) -> asyncpg.Connection:
|
||||
"""Get a database connection from the pool."""
|
||||
if not self._pool:
|
||||
await self.connect()
|
||||
return await self._pool.acquire()
|
||||
|
||||
async def release_connection(self, conn: asyncpg.Connection) -> None:
|
||||
"""Release a database connection back to the pool."""
|
||||
if self._pool:
|
||||
await self._pool.release(conn)
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform database health check."""
|
||||
try:
|
||||
conn = await self.get_connection()
|
||||
result = await conn.fetchval("SELECT 1")
|
||||
await self.release_connection(conn)
|
||||
|
||||
if result == 1:
|
||||
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
|
||||
else:
|
||||
return {"status": "unhealthy", "error": "Health check query failed"}
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", extra={"error": str(e)})
|
||||
return {"status": "unhealthy", "error": str(e)}
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""Context manager for database transactions."""
|
||||
conn = await self.get_connection()
|
||||
try:
|
||||
async with conn.transaction():
|
||||
yield conn
|
||||
finally:
|
||||
await self.release_connection(conn)
|
||||
|
||||
|
||||
# Global database connection instance
|
||||
_db_connection = DatabaseConnection()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_connection():
|
||||
"""Context manager for database connections."""
|
||||
conn = await _db_connection.get_connection()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await _db_connection.release_connection(conn)
|
||||
|
||||
|
||||
async def init_database() -> None:
|
||||
"""Initialize database and run migrations."""
|
||||
logger.info("Initializing database")
|
||||
|
||||
async with get_db_connection() as conn:
|
||||
# Create sessions table
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_id VARCHAR(32) PRIMARY KEY,
|
||||
container_name VARCHAR(255) NOT NULL,
|
||||
container_id VARCHAR(255),
|
||||
host_dir VARCHAR(1024) NOT NULL,
|
||||
port INTEGER,
|
||||
auth_token VARCHAR(255),
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
last_accessed TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
status VARCHAR(50) NOT NULL DEFAULT 'creating',
|
||||
metadata JSONB DEFAULT '{}'
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for performance
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_container_name ON sessions(container_name);
|
||||
""")
|
||||
|
||||
# Create cleanup function for expired sessions
|
||||
await conn.execute("""
|
||||
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
|
||||
RETURNS INTEGER AS $$
|
||||
DECLARE
|
||||
deleted_count INTEGER;
|
||||
BEGIN
|
||||
DELETE FROM sessions
|
||||
WHERE last_accessed < NOW() - INTERVAL '1 hour';
|
||||
|
||||
GET DIAGNOSTICS deleted_count = ROW_COUNT;
|
||||
RETURN deleted_count;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
""")
|
||||
|
||||
logger.info("Database initialized and migrations applied")
|
||||
|
||||
|
||||
async def shutdown_database() -> None:
|
||||
"""Shutdown database connections."""
|
||||
await _db_connection.disconnect()
|
||||
|
||||
|
||||
class SessionModel:
|
||||
"""Database model for sessions."""
|
||||
|
||||
@staticmethod
|
||||
async def create_session(session_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create a new session in the database."""
|
||||
async with get_db_connection() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO sessions (
|
||||
session_id, container_name, container_id, host_dir, port,
|
||||
auth_token, status, metadata
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING session_id, container_name, container_id, host_dir, port,
|
||||
auth_token, created_at, last_accessed, status, metadata
|
||||
""",
|
||||
session_data["session_id"],
|
||||
session_data["container_name"],
|
||||
session_data.get("container_id"),
|
||||
session_data["host_dir"],
|
||||
session_data.get("port"),
|
||||
session_data.get("auth_token"),
|
||||
session_data.get("status", "creating"),
|
||||
json.dumps(session_data.get("metadata", {})),
|
||||
)
|
||||
|
||||
if row:
|
||||
return dict(row)
|
||||
raise ValueError("Failed to create session")
|
||||
|
||||
@staticmethod
|
||||
async def get_session(session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a session by ID."""
|
||||
async with get_db_connection() as conn:
|
||||
# Update last_accessed timestamp
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET last_accessed = NOW()
|
||||
WHERE session_id = $1
|
||||
RETURNING session_id, container_name, container_id, host_dir, port,
|
||||
auth_token, created_at, last_accessed, status, metadata
|
||||
""",
|
||||
session_id,
|
||||
)
|
||||
|
||||
if row:
|
||||
result = dict(row)
|
||||
result["metadata"] = json.loads(result["metadata"] or "{}")
|
||||
return result
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def update_session(session_id: str, updates: Dict[str, Any]) -> bool:
|
||||
"""Update session fields."""
|
||||
if not updates:
|
||||
return True
|
||||
|
||||
async with get_db_connection() as conn:
|
||||
# Build dynamic update query
|
||||
set_parts = []
|
||||
values = [session_id]
|
||||
param_index = 2
|
||||
|
||||
for key, value in updates.items():
|
||||
if key in ["session_id", "created_at"]: # Don't update these
|
||||
continue
|
||||
|
||||
if key == "metadata":
|
||||
set_parts.append(f"metadata = ${param_index}")
|
||||
values.append(json.dumps(value))
|
||||
elif key == "last_accessed":
|
||||
set_parts.append(f"last_accessed = NOW()")
|
||||
else:
|
||||
set_parts.append(f"{key} = ${param_index}")
|
||||
values.append(value)
|
||||
param_index += 1
|
||||
|
||||
if not set_parts:
|
||||
return True
|
||||
|
||||
query = f"""
|
||||
UPDATE sessions
|
||||
SET {", ".join(set_parts)}
|
||||
WHERE session_id = $1
|
||||
"""
|
||||
|
||||
result = await conn.execute(query, *values)
|
||||
return result == "UPDATE 1"
|
||||
|
||||
@staticmethod
|
||||
async def delete_session(session_id: str) -> bool:
|
||||
"""Delete a session."""
|
||||
async with get_db_connection() as conn:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM sessions WHERE session_id = $1", session_id
|
||||
)
|
||||
return result == "DELETE 1"
|
||||
|
||||
@staticmethod
|
||||
async def list_sessions(limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
|
||||
"""List sessions with pagination."""
|
||||
async with get_db_connection() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT session_id, container_name, container_id, host_dir, port,
|
||||
auth_token, created_at, last_accessed, status, metadata
|
||||
FROM sessions
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
""",
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
|
||||
result = []
|
||||
for row in rows:
|
||||
session = dict(row)
|
||||
session["metadata"] = json.loads(session["metadata"] or "{}")
|
||||
result.append(session)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def count_sessions() -> int:
|
||||
"""Count total sessions."""
|
||||
async with get_db_connection() as conn:
|
||||
return await conn.fetchval("SELECT COUNT(*) FROM sessions")
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_sessions() -> int:
|
||||
"""Clean up expired sessions using database function."""
|
||||
async with get_db_connection() as conn:
|
||||
return await conn.fetchval("SELECT cleanup_expired_sessions()")
|
||||
|
||||
@staticmethod
|
||||
async def get_active_sessions_count() -> int:
|
||||
"""Get count of active (running) sessions."""
|
||||
async with get_db_connection() as conn:
|
||||
return await conn.fetchval("""
|
||||
SELECT COUNT(*) FROM sessions
|
||||
WHERE status = 'running'
|
||||
""")
|
||||
|
||||
@staticmethod
|
||||
async def get_sessions_by_status(status: str) -> List[Dict[str, Any]]:
|
||||
"""Get sessions by status."""
|
||||
async with get_db_connection() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT session_id, container_name, container_id, host_dir, port,
|
||||
auth_token, created_at, last_accessed, status, metadata
|
||||
FROM sessions
|
||||
WHERE status = $1
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
status,
|
||||
)
|
||||
|
||||
result = []
|
||||
for row in rows:
|
||||
session = dict(row)
|
||||
session["metadata"] = json.loads(session["metadata"] or "{}")
|
||||
result.append(session)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Migration utilities
|
||||
async def run_migrations():
|
||||
"""Run database migrations."""
|
||||
logger.info("Running database migrations")
|
||||
|
||||
async with get_db_connection() as conn:
|
||||
# Migration 1: Add metadata column if it doesn't exist
|
||||
try:
|
||||
await conn.execute("""
|
||||
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS metadata JSONB DEFAULT '{}'
|
||||
""")
|
||||
logger.info("Migration: Added metadata column")
|
||||
except Exception as e:
|
||||
logger.warning(f"Migration metadata column may already exist: {e}")
|
||||
|
||||
# Migration 2: Add indexes if they don't exist
|
||||
try:
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_last_accessed ON sessions(last_accessed);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_created_at ON sessions(created_at);
|
||||
""")
|
||||
logger.info("Migration: Added performance indexes")
|
||||
except Exception as e:
|
||||
logger.warning(f"Migration indexes may already exist: {e}")
|
||||
|
||||
logger.info("Database migrations completed")
|
||||
|
||||
|
||||
# Health monitoring
|
||||
async def get_database_stats() -> Dict[str, Any]:
|
||||
"""Get database statistics and health information."""
|
||||
try:
|
||||
async with get_db_connection() as conn:
|
||||
# Get basic stats
|
||||
session_count = await conn.fetchval("SELECT COUNT(*) FROM sessions")
|
||||
active_sessions = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM sessions WHERE status = 'running'"
|
||||
)
|
||||
oldest_session = await conn.fetchval("SELECT MIN(created_at) FROM sessions")
|
||||
newest_session = await conn.fetchval("SELECT MAX(created_at) FROM sessions")
|
||||
|
||||
# Get database size information
|
||||
db_size = await conn.fetchval("""
|
||||
SELECT pg_size_pretty(pg_database_size(current_database()))
|
||||
""")
|
||||
|
||||
return {
|
||||
"total_sessions": session_count,
|
||||
"active_sessions": active_sessions,
|
||||
"oldest_session": oldest_session.isoformat()
|
||||
if oldest_session
|
||||
else None,
|
||||
"newest_session": newest_session.isoformat()
|
||||
if newest_session
|
||||
else None,
|
||||
"database_size": db_size,
|
||||
"status": "healthy",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Failed to get database stats", extra={"error": str(e)})
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
}
|
||||
Reference in New Issue
Block a user