[Feature] Support persistent_workers in DataLoader (PyTorch>=1.7.0) (#646)

pull/1801/head
Jerry Jiarui XU 2021-06-28 02:39:32 -07:00 committed by GitHub
parent 60baa4e841
commit bf746bf737
1 changed files with 33 additions and 24 deletions

View File

@ -4,11 +4,11 @@ import random
from functools import partial from functools import partial
import numpy as np import numpy as np
import torch
from mmcv.parallel import collate from mmcv.parallel import collate
from mmcv.runner import get_dist_info from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg from mmcv.utils import Registry, build_from_cfg
from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import DistributedSampler
if platform.system() != 'Windows': if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973 # https://github.com/pytorch/pytorch/issues/973
@ -84,7 +84,7 @@ def build_dataloader(dataset,
seed=None, seed=None,
drop_last=False, drop_last=False,
pin_memory=True, pin_memory=True,
dataloader_type='PoolDataLoader', persistent_workers=True,
**kwargs): **kwargs):
"""Build PyTorch DataLoader. """Build PyTorch DataLoader.
@ -106,7 +106,11 @@ def build_dataloader(dataset,
Default: False Default: False
pin_memory (bool): Whether to use pin_memory in DataLoader. pin_memory (bool): Whether to use pin_memory in DataLoader.
Default: True Default: True
dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader' persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
The argument also has effect in PyTorch>=1.7.0.
Default: True
kwargs: any keyword argument to be used to initialize DataLoader kwargs: any keyword argument to be used to initialize DataLoader
Returns: Returns:
@ -128,16 +132,21 @@ def build_dataloader(dataset,
worker_init_fn, num_workers=num_workers, rank=rank, worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None seed=seed) if seed is not None else None
assert dataloader_type in ( if torch.__version__ >= '1.7.0':
'DataLoader', data_loader = DataLoader(
'PoolDataLoader'), f'unsupported dataloader {dataloader_type}' dataset,
batch_size=batch_size,
if dataloader_type == 'PoolDataLoader': sampler=sampler,
dataloader = PoolDataLoader num_workers=num_workers,
elif dataloader_type == 'DataLoader': collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
dataloader = DataLoader pin_memory=pin_memory,
shuffle=shuffle,
data_loader = dataloader( worker_init_fn=init_fn,
drop_last=drop_last,
persistent_workers=persistent_workers,
**kwargs)
else:
data_loader = DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
sampler=sampler, sampler=sampler,