Files
FishServer/FishMeasure/detection/test_all_images.py
2026-04-08 19:32:23 +08:00

83 lines
2.5 KiB
Python
Executable File

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Test YOLO model on all images in test folder
Finds all images recursively and runs inference
"""
import argparse
from pathlib import Path
from ultralytics import YOLO
def find_all_images(root_dir: Path):
"""Find all image files recursively"""
image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
images = []
for ext in image_exts:
images.extend(root_dir.rglob(f'*{ext}'))
images.extend(root_dir.rglob(f'*{ext.upper()}'))
return sorted(images)
def main():
parser = argparse.ArgumentParser(description="Test YOLO model on all test images")
parser.add_argument("--weights", required=True, help="Model weights path")
parser.add_argument("--source", required=True, help="Test folder path")
parser.add_argument("--output", default="runs/predict/test_all", help="Output directory")
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
args = parser.parse_args()
source_dir = Path(args.source)
if not source_dir.exists():
print(f"[ERROR] Source directory does not exist: {source_dir}")
return
# Find all images
print(f"[INFO] Searching for images in: {source_dir}")
images = find_all_images(source_dir)
print(f"[INFO] Found {len(images)} images")
if not images:
print("[ERROR] No images found!")
return
# Load model
print(f"[INFO] Loading model: {args.weights}")
model = YOLO(args.weights)
# Run inference
print(f"[INFO] Running inference on {len(images)} images...")
print(f"[INFO] Results will be saved to: {args.output}")
results = model.predict(
source=[str(img) for img in images],
imgsz=args.imgsz,
conf=args.conf,
save=True,
project=Path(args.output).parent,
name=Path(args.output).name,
)
# Print summary
total_detections = 0
for result in results:
if result.boxes is not None:
total_detections += len(result.boxes)
print("\n" + "="*60)
print("Test Summary:")
print("="*60)
print(f"Total images: {len(images)}")
print(f"Total detections: {total_detections}")
print(f"Average detections per image: {total_detections/len(images):.2f}")
print(f"Results saved to: {args.output}")
print("="*60)
if __name__ == "__main__":
main()