diff --git a/tools/misc/verify_dataset.py b/tools/misc/verify_dataset.py index 4b5eb7e0e..05e7bb931 100644 --- a/tools/misc/verify_dataset.py +++ b/tools/misc/verify_dataset.py @@ -7,7 +7,8 @@ from pathlib import Path from mmengine import (Config, DictAction, track_parallel_progress, track_progress) -from mmcls.datasets import PIPELINES, build_dataset +from mmcls.datasets import build_dataset +from mmcls.registry import TRANSFORMS def parse_args(): @@ -46,15 +47,14 @@ def parse_args(): class DatasetValidator(): """the dataset tool class to check if all file are broken.""" - def __init__(self, dataset_cfg, log_file_path, phase): + def __init__(self, dataset_cfg, log_file_path): super(DatasetValidator, self).__init__() # keep only LoadImageFromFile pipeline - assert dataset_cfg.data[phase].pipeline[0][ - 'type'] == 'LoadImageFromFile', 'This tool is only for dataset ' \ - 'that needs to load image from files.' - self.pipeline = PIPELINES.build(dataset_cfg.data[phase].pipeline[0]) - dataset_cfg.data[phase].pipeline = [] - dataset = build_dataset(dataset_cfg.data[phase]) + assert dataset_cfg.pipeline[0]['type'] == 'LoadImageFromFile', ( + 'This tool is only for datasets needs to load image from files.') + self.pipeline = TRANSFORMS.build(dataset_cfg.pipeline[0]) + dataset_cfg.pipeline = [] + dataset = build_dataset(dataset_cfg) self.dataset = dataset self.log_file_path = log_file_path @@ -102,13 +102,22 @@ def main(): # touch output file to save broken files list. output_path = Path(args.out_path) if not output_path.parent.exists(): - raise Exception('log_file parent directory not found.') + raise Exception("Path '--out-path' parent directory not found.") if output_path.exists(): os.remove(output_path) output_path.touch() - # do valid - validator = DatasetValidator(cfg, output_path, args.phase) + if args.phase == 'train': + dataset_cfg = cfg.train_dataloader.dataset + elif args.phase == 'val': + dataset_cfg = cfg.val_dataloader.dataset + elif args.phase == 'test': + dataset_cfg = cfg.test_dataloader.dataset + else: + raise ValueError("'--phase' only support 'train', 'val' and 'test'.") + + # do validate + validator = DatasetValidator(dataset_cfg, output_path) if args.num_process > 1: # The default chunksize calcuation method of Pool.map