Files
lovdata-chat/docker/scripts/test-database-persistence.py
2026-01-18 23:29:04 +01:00

407 lines
13 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Database Persistence Test Script
Tests the PostgreSQL-backed session storage system for reliability,
performance, and multi-instance deployment support.
"""
import os
import sys
import asyncio
import json
from pathlib import Path
# Add session-manager to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from database import (
DatabaseConnection,
SessionModel,
get_database_stats,
init_database,
run_migrations,
_db_connection,
)
# Set up logging
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def test_database_connection():
"""Test database connection and basic operations."""
print("🗄️ Testing Database Connection")
print("=" * 50)
try:
# Test connection
health = await _db_connection.health_check()
if health.get("status") == "healthy":
print("✅ Database connection established")
return True
else:
print(f"❌ Database connection failed: {health}")
return False
except Exception as e:
print(f"❌ Database connection error: {e}")
return False
async def test_database_schema():
"""Test database schema creation and migrations."""
print("\n📋 Testing Database Schema")
print("=" * 50)
try:
# Initialize database and run migrations
await init_database()
await run_migrations()
# Verify schema by checking if we can query the table
async with _db_connection.get_connection() as conn:
result = await conn.fetchval("""
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_name = 'sessions'
)
""")
if result:
print("✅ Database schema created successfully")
# Check indexes
indexes = await conn.fetch("""
SELECT indexname FROM pg_indexes
WHERE tablename = 'sessions'
""")
index_names = [row["indexname"] for row in indexes]
expected_indexes = [
"sessions_pkey",
"idx_sessions_status",
"idx_sessions_last_accessed",
"idx_sessions_created_at",
]
for expected in expected_indexes:
if any(expected in name for name in index_names):
print(f"✅ Index {expected} exists")
else:
print(f"❌ Index {expected} missing")
return True
else:
print("❌ Sessions table not found")
return False
except Exception as e:
print(f"❌ Schema creation failed: {e}")
return False
async def test_session_crud():
"""Test session create, read, update, delete operations."""
print("\n🔄 Testing Session CRUD Operations")
print("=" * 50)
test_session = {
"session_id": "test-session-db-123",
"container_name": "test-container-123",
"host_dir": "/tmp/test-session",
"port": 8081,
"auth_token": "test-token-abc123",
"status": "creating",
"metadata": {"test": True, "created_by": "test_script"},
}
try:
# Create session
created = await SessionModel.create_session(test_session)
if created and created["session_id"] == test_session["session_id"]:
print("✅ Session created successfully")
else:
print("❌ Session creation failed")
return False
# Read session
retrieved = await SessionModel.get_session(test_session["session_id"])
if retrieved and retrieved["session_id"] == test_session["session_id"]:
print("✅ Session retrieved successfully")
# Verify metadata
if retrieved.get("metadata", {}).get("test"):
print("✅ Session metadata preserved")
else:
print("❌ Session metadata missing")
else:
print("❌ Session retrieval failed")
return False
# Update session
updates = {"status": "running", "container_id": "container-abc123"}
updated = await SessionModel.update_session(test_session["session_id"], updates)
if updated:
print("✅ Session updated successfully")
# Verify update
updated_session = await SessionModel.get_session(test_session["session_id"])
if (
updated_session["status"] == "running"
and updated_session["container_id"] == "container-abc123"
):
print("✅ Session update verified")
else:
print("❌ Session update not reflected")
else:
print("❌ Session update failed")
# Delete session
deleted = await SessionModel.delete_session(test_session["session_id"])
if deleted:
print("✅ Session deleted successfully")
# Verify deletion
deleted_session = await SessionModel.get_session(test_session["session_id"])
if deleted_session is None:
print("✅ Session deletion verified")
else:
print("❌ Session still exists after deletion")
else:
print("❌ Session deletion failed")
return True
except Exception as e:
print(f"❌ CRUD operation failed: {e}")
return False
async def test_concurrent_sessions():
"""Test handling multiple concurrent sessions."""
print("\n👥 Testing Concurrent Sessions")
print("=" * 50)
concurrent_sessions = []
for i in range(5):
session = {
"session_id": f"concurrent-session-{i}",
"container_name": f"container-{i}",
"host_dir": f"/tmp/session-{i}",
"port": 8080 + i,
"auth_token": f"token-{i}",
"status": "creating",
}
concurrent_sessions.append(session)
try:
# Create sessions concurrently
create_tasks = [
SessionModel.create_session(session) for session in concurrent_sessions
]
created_sessions = await asyncio.gather(*create_tasks)
successful_creates = sum(1 for s in created_sessions if s is not None)
print(
f"✅ Created {successful_creates}/{len(concurrent_sessions)} concurrent sessions"
)
# Retrieve sessions concurrently
retrieve_tasks = [
SessionModel.get_session(s["session_id"]) for s in concurrent_sessions
]
retrieved_sessions = await asyncio.gather(*retrieve_tasks)
successful_retrieves = sum(1 for s in retrieved_sessions if s is not None)
print(
f"✅ Retrieved {successful_retrieves}/{len(concurrent_sessions)} concurrent sessions"
)
# Update sessions concurrently
update_tasks = [
SessionModel.update_session(s["session_id"], {"status": "running"})
for s in concurrent_sessions
]
update_results = await asyncio.gather(*update_tasks)
successful_updates = sum(1 for r in update_results if r)
print(
f"✅ Updated {successful_updates}/{len(concurrent_sessions)} concurrent sessions"
)
# Clean up
cleanup_tasks = [
SessionModel.delete_session(s["session_id"]) for s in concurrent_sessions
]
await asyncio.gather(*cleanup_tasks)
print("✅ Concurrent session operations completed")
return (
successful_creates == len(concurrent_sessions)
and successful_retrieves == len(concurrent_sessions)
and successful_updates == len(concurrent_sessions)
)
except Exception as e:
print(f"❌ Concurrent operations failed: {e}")
return False
async def test_database_performance():
"""Test database performance and statistics."""
print("\n⚡ Testing Database Performance")
print("=" * 50)
try:
# Get database statistics
stats = await get_database_stats()
if isinstance(stats, dict):
print("✅ Database statistics retrieved")
print(f" Total sessions: {stats.get('total_sessions', 'N/A')}")
print(f" Active sessions: {stats.get('active_sessions', 'N/A')}")
print(f" Database size: {stats.get('database_size', 'N/A')}")
if stats.get("status") == "healthy":
print("✅ Database health check passed")
else:
print(f"⚠️ Database health status: {stats.get('status')}")
else:
print("❌ Database statistics not available")
return False
# Test session counting
count = await SessionModel.count_sessions()
print(f"✅ Session count query: {count} sessions")
# Test active session counting
active_count = await SessionModel.get_active_sessions_count()
print(f"✅ Active session count: {active_count} sessions")
return True
except Exception as e:
print(f"❌ Performance testing failed: {e}")
return False
async def test_session_queries():
"""Test various session query operations."""
print("\n🔍 Testing Session Queries")
print("=" * 50)
# Create test sessions with different statuses
test_sessions = [
{
"session_id": "query-test-1",
"container_name": "container-1",
"status": "creating",
},
{
"session_id": "query-test-2",
"container_name": "container-2",
"status": "running",
},
{
"session_id": "query-test-3",
"container_name": "container-3",
"status": "running",
},
{
"session_id": "query-test-4",
"container_name": "container-4",
"status": "stopped",
},
]
try:
# Create test sessions
for session in test_sessions:
await SessionModel.create_session(session)
# Test list sessions
all_sessions = await SessionModel.list_sessions(limit=10)
print(f"✅ Listed {len(all_sessions)} sessions")
# Test filter by status
running_sessions = await SessionModel.get_sessions_by_status("running")
print(f"✅ Found {len(running_sessions)} running sessions")
creating_sessions = await SessionModel.get_sessions_by_status("creating")
print(f"✅ Found {len(creating_sessions)} creating sessions")
# Verify counts
expected_running = len([s for s in test_sessions if s["status"] == "running"])
if len(running_sessions) == expected_running:
print("✅ Status filtering accurate")
else:
print(
f"❌ Status filtering incorrect: expected {expected_running}, got {len(running_sessions)}"
)
# Clean up test sessions
for session in test_sessions:
await SessionModel.delete_session(session["session_id"])
print("✅ Query testing completed")
return True
except Exception as e:
print(f"❌ Query testing failed: {e}")
return False
async def run_all_database_tests():
"""Run all database persistence tests."""
print("💾 Database Persistence Test Suite")
print("=" * 70)
tests = [
("Database Connection", test_database_connection),
("Database Schema", test_database_schema),
("Session CRUD", test_session_crud),
("Concurrent Sessions", test_concurrent_sessions),
("Database Performance", test_database_performance),
("Session Queries", test_session_queries),
]
results = []
for test_name, test_func in tests:
print(f"\n{'=' * 25} {test_name} {'=' * 25}")
try:
result = await test_func()
results.append(result)
status = "✅ PASSED" if result else "❌ FAILED"
print(f"\n{status}: {test_name}")
except Exception as e:
print(f"\n❌ ERROR in {test_name}: {e}")
results.append(False)
# Summary
print(f"\n{'=' * 70}")
passed = sum(results)
total = len(results)
print(f"📊 Test Results: {passed}/{total} tests passed")
if passed == total:
print("🎉 All database persistence tests completed successfully!")
print("💾 PostgreSQL backend provides reliable session storage.")
else:
print("⚠️ Some tests failed. Check the output above for details.")
print("💡 Ensure PostgreSQL is running and connection settings are correct.")
print(" Required environment variables:")
print(" - DB_HOST (default: localhost)")
print(" - DB_PORT (default: 5432)")
print(" - DB_USER (default: lovdata)")
print(" - DB_PASSWORD (default: password)")
print(" - DB_NAME (default: lovdata_chat)")
return passed == total
if __name__ == "__main__":
asyncio.run(run_all_database_tests())