Files
OperationRoomMonitor/code/dataset.py
2026-06-02 16:59:42 +08:00

1422 lines
46 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
从 ~/data/haocai/ 递归扫描「叶子会话目录」(含 mp4 + xlsx且子目录中不再含 mp4
按 Excel 中的时间段从对应视频抽帧,输出到「输出根/images/<商品名称>/<规格>/」并生成 JSON 元数据。
输出分辨率默认与源视频帧一致;可用 --max-width / --max-height 限制最大尺寸(仅缩小、不放大)。
可选 --sample-every N按全局成功保存顺序每第 N 张在 JSON 中标记 sample=true便于抽检
可选 --limit N最多生成 N 条(图片或片段),用于快速检查 JSON 格式0 表示不限制。
可选 --extract-backend抽帧方式。默认 auto有 ffmpeg 则用 ffmpeg。默认精确 seek-ss 在 -i 之后);
可加 --ffmpeg-fast-seek 换快 seek部分 HEVC/H.265 文件会得到全灰无效帧,脚本会自动改回精确 seek 重试)。
建议安装 ffprobe 与 ffmpeg时长/帧率以 ffprobe 为准。
可选 --detect-bbox用 Grounding DINOtransformers + torch检测人体并输出 bbox 到 JSON。
可选 --save-vis在输出根下单独目录默认 vis/)生成与 images 同结构的 *_vis.jpg框与英文类别叠加在图上。
列约定(与样本数据一致):
- 单个 xlsx、两个视频约 AJ表头含「视频1」「视频2」时间段列常见为第 9、10 列)。
- 单个 xlsx、一个视频约 AI最后一列为「视频内时间段」。
- 两个 xlsx、两个视频每个文件 AI最后一列为该视频「视频内时间段」按文件名中的 01/02 与视频配对。
"""
from __future__ import annotations
import argparse
import hashlib
import json
import re
import shutil
import subprocess
import sys
import time
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Callable, Iterator, Optional
import cv2
import numpy as np
import pandas as pd
# 临时 / 锁文件
_IGNORE_XLSX = re.compile(r"^~\$|^\._|^\.\~", re.I)
def _log(msg: str) -> None:
"""运行日志stderr立即刷新"""
ts = time.strftime("%H:%M:%S")
print(f"[{ts}] {msg}", file=sys.stderr, flush=True)
@dataclass
class ImageRecord:
name: str
path: str
label_category: str # 商品名称
size: str # 规格
sample: bool = False # 每第 N 张(见 --sample-every为 True
# YOLO 格式 [x_center, y_center, w, h] 归一化 01未启用检测或未检出时为 None
bbox_xywhn: Optional[list[float]] = None
detection_score: Optional[float] = None
@dataclass
class VideoMeta:
"""视频流元数据;优先来自 ffprobe比 OpenCV 对 HEVC/VFR 更可靠)。"""
width: int
height: int
fps: float
duration_sec: float
frame_count: int = 0
def _parse_fraction(s: str) -> float:
s = (s or "").strip()
if not s or s == "0/0":
return 0.0
if "/" in s:
a, b = s.split("/", 1)
try:
den = float(b)
return float(a) / den if den else 0.0
except ValueError:
return 0.0
try:
return float(s)
except ValueError:
return 0.0
def _ffprobe_video_meta(path: Path, ffprobe_bin: str) -> Optional[VideoMeta]:
if not shutil.which(ffprobe_bin):
return None
cmd = [
ffprobe_bin,
"-v",
"error",
"-select_streams",
"v:0",
"-show_entries",
"stream=width,height,avg_frame_rate,r_frame_rate,nb_frames,duration",
"-show_entries",
"format=duration",
"-of",
"json",
str(path),
]
try:
p = subprocess.run(
cmd, capture_output=True, text=True, timeout=60, check=False
)
except (subprocess.TimeoutExpired, OSError):
return None
if p.returncode != 0 or not p.stdout:
return None
try:
data = json.loads(p.stdout)
except json.JSONDecodeError:
return None
streams = data.get("streams") or []
if not streams:
return None
st = streams[0]
w = int(st.get("width") or 0)
h = int(st.get("height") or 0)
if w < 2 or h < 2:
return None
fps = _parse_fraction(str(st.get("avg_frame_rate") or ""))
if fps <= 0:
fps = _parse_fraction(str(st.get("r_frame_rate") or ""))
dur_s = float(st.get("duration") or 0.0)
fmt = data.get("format") or {}
if dur_s <= 0:
dur_s = float(fmt.get("duration") or 0.0)
nbf = st.get("nb_frames")
frame_count = 0
if nbf is not None and str(nbf).strip() and str(nbf).upper() != "N/A":
try:
frame_count = int(nbf)
except (TypeError, ValueError):
frame_count = 0
if frame_count <= 0 and dur_s > 0 and fps > 0:
frame_count = int(round(dur_s * fps))
if fps <= 0 and dur_s > 0 and frame_count > 0:
fps = frame_count / dur_s
if fps <= 0:
fps = 25.0
return VideoMeta(
width=w,
height=h,
fps=float(fps),
duration_sec=float(dur_s),
frame_count=frame_count,
)
def _opencv_video_meta(path: Path) -> VideoMeta:
cap = cv2.VideoCapture(str(path), cv2.CAP_FFMPEG)
if not cap.isOpened():
return VideoMeta(0, 0, 25.0, 0.0, 0)
try:
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
fps = float(cap.get(cv2.CAP_PROP_FPS) or 25.0)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
duration = (
(frame_count / fps) if fps > 0 and frame_count > 0 else 0.0
)
return VideoMeta(
width=w, height=h, fps=fps, duration_sec=duration, frame_count=frame_count
)
finally:
cap.release()
# 同一视频在一张表里会抽多次帧;缓存 ffprobe 结果,避免每个时间点都跑一遍 ffprobe。
_VIDEO_META_CACHE: dict[tuple[str, str], VideoMeta] = {}
def get_video_meta(path: Path, ffprobe_bin: str = "ffprobe") -> VideoMeta:
key = (str(Path(path).resolve()), ffprobe_bin)
if key in _VIDEO_META_CACHE:
return _VIDEO_META_CACHE[key]
m = _ffprobe_video_meta(path, ffprobe_bin)
if m is not None:
_VIDEO_META_CACHE[key] = m
return m
m = _opencv_video_meta(path)
_VIDEO_META_CACHE[key] = m
return m
def _clamp_time_sec(t_sec: float, meta: VideoMeta) -> float:
if meta.duration_sec > 0:
margin = 1.0 / max(meta.fps, 1.0)
return float(
min(max(0.0, t_sec), max(0.0, meta.duration_sec - margin))
)
return max(0.0, t_sec)
def _time_to_frame_index(t_sec: float, meta: VideoMeta) -> int:
fps = meta.fps if meta.fps > 0 else 25.0
t = _clamp_time_sec(t_sec, meta)
idx = int(round(t * fps))
if meta.frame_count > 0:
idx = min(idx, meta.frame_count - 1)
return max(0, idx)
def _expand_root(p: str | Path) -> Path:
return Path(p).expanduser().resolve()
def _is_real_xlsx(path: Path) -> bool:
if path.suffix.lower() not in (".xlsx", ".xls"):
return False
name = path.name
if name.startswith("~$") or name.startswith(".~"):
return False
if _IGNORE_XLSX.search(name):
return False
return True
def _is_real_mp4(path: Path) -> bool:
if path.suffix.lower() != ".mp4":
return False
if ".crdownload" in path.name.lower():
return False
return True
def _dir_has_mp4_recursive(d: Path) -> bool:
if not d.is_dir():
return False
try:
for p in d.rglob("*.mp4"):
if _is_real_mp4(p):
return True
except OSError:
pass
return False
def iter_leaf_session_dirs(root: Path) -> Iterator[Path]:
"""叶子目录:直接包含至少一个有效 mp4 与 xlsx且其子目录内不再出现 mp4。"""
import os
root = root.resolve()
if not root.is_dir():
return
for dirpath, dirnames, filenames in os.walk(root, topdown=True):
p = Path(dirpath)
mp4s = [p / f for f in filenames if _is_real_mp4(p / f)]
xlsxs = [p / f for f in filenames if _is_real_xlsx(p / f)]
if not mp4s or not xlsxs:
continue
sub_has_mp4 = False
for sub in dirnames:
if _dir_has_mp4_recursive(p / sub):
sub_has_mp4 = True
break
if sub_has_mp4:
continue
yield p
def _video_sort_key(path: Path) -> tuple:
stem = path.stem
m = re.search(r"(\d+)", stem)
n = int(m.group(1)) if m else 10**9
return (n, stem.lower())
def list_videos(session_dir: Path) -> list[Path]:
vids = [p for p in session_dir.iterdir() if p.is_file() and _is_real_mp4(p)]
return sorted(vids, key=_video_sort_key)
def list_excels(session_dir: Path) -> list[Path]:
xs = [p for p in session_dir.iterdir() if p.is_file() and _is_real_xlsx(p)]
return sorted(xs, key=lambda p: p.name.lower())
def _excel_pair_key(path: Path) -> tuple:
m = re.search(r"(\d+)", path.stem)
n = int(m.group(1)) if m else 10**9
return (n, path.name.lower())
def _normalize_header(s: Any) -> str:
if s is None or (isinstance(s, float) and pd.isna(s)):
return ""
return str(s).strip()
def _find_col(df: pd.DataFrame, *candidates: str) -> str | None:
cols = [str(c).strip() for c in df.columns]
for want in candidates:
for c in df.columns:
h = _normalize_header(c)
if h == want or want in h:
return c
return None
def normalize_haocai_class_name(name: str) -> str:
"""
与 build_haocai_dataset_hand_crops.row_product 保持一致的类名归一。
Excel 与训练类名在个别耗材上同物异名,此处合并为同一条目。
"""
s = (name or "").strip()
if s == "一次性使用灭菌棉签":
return "一次性医用灭菌棉签"
if s in (
"一次性使用手术衣",
"一次性使用手术单(一次性医用垫单)",
"一次性医用垫单",
):
return "一次性使用手术单"
return s
def parse_time_range(text: Any) -> tuple[float, float] | None:
"""
支持:
- 1.23-2.23 → 1 分 23 秒 到 2 分 23 秒
- 0.05-0.11 → 0 分 5 秒 到 0 分 11 秒(点后为两位秒)
- 0010-0016 / 00:10-00:16 → mm:ss
"""
if text is None or (isinstance(text, float) and pd.isna(text)):
return None
s = str(text).strip()
if not s or s.lower() == "nan":
return None
# 全角冒号
s = s.replace("", ":")
# mm:ss - mm:ss
m = re.match(
r"^\s*(\d{1,2}):(\d{2})\s*[-–—~]\s*(\d{1,2}):(\d{2})\s*$",
s,
)
if m:
h1, m1, h2, m2 = m.groups()
a = int(h1) * 60 + int(m1)
b = int(h2) * 60 + int(m2)
return (float(min(a, b)), float(max(a, b)))
# M.SS - M.SS分.秒,秒为 12 位时按两位秒理解)
m = re.match(
r"^\s*(\d+)\s*\.\s*(\d{1,2})\s*[-–—~]\s*(\d+)\s*\.\s*(\d{1,2})\s*$",
s,
)
if m:
mm1, ss1, mm2, ss2 = m.groups()
ss1 = ss1.zfill(2)[:2]
ss2 = ss2.zfill(2)[:2]
a = int(mm1) * 60 + int(ss1)
b = int(mm2) * 60 + int(ss2)
return (float(min(a, b)), float(max(a, b)))
return None
def _midpoint_seconds(start: float, end: float) -> float:
return max(0.0, (start + end) / 2.0)
def _sample_time_in_tear_segment(
start: float,
end: float,
*,
mode: str = "tear_first_half",
) -> float:
"""
在 Excel 标注的「撕」时间段 [start, end] 内选取抽帧时刻。
- tear_first_half默认落在区间**前半段**,取该半段内 3/4 分位
t = start + 0.375 * (end - start),与「后半段 3/4」对称。
- tear_second_half整段的后 3/4 分位 t = start + 0.75 * (end - start)。
- midpoint取 (start+end)/2。
"""
if end <= start:
return max(0.0, start)
span = end - start
if mode == "midpoint":
return _midpoint_seconds(start, end)
if mode == "tear_second_half":
return max(0.0, start + 0.75 * span)
# tear_first_half
return max(0.0, start + 0.375 * span)
def resize_frame_to_max(
frame: Any,
max_width: int,
max_height: int,
) -> Any:
"""
将帧限制在 max_width×max_height 以内,保持宽高比。
max_width / max_height 为 0 表示该方向不限制;二者均为 0 则返回原帧(原始分辨率)。
仅缩小不放大。
"""
if frame is None:
return None
if max_width <= 0 and max_height <= 0:
return frame
h, w = frame.shape[:2]
scales: list[float] = []
if max_width > 0:
scales.append(max_width / w)
if max_height > 0:
scales.append(max_height / h)
if not scales:
return frame
scale = min(scales)
scale = min(scale, 1.0)
if scale >= 1.0:
return frame
nw = max(1, int(round(w * scale)))
nh = max(1, int(round(h * scale)))
return cv2.resize(frame, (nw, nh), interpolation=cv2.INTER_AREA)
def save_frame_jpeg(
frame: Any,
out_path: Path,
jpeg_quality: int = 85,
max_width: int = 0,
max_height: int = 0,
) -> tuple[bool, Optional[np.ndarray]]:
"""按 max_width/max_height 可选缩小后以 JPEG 写出;返回 (是否成功, 与磁盘一致的 BGR 图)。"""
img = resize_frame_to_max(frame, max_width, max_height)
if img is None:
return False, None
params = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality]
ok = bool(cv2.imwrite(str(out_path), img, params))
return ok, img if ok else None
def save_bbox_vis_jpeg(
img_bgr: np.ndarray,
out_path: Path,
bbox_xywhn: Optional[list[float]],
detection_score: Optional[float],
jpeg_quality: int = 85,
) -> bool:
"""在副本上画框后保存为 JPEG。bbox_xywhn 为 YOLO 格式归一化 [cx, cy, w, h]。"""
vis = img_bgr.copy()
h, w = vis.shape[:2]
if bbox_xywhn and len(bbox_xywhn) == 4:
cx, cy, bw, bh = bbox_xywhn
x1 = int(round((cx - bw / 2) * w))
y1 = int(round((cy - bh / 2) * h))
x2 = int(round((cx + bw / 2) * w))
y2 = int(round((cy + bh / 2) * h))
x1 = max(0, min(x1, w - 1))
x2 = max(0, min(x2, w - 1))
y1 = max(0, min(y1, h - 1))
y2 = max(0, min(y2, h - 1))
cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 220, 0), max(1, min(w, h) // 400))
cap = f"{detection_score:.2f}" if detection_score is not None else "det"
(tw, th), _ = cv2.getTextSize(cap, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
ty = max(y1 - 4, th + 4)
cv2.rectangle(vis, (x1, ty - th - 4), (x1 + tw + 4, ty + 2), (0, 220, 0), -1)
cv2.putText(vis, cap, (x1 + 2, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
else:
cv2.putText(vis, "no detection", (8, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (128, 128, 128), 2, cv2.LINE_AA)
out_path.parent.mkdir(parents=True, exist_ok=True)
params = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality]
return bool(cv2.imwrite(str(out_path), vis, params))
def _write_vis_if_enabled(
vis_out_root: Optional[Path],
label_category: str,
size: str,
fname: str,
img_bgr: np.ndarray,
bbox_xywhn: Optional[list[float]],
detection_score: Optional[float],
) -> None:
if vis_out_root is None:
return
vis_dir = _product_image_dir(vis_out_root, label_category, size)
vis_dir.mkdir(parents=True, exist_ok=True)
vis_path = vis_dir / f"{Path(fname).stem}_vis.jpg"
save_bbox_vis_jpeg(img_bgr, vis_path, bbox_xywhn, detection_score)
def _clip_xyxy_xyxy(
xyxy: list[float], w: int, h: int
) -> list[float]:
x1, y1, x2, y2 = xyxy
x1 = float(max(0, min(x1, w - 1)))
x2 = float(max(0, min(x2, w)))
y1 = float(max(0, min(y1, h - 1)))
y2 = float(max(0, min(y2, h)))
if x2 <= x1:
x2 = min(x1 + 1.0, float(w))
if y2 <= y1:
y2 = min(y1 + 1.0, float(h))
return [x1, y1, x2, y2]
def _xyxy_to_xywhn(xyxy: list[float], w: int, h: int) -> list[float]:
"""xyxy 像素 → YOLO [x_center, y_center, width, height] 归一化 01。"""
x1, y1, x2, y2 = xyxy
bw = x2 - x1
bh = y2 - y1
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
return [cx / w, cy / h, bw / w, bh / h]
class GroundingDinoDetector:
"""
使用 Grounding DINOHuggingFace transformers做开放词汇检测。
返回得分最高的一个框YOLO 格式 [cx, cy, w, h] 归一化 + 分数。
"""
def __init__(
self,
model_id: str = "IDEA-Research/grounding-dino-base",
prompt: str = "person .",
box_threshold: float = 0.30,
text_threshold: float = 0.25,
) -> None:
import torch
from PIL import Image as _PILImage # noqa: F401
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
self._torch = torch
self._PILImage = _PILImage
self._device = "cuda" if torch.cuda.is_available() else "cpu"
self._processor = AutoProcessor.from_pretrained(model_id)
self._model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(self._device)
self._model.eval()
self.prompt = prompt
self.box_threshold = box_threshold
self.text_threshold = text_threshold
_log(f"GroundingDinoDetector loaded: {model_id} on {self._device}")
def detect(self, img_bgr: np.ndarray) -> tuple[
Optional[list[float]],
Optional[float],
]:
h, w = img_bgr.shape[:2]
if w < 2 or h < 2:
return None, None
rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
pil = self._PILImage.fromarray(rgb)
with self._torch.no_grad():
inputs = self._processor(images=pil, text=self.prompt, return_tensors="pt").to(self._device)
outputs = self._model(**inputs)
target_sizes = self._torch.tensor([[h, w]], device=self._device)
try:
results = self._processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
threshold=self.box_threshold,
text_threshold=self.text_threshold,
target_sizes=target_sizes,
)[0]
except TypeError:
results = self._processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold,
target_sizes=target_sizes,
)[0]
if results is None or len(results["boxes"]) == 0:
return None, None
best_idx = int(results["scores"].argmax().item())
b = results["boxes"][best_idx].tolist()
score = float(results["scores"][best_idx].item())
xyxy = _clip_xyxy_xyxy([float(b[0]), float(b[1]), float(b[2]), float(b[3])], w, h)
xywhn = _xyxy_to_xywhn(xyxy, w, h)
return xywhn, score
def _is_degenerate_gray_frame(img: np.ndarray) -> bool:
"""ffmpeg 快 seek 在部分 HEVC 码流上可能输出近似中性灰、几乎无纹理的无效帧。"""
if img is None or img.size == 0:
return True
m = float(np.mean(img))
s = float(np.std(img))
return 118.0 <= m <= 138.0 and s < 8.0
def extract_frame_ffmpeg(
video_path: Path,
t_sec: float,
*,
ffmpeg_bin: str = "ffmpeg",
ffprobe_bin: str = "ffprobe",
accurate_seek: bool = True,
timeout_sec: float = 600.0,
) -> np.ndarray | None:
"""
使用 ffmpeg 解码单帧。时间戳 clamp 优先用 ffprobe避免 OpenCV 对 HEVC 的 fps/时长偏差。
accurate_seek=True默认-ss 在 -i 之后,解码正确,长视频较慢。
accurate_seek=False-ss 在 -i 之前,快,少数文件仍可能异常。
"""
if not shutil.which(ffmpeg_bin):
return None
meta = get_video_meta(video_path, ffprobe_bin)
if meta.width < 2 or meta.height < 2:
return None
t_clamped = _clamp_time_sec(t_sec, meta)
w, h = meta.width, meta.height
expected_raw = w * h * 3
def _run_ffmpeg(cmd: list[str]) -> tuple[Optional[bytes], Optional[str]]:
try:
p = subprocess.run(
cmd,
capture_output=True,
timeout=timeout_sec,
check=False,
)
except subprocess.TimeoutExpired:
return None, "timeout"
err = (p.stderr or b"").decode("utf-8", errors="replace")[:800]
if p.returncode != 0:
return None, err or f"exit {p.returncode}"
if not p.stdout:
return None, err or "empty stdout"
return p.stdout, None
def _decode_png(data: bytes) -> Optional[np.ndarray]:
arr = np.frombuffer(data, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
return img
# 1) 精确 seek + PNG通用
if accurate_seek:
cmd_png = [
ffmpeg_bin,
"-hide_banner",
"-loglevel",
"error",
"-i",
str(video_path),
"-ss",
f"{t_clamped:.6f}",
"-frames:v",
"1",
"-an",
"-f",
"image2pipe",
"-vcodec",
"png",
"-",
]
else:
cmd_png = [
ffmpeg_bin,
"-hide_banner",
"-loglevel",
"error",
"-ss",
f"{t_clamped:.6f}",
"-i",
str(video_path),
"-frames:v",
"1",
"-an",
"-f",
"image2pipe",
"-vcodec",
"png",
"-",
]
out, err = _run_ffmpeg(cmd_png)
if out is not None:
img = _decode_png(out)
if img is not None and img.size > 0:
if not accurate_seek and _is_degenerate_gray_frame(img):
_log(
f"快 seek 输出疑似灰帧,改用精确 seek: {video_path.name} t={t_clamped:.2f}s"
)
return extract_frame_ffmpeg(
video_path,
t_sec,
ffmpeg_bin=ffmpeg_bin,
ffprobe_bin=ffprobe_bin,
accurate_seek=True,
timeout_sec=timeout_sec,
)
return img
if err and err != "timeout":
_log(f"ffmpeg PNG 解码失败: {video_path.name}: {err[:200]}")
# 2) 精确 seek + raw BGR避免 PNG 编解码;尺寸来自 ffprobe
cmd_raw = [
ffmpeg_bin,
"-hide_banner",
"-loglevel",
"error",
"-i",
str(video_path),
"-ss",
f"{t_clamped:.6f}",
"-frames:v",
"1",
"-an",
"-f",
"rawvideo",
"-pix_fmt",
"bgr24",
"-s",
f"{w}x{h}",
"-",
]
if not accurate_seek:
cmd_raw = [
ffmpeg_bin,
"-hide_banner",
"-loglevel",
"error",
"-ss",
f"{t_clamped:.6f}",
"-i",
str(video_path),
"-frames:v",
"1",
"-an",
"-f",
"rawvideo",
"-pix_fmt",
"bgr24",
"-s",
f"{w}x{h}",
"-",
]
out2, err2 = _run_ffmpeg(cmd_raw)
if out2 is not None and len(out2) == expected_raw:
img = np.frombuffer(out2, dtype=np.uint8).reshape((h, w, 3)).copy()
if not accurate_seek and _is_degenerate_gray_frame(img):
_log(
f"快 seek raw 疑似灰帧,改用精确 seek: {video_path.name} t={t_clamped:.2f}s"
)
return extract_frame_ffmpeg(
video_path,
t_sec,
ffmpeg_bin=ffmpeg_bin,
ffprobe_bin=ffprobe_bin,
accurate_seek=True,
timeout_sec=timeout_sec,
)
return img
if err2 and err2 != "timeout":
_log(f"ffmpeg rawvideo 失败: {video_path.name}: {err2[:200]}")
return None
def extract_frame_opencv_sequential(
video_path: Path,
t_sec: float,
ffprobe_bin: str = "ffprobe",
) -> Any | None:
"""
从第 0 帧顺序读到目标帧;帧索引由 ffprobe 元数据计算(比仅用 OpenCV fps 更稳)。
"""
meta = get_video_meta(video_path, ffprobe_bin)
target_idx = _time_to_frame_index(t_sec, meta)
cap = cv2.VideoCapture(str(video_path), cv2.CAP_FFMPEG)
if not cap.isOpened():
return None
try:
frame: Any | None = None
for _ in range(target_idx + 1):
ok, frame = cap.read()
if not ok or frame is None:
return None
return frame
finally:
cap.release()
def make_extract_frame_fn(
backend: str,
ffmpeg_bin: str,
ffprobe_bin: str,
accurate_seek: bool,
) -> tuple[Callable[[Path, float], Any | None], str]:
"""
返回 (抽帧函数, 实际后端说明)。
auto有 ffmpeg 用 ffmpeg否则 OpenCV 顺序解码。
"""
b = backend.strip().lower()
if b == "auto":
b = "ffmpeg" if shutil.which(ffmpeg_bin) else "opencv"
if b == "ffmpeg" and not shutil.which(ffmpeg_bin):
_log(f"未找到 {ffmpeg_bin!r},改用 OpenCV 顺序解码(较慢)")
b = "opencv"
if b == "ffmpeg":
def fn_ffmpeg(p: Path, t: float) -> Any | None:
img = extract_frame_ffmpeg(
p,
t,
ffmpeg_bin=ffmpeg_bin,
ffprobe_bin=ffprobe_bin,
accurate_seek=accurate_seek,
)
if img is None:
return extract_frame_opencv_sequential(p, t, ffprobe_bin)
return img
mode = "ffmpeg_accurate" if accurate_seek else "ffmpeg_fast"
return fn_ffmpeg, mode
def fn_cv_only(p: Path, t: float) -> Any | None:
return extract_frame_opencv_sequential(p, t, ffprobe_bin)
return fn_cv_only, "opencv_sequential"
def _unique_image_name(
session_rel: str,
row_idx: int,
video_tag: str,
time_raw: str,
ext: str = ".jpg",
) -> str:
h = hashlib.sha1(
f"{session_rel}|{row_idx}|{video_tag}|{time_raw}".encode("utf-8")
).hexdigest()[:16]
safe = re.sub(r"[^\w\u4e00-\u9fff\-]+", "_", session_rel)[-80:]
return f"{safe}__r{row_idx}_{video_tag}_{h}{ext}"
def _sanitize_dir_segment(text: Any, fallback: str) -> str:
"""目录名:去掉路径非法字符,过长截断;空则用 fallback。"""
if text is None:
return fallback
if isinstance(text, float) and pd.isna(text):
return fallback
t = str(text).strip()
if not t:
return fallback
t = re.sub(r'[/\\:\0<>"|?*]+', "_", t)
t = t.strip(" .")
if not t or all(c == "." for c in t):
return fallback
max_len = 180
if len(t) > max_len:
t = t[:max_len].rstrip()
return t or fallback
def _product_image_dir(
images_out: Path, label_category: str, size: str
) -> Path:
"""images/<商品名称>/<规格>/"""
d_name = _sanitize_dir_segment(label_category, "未命名商品")
d_spec = _sanitize_dir_segment(size, "未填规格")
return images_out / d_name / d_spec
def _read_excel(path: Path) -> pd.DataFrame:
return pd.read_excel(path, header=0)
def _limit_reached(records: list[ImageRecord], limit: int) -> bool:
"""limit>0 且已保存条数达到上限时返回 True。"""
return limit > 0 and len(records) >= limit
def _record_saved(
records: list[ImageRecord],
global_idx: list[int],
sample_every: int,
fname: str,
out_path: Path,
label_category: str,
size: str,
bbox_xywhn: Optional[list[float]] = None,
detection_score: Optional[float] = None,
) -> None:
"""global_idx[0] 为已成功保存张数;每第 sample_every 张标记 sampleN=10 → 第 10、20… 张)。"""
global_idx[0] += 1
sample = bool(
sample_every > 0 and global_idx[0] % sample_every == 0
)
records.append(
ImageRecord(
name=fname,
path=str(out_path.resolve()),
label_category=label_category,
size=size,
sample=sample,
bbox_xywhn=bbox_xywhn,
detection_score=detection_score,
)
)
def _bbox_from_detector(
detector: Optional[GroundingDinoDetector],
img_bgr: Optional[np.ndarray],
) -> tuple[Optional[list[float]], Optional[float]]:
if detector is None or img_bgr is None:
return None, None
return detector.detect(img_bgr)
def process_session(
session_dir: Path,
data_root: Path,
images_out: Path,
records: list[ImageRecord],
global_idx: list[int],
sample_every: int,
limit: int = 0,
max_width: int = 0,
max_height: int = 0,
bbox_detector: Optional[GroundingDinoDetector] = None,
vis_out_root: Optional[Path] = None,
extract_frame_fn: Callable[[Path, float], Any | None] = extract_frame_opencv_sequential,
time_sample_mode: str = "tear_first_half",
) -> int:
"""处理一个叶子目录返回成功写入的图片数量。limit>0 时最多再写入到总条数达 limit。"""
videos = list_videos(session_dir)
excels = list_excels(session_dir)
if not videos or not excels:
return 0
session_rel = str(session_dir.relative_to(data_root))
n_ok = 0
def row_product(row: pd.Series, df: pd.DataFrame) -> tuple[str, str]:
c_name = _find_col(df, "商品名称")
c_spec = _find_col(df, "规格")
name = ""
spec = ""
if c_name is not None:
v = row.get(c_name)
if v is not None and not (isinstance(v, float) and pd.isna(v)):
name = str(v).strip()
if c_spec is not None:
v = row.get(c_spec)
if v is not None and not (isinstance(v, float) and pd.isna(v)):
spec = str(v).strip()
return normalize_haocai_class_name(name), spec
# 两个 Excel + 两个视频:各读各表,按行与对应视频抽帧
if len(excels) >= 2 and len(videos) >= 2:
excel_list = sorted(excels, key=_excel_pair_key)
vid_list = sorted(videos, key=_video_sort_key)
pairs = min(len(excel_list), len(vid_list), 2)
for pi in range(pairs):
df = _read_excel(excel_list[pi])
vid = vid_list[pi]
time_col = _find_col(
df,
"视频内时间段",
"视频01内时间段",
"视频02内时间段",
)
if time_col is None:
# 最后一列常为时间
time_col = df.columns[-1]
for ri, (_, row) in enumerate(df.iterrows()):
if _limit_reached(records, limit):
return n_ok
tr = row.get(time_col)
pr = parse_time_range(tr)
if pr is None:
continue
t0, t1 = pr
label, size = row_product(row, df)
if not label and not size:
continue
t_mid = _sample_time_in_tear_segment(
t0, t1, mode=time_sample_mode
)
frame = extract_frame_fn(vid, t_mid)
if frame is None:
continue
fname = _unique_image_name(
session_rel, ri, f"v{pi + 1}", str(tr)
)
out_dir = _product_image_dir(images_out, label, size)
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / fname
saved, img_out = save_frame_jpeg(
frame,
out_path,
max_width=max_width,
max_height=max_height,
)
if saved:
bx, ds = _bbox_from_detector(bbox_detector, img_out)
_record_saved(
records, global_idx, sample_every,
fname, out_path, label, size,
bbox_xywhn=bx, detection_score=ds,
)
_write_vis_if_enabled(
vis_out_root, label, size, fname, img_out, bx, ds,
)
n_ok += 1
if _limit_reached(records, limit):
return n_ok
return n_ok
# 单个 Excel
if len(excels) == 1:
df = _read_excel(excels[0])
c_v1 = _find_col(df, "视频1内时间段", "视频01内时间段")
c_v2 = _find_col(df, "视频2内时间段", "视频02内时间段")
if len(videos) >= 2 and c_v1 is not None and c_v2 is not None:
vid_list = sorted(videos, key=_video_sort_key)[:2]
for ri, (_, row) in enumerate(df.iterrows()):
for vi, (c_time, vid) in enumerate(
zip([c_v1, c_v2], vid_list)
):
if _limit_reached(records, limit):
return n_ok
tr = row.get(c_time)
pr = parse_time_range(tr)
if pr is None:
continue
t_mid = _sample_time_in_tear_segment(
*pr, mode=time_sample_mode
)
frame = extract_frame_fn(vid, t_mid)
if frame is None:
continue
label, size = row_product(row, df)
fname = _unique_image_name(
session_rel, ri, f"v{vi + 1}", str(tr)
)
out_dir = _product_image_dir(images_out, label, size)
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / fname
saved, img_out = save_frame_jpeg(
frame,
out_path,
max_width=max_width,
max_height=max_height,
)
if saved:
bx, ds = _bbox_from_detector(bbox_detector, img_out)
_record_saved(
records, global_idx, sample_every,
fname, out_path, label, size,
bbox_xywhn=bx, detection_score=ds,
)
_write_vis_if_enabled(
vis_out_root, label, size, fname, img_out, bx, ds,
)
n_ok += 1
if _limit_reached(records, limit):
return n_ok
return n_ok
# 单视频:最后一列或「视频内时间段」
time_col = _find_col(df, "视频内时间段", "视频1内时间段")
if time_col is None:
time_col = df.columns[-1]
vid = vid_list[0] if (vid_list := sorted(videos, key=_video_sort_key)) else None
if vid is None:
return 0
for ri, (_, row) in enumerate(df.iterrows()):
if _limit_reached(records, limit):
return n_ok
tr = row.get(time_col)
pr = parse_time_range(tr)
if pr is None:
continue
t_mid = _sample_time_in_tear_segment(
*pr, mode=time_sample_mode
)
frame = extract_frame_fn(vid, t_mid)
if frame is None:
continue
label, size = row_product(row, df)
fname = _unique_image_name(session_rel, ri, "v1", str(tr))
out_dir = _product_image_dir(images_out, label, size)
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / fname
saved, img_out = save_frame_jpeg(
frame,
out_path,
max_width=max_width,
max_height=max_height,
)
if saved:
bx, ds = _bbox_from_detector(bbox_detector, img_out)
_record_saved(
records, global_idx, sample_every,
fname, out_path, label, size,
bbox_xywhn=bx, detection_score=ds,
)
_write_vis_if_enabled(
vis_out_root, label, size, fname, img_out, bx, ds,
)
n_ok += 1
if _limit_reached(records, limit):
return n_ok
return n_ok
# 其余情况:尝试用第一个 Excel + 第一个视频
if excels and videos:
df = _read_excel(excels[0])
time_col = _find_col(df, "视频内时间段") or df.columns[-1]
vid = sorted(videos, key=_video_sort_key)[0]
for ri, (_, row) in enumerate(df.iterrows()):
if _limit_reached(records, limit):
return n_ok
tr = row.get(time_col)
pr = parse_time_range(tr)
if pr is None:
continue
t_mid = _sample_time_in_tear_segment(
*pr, mode=time_sample_mode
)
frame = extract_frame_fn(vid, t_mid)
if frame is None:
continue
label, size = row_product(row, df)
fname = _unique_image_name(session_rel, ri, "v1", str(tr))
out_dir = _product_image_dir(images_out, label, size)
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / fname
saved, img_out = save_frame_jpeg(
frame,
out_path,
max_width=max_width,
max_height=max_height,
)
if saved:
bx, ds = _bbox_from_detector(bbox_detector, img_out)
_record_saved(
records, global_idx, sample_every,
fname, out_path, label, size,
bbox_xywhn=bx, detection_score=ds,
)
_write_vis_if_enabled(
vis_out_root, label, size, fname, img_out, bx, ds,
)
n_ok += 1
if _limit_reached(records, limit):
return n_ok
return n_ok
def main() -> int:
parser = argparse.ArgumentParser(description="浩材视频抽帧数据集生成")
parser.add_argument(
"--data-root",
type=str,
default="~/data/haocai",
help="数据根目录(默认 ~/data/haocai",
)
parser.add_argument(
"--output-dir",
type=str,
default="./haocai_dataset",
help="输出根目录(图片与 JSON 放在其下)",
)
parser.add_argument(
"--json-name",
type=str,
default="dataset.json",
help="JSON 文件名(位于 output-dir 下)",
)
parser.add_argument(
"--images-subdir",
type=str,
default="images",
help="图片子目录名(位于 output-dir 下)",
)
parser.add_argument(
"--sample-every",
type=int,
default=0,
metavar="N",
help="全局按保存顺序计数,每第 N 张在 JSON 中 sample=true0 表示全部 sample=false",
)
parser.add_argument(
"--limit",
type=int,
default=0,
metavar="N",
help="最多生成 N 条记录(与 JSON 条目数一致用于试跑检查格式0 表示不限制",
)
parser.add_argument(
"--max-width",
type=int,
default=0,
metavar="PX",
help="输出 JPEG 最大宽度像素0=不限制(默认,保持原始分辨率)",
)
parser.add_argument(
"--max-height",
type=int,
default=0,
metavar="PX",
help="输出 JPEG 最大高度像素0=不限制(默认)。与 --max-width 同时生效时缩放到可放入矩形内",
)
parser.add_argument(
"--detect-bbox",
action="store_true",
help="用 Grounding DINO 检测人体并写 bbox 到 JSON需 pip install transformers torch pillow",
)
parser.add_argument(
"--dino-model-id",
type=str,
default="IDEA-Research/grounding-dino-base",
metavar="ID",
help="Grounding DINO HuggingFace 模型 ID",
)
parser.add_argument(
"--dino-prompt",
type=str,
default="person .",
metavar="TEXT",
help="Grounding DINO 检测 prompt默认 'person .'",
)
parser.add_argument(
"--dino-box-threshold",
type=float,
default=0.30,
metavar="F",
help="Grounding DINO box 置信度阈值(默认 0.30",
)
parser.add_argument(
"--dino-text-threshold",
type=float,
default=0.25,
metavar="F",
help="Grounding DINO text 置信度阈值(默认 0.25",
)
parser.add_argument(
"--save-vis",
action="store_true",
help="在 output-dir 下写入可视化图(默认子目录 vis/),与 images 同目录结构,文件名为 <原名>_vis.jpg",
)
parser.add_argument(
"--vis-subdir",
type=str,
default="vis",
help="可视化 JPEG 所在子目录名(位于 output-dir 下,默认 vis",
)
parser.add_argument(
"--extract-backend",
type=str,
choices=("auto", "ffmpeg", "opencv"),
default="auto",
help="抽帧auto=有 ffmpeg 则用 ffmpeg推荐HEVC 不易花屏);"
"ffmpeg=必须可用 ffmpegopencv=顺序解码,无 ffmpeg 时可用但较慢",
)
parser.add_argument(
"--ffmpeg-bin",
type=str,
default="ffmpeg",
metavar="CMD",
help="ffmpeg 可执行文件名或绝对路径(默认 ffmpeg",
)
parser.add_argument(
"--ffprobe-bin",
type=str,
default="ffprobe",
metavar="CMD",
help="ffprobe 可执行文件名(用于时长/帧率/分辨率;默认 ffprobe",
)
parser.add_argument(
"--ffmpeg-fast-seek",
action="store_true",
help="快 seek-ss 在 -i 之前,长视频抽帧快很多;默认精确 seek 从开头解码到目标时刻,故很慢",
)
parser.add_argument(
"--sample-midpoint",
action="store_true",
help="时间段内抽帧取中点;默认取「撕」区间前半段(半段内 3/4 分位)",
)
parser.add_argument(
"--tear-second-half",
action="store_true",
help="撕时间段内用整段后半 3/4 分位(旧默认);与默认前半段二选一",
)
args = parser.parse_args()
if args.sample_every < 0:
print("--sample-every 须 >= 0", file=sys.stderr)
return 2
if args.limit < 0:
print("--limit 须 >= 0", file=sys.stderr)
return 2
if args.max_width < 0 or args.max_height < 0:
print("--max-width / --max-height 须 >= 0", file=sys.stderr)
return 2
bbox_detector: Optional[GroundingDinoDetector] = None
if args.detect_bbox:
try:
_log("Grounding DINO bbox detection enabled")
_log(
f"model={args.dino_model_id}, prompt={args.dino_prompt!r}, "
f"box_threshold={args.dino_box_threshold}, "
f"text_threshold={args.dino_text_threshold}"
)
bbox_detector = GroundingDinoDetector(
model_id=args.dino_model_id,
prompt=args.dino_prompt,
box_threshold=args.dino_box_threshold,
text_threshold=args.dino_text_threshold,
)
except Exception as e:
print(
f"启用 --detect-bbox 失败: {type(e).__name__}: {e}\n"
"请确认已安装: pip install transformers torch pillow",
file=sys.stderr,
)
return 2
data_root = _expand_root(args.data_root)
out_root = _expand_root(args.output_dir)
images_out = out_root / args.images_subdir
images_out.mkdir(parents=True, exist_ok=True)
vis_out_root: Optional[Path] = None
if args.save_vis:
vis_out_root = out_root / args.vis_subdir
vis_out_root.mkdir(parents=True, exist_ok=True)
records: list[ImageRecord] = []
global_idx = [0]
total = 0
sessions = list(iter_leaf_session_dirs(data_root))
if not sessions:
print(f"未找到叶子会话目录(需同时含 mp4 与 xlsx: {data_root}", file=sys.stderr)
if not shutil.which(args.ffprobe_bin):
_log(
f"未找到 {args.ffprobe_bin!r},时长/帧率将仅用 OpenCVHEVC 可能偏差);"
"建议: conda install ffmpeg 或 apt install ffmpeg"
)
extract_frame_fn, extract_mode = make_extract_frame_fn(
args.extract_backend,
args.ffmpeg_bin,
args.ffprobe_bin,
accurate_seek=not args.ffmpeg_fast_seek,
)
_log(f"抽帧后端: {extract_mode}")
if args.sample_midpoint:
time_sample_mode = "midpoint"
elif args.tear_second_half:
time_sample_mode = "tear_second_half"
else:
time_sample_mode = "tear_first_half"
_log(
"时间段采样: "
+ (
"中点(--sample-midpoint"
if time_sample_mode == "midpoint"
else (
"撕区间后半段 3/4--tear-second-half"
if time_sample_mode == "tear_second_half"
else "撕区间前半段(默认,半段内 3/4 分位)"
)
)
)
if extract_mode.startswith("ffmpeg") and not args.ffmpeg_fast_seek:
_log(
"精确 seek默认在长视频、大时间戳时很慢每次抽帧都会从文件开头解码到目标时刻。"
"若可接受略快 seek请加 --ffmpeg-fast-seek 加速。"
)
for sd in sorted(sessions):
if _limit_reached(records, args.limit):
break
n = process_session(
sd,
data_root,
images_out,
records,
global_idx,
args.sample_every,
args.limit,
args.max_width,
args.max_height,
bbox_detector,
vis_out_root,
extract_frame_fn=extract_frame_fn,
time_sample_mode=time_sample_mode,
)
total += n
print(f"{sd.relative_to(data_root)}: {n}")
json_path = out_root / args.json_name
payload = [asdict(r) for r in records]
json_path.write_text(
json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8"
)
lim_note = f"limit={args.limit}" if args.limit > 0 else ""
vis_note = (
f",可视化目录: {vis_out_root}"
if vis_out_root is not None
else ""
)
print(
f"共写入 {total} 张图片{lim_note}JSON 条目 {len(records)},元数据: {json_path}{vis_note}"
)
return 0
if __name__ == "__main__":
raise SystemExit(main())