import json import sqlite3 from pathlib import Path CREATE_TABLE_SQL = """ CREATE TABLE IF NOT EXISTS main_capital_flow ( id TEXT PRIMARY KEY, trade_date TEXT NOT NULL UNIQUE, subject TEXT, snapshot_time TEXT, institution_amount_yi REAL, main_force_amount_yi REAL, large_household_amount_yi REAL, retail_amount_yi REAL, trend TEXT, summary TEXT NOT NULL, image_name TEXT NOT NULL, image_path TEXT NOT NULL, raw_extraction_json TEXT NOT NULL, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ) """ class MainCapitalFlowRepository: def __init__(self, db_path: Path) -> None: self.db_path = db_path self.db_path.parent.mkdir(parents=True, exist_ok=True) with self._connect() as connection: connection.execute(CREATE_TABLE_SQL) connection.commit() def _connect(self) -> sqlite3.Connection: connection = sqlite3.connect(self.db_path) connection.row_factory = sqlite3.Row return connection def list_records(self) -> list[dict]: with self._connect() as connection: rows = connection.execute( "SELECT * FROM main_capital_flow ORDER BY trade_date DESC, created_at DESC" ).fetchall() return [self._deserialize_row(row) for row in rows] def get_record(self, record_id: str) -> dict | None: with self._connect() as connection: row = connection.execute( "SELECT * FROM main_capital_flow WHERE id = ?", (record_id,), ).fetchone() return None if row is None else self._deserialize_row(row) def get_by_trade_date(self, trade_date: str) -> dict | None: with self._connect() as connection: row = connection.execute( "SELECT * FROM main_capital_flow WHERE trade_date = ?", (trade_date,), ).fetchone() return None if row is None else self._deserialize_row(row) def insert_record(self, payload: dict) -> dict: with self._connect() as connection: connection.execute( """ INSERT INTO main_capital_flow ( id, trade_date, subject, snapshot_time, institution_amount_yi, main_force_amount_yi, large_household_amount_yi, retail_amount_yi, trend, summary, image_name, image_path, raw_extraction_json, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( payload["id"], payload["trade_date"], payload.get("subject"), payload.get("snapshot_time"), payload.get("institution_amount_yi"), payload.get("main_force_amount_yi"), payload.get("large_household_amount_yi"), payload.get("retail_amount_yi"), payload.get("trend"), payload["summary"], payload["image_name"], payload["image_path"], json.dumps(payload.get("raw_extraction", {}), ensure_ascii=False), payload["created_at"], payload["updated_at"], ), ) connection.commit() return self.get_record(payload["id"]) or payload def delete_record(self, record_id: str) -> dict | None: record = self.get_record(record_id) if record is None: return None with self._connect() as connection: connection.execute("DELETE FROM main_capital_flow WHERE id = ?", (record_id,)) connection.commit() return record def _deserialize_row(self, row: sqlite3.Row) -> dict: payload = dict(row) payload["raw_extraction"] = json.loads(payload.pop("raw_extraction_json")) return payload