Files
FishServer/fish_api/app/db.py

775 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""仅 FastAPI 进程使用 SQLite落库测量/健康结果与 watch 已处理路径。
FishMeasure / FishAction 子进程不连接、不依赖本库;它们只读写各自文件(如 measure_output 下
weight_prediction.json、临时 pred.json 等),由 fish_api 在子进程结束后读文件并写入本表。
预览视频在 media_rootstart_fresh 会清空 measure_output、media、ingest 临时目录。
"""
from __future__ import annotations
import json
import math
import shutil
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from app.settings import Settings
from app.state import HealthSnapshot, MeasureSnapshot
# 未带客户端标识时与旧行为兼容:共享同一条投递游标
DEFAULT_CLIENT_ID = "default"
MAX_CLIENT_ID_LEN = 128
# 客户端切片索引起缓存:记录每个 client_id 上次返回的切片索引(用于对齐 water/video 端点)
_client_health_slice_index: dict[str, int] = {}
def _parse_slice_index_from_source_path(source_path: Optional[str]) -> int:
"""从 source_path 解析切片索引,格式为 video.mp4#slice{N}
Returns:
切片序号(>=0如果不是切片则返回 -1
"""
if not source_path:
return -1
if "#slice" not in source_path:
return -1
try:
idx_part = source_path.split("#slice")[-1]
return int(idx_part)
except (ValueError, IndexError):
return -1
def get_last_health_slice_index(client_id: str) -> int:
"""获取指定 client_id 上次返回的切片索引(用于 water/video 端点对齐)。
Returns:
切片序号(>=0如果没有记录则返回 -1
"""
cid = normalize_client_id(client_id)
return _client_health_slice_index.get(cid, -1)
def normalize_client_id(raw: Optional[str]) -> str:
"""供轮询接口使用;过长截断,空值回退为 DEFAULT_CLIENT_ID。"""
if raw is None:
return DEFAULT_CLIENT_ID
s = str(raw).strip()
if not s:
return DEFAULT_CLIENT_ID
if len(s) > MAX_CLIENT_ID_LEN:
s = s[:MAX_CLIENT_ID_LEN]
return s
def _connect(path: Path) -> sqlite3.Connection:
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(path), check_same_thread=False, isolation_level=None)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn
def init_db(settings: Settings) -> None:
conn = _connect(settings.sqlite_path)
try:
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS measure_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TEXT NOT NULL,
result_json TEXT NOT NULL,
video_left TEXT NOT NULL DEFAULT '',
video_right TEXT NOT NULL DEFAULT '',
error TEXT,
raw_prediction_path TEXT,
source_path TEXT,
client_id TEXT DEFAULT NULL,
pred REAL,
star INTEGER DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_measure_client_id ON measure_snapshots(client_id);
CREATE TABLE IF NOT EXISTS health_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TEXT NOT NULL,
behavior_result TEXT NOT NULL DEFAULT '',
health_result TEXT NOT NULL DEFAULT '',
raw_class_en TEXT NOT NULL DEFAULT '',
error TEXT,
source_path TEXT
);
CREATE TABLE IF NOT EXISTS watch_processed (
path TEXT NOT NULL,
kind TEXT NOT NULL CHECK (kind IN ('measure', 'action')),
PRIMARY KEY (path, kind)
);
CREATE TABLE IF NOT EXISTS delivery_client_cursor (
client_id TEXT NOT NULL,
kind TEXT NOT NULL CHECK (kind IN ('measure', 'health')),
last_delivered_id INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (client_id, kind)
);
"""
)
_migrate_delivery_cursor_from_legacy(conn)
_ensure_delivery_cursors(conn)
_migrate_add_client_id_column(conn)
_migrate_add_pred_star_columns(conn)
_migrate_add_calculation_log_column(conn)
finally:
conn.close()
def _migrate_add_client_id_column(conn: sqlite3.Connection) -> None:
"""为旧数据库添加 client_id 列(如果不存在)。"""
row = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='measure_snapshots'"
).fetchone()
if row is None:
return
# 检查 client_id 列是否存在
cols = conn.execute("PRAGMA table_info(measure_snapshots)").fetchall()
has_client_id = any(col[1] == "client_id" for col in cols)
if not has_client_id:
conn.execute("ALTER TABLE measure_snapshots ADD COLUMN client_id TEXT DEFAULT NULL")
conn.execute("CREATE INDEX idx_measure_client_id ON measure_snapshots(client_id)")
conn.commit()
def _migrate_add_pred_star_columns(conn: sqlite3.Connection) -> None:
"""为旧数据库添加 pred 和 star 列(如果不存在)。"""
row = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='measure_snapshots'"
).fetchone()
if row is None:
return
cols = conn.execute("PRAGMA table_info(measure_snapshots)").fetchall()
col_names = {col[1] for col in cols}
if "pred" not in col_names:
conn.execute("ALTER TABLE measure_snapshots ADD COLUMN pred REAL")
if "star" not in col_names:
conn.execute("ALTER TABLE measure_snapshots ADD COLUMN star INTEGER DEFAULT 0")
conn.commit()
def _migrate_add_calculation_log_column(conn: sqlite3.Connection) -> None:
"""为旧数据库添加 calculation_log 列(体重推算过程文本,对齐 test_dgcnn 终端输出)。"""
row = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='measure_snapshots'"
).fetchone()
if row is None:
return
cols = conn.execute("PRAGMA table_info(measure_snapshots)").fetchall()
col_names = {col[1] for col in cols}
if "calculation_log" not in col_names:
conn.execute("ALTER TABLE measure_snapshots ADD COLUMN calculation_log TEXT")
conn.commit()
def _migrate_delivery_cursor_from_legacy(conn: sqlite3.Connection) -> None:
"""旧表 delivery_cursor(kind) → delivery_client_cursor(default, kind)。"""
row = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='delivery_cursor'"
).fetchone()
if row is None:
return
for kind in ("measure", "health"):
r = conn.execute(
"SELECT last_delivered_id FROM delivery_cursor WHERE kind = ?", (kind,)
).fetchone()
if r is not None:
conn.execute(
"""
INSERT OR REPLACE INTO delivery_client_cursor
(client_id, kind, last_delivered_id)
VALUES (?, ?, ?)
""",
(DEFAULT_CLIENT_ID, kind, int(r["last_delivered_id"])),
)
conn.execute("DROP TABLE delivery_cursor")
conn.commit()
def _ensure_delivery_cursors(conn: sqlite3.Connection) -> None:
"""为默认客户端插入一行游标;首次插入时 last_delivered_id=当前 MAX(id),避免升级后逐条投递历史快照。"""
for kind, table in (
("measure", "measure_snapshots"),
("health", "health_snapshots"),
):
row = conn.execute(
"SELECT 1 FROM delivery_client_cursor WHERE client_id = ? AND kind = ?",
(DEFAULT_CLIENT_ID, kind),
).fetchone()
if row is None:
mid = conn.execute(
f"SELECT COALESCE(MAX(id), 0) FROM {table}"
).fetchone()[0]
conn.execute(
"""
INSERT INTO delivery_client_cursor (client_id, kind, last_delivered_id)
VALUES (?, ?, ?)
""",
(DEFAULT_CLIENT_ID, kind, int(mid)),
)
conn.commit()
def save_measure_snapshot(
settings: Settings,
snap: MeasureSnapshot,
source_path: Optional[str] = None,
client_id: Optional[str] = None,
) -> None:
init_db(settings)
conn = _connect(settings.sqlite_path)
try:
ts = (
snap.updated_at.isoformat()
if snap.updated_at
else datetime.now(timezone.utc).isoformat()
)
conn.execute(
"""
INSERT INTO measure_snapshots (
created_at, result_json, video_left, video_right,
error, raw_prediction_path, source_path, client_id, pred, star,
calculation_log
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
ts,
json.dumps(snap.result, ensure_ascii=False),
snap.video_left,
snap.video_right,
snap.error,
snap.raw_prediction_path,
source_path,
client_id,
snap.pred,
1 if snap.star else 0,
snap.calculation_log,
),
)
finally:
conn.close()
def save_health_snapshot(
settings: Settings,
snap: HealthSnapshot,
source_path: Optional[str] = None,
) -> None:
init_db(settings)
conn = _connect(settings.sqlite_path)
try:
ts = (
snap.updated_at.isoformat()
if snap.updated_at
else datetime.now(timezone.utc).isoformat()
)
conn.execute(
"""
INSERT INTO health_snapshots (
created_at, behavior_result, health_result,
raw_class_en, error, source_path
) VALUES (?, ?, ?, ?, ?, ?)
""",
(
ts,
snap.behavior_result,
snap.health_result,
snap.raw_class_en,
snap.error,
source_path,
),
)
finally:
conn.close()
def _parse_dt(s: Optional[str]) -> Optional[datetime]:
if not s:
return None
try:
return datetime.fromisoformat(s.replace("Z", "+00:00"))
except ValueError:
return None
def get_latest_measure(settings: Settings) -> MeasureSnapshot:
init_db(settings)
conn = _connect(settings.sqlite_path)
try:
row = conn.execute(
"""
SELECT created_at, result_json, video_left, video_right,
error, raw_prediction_path, calculation_log
FROM measure_snapshots
ORDER BY id DESC
LIMIT 1
"""
).fetchone()
if row is None:
return MeasureSnapshot(result=[], video_left="", video_right="")
data: Any = json.loads(row["result_json"])
if not isinstance(data, list):
data = []
return MeasureSnapshot(
result=data,
video_left=row["video_left"] or "",
video_right=row["video_right"] or "",
updated_at=_parse_dt(row["created_at"]),
error=row["error"],
raw_prediction_path=row["raw_prediction_path"],
calculation_log=row["calculation_log"],
)
finally:
conn.close()
def list_all_measure_snapshots(settings: Settings) -> List[Dict[str, Any]]:
"""返回 ``measure_snapshots`` 全部行id 降序,最新在前),供调试接口使用。"""
init_db(settings)
conn = _connect(settings.sqlite_path)
try:
rows = conn.execute(
"""
SELECT id, created_at, result_json, video_left, video_right,
error, raw_prediction_path, source_path, client_id, pred, star,
calculation_log
FROM measure_snapshots
ORDER BY id DESC
"""
).fetchall()
out: List[Dict[str, Any]] = []
for row in rows:
data: Any = json.loads(row["result_json"])
if not isinstance(data, list):
data = []
st = row["star"]
out.append(
{
"id": row["id"],
"created_at": row["created_at"],
"result": data,
"video_left": row["video_left"] or "",
"video_right": row["video_right"] or "",
"error": row["error"],
"raw_prediction_path": row["raw_prediction_path"],
"source_path": row["source_path"],
"client_id": row["client_id"],
"pred": row["pred"],
"star": bool(st) if st is not None else False,
"calculation_log": row["calculation_log"],
}
)
return out
finally:
conn.close()
def get_latest_health(settings: Settings) -> HealthSnapshot:
init_db(settings)
conn = _connect(settings.sqlite_path)
try:
row = conn.execute(
"""
SELECT created_at, behavior_result, health_result,
raw_class_en, error
FROM health_snapshots
ORDER BY id DESC
LIMIT 1
"""
).fetchone()
if row is None:
return HealthSnapshot(behavior_result="", health_result="")
return HealthSnapshot(
behavior_result=row["behavior_result"] or "",
health_result=row["health_result"] or "",
updated_at=_parse_dt(row["created_at"]),
error=row["error"],
raw_class_en=row["raw_class_en"] or "",
)
finally:
conn.close()
def _coerce_finite_number(v: Any) -> Optional[float]:
if v is None:
return None
if isinstance(v, bool):
return None
if isinstance(v, (int, float)):
x = float(v)
return x if math.isfinite(x) else None
if isinstance(v, str):
s = v.strip()
if not s:
return None
try:
x = float(s)
return x if math.isfinite(x) else None
except ValueError:
return None
return None
def _coerce_track_id(v: Any) -> Optional[int]:
# bool is a subclass of int in Python
if isinstance(v, bool):
return None
if isinstance(v, int):
return v if v >= 0 else None
if isinstance(v, str):
try:
i = int(v.strip(), 10)
return i if i >= 0 else None
except ValueError:
return None
return None
def measure_result_deliverable(result: Any, error: Optional[str]) -> bool:
"""至少一条记录含有效 track id 与有限数值的 weight(g)、length(mm)。"""
if error:
return False
if not isinstance(result, list) or not result:
return False
for it in result:
if not isinstance(it, dict):
continue
tid = _coerce_track_id(it.get("id"))
w = _coerce_finite_number(it.get("weight"))
ln = _coerce_finite_number(it.get("length"))
if tid is not None and w is not None and ln is not None:
return True
return False
def measure_snapshot_deliverable(snap: MeasureSnapshot) -> bool:
return measure_result_deliverable(snap.result, snap.error)
def health_snapshot_deliverable(snap: HealthSnapshot) -> bool:
if snap.error:
return False
b = (snap.behavior_result or "").strip()
h = (snap.health_result or "").strip()
r = (snap.raw_class_en or "").strip()
return bool(b or h or r)
def _health_row_deliverable(
behavior_result: str,
health_result: str,
raw_class_en: str,
error: Optional[str],
) -> bool:
snap = HealthSnapshot(
behavior_result=behavior_result or "",
health_result=health_result or "",
raw_class_en=raw_class_en or "",
error=error,
)
return health_snapshot_deliverable(snap)
def _last_delivered_id(
conn: sqlite3.Connection,
kind: str,
snapshots_table: str,
client_id: str,
) -> int:
row = conn.execute(
"""
SELECT last_delivered_id FROM delivery_client_cursor
WHERE kind = ? AND client_id = ?
""",
(kind, client_id),
).fetchone()
if row is not None:
return int(row["last_delivered_id"])
mid = conn.execute(
f"SELECT COALESCE(MAX(id), 0) FROM {snapshots_table}"
).fetchone()[0]
conn.execute(
"""
INSERT INTO delivery_client_cursor (client_id, kind, last_delivered_id)
VALUES (?, ?, ?)
""",
(client_id, kind, int(mid)),
)
return int(mid)
def pop_next_measure(
settings: Settings,
client_id: str = DEFAULT_CLIENT_ID,
) -> Tuple[MeasureSnapshot, bool, Optional[int]]:
"""取该客户端队首未投递且可交付的 measure 快照并推进其游标;跳过不可交付行仅推进游标。
只返回与该 client_id 匹配的记录client_id 为 NULL 的记录对所有客户端可见,用于向后兼容)。
"""
cid = normalize_client_id(client_id)
init_db(settings)
conn = _connect(settings.sqlite_path)
try:
conn.execute("BEGIN IMMEDIATE")
last_id = _last_delivered_id(conn, "measure", "measure_snapshots", cid)
while True:
# 只查询匹配该 client_id 或 client_id 为 NULL 的记录
next_row = conn.execute(
"""
SELECT id, created_at, result_json, video_left, video_right,
error, raw_prediction_path, pred, star, calculation_log
FROM measure_snapshots
WHERE id > ? AND (client_id = ? OR client_id IS NULL)
ORDER BY id ASC
LIMIT 1
""",
(last_id, cid),
).fetchone()
if next_row is None:
conn.commit()
return MeasureSnapshot(result=[], video_left="", video_right=""), False, None
nid = int(next_row["id"])
err: Optional[str] = next_row["error"]
data: Any = json.loads(next_row["result_json"])
if not isinstance(data, list):
data = []
conn.execute(
"""
UPDATE delivery_client_cursor SET last_delivered_id = ?
WHERE kind = ? AND client_id = ?
""",
(nid, "measure", cid),
)
if not measure_result_deliverable(data, err):
last_id = nid
continue
conn.commit()
snap = MeasureSnapshot(
result=data,
video_left=next_row["video_left"] or "",
video_right=next_row["video_right"] or "",
updated_at=_parse_dt(next_row["created_at"]),
error=err,
raw_prediction_path=next_row["raw_prediction_path"],
pred=next_row["pred"],
star=bool(next_row["star"]) if next_row["star"] is not None else False,
calculation_log=next_row["calculation_log"],
)
return snap, True, nid
except Exception:
conn.rollback()
raise
finally:
conn.close()
def pop_next_health(
settings: Settings,
client_id: str = DEFAULT_CLIENT_ID,
) -> Tuple[HealthSnapshot, bool, Optional[int]]:
"""取该客户端队首未投递且可交付的 health 快照并推进其游标;跳过不可交付行仅推进游标。"""
global _client_health_slice_index
cid = normalize_client_id(client_id)
init_db(settings)
conn = _connect(settings.sqlite_path)
try:
conn.execute("BEGIN IMMEDIATE")
last_id = _last_delivered_id(conn, "health", "health_snapshots", cid)
while True:
next_row = conn.execute(
"""
SELECT id, created_at, behavior_result, health_result,
raw_class_en, error, source_path
FROM health_snapshots
WHERE id > ?
ORDER BY id ASC
LIMIT 1
""",
(last_id,),
).fetchone()
if next_row is None:
conn.commit()
return HealthSnapshot(behavior_result="", health_result=""), False, None
nid = int(next_row["id"])
beh = next_row["behavior_result"] or ""
hlth = next_row["health_result"] or ""
raw_en = next_row["raw_class_en"] or ""
err: Optional[str] = next_row["error"]
source_path: Optional[str] = next_row["source_path"]
conn.execute(
"""
UPDATE delivery_client_cursor SET last_delivered_id = ?
WHERE kind = ? AND client_id = ?
""",
(nid, "health", cid),
)
if not _health_row_deliverable(beh, hlth, raw_en, err):
last_id = nid
continue
conn.commit()
# 解析并记录切片索引(用于与 water/video 端点对齐)
slice_idx = _parse_slice_index_from_source_path(source_path)
_client_health_slice_index[cid] = slice_idx
snap = HealthSnapshot(
behavior_result=beh,
health_result=hlth,
updated_at=_parse_dt(next_row["created_at"]),
error=err,
raw_class_en=raw_en,
)
return snap, True, nid
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _load_json_processed_set(path: Path) -> Set[str]:
if not path.is_file():
return set()
try:
with open(path, encoding="utf-8") as f:
data: Any = json.load(f)
if isinstance(data, list):
return set(str(x) for x in data)
if isinstance(data, dict) and "processed" in data:
return set(str(x) for x in data["processed"])
except (json.JSONDecodeError, OSError):
pass
return set()
def load_watch_processed(settings: Settings, state_file: Path, kind: str) -> Set[str]:
"""从 SQLite 读取已处理路径;若存在旧版 JSON 状态文件则合并导入(幂等)。"""
assert kind in ("measure", "action")
init_db(settings)
conn = _connect(settings.sqlite_path)
try:
for p in _load_json_processed_set(state_file):
conn.execute(
"INSERT OR IGNORE INTO watch_processed (path, kind) VALUES (?, ?)",
(p, kind),
)
conn.commit()
cur = conn.execute(
"SELECT path FROM watch_processed WHERE kind = ?", (kind,)
)
return {r[0] for r in cur}
finally:
conn.close()
def add_watch_processed(settings: Settings, path: str, kind: str) -> None:
assert kind in ("measure", "action")
init_db(settings)
conn = _connect(settings.sqlite_path)
try:
conn.execute(
"INSERT OR IGNORE INTO watch_processed (path, kind) VALUES (?, ?)",
(path, kind),
)
conn.commit()
finally:
conn.close()
def _safe_rm_tree(path: Path) -> None:
"""安全删除目录树(包括目录本身),忽略不存在或权限错误。"""
p = Path(path).resolve()
if not p.exists():
return
try:
if p.is_dir():
shutil.rmtree(p)
else:
p.unlink()
except OSError as e:
print(f"[prestart-fresh] skip remove {p}: {e}", flush=True)
def clear_runtime_compute_dirs(settings: Settings) -> None:
"""清空 FishMeasure / FishAction 运行时目录、托管预览、ingest 临时文件(保留目录本身)。
与 remove_sqlite_database_files 一起在启动脚本中调用,使两条流水线重启后均重新计算。
"""
for base in (
settings.measure_output_root,
settings.action_output_root,
settings.media_root,
settings.stream_tmp_dir,
):
p = Path(base).resolve()
if not p.is_dir():
continue
for child in p.iterdir():
try:
if child.is_dir():
shutil.rmtree(child)
else:
child.unlink()
except OSError as e:
print(f"[prestart-fresh] skip remove {child}: {e}", flush=True)
def remove_sqlite_database_files(settings: Settings) -> None:
"""删除 SQLite 主库及 WAL/SHM 副文件;不存在则忽略。下次 init_db 会重建空库。"""
base = settings.sqlite_path.resolve()
for p in (base, Path(str(base) + "-wal"), Path(str(base) + "-shm")):
try:
if p.is_file():
p.unlink()
except OSError:
pass
def clear_watch_cache_and_snapshots(settings: Settings) -> None:
"""清空 watch 已处理路径与对应快照,便于重新跑推理(与 measure/action_watch 的 use_state_file 开关一致)。"""
init_db(settings)
conn = _connect(settings.sqlite_path)
try:
if settings.measure_watch_use_state_file:
conn.execute("DELETE FROM watch_processed WHERE kind = ?", ("measure",))
conn.execute("DELETE FROM measure_snapshots")
conn.execute(
"UPDATE delivery_client_cursor SET last_delivered_id = 0 WHERE kind = ?",
("measure",),
)
if settings.action_watch_use_state_file:
conn.execute("DELETE FROM watch_processed WHERE kind = ?", ("action",))
conn.execute("DELETE FROM health_snapshots")
conn.execute(
"UPDATE delivery_client_cursor SET last_delivered_id = 0 WHERE kind = ?",
("health",),
)
conn.commit()
finally:
conn.close()