mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Feature] Support two options in build_dataloader
. (#349)
* Support presistent_works in dataloader. * Use pin_memory by default
This commit is contained in:
parent
eb7c70c5c7
commit
76c5d34dcc
@ -1,8 +1,10 @@
|
|||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
|
from distutils.version import LooseVersion
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from mmcv.parallel import collate
|
from mmcv.parallel import collate
|
||||||
from mmcv.runner import get_dist_info
|
from mmcv.runner import get_dist_info
|
||||||
from mmcv.utils import Registry, build_from_cfg
|
from mmcv.utils import Registry, build_from_cfg
|
||||||
@ -47,6 +49,8 @@ def build_dataloader(dataset,
|
|||||||
shuffle=True,
|
shuffle=True,
|
||||||
round_up=True,
|
round_up=True,
|
||||||
seed=None,
|
seed=None,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Build PyTorch DataLoader.
|
"""Build PyTorch DataLoader.
|
||||||
|
|
||||||
@ -65,6 +69,13 @@ def build_dataloader(dataset,
|
|||||||
Default: True.
|
Default: True.
|
||||||
round_up (bool): Whether to round up the length of dataset by adding
|
round_up (bool): Whether to round up the length of dataset by adding
|
||||||
extra samples to make it evenly divisible. Default: True.
|
extra samples to make it evenly divisible. Default: True.
|
||||||
|
pin_memory (bool): Whether to use pin_memory in DataLoader.
|
||||||
|
Default: True
|
||||||
|
persistent_workers (bool): If True, the data loader will not shutdown
|
||||||
|
the worker processes after a dataset has been consumed once.
|
||||||
|
This allows to maintain the workers Dataset instances alive.
|
||||||
|
The argument also has effect in PyTorch>=1.7.0.
|
||||||
|
Default: True
|
||||||
kwargs: any keyword argument to be used to initialize DataLoader
|
kwargs: any keyword argument to be used to initialize DataLoader
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -86,13 +97,16 @@ def build_dataloader(dataset,
|
|||||||
worker_init_fn, num_workers=num_workers, rank=rank,
|
worker_init_fn, num_workers=num_workers, rank=rank,
|
||||||
seed=seed) if seed is not None else None
|
seed=seed) if seed is not None else None
|
||||||
|
|
||||||
|
if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
|
||||||
|
kwargs['persistent_workers'] = persistent_workers
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
||||||
pin_memory=False,
|
pin_memory=pin_memory,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
worker_init_fn=init_fn,
|
worker_init_fn=init_fn,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user