[Fix] Fix verify dataset tool in 1.x (#1062)
parent
8c5d86a388
commit
080eb79f94
|
@ -7,7 +7,8 @@ from pathlib import Path
|
||||||
from mmengine import (Config, DictAction, track_parallel_progress,
|
from mmengine import (Config, DictAction, track_parallel_progress,
|
||||||
track_progress)
|
track_progress)
|
||||||
|
|
||||||
from mmcls.datasets import PIPELINES, build_dataset
|
from mmcls.datasets import build_dataset
|
||||||
|
from mmcls.registry import TRANSFORMS
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -46,15 +47,14 @@ def parse_args():
|
||||||
class DatasetValidator():
|
class DatasetValidator():
|
||||||
"""the dataset tool class to check if all file are broken."""
|
"""the dataset tool class to check if all file are broken."""
|
||||||
|
|
||||||
def __init__(self, dataset_cfg, log_file_path, phase):
|
def __init__(self, dataset_cfg, log_file_path):
|
||||||
super(DatasetValidator, self).__init__()
|
super(DatasetValidator, self).__init__()
|
||||||
# keep only LoadImageFromFile pipeline
|
# keep only LoadImageFromFile pipeline
|
||||||
assert dataset_cfg.data[phase].pipeline[0][
|
assert dataset_cfg.pipeline[0]['type'] == 'LoadImageFromFile', (
|
||||||
'type'] == 'LoadImageFromFile', 'This tool is only for dataset ' \
|
'This tool is only for datasets needs to load image from files.')
|
||||||
'that needs to load image from files.'
|
self.pipeline = TRANSFORMS.build(dataset_cfg.pipeline[0])
|
||||||
self.pipeline = PIPELINES.build(dataset_cfg.data[phase].pipeline[0])
|
dataset_cfg.pipeline = []
|
||||||
dataset_cfg.data[phase].pipeline = []
|
dataset = build_dataset(dataset_cfg)
|
||||||
dataset = build_dataset(dataset_cfg.data[phase])
|
|
||||||
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.log_file_path = log_file_path
|
self.log_file_path = log_file_path
|
||||||
|
@ -102,13 +102,22 @@ def main():
|
||||||
# touch output file to save broken files list.
|
# touch output file to save broken files list.
|
||||||
output_path = Path(args.out_path)
|
output_path = Path(args.out_path)
|
||||||
if not output_path.parent.exists():
|
if not output_path.parent.exists():
|
||||||
raise Exception('log_file parent directory not found.')
|
raise Exception("Path '--out-path' parent directory not found.")
|
||||||
if output_path.exists():
|
if output_path.exists():
|
||||||
os.remove(output_path)
|
os.remove(output_path)
|
||||||
output_path.touch()
|
output_path.touch()
|
||||||
|
|
||||||
# do valid
|
if args.phase == 'train':
|
||||||
validator = DatasetValidator(cfg, output_path, args.phase)
|
dataset_cfg = cfg.train_dataloader.dataset
|
||||||
|
elif args.phase == 'val':
|
||||||
|
dataset_cfg = cfg.val_dataloader.dataset
|
||||||
|
elif args.phase == 'test':
|
||||||
|
dataset_cfg = cfg.test_dataloader.dataset
|
||||||
|
else:
|
||||||
|
raise ValueError("'--phase' only support 'train', 'val' and 'test'.")
|
||||||
|
|
||||||
|
# do validate
|
||||||
|
validator = DatasetValidator(dataset_cfg, output_path)
|
||||||
|
|
||||||
if args.num_process > 1:
|
if args.num_process > 1:
|
||||||
# The default chunksize calcuation method of Pool.map
|
# The default chunksize calcuation method of Pool.map
|
||||||
|
|
Loading…
Reference in New Issue