481 lines
19 KiB
Python
481 lines
19 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
from contextlib import contextmanager
|
|
from datetime import date, datetime
|
|
from pathlib import Path
|
|
from typing import Any, Iterator
|
|
|
|
from sqlalchemy import JSON, Date, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, create_engine, delete, select
|
|
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship, sessionmaker
|
|
|
|
from app.models import (
|
|
Account,
|
|
ClsNewsDocument,
|
|
ClsNewsItem,
|
|
ClsNewsSummary,
|
|
ClsSectorImpact,
|
|
DailyInputAccount,
|
|
DailyInputDocument,
|
|
OpinionArticle,
|
|
ReportDocument,
|
|
ReportListItem,
|
|
)
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
|
CONFIG_DIR = PROJECT_ROOT / "backend" / "config"
|
|
DATABASE_CONFIG_PATH = CONFIG_DIR / "database.json"
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
|
|
class AccountRecord(Base):
|
|
__tablename__ = "accounts"
|
|
|
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
|
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
|
description: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
|
|
|
|
class DailyInputRecord(Base):
|
|
__tablename__ = "daily_inputs"
|
|
|
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
date: Mapped[Any] = mapped_column(Date, nullable=False, unique=True, index=True)
|
|
updated_at: Mapped[Any] = mapped_column(DateTime(timezone=True), nullable=False)
|
|
|
|
links: Mapped[list["DailyInputLinkRecord"]] = relationship(
|
|
back_populates="daily_input",
|
|
cascade="all, delete-orphan",
|
|
order_by="DailyInputLinkRecord.sort_order",
|
|
)
|
|
|
|
|
|
class DailyInputLinkRecord(Base):
|
|
__tablename__ = "daily_input_links"
|
|
__table_args__ = (
|
|
UniqueConstraint("daily_input_id", "account_id", "url_hash", name="uq_daily_input_account_url"),
|
|
)
|
|
|
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
daily_input_id: Mapped[int] = mapped_column(ForeignKey("daily_inputs.id", ondelete="CASCADE"), nullable=False)
|
|
account_id: Mapped[str] = mapped_column(ForeignKey("accounts.id"), nullable=False, index=True)
|
|
url: Mapped[str] = mapped_column(String(1024), nullable=False)
|
|
url_hash: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
|
|
daily_input: Mapped[DailyInputRecord] = relationship(back_populates="links")
|
|
|
|
|
|
class ReportRecord(Base):
|
|
__tablename__ = "reports"
|
|
|
|
date: Mapped[Any] = mapped_column(Date, primary_key=True)
|
|
generated_at: Mapped[Any] = mapped_column(DateTime(timezone=True), nullable=False)
|
|
summary: Mapped[str] = mapped_column(Text, nullable=False)
|
|
focus_sectors: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
|
article_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
account_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
|
|
articles: Mapped[list["ReportArticleRecord"]] = relationship(
|
|
back_populates="report",
|
|
cascade="all, delete-orphan",
|
|
order_by="ReportArticleRecord.sort_order",
|
|
)
|
|
|
|
|
|
class ReportArticleRecord(Base):
|
|
__tablename__ = "report_articles"
|
|
|
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
report_date: Mapped[Any] = mapped_column(ForeignKey("reports.date", ondelete="CASCADE"), nullable=False, index=True)
|
|
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
article_id: Mapped[str] = mapped_column(String(128), nullable=False)
|
|
account_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
account_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
|
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
published_at: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
summary: Mapped[str] = mapped_column(Text, nullable=False)
|
|
source_url: Mapped[str] = mapped_column(Text, nullable=False)
|
|
sectors: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
|
sentiment: Mapped[str] = mapped_column(String(16), nullable=False)
|
|
article_type: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
|
|
report: Mapped[ReportRecord] = relationship(back_populates="articles")
|
|
|
|
|
|
class ClsNewsSnapshotRecord(Base):
|
|
__tablename__ = "cls_news_snapshots"
|
|
|
|
date: Mapped[Any] = mapped_column(Date, primary_key=True)
|
|
updated_at: Mapped[Any] = mapped_column(DateTime(timezone=True), nullable=False)
|
|
window_label: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
overview: Mapped[str] = mapped_column(Text, nullable=False)
|
|
hot_topics: Mapped[str] = mapped_column(Text, nullable=False)
|
|
watch_list: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
|
|
|
sector_impacts: Mapped[list["ClsSectorImpactRecord"]] = relationship(
|
|
back_populates="snapshot",
|
|
cascade="all, delete-orphan",
|
|
order_by="ClsSectorImpactRecord.sort_order",
|
|
)
|
|
items: Mapped[list["ClsNewsItemRecord"]] = relationship(
|
|
back_populates="snapshot",
|
|
cascade="all, delete-orphan",
|
|
order_by="ClsNewsItemRecord.sort_order",
|
|
)
|
|
|
|
|
|
class ClsSectorImpactRecord(Base):
|
|
__tablename__ = "cls_sector_impacts"
|
|
|
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
snapshot_date: Mapped[Any] = mapped_column(ForeignKey("cls_news_snapshots.date", ondelete="CASCADE"), nullable=False, index=True)
|
|
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
sector: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
sentiment: Mapped[str] = mapped_column(String(16), nullable=False)
|
|
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
|
related_titles: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
|
|
|
snapshot: Mapped[ClsNewsSnapshotRecord] = relationship(back_populates="sector_impacts")
|
|
|
|
|
|
class ClsNewsItemRecord(Base):
|
|
__tablename__ = "cls_news_items"
|
|
|
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
snapshot_date: Mapped[Any] = mapped_column(ForeignKey("cls_news_snapshots.date", ondelete="CASCADE"), nullable=False, index=True)
|
|
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
item_id: Mapped[str] = mapped_column(String(128), nullable=False)
|
|
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
published_at: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
source: Mapped[str] = mapped_column(String(128), nullable=False)
|
|
summary: Mapped[str] = mapped_column(Text, nullable=False)
|
|
reference_url: Mapped[str] = mapped_column(Text, nullable=False)
|
|
sectors: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
|
sentiment: Mapped[str] = mapped_column(String(16), nullable=False)
|
|
|
|
snapshot: Mapped[ClsNewsSnapshotRecord] = relationship(back_populates="items")
|
|
|
|
|
|
def load_database_config() -> dict[str, Any]:
|
|
if not DATABASE_CONFIG_PATH.exists():
|
|
raise RuntimeError(f"Database config not found: {DATABASE_CONFIG_PATH}")
|
|
config = json.loads(DATABASE_CONFIG_PATH.read_text(encoding="utf-8"))
|
|
required_fields = ("host", "port", "username", "password", "database")
|
|
missing = [field for field in required_fields if not config.get(field)]
|
|
if missing:
|
|
raise RuntimeError(f"Database config is incomplete: {', '.join(missing)}")
|
|
if any(str(config.get(field)).strip() == "CHANGE_ME" for field in required_fields):
|
|
raise RuntimeError(f"Database config still contains placeholders: {DATABASE_CONFIG_PATH}")
|
|
config.setdefault("charset", "utf8mb4")
|
|
config.setdefault("echo", False)
|
|
return config
|
|
|
|
|
|
_ENGINE = None
|
|
_SESSION_FACTORY: sessionmaker[Session] | None = None
|
|
|
|
|
|
def parse_date(value: str) -> date:
|
|
return date.fromisoformat(value)
|
|
|
|
|
|
def parse_datetime(value: str) -> datetime:
|
|
return datetime.fromisoformat(value)
|
|
|
|
|
|
def hash_url(value: str) -> str:
|
|
return hashlib.sha256(value.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def get_engine():
|
|
global _ENGINE
|
|
if _ENGINE is None:
|
|
config = load_database_config()
|
|
_ENGINE = create_engine(
|
|
(
|
|
f"mysql+pymysql://{config['username']}:{config['password']}"
|
|
f"@{config['host']}:{config['port']}/{config['database']}?charset={config['charset']}"
|
|
),
|
|
echo=bool(config.get("echo", False)),
|
|
future=True,
|
|
)
|
|
return _ENGINE
|
|
|
|
|
|
def get_session_factory() -> sessionmaker[Session]:
|
|
global _SESSION_FACTORY
|
|
if _SESSION_FACTORY is None:
|
|
_SESSION_FACTORY = sessionmaker(bind=get_engine(), autoflush=False, autocommit=False, future=True)
|
|
return _SESSION_FACTORY
|
|
|
|
|
|
def init_database() -> None:
|
|
Base.metadata.create_all(get_engine())
|
|
|
|
|
|
@contextmanager
|
|
def session_scope() -> Iterator[Session]:
|
|
session = get_session_factory()()
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def save_accounts(accounts: list[Account]) -> None:
|
|
with session_scope() as session:
|
|
for account in accounts:
|
|
existing = session.get(AccountRecord, account.id)
|
|
if existing is None:
|
|
session.add(AccountRecord(id=account.id, name=account.name, description=account.description))
|
|
continue
|
|
existing.name = account.name
|
|
existing.description = account.description
|
|
|
|
|
|
def fetch_accounts() -> list[Account]:
|
|
with session_scope() as session:
|
|
records = session.scalars(select(AccountRecord).order_by(AccountRecord.id)).all()
|
|
return [Account(id=record.id, name=record.name, description=record.description) for record in records]
|
|
|
|
|
|
def fetch_daily_input_document(date_str: str) -> DailyInputDocument | None:
|
|
with session_scope() as session:
|
|
record = session.scalar(select(DailyInputRecord).where(DailyInputRecord.date == parse_date(date_str)))
|
|
if record is None:
|
|
return None
|
|
account_records = session.scalars(select(AccountRecord).order_by(AccountRecord.id)).all()
|
|
links_by_account: dict[str, list[str]] = {}
|
|
for link in record.links:
|
|
links_by_account.setdefault(link.account_id, []).append(link.url)
|
|
return DailyInputDocument(
|
|
date=str(record.date),
|
|
updated_at=record.updated_at.isoformat(timespec="seconds"),
|
|
accounts=[
|
|
DailyInputAccount(
|
|
account_id=account.id,
|
|
account_name=account.name,
|
|
links=links_by_account.get(account.id, []),
|
|
)
|
|
for account in account_records
|
|
],
|
|
)
|
|
|
|
|
|
def save_daily_input_document(document: DailyInputDocument) -> DailyInputDocument:
|
|
with session_scope() as session:
|
|
record = session.scalar(select(DailyInputRecord).where(DailyInputRecord.date == parse_date(document.date)))
|
|
if record is None:
|
|
record = DailyInputRecord(date=parse_date(document.date), updated_at=parse_datetime(document.updated_at))
|
|
session.add(record)
|
|
session.flush()
|
|
else:
|
|
record.updated_at = parse_datetime(document.updated_at)
|
|
record.links.clear()
|
|
session.flush()
|
|
|
|
sort_order = 0
|
|
for account in document.accounts:
|
|
for url in account.links:
|
|
record.links.append(
|
|
DailyInputLinkRecord(
|
|
account_id=account.account_id,
|
|
url=url,
|
|
url_hash=hash_url(url),
|
|
sort_order=sort_order,
|
|
)
|
|
)
|
|
sort_order += 1
|
|
return document
|
|
|
|
|
|
def fetch_report_document(date_str: str) -> ReportDocument | None:
|
|
with session_scope() as session:
|
|
record = session.get(ReportRecord, parse_date(date_str))
|
|
if record is None:
|
|
return None
|
|
return ReportDocument(
|
|
date=str(record.date),
|
|
generated_at=record.generated_at.isoformat(timespec="seconds"),
|
|
summary=record.summary,
|
|
focus_sectors=list(record.focus_sectors or []),
|
|
article_count=record.article_count,
|
|
account_count=record.account_count,
|
|
articles=[
|
|
OpinionArticle(
|
|
id=article.article_id,
|
|
account_id=article.account_id,
|
|
account_name=article.account_name,
|
|
title=article.title,
|
|
published_at=article.published_at,
|
|
summary=article.summary,
|
|
source_url=article.source_url,
|
|
sectors=list(article.sectors or []),
|
|
sentiment=article.sentiment,
|
|
article_type=article.article_type,
|
|
)
|
|
for article in record.articles
|
|
],
|
|
)
|
|
|
|
|
|
def save_report_document(document: ReportDocument) -> ReportDocument:
|
|
with session_scope() as session:
|
|
record = session.get(ReportRecord, parse_date(document.date))
|
|
if record is None:
|
|
record = ReportRecord(
|
|
date=parse_date(document.date),
|
|
generated_at=parse_datetime(document.generated_at),
|
|
summary=document.summary,
|
|
focus_sectors=document.focus_sectors,
|
|
article_count=document.article_count,
|
|
account_count=document.account_count,
|
|
)
|
|
session.add(record)
|
|
else:
|
|
record.generated_at = parse_datetime(document.generated_at)
|
|
record.summary = document.summary
|
|
record.focus_sectors = document.focus_sectors
|
|
record.article_count = document.article_count
|
|
record.account_count = document.account_count
|
|
record.articles.clear()
|
|
session.flush()
|
|
|
|
for index, article in enumerate(document.articles):
|
|
record.articles.append(
|
|
ReportArticleRecord(
|
|
sort_order=index,
|
|
article_id=article.id,
|
|
account_id=article.account_id,
|
|
account_name=article.account_name,
|
|
title=article.title,
|
|
published_at=article.published_at,
|
|
summary=article.summary,
|
|
source_url=article.source_url,
|
|
sectors=article.sectors,
|
|
sentiment=article.sentiment,
|
|
article_type=article.article_type,
|
|
)
|
|
)
|
|
return document
|
|
|
|
|
|
def fetch_report_list() -> list[ReportListItem]:
|
|
with session_scope() as session:
|
|
records = session.scalars(select(ReportRecord).order_by(ReportRecord.date.desc())).all()
|
|
return [
|
|
ReportListItem(
|
|
date=str(record.date),
|
|
generated_at=record.generated_at.isoformat(timespec="seconds"),
|
|
summary=record.summary,
|
|
article_count=record.article_count,
|
|
focus_sectors=list(record.focus_sectors or []),
|
|
)
|
|
for record in records
|
|
]
|
|
|
|
|
|
def fetch_cls_news_document(date_str: str) -> ClsNewsDocument | None:
|
|
with session_scope() as session:
|
|
record = session.get(ClsNewsSnapshotRecord, parse_date(date_str))
|
|
if record is None:
|
|
return None
|
|
return ClsNewsDocument(
|
|
date=str(record.date),
|
|
updated_at=record.updated_at.isoformat(timespec="seconds"),
|
|
window_label=record.window_label,
|
|
summary=ClsNewsSummary(
|
|
overview=record.overview,
|
|
hot_topics=record.hot_topics,
|
|
watch_list=list(record.watch_list or []),
|
|
),
|
|
sector_impacts=[
|
|
ClsSectorImpact(
|
|
sector=item.sector,
|
|
sentiment=item.sentiment,
|
|
reason=item.reason,
|
|
related_titles=list(item.related_titles or []),
|
|
)
|
|
for item in record.sector_impacts
|
|
],
|
|
items=[
|
|
ClsNewsItem(
|
|
id=item.item_id,
|
|
title=item.title,
|
|
published_at=item.published_at,
|
|
source=item.source,
|
|
summary=item.summary,
|
|
reference_url=item.reference_url,
|
|
sectors=list(item.sectors or []),
|
|
sentiment=item.sentiment,
|
|
)
|
|
for item in record.items
|
|
],
|
|
)
|
|
|
|
|
|
def save_cls_news_document(document: ClsNewsDocument) -> ClsNewsDocument:
|
|
with session_scope() as session:
|
|
record = session.get(ClsNewsSnapshotRecord, parse_date(document.date))
|
|
if record is None:
|
|
record = ClsNewsSnapshotRecord(
|
|
date=parse_date(document.date),
|
|
updated_at=parse_datetime(document.updated_at),
|
|
window_label=document.window_label,
|
|
overview=document.summary.overview,
|
|
hot_topics=document.summary.hot_topics,
|
|
watch_list=document.summary.watch_list,
|
|
)
|
|
session.add(record)
|
|
else:
|
|
record.updated_at = parse_datetime(document.updated_at)
|
|
record.window_label = document.window_label
|
|
record.overview = document.summary.overview
|
|
record.hot_topics = document.summary.hot_topics
|
|
record.watch_list = document.summary.watch_list
|
|
session.execute(
|
|
delete(ClsSectorImpactRecord).where(ClsSectorImpactRecord.snapshot_date == record.date)
|
|
)
|
|
session.execute(
|
|
delete(ClsNewsItemRecord).where(ClsNewsItemRecord.snapshot_date == record.date)
|
|
)
|
|
record.sector_impacts = []
|
|
record.items = []
|
|
session.flush()
|
|
|
|
for index, impact in enumerate(document.sector_impacts):
|
|
record.sector_impacts.append(
|
|
ClsSectorImpactRecord(
|
|
sort_order=index,
|
|
sector=impact.sector,
|
|
sentiment=impact.sentiment,
|
|
reason=impact.reason,
|
|
related_titles=impact.related_titles,
|
|
)
|
|
)
|
|
|
|
for index, item in enumerate(document.items):
|
|
record.items.append(
|
|
ClsNewsItemRecord(
|
|
sort_order=index,
|
|
item_id=item.id,
|
|
title=item.title,
|
|
published_at=item.published_at,
|
|
source=item.source,
|
|
summary=item.summary,
|
|
reference_url=item.reference_url,
|
|
sectors=item.sectors,
|
|
sentiment=item.sentiment,
|
|
)
|
|
)
|
|
return document
|