Files
zjjk/backend/app/services/main_capital_flow_service.py

206 lines
8.3 KiB
Python
Raw Normal View History

2026-04-08 20:04:40 +08:00
import base64
import json
import re
import urllib.error
import urllib.request
from datetime import datetime
from pathlib import Path
from uuid import uuid4
from fastapi import HTTPException, UploadFile
from app.core.config import MAIN_CAPITAL_FLOW_DB_FILE, MAIN_CAPITAL_FLOW_UPLOADS_DIR
from app.repositories.main_capital_flow_repository import MainCapitalFlowRepository
from app.repositories.monitoring_repository import MonitoringRepository
def _extract_json_block(content: str) -> dict:
fenced_match = re.search(r"```json\s*(\{.*?\})\s*```", content, flags=re.DOTALL)
if fenced_match:
return json.loads(fenced_match.group(1))
object_match = re.search(r"(\{.*\})", content, flags=re.DOTALL)
if object_match:
return json.loads(object_match.group(1))
raise ValueError("No JSON object found in model output")
class MainCapitalFlowService:
def __init__(self) -> None:
self.repository = MainCapitalFlowRepository(MAIN_CAPITAL_FLOW_DB_FILE)
self.monitoring_repository = MonitoringRepository()
def list_records(self) -> dict:
items = [self._serialize_record(record) for record in self.repository.list_records()]
return {"items": items, "total": len(items)}
def get_record(self, record_id: str) -> dict:
record = self.repository.get_record(record_id)
if record is None:
raise HTTPException(status_code=404, detail="Record not found")
return self._serialize_record(record)
def delete_record(self, record_id: str) -> dict:
record = self.repository.delete_record(record_id)
if record is None:
raise HTTPException(status_code=404, detail="Record not found")
image_path = Path(record["image_path"])
if image_path.exists():
image_path.unlink(missing_ok=True)
return {"deleted": True, "id": record_id}
async def recognize_image(
self,
upload_file: UploadFile,
trade_date: str | None = None,
subject: str | None = None,
) -> dict:
suffix = Path(upload_file.filename or "upload.jpg").suffix or ".jpg"
temp_image_name = f"temp_{uuid4().hex}{suffix.lower()}"
stored_path = MAIN_CAPITAL_FLOW_UPLOADS_DIR / temp_image_name
image_name = upload_file.filename or temp_image_name
binary = await upload_file.read()
stored_path.parent.mkdir(parents=True, exist_ok=True)
stored_path.write_bytes(binary)
extraction = self._extract_via_model(binary, trade_date=trade_date, subject=subject)
return {
"temp_image_name": temp_image_name,
"image_name": image_name,
"image_url": self._build_image_url(stored_path),
"trade_date": extraction.get("trade_date") or trade_date,
"subject": extraction.get("subject") or subject,
"snapshot_time": extraction.get("snapshot_time"),
"institution_amount_yi": extraction.get("institution_amount_yi"),
"main_force_amount_yi": extraction.get("main_force_amount_yi"),
"large_household_amount_yi": extraction.get("large_household_amount_yi"),
"retail_amount_yi": extraction.get("retail_amount_yi"),
"trend": extraction.get("overall_trend"),
"summary": extraction.get("intraday_summary"),
"raw_extraction": extraction,
}
def create_record(self, payload: dict) -> dict:
if self.repository.get_by_trade_date(payload["trade_date"]):
raise HTTPException(status_code=409, detail="该日期记录已存在")
image_path = MAIN_CAPITAL_FLOW_UPLOADS_DIR / payload["temp_image_name"]
if not image_path.exists():
raise HTTPException(status_code=400, detail="识别图片不存在,请重新上传")
now = datetime.now().isoformat(timespec="seconds")
record = self.repository.insert_record(
{
"id": uuid4().hex,
"trade_date": payload["trade_date"],
"subject": payload.get("subject"),
"snapshot_time": payload.get("snapshot_time"),
"institution_amount_yi": payload.get("institution_amount_yi"),
"main_force_amount_yi": payload.get("main_force_amount_yi"),
"large_household_amount_yi": payload.get("large_household_amount_yi"),
"retail_amount_yi": payload.get("retail_amount_yi"),
"trend": payload.get("trend"),
"summary": payload["summary"],
"image_name": payload["image_name"],
"image_path": str(image_path),
"raw_extraction": payload.get("raw_extraction", {}),
"created_at": now,
"updated_at": now,
}
)
return {"item": self._serialize_record(record)}
def _extract_via_model(
self,
image_bytes: bytes,
trade_date: str | None,
subject: str | None,
) -> dict:
llm_config = self._get_llm_config()
if not llm_config["api_key"]:
raise HTTPException(status_code=500, detail="未配置视觉模型 API")
encoded_image = base64.b64encode(image_bytes).decode("utf-8")
prompt = """
You are extracting structured data from a Chinese stock capital flow screenshot.
Return only JSON with these keys:
trade_date, subject, snapshot_time, institution_amount_yi, main_force_amount_yi,
large_household_amount_yi, retail_amount_yi, overall_trend, intraday_summary.
Rules:
1. intraday_summary must describe only the intraday capital-flow trend and must not repeat the raw amounts.
2. overall_trend should be a short Chinese phrase like "震荡上行", "午后修复", "冲高回落", "弱势下探".
3. If a field is not clearly visible, set it to null.
4. If trade_date is absent in the image, keep it null.
5. Return JSON only.
"""
request_payload = {
"model": llm_config["model"],
"messages": [
{
"role": "system",
"content": "You extract structured JSON from Chinese capital-flow screenshots."
},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}",
},
},
],
},
],
}
request = urllib.request.Request(
url=f"{llm_config['base_url'].rstrip('/')}/chat/completions",
data=json.dumps(request_payload).encode("utf-8"),
headers={
"Authorization": f"Bearer {llm_config['api_key']}",
"Content-Type": "application/json",
},
method="POST",
)
try:
with urllib.request.urlopen(request, timeout=180) as response:
response_payload = json.loads(response.read().decode("utf-8"))
except urllib.error.HTTPError as exc:
error_text = exc.read().decode("utf-8", errors="ignore")
raise HTTPException(status_code=502, detail=f"模型识别失败: {error_text}") from exc
choices = response_payload.get("choices", [])
content = choices[0].get("message", {}).get("content", "") if choices else ""
parsed = _extract_json_block(content)
if subject and not parsed.get("subject"):
parsed["subject"] = subject
if trade_date and not parsed.get("trade_date"):
parsed["trade_date"] = trade_date
return parsed
def _get_llm_config(self) -> dict:
config = self.monitoring_repository.get_system_config()
return {
"api_key": config.get("llm_api_key", ""),
"base_url": config.get("llm_base_url", "https://api.openai.com/v1"),
"model": config.get("llm_vision_model", "gpt-4.1-mini"),
}
def _build_image_url(self, path: Path) -> str:
return f"/main-capital-flow-images/{path.name}"
def _serialize_record(self, record: dict) -> dict:
return {
**record,
"image_url": self._build_image_url(Path(record["image_path"])),
}
main_capital_flow_service = MainCapitalFlowService()