[Refactor] Deprecate imgs_per_gpu and use samples_per_gpu (#204)

* [Refactor] change imgs_per_gpu to samples_per_gpu in config files

* [Docs] change imgs_per_gpu to samples_per_gpu in docs

* [Refactor] change imgs_per_gpu to samples_per_gpu in codes and add warnings

* [Fix] fix isort

* [Docs] fix docs format

* [Refactor] add related UT codes

* [Fix] fix isort
pull/213/head
Yixiao Fang 2022-02-09 17:45:41 +08:00 committed by GitHub
parent 16b3f7b61e
commit af331b043f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 168 additions and 60 deletions

View File

@ -20,7 +20,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=128,
samples_per_gpu=128,
workers_per_gpu=2,
train=dict(
type=dataset_type,

View File

@ -23,7 +23,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32, # total 32x8=256, 8GPU linear cls
samples_per_gpu=32, # total 32x8=256, 8GPU linear cls
workers_per_gpu=4,
train=dict(
type=dataset_type,

View File

@ -23,7 +23,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32, # total 32x8=256, 8GPU linear cls
samples_per_gpu=32, # total 32x8=256, 8GPU linear cls
workers_per_gpu=4,
train=dict(
type=dataset_type,

View File

@ -25,7 +25,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32, # total 32x8=256, 8GPU linear cls
samples_per_gpu=32, # total 32x8=256, 8GPU linear cls
workers_per_gpu=4,
train=dict(
type=dataset_type,

View File

@ -10,7 +10,7 @@ model = dict(backbone=dict(norm_cfg=dict(type='SyncBN')))
# dataset settings
data = dict(
imgs_per_gpu=64, # total 64x4=256
samples_per_gpu=64, # total 64x4=256
train=dict(
data_source=dict(ann_file='data/imagenet/meta/train_10pct.txt')))

View File

@ -10,7 +10,7 @@ model = dict(backbone=dict(norm_cfg=dict(type='SyncBN')))
# dataset settings
data = dict(
imgs_per_gpu=64, # total 64x4=256
samples_per_gpu=64, # total 64x4=256
train=dict(
data_source=dict(ann_file='data/imagenet/meta/train_1percent.txt')))

View File

@ -8,7 +8,7 @@ _base_ = [
model = dict(backbone=dict(frozen_stages=4))
# dataset summary
data = dict(imgs_per_gpu=512) # total 512*8=4096, 8GPU linear cls
data = dict(samples_per_gpu=512) # total 512*8=4096, 8GPU linear cls
# simsiam setting
# runtime settings

View File

@ -9,7 +9,7 @@ _base_ = [
model = dict(backbone=dict(frozen_stages=12, norm_eval=True))
# dataset summary
data = dict(imgs_per_gpu=128) # total 128*8=1024, 8 GPU linear cls
data = dict(samples_per_gpu=128) # total 128*8=1024, 8 GPU linear cls
# optimizer
optimizer = dict(type='SGD', lr=12, momentum=0.9, weight_decay=0.)

View File

@ -5,7 +5,7 @@ split_name = ['voc07_trainval', 'voc07_test']
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
data = dict(
imgs_per_gpu=32,
samples_per_gpu=32,
workers_per_gpu=4,
extract=dict(
type=dataset_type,

View File

@ -4,7 +4,7 @@ name = 'imagenet_val'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
data = dict(
imgs_per_gpu=8,
samples_per_gpu=8,
workers_per_gpu=4,
extract=dict(
type='SingleViewDataset',

View File

@ -51,7 +51,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32, # total 32*8(gpu)=256
samples_per_gpu=32, # total 32*8(gpu)=256
workers_per_gpu=4,
train=dict(
type=dataset_type,

View File

@ -31,7 +31,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=64, # 64*8
samples_per_gpu=64, # 64*8
workers_per_gpu=4,
train=dict(
type=dataset_type,
@ -49,7 +49,7 @@ custom_hooks = [
dict(
type='DeepClusterHook',
extractor=dict(
imgs_per_gpu=128,
samples_per_gpu=128,
workers_per_gpu=8,
dataset=dict(
type=dataset_type,

View File

@ -23,7 +23,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32, # total 32*8=256
samples_per_gpu=32, # total 32*8=256
workers_per_gpu=4,
drop_last=True,
train=dict(

View File

@ -30,7 +30,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32, # total 32*8=256
samples_per_gpu=32, # total 32*8=256
workers_per_gpu=4,
drop_last=True,
train=dict(

View File

@ -51,7 +51,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=256, # 256*16(gpu)=4096
samples_per_gpu=256, # 256*16(gpu)=4096
workers_per_gpu=4,
train=dict(
type=dataset_type,

View File

@ -23,7 +23,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32, # total 32*8
samples_per_gpu=32, # total 32*8
workers_per_gpu=4,
train=dict(
type=dataset_type,

View File

@ -31,7 +31,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=64, # 64*8
samples_per_gpu=64, # 64*8
sampling_replace=True,
workers_per_gpu=4,
train=dict(
@ -50,7 +50,7 @@ custom_hooks = [
dict(
type='DeepClusterHook',
extractor=dict(
imgs_per_gpu=128,
samples_per_gpu=128,
workers_per_gpu=8,
dataset=dict(
type=dataset_type,

View File

@ -21,7 +21,7 @@ prefetch = False
# dataset summary
data = dict(
imgs_per_gpu=64, # 64 x 8 = 512
samples_per_gpu=64, # 64 x 8 = 512
workers_per_gpu=4,
train=dict(
type=dataset_type,

View File

@ -23,7 +23,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=16, # (16*4) x 8 = 512
samples_per_gpu=16, # (16*4) x 8 = 512
workers_per_gpu=2,
train=dict(
type=dataset_type,

View File

@ -29,7 +29,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32, # total 32*8
samples_per_gpu=32, # total 32*8
workers_per_gpu=4,
train=dict(
type=dataset_type,

View File

@ -51,7 +51,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32, # total 32*8=256
samples_per_gpu=32, # total 32*8=256
workers_per_gpu=4,
drop_last=True,
train=dict(

View File

@ -6,7 +6,7 @@ _base_ = [
]
# dataset summary
data = dict(imgs_per_gpu=256)
data = dict(samples_per_gpu=256)
# additional hooks
# interval for accumulate gradient, total 8*256*2(interval)=4096

View File

@ -58,7 +58,8 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=128, train=dict(pipelines=[train_pipeline1, train_pipeline2]))
samples_per_gpu=128,
train=dict(pipelines=[train_pipeline1, train_pipeline2]))
# MoCo v3 use the same momentum update method as BYOL
custom_hooks = [dict(type='MomentumUpdateHook')]

View File

@ -1,4 +1,4 @@
_base_ = 'simclr_resnet50_8xb32-coslr-200e_in1k.py'
# dataset summary
data = dict(imgs_per_gpu=64) # total 64*8
data = dict(samples_per_gpu=64) # total 64*8

View File

@ -13,7 +13,7 @@ custom_hooks = [
dict(
type='SwAVHook',
priority='VERY_HIGH',
batch_size={{_base_.data.imgs_per_gpu}},
batch_size={{_base_.data.samples_per_gpu}},
epoch_queue_starts=15,
crops_for_assign=[0, 1],
feat_dim=128,

View File

@ -207,7 +207,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32, # Batch size of a single GPU, total 32*8=256
samples_per_gpu=32, # Batch size of a single GPU, total 32*8=256
workers_per_gpu=4, # Worker to pre-fetch data for each single GPU
drop_last=True, # Whether to drop the last batch of data
train=dict(
@ -304,7 +304,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32,
samples_per_gpu=32,
workers_per_gpu=4,
drop_last=True,
train=dict(type=dataset_type, type=data_source, data_prefix=...),

View File

@ -154,14 +154,14 @@ When there is not enough computation resource, the batch size can only be set to
Here is an example:
```python
data = dict(imgs_per_gpu=64)
data = dict(samples_per_gpu=64)
optimizer_config = dict(type="DistOptimizerHook", update_interval=4)
```
Indicates that during training, back-propagation is performed every 4 iters. And the above is equivalent to:
```python
data = dict(imgs_per_gpu=256)
data = dict(samples_per_gpu=256)
optimizer_config = dict(type="OptimizerHook")
```

View File

@ -70,7 +70,7 @@ bash tools/benchmarks/classification/slurm_train_linear.sh ${PARTITION} ${JOB_NA
```
Remarks:
- The default GPU number is 8. When changing GPUS, please also change imgs_per_gpu in the config file accordingly to ensure the total batch size is 256.
- The default GPU number is 8. When changing GPUS, please also change `samples_per_gpu` in the config file accordingly to ensure the total batch size is 256.
- `CONFIG`: Use config files under `configs/benchmarks/classification/`, excluding svm_voc07.py and tsne_imagenet.py and imagenet_*percent folders.
- `PRETRAIN`: the pretrained model file.

View File

@ -207,7 +207,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32, # Batch size of a single GPU, total 32*8=256
samples_per_gpu=32, # Batch size of a single GPU, total 32*8=256
workers_per_gpu=4, # Worker to pre-fetch data for each single GPU
drop_last=True, # Whether to drop the last batch of data
train=dict(
@ -304,7 +304,7 @@ if not prefetch:
# dataset summary
data = dict(
imgs_per_gpu=32,
samples_per_gpu=32,
workers_per_gpu=4,
drop_last=True,
train=dict(type=dataset_type, type=data_source, data_prefix=...),

View File

@ -154,14 +154,14 @@ optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
用例如下:
```python
data = dict(imgs_per_gpu=64)
data = dict(samples_per_gpu=64)
optimizer_config = dict(type="DistOptimizerHook", update_interval=4)
```
表示训练时,每 4 个 iter 执行一次反向传播。由于此时单张 GPU 上的批次大小为 64也就等价于单张 GPU 上一次迭代的批次大小为 256也即
```python
data = dict(imgs_per_gpu=256)
data = dict(samples_per_gpu=256)
optimizer_config = dict(type="OptimizerHook")
```

View File

@ -70,7 +70,7 @@ bash tools/benchmarks/classification/slurm_train_linear.sh ${PARTITION} ${JOB_NA
```
Remarks:
- The default GPU number is 8. When changing GPUS, please also change imgs_per_gpu in the config file accordingly to ensure the total batch size is 256.
- The default GPU number is 8. When changing GPUS, please also change `samples_per_gpu` in the config file accordingly to ensure the total batch size is 256.
- `CONFIG`: Use config files under `configs/benchmarks/classification/`, excluding svm_voc07.py and tsne_imagenet.py and imagenet_*percent folders.
- `PRETRAIN`: the pretrained model file.

View File

@ -75,13 +75,26 @@ def train_model(model,
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loaders = [
build_dataloader(
ds,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
# `num_gpus` will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
replace=getattr(cfg.data, 'sampling_replace', False),
@ -161,7 +174,7 @@ def train_model(model,
val_dataset = build_dataset(cfg.data.val)
val_dataloader = build_dataloader(
val_dataset,
imgs_per_gpu=cfg.data.imgs_per_gpu,
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False,

View File

@ -7,6 +7,7 @@ from mmcv.utils import print_log
from mmselfsup.utils import Extractor
from mmselfsup.utils import clustering as _clustering
from mmselfsup.utils import get_root_logger
@HOOKS.register_module()
@ -41,6 +42,23 @@ class DeepClusterHook(Hook):
interval=1,
dist_mode=True,
data_loaders=None):
logger = get_root_logger()
if 'imgs_per_gpu' in extractor:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in extractor:
logger.warning(
f'Got "imgs_per_gpu"={extractor["imgs_per_gpu"]} and '
f'"samples_per_gpu"={extractor["samples_per_gpu"]}, '
f'"imgs_per_gpu"={extractor["imgs_per_gpu"]} is used in '
f'this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{extractor["imgs_per_gpu"]} in this experiments')
extractor['samples_per_gpu'] = extractor['imgs_per_gpu']
self.extractor = Extractor(dist_mode=dist_mode, **extractor)
self.clustering_type = clustering.pop('type')
self.clustering_cfg = clustering

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import random
import warnings
from functools import partial
import numpy as np
@ -45,8 +46,9 @@ def build_dataset(cfg, default_args=None):
def build_dataloader(dataset,
imgs_per_gpu,
workers_per_gpu,
imgs_per_gpu=None,
samples_per_gpu=None,
workers_per_gpu=1,
num_gpus=1,
dist=True,
shuffle=True,
@ -62,10 +64,13 @@ def build_dataloader(dataset,
Args:
dataset (Dataset): A PyTorch dataset.
imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of
each GPU.
imgs_per_gpu (int): (Deprecated, please use samples_per_gpu) Number of
images on each GPU, i.e., batch size of each GPU. Defaults to None.
samples_per_gpu (int): Number of images on each GPU, i.e., batch size
of each GPU. Defaults to None.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
for each GPU. `persistent_workers` option needs num_workers > 0.
Defaults to 1.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Defaults to True.
shuffle (bool): Whether to shuffle the data at every epoch.
@ -85,18 +90,33 @@ def build_dataloader(dataset,
Returns:
DataLoader: A PyTorch dataloader.
"""
if imgs_per_gpu is None and samples_per_gpu is None:
raise ValueError(
'Please inidcate number of images on each GPU, ',
'"imgs_per_gpu" and "samples_per_gpu" can not be "None" at the ',
'same time. "imgs_per_gpu" is deprecated, please use ',
'"samples_per_gpu".')
if imgs_per_gpu is not None:
warnings.warn(f'Got "imgs_per_gpu"={imgs_per_gpu} and '
f'"samples_per_gpu"={samples_per_gpu}, "imgs_per_gpu"'
f'={imgs_per_gpu} is used in this experiments. '
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{imgs_per_gpu} in this experiments')
samples_per_gpu = imgs_per_gpu
rank, world_size = get_dist_info()
if dist:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, replace=replace)
shuffle = False
batch_size = imgs_per_gpu
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
if replace:
return NotImplemented
sampler = None # TODO: set replace
batch_size = num_gpus * imgs_per_gpu
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
init_fn = partial(
@ -117,7 +137,7 @@ def build_dataloader(dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,

View File

@ -11,8 +11,8 @@ class Extractor(object):
Args:
dataset (Dataset | dict): A PyTorch dataset or dict that indicates
the dataset.
imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of
each GPU.
samples_per_gpu (int): Number of images on each GPU, i.e., batch size
of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
dist_mode (bool): Use distributed extraction or not. Defaults to False.
@ -25,7 +25,7 @@ class Extractor(object):
def __init__(self,
dataset,
imgs_per_gpu,
samples_per_gpu,
workers_per_gpu,
dist_mode=False,
persistent_workers=True,
@ -40,8 +40,8 @@ class Extractor(object):
f'not {type(dataset)}')
self.data_loader = datasets.build_dataloader(
self.dataset,
imgs_per_gpu,
workers_per_gpu,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=dist_mode,
shuffle=False,
persistent_workers=persistent_workers,

View File

@ -3,7 +3,7 @@ from unittest.mock import ANY
import pytest
from mmselfsup.datasets import (ConcatDataset, DeepClusterDataset,
RepeatDataset, build_dataset)
RepeatDataset, build_dataloader, build_dataset)
DATASET_CONFIG = dict(
type='DeepClusterDataset',
@ -50,3 +50,19 @@ DATASET_CONFIG = dict(
])
def test_build_dataset(cfg, expected_type):
assert isinstance(build_dataset(cfg), expected_type)
def test_build_dataloader():
dataset = build_dataset(DATASET_CONFIG)
with pytest.raises(ValueError):
data_loader = build_dataloader(dataset)
data_loader = build_dataloader(
dataset,
imgs_per_gpu=1,
samples_per_gpu=None,
dist=False,
)
assert len(data_loader) == 2
assert data_loader.batch_size == 1

View File

@ -156,9 +156,22 @@ def main():
dataset.data_source.data_infos = tmp_infos
logger.info(f'Apply t-SNE to visualize {len(dataset)} samples.')
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loader = build_dataloader(
dataset,
imgs_per_gpu=dataset_cfg.data.imgs_per_gpu,
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)

View File

@ -6,7 +6,7 @@ set -x
CFG=$1 # use cfgs under "configs/benchmarks/classification/imagenet/*.py"
PRETRAIN=$2 # pretrained model
PY_ARGS=${@:3}
GPUS=${GPUS:-8} # When changing GPUS, please also change imgs_per_gpu in the config file accordingly to ensure the total batch size is 256.
GPUS=${GPUS:-8} # When changing GPUS, please also change samples_per_gpu in the config file accordingly to ensure the total batch size is 256.
PORT=${PORT:-29500}
# set work_dir according to config path and pretrained model to distinguish different models

View File

@ -8,7 +8,7 @@ JOB_NAME=$2
CFG=$3 # use cfgs under "configs/benchmarks/classification/imagenet/*.py"
PRETRAIN=$4 # pretrained model
PY_ARGS=${@:5}
GPUS=${GPUS:-8} # When changing GPUS, please also change imgs_per_gpu in the config file accordingly to ensure the total batch size is 256.
GPUS=${GPUS:-8} # When changing GPUS, please also change samples_per_gpu in the config file accordingly to ensure the total batch size is 256.
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PORT=${PORT:-29500}

View File

@ -96,9 +96,23 @@ def main():
# build the dataloader
dataset_cfg = mmcv.Config.fromfile(args.dataset_config)
dataset = build_dataset(dataset_cfg.data.extract)
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loader = build_dataloader(
dataset,
imgs_per_gpu=dataset_cfg.data.imgs_per_gpu,
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)

View File

@ -96,9 +96,22 @@ def main():
# build the dataloader
dataset = build_dataset(cfg.data.val)
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loader = build_dataloader(
dataset,
imgs_per_gpu=cfg.data.imgs_per_gpu,
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)