""" Database Manager for Portfolio Manager. Handles SQLite database operations for holdings, transactions, and price cache. """ import sqlite3 import logging from datetime import datetime from typing import List, Dict, Optional, Tuple from contextlib import contextmanager import os logger = logging.getLogger(__name__) class DBManager: """Manages SQLite database operations for the portfolio.""" def __init__(self, db_path: str = "data/portfolio.db"): """ Initialize database manager. Args: db_path: Path to SQLite database file """ self.db_path = db_path self._ensure_data_directory() self.initialize_database() def _ensure_data_directory(self) -> None: """Ensure data directory exists.""" db_dir = os.path.dirname(self.db_path) if db_dir and not os.path.exists(db_dir): os.makedirs(db_dir) logger.info(f"Created data directory: {db_dir}") @contextmanager def get_connection(self): """ Context manager for database connections. Yields: sqlite3.Connection: Database connection """ conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row # Enable column access by name try: yield conn conn.commit() except Exception as e: conn.rollback() logger.error(f"Database error: {e}", exc_info=True) raise finally: conn.close() def initialize_database(self) -> None: """Create database tables if they don't exist.""" with self.get_connection() as conn: cursor = conn.cursor() # Holdings table cursor.execute(""" CREATE TABLE IF NOT EXISTS holdings ( id INTEGER PRIMARY KEY AUTOINCREMENT, ticker TEXT NOT NULL UNIQUE, name TEXT, asset_type TEXT, quantity REAL NOT NULL, avg_price REAL NOT NULL, current_price REAL, current_value REAL, weight_pct REAL, last_updated TIMESTAMP ) """) # Transactions table cursor.execute(""" CREATE TABLE IF NOT EXISTS transactions ( id INTEGER PRIMARY KEY AUTOINCREMENT, ticker TEXT NOT NULL, transaction_type TEXT NOT NULL, date DATE NOT NULL, quantity REAL, price REAL, amount REAL NOT NULL, notes TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Price cache table cursor.execute(""" CREATE TABLE IF NOT EXISTS price_cache ( ticker TEXT PRIMARY KEY, price REAL NOT NULL, last_updated TIMESTAMP NOT NULL ) """) logger.info(f"Database initialized at {self.db_path}") # ========== HOLDINGS OPERATIONS ========== def add_holding( self, ticker: str, name: str, asset_type: str, quantity: float, avg_price: float ) -> None: """ Add a new holding to the portfolio. Args: ticker: Stock ticker symbol (e.g., 'VWCE.MI') name: Full name of the asset asset_type: 'ETF', 'Stock', or 'Cash' quantity: Number of shares/units avg_price: Average purchase price """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO holdings (ticker, name, asset_type, quantity, avg_price, last_updated) VALUES (?, ?, ?, ?, ?, ?) """, (ticker, name, asset_type, quantity, avg_price, datetime.now())) logger.info(f"Added holding: {ticker} ({quantity} @ {avg_price})") def update_holding(self, ticker: str, **kwargs) -> None: """ Update holding fields. Args: ticker: Ticker to update **kwargs: Fields to update (quantity, avg_price, current_price, etc.) """ if not kwargs: return # Build SET clause dynamically set_clause = ", ".join([f"{key} = ?" for key in kwargs.keys()]) values = list(kwargs.values()) + [datetime.now(), ticker] with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(f""" UPDATE holdings SET {set_clause}, last_updated = ? WHERE ticker = ? """, values) logger.debug(f"Updated holding {ticker}: {kwargs}") def delete_holding(self, ticker: str) -> None: """ Delete a holding from the portfolio. Args: ticker: Ticker to delete """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM holdings WHERE ticker = ?", (ticker,)) logger.info(f"Deleted holding: {ticker}") def get_all_holdings(self) -> List[Dict]: """ Get all holdings from the database. Returns: List of holdings as dictionaries """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM holdings ORDER BY ticker") rows = cursor.fetchall() return [dict(row) for row in rows] def get_holding(self, ticker: str) -> Optional[Dict]: """ Get a specific holding by ticker. Args: ticker: Ticker to retrieve Returns: Holding dictionary or None if not found """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM holdings WHERE ticker = ?", (ticker,)) row = cursor.fetchone() return dict(row) if row else None # ========== TRANSACTIONS OPERATIONS ========== def add_transaction( self, ticker: str, tx_type: str, date: str, amount: float, quantity: Optional[float] = None, price: Optional[float] = None, notes: str = "" ) -> None: """ Add a transaction to the log. Args: ticker: Stock ticker tx_type: 'BUY', 'SELL', or 'DIVIDEND' date: Transaction date (YYYY-MM-DD) amount: Total transaction amount quantity: Number of shares (None for DIVIDEND) price: Price per share (None for DIVIDEND) notes: Optional notes """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO transactions (ticker, transaction_type, date, quantity, price, amount, notes) VALUES (?, ?, ?, ?, ?, ?, ?) """, (ticker, tx_type, date, quantity, price, amount, notes)) logger.info(f"Added transaction: {tx_type} {ticker} {quantity or ''} @ {price or ''} = {amount}") def get_transactions( self, ticker: Optional[str] = None, tx_type: Optional[str] = None ) -> List[Dict]: """ Get transactions with optional filtering. Args: ticker: Filter by ticker (None for all) tx_type: Filter by type (None for all) Returns: List of transactions as dictionaries """ query = "SELECT * FROM transactions WHERE 1=1" params = [] if ticker: query += " AND ticker = ?" params.append(ticker) if tx_type: query += " AND transaction_type = ?" params.append(tx_type) query += " ORDER BY date DESC, created_at DESC" with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(query, params) rows = cursor.fetchall() return [dict(row) for row in rows] def delete_transaction(self, transaction_id: int) -> None: """ Delete a transaction by ID. Args: transaction_id: Transaction ID to delete """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM transactions WHERE id = ?", (transaction_id,)) logger.info(f"Deleted transaction ID: {transaction_id}") # ========== PRICE CACHE OPERATIONS ========== def get_cached_price(self, ticker: str) -> Optional[Tuple[float, datetime]]: """ Get cached price for a ticker. Args: ticker: Stock ticker Returns: Tuple of (price, last_updated) or None if not cached """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT price, last_updated FROM price_cache WHERE ticker = ?", (ticker,) ) row = cursor.fetchone() if row: return (row['price'], datetime.fromisoformat(row['last_updated'])) return None def update_price_cache(self, ticker: str, price: float) -> None: """ Update price cache for a ticker. Args: ticker: Stock ticker price: Current price """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT OR REPLACE INTO price_cache (ticker, price, last_updated) VALUES (?, ?, ?) """, (ticker, price, datetime.now().isoformat())) logger.debug(f"Updated price cache: {ticker} = {price}") def get_all_tickers(self) -> List[str]: """ Get all unique tickers from holdings (excluding CASH). Returns: List of ticker symbols """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT ticker FROM holdings WHERE ticker != 'CASH' AND asset_type != 'Cash' ORDER BY ticker """) return [row['ticker'] for row in cursor.fetchall()] # ========== UTILITY OPERATIONS ========== def backup_database(self, backup_path: Optional[str] = None) -> str: """ Create a backup of the database. Args: backup_path: Path for backup file (auto-generated if None) Returns: Path to backup file """ if backup_path is None: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_path = f"{self.db_path}.backup_{timestamp}" with self.get_connection() as source: dest = sqlite3.connect(backup_path) source.backup(dest) dest.close() logger.info(f"Database backed up to {backup_path}") return backup_path def get_database_stats(self) -> Dict[str, int]: """ Get database statistics. Returns: Dictionary with table row counts """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT COUNT(*) as count FROM holdings") holdings_count = cursor.fetchone()['count'] cursor.execute("SELECT COUNT(*) as count FROM transactions") transactions_count = cursor.fetchone()['count'] cursor.execute("SELECT COUNT(*) as count FROM price_cache") cache_count = cursor.fetchone()['count'] return { 'holdings': holdings_count, 'transactions': transactions_count, 'price_cache': cache_count }