Initial commit: FishServer monorepo (FishAction, FishMeasure, fish_api)
Made-with: Cursor
This commit is contained in:
347
FishMeasure/weight_estimator/dataset.py
Executable file
347
FishMeasure/weight_estimator/dataset.py
Executable file
@@ -0,0 +1,347 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Build a dataset index mapping point clouds (.ply) to weight labels.
|
||||
|
||||
Data:
|
||||
- Point clouds live under:
|
||||
/home/ubuntu/data/fish/2025-11-19-output/{sample_id}/cloud/*.ply
|
||||
- Labels live at:
|
||||
/home/ubuntu/projects/FishMeasure/measure/data/label.csv
|
||||
|
||||
CSV format:
|
||||
- Column B (index 1): sample_id (svo2 name / folder name)
|
||||
- Column F (index 5): weight in grams (float)
|
||||
|
||||
Output JSON:
|
||||
{
|
||||
"meta": {...},
|
||||
"items": [
|
||||
{"ply": "<abs_or_rel_path>", "sample_id": "...", "weight_g": 123.45}
|
||||
],
|
||||
"mapping": {
|
||||
"<abs_or_rel_path>": 123.45
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
DEFAULT_DATA_ROOT = "/home/ubuntu/data/fish/2025-11-19-output"
|
||||
DEFAULT_LABEL_CSV = "/home/ubuntu/projects/FishMeasure/measure/data/label.csv"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LabelRow:
|
||||
sample_id: str
|
||||
weight_g: Optional[float]
|
||||
|
||||
|
||||
def _parse_weight_g(cell: str) -> Optional[float]:
|
||||
cell = (cell or "").strip()
|
||||
if not cell:
|
||||
return None
|
||||
try:
|
||||
return float(cell)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def load_labels(label_csv: Path) -> List[LabelRow]:
|
||||
"""
|
||||
Load labels from label.csv.
|
||||
|
||||
Returns a list of LabelRow (keeps duplicates).
|
||||
"""
|
||||
rows: List[LabelRow] = []
|
||||
with label_csv.open("r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
for r in reader:
|
||||
if len(r) < 2:
|
||||
continue
|
||||
sample_id = (r[1] or "").strip()
|
||||
if not sample_id or sample_id.lower() == "xxx":
|
||||
continue
|
||||
weight_g = _parse_weight_g(r[5] if len(r) > 5 else "")
|
||||
rows.append(LabelRow(sample_id=sample_id, weight_g=weight_g))
|
||||
return rows
|
||||
|
||||
|
||||
def resolve_sample_weights(
|
||||
label_rows: List[LabelRow],
|
||||
duplicate_policy: str = "mean",
|
||||
) -> Tuple[Dict[str, Optional[float]], Dict[str, List[Optional[float]]]]:
|
||||
"""
|
||||
Resolve sample_id -> single weight according to policy.
|
||||
|
||||
duplicate_policy:
|
||||
- mean: average of non-null weights
|
||||
- first: first non-null weight
|
||||
- error: raise if a sample_id has multiple distinct non-null weights
|
||||
|
||||
Returns:
|
||||
- resolved: {sample_id: weight_g or None}
|
||||
- raw: {sample_id: [weight_g_or_None, ...]}
|
||||
"""
|
||||
raw: Dict[str, List[Optional[float]]] = {}
|
||||
for row in label_rows:
|
||||
raw.setdefault(row.sample_id, []).append(row.weight_g)
|
||||
|
||||
resolved: Dict[str, Optional[float]] = {}
|
||||
for sample_id, weights in raw.items():
|
||||
non_null = [w for w in weights if w is not None]
|
||||
if not non_null:
|
||||
resolved[sample_id] = None
|
||||
continue
|
||||
|
||||
if duplicate_policy == "first":
|
||||
resolved[sample_id] = non_null[0]
|
||||
continue
|
||||
|
||||
if duplicate_policy == "mean":
|
||||
resolved[sample_id] = sum(non_null) / len(non_null)
|
||||
continue
|
||||
|
||||
if duplicate_policy == "error":
|
||||
distinct = sorted(set(non_null))
|
||||
if len(distinct) > 1:
|
||||
raise ValueError(
|
||||
f"Duplicate sample_id with multiple weights: {sample_id}: {distinct}"
|
||||
)
|
||||
resolved[sample_id] = distinct[0]
|
||||
continue
|
||||
|
||||
raise ValueError(f"Unknown duplicate_policy: {duplicate_policy}")
|
||||
|
||||
return resolved, raw
|
||||
|
||||
|
||||
def collect_ply_items(
|
||||
data_root: Path,
|
||||
sample_weights: Dict[str, Optional[float]],
|
||||
relative_paths: bool,
|
||||
) -> Tuple[List[dict], Dict[str, float], Dict[str, int]]:
|
||||
"""
|
||||
Build (items list, mapping dict, stats dict).
|
||||
"""
|
||||
items: List[dict] = []
|
||||
mapping: Dict[str, float] = {}
|
||||
|
||||
missing_folders = 0
|
||||
missing_weights = 0
|
||||
total_plys = 0
|
||||
|
||||
for sample_id in sorted(sample_weights.keys()):
|
||||
weight_g = sample_weights[sample_id]
|
||||
if weight_g is None:
|
||||
missing_weights += 1
|
||||
continue
|
||||
|
||||
cloud_dir = data_root / sample_id / "cloud"
|
||||
if not cloud_dir.exists():
|
||||
missing_folders += 1
|
||||
continue
|
||||
|
||||
ply_files = sorted(cloud_dir.glob("*.ply"))
|
||||
if not ply_files:
|
||||
# still count as missing folder-like data case
|
||||
missing_folders += 1
|
||||
continue
|
||||
|
||||
for ply in ply_files:
|
||||
total_plys += 1
|
||||
ply_path = ply
|
||||
if relative_paths:
|
||||
try:
|
||||
ply_path = ply.relative_to(data_root)
|
||||
except Exception:
|
||||
ply_path = ply
|
||||
|
||||
ply_key = str(ply_path)
|
||||
item = {
|
||||
"ply": ply_key,
|
||||
"sample_id": sample_id,
|
||||
"weight_g": float(weight_g),
|
||||
}
|
||||
items.append(item)
|
||||
mapping[ply_key] = float(weight_g)
|
||||
|
||||
stats = {
|
||||
"missing_folders": missing_folders,
|
||||
"missing_weights": missing_weights,
|
||||
"total_plys": total_plys,
|
||||
}
|
||||
return items, mapping, stats
|
||||
|
||||
|
||||
def build_index(
|
||||
data_root: Path,
|
||||
label_csv: Path,
|
||||
output_json: Path,
|
||||
duplicate_policy: str = "mean",
|
||||
relative_paths: bool = False,
|
||||
) -> dict:
|
||||
label_rows = load_labels(label_csv)
|
||||
sample_weights, raw_weights = resolve_sample_weights(
|
||||
label_rows, duplicate_policy=duplicate_policy
|
||||
)
|
||||
|
||||
items, mapping, stats = collect_ply_items(
|
||||
data_root=data_root,
|
||||
sample_weights=sample_weights,
|
||||
relative_paths=relative_paths,
|
||||
)
|
||||
|
||||
num_duplicates = sum(1 for _, ws in raw_weights.items() if len(ws) > 1)
|
||||
|
||||
out = {
|
||||
"meta": {
|
||||
"data_root": str(data_root),
|
||||
"label_csv": str(label_csv),
|
||||
"weight_column": "F",
|
||||
"duplicate_policy": duplicate_policy,
|
||||
"relative_paths": relative_paths,
|
||||
"num_label_rows": len(label_rows),
|
||||
"num_sample_ids": len(sample_weights),
|
||||
"num_sample_ids_with_duplicate_rows": num_duplicates,
|
||||
**stats,
|
||||
},
|
||||
"items": items,
|
||||
"mapping": mapping,
|
||||
}
|
||||
|
||||
output_json.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_json.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||
return out
|
||||
|
||||
|
||||
def prune_missing_files(index_json: Path, data_root: Optional[Path] = None) -> dict:
|
||||
"""
|
||||
Load index JSON, remove items whose PLY file does not exist, and return updated index.
|
||||
"""
|
||||
with index_json.open("r", encoding="utf-8") as f:
|
||||
index = json.load(f)
|
||||
|
||||
items = index.get("items", [])
|
||||
meta_root = index.get("meta", {}).get("data_root", None)
|
||||
root = Path(data_root or meta_root or "/").expanduser().resolve()
|
||||
|
||||
kept: List[dict] = []
|
||||
removed = 0
|
||||
|
||||
for item in items:
|
||||
ply_str = item.get("ply", "")
|
||||
ply_path = Path(ply_str)
|
||||
if not ply_path.is_absolute():
|
||||
ply_path = root / ply_path
|
||||
ply_path = ply_path.expanduser().resolve()
|
||||
|
||||
if ply_path.exists():
|
||||
kept.append(item)
|
||||
else:
|
||||
removed += 1
|
||||
|
||||
mapping = {it["ply"]: it["weight_g"] for it in kept}
|
||||
meta = index.get("meta", {})
|
||||
meta["total_plys"] = len(kept)
|
||||
meta["pruned_missing"] = removed
|
||||
|
||||
out = {
|
||||
"meta": meta,
|
||||
"items": kept,
|
||||
"mapping": mapping,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create a JSON mapping each .ply to weight (grams) using label.csv column F."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default=DEFAULT_DATA_ROOT,
|
||||
help=f"Dataset root containing { '{sample_id}/cloud/*.ply' } (default: {DEFAULT_DATA_ROOT})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label-csv",
|
||||
type=str,
|
||||
default=DEFAULT_LABEL_CSV,
|
||||
help=f"Path to label.csv (default: {DEFAULT_LABEL_CSV})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="weight_estimator/dataset_index.json",
|
||||
help="Output JSON path (default: weight_estimator/dataset_index.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--duplicate-policy",
|
||||
type=str,
|
||||
default="mean",
|
||||
choices=["mean", "first", "error"],
|
||||
help="How to resolve duplicate sample_id rows in label.csv (default: mean).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--relative-paths",
|
||||
action="store_true",
|
||||
help="Store PLY paths relative to --data-root instead of absolute paths.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prune-missing",
|
||||
action="store_true",
|
||||
help="Load existing index, remove entries for non-existent PLY files, and overwrite.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
output_json = Path(args.output).expanduser().resolve()
|
||||
|
||||
if args.prune_missing:
|
||||
if not output_json.exists():
|
||||
raise SystemExit(f"Index JSON not found (required for --prune-missing): {output_json}")
|
||||
data_root = Path(args.data_root).expanduser().resolve() if args.data_root else None
|
||||
out = prune_missing_files(output_json, data_root=data_root)
|
||||
output_json.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||
meta = out["meta"]
|
||||
print("Pruned dataset index (removed non-existent PLY entries).")
|
||||
print(f" output: {output_json}")
|
||||
print(f" total_plys: {meta['total_plys']} (removed {meta.get('pruned_missing', 0)} missing)")
|
||||
return
|
||||
|
||||
data_root = Path(args.data_root).expanduser().resolve()
|
||||
label_csv = Path(args.label_csv).expanduser().resolve()
|
||||
|
||||
if not data_root.exists():
|
||||
raise SystemExit(f"data root does not exist: {data_root}")
|
||||
if not label_csv.exists():
|
||||
raise SystemExit(f"label csv does not exist: {label_csv}")
|
||||
|
||||
out = build_index(
|
||||
data_root=data_root,
|
||||
label_csv=label_csv,
|
||||
output_json=output_json,
|
||||
duplicate_policy=args.duplicate_policy,
|
||||
relative_paths=args.relative_paths,
|
||||
)
|
||||
|
||||
meta = out["meta"]
|
||||
print("Dataset index written.")
|
||||
print(f" output: {output_json}")
|
||||
print(f" total_plys: {meta['total_plys']}")
|
||||
print(f" missing_weights: {meta['missing_weights']}")
|
||||
print(f" missing_folders: {meta['missing_folders']}")
|
||||
print(f" duplicate_policy: {meta['duplicate_policy']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user