Initial commit
This commit is contained in:
480
backend/app/services/storage.py
Normal file
480
backend/app/services/storage.py
Normal file
@ -0,0 +1,480 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator
|
||||
|
||||
from sqlalchemy import JSON, Date, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, create_engine, delete, select
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship, sessionmaker
|
||||
|
||||
from app.models import (
|
||||
Account,
|
||||
ClsNewsDocument,
|
||||
ClsNewsItem,
|
||||
ClsNewsSummary,
|
||||
ClsSectorImpact,
|
||||
DailyInputAccount,
|
||||
DailyInputDocument,
|
||||
OpinionArticle,
|
||||
ReportDocument,
|
||||
ReportListItem,
|
||||
)
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
CONFIG_DIR = PROJECT_ROOT / "backend" / "config"
|
||||
DATABASE_CONFIG_PATH = CONFIG_DIR / "database.json"
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class AccountRecord(Base):
|
||||
__tablename__ = "accounts"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
description: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
|
||||
class DailyInputRecord(Base):
|
||||
__tablename__ = "daily_inputs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
date: Mapped[Any] = mapped_column(Date, nullable=False, unique=True, index=True)
|
||||
updated_at: Mapped[Any] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
links: Mapped[list["DailyInputLinkRecord"]] = relationship(
|
||||
back_populates="daily_input",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="DailyInputLinkRecord.sort_order",
|
||||
)
|
||||
|
||||
|
||||
class DailyInputLinkRecord(Base):
|
||||
__tablename__ = "daily_input_links"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("daily_input_id", "account_id", "url_hash", name="uq_daily_input_account_url"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
daily_input_id: Mapped[int] = mapped_column(ForeignKey("daily_inputs.id", ondelete="CASCADE"), nullable=False)
|
||||
account_id: Mapped[str] = mapped_column(ForeignKey("accounts.id"), nullable=False, index=True)
|
||||
url: Mapped[str] = mapped_column(String(1024), nullable=False)
|
||||
url_hash: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
daily_input: Mapped[DailyInputRecord] = relationship(back_populates="links")
|
||||
|
||||
|
||||
class ReportRecord(Base):
|
||||
__tablename__ = "reports"
|
||||
|
||||
date: Mapped[Any] = mapped_column(Date, primary_key=True)
|
||||
generated_at: Mapped[Any] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
summary: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
focus_sectors: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
||||
article_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
account_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
articles: Mapped[list["ReportArticleRecord"]] = relationship(
|
||||
back_populates="report",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="ReportArticleRecord.sort_order",
|
||||
)
|
||||
|
||||
|
||||
class ReportArticleRecord(Base):
|
||||
__tablename__ = "report_articles"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
report_date: Mapped[Any] = mapped_column(ForeignKey("reports.date", ondelete="CASCADE"), nullable=False, index=True)
|
||||
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
article_id: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
account_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
account_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
published_at: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
summary: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
source_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
sectors: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
||||
sentiment: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
article_type: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
|
||||
report: Mapped[ReportRecord] = relationship(back_populates="articles")
|
||||
|
||||
|
||||
class ClsNewsSnapshotRecord(Base):
|
||||
__tablename__ = "cls_news_snapshots"
|
||||
|
||||
date: Mapped[Any] = mapped_column(Date, primary_key=True)
|
||||
updated_at: Mapped[Any] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
window_label: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
overview: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
hot_topics: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
watch_list: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
||||
|
||||
sector_impacts: Mapped[list["ClsSectorImpactRecord"]] = relationship(
|
||||
back_populates="snapshot",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="ClsSectorImpactRecord.sort_order",
|
||||
)
|
||||
items: Mapped[list["ClsNewsItemRecord"]] = relationship(
|
||||
back_populates="snapshot",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="ClsNewsItemRecord.sort_order",
|
||||
)
|
||||
|
||||
|
||||
class ClsSectorImpactRecord(Base):
|
||||
__tablename__ = "cls_sector_impacts"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
snapshot_date: Mapped[Any] = mapped_column(ForeignKey("cls_news_snapshots.date", ondelete="CASCADE"), nullable=False, index=True)
|
||||
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
sector: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
sentiment: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
related_titles: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
||||
|
||||
snapshot: Mapped[ClsNewsSnapshotRecord] = relationship(back_populates="sector_impacts")
|
||||
|
||||
|
||||
class ClsNewsItemRecord(Base):
|
||||
__tablename__ = "cls_news_items"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
snapshot_date: Mapped[Any] = mapped_column(ForeignKey("cls_news_snapshots.date", ondelete="CASCADE"), nullable=False, index=True)
|
||||
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
item_id: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
published_at: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
summary: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
reference_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
sectors: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
||||
sentiment: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
|
||||
snapshot: Mapped[ClsNewsSnapshotRecord] = relationship(back_populates="items")
|
||||
|
||||
|
||||
def load_database_config() -> dict[str, Any]:
|
||||
if not DATABASE_CONFIG_PATH.exists():
|
||||
raise RuntimeError(f"Database config not found: {DATABASE_CONFIG_PATH}")
|
||||
config = json.loads(DATABASE_CONFIG_PATH.read_text(encoding="utf-8"))
|
||||
required_fields = ("host", "port", "username", "password", "database")
|
||||
missing = [field for field in required_fields if not config.get(field)]
|
||||
if missing:
|
||||
raise RuntimeError(f"Database config is incomplete: {', '.join(missing)}")
|
||||
if any(str(config.get(field)).strip() == "CHANGE_ME" for field in required_fields):
|
||||
raise RuntimeError(f"Database config still contains placeholders: {DATABASE_CONFIG_PATH}")
|
||||
config.setdefault("charset", "utf8mb4")
|
||||
config.setdefault("echo", False)
|
||||
return config
|
||||
|
||||
|
||||
_ENGINE = None
|
||||
_SESSION_FACTORY: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def parse_date(value: str) -> date:
|
||||
return date.fromisoformat(value)
|
||||
|
||||
|
||||
def parse_datetime(value: str) -> datetime:
|
||||
return datetime.fromisoformat(value)
|
||||
|
||||
|
||||
def hash_url(value: str) -> str:
|
||||
return hashlib.sha256(value.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def get_engine():
|
||||
global _ENGINE
|
||||
if _ENGINE is None:
|
||||
config = load_database_config()
|
||||
_ENGINE = create_engine(
|
||||
(
|
||||
f"mysql+pymysql://{config['username']}:{config['password']}"
|
||||
f"@{config['host']}:{config['port']}/{config['database']}?charset={config['charset']}"
|
||||
),
|
||||
echo=bool(config.get("echo", False)),
|
||||
future=True,
|
||||
)
|
||||
return _ENGINE
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
global _SESSION_FACTORY
|
||||
if _SESSION_FACTORY is None:
|
||||
_SESSION_FACTORY = sessionmaker(bind=get_engine(), autoflush=False, autocommit=False, future=True)
|
||||
return _SESSION_FACTORY
|
||||
|
||||
|
||||
def init_database() -> None:
|
||||
Base.metadata.create_all(get_engine())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope() -> Iterator[Session]:
|
||||
session = get_session_factory()()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def save_accounts(accounts: list[Account]) -> None:
|
||||
with session_scope() as session:
|
||||
for account in accounts:
|
||||
existing = session.get(AccountRecord, account.id)
|
||||
if existing is None:
|
||||
session.add(AccountRecord(id=account.id, name=account.name, description=account.description))
|
||||
continue
|
||||
existing.name = account.name
|
||||
existing.description = account.description
|
||||
|
||||
|
||||
def fetch_accounts() -> list[Account]:
|
||||
with session_scope() as session:
|
||||
records = session.scalars(select(AccountRecord).order_by(AccountRecord.id)).all()
|
||||
return [Account(id=record.id, name=record.name, description=record.description) for record in records]
|
||||
|
||||
|
||||
def fetch_daily_input_document(date_str: str) -> DailyInputDocument | None:
|
||||
with session_scope() as session:
|
||||
record = session.scalar(select(DailyInputRecord).where(DailyInputRecord.date == parse_date(date_str)))
|
||||
if record is None:
|
||||
return None
|
||||
account_records = session.scalars(select(AccountRecord).order_by(AccountRecord.id)).all()
|
||||
links_by_account: dict[str, list[str]] = {}
|
||||
for link in record.links:
|
||||
links_by_account.setdefault(link.account_id, []).append(link.url)
|
||||
return DailyInputDocument(
|
||||
date=str(record.date),
|
||||
updated_at=record.updated_at.isoformat(timespec="seconds"),
|
||||
accounts=[
|
||||
DailyInputAccount(
|
||||
account_id=account.id,
|
||||
account_name=account.name,
|
||||
links=links_by_account.get(account.id, []),
|
||||
)
|
||||
for account in account_records
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def save_daily_input_document(document: DailyInputDocument) -> DailyInputDocument:
|
||||
with session_scope() as session:
|
||||
record = session.scalar(select(DailyInputRecord).where(DailyInputRecord.date == parse_date(document.date)))
|
||||
if record is None:
|
||||
record = DailyInputRecord(date=parse_date(document.date), updated_at=parse_datetime(document.updated_at))
|
||||
session.add(record)
|
||||
session.flush()
|
||||
else:
|
||||
record.updated_at = parse_datetime(document.updated_at)
|
||||
record.links.clear()
|
||||
session.flush()
|
||||
|
||||
sort_order = 0
|
||||
for account in document.accounts:
|
||||
for url in account.links:
|
||||
record.links.append(
|
||||
DailyInputLinkRecord(
|
||||
account_id=account.account_id,
|
||||
url=url,
|
||||
url_hash=hash_url(url),
|
||||
sort_order=sort_order,
|
||||
)
|
||||
)
|
||||
sort_order += 1
|
||||
return document
|
||||
|
||||
|
||||
def fetch_report_document(date_str: str) -> ReportDocument | None:
|
||||
with session_scope() as session:
|
||||
record = session.get(ReportRecord, parse_date(date_str))
|
||||
if record is None:
|
||||
return None
|
||||
return ReportDocument(
|
||||
date=str(record.date),
|
||||
generated_at=record.generated_at.isoformat(timespec="seconds"),
|
||||
summary=record.summary,
|
||||
focus_sectors=list(record.focus_sectors or []),
|
||||
article_count=record.article_count,
|
||||
account_count=record.account_count,
|
||||
articles=[
|
||||
OpinionArticle(
|
||||
id=article.article_id,
|
||||
account_id=article.account_id,
|
||||
account_name=article.account_name,
|
||||
title=article.title,
|
||||
published_at=article.published_at,
|
||||
summary=article.summary,
|
||||
source_url=article.source_url,
|
||||
sectors=list(article.sectors or []),
|
||||
sentiment=article.sentiment,
|
||||
article_type=article.article_type,
|
||||
)
|
||||
for article in record.articles
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def save_report_document(document: ReportDocument) -> ReportDocument:
|
||||
with session_scope() as session:
|
||||
record = session.get(ReportRecord, parse_date(document.date))
|
||||
if record is None:
|
||||
record = ReportRecord(
|
||||
date=parse_date(document.date),
|
||||
generated_at=parse_datetime(document.generated_at),
|
||||
summary=document.summary,
|
||||
focus_sectors=document.focus_sectors,
|
||||
article_count=document.article_count,
|
||||
account_count=document.account_count,
|
||||
)
|
||||
session.add(record)
|
||||
else:
|
||||
record.generated_at = parse_datetime(document.generated_at)
|
||||
record.summary = document.summary
|
||||
record.focus_sectors = document.focus_sectors
|
||||
record.article_count = document.article_count
|
||||
record.account_count = document.account_count
|
||||
record.articles.clear()
|
||||
session.flush()
|
||||
|
||||
for index, article in enumerate(document.articles):
|
||||
record.articles.append(
|
||||
ReportArticleRecord(
|
||||
sort_order=index,
|
||||
article_id=article.id,
|
||||
account_id=article.account_id,
|
||||
account_name=article.account_name,
|
||||
title=article.title,
|
||||
published_at=article.published_at,
|
||||
summary=article.summary,
|
||||
source_url=article.source_url,
|
||||
sectors=article.sectors,
|
||||
sentiment=article.sentiment,
|
||||
article_type=article.article_type,
|
||||
)
|
||||
)
|
||||
return document
|
||||
|
||||
|
||||
def fetch_report_list() -> list[ReportListItem]:
|
||||
with session_scope() as session:
|
||||
records = session.scalars(select(ReportRecord).order_by(ReportRecord.date.desc())).all()
|
||||
return [
|
||||
ReportListItem(
|
||||
date=str(record.date),
|
||||
generated_at=record.generated_at.isoformat(timespec="seconds"),
|
||||
summary=record.summary,
|
||||
article_count=record.article_count,
|
||||
focus_sectors=list(record.focus_sectors or []),
|
||||
)
|
||||
for record in records
|
||||
]
|
||||
|
||||
|
||||
def fetch_cls_news_document(date_str: str) -> ClsNewsDocument | None:
|
||||
with session_scope() as session:
|
||||
record = session.get(ClsNewsSnapshotRecord, parse_date(date_str))
|
||||
if record is None:
|
||||
return None
|
||||
return ClsNewsDocument(
|
||||
date=str(record.date),
|
||||
updated_at=record.updated_at.isoformat(timespec="seconds"),
|
||||
window_label=record.window_label,
|
||||
summary=ClsNewsSummary(
|
||||
overview=record.overview,
|
||||
hot_topics=record.hot_topics,
|
||||
watch_list=list(record.watch_list or []),
|
||||
),
|
||||
sector_impacts=[
|
||||
ClsSectorImpact(
|
||||
sector=item.sector,
|
||||
sentiment=item.sentiment,
|
||||
reason=item.reason,
|
||||
related_titles=list(item.related_titles or []),
|
||||
)
|
||||
for item in record.sector_impacts
|
||||
],
|
||||
items=[
|
||||
ClsNewsItem(
|
||||
id=item.item_id,
|
||||
title=item.title,
|
||||
published_at=item.published_at,
|
||||
source=item.source,
|
||||
summary=item.summary,
|
||||
reference_url=item.reference_url,
|
||||
sectors=list(item.sectors or []),
|
||||
sentiment=item.sentiment,
|
||||
)
|
||||
for item in record.items
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def save_cls_news_document(document: ClsNewsDocument) -> ClsNewsDocument:
|
||||
with session_scope() as session:
|
||||
record = session.get(ClsNewsSnapshotRecord, parse_date(document.date))
|
||||
if record is None:
|
||||
record = ClsNewsSnapshotRecord(
|
||||
date=parse_date(document.date),
|
||||
updated_at=parse_datetime(document.updated_at),
|
||||
window_label=document.window_label,
|
||||
overview=document.summary.overview,
|
||||
hot_topics=document.summary.hot_topics,
|
||||
watch_list=document.summary.watch_list,
|
||||
)
|
||||
session.add(record)
|
||||
else:
|
||||
record.updated_at = parse_datetime(document.updated_at)
|
||||
record.window_label = document.window_label
|
||||
record.overview = document.summary.overview
|
||||
record.hot_topics = document.summary.hot_topics
|
||||
record.watch_list = document.summary.watch_list
|
||||
session.execute(
|
||||
delete(ClsSectorImpactRecord).where(ClsSectorImpactRecord.snapshot_date == record.date)
|
||||
)
|
||||
session.execute(
|
||||
delete(ClsNewsItemRecord).where(ClsNewsItemRecord.snapshot_date == record.date)
|
||||
)
|
||||
record.sector_impacts = []
|
||||
record.items = []
|
||||
session.flush()
|
||||
|
||||
for index, impact in enumerate(document.sector_impacts):
|
||||
record.sector_impacts.append(
|
||||
ClsSectorImpactRecord(
|
||||
sort_order=index,
|
||||
sector=impact.sector,
|
||||
sentiment=impact.sentiment,
|
||||
reason=impact.reason,
|
||||
related_titles=impact.related_titles,
|
||||
)
|
||||
)
|
||||
|
||||
for index, item in enumerate(document.items):
|
||||
record.items.append(
|
||||
ClsNewsItemRecord(
|
||||
sort_order=index,
|
||||
item_id=item.id,
|
||||
title=item.title,
|
||||
published_at=item.published_at,
|
||||
source=item.source,
|
||||
summary=item.summary,
|
||||
reference_url=item.reference_url,
|
||||
sectors=item.sectors,
|
||||
sentiment=item.sentiment,
|
||||
)
|
||||
)
|
||||
return document
|
||||
Reference in New Issue
Block a user