206 lines
8.3 KiB
Python
206 lines
8.3 KiB
Python
import base64
|
|
import json
|
|
import re
|
|
import urllib.error
|
|
import urllib.request
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from uuid import uuid4
|
|
|
|
from fastapi import HTTPException, UploadFile
|
|
|
|
from app.core.config import MAIN_CAPITAL_FLOW_DB_FILE, MAIN_CAPITAL_FLOW_UPLOADS_DIR
|
|
from app.repositories.main_capital_flow_repository import MainCapitalFlowRepository
|
|
from app.repositories.monitoring_repository import MonitoringRepository
|
|
|
|
|
|
def _extract_json_block(content: str) -> dict:
|
|
fenced_match = re.search(r"```json\s*(\{.*?\})\s*```", content, flags=re.DOTALL)
|
|
if fenced_match:
|
|
return json.loads(fenced_match.group(1))
|
|
|
|
object_match = re.search(r"(\{.*\})", content, flags=re.DOTALL)
|
|
if object_match:
|
|
return json.loads(object_match.group(1))
|
|
|
|
raise ValueError("No JSON object found in model output")
|
|
|
|
|
|
class MainCapitalFlowService:
|
|
def __init__(self) -> None:
|
|
self.repository = MainCapitalFlowRepository(MAIN_CAPITAL_FLOW_DB_FILE)
|
|
self.monitoring_repository = MonitoringRepository()
|
|
|
|
def list_records(self) -> dict:
|
|
items = [self._serialize_record(record) for record in self.repository.list_records()]
|
|
return {"items": items, "total": len(items)}
|
|
|
|
def get_record(self, record_id: str) -> dict:
|
|
record = self.repository.get_record(record_id)
|
|
if record is None:
|
|
raise HTTPException(status_code=404, detail="Record not found")
|
|
return self._serialize_record(record)
|
|
|
|
def delete_record(self, record_id: str) -> dict:
|
|
record = self.repository.delete_record(record_id)
|
|
if record is None:
|
|
raise HTTPException(status_code=404, detail="Record not found")
|
|
|
|
image_path = Path(record["image_path"])
|
|
if image_path.exists():
|
|
image_path.unlink(missing_ok=True)
|
|
return {"deleted": True, "id": record_id}
|
|
|
|
async def recognize_image(
|
|
self,
|
|
upload_file: UploadFile,
|
|
trade_date: str | None = None,
|
|
subject: str | None = None,
|
|
) -> dict:
|
|
suffix = Path(upload_file.filename or "upload.jpg").suffix or ".jpg"
|
|
temp_image_name = f"temp_{uuid4().hex}{suffix.lower()}"
|
|
stored_path = MAIN_CAPITAL_FLOW_UPLOADS_DIR / temp_image_name
|
|
image_name = upload_file.filename or temp_image_name
|
|
|
|
binary = await upload_file.read()
|
|
stored_path.parent.mkdir(parents=True, exist_ok=True)
|
|
stored_path.write_bytes(binary)
|
|
|
|
extraction = self._extract_via_model(binary, trade_date=trade_date, subject=subject)
|
|
return {
|
|
"temp_image_name": temp_image_name,
|
|
"image_name": image_name,
|
|
"image_url": self._build_image_url(stored_path),
|
|
"trade_date": extraction.get("trade_date") or trade_date,
|
|
"subject": extraction.get("subject") or subject,
|
|
"snapshot_time": extraction.get("snapshot_time"),
|
|
"institution_amount_yi": extraction.get("institution_amount_yi"),
|
|
"main_force_amount_yi": extraction.get("main_force_amount_yi"),
|
|
"large_household_amount_yi": extraction.get("large_household_amount_yi"),
|
|
"retail_amount_yi": extraction.get("retail_amount_yi"),
|
|
"trend": extraction.get("overall_trend"),
|
|
"summary": extraction.get("intraday_summary"),
|
|
"raw_extraction": extraction,
|
|
}
|
|
|
|
def create_record(self, payload: dict) -> dict:
|
|
if self.repository.get_by_trade_date(payload["trade_date"]):
|
|
raise HTTPException(status_code=409, detail="该日期记录已存在")
|
|
|
|
image_path = MAIN_CAPITAL_FLOW_UPLOADS_DIR / payload["temp_image_name"]
|
|
if not image_path.exists():
|
|
raise HTTPException(status_code=400, detail="识别图片不存在,请重新上传")
|
|
|
|
now = datetime.now().isoformat(timespec="seconds")
|
|
record = self.repository.insert_record(
|
|
{
|
|
"id": uuid4().hex,
|
|
"trade_date": payload["trade_date"],
|
|
"subject": payload.get("subject"),
|
|
"snapshot_time": payload.get("snapshot_time"),
|
|
"institution_amount_yi": payload.get("institution_amount_yi"),
|
|
"main_force_amount_yi": payload.get("main_force_amount_yi"),
|
|
"large_household_amount_yi": payload.get("large_household_amount_yi"),
|
|
"retail_amount_yi": payload.get("retail_amount_yi"),
|
|
"trend": payload.get("trend"),
|
|
"summary": payload["summary"],
|
|
"image_name": payload["image_name"],
|
|
"image_path": str(image_path),
|
|
"raw_extraction": payload.get("raw_extraction", {}),
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
}
|
|
)
|
|
return {"item": self._serialize_record(record)}
|
|
|
|
def _extract_via_model(
|
|
self,
|
|
image_bytes: bytes,
|
|
trade_date: str | None,
|
|
subject: str | None,
|
|
) -> dict:
|
|
llm_config = self._get_llm_config()
|
|
if not llm_config["api_key"]:
|
|
raise HTTPException(status_code=500, detail="未配置视觉模型 API")
|
|
|
|
encoded_image = base64.b64encode(image_bytes).decode("utf-8")
|
|
prompt = """
|
|
You are extracting structured data from a Chinese stock capital flow screenshot.
|
|
Return only JSON with these keys:
|
|
trade_date, subject, snapshot_time, institution_amount_yi, main_force_amount_yi,
|
|
large_household_amount_yi, retail_amount_yi, overall_trend, intraday_summary.
|
|
|
|
Rules:
|
|
1. intraday_summary must describe only the intraday capital-flow trend and must not repeat the raw amounts.
|
|
2. overall_trend should be a short Chinese phrase like "震荡上行", "午后修复", "冲高回落", "弱势下探".
|
|
3. If a field is not clearly visible, set it to null.
|
|
4. If trade_date is absent in the image, keep it null.
|
|
5. Return JSON only.
|
|
"""
|
|
request_payload = {
|
|
"model": llm_config["model"],
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "You extract structured JSON from Chinese capital-flow screenshots."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": prompt},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{encoded_image}",
|
|
},
|
|
},
|
|
],
|
|
},
|
|
],
|
|
}
|
|
|
|
request = urllib.request.Request(
|
|
url=f"{llm_config['base_url'].rstrip('/')}/chat/completions",
|
|
data=json.dumps(request_payload).encode("utf-8"),
|
|
headers={
|
|
"Authorization": f"Bearer {llm_config['api_key']}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with urllib.request.urlopen(request, timeout=180) as response:
|
|
response_payload = json.loads(response.read().decode("utf-8"))
|
|
except urllib.error.HTTPError as exc:
|
|
error_text = exc.read().decode("utf-8", errors="ignore")
|
|
raise HTTPException(status_code=502, detail=f"模型识别失败: {error_text}") from exc
|
|
|
|
choices = response_payload.get("choices", [])
|
|
content = choices[0].get("message", {}).get("content", "") if choices else ""
|
|
parsed = _extract_json_block(content)
|
|
if subject and not parsed.get("subject"):
|
|
parsed["subject"] = subject
|
|
if trade_date and not parsed.get("trade_date"):
|
|
parsed["trade_date"] = trade_date
|
|
return parsed
|
|
|
|
def _get_llm_config(self) -> dict:
|
|
config = self.monitoring_repository.get_system_config()
|
|
return {
|
|
"api_key": config.get("llm_api_key", ""),
|
|
"base_url": config.get("llm_base_url", "https://api.openai.com/v1"),
|
|
"model": config.get("llm_vision_model", "gpt-4.1-mini"),
|
|
}
|
|
|
|
def _build_image_url(self, path: Path) -> str:
|
|
return f"/main-capital-flow-images/{path.name}"
|
|
|
|
def _serialize_record(self, record: dict) -> dict:
|
|
return {
|
|
**record,
|
|
"image_url": self._build_image_url(Path(record["image_path"])),
|
|
}
|
|
|
|
|
|
main_capital_flow_service = MainCapitalFlowService()
|