""" Database Manager for AI Trading System Sprint 1: Foundation - Database connection and session management """ import os from contextlib import contextmanager from typing import Generator from loguru import logger from sqlalchemy import create_engine, event, text from sqlalchemy.engine import Engine from sqlalchemy.orm import Session, sessionmaker from src.database.models import Base class DatabaseManager: """Manages database connections and sessions""" def __init__(self, db_path: str = "data/trading_system.db", echo: bool = False): """ Initialize database manager Args: db_path: Path to SQLite database file echo: Enable SQL query logging """ self.db_path = db_path self.echo = echo # Ensure data directory exists os.makedirs(os.path.dirname(db_path), exist_ok=True) # Create engine self.engine = create_engine( f"sqlite:///{db_path}", echo=echo, connect_args={"check_same_thread": False}, ) # Enable foreign keys for SQLite @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_conn, connection_record): cursor = dbapi_conn.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() # Create session factory self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) logger.info(f"Database manager initialized: {db_path}") def create_tables(self) -> None: """Create all database tables""" try: Base.metadata.create_all(bind=self.engine) logger.info("Database tables created successfully") except Exception as e: logger.error(f"Error creating database tables: {e}") raise def drop_tables(self) -> None: """Drop all database tables (use with caution!)""" try: Base.metadata.drop_all(bind=self.engine) logger.warning("All database tables dropped") except Exception as e: logger.error(f"Error dropping database tables: {e}") raise @contextmanager def get_session(self) -> Generator[Session, None, None]: """ Context manager for database sessions Usage: with db_manager.get_session() as session: # Do database operations session.query(Stock).all() """ session = self.SessionLocal() try: yield session session.commit() except Exception as e: session.rollback() logger.error(f"Database session error: {e}") raise finally: session.close() def health_check(self) -> bool: """ Check database connectivity Returns: True if database is accessible, False otherwise """ try: with self.get_session() as session: session.execute(text("SELECT 1")) logger.info("Database health check: OK") return True except Exception as e: logger.error(f"Database health check failed: {e}") return False def get_table_count(self, table_name: str) -> int: """ Get row count for a specific table Args: table_name: Name of the table Returns: Number of rows in the table """ try: with self.get_session() as session: result = session.execute(text(f"SELECT COUNT(*) FROM {table_name}")) count = result.scalar() return count except Exception as e: logger.error(f"Error getting count for table {table_name}: {e}") return 0 def close(self) -> None: """Close database engine""" try: self.engine.dispose() logger.info("Database connection closed") except Exception as e: logger.error(f"Error closing database: {e}")