Files
cjfx/backend/app/services/storage.py

481 lines
19 KiB
Python
Raw Normal View History

2026-03-20 22:59:54 +08:00
from __future__ import annotations
import hashlib
import json
from contextlib import contextmanager
from datetime import date, datetime
from pathlib import Path
from typing import Any, Iterator
from sqlalchemy import JSON, Date, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, create_engine, delete, select
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship, sessionmaker
from app.models import (
Account,
ClsNewsDocument,
ClsNewsItem,
ClsNewsSummary,
ClsSectorImpact,
DailyInputAccount,
DailyInputDocument,
OpinionArticle,
ReportDocument,
ReportListItem,
)
PROJECT_ROOT = Path(__file__).resolve().parents[3]
CONFIG_DIR = PROJECT_ROOT / "backend" / "config"
DATABASE_CONFIG_PATH = CONFIG_DIR / "database.json"
class Base(DeclarativeBase):
pass
class AccountRecord(Base):
__tablename__ = "accounts"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
name: Mapped[str] = mapped_column(String(128), nullable=False)
description: Mapped[str] = mapped_column(String(255), nullable=False)
class DailyInputRecord(Base):
__tablename__ = "daily_inputs"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
date: Mapped[Any] = mapped_column(Date, nullable=False, unique=True, index=True)
updated_at: Mapped[Any] = mapped_column(DateTime(timezone=True), nullable=False)
links: Mapped[list["DailyInputLinkRecord"]] = relationship(
back_populates="daily_input",
cascade="all, delete-orphan",
order_by="DailyInputLinkRecord.sort_order",
)
class DailyInputLinkRecord(Base):
__tablename__ = "daily_input_links"
__table_args__ = (
UniqueConstraint("daily_input_id", "account_id", "url_hash", name="uq_daily_input_account_url"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
daily_input_id: Mapped[int] = mapped_column(ForeignKey("daily_inputs.id", ondelete="CASCADE"), nullable=False)
account_id: Mapped[str] = mapped_column(ForeignKey("accounts.id"), nullable=False, index=True)
url: Mapped[str] = mapped_column(String(1024), nullable=False)
url_hash: Mapped[str] = mapped_column(String(64), nullable=False)
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
daily_input: Mapped[DailyInputRecord] = relationship(back_populates="links")
class ReportRecord(Base):
__tablename__ = "reports"
date: Mapped[Any] = mapped_column(Date, primary_key=True)
generated_at: Mapped[Any] = mapped_column(DateTime(timezone=True), nullable=False)
summary: Mapped[str] = mapped_column(Text, nullable=False)
focus_sectors: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
article_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
account_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
articles: Mapped[list["ReportArticleRecord"]] = relationship(
back_populates="report",
cascade="all, delete-orphan",
order_by="ReportArticleRecord.sort_order",
)
class ReportArticleRecord(Base):
__tablename__ = "report_articles"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
report_date: Mapped[Any] = mapped_column(ForeignKey("reports.date", ondelete="CASCADE"), nullable=False, index=True)
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
article_id: Mapped[str] = mapped_column(String(128), nullable=False)
account_id: Mapped[str] = mapped_column(String(64), nullable=False)
account_name: Mapped[str] = mapped_column(String(128), nullable=False)
title: Mapped[str] = mapped_column(String(255), nullable=False)
published_at: Mapped[str] = mapped_column(String(64), nullable=False)
summary: Mapped[str] = mapped_column(Text, nullable=False)
source_url: Mapped[str] = mapped_column(Text, nullable=False)
sectors: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
sentiment: Mapped[str] = mapped_column(String(16), nullable=False)
article_type: Mapped[str] = mapped_column(String(64), nullable=False)
report: Mapped[ReportRecord] = relationship(back_populates="articles")
class ClsNewsSnapshotRecord(Base):
__tablename__ = "cls_news_snapshots"
date: Mapped[Any] = mapped_column(Date, primary_key=True)
updated_at: Mapped[Any] = mapped_column(DateTime(timezone=True), nullable=False)
window_label: Mapped[str] = mapped_column(String(64), nullable=False)
overview: Mapped[str] = mapped_column(Text, nullable=False)
hot_topics: Mapped[str] = mapped_column(Text, nullable=False)
watch_list: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
sector_impacts: Mapped[list["ClsSectorImpactRecord"]] = relationship(
back_populates="snapshot",
cascade="all, delete-orphan",
order_by="ClsSectorImpactRecord.sort_order",
)
items: Mapped[list["ClsNewsItemRecord"]] = relationship(
back_populates="snapshot",
cascade="all, delete-orphan",
order_by="ClsNewsItemRecord.sort_order",
)
class ClsSectorImpactRecord(Base):
__tablename__ = "cls_sector_impacts"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
snapshot_date: Mapped[Any] = mapped_column(ForeignKey("cls_news_snapshots.date", ondelete="CASCADE"), nullable=False, index=True)
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
sector: Mapped[str] = mapped_column(String(64), nullable=False)
sentiment: Mapped[str] = mapped_column(String(16), nullable=False)
reason: Mapped[str] = mapped_column(Text, nullable=False)
related_titles: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
snapshot: Mapped[ClsNewsSnapshotRecord] = relationship(back_populates="sector_impacts")
class ClsNewsItemRecord(Base):
__tablename__ = "cls_news_items"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
snapshot_date: Mapped[Any] = mapped_column(ForeignKey("cls_news_snapshots.date", ondelete="CASCADE"), nullable=False, index=True)
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
item_id: Mapped[str] = mapped_column(String(128), nullable=False)
title: Mapped[str] = mapped_column(String(255), nullable=False)
published_at: Mapped[str] = mapped_column(String(64), nullable=False)
source: Mapped[str] = mapped_column(String(128), nullable=False)
summary: Mapped[str] = mapped_column(Text, nullable=False)
reference_url: Mapped[str] = mapped_column(Text, nullable=False)
sectors: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
sentiment: Mapped[str] = mapped_column(String(16), nullable=False)
snapshot: Mapped[ClsNewsSnapshotRecord] = relationship(back_populates="items")
def load_database_config() -> dict[str, Any]:
if not DATABASE_CONFIG_PATH.exists():
raise RuntimeError(f"Database config not found: {DATABASE_CONFIG_PATH}")
config = json.loads(DATABASE_CONFIG_PATH.read_text(encoding="utf-8"))
required_fields = ("host", "port", "username", "password", "database")
missing = [field for field in required_fields if not config.get(field)]
if missing:
raise RuntimeError(f"Database config is incomplete: {', '.join(missing)}")
if any(str(config.get(field)).strip() == "CHANGE_ME" for field in required_fields):
raise RuntimeError(f"Database config still contains placeholders: {DATABASE_CONFIG_PATH}")
config.setdefault("charset", "utf8mb4")
config.setdefault("echo", False)
return config
_ENGINE = None
_SESSION_FACTORY: sessionmaker[Session] | None = None
def parse_date(value: str) -> date:
return date.fromisoformat(value)
def parse_datetime(value: str) -> datetime:
return datetime.fromisoformat(value)
def hash_url(value: str) -> str:
return hashlib.sha256(value.encode("utf-8")).hexdigest()
def get_engine():
global _ENGINE
if _ENGINE is None:
config = load_database_config()
_ENGINE = create_engine(
(
f"mysql+pymysql://{config['username']}:{config['password']}"
f"@{config['host']}:{config['port']}/{config['database']}?charset={config['charset']}"
),
echo=bool(config.get("echo", False)),
future=True,
)
return _ENGINE
def get_session_factory() -> sessionmaker[Session]:
global _SESSION_FACTORY
if _SESSION_FACTORY is None:
_SESSION_FACTORY = sessionmaker(bind=get_engine(), autoflush=False, autocommit=False, future=True)
return _SESSION_FACTORY
def init_database() -> None:
Base.metadata.create_all(get_engine())
@contextmanager
def session_scope() -> Iterator[Session]:
session = get_session_factory()()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def save_accounts(accounts: list[Account]) -> None:
with session_scope() as session:
for account in accounts:
existing = session.get(AccountRecord, account.id)
if existing is None:
session.add(AccountRecord(id=account.id, name=account.name, description=account.description))
continue
existing.name = account.name
existing.description = account.description
def fetch_accounts() -> list[Account]:
with session_scope() as session:
records = session.scalars(select(AccountRecord).order_by(AccountRecord.id)).all()
return [Account(id=record.id, name=record.name, description=record.description) for record in records]
def fetch_daily_input_document(date_str: str) -> DailyInputDocument | None:
with session_scope() as session:
record = session.scalar(select(DailyInputRecord).where(DailyInputRecord.date == parse_date(date_str)))
if record is None:
return None
account_records = session.scalars(select(AccountRecord).order_by(AccountRecord.id)).all()
links_by_account: dict[str, list[str]] = {}
for link in record.links:
links_by_account.setdefault(link.account_id, []).append(link.url)
return DailyInputDocument(
date=str(record.date),
updated_at=record.updated_at.isoformat(timespec="seconds"),
accounts=[
DailyInputAccount(
account_id=account.id,
account_name=account.name,
links=links_by_account.get(account.id, []),
)
for account in account_records
],
)
def save_daily_input_document(document: DailyInputDocument) -> DailyInputDocument:
with session_scope() as session:
record = session.scalar(select(DailyInputRecord).where(DailyInputRecord.date == parse_date(document.date)))
if record is None:
record = DailyInputRecord(date=parse_date(document.date), updated_at=parse_datetime(document.updated_at))
session.add(record)
session.flush()
else:
record.updated_at = parse_datetime(document.updated_at)
record.links.clear()
session.flush()
sort_order = 0
for account in document.accounts:
for url in account.links:
record.links.append(
DailyInputLinkRecord(
account_id=account.account_id,
url=url,
url_hash=hash_url(url),
sort_order=sort_order,
)
)
sort_order += 1
return document
def fetch_report_document(date_str: str) -> ReportDocument | None:
with session_scope() as session:
record = session.get(ReportRecord, parse_date(date_str))
if record is None:
return None
return ReportDocument(
date=str(record.date),
generated_at=record.generated_at.isoformat(timespec="seconds"),
summary=record.summary,
focus_sectors=list(record.focus_sectors or []),
article_count=record.article_count,
account_count=record.account_count,
articles=[
OpinionArticle(
id=article.article_id,
account_id=article.account_id,
account_name=article.account_name,
title=article.title,
published_at=article.published_at,
summary=article.summary,
source_url=article.source_url,
sectors=list(article.sectors or []),
sentiment=article.sentiment,
article_type=article.article_type,
)
for article in record.articles
],
)
def save_report_document(document: ReportDocument) -> ReportDocument:
with session_scope() as session:
record = session.get(ReportRecord, parse_date(document.date))
if record is None:
record = ReportRecord(
date=parse_date(document.date),
generated_at=parse_datetime(document.generated_at),
summary=document.summary,
focus_sectors=document.focus_sectors,
article_count=document.article_count,
account_count=document.account_count,
)
session.add(record)
else:
record.generated_at = parse_datetime(document.generated_at)
record.summary = document.summary
record.focus_sectors = document.focus_sectors
record.article_count = document.article_count
record.account_count = document.account_count
record.articles.clear()
session.flush()
for index, article in enumerate(document.articles):
record.articles.append(
ReportArticleRecord(
sort_order=index,
article_id=article.id,
account_id=article.account_id,
account_name=article.account_name,
title=article.title,
published_at=article.published_at,
summary=article.summary,
source_url=article.source_url,
sectors=article.sectors,
sentiment=article.sentiment,
article_type=article.article_type,
)
)
return document
def fetch_report_list() -> list[ReportListItem]:
with session_scope() as session:
records = session.scalars(select(ReportRecord).order_by(ReportRecord.date.desc())).all()
return [
ReportListItem(
date=str(record.date),
generated_at=record.generated_at.isoformat(timespec="seconds"),
summary=record.summary,
article_count=record.article_count,
focus_sectors=list(record.focus_sectors or []),
)
for record in records
]
def fetch_cls_news_document(date_str: str) -> ClsNewsDocument | None:
with session_scope() as session:
record = session.get(ClsNewsSnapshotRecord, parse_date(date_str))
if record is None:
return None
return ClsNewsDocument(
date=str(record.date),
updated_at=record.updated_at.isoformat(timespec="seconds"),
window_label=record.window_label,
summary=ClsNewsSummary(
overview=record.overview,
hot_topics=record.hot_topics,
watch_list=list(record.watch_list or []),
),
sector_impacts=[
ClsSectorImpact(
sector=item.sector,
sentiment=item.sentiment,
reason=item.reason,
related_titles=list(item.related_titles or []),
)
for item in record.sector_impacts
],
items=[
ClsNewsItem(
id=item.item_id,
title=item.title,
published_at=item.published_at,
source=item.source,
summary=item.summary,
reference_url=item.reference_url,
sectors=list(item.sectors or []),
sentiment=item.sentiment,
)
for item in record.items
],
)
def save_cls_news_document(document: ClsNewsDocument) -> ClsNewsDocument:
with session_scope() as session:
record = session.get(ClsNewsSnapshotRecord, parse_date(document.date))
if record is None:
record = ClsNewsSnapshotRecord(
date=parse_date(document.date),
updated_at=parse_datetime(document.updated_at),
window_label=document.window_label,
overview=document.summary.overview,
hot_topics=document.summary.hot_topics,
watch_list=document.summary.watch_list,
)
session.add(record)
else:
record.updated_at = parse_datetime(document.updated_at)
record.window_label = document.window_label
record.overview = document.summary.overview
record.hot_topics = document.summary.hot_topics
record.watch_list = document.summary.watch_list
session.execute(
delete(ClsSectorImpactRecord).where(ClsSectorImpactRecord.snapshot_date == record.date)
)
session.execute(
delete(ClsNewsItemRecord).where(ClsNewsItemRecord.snapshot_date == record.date)
)
record.sector_impacts = []
record.items = []
session.flush()
for index, impact in enumerate(document.sector_impacts):
record.sector_impacts.append(
ClsSectorImpactRecord(
sort_order=index,
sector=impact.sector,
sentiment=impact.sentiment,
reason=impact.reason,
related_titles=impact.related_titles,
)
)
for index, item in enumerate(document.items):
record.items.append(
ClsNewsItemRecord(
sort_order=index,
item_id=item.id,
title=item.title,
published_at=item.published_at,
source=item.source,
summary=item.summary,
reference_url=item.reference_url,
sectors=item.sectors,
sentiment=item.sentiment,
)
)
return document