Files
operating-room-monitor-server/backend/algorithm_subprocesses/5.15/visualize_result_video.py

401 lines
14 KiB
Python
Raw Normal View History

2026-05-21 15:48:03 +08:00
#!/usr/bin/env python3
"""根据 output/result.txt 生成手部融合框可视化视频。"""
from __future__ import annotations
import argparse
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import cv2
import numpy as np
try:
from PIL import Image, ImageDraw, ImageFont
except Exception: # noqa: BLE001
Image = None
ImageDraw = None
ImageFont = None
PACK_ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(PACK_ROOT / "src"))
from paths import ensure_code_on_path # noqa: E402
ensure_code_on_path(PACK_ROOT)
from pipeline.hand_roi_merge import HandMergeConfig, HandRoiGrouper, two_largest_hands, union_xyxy # noqa: E402
from run_segments_consumable_vote import collect_hand_boxes, pad_box as _pad_box # noqa: E402
from ultralytics import YOLO # noqa: E402
@dataclass
class SegmentRow:
rank: int
start_sec: float
end_sec: float
top1_name: str
_FONT_CANDIDATES = [
Path("/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc"),
Path("/usr/share/fonts/opentype/noto/NotoSerifCJK-Regular.ttc"),
2026-05-22 10:45:47 +08:00
Path("/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc"),
2026-05-21 15:48:03 +08:00
Path("/usr/share/fonts/truetype/wqy/wqy-microhei.ttc"),
2026-05-22 10:45:47 +08:00
Path("/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc"),
2026-05-21 15:48:03 +08:00
]
def log(msg: str) -> None:
print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
def compute_scaled_size(width: int, height: int, max_width: int) -> tuple[int, int]:
"""Fit inside max_width (even dimensions for encoders). 1920 => 1080p landscape."""
if max_width <= 0 or width <= max_width:
w, h = width, height
else:
scale = max_width / float(width)
w = max(1, int(round(width * scale)))
h = max(1, int(round(height * scale)))
if w % 2:
w -= 1
if h % 2:
h -= 1
return max(2, w), max(2, h)
def parse_result_txt(path: Path) -> tuple[list[SegmentRow], str]:
lines = path.read_text(encoding="utf-8").splitlines()
header_idx = None
doctor_info = ""
for i, raw in enumerate(lines):
line = raw.strip()
if line.startswith("医生信息:"):
doctor_info = line
if line.lower().startswith("rank\t"):
header_idx = i
break
if header_idx is None:
raise ValueError(f"未找到结果表头: {path}")
header = lines[header_idx].split("\t")
col_idx = {name: idx for idx, name in enumerate(header)}
for key in ("rank", "start_sec", "end_sec", "top1_name"):
if key not in col_idx:
raise ValueError(f"结果文件缺少列 {key}: {path}")
out: list[SegmentRow] = []
for raw in lines[header_idx + 1 :]:
line = raw.strip()
if not line:
continue
if line.startswith("医生信息:"):
doctor_info = line
continue
parts = raw.split("\t")
need = max(col_idx.values()) + 1
if len(parts) < need:
continue
try:
rank = int(parts[col_idx["rank"]].strip())
start_sec = float(parts[col_idx["start_sec"]].strip())
end_sec = float(parts[col_idx["end_sec"]].strip())
except ValueError:
continue
top1_name = parts[col_idx["top1_name"]].strip()
out.append(
SegmentRow(
rank=rank,
start_sec=start_sec,
end_sec=end_sec,
top1_name=top1_name,
)
)
out.sort(key=lambda x: (x.start_sec, x.end_sec, x.rank))
return out, doctor_info
def active_segment_at(segments: list[SegmentRow], idx_hint: int, t_sec: float) -> tuple[int, SegmentRow | None]:
i = idx_hint
n = len(segments)
while i < n and t_sec > segments[i].end_sec + 1e-6:
i += 1
if i < n:
seg = segments[i]
if seg.start_sec - 1e-6 <= t_sec <= seg.end_sec + 1e-6:
return i, seg
return i, None
def fused_box_padded(
frame,
hands: list[list[float]],
grouper: HandRoiGrouper,
) -> tuple[int, int, int, int] | None:
if not hands:
return None
h, w = frame.shape[:2]
pad_fn = grouper.pad_box_fn
ratio = grouper.pad_ratio
if len(hands) == 1:
return pad_fn(hands[0], w, h, ratio)
# 需求:不要分别画两只手;两手时统一合成为一个外接框。
h1, h2 = two_largest_hands(hands)
uni = union_xyxy(h1, h2)
return pad_fn(uni, w, h, ratio)
def load_pil_font(font_path: Path | None, font_size: int):
if ImageFont is None:
return None, None
candidates: list[Path] = []
if font_path is not None:
candidates.append(font_path)
candidates.extend(_FONT_CANDIDATES)
for p in candidates:
if p.is_file():
try:
return ImageFont.truetype(str(p), font_size), p
except Exception: # noqa: BLE001
continue
try:
return ImageFont.load_default(), None
except Exception: # noqa: BLE001
return None, None
def draw_label_box(frame, rect: tuple[int, int, int, int], label: str, pil_font) -> None:
x1, y1, x2, y2 = rect
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 165, 255), 2)
text = label.strip() if label.strip() else "unknown"
text = text.replace("\t", " ")
if Image is not None and ImageDraw is not None and pil_font is not None:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(frame_rgb)
draw = ImageDraw.Draw(img)
l, t, r, b = draw.textbbox((0, 0), text, font=pil_font)
tw = max(1, r - l)
th = max(1, b - t)
by2 = max(0, y1 - 4)
by1 = max(0, by2 - th - 8)
bx2 = min(frame.shape[1] - 1, x1 + tw + 8)
draw.rectangle([(x1, by1), (bx2, by2)], fill=(255, 165, 0))
draw.text((x1 + 4, by1 + 2), text, font=pil_font, fill=(0, 0, 0))
frame[:, :, :] = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
return
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.62, 2)
by2 = max(0, y1 - 4)
by1 = max(0, by2 - th - 8)
bx2 = min(frame.shape[1] - 1, x1 + tw + 8)
cv2.rectangle(frame, (x1, by1), (bx2, by2), (0, 165, 255), -1)
cv2.putText(frame, text, (x1 + 4, max(0, by2 - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.62, (0, 0, 0), 2, cv2.LINE_AA)
def draw_segment_label_banner(frame, label: str, pil_font) -> None:
text = label.strip() if label.strip() else "unknown"
text = text.replace("\t", " ")
if Image is not None and ImageDraw is not None and pil_font is not None:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(frame_rgb)
draw = ImageDraw.Draw(img)
l, t, r, b = draw.textbbox((0, 0), text, font=pil_font)
tw = max(1, r - l)
th = max(1, b - t)
pad = 10
x1 = 12
y1 = 12
x2 = min(frame.shape[1] - 1, x1 + tw + pad * 2)
y2 = min(frame.shape[0] - 1, y1 + th + pad * 2)
draw.rectangle([(x1, y1), (x2, y2)], fill=(255, 165, 0))
draw.text((x1 + pad, y1 + pad), text, font=pil_font, fill=(0, 0, 0))
frame[:, :, :] = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
return
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.62, 2)
pad = 8
x1 = 12
y1 = 12
x2 = min(frame.shape[1] - 1, x1 + tw + pad * 2)
y2 = min(frame.shape[0] - 1, y1 + th + pad * 2)
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 165, 255), -1)
cv2.putText(frame, text, (x1 + pad, y2 - pad), cv2.FONT_HERSHEY_SIMPLEX, 0.62, (0, 0, 0), 2, cv2.LINE_AA)
def draw_bottom_right_info(frame, text: str, pil_font) -> None:
info = text.strip()
if not info:
return
if Image is not None and ImageDraw is not None and pil_font is not None:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(frame_rgb)
draw = ImageDraw.Draw(img)
l, t, r, b = draw.textbbox((0, 0), info, font=pil_font)
tw = max(1, r - l)
th = max(1, b - t)
pad = 10
x1 = max(0, frame.shape[1] - tw - pad * 2 - 12)
y1 = max(0, frame.shape[0] - th - pad * 2 - 12)
x2 = min(frame.shape[1] - 1, x1 + tw + pad * 2)
y2 = min(frame.shape[0] - 1, y1 + th + pad * 2)
draw.rectangle([(x1, y1), (x2, y2)], fill=(255, 165, 0))
draw.text((x1 + pad, y1 + pad), info, font=pil_font, fill=(0, 0, 0))
frame[:, :, :] = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
return
(tw, th), _ = cv2.getTextSize(info, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 2)
pad = 8
x1 = max(0, frame.shape[1] - tw - pad * 2 - 10)
y1 = max(0, frame.shape[0] - th - pad * 2 - 10)
x2 = min(frame.shape[1] - 1, x1 + tw + pad * 2)
y2 = min(frame.shape[0] - 1, y1 + th + pad * 2)
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 165, 255), -1)
cv2.putText(frame, info, (x1 + pad, y2 - pad), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 2, cv2.LINE_AA)
def main() -> int:
os.environ.setdefault("OPENCV_FFMPEG_LOGLEVEL", "8")
ap = argparse.ArgumentParser(description="按 result.txt 时间段绘制手部融合框+耗材标签,输出 MP4。")
ap.add_argument("--video", type=Path, default=PACK_ROOT / "input" / "sample.mp4")
ap.add_argument("--result-txt", type=Path, default=PACK_ROOT / "output" / "result.txt")
ap.add_argument("--hand-model", type=Path, default=PACK_ROOT / "weights" / "hand_detect.pt")
ap.add_argument("--out-video", type=Path, default=PACK_ROOT / "output" / "result_vis.mp4")
ap.add_argument("--det-conf", type=float, default=0.6)
ap.add_argument("--imgsz-det", type=int, default=640)
ap.add_argument("--pad-ratio", type=float, default=0.20)
ap.add_argument("--merge-iou-gt", type=float, default=0.0)
ap.add_argument("--merge-center-dist-max-px", type=float, default=None)
ap.add_argument("--merge-center-dist-max-frac-diag", type=float, default=None)
ap.add_argument("--device", type=str, default="cuda")
ap.add_argument("--half", action="store_true", help="传给 YOLO predict 的 half=True")
ap.add_argument(
"--font-path",
type=Path,
default=None,
help="中文字体文件ttf/ttc路径不传则自动尝试系统常见 CJK 字体",
)
ap.add_argument(
"--max-width",
type=int,
default=1920,
help="输出最大宽度(默认 1920≈1080p>0 时在读帧后缩放再跑 YOLO/写盘。",
)
args = ap.parse_args()
video_path = args.video.resolve()
txt_path = args.result_txt.resolve()
model_path = args.hand_model.resolve()
out_path = args.out_video.resolve()
out_path.parent.mkdir(parents=True, exist_ok=True)
for p, name in ((video_path, "输入视频"), (txt_path, "结果txt"), (model_path, "手部权重")):
if not p.is_file():
print(f"缺少{name}: {p}", file=sys.stderr)
return 1
segs, doctor_info_text = parse_result_txt(txt_path)
if not segs:
print(f"未在 txt 中解析到有效时间段: {txt_path}", file=sys.stderr)
return 1
if doctor_info_text:
log(f"医生信息: {doctor_info_text}")
log(f"加载手部模型: {model_path}")
det = YOLO(str(model_path))
merge_cfg = HandMergeConfig(
merge_iou_gt=float(args.merge_iou_gt),
merge_center_dist_max_px=args.merge_center_dist_max_px,
merge_center_dist_max_frac_diag=args.merge_center_dist_max_frac_diag,
)
grouper = HandRoiGrouper(merge_cfg, pad_box_fn=_pad_box, pad_ratio=float(args.pad_ratio))
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
print(f"无法打开视频: {video_path}", file=sys.stderr)
return 1
fps = float(cap.get(cv2.CAP_PROP_FPS))
if fps <= 0:
fps = 25.0
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out_w, out_h = compute_scaled_size(w, h, int(args.max_width))
if (out_w, out_h) != (w, h):
log(f"输出分辨率: {w}x{h} -> {out_w}x{out_h} (max_width={args.max_width})")
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
font_size = max(18, int(out_h * 0.028))
font_path = args.font_path.resolve() if args.font_path is not None else None
pil_font, font_used = load_pil_font(font_path, font_size)
if font_used is not None:
log(f"标签字体: {font_used}")
elif pil_font is not None:
log("标签字体: Pillow 默认字体(可能不支持中文)")
else:
log("标签字体: 回退 OpenCV 内置字体(中文可能显示异常)")
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(str(out_path), fourcc, fps, (out_w, out_h))
if not writer.isOpened():
cap.release()
print(f"无法创建视频写入器: {out_path}", file=sys.stderr)
return 1
predict_kw: dict[str, Any] = {"device": args.device}
if bool(args.half):
predict_kw["half"] = True
frame_idx = 0
seg_idx = 0
n_drawn = 0
try:
while True:
ok, frame = cap.read()
if not ok or frame is None:
break
if (out_w, out_h) != (w, h):
frame = cv2.resize(frame, (out_w, out_h), interpolation=cv2.INTER_AREA)
frame_idx += 1
t_sec = frame_idx / fps
seg_idx, seg = active_segment_at(segs, seg_idx, t_sec)
if seg is not None:
r0 = det.predict(
frame,
conf=float(args.det_conf),
imgsz=int(args.imgsz_det),
verbose=False,
**predict_kw,
)[0]
hands = collect_hand_boxes(det, r0.boxes) if r0.boxes else []
fused = fused_box_padded(frame, hands, grouper)
if fused is not None:
draw_label_box(frame, fused, seg.top1_name, pil_font)
n_drawn += 1
else:
draw_segment_label_banner(frame, seg.top1_name, pil_font)
n_drawn += 1
if doctor_info_text:
draw_bottom_right_info(frame, doctor_info_text, pil_font)
writer.write(frame)
if frame_idx % 200 == 0:
log(f"处理中: {frame_idx}/{max(total, 1)}")
finally:
writer.release()
cap.release()
log(f"完成: 输出 {out_path}")
log(f"共绘制 {n_drawn} 帧融合框(总帧 {frame_idx}")
return 0
if __name__ == "__main__":
raise SystemExit(main())