339 lines
12 KiB
Python
339 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
"""根据 output/result.txt 生成手部融合框可视化视频。"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
import time
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import cv2
|
||
import numpy as np
|
||
|
||
try:
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
except Exception: # noqa: BLE001
|
||
Image = None
|
||
ImageDraw = None
|
||
ImageFont = None
|
||
|
||
PACK_ROOT = Path(__file__).resolve().parent
|
||
sys.path.insert(0, str(PACK_ROOT / "src"))
|
||
|
||
from paths import ensure_code_on_path # noqa: E402
|
||
|
||
ensure_code_on_path(PACK_ROOT)
|
||
|
||
from pipeline.hand_roi_merge import HandMergeConfig, HandRoiGrouper, two_largest_hands, union_xyxy # noqa: E402
|
||
from run_segments_consumable_vote import collect_hand_boxes, pad_box as _pad_box # noqa: E402
|
||
from ultralytics import YOLO # noqa: E402
|
||
|
||
|
||
@dataclass
|
||
class SegmentRow:
|
||
rank: int
|
||
start_sec: float
|
||
end_sec: float
|
||
top1_name: str
|
||
|
||
|
||
_FONT_CANDIDATES = [
|
||
Path("/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc"),
|
||
Path("/usr/share/fonts/opentype/noto/NotoSerifCJK-Regular.ttc"),
|
||
Path("/usr/share/fonts/truetype/wqy/wqy-microhei.ttc"),
|
||
]
|
||
|
||
|
||
def log(msg: str) -> None:
|
||
print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
|
||
|
||
|
||
def parse_result_txt(path: Path) -> tuple[list[SegmentRow], str]:
|
||
lines = path.read_text(encoding="utf-8").splitlines()
|
||
header_idx = None
|
||
doctor_info = ""
|
||
for i, raw in enumerate(lines):
|
||
line = raw.strip()
|
||
if line.startswith("医生信息:"):
|
||
doctor_info = line
|
||
if line.lower().startswith("rank\t"):
|
||
header_idx = i
|
||
break
|
||
if header_idx is None:
|
||
raise ValueError(f"未找到结果表头: {path}")
|
||
|
||
header = lines[header_idx].split("\t")
|
||
col_idx = {name: idx for idx, name in enumerate(header)}
|
||
for key in ("rank", "start_sec", "end_sec", "top1_name"):
|
||
if key not in col_idx:
|
||
raise ValueError(f"结果文件缺少列 {key}: {path}")
|
||
|
||
out: list[SegmentRow] = []
|
||
for raw in lines[header_idx + 1 :]:
|
||
line = raw.strip()
|
||
if not line:
|
||
continue
|
||
if line.startswith("医生信息:"):
|
||
doctor_info = line
|
||
continue
|
||
parts = raw.split("\t")
|
||
need = max(col_idx.values()) + 1
|
||
if len(parts) < need:
|
||
continue
|
||
try:
|
||
rank = int(parts[col_idx["rank"]].strip())
|
||
start_sec = float(parts[col_idx["start_sec"]].strip())
|
||
end_sec = float(parts[col_idx["end_sec"]].strip())
|
||
except ValueError:
|
||
continue
|
||
top1_name = parts[col_idx["top1_name"]].strip()
|
||
out.append(
|
||
SegmentRow(
|
||
rank=rank,
|
||
start_sec=start_sec,
|
||
end_sec=end_sec,
|
||
top1_name=top1_name,
|
||
)
|
||
)
|
||
out.sort(key=lambda x: (x.start_sec, x.end_sec, x.rank))
|
||
return out, doctor_info
|
||
|
||
|
||
def active_segment_at(segments: list[SegmentRow], idx_hint: int, t_sec: float) -> tuple[int, SegmentRow | None]:
|
||
i = idx_hint
|
||
n = len(segments)
|
||
while i < n and t_sec > segments[i].end_sec + 1e-6:
|
||
i += 1
|
||
if i < n:
|
||
seg = segments[i]
|
||
if seg.start_sec - 1e-6 <= t_sec <= seg.end_sec + 1e-6:
|
||
return i, seg
|
||
return i, None
|
||
|
||
|
||
def fused_box_padded(
|
||
frame,
|
||
hands: list[list[float]],
|
||
grouper: HandRoiGrouper,
|
||
) -> tuple[int, int, int, int] | None:
|
||
if not hands:
|
||
return None
|
||
h, w = frame.shape[:2]
|
||
pad_fn = grouper.pad_box_fn
|
||
ratio = grouper.pad_ratio
|
||
if len(hands) == 1:
|
||
return pad_fn(hands[0], w, h, ratio)
|
||
|
||
# 需求:不要分别画两只手;两手时统一合成为一个外接框。
|
||
h1, h2 = two_largest_hands(hands)
|
||
uni = union_xyxy(h1, h2)
|
||
return pad_fn(uni, w, h, ratio)
|
||
|
||
|
||
def load_pil_font(font_path: Path | None, font_size: int):
|
||
if ImageFont is None:
|
||
return None, None
|
||
candidates: list[Path] = []
|
||
if font_path is not None:
|
||
candidates.append(font_path)
|
||
candidates.extend(_FONT_CANDIDATES)
|
||
for p in candidates:
|
||
if p.is_file():
|
||
try:
|
||
return ImageFont.truetype(str(p), font_size), p
|
||
except Exception: # noqa: BLE001
|
||
continue
|
||
try:
|
||
return ImageFont.load_default(), None
|
||
except Exception: # noqa: BLE001
|
||
return None, None
|
||
|
||
|
||
def draw_label_box(frame, rect: tuple[int, int, int, int], label: str, pil_font) -> None:
|
||
x1, y1, x2, y2 = rect
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 165, 255), 2)
|
||
text = label.strip() if label.strip() else "unknown"
|
||
text = text.replace("\t", " ")
|
||
|
||
if Image is not None and ImageDraw is not None and pil_font is not None:
|
||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
img = Image.fromarray(frame_rgb)
|
||
draw = ImageDraw.Draw(img)
|
||
l, t, r, b = draw.textbbox((0, 0), text, font=pil_font)
|
||
tw = max(1, r - l)
|
||
th = max(1, b - t)
|
||
by2 = max(0, y1 - 4)
|
||
by1 = max(0, by2 - th - 8)
|
||
bx2 = min(frame.shape[1] - 1, x1 + tw + 8)
|
||
draw.rectangle([(x1, by1), (bx2, by2)], fill=(255, 165, 0))
|
||
draw.text((x1 + 4, by1 + 2), text, font=pil_font, fill=(0, 0, 0))
|
||
frame[:, :, :] = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
|
||
return
|
||
|
||
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.62, 2)
|
||
by2 = max(0, y1 - 4)
|
||
by1 = max(0, by2 - th - 8)
|
||
bx2 = min(frame.shape[1] - 1, x1 + tw + 8)
|
||
cv2.rectangle(frame, (x1, by1), (bx2, by2), (0, 165, 255), -1)
|
||
cv2.putText(frame, text, (x1 + 4, max(0, by2 - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.62, (0, 0, 0), 2, cv2.LINE_AA)
|
||
|
||
|
||
def draw_bottom_right_info(frame, text: str, pil_font) -> None:
|
||
info = text.strip()
|
||
if not info:
|
||
return
|
||
|
||
if Image is not None and ImageDraw is not None and pil_font is not None:
|
||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
img = Image.fromarray(frame_rgb)
|
||
draw = ImageDraw.Draw(img)
|
||
l, t, r, b = draw.textbbox((0, 0), info, font=pil_font)
|
||
tw = max(1, r - l)
|
||
th = max(1, b - t)
|
||
pad = 10
|
||
x1 = max(0, frame.shape[1] - tw - pad * 2 - 12)
|
||
y1 = max(0, frame.shape[0] - th - pad * 2 - 12)
|
||
x2 = min(frame.shape[1] - 1, x1 + tw + pad * 2)
|
||
y2 = min(frame.shape[0] - 1, y1 + th + pad * 2)
|
||
draw.rectangle([(x1, y1), (x2, y2)], fill=(255, 165, 0))
|
||
draw.text((x1 + pad, y1 + pad), info, font=pil_font, fill=(0, 0, 0))
|
||
frame[:, :, :] = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
|
||
return
|
||
|
||
(tw, th), _ = cv2.getTextSize(info, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 2)
|
||
pad = 8
|
||
x1 = max(0, frame.shape[1] - tw - pad * 2 - 10)
|
||
y1 = max(0, frame.shape[0] - th - pad * 2 - 10)
|
||
x2 = min(frame.shape[1] - 1, x1 + tw + pad * 2)
|
||
y2 = min(frame.shape[0] - 1, y1 + th + pad * 2)
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 165, 255), -1)
|
||
cv2.putText(frame, info, (x1 + pad, y2 - pad), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 2, cv2.LINE_AA)
|
||
|
||
|
||
def main() -> int:
|
||
os.environ.setdefault("OPENCV_FFMPEG_LOGLEVEL", "8")
|
||
ap = argparse.ArgumentParser(description="按 result.txt 时间段绘制手部融合框+耗材标签,输出 MP4。")
|
||
ap.add_argument("--video", type=Path, default=PACK_ROOT / "input" / "sample.mp4")
|
||
ap.add_argument("--result-txt", type=Path, default=PACK_ROOT / "output" / "result.txt")
|
||
ap.add_argument("--hand-model", type=Path, default=PACK_ROOT / "weights" / "hand_detect.pt")
|
||
ap.add_argument("--out-video", type=Path, default=PACK_ROOT / "output" / "result_vis.mp4")
|
||
ap.add_argument("--det-conf", type=float, default=0.6)
|
||
ap.add_argument("--imgsz-det", type=int, default=640)
|
||
ap.add_argument("--pad-ratio", type=float, default=0.20)
|
||
ap.add_argument("--merge-iou-gt", type=float, default=0.0)
|
||
ap.add_argument("--merge-center-dist-max-px", type=float, default=None)
|
||
ap.add_argument("--merge-center-dist-max-frac-diag", type=float, default=None)
|
||
ap.add_argument("--device", type=str, default="cuda")
|
||
ap.add_argument("--half", action="store_true", help="传给 YOLO predict 的 half=True")
|
||
ap.add_argument(
|
||
"--font-path",
|
||
type=Path,
|
||
default=None,
|
||
help="中文字体文件(ttf/ttc)路径;不传则自动尝试系统常见 CJK 字体",
|
||
)
|
||
args = ap.parse_args()
|
||
|
||
video_path = args.video.resolve()
|
||
txt_path = args.result_txt.resolve()
|
||
model_path = args.hand_model.resolve()
|
||
out_path = args.out_video.resolve()
|
||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
for p, name in ((video_path, "输入视频"), (txt_path, "结果txt"), (model_path, "手部权重")):
|
||
if not p.is_file():
|
||
print(f"缺少{name}: {p}", file=sys.stderr)
|
||
return 1
|
||
|
||
segs, doctor_info_text = parse_result_txt(txt_path)
|
||
if not segs:
|
||
print(f"未在 txt 中解析到有效时间段: {txt_path}", file=sys.stderr)
|
||
return 1
|
||
if doctor_info_text:
|
||
log(f"医生信息: {doctor_info_text}")
|
||
|
||
log(f"加载手部模型: {model_path}")
|
||
det = YOLO(str(model_path))
|
||
merge_cfg = HandMergeConfig(
|
||
merge_iou_gt=float(args.merge_iou_gt),
|
||
merge_center_dist_max_px=args.merge_center_dist_max_px,
|
||
merge_center_dist_max_frac_diag=args.merge_center_dist_max_frac_diag,
|
||
)
|
||
grouper = HandRoiGrouper(merge_cfg, pad_box_fn=_pad_box, pad_ratio=float(args.pad_ratio))
|
||
|
||
cap = cv2.VideoCapture(str(video_path))
|
||
if not cap.isOpened():
|
||
print(f"无法打开视频: {video_path}", file=sys.stderr)
|
||
return 1
|
||
fps = float(cap.get(cv2.CAP_PROP_FPS))
|
||
if fps <= 0:
|
||
fps = 25.0
|
||
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
font_size = max(18, int(h * 0.028))
|
||
font_path = args.font_path.resolve() if args.font_path is not None else None
|
||
pil_font, font_used = load_pil_font(font_path, font_size)
|
||
if font_used is not None:
|
||
log(f"标签字体: {font_used}")
|
||
elif pil_font is not None:
|
||
log("标签字体: Pillow 默认字体(可能不支持中文)")
|
||
else:
|
||
log("标签字体: 回退 OpenCV 内置字体(中文可能显示异常)")
|
||
|
||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||
writer = cv2.VideoWriter(str(out_path), fourcc, fps, (w, h))
|
||
if not writer.isOpened():
|
||
cap.release()
|
||
print(f"无法创建视频写入器: {out_path}", file=sys.stderr)
|
||
return 1
|
||
|
||
predict_kw: dict[str, Any] = {"device": args.device}
|
||
if bool(args.half):
|
||
predict_kw["half"] = True
|
||
|
||
frame_idx = 0
|
||
seg_idx = 0
|
||
n_drawn = 0
|
||
try:
|
||
while True:
|
||
ok, frame = cap.read()
|
||
if not ok or frame is None:
|
||
break
|
||
frame_idx += 1
|
||
t_sec = frame_idx / fps
|
||
|
||
seg_idx, seg = active_segment_at(segs, seg_idx, t_sec)
|
||
if seg is not None:
|
||
r0 = det.predict(
|
||
frame,
|
||
conf=float(args.det_conf),
|
||
imgsz=int(args.imgsz_det),
|
||
verbose=False,
|
||
**predict_kw,
|
||
)[0]
|
||
hands = collect_hand_boxes(det, r0.boxes) if r0.boxes else []
|
||
fused = fused_box_padded(frame, hands, grouper)
|
||
if fused is not None:
|
||
draw_label_box(frame, fused, seg.top1_name, pil_font)
|
||
n_drawn += 1
|
||
if doctor_info_text:
|
||
draw_bottom_right_info(frame, doctor_info_text, pil_font)
|
||
writer.write(frame)
|
||
|
||
if frame_idx % 200 == 0:
|
||
log(f"处理中: {frame_idx}/{max(total, 1)} 帧")
|
||
finally:
|
||
writer.release()
|
||
cap.release()
|
||
|
||
log(f"完成: 输出 {out_path}")
|
||
log(f"共绘制 {n_drawn} 帧融合框(总帧 {frame_idx})")
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|