83 lines
2.5 KiB
Python
83 lines
2.5 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Test YOLO model on all images in test folder
|
|
Finds all images recursively and runs inference
|
|
"""
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
from ultralytics import YOLO
|
|
|
|
|
|
def find_all_images(root_dir: Path):
|
|
"""Find all image files recursively"""
|
|
image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
|
|
images = []
|
|
for ext in image_exts:
|
|
images.extend(root_dir.rglob(f'*{ext}'))
|
|
images.extend(root_dir.rglob(f'*{ext.upper()}'))
|
|
return sorted(images)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Test YOLO model on all test images")
|
|
parser.add_argument("--weights", required=True, help="Model weights path")
|
|
parser.add_argument("--source", required=True, help="Test folder path")
|
|
parser.add_argument("--output", default="runs/predict/test_all", help="Output directory")
|
|
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
|
|
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
|
|
args = parser.parse_args()
|
|
|
|
source_dir = Path(args.source)
|
|
if not source_dir.exists():
|
|
print(f"[ERROR] Source directory does not exist: {source_dir}")
|
|
return
|
|
|
|
# Find all images
|
|
print(f"[INFO] Searching for images in: {source_dir}")
|
|
images = find_all_images(source_dir)
|
|
print(f"[INFO] Found {len(images)} images")
|
|
|
|
if not images:
|
|
print("[ERROR] No images found!")
|
|
return
|
|
|
|
# Load model
|
|
print(f"[INFO] Loading model: {args.weights}")
|
|
model = YOLO(args.weights)
|
|
|
|
# Run inference
|
|
print(f"[INFO] Running inference on {len(images)} images...")
|
|
print(f"[INFO] Results will be saved to: {args.output}")
|
|
|
|
results = model.predict(
|
|
source=[str(img) for img in images],
|
|
imgsz=args.imgsz,
|
|
conf=args.conf,
|
|
save=True,
|
|
project=Path(args.output).parent,
|
|
name=Path(args.output).name,
|
|
)
|
|
|
|
# Print summary
|
|
total_detections = 0
|
|
for result in results:
|
|
if result.boxes is not None:
|
|
total_detections += len(result.boxes)
|
|
|
|
print("\n" + "="*60)
|
|
print("Test Summary:")
|
|
print("="*60)
|
|
print(f"Total images: {len(images)}")
|
|
print(f"Total detections: {total_detections}")
|
|
print(f"Average detections per image: {total_detections/len(images):.2f}")
|
|
print(f"Results saved to: {args.output}")
|
|
print("="*60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|