Files
zjjk/backend/app/repositories/main_capital_flow_repository.py

119 lines
4.1 KiB
Python
Raw Permalink Normal View History

2026-04-08 20:04:40 +08:00
import json
import sqlite3
from pathlib import Path
CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS main_capital_flow (
id TEXT PRIMARY KEY,
trade_date TEXT NOT NULL UNIQUE,
subject TEXT,
snapshot_time TEXT,
institution_amount_yi REAL,
main_force_amount_yi REAL,
large_household_amount_yi REAL,
retail_amount_yi REAL,
trend TEXT,
summary TEXT NOT NULL,
image_name TEXT NOT NULL,
image_path TEXT NOT NULL,
raw_extraction_json TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
"""
class MainCapitalFlowRepository:
def __init__(self, db_path: Path) -> None:
self.db_path = db_path
self.db_path.parent.mkdir(parents=True, exist_ok=True)
with self._connect() as connection:
connection.execute(CREATE_TABLE_SQL)
connection.commit()
def _connect(self) -> sqlite3.Connection:
connection = sqlite3.connect(self.db_path)
connection.row_factory = sqlite3.Row
return connection
def list_records(self) -> list[dict]:
with self._connect() as connection:
rows = connection.execute(
"SELECT * FROM main_capital_flow ORDER BY trade_date DESC, created_at DESC"
).fetchall()
return [self._deserialize_row(row) for row in rows]
def get_record(self, record_id: str) -> dict | None:
with self._connect() as connection:
row = connection.execute(
"SELECT * FROM main_capital_flow WHERE id = ?",
(record_id,),
).fetchone()
return None if row is None else self._deserialize_row(row)
def get_by_trade_date(self, trade_date: str) -> dict | None:
with self._connect() as connection:
row = connection.execute(
"SELECT * FROM main_capital_flow WHERE trade_date = ?",
(trade_date,),
).fetchone()
return None if row is None else self._deserialize_row(row)
def insert_record(self, payload: dict) -> dict:
with self._connect() as connection:
connection.execute(
"""
INSERT INTO main_capital_flow (
id,
trade_date,
subject,
snapshot_time,
institution_amount_yi,
main_force_amount_yi,
large_household_amount_yi,
retail_amount_yi,
trend,
summary,
image_name,
image_path,
raw_extraction_json,
created_at,
updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
payload["id"],
payload["trade_date"],
payload.get("subject"),
payload.get("snapshot_time"),
payload.get("institution_amount_yi"),
payload.get("main_force_amount_yi"),
payload.get("large_household_amount_yi"),
payload.get("retail_amount_yi"),
payload.get("trend"),
payload["summary"],
payload["image_name"],
payload["image_path"],
json.dumps(payload.get("raw_extraction", {}), ensure_ascii=False),
payload["created_at"],
payload["updated_at"],
),
)
connection.commit()
return self.get_record(payload["id"]) or payload
def delete_record(self, record_id: str) -> dict | None:
record = self.get_record(record_id)
if record is None:
return None
with self._connect() as connection:
connection.execute("DELETE FROM main_capital_flow WHERE id = ?", (record_id,))
connection.commit()
return record
def _deserialize_row(self, row: sqlite3.Row) -> dict:
payload = dict(row)
payload["raw_extraction"] = json.loads(payload.pop("raw_extraction_json"))
return payload