460 lines
17 KiB
Python
460 lines
17 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
Ultralytics YOLO 测试脚本
|
|||
|
|
- validate 模式:在 dataset.yaml 上评测 mAP、精度/召回
|
|||
|
|
- predict 模式:对图片/文件夹/视频/摄像头进行推理并保存可视化
|
|||
|
|
- test 模式:在指定 source 数据上评估模型指标(mAP、精度、召回等)
|
|||
|
|
- annotation 模式:对图像进行检测并生成LabelMe格式的JSON标注文件(保存到与图像相同目录)
|
|||
|
|
|
|||
|
|
示例:
|
|||
|
|
# 在验证集上评测
|
|||
|
|
python3 detection/test_yolo.py \
|
|||
|
|
--weights runs/train/single_yolov8s_20250908_152257/weights/best.pt \
|
|||
|
|
--data ../datasets/yolo_dataset/dataset.yaml \
|
|||
|
|
--imgsz 640 \
|
|||
|
|
--mode validate
|
|||
|
|
|
|||
|
|
# 对文件夹做推理并保存结果
|
|||
|
|
python3 detection/test_yolo.py \
|
|||
|
|
--weights runs/train/yolov8s_strong_20250909_114704/weights/best.pt \
|
|||
|
|
--source ./datasets/l0_9.9_sum/images/test \
|
|||
|
|
--imgsz 640 \
|
|||
|
|
--conf 0.25 \
|
|||
|
|
--mode predict \
|
|||
|
|
--project runs/predict \
|
|||
|
|
--name demo
|
|||
|
|
|
|||
|
|
# 在指定数据上评估模型指标
|
|||
|
|
python3 detection/test_yolo.py \
|
|||
|
|
--weights runs/train/yolov8s_strong_20250909_114704/weights/best.pt \
|
|||
|
|
--source ./datasets/l0_9.9_sum/images/test \
|
|||
|
|
--imgsz 640 \
|
|||
|
|
--conf 0.25 \
|
|||
|
|
--mode test \
|
|||
|
|
--project runs/test \
|
|||
|
|
--name eval_test
|
|||
|
|
|
|||
|
|
# 生成LabelMe格式的JSON标注文件,并像 predict 模式一样输出预测可视化
|
|||
|
|
python3 detection/test_yolo.py \
|
|||
|
|
--weights runs/train/yolov8s_strong_20250909_114704/weights/best.pt \
|
|||
|
|
--source ./7009-7509\
|
|||
|
|
--imgsz 640 \
|
|||
|
|
--conf 0.25 \
|
|||
|
|
--mode annotation \
|
|||
|
|
--project runs/annotations\
|
|||
|
|
--name annotations14
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import argparse
|
|||
|
|
import numpy as np
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_args():
|
|||
|
|
ap = argparse.ArgumentParser(description="YOLO 模型测试/推理")
|
|||
|
|
ap.add_argument("--weights", required=True, help="权重路径,例如 runs/train/.../weights/best.pt")
|
|||
|
|
ap.add_argument("--mode", choices=["validate", "predict", "test", "annotation"], default="validate", help="测试模式")
|
|||
|
|
ap.add_argument("--data", default="", help="dataset.yaml 路径(用于 validate)")
|
|||
|
|
ap.add_argument("--source", default="", help="推理输入(用于 predict/test):图片/文件夹/视频/rtsp/0(摄像头)")
|
|||
|
|
ap.add_argument("--imgsz", type=int, default=640, help="输入尺寸")
|
|||
|
|
ap.add_argument("--conf", type=float, default=0.25, help="置信度阈值(预测)")
|
|||
|
|
ap.add_argument("--iou", type=float, default=0.45, help="NMS IOU 阈值(预测)")
|
|||
|
|
ap.add_argument("--device", default="", help="CUDA 设备,如 '0' 或 '0,1',留空自动")
|
|||
|
|
ap.add_argument("--project", default="runs/test", help="输出根目录")
|
|||
|
|
ap.add_argument("--name", default="", help="任务名(默认自动)")
|
|||
|
|
ap.add_argument("--save_txt", action="store_true", help="保存 YOLO txt 结果(预测)")
|
|||
|
|
ap.add_argument("--save_conf", action="store_true", help="在 txt 中保存置信度")
|
|||
|
|
ap.add_argument("--recursive", action="store_true", help="当 --source 为目录时递归查找(默认仅当前目录)")
|
|||
|
|
return ap.parse_args()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def list_images_in_dir(dir_path: Path, recursive: bool = False):
|
|||
|
|
exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
|
|||
|
|
if recursive:
|
|||
|
|
return [p for p in dir_path.rglob("*") if p.suffix.lower() in exts]
|
|||
|
|
else:
|
|||
|
|
return [p for p in dir_path.iterdir() if p.is_file() and p.suffix.lower() in exts]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def evaluate_on_source(model, source_path: str, args):
|
|||
|
|
"""
|
|||
|
|
在指定的 source 数据上评估模型指标
|
|||
|
|
"""
|
|||
|
|
from ultralytics.utils.metrics import ConfusionMatrix
|
|||
|
|
import torch
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
src = Path(source_path)
|
|||
|
|
if not src.exists():
|
|||
|
|
print(f"[错误] 源路径不存在: {src}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 收集所有图像文件
|
|||
|
|
if src.is_dir():
|
|||
|
|
image_files = list_images_in_dir(src, recursive=args.recursive)
|
|||
|
|
if not image_files:
|
|||
|
|
print(f"[错误] 目录下未找到图像: {src}")
|
|||
|
|
return None
|
|||
|
|
else:
|
|||
|
|
image_files = [src]
|
|||
|
|
|
|||
|
|
print(f"[测试] 找到 {len(image_files)} 个图像文件")
|
|||
|
|
|
|||
|
|
# 创建临时 dataset.yaml 用于评估
|
|||
|
|
import tempfile
|
|||
|
|
import yaml
|
|||
|
|
|
|||
|
|
# 假设标签文件在对应的 labels 目录中
|
|||
|
|
labels_dir = src.parent.parent / "labels" / src.name
|
|||
|
|
if not labels_dir.exists():
|
|||
|
|
print(f"[警告] 未找到标签目录: {labels_dir}")
|
|||
|
|
print("[测试] 将仅进行推理,无法计算精确指标")
|
|||
|
|
return run_inference_only(model, image_files, args)
|
|||
|
|
|
|||
|
|
# 创建临时 dataset.yaml
|
|||
|
|
temp_yaml = {
|
|||
|
|
'path': str(src.parent),
|
|||
|
|
'train': '', # 不使用
|
|||
|
|
'val': str(src.relative_to(src.parent)), # 使用 source 作为验证集
|
|||
|
|
'test': '',
|
|||
|
|
'nc': 1, # 假设单类检测
|
|||
|
|
'names': ['fish'] # 类别名称
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
|||
|
|
yaml.dump(temp_yaml, f)
|
|||
|
|
temp_yaml_path = f.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 使用 model.val 进行评估
|
|||
|
|
print("[测试] 开始评估...")
|
|||
|
|
metrics = model.val(
|
|||
|
|
data=temp_yaml_path,
|
|||
|
|
imgsz=args.imgsz,
|
|||
|
|
conf=args.conf,
|
|||
|
|
iou=args.iou,
|
|||
|
|
device=args.device if args.device else None,
|
|||
|
|
project=args.project,
|
|||
|
|
name=args.name or "test_eval",
|
|||
|
|
verbose=True,
|
|||
|
|
save_json=True, # 保存详细结果
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 打印详细指标
|
|||
|
|
print("\n" + "="*50)
|
|||
|
|
print("评估结果:")
|
|||
|
|
print("="*50)
|
|||
|
|
|
|||
|
|
# print number of images
|
|||
|
|
print(f"总图像数: {len(image_files)}")
|
|||
|
|
|
|||
|
|
if hasattr(metrics, 'box'):
|
|||
|
|
box_metrics = metrics.box
|
|||
|
|
print(f"mAP@0.5: {box_metrics.map50:.4f}")
|
|||
|
|
print(f"mAP@0.5:0.95: {box_metrics.map:.4f}")
|
|||
|
|
print(f"Precision: {box_metrics.mp:.4f}")
|
|||
|
|
print(f"Recall: {box_metrics.mr:.4f}")
|
|||
|
|
print(f"F1-Score: {2 * box_metrics.mp * box_metrics.mr / (box_metrics.mp + box_metrics.mr):.4f}")
|
|||
|
|
|
|||
|
|
if hasattr(metrics, 'speed'):
|
|||
|
|
speed = metrics.speed
|
|||
|
|
print(f"推理速度: {speed['inference']:.2f} ms/image")
|
|||
|
|
print(f"NMS速度: {speed['postprocess']:.2f} ms/image")
|
|||
|
|
|
|||
|
|
print("="*50)
|
|||
|
|
|
|||
|
|
return metrics
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[错误] 评估失败: {e}")
|
|||
|
|
print("[测试] 回退到仅推理模式...")
|
|||
|
|
return run_inference_only(model, image_files, args)
|
|||
|
|
finally:
|
|||
|
|
# 清理临时文件
|
|||
|
|
try:
|
|||
|
|
os.unlink(temp_yaml_path)
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_inference_only(model, image_files, args):
|
|||
|
|
"""
|
|||
|
|
仅进行推理,计算基本统计信息
|
|||
|
|
"""
|
|||
|
|
print("[测试] 仅推理模式 - 无法计算精确指标")
|
|||
|
|
|
|||
|
|
total_detections = 0
|
|||
|
|
total_images = len(image_files)
|
|||
|
|
confidences = []
|
|||
|
|
|
|||
|
|
for img_path in image_files:
|
|||
|
|
try:
|
|||
|
|
results = model.predict(
|
|||
|
|
source=str(img_path),
|
|||
|
|
imgsz=args.imgsz,
|
|||
|
|
conf=args.conf,
|
|||
|
|
iou=args.iou,
|
|||
|
|
device=args.device if args.device else None,
|
|||
|
|
verbose=False,
|
|||
|
|
save=True, # 同步保存Ultralytics自带的可视化结果
|
|||
|
|
project=args.project,
|
|||
|
|
name=args.name or "annotations",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if results and len(results) > 0:
|
|||
|
|
res = results[0]
|
|||
|
|
if hasattr(res, 'boxes') and res.boxes is not None:
|
|||
|
|
num_dets = len(res.boxes)
|
|||
|
|
total_detections += num_dets
|
|||
|
|
|
|||
|
|
# 收集置信度
|
|||
|
|
for box in res.boxes:
|
|||
|
|
conf = float(box.conf[0].tolist())
|
|||
|
|
confidences.append(conf)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[警告] 处理图像失败 {img_path}: {e}")
|
|||
|
|
|
|||
|
|
# 打印统计信息
|
|||
|
|
print("\n" + "="*50)
|
|||
|
|
print("推理统计:")
|
|||
|
|
print("="*50)
|
|||
|
|
print(f"总图像数: {total_images}")
|
|||
|
|
print(f"总检测数: {total_detections}")
|
|||
|
|
print(f"平均检测数: {total_detections/total_images:.2f}")
|
|||
|
|
|
|||
|
|
if confidences:
|
|||
|
|
confidences = np.array(confidences)
|
|||
|
|
print(f"平均置信度: {confidences.mean():.4f}")
|
|||
|
|
print(f"最高置信度: {confidences.max():.4f}")
|
|||
|
|
print(f"最低置信度: {confidences.min():.4f}")
|
|||
|
|
|
|||
|
|
print("="*50)
|
|||
|
|
print("[注意] 这是仅推理模式,无法计算 mAP 等精确指标")
|
|||
|
|
print(" 如需精确评估,请确保有对应的标签文件")
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_labelme_annotations(model, source_path: str, args):
|
|||
|
|
"""
|
|||
|
|
对图像进行检测并生成LabelMe格式的JSON标注文件(保存到 runs/<project>/<name>/),
|
|||
|
|
可视化结果由 Ultralytics 的 predict 保存(与 predict 模式一致)。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
model: YOLO模型
|
|||
|
|
source_path: 图像目录路径
|
|||
|
|
args: 命令行参数
|
|||
|
|
"""
|
|||
|
|
import json
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
src = Path(source_path)
|
|||
|
|
if not src.exists():
|
|||
|
|
print(f"[错误] 源路径不存在: {src}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 收集所有图像文件
|
|||
|
|
if src.is_dir():
|
|||
|
|
image_files = list_images_in_dir(src, recursive=args.recursive)
|
|||
|
|
if not image_files:
|
|||
|
|
print(f"[错误] 目录下未找到图像: {src}")
|
|||
|
|
return
|
|||
|
|
else:
|
|||
|
|
image_files = [src]
|
|||
|
|
|
|||
|
|
print(f"[标注] 找到 {len(image_files)} 个图像文件")
|
|||
|
|
|
|||
|
|
# 创建输出目录 runs/<project>/<name>/
|
|||
|
|
project_dir = Path(args.project) #/ (args.name or "annotations")
|
|||
|
|
project_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
output_dir = Path(args.project) / (args.name or "annotations")
|
|||
|
|
|
|||
|
|
# 流式推理并即时写出标注,避免一次性占用大量显存/内存
|
|||
|
|
print(f"[标注] 流式预测并保存可视化到: {output_dir}")
|
|||
|
|
processed_count = 0
|
|||
|
|
failed_count = 0
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
for res in model.predict(
|
|||
|
|
source=[str(p) for p in image_files],
|
|||
|
|
imgsz=args.imgsz,
|
|||
|
|
conf=args.conf,
|
|||
|
|
iou=args.iou,
|
|||
|
|
device=args.device if args.device else None,
|
|||
|
|
verbose=False,
|
|||
|
|
save=True,
|
|||
|
|
project=args.project,
|
|||
|
|
name=args.name or "annotations",
|
|||
|
|
stream=True,
|
|||
|
|
):
|
|||
|
|
try:
|
|||
|
|
# 结果对象自带路径与原图尺寸
|
|||
|
|
img_path = Path(getattr(res, 'path', ''))
|
|||
|
|
if not img_path:
|
|||
|
|
# 回退:无法获得路径时跳过
|
|||
|
|
failed_count += 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if hasattr(res, 'orig_shape') and res.orig_shape is not None:
|
|||
|
|
img_height, img_width = int(res.orig_shape[0]), int(res.orig_shape[1])
|
|||
|
|
else:
|
|||
|
|
original_img = cv2.imread(str(img_path))
|
|||
|
|
if original_img is None:
|
|||
|
|
print(f"[错误] 无法读取图像: {img_path}")
|
|||
|
|
failed_count += 1
|
|||
|
|
continue
|
|||
|
|
img_height, img_width = original_img.shape[:2]
|
|||
|
|
|
|||
|
|
labelme_data = {
|
|||
|
|
"version": "5.8.3",
|
|||
|
|
"flags": {},
|
|||
|
|
"shapes": [],
|
|||
|
|
"imagePath": img_path.name,
|
|||
|
|
"imageData": None,
|
|||
|
|
"imageHeight": img_height,
|
|||
|
|
"imageWidth": img_width
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if hasattr(res, 'boxes') and res.boxes is not None:
|
|||
|
|
for box in res.boxes:
|
|||
|
|
x1, y1, x2, y2 = [float(v) for v in box.xyxy[0].tolist()]
|
|||
|
|
conf = float(box.conf[0].tolist())
|
|||
|
|
cls = int(box.cls[0].tolist())
|
|||
|
|
shape = {
|
|||
|
|
"label": "鱿鱼",
|
|||
|
|
"points": [[x1, y1], [x2, y2]],
|
|||
|
|
"group_id": None,
|
|||
|
|
"description": f"confidence: {conf:.3f}",
|
|||
|
|
"shape_type": "rectangle",
|
|||
|
|
"flags": {},
|
|||
|
|
"mask": None
|
|||
|
|
}
|
|||
|
|
labelme_data["shapes"].append(shape)
|
|||
|
|
|
|||
|
|
json_filename = img_path.stem + ".json"
|
|||
|
|
json_path = output_dir / json_filename
|
|||
|
|
with open(json_path, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(labelme_data, f, ensure_ascii=False, indent=2)
|
|||
|
|
|
|||
|
|
print(f"[成功] {img_path.name} -> {json_filename} (检测到 {len(labelme_data['shapes'])} 个目标)")
|
|||
|
|
processed_count += 1
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[错误] 处理图像失败 {getattr(res, 'path', 'unknown')}: {e}")
|
|||
|
|
failed_count += 1
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[错误] 批处理预测失败: {e}")
|
|||
|
|
|
|||
|
|
# 输出统计信息
|
|||
|
|
print("\n" + "="*60)
|
|||
|
|
print("标注生成完成:")
|
|||
|
|
print("="*60)
|
|||
|
|
print(f"成功处理: {processed_count} 个图像")
|
|||
|
|
print(f"处理失败: {failed_count} 个图像")
|
|||
|
|
print(f"输出目录: {output_dir}")
|
|||
|
|
print("生成文件:")
|
|||
|
|
print(" - JSON标注文件: *.json")
|
|||
|
|
print(" - 预测可视化: 由 Ultralytics 自动保存 (与 predict 模式一致)")
|
|||
|
|
print("="*60)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
args = parse_args()
|
|||
|
|
try:
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
except Exception as e:
|
|||
|
|
print("[错误] 未找到 ultralytics,请先: pip install ultralytics")
|
|||
|
|
print(e)
|
|||
|
|
sys.exit(1)
|
|||
|
|
|
|||
|
|
# 检查matplotlib兼容性
|
|||
|
|
try:
|
|||
|
|
import matplotlib
|
|||
|
|
import numpy as np
|
|||
|
|
# 测试matplotlib和numpy的兼容性
|
|||
|
|
matplotlib.use('Agg') # 使用非交互式后端
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
print(f"[信息] Matplotlib版本: {matplotlib.__version__}")
|
|||
|
|
print(f"[信息] NumPy版本: {np.__version__}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[警告] Matplotlib/NumPy兼容性问题: {e}")
|
|||
|
|
print("[建议] 尝试升级matplotlib: pip install matplotlib --upgrade")
|
|||
|
|
# 继续执行,但可能会在后续遇到问题
|
|||
|
|
|
|||
|
|
model = YOLO(args.weights)
|
|||
|
|
|
|||
|
|
if args.mode == "validate":
|
|||
|
|
if not args.data:
|
|||
|
|
print("[错误] validate 模式需要 --data 指向 dataset.yaml")
|
|||
|
|
sys.exit(1)
|
|||
|
|
metrics = model.val(
|
|||
|
|
data=args.data,
|
|||
|
|
imgsz=args.imgsz,
|
|||
|
|
device=args.device if args.device else None,
|
|||
|
|
project=args.project,
|
|||
|
|
name=args.name or "val",
|
|||
|
|
iou=args.iou,
|
|||
|
|
verbose=True,
|
|||
|
|
)
|
|||
|
|
# metrics 包含各项指标,可按需打印
|
|||
|
|
print("评测完成。mAP50:", getattr(metrics, 'box', None).map50 if hasattr(metrics, 'box') else None)
|
|||
|
|
elif args.mode == "test":
|
|||
|
|
if not args.source:
|
|||
|
|
print("[错误] test 模式需要 --source 输入")
|
|||
|
|
sys.exit(1)
|
|||
|
|
metrics = evaluate_on_source(model, args.source, args)
|
|||
|
|
if metrics:
|
|||
|
|
print("测试完成。")
|
|||
|
|
else:
|
|||
|
|
print("测试完成(仅推理模式)。")
|
|||
|
|
elif args.mode == "annotation":
|
|||
|
|
if not args.source:
|
|||
|
|
print("[错误] annotation 模式需要 --source 输入")
|
|||
|
|
sys.exit(1)
|
|||
|
|
generate_labelme_annotations(model, args.source, args)
|
|||
|
|
print("标注生成完成。")
|
|||
|
|
else: # predict mode
|
|||
|
|
if not args.source:
|
|||
|
|
print("[错误] predict 模式需要 --source 输入")
|
|||
|
|
sys.exit(1)
|
|||
|
|
|
|||
|
|
src = Path(args.source)
|
|||
|
|
sources = None
|
|||
|
|
if src.is_dir():
|
|||
|
|
# 严格限制在该目录(默认非递归)
|
|||
|
|
sources = list_images_in_dir(src, recursive=args.recursive)
|
|||
|
|
if not sources:
|
|||
|
|
print(f"[错误] 目录下未找到图像: {src}")
|
|||
|
|
sys.exit(1)
|
|||
|
|
else:
|
|||
|
|
# 单个文件或其他流式输入,直接传递
|
|||
|
|
sources = [str(src)]
|
|||
|
|
|
|||
|
|
last_save_dir = None
|
|||
|
|
# 使用流式推理避免一次性将所有结果保存在内存/显存中
|
|||
|
|
for res in model.predict(
|
|||
|
|
source=str(src) if src.is_dir() else [str(src)],
|
|||
|
|
imgsz=args.imgsz,
|
|||
|
|
conf=args.conf,
|
|||
|
|
iou=args.iou,
|
|||
|
|
device=args.device if args.device else None,
|
|||
|
|
project=args.project,
|
|||
|
|
name=args.name or "pred",
|
|||
|
|
save=True,
|
|||
|
|
save_txt=args.save_txt,
|
|||
|
|
save_conf=args.save_conf,
|
|||
|
|
verbose=True,
|
|||
|
|
stream=True,
|
|||
|
|
):
|
|||
|
|
last_save_dir = getattr(res, 'save_dir', None)
|
|||
|
|
print("推理完成。结果目录:", last_save_dir or args.project)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|