#!/usr/bin/env python3 """ Database Persistence Test Script Tests the PostgreSQL-backed session storage system for reliability, performance, and multi-instance deployment support. """ import os import sys import asyncio import json from pathlib import Path # Add session-manager to path for imports sys.path.insert(0, str(Path(__file__).parent)) from database import ( DatabaseConnection, SessionModel, get_database_stats, init_database, run_migrations, _db_connection, ) # Set up logging import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) async def test_database_connection(): """Test database connection and basic operations.""" print("šŸ—„ļø Testing Database Connection") print("=" * 50) try: # Test connection health = await _db_connection.health_check() if health.get("status") == "healthy": print("āœ… Database connection established") return True else: print(f"āŒ Database connection failed: {health}") return False except Exception as e: print(f"āŒ Database connection error: {e}") return False async def test_database_schema(): """Test database schema creation and migrations.""" print("\nšŸ“‹ Testing Database Schema") print("=" * 50) try: # Initialize database and run migrations await init_database() await run_migrations() # Verify schema by checking if we can query the table async with _db_connection.get_connection() as conn: result = await conn.fetchval(""" SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_name = 'sessions' ) """) if result: print("āœ… Database schema created successfully") # Check indexes indexes = await conn.fetch(""" SELECT indexname FROM pg_indexes WHERE tablename = 'sessions' """) index_names = [row["indexname"] for row in indexes] expected_indexes = [ "sessions_pkey", "idx_sessions_status", "idx_sessions_last_accessed", "idx_sessions_created_at", ] for expected in expected_indexes: if any(expected in name for name in index_names): print(f"āœ… Index {expected} exists") else: print(f"āŒ Index {expected} missing") return True else: print("āŒ Sessions table not found") return False except Exception as e: print(f"āŒ Schema creation failed: {e}") return False async def test_session_crud(): """Test session create, read, update, delete operations.""" print("\nšŸ”„ Testing Session CRUD Operations") print("=" * 50) test_session = { "session_id": "test-session-db-123", "container_name": "test-container-123", "host_dir": "/tmp/test-session", "port": 8081, "auth_token": "test-token-abc123", "status": "creating", "metadata": {"test": True, "created_by": "test_script"}, } try: # Create session created = await SessionModel.create_session(test_session) if created and created["session_id"] == test_session["session_id"]: print("āœ… Session created successfully") else: print("āŒ Session creation failed") return False # Read session retrieved = await SessionModel.get_session(test_session["session_id"]) if retrieved and retrieved["session_id"] == test_session["session_id"]: print("āœ… Session retrieved successfully") # Verify metadata if retrieved.get("metadata", {}).get("test"): print("āœ… Session metadata preserved") else: print("āŒ Session metadata missing") else: print("āŒ Session retrieval failed") return False # Update session updates = {"status": "running", "container_id": "container-abc123"} updated = await SessionModel.update_session(test_session["session_id"], updates) if updated: print("āœ… Session updated successfully") # Verify update updated_session = await SessionModel.get_session(test_session["session_id"]) if ( updated_session["status"] == "running" and updated_session["container_id"] == "container-abc123" ): print("āœ… Session update verified") else: print("āŒ Session update not reflected") else: print("āŒ Session update failed") # Delete session deleted = await SessionModel.delete_session(test_session["session_id"]) if deleted: print("āœ… Session deleted successfully") # Verify deletion deleted_session = await SessionModel.get_session(test_session["session_id"]) if deleted_session is None: print("āœ… Session deletion verified") else: print("āŒ Session still exists after deletion") else: print("āŒ Session deletion failed") return True except Exception as e: print(f"āŒ CRUD operation failed: {e}") return False async def test_concurrent_sessions(): """Test handling multiple concurrent sessions.""" print("\nšŸ‘„ Testing Concurrent Sessions") print("=" * 50) concurrent_sessions = [] for i in range(5): session = { "session_id": f"concurrent-session-{i}", "container_name": f"container-{i}", "host_dir": f"/tmp/session-{i}", "port": 8080 + i, "auth_token": f"token-{i}", "status": "creating", } concurrent_sessions.append(session) try: # Create sessions concurrently create_tasks = [ SessionModel.create_session(session) for session in concurrent_sessions ] created_sessions = await asyncio.gather(*create_tasks) successful_creates = sum(1 for s in created_sessions if s is not None) print( f"āœ… Created {successful_creates}/{len(concurrent_sessions)} concurrent sessions" ) # Retrieve sessions concurrently retrieve_tasks = [ SessionModel.get_session(s["session_id"]) for s in concurrent_sessions ] retrieved_sessions = await asyncio.gather(*retrieve_tasks) successful_retrieves = sum(1 for s in retrieved_sessions if s is not None) print( f"āœ… Retrieved {successful_retrieves}/{len(concurrent_sessions)} concurrent sessions" ) # Update sessions concurrently update_tasks = [ SessionModel.update_session(s["session_id"], {"status": "running"}) for s in concurrent_sessions ] update_results = await asyncio.gather(*update_tasks) successful_updates = sum(1 for r in update_results if r) print( f"āœ… Updated {successful_updates}/{len(concurrent_sessions)} concurrent sessions" ) # Clean up cleanup_tasks = [ SessionModel.delete_session(s["session_id"]) for s in concurrent_sessions ] await asyncio.gather(*cleanup_tasks) print("āœ… Concurrent session operations completed") return ( successful_creates == len(concurrent_sessions) and successful_retrieves == len(concurrent_sessions) and successful_updates == len(concurrent_sessions) ) except Exception as e: print(f"āŒ Concurrent operations failed: {e}") return False async def test_database_performance(): """Test database performance and statistics.""" print("\n⚔ Testing Database Performance") print("=" * 50) try: # Get database statistics stats = await get_database_stats() if isinstance(stats, dict): print("āœ… Database statistics retrieved") print(f" Total sessions: {stats.get('total_sessions', 'N/A')}") print(f" Active sessions: {stats.get('active_sessions', 'N/A')}") print(f" Database size: {stats.get('database_size', 'N/A')}") if stats.get("status") == "healthy": print("āœ… Database health check passed") else: print(f"āš ļø Database health status: {stats.get('status')}") else: print("āŒ Database statistics not available") return False # Test session counting count = await SessionModel.count_sessions() print(f"āœ… Session count query: {count} sessions") # Test active session counting active_count = await SessionModel.get_active_sessions_count() print(f"āœ… Active session count: {active_count} sessions") return True except Exception as e: print(f"āŒ Performance testing failed: {e}") return False async def test_session_queries(): """Test various session query operations.""" print("\nšŸ” Testing Session Queries") print("=" * 50) # Create test sessions with different statuses test_sessions = [ { "session_id": "query-test-1", "container_name": "container-1", "status": "creating", }, { "session_id": "query-test-2", "container_name": "container-2", "status": "running", }, { "session_id": "query-test-3", "container_name": "container-3", "status": "running", }, { "session_id": "query-test-4", "container_name": "container-4", "status": "stopped", }, ] try: # Create test sessions for session in test_sessions: await SessionModel.create_session(session) # Test list sessions all_sessions = await SessionModel.list_sessions(limit=10) print(f"āœ… Listed {len(all_sessions)} sessions") # Test filter by status running_sessions = await SessionModel.get_sessions_by_status("running") print(f"āœ… Found {len(running_sessions)} running sessions") creating_sessions = await SessionModel.get_sessions_by_status("creating") print(f"āœ… Found {len(creating_sessions)} creating sessions") # Verify counts expected_running = len([s for s in test_sessions if s["status"] == "running"]) if len(running_sessions) == expected_running: print("āœ… Status filtering accurate") else: print( f"āŒ Status filtering incorrect: expected {expected_running}, got {len(running_sessions)}" ) # Clean up test sessions for session in test_sessions: await SessionModel.delete_session(session["session_id"]) print("āœ… Query testing completed") return True except Exception as e: print(f"āŒ Query testing failed: {e}") return False async def run_all_database_tests(): """Run all database persistence tests.""" print("šŸ’¾ Database Persistence Test Suite") print("=" * 70) tests = [ ("Database Connection", test_database_connection), ("Database Schema", test_database_schema), ("Session CRUD", test_session_crud), ("Concurrent Sessions", test_concurrent_sessions), ("Database Performance", test_database_performance), ("Session Queries", test_session_queries), ] results = [] for test_name, test_func in tests: print(f"\n{'=' * 25} {test_name} {'=' * 25}") try: result = await test_func() results.append(result) status = "āœ… PASSED" if result else "āŒ FAILED" print(f"\n{status}: {test_name}") except Exception as e: print(f"\nāŒ ERROR in {test_name}: {e}") results.append(False) # Summary print(f"\n{'=' * 70}") passed = sum(results) total = len(results) print(f"šŸ“Š Test Results: {passed}/{total} tests passed") if passed == total: print("šŸŽ‰ All database persistence tests completed successfully!") print("šŸ’¾ PostgreSQL backend provides reliable session storage.") else: print("āš ļø Some tests failed. Check the output above for details.") print("šŸ’” Ensure PostgreSQL is running and connection settings are correct.") print(" Required environment variables:") print(" - DB_HOST (default: localhost)") print(" - DB_PORT (default: 5432)") print(" - DB_USER (default: lovdata)") print(" - DB_PASSWORD (default: password)") print(" - DB_NAME (default: lovdata_chat)") return passed == total if __name__ == "__main__": asyncio.run(run_all_database_tests())