[Feature] Support two options in build_dataloader. (#349)

* Support presistent_works in dataloader.

* Use pin_memory by default
This commit is contained in:
Ma Zerun 2021-07-14 15:21:49 +08:00 committed by GitHub
parent eb7c70c5c7
commit 76c5d34dcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,8 +1,10 @@
import platform import platform
import random import random
from distutils.version import LooseVersion
from functools import partial from functools import partial
import numpy as np import numpy as np
import torch
from mmcv.parallel import collate from mmcv.parallel import collate
from mmcv.runner import get_dist_info from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg from mmcv.utils import Registry, build_from_cfg
@ -47,6 +49,8 @@ def build_dataloader(dataset,
shuffle=True, shuffle=True,
round_up=True, round_up=True,
seed=None, seed=None,
pin_memory=True,
persistent_workers=True,
**kwargs): **kwargs):
"""Build PyTorch DataLoader. """Build PyTorch DataLoader.
@ -65,6 +69,13 @@ def build_dataloader(dataset,
Default: True. Default: True.
round_up (bool): Whether to round up the length of dataset by adding round_up (bool): Whether to round up the length of dataset by adding
extra samples to make it evenly divisible. Default: True. extra samples to make it evenly divisible. Default: True.
pin_memory (bool): Whether to use pin_memory in DataLoader.
Default: True
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
The argument also has effect in PyTorch>=1.7.0.
Default: True
kwargs: any keyword argument to be used to initialize DataLoader kwargs: any keyword argument to be used to initialize DataLoader
Returns: Returns:
@ -86,13 +97,16 @@ def build_dataloader(dataset,
worker_init_fn, num_workers=num_workers, rank=rank, worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None seed=seed) if seed is not None else None
if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
kwargs['persistent_workers'] = persistent_workers
data_loader = DataLoader( data_loader = DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
sampler=sampler, sampler=sampler,
num_workers=num_workers, num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=False, pin_memory=pin_memory,
shuffle=shuffle, shuffle=shuffle,
worker_init_fn=init_fn, worker_init_fn=init_fn,
**kwargs) **kwargs)