Spaces:
Sleeping
Sleeping
| """ | |
| SAAP Database Connection Management - Production Ready | |
| SQLAlchemy database connection, session management, and health monitoring | |
| """ | |
| import asyncio | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from sqlalchemy import create_engine, text, event | |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker | |
| from sqlalchemy.orm import sessionmaker, Session | |
| from sqlalchemy.pool import QueuePool, NullPool, AsyncAdaptedQueuePool | |
| from sqlalchemy.exc import SQLAlchemyError, OperationalError | |
| from typing import AsyncGenerator, Optional, Dict, Any | |
| from datetime import datetime, timedelta | |
| from config.settings import settings | |
| from database.models import Base, DBHealthCheck | |
| logger = logging.getLogger(__name__) | |
| class DatabaseManager: | |
| """ | |
| Production-ready database connection manager | |
| Features: | |
| - Connection pooling with health monitoring | |
| - Async and sync session management | |
| - Automatic retry and fallback mechanisms | |
| - Database health checks and metrics | |
| - Migration support | |
| """ | |
| def __init__(self): | |
| self.engine = None | |
| self.async_engine = None | |
| self.SessionLocal = None | |
| self.AsyncSessionLocal = None | |
| self.is_initialized = False | |
| self.last_health_check = None | |
| self.health_status = {"status": "initializing"} | |
| def _get_sync_engine_kwargs(self) -> Dict[str, Any]: | |
| """Get sync engine configuration based on database type""" | |
| database_url = settings.get_database_url() | |
| base_kwargs = { | |
| "echo": settings.debug, | |
| "future": True | |
| } | |
| if database_url.startswith("sqlite"): | |
| # SQLite-specific configuration | |
| base_kwargs.update({ | |
| "poolclass": NullPool, # SQLite doesn't need connection pooling | |
| "connect_args": { | |
| "check_same_thread": settings.database.sqlite_check_same_thread | |
| } | |
| }) | |
| else: | |
| # PostgreSQL/MySQL configuration with connection pooling | |
| base_kwargs.update({ | |
| "poolclass": QueuePool, # Use QueuePool for sync engines | |
| "pool_size": settings.database.pool_size, | |
| "max_overflow": settings.database.max_overflow, | |
| "pool_timeout": settings.database.pool_timeout, | |
| "pool_recycle": settings.database.pool_recycle, | |
| "pool_pre_ping": True # Verify connections before use | |
| }) | |
| return base_kwargs | |
| def _get_async_engine_kwargs(self) -> Dict[str, Any]: | |
| """Get async engine configuration based on database type""" | |
| database_url = settings.get_database_url() | |
| base_kwargs = { | |
| "echo": settings.debug, | |
| "future": True | |
| } | |
| if database_url.startswith("sqlite"): | |
| # SQLite-specific configuration for async | |
| base_kwargs.update({ | |
| "poolclass": NullPool, # SQLite doesn't need connection pooling | |
| }) | |
| else: | |
| # PostgreSQL/MySQL configuration with async connection pooling | |
| base_kwargs.update({ | |
| "poolclass": AsyncAdaptedQueuePool, # Use AsyncAdaptedQueuePool for async engines | |
| "pool_size": settings.database.pool_size, | |
| "max_overflow": settings.database.max_overflow, | |
| "pool_timeout": settings.database.pool_timeout, | |
| "pool_recycle": settings.database.pool_recycle, | |
| "pool_pre_ping": True # Verify connections before use | |
| }) | |
| return base_kwargs | |
| def _setup_sql_logging(self, engine): | |
| """Setup SQL logging for debugging""" | |
| if settings.debug: | |
| def log_sql(conn, cursor, statement, parameters, context, executemany): | |
| """Log SQL statements in debug mode""" | |
| logger.debug(f"SQL: {statement}") | |
| if parameters: | |
| logger.debug(f"Parameters: {parameters}") | |
| async def initialize(self): | |
| """Initialize database connections and create tables""" | |
| try: | |
| logger.info("π Initializing SAAP Database Connection...") | |
| database_url = settings.get_database_url() | |
| # Create sync engine for migrations and admin tasks | |
| sync_kwargs = self._get_sync_engine_kwargs() | |
| self.engine = create_engine(database_url, **sync_kwargs) | |
| # Setup SQL logging for sync engine | |
| self._setup_sql_logging(self.engine) | |
| # Create async engine for main application | |
| async_url = database_url.replace("sqlite://", "sqlite+aiosqlite://") | |
| if not database_url.startswith("sqlite"): | |
| # For PostgreSQL: replace postgresql:// with postgresql+asyncpg:// | |
| async_url = database_url.replace("postgresql://", "postgresql+asyncpg://") | |
| async_kwargs = self._get_async_engine_kwargs() | |
| self.async_engine = create_async_engine(async_url, **async_kwargs) | |
| # Setup SQL logging for async engine | |
| self._setup_sql_logging(self.async_engine.sync_engine) | |
| # Create session factories | |
| self.SessionLocal = sessionmaker( | |
| bind=self.engine, | |
| autocommit=False, | |
| autoflush=False, | |
| expire_on_commit=False | |
| ) | |
| self.AsyncSessionLocal = async_sessionmaker( | |
| bind=self.async_engine, | |
| class_=AsyncSession, | |
| autocommit=False, | |
| autoflush=False, | |
| expire_on_commit=False | |
| ) | |
| # π§ FIX: Configure mapper registry ONCE to prevent recursion | |
| # This prevents SQLAlchemy from trying to configure mappers multiple times | |
| # when both sync and async engines are used with the same Base | |
| try: | |
| Base.registry.configure() | |
| logger.info("β Mapper registry configured successfully") | |
| except Exception as mapper_error: | |
| logger.warning(f"β οΈ Mapper already configured: {mapper_error}") | |
| # Create database tables | |
| await self._create_tables() | |
| # β FIX: Set initialized BEFORE health check to prevent recursion | |
| # This prevents get_async_session() from calling initialize() again | |
| self.is_initialized = True | |
| # Perform initial health check (safe now - won't trigger re-initialization) | |
| await self._update_health_status() | |
| logger.info(f"β Database initialized successfully: {database_url}") | |
| except Exception as e: | |
| logger.error(f"β Database initialization failed: {e}") | |
| self.health_status = {"status": "failed", "error": str(e)} | |
| raise | |
| async def _create_tables(self): | |
| """Create database tables if they don't exist""" | |
| try: | |
| if settings.debug: | |
| logger.debug("π§ Creating database tables...") | |
| if settings.get_database_url().startswith("sqlite"): | |
| # For SQLite, use sync engine | |
| Base.metadata.create_all(bind=self.engine) | |
| logger.info("β Database tables created (SQLite)") | |
| else: | |
| # For PostgreSQL/MySQL, use async engine | |
| async with self.async_engine.begin() as conn: | |
| await conn.run_sync(Base.metadata.create_all) | |
| logger.info("β Database tables created (Async)") | |
| except Exception as e: | |
| logger.error(f"β Failed to create database tables: {e}") | |
| raise | |
| async def get_async_session(self) -> AsyncGenerator[AsyncSession, None]: | |
| """Get async database session with automatic cleanup""" | |
| if not self.is_initialized: | |
| await self.initialize() | |
| session = self.AsyncSessionLocal() | |
| try: | |
| yield session | |
| await session.commit() | |
| except Exception as e: | |
| await session.rollback() | |
| logger.error(f"β Database session error: {e}") | |
| raise | |
| finally: | |
| await session.close() | |
| def get_sync_session(self) -> Session: | |
| """Get sync database session (for migrations and admin tasks)""" | |
| if not self.engine: | |
| raise RuntimeError("Database not initialized") | |
| return self.SessionLocal() | |
| async def _update_health_status(self): | |
| """Update database health monitoring status""" | |
| try: | |
| start_time = datetime.utcnow() | |
| # Test database connectivity | |
| async with self.get_async_session() as session: | |
| result = await session.execute(text("SELECT 1")) | |
| result.fetchone() | |
| end_time = datetime.utcnow() | |
| response_time = (end_time - start_time).total_seconds() * 1000 # Convert to milliseconds | |
| # Get agent count - with proper error handling | |
| agent_count = 0 | |
| active_agent_count = 0 | |
| try: | |
| async with self.get_async_session() as session: | |
| # Check if agents table exists first | |
| check_table_query = text(""" | |
| SELECT COUNT(*) FROM information_schema.tables | |
| WHERE table_name = 'agents' | |
| """) | |
| # For SQLite, use different query | |
| if settings.get_database_url().startswith("sqlite"): | |
| check_table_query = text(""" | |
| SELECT COUNT(*) FROM sqlite_master | |
| WHERE type='table' AND name='agents' | |
| """) | |
| table_exists_result = await session.execute(check_table_query) | |
| table_exists = table_exists_result.scalar() > 0 | |
| if table_exists: | |
| agent_count_result = await session.execute(text("SELECT COUNT(*) FROM agents")) | |
| agent_count = agent_count_result.scalar() | |
| active_agent_count_result = await session.execute( | |
| text("SELECT COUNT(*) FROM agents WHERE status = 'active'") | |
| ) | |
| active_agent_count = active_agent_count_result.scalar() | |
| except Exception as table_error: | |
| logger.debug(f"Tables not ready yet: {table_error}") | |
| # This is expected during initial setup | |
| self.health_status = { | |
| "status": "healthy", | |
| "database_type": settings.get_database_url().split("://")[0], | |
| "response_time_ms": response_time, | |
| "agent_count": agent_count, | |
| "active_agent_count": active_agent_count, | |
| "connection_pool": { | |
| "size": getattr(self.async_engine.pool, 'size', 0) if self.async_engine else 0, | |
| "checked_out": getattr(self.async_engine.pool, 'checked_out', 0) if self.async_engine else 0 | |
| }, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| self.last_health_check = datetime.utcnow() | |
| # Save health check to database (optional) | |
| await self._save_health_check(response_time, agent_count, active_agent_count) | |
| except Exception as e: | |
| logger.error(f"β Database health check failed: {e}") | |
| self.health_status = { | |
| "status": "error", | |
| "error": str(e), | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| async def _save_health_check(self, response_time: float, agent_count: int, active_agent_count: int): | |
| """Save health check result to database""" | |
| try: | |
| async with self.get_async_session() as session: | |
| health_check = DBHealthCheck( | |
| component="database", | |
| status=self.health_status["status"], | |
| response_time_ms=response_time, | |
| agent_count=agent_count, | |
| active_agent_count=active_agent_count, | |
| details={"database_type": settings.get_database_url().split("://")[0]} | |
| ) | |
| session.add(health_check) | |
| await session.commit() | |
| except Exception as e: | |
| logger.warning(f"β οΈ Failed to save health check: {e}") | |
| async def health_check(self) -> Dict[str, Any]: | |
| """Get current database health status""" | |
| # Update health status if it's been more than 30 seconds | |
| if not self.last_health_check or (datetime.utcnow() - self.last_health_check).seconds > 30: | |
| await self._update_health_status() | |
| return self.health_status | |
| async def get_performance_metrics(self) -> Dict[str, Any]: | |
| """Get database performance metrics""" | |
| try: | |
| async with self.get_async_session() as session: | |
| # Get recent health checks | |
| recent_checks_query = text(""" | |
| SELECT * FROM health_checks | |
| WHERE component = 'database' | |
| ORDER BY created_at DESC | |
| LIMIT 10 | |
| """) | |
| recent_checks = await session.execute(recent_checks_query) | |
| checks = recent_checks.fetchall() | |
| if checks: | |
| avg_response_time = sum(check.response_time_ms for check in checks) / len(checks) | |
| latest_check = checks[0] | |
| return { | |
| "average_response_time_ms": avg_response_time, | |
| "latest_agent_count": latest_check.agent_count, | |
| "latest_active_agents": latest_check.active_agent_count, | |
| "health_checks_count": len(checks), | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| else: | |
| return {"message": "No performance data available"} | |
| except Exception as e: | |
| logger.error(f"β Failed to get performance metrics: {e}") | |
| return {"error": str(e)} | |
| async def cleanup_old_data(self, days: int = 30): | |
| """Clean up old data from database""" | |
| try: | |
| cutoff_date = datetime.utcnow() - timedelta(days=days) | |
| async with self.get_async_session() as session: | |
| # Clean old chat messages | |
| await session.execute( | |
| text("DELETE FROM chat_messages WHERE created_at < :cutoff_date"), | |
| {"cutoff_date": cutoff_date} | |
| ) | |
| # Clean old health checks | |
| await session.execute( | |
| text("DELETE FROM health_checks WHERE created_at < :cutoff_date"), | |
| {"cutoff_date": cutoff_date} | |
| ) | |
| # Clean old system logs | |
| await session.execute( | |
| text("DELETE FROM system_logs WHERE created_at < :cutoff_date"), | |
| {"cutoff_date": cutoff_date} | |
| ) | |
| await session.commit() | |
| logger.info(f"β Cleaned up data older than {days} days") | |
| except Exception as e: | |
| logger.error(f"β Data cleanup failed: {e}") | |
| async def close(self): | |
| """Close database connections""" | |
| try: | |
| logger.info("π§ Closing database connections...") | |
| if self.async_engine: | |
| await self.async_engine.dispose() | |
| if self.engine: | |
| self.engine.dispose() | |
| self.is_initialized = False | |
| logger.info("β Database connections closed") | |
| except Exception as e: | |
| logger.error(f"β Error closing database: {e}") | |
| # Global database manager instance | |
| db_manager = DatabaseManager() | |
| # Convenience functions for dependency injection | |
| async def get_db_session() -> AsyncGenerator[AsyncSession, None]: | |
| """FastAPI dependency for getting database session""" | |
| async with db_manager.get_async_session() as session: | |
| yield session | |
| def get_sync_db_session() -> Session: | |
| """Get synchronous database session""" | |
| return db_manager.get_sync_session() | |
| if __name__ == "__main__": | |
| async def test_database(): | |
| """Test database connectivity""" | |
| await db_manager.initialize() | |
| health = await db_manager.health_check() | |
| print(f"π Database Health: {health}") | |
| metrics = await db_manager.get_performance_metrics() | |
| print(f"π Performance Metrics: {metrics}") | |
| await db_manager.close() | |
| asyncio.run(test_database()) | |