407 lines
13 KiB
Python
Executable File
407 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Database Persistence Test Script
|
|
|
|
Tests the PostgreSQL-backed session storage system for reliability,
|
|
performance, and multi-instance deployment support.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import asyncio
|
|
import json
|
|
from pathlib import Path
|
|
|
|
# Add session-manager to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
from database import (
|
|
DatabaseConnection,
|
|
SessionModel,
|
|
get_database_stats,
|
|
init_database,
|
|
run_migrations,
|
|
_db_connection,
|
|
)
|
|
|
|
# Set up logging
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def test_database_connection():
|
|
"""Test database connection and basic operations."""
|
|
print("🗄️ Testing Database Connection")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
# Test connection
|
|
health = await _db_connection.health_check()
|
|
if health.get("status") == "healthy":
|
|
print("✅ Database connection established")
|
|
return True
|
|
else:
|
|
print(f"❌ Database connection failed: {health}")
|
|
return False
|
|
except Exception as e:
|
|
print(f"❌ Database connection error: {e}")
|
|
return False
|
|
|
|
|
|
async def test_database_schema():
|
|
"""Test database schema creation and migrations."""
|
|
print("\n📋 Testing Database Schema")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
# Initialize database and run migrations
|
|
await init_database()
|
|
await run_migrations()
|
|
|
|
# Verify schema by checking if we can query the table
|
|
async with _db_connection.get_connection() as conn:
|
|
result = await conn.fetchval("""
|
|
SELECT EXISTS (
|
|
SELECT 1 FROM information_schema.tables
|
|
WHERE table_name = 'sessions'
|
|
)
|
|
""")
|
|
|
|
if result:
|
|
print("✅ Database schema created successfully")
|
|
|
|
# Check indexes
|
|
indexes = await conn.fetch("""
|
|
SELECT indexname FROM pg_indexes
|
|
WHERE tablename = 'sessions'
|
|
""")
|
|
|
|
index_names = [row["indexname"] for row in indexes]
|
|
expected_indexes = [
|
|
"sessions_pkey",
|
|
"idx_sessions_status",
|
|
"idx_sessions_last_accessed",
|
|
"idx_sessions_created_at",
|
|
]
|
|
|
|
for expected in expected_indexes:
|
|
if any(expected in name for name in index_names):
|
|
print(f"✅ Index {expected} exists")
|
|
else:
|
|
print(f"❌ Index {expected} missing")
|
|
|
|
return True
|
|
else:
|
|
print("❌ Sessions table not found")
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"❌ Schema creation failed: {e}")
|
|
return False
|
|
|
|
|
|
async def test_session_crud():
|
|
"""Test session create, read, update, delete operations."""
|
|
print("\n🔄 Testing Session CRUD Operations")
|
|
print("=" * 50)
|
|
|
|
test_session = {
|
|
"session_id": "test-session-db-123",
|
|
"container_name": "test-container-123",
|
|
"host_dir": "/tmp/test-session",
|
|
"port": 8081,
|
|
"auth_token": "test-token-abc123",
|
|
"status": "creating",
|
|
"metadata": {"test": True, "created_by": "test_script"},
|
|
}
|
|
|
|
try:
|
|
# Create session
|
|
created = await SessionModel.create_session(test_session)
|
|
if created and created["session_id"] == test_session["session_id"]:
|
|
print("✅ Session created successfully")
|
|
else:
|
|
print("❌ Session creation failed")
|
|
return False
|
|
|
|
# Read session
|
|
retrieved = await SessionModel.get_session(test_session["session_id"])
|
|
if retrieved and retrieved["session_id"] == test_session["session_id"]:
|
|
print("✅ Session retrieved successfully")
|
|
|
|
# Verify metadata
|
|
if retrieved.get("metadata", {}).get("test"):
|
|
print("✅ Session metadata preserved")
|
|
else:
|
|
print("❌ Session metadata missing")
|
|
else:
|
|
print("❌ Session retrieval failed")
|
|
return False
|
|
|
|
# Update session
|
|
updates = {"status": "running", "container_id": "container-abc123"}
|
|
updated = await SessionModel.update_session(test_session["session_id"], updates)
|
|
if updated:
|
|
print("✅ Session updated successfully")
|
|
|
|
# Verify update
|
|
updated_session = await SessionModel.get_session(test_session["session_id"])
|
|
if (
|
|
updated_session["status"] == "running"
|
|
and updated_session["container_id"] == "container-abc123"
|
|
):
|
|
print("✅ Session update verified")
|
|
else:
|
|
print("❌ Session update not reflected")
|
|
else:
|
|
print("❌ Session update failed")
|
|
|
|
# Delete session
|
|
deleted = await SessionModel.delete_session(test_session["session_id"])
|
|
if deleted:
|
|
print("✅ Session deleted successfully")
|
|
|
|
# Verify deletion
|
|
deleted_session = await SessionModel.get_session(test_session["session_id"])
|
|
if deleted_session is None:
|
|
print("✅ Session deletion verified")
|
|
else:
|
|
print("❌ Session still exists after deletion")
|
|
else:
|
|
print("❌ Session deletion failed")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ CRUD operation failed: {e}")
|
|
return False
|
|
|
|
|
|
async def test_concurrent_sessions():
|
|
"""Test handling multiple concurrent sessions."""
|
|
print("\n👥 Testing Concurrent Sessions")
|
|
print("=" * 50)
|
|
|
|
concurrent_sessions = []
|
|
for i in range(5):
|
|
session = {
|
|
"session_id": f"concurrent-session-{i}",
|
|
"container_name": f"container-{i}",
|
|
"host_dir": f"/tmp/session-{i}",
|
|
"port": 8080 + i,
|
|
"auth_token": f"token-{i}",
|
|
"status": "creating",
|
|
}
|
|
concurrent_sessions.append(session)
|
|
|
|
try:
|
|
# Create sessions concurrently
|
|
create_tasks = [
|
|
SessionModel.create_session(session) for session in concurrent_sessions
|
|
]
|
|
created_sessions = await asyncio.gather(*create_tasks)
|
|
|
|
successful_creates = sum(1 for s in created_sessions if s is not None)
|
|
print(
|
|
f"✅ Created {successful_creates}/{len(concurrent_sessions)} concurrent sessions"
|
|
)
|
|
|
|
# Retrieve sessions concurrently
|
|
retrieve_tasks = [
|
|
SessionModel.get_session(s["session_id"]) for s in concurrent_sessions
|
|
]
|
|
retrieved_sessions = await asyncio.gather(*retrieve_tasks)
|
|
|
|
successful_retrieves = sum(1 for s in retrieved_sessions if s is not None)
|
|
print(
|
|
f"✅ Retrieved {successful_retrieves}/{len(concurrent_sessions)} concurrent sessions"
|
|
)
|
|
|
|
# Update sessions concurrently
|
|
update_tasks = [
|
|
SessionModel.update_session(s["session_id"], {"status": "running"})
|
|
for s in concurrent_sessions
|
|
]
|
|
update_results = await asyncio.gather(*update_tasks)
|
|
|
|
successful_updates = sum(1 for r in update_results if r)
|
|
print(
|
|
f"✅ Updated {successful_updates}/{len(concurrent_sessions)} concurrent sessions"
|
|
)
|
|
|
|
# Clean up
|
|
cleanup_tasks = [
|
|
SessionModel.delete_session(s["session_id"]) for s in concurrent_sessions
|
|
]
|
|
await asyncio.gather(*cleanup_tasks)
|
|
|
|
print("✅ Concurrent session operations completed")
|
|
|
|
return (
|
|
successful_creates == len(concurrent_sessions)
|
|
and successful_retrieves == len(concurrent_sessions)
|
|
and successful_updates == len(concurrent_sessions)
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"❌ Concurrent operations failed: {e}")
|
|
return False
|
|
|
|
|
|
async def test_database_performance():
|
|
"""Test database performance and statistics."""
|
|
print("\n⚡ Testing Database Performance")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
# Get database statistics
|
|
stats = await get_database_stats()
|
|
|
|
if isinstance(stats, dict):
|
|
print("✅ Database statistics retrieved")
|
|
print(f" Total sessions: {stats.get('total_sessions', 'N/A')}")
|
|
print(f" Active sessions: {stats.get('active_sessions', 'N/A')}")
|
|
print(f" Database size: {stats.get('database_size', 'N/A')}")
|
|
|
|
if stats.get("status") == "healthy":
|
|
print("✅ Database health check passed")
|
|
else:
|
|
print(f"⚠️ Database health status: {stats.get('status')}")
|
|
else:
|
|
print("❌ Database statistics not available")
|
|
return False
|
|
|
|
# Test session counting
|
|
count = await SessionModel.count_sessions()
|
|
print(f"✅ Session count query: {count} sessions")
|
|
|
|
# Test active session counting
|
|
active_count = await SessionModel.get_active_sessions_count()
|
|
print(f"✅ Active session count: {active_count} sessions")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ Performance testing failed: {e}")
|
|
return False
|
|
|
|
|
|
async def test_session_queries():
|
|
"""Test various session query operations."""
|
|
print("\n🔍 Testing Session Queries")
|
|
print("=" * 50)
|
|
|
|
# Create test sessions with different statuses
|
|
test_sessions = [
|
|
{
|
|
"session_id": "query-test-1",
|
|
"container_name": "container-1",
|
|
"status": "creating",
|
|
},
|
|
{
|
|
"session_id": "query-test-2",
|
|
"container_name": "container-2",
|
|
"status": "running",
|
|
},
|
|
{
|
|
"session_id": "query-test-3",
|
|
"container_name": "container-3",
|
|
"status": "running",
|
|
},
|
|
{
|
|
"session_id": "query-test-4",
|
|
"container_name": "container-4",
|
|
"status": "stopped",
|
|
},
|
|
]
|
|
|
|
try:
|
|
# Create test sessions
|
|
for session in test_sessions:
|
|
await SessionModel.create_session(session)
|
|
|
|
# Test list sessions
|
|
all_sessions = await SessionModel.list_sessions(limit=10)
|
|
print(f"✅ Listed {len(all_sessions)} sessions")
|
|
|
|
# Test filter by status
|
|
running_sessions = await SessionModel.get_sessions_by_status("running")
|
|
print(f"✅ Found {len(running_sessions)} running sessions")
|
|
|
|
creating_sessions = await SessionModel.get_sessions_by_status("creating")
|
|
print(f"✅ Found {len(creating_sessions)} creating sessions")
|
|
|
|
# Verify counts
|
|
expected_running = len([s for s in test_sessions if s["status"] == "running"])
|
|
if len(running_sessions) == expected_running:
|
|
print("✅ Status filtering accurate")
|
|
else:
|
|
print(
|
|
f"❌ Status filtering incorrect: expected {expected_running}, got {len(running_sessions)}"
|
|
)
|
|
|
|
# Clean up test sessions
|
|
for session in test_sessions:
|
|
await SessionModel.delete_session(session["session_id"])
|
|
|
|
print("✅ Query testing completed")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ Query testing failed: {e}")
|
|
return False
|
|
|
|
|
|
async def run_all_database_tests():
|
|
"""Run all database persistence tests."""
|
|
print("💾 Database Persistence Test Suite")
|
|
print("=" * 70)
|
|
|
|
tests = [
|
|
("Database Connection", test_database_connection),
|
|
("Database Schema", test_database_schema),
|
|
("Session CRUD", test_session_crud),
|
|
("Concurrent Sessions", test_concurrent_sessions),
|
|
("Database Performance", test_database_performance),
|
|
("Session Queries", test_session_queries),
|
|
]
|
|
|
|
results = []
|
|
for test_name, test_func in tests:
|
|
print(f"\n{'=' * 25} {test_name} {'=' * 25}")
|
|
try:
|
|
result = await test_func()
|
|
results.append(result)
|
|
status = "✅ PASSED" if result else "❌ FAILED"
|
|
print(f"\n{status}: {test_name}")
|
|
except Exception as e:
|
|
print(f"\n❌ ERROR in {test_name}: {e}")
|
|
results.append(False)
|
|
|
|
# Summary
|
|
print(f"\n{'=' * 70}")
|
|
passed = sum(results)
|
|
total = len(results)
|
|
print(f"📊 Test Results: {passed}/{total} tests passed")
|
|
|
|
if passed == total:
|
|
print("🎉 All database persistence tests completed successfully!")
|
|
print("💾 PostgreSQL backend provides reliable session storage.")
|
|
else:
|
|
print("⚠️ Some tests failed. Check the output above for details.")
|
|
print("💡 Ensure PostgreSQL is running and connection settings are correct.")
|
|
print(" Required environment variables:")
|
|
print(" - DB_HOST (default: localhost)")
|
|
print(" - DB_PORT (default: 5432)")
|
|
print(" - DB_USER (default: lovdata)")
|
|
print(" - DB_PASSWORD (default: password)")
|
|
print(" - DB_NAME (default: lovdata_chat)")
|
|
|
|
return passed == total
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(run_all_database_tests())
|