[Feature] Support persistent_workers in DataLoader (PyTorch>=1.7.0) (#646)

pull/1801/head
Jerry Jiarui XU 2021-06-28 02:39:32 -07:00 committed by GitHub
parent 60baa4e841
commit bf746bf737
1 changed files with 33 additions and 24 deletions

View File

@ -4,11 +4,11 @@ import random
from functools import partial
import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg
from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
from torch.utils.data import DistributedSampler
from torch.utils.data import DataLoader, DistributedSampler
if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
@ -84,7 +84,7 @@ def build_dataloader(dataset,
seed=None,
drop_last=False,
pin_memory=True,
dataloader_type='PoolDataLoader',
persistent_workers=True,
**kwargs):
"""Build PyTorch DataLoader.
@ -106,7 +106,11 @@ def build_dataloader(dataset,
Default: False
pin_memory (bool): Whether to use pin_memory in DataLoader.
Default: True
dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
The argument also has effect in PyTorch>=1.7.0.
Default: True
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
@ -128,26 +132,31 @@ def build_dataloader(dataset,
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
assert dataloader_type in (
'DataLoader',
'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
if dataloader_type == 'PoolDataLoader':
dataloader = PoolDataLoader
elif dataloader_type == 'DataLoader':
dataloader = DataLoader
data_loader = dataloader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,
drop_last=drop_last,
**kwargs)
if torch.__version__ >= '1.7.0':
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,
drop_last=drop_last,
persistent_workers=persistent_workers,
**kwargs)
else:
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,
drop_last=drop_last,
**kwargs)
return data_loader