diff --git a/mmcls/datasets/builder.py b/mmcls/datasets/builder.py index ebef4c9b4..fb6028a18 100644 --- a/mmcls/datasets/builder.py +++ b/mmcls/datasets/builder.py @@ -1,8 +1,10 @@ import platform import random +from distutils.version import LooseVersion from functools import partial import numpy as np +import torch from mmcv.parallel import collate from mmcv.runner import get_dist_info from mmcv.utils import Registry, build_from_cfg @@ -47,6 +49,8 @@ def build_dataloader(dataset, shuffle=True, round_up=True, seed=None, + pin_memory=True, + persistent_workers=True, **kwargs): """Build PyTorch DataLoader. @@ -65,6 +69,13 @@ def build_dataloader(dataset, Default: True. round_up (bool): Whether to round up the length of dataset by adding extra samples to make it evenly divisible. Default: True. + pin_memory (bool): Whether to use pin_memory in DataLoader. + Default: True + persistent_workers (bool): If True, the data loader will not shutdown + the worker processes after a dataset has been consumed once. + This allows to maintain the workers Dataset instances alive. + The argument also has effect in PyTorch>=1.7.0. + Default: True kwargs: any keyword argument to be used to initialize DataLoader Returns: @@ -86,13 +97,16 @@ def build_dataloader(dataset, worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): + kwargs['persistent_workers'] = persistent_workers + data_loader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), - pin_memory=False, + pin_memory=pin_memory, shuffle=shuffle, worker_init_fn=init_fn, **kwargs)