[Feature] Support two options in build_dataloader. (#349)

* Support presistent_works in dataloader.

* Use pin_memory by default
This commit is contained in:
Ma Zerun 2021-07-14 15:21:49 +08:00 committed by GitHub
parent eb7c70c5c7
commit 76c5d34dcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,8 +1,10 @@
import platform
import random
from distutils.version import LooseVersion
from functools import partial
import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg
@ -47,6 +49,8 @@ def build_dataloader(dataset,
shuffle=True,
round_up=True,
seed=None,
pin_memory=True,
persistent_workers=True,
**kwargs):
"""Build PyTorch DataLoader.
@ -65,6 +69,13 @@ def build_dataloader(dataset,
Default: True.
round_up (bool): Whether to round up the length of dataset by adding
extra samples to make it evenly divisible. Default: True.
pin_memory (bool): Whether to use pin_memory in DataLoader.
Default: True
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
The argument also has effect in PyTorch>=1.7.0.
Default: True
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
@ -86,13 +97,16 @@ def build_dataloader(dataset,
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
kwargs['persistent_workers'] = persistent_workers
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=False,
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,
**kwargs)