import json import sqlite3 from pathlib import Path CREATE_TABLE_SQL = """ CREATE TABLE IF NOT EXISTS capital_image_records ( id TEXT PRIMARY KEY, trade_date TEXT, subject TEXT, snapshot_time TEXT, main_force_amount_yi REAL, institution_amount_yi REAL, large_household_amount_yi REAL, retail_amount_yi REAL, overall_trend TEXT, intraday_summary TEXT, review_status TEXT NOT NULL, extraction_method TEXT NOT NULL, image_name TEXT NOT NULL, image_path TEXT NOT NULL, raw_extraction_json TEXT NOT NULL, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ) """ class CapitalImageRepository: def __init__(self, db_path: Path) -> None: self.db_path = db_path self.db_path.parent.mkdir(parents=True, exist_ok=True) with self._connect() as connection: connection.execute(CREATE_TABLE_SQL) connection.commit() def _connect(self) -> sqlite3.Connection: connection = sqlite3.connect(self.db_path) connection.row_factory = sqlite3.Row return connection def list_records(self, trade_date: str | None = None, subject: str | None = None) -> list[dict]: query = "SELECT * FROM capital_image_records" clauses: list[str] = [] params: list[str] = [] if trade_date: clauses.append("trade_date = ?") params.append(trade_date) if subject: clauses.append("subject LIKE ?") params.append(f"%{subject}%") if clauses: query += " WHERE " + " AND ".join(clauses) query += " ORDER BY created_at DESC" with self._connect() as connection: rows = connection.execute(query, params).fetchall() return [self._deserialize_row(row) for row in rows] def get_record(self, record_id: str) -> dict | None: with self._connect() as connection: row = connection.execute( "SELECT * FROM capital_image_records WHERE id = ?", (record_id,), ).fetchone() if row is None: return None return self._deserialize_row(row) def insert_record(self, payload: dict) -> dict: with self._connect() as connection: connection.execute( """ INSERT INTO capital_image_records ( id, trade_date, subject, snapshot_time, main_force_amount_yi, institution_amount_yi, large_household_amount_yi, retail_amount_yi, overall_trend, intraday_summary, review_status, extraction_method, image_name, image_path, raw_extraction_json, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( payload["id"], payload.get("trade_date"), payload.get("subject"), payload.get("snapshot_time"), payload.get("main_force_amount_yi"), payload.get("institution_amount_yi"), payload.get("large_household_amount_yi"), payload.get("retail_amount_yi"), payload.get("overall_trend"), payload.get("intraday_summary"), payload["review_status"], payload["extraction_method"], payload["image_name"], payload["image_path"], json.dumps(payload.get("raw_extraction", {}), ensure_ascii=False), payload["created_at"], payload["updated_at"], ), ) connection.commit() return self.get_record(payload["id"]) or payload def _deserialize_row(self, row: sqlite3.Row) -> dict: payload = dict(row) payload["raw_extraction"] = json.loads(payload.pop("raw_extraction_json")) return payload