[Feature] Support persistent_workers in DataLoader (PyTorch>=1.7.0) (#646)
parent
98067bec5c
commit
170a9d1f7c
|
@ -4,11 +4,11 @@ import random
|
|||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import collate
|
||||
from mmcv.runner import get_dist_info
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
|
||||
from torch.utils.data import DistributedSampler
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
if platform.system() != 'Windows':
|
||||
# https://github.com/pytorch/pytorch/issues/973
|
||||
|
@ -84,7 +84,7 @@ def build_dataloader(dataset,
|
|||
seed=None,
|
||||
drop_last=False,
|
||||
pin_memory=True,
|
||||
dataloader_type='PoolDataLoader',
|
||||
persistent_workers=True,
|
||||
**kwargs):
|
||||
"""Build PyTorch DataLoader.
|
||||
|
||||
|
@ -106,7 +106,11 @@ def build_dataloader(dataset,
|
|||
Default: False
|
||||
pin_memory (bool): Whether to use pin_memory in DataLoader.
|
||||
Default: True
|
||||
dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
|
||||
persistent_workers (bool): If True, the data loader will not shutdown
|
||||
the worker processes after a dataset has been consumed once.
|
||||
This allows to maintain the workers Dataset instances alive.
|
||||
The argument also has effect in PyTorch>=1.7.0.
|
||||
Default: True
|
||||
kwargs: any keyword argument to be used to initialize DataLoader
|
||||
|
||||
Returns:
|
||||
|
@ -128,26 +132,31 @@ def build_dataloader(dataset,
|
|||
worker_init_fn, num_workers=num_workers, rank=rank,
|
||||
seed=seed) if seed is not None else None
|
||||
|
||||
assert dataloader_type in (
|
||||
'DataLoader',
|
||||
'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
|
||||
|
||||
if dataloader_type == 'PoolDataLoader':
|
||||
dataloader = PoolDataLoader
|
||||
elif dataloader_type == 'DataLoader':
|
||||
dataloader = DataLoader
|
||||
|
||||
data_loader = dataloader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
||||
pin_memory=pin_memory,
|
||||
shuffle=shuffle,
|
||||
worker_init_fn=init_fn,
|
||||
drop_last=drop_last,
|
||||
**kwargs)
|
||||
if torch.__version__ >= '1.7.0':
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
||||
pin_memory=pin_memory,
|
||||
shuffle=shuffle,
|
||||
worker_init_fn=init_fn,
|
||||
drop_last=drop_last,
|
||||
persistent_workers=persistent_workers,
|
||||
**kwargs)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
||||
pin_memory=pin_memory,
|
||||
shuffle=shuffle,
|
||||
worker_init_fn=init_fn,
|
||||
drop_last=drop_last,
|
||||
**kwargs)
|
||||
|
||||
return data_loader
|
||||
|
||||
|
|
Loading…
Reference in New Issue