[Feature] Support persistent_workers in DataLoader (PyTorch>=1.7.0) (#646)
parent
60baa4e841
commit
bf746bf737
|
@ -4,11 +4,11 @@ import random
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from mmcv.parallel import collate
|
from mmcv.parallel import collate
|
||||||
from mmcv.runner import get_dist_info
|
from mmcv.runner import get_dist_info
|
||||||
from mmcv.utils import Registry, build_from_cfg
|
from mmcv.utils import Registry, build_from_cfg
|
||||||
from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from torch.utils.data import DistributedSampler
|
|
||||||
|
|
||||||
if platform.system() != 'Windows':
|
if platform.system() != 'Windows':
|
||||||
# https://github.com/pytorch/pytorch/issues/973
|
# https://github.com/pytorch/pytorch/issues/973
|
||||||
|
@ -84,7 +84,7 @@ def build_dataloader(dataset,
|
||||||
seed=None,
|
seed=None,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
dataloader_type='PoolDataLoader',
|
persistent_workers=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Build PyTorch DataLoader.
|
"""Build PyTorch DataLoader.
|
||||||
|
|
||||||
|
@ -106,7 +106,11 @@ def build_dataloader(dataset,
|
||||||
Default: False
|
Default: False
|
||||||
pin_memory (bool): Whether to use pin_memory in DataLoader.
|
pin_memory (bool): Whether to use pin_memory in DataLoader.
|
||||||
Default: True
|
Default: True
|
||||||
dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
|
persistent_workers (bool): If True, the data loader will not shutdown
|
||||||
|
the worker processes after a dataset has been consumed once.
|
||||||
|
This allows to maintain the workers Dataset instances alive.
|
||||||
|
The argument also has effect in PyTorch>=1.7.0.
|
||||||
|
Default: True
|
||||||
kwargs: any keyword argument to be used to initialize DataLoader
|
kwargs: any keyword argument to be used to initialize DataLoader
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -128,16 +132,21 @@ def build_dataloader(dataset,
|
||||||
worker_init_fn, num_workers=num_workers, rank=rank,
|
worker_init_fn, num_workers=num_workers, rank=rank,
|
||||||
seed=seed) if seed is not None else None
|
seed=seed) if seed is not None else None
|
||||||
|
|
||||||
assert dataloader_type in (
|
if torch.__version__ >= '1.7.0':
|
||||||
'DataLoader',
|
data_loader = DataLoader(
|
||||||
'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
if dataloader_type == 'PoolDataLoader':
|
sampler=sampler,
|
||||||
dataloader = PoolDataLoader
|
num_workers=num_workers,
|
||||||
elif dataloader_type == 'DataLoader':
|
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
||||||
dataloader = DataLoader
|
pin_memory=pin_memory,
|
||||||
|
shuffle=shuffle,
|
||||||
data_loader = dataloader(
|
worker_init_fn=init_fn,
|
||||||
|
drop_last=drop_last,
|
||||||
|
persistent_workers=persistent_workers,
|
||||||
|
**kwargs)
|
||||||
|
else:
|
||||||
|
data_loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
|
Loading…
Reference in New Issue