[Refactor] Deprecate imgs_per_gpu and use samples_per_gpu (#204)

* [Refactor] change imgs_per_gpu to samples_per_gpu in config files

* [Docs] change imgs_per_gpu to samples_per_gpu in docs

* [Refactor] change imgs_per_gpu to samples_per_gpu in codes and add warnings

* [Fix] fix isort

* [Docs] fix docs format

* [Refactor] add related UT codes

* [Fix] fix isort
pull/213/head
Yixiao Fang 2022-02-09 17:45:41 +08:00 committed by GitHub
parent 16b3f7b61e
commit af331b043f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 168 additions and 60 deletions

View File

@ -20,7 +20,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=128, samples_per_gpu=128,
workers_per_gpu=2, workers_per_gpu=2,
train=dict( train=dict(
type=dataset_type, type=dataset_type,

View File

@ -23,7 +23,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, # total 32x8=256, 8GPU linear cls samples_per_gpu=32, # total 32x8=256, 8GPU linear cls
workers_per_gpu=4, workers_per_gpu=4,
train=dict( train=dict(
type=dataset_type, type=dataset_type,

View File

@ -23,7 +23,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, # total 32x8=256, 8GPU linear cls samples_per_gpu=32, # total 32x8=256, 8GPU linear cls
workers_per_gpu=4, workers_per_gpu=4,
train=dict( train=dict(
type=dataset_type, type=dataset_type,

View File

@ -25,7 +25,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, # total 32x8=256, 8GPU linear cls samples_per_gpu=32, # total 32x8=256, 8GPU linear cls
workers_per_gpu=4, workers_per_gpu=4,
train=dict( train=dict(
type=dataset_type, type=dataset_type,

View File

@ -10,7 +10,7 @@ model = dict(backbone=dict(norm_cfg=dict(type='SyncBN')))
# dataset settings # dataset settings
data = dict( data = dict(
imgs_per_gpu=64, # total 64x4=256 samples_per_gpu=64, # total 64x4=256
train=dict( train=dict(
data_source=dict(ann_file='data/imagenet/meta/train_10pct.txt'))) data_source=dict(ann_file='data/imagenet/meta/train_10pct.txt')))

View File

@ -10,7 +10,7 @@ model = dict(backbone=dict(norm_cfg=dict(type='SyncBN')))
# dataset settings # dataset settings
data = dict( data = dict(
imgs_per_gpu=64, # total 64x4=256 samples_per_gpu=64, # total 64x4=256
train=dict( train=dict(
data_source=dict(ann_file='data/imagenet/meta/train_1percent.txt'))) data_source=dict(ann_file='data/imagenet/meta/train_1percent.txt')))

View File

@ -8,7 +8,7 @@ _base_ = [
model = dict(backbone=dict(frozen_stages=4)) model = dict(backbone=dict(frozen_stages=4))
# dataset summary # dataset summary
data = dict(imgs_per_gpu=512) # total 512*8=4096, 8GPU linear cls data = dict(samples_per_gpu=512) # total 512*8=4096, 8GPU linear cls
# simsiam setting # simsiam setting
# runtime settings # runtime settings

View File

@ -9,7 +9,7 @@ _base_ = [
model = dict(backbone=dict(frozen_stages=12, norm_eval=True)) model = dict(backbone=dict(frozen_stages=12, norm_eval=True))
# dataset summary # dataset summary
data = dict(imgs_per_gpu=128) # total 128*8=1024, 8 GPU linear cls data = dict(samples_per_gpu=128) # total 128*8=1024, 8 GPU linear cls
# optimizer # optimizer
optimizer = dict(type='SGD', lr=12, momentum=0.9, weight_decay=0.) optimizer = dict(type='SGD', lr=12, momentum=0.9, weight_decay=0.)

View File

@ -5,7 +5,7 @@ split_name = ['voc07_trainval', 'voc07_test']
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
data = dict( data = dict(
imgs_per_gpu=32, samples_per_gpu=32,
workers_per_gpu=4, workers_per_gpu=4,
extract=dict( extract=dict(
type=dataset_type, type=dataset_type,

View File

@ -4,7 +4,7 @@ name = 'imagenet_val'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
data = dict( data = dict(
imgs_per_gpu=8, samples_per_gpu=8,
workers_per_gpu=4, workers_per_gpu=4,
extract=dict( extract=dict(
type='SingleViewDataset', type='SingleViewDataset',

View File

@ -51,7 +51,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, # total 32*8(gpu)=256 samples_per_gpu=32, # total 32*8(gpu)=256
workers_per_gpu=4, workers_per_gpu=4,
train=dict( train=dict(
type=dataset_type, type=dataset_type,

View File

@ -31,7 +31,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=64, # 64*8 samples_per_gpu=64, # 64*8
workers_per_gpu=4, workers_per_gpu=4,
train=dict( train=dict(
type=dataset_type, type=dataset_type,
@ -49,7 +49,7 @@ custom_hooks = [
dict( dict(
type='DeepClusterHook', type='DeepClusterHook',
extractor=dict( extractor=dict(
imgs_per_gpu=128, samples_per_gpu=128,
workers_per_gpu=8, workers_per_gpu=8,
dataset=dict( dataset=dict(
type=dataset_type, type=dataset_type,

View File

@ -23,7 +23,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, # total 32*8=256 samples_per_gpu=32, # total 32*8=256
workers_per_gpu=4, workers_per_gpu=4,
drop_last=True, drop_last=True,
train=dict( train=dict(

View File

@ -30,7 +30,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, # total 32*8=256 samples_per_gpu=32, # total 32*8=256
workers_per_gpu=4, workers_per_gpu=4,
drop_last=True, drop_last=True,
train=dict( train=dict(

View File

@ -51,7 +51,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=256, # 256*16(gpu)=4096 samples_per_gpu=256, # 256*16(gpu)=4096
workers_per_gpu=4, workers_per_gpu=4,
train=dict( train=dict(
type=dataset_type, type=dataset_type,

View File

@ -23,7 +23,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, # total 32*8 samples_per_gpu=32, # total 32*8
workers_per_gpu=4, workers_per_gpu=4,
train=dict( train=dict(
type=dataset_type, type=dataset_type,

View File

@ -31,7 +31,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=64, # 64*8 samples_per_gpu=64, # 64*8
sampling_replace=True, sampling_replace=True,
workers_per_gpu=4, workers_per_gpu=4,
train=dict( train=dict(
@ -50,7 +50,7 @@ custom_hooks = [
dict( dict(
type='DeepClusterHook', type='DeepClusterHook',
extractor=dict( extractor=dict(
imgs_per_gpu=128, samples_per_gpu=128,
workers_per_gpu=8, workers_per_gpu=8,
dataset=dict( dataset=dict(
type=dataset_type, type=dataset_type,

View File

@ -21,7 +21,7 @@ prefetch = False
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=64, # 64 x 8 = 512 samples_per_gpu=64, # 64 x 8 = 512
workers_per_gpu=4, workers_per_gpu=4,
train=dict( train=dict(
type=dataset_type, type=dataset_type,

View File

@ -23,7 +23,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=16, # (16*4) x 8 = 512 samples_per_gpu=16, # (16*4) x 8 = 512
workers_per_gpu=2, workers_per_gpu=2,
train=dict( train=dict(
type=dataset_type, type=dataset_type,

View File

@ -29,7 +29,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, # total 32*8 samples_per_gpu=32, # total 32*8
workers_per_gpu=4, workers_per_gpu=4,
train=dict( train=dict(
type=dataset_type, type=dataset_type,

View File

@ -51,7 +51,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, # total 32*8=256 samples_per_gpu=32, # total 32*8=256
workers_per_gpu=4, workers_per_gpu=4,
drop_last=True, drop_last=True,
train=dict( train=dict(

View File

@ -6,7 +6,7 @@ _base_ = [
] ]
# dataset summary # dataset summary
data = dict(imgs_per_gpu=256) data = dict(samples_per_gpu=256)
# additional hooks # additional hooks
# interval for accumulate gradient, total 8*256*2(interval)=4096 # interval for accumulate gradient, total 8*256*2(interval)=4096

View File

@ -58,7 +58,8 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=128, train=dict(pipelines=[train_pipeline1, train_pipeline2])) samples_per_gpu=128,
train=dict(pipelines=[train_pipeline1, train_pipeline2]))
# MoCo v3 use the same momentum update method as BYOL # MoCo v3 use the same momentum update method as BYOL
custom_hooks = [dict(type='MomentumUpdateHook')] custom_hooks = [dict(type='MomentumUpdateHook')]

View File

@ -1,4 +1,4 @@
_base_ = 'simclr_resnet50_8xb32-coslr-200e_in1k.py' _base_ = 'simclr_resnet50_8xb32-coslr-200e_in1k.py'
# dataset summary # dataset summary
data = dict(imgs_per_gpu=64) # total 64*8 data = dict(samples_per_gpu=64) # total 64*8

View File

@ -13,7 +13,7 @@ custom_hooks = [
dict( dict(
type='SwAVHook', type='SwAVHook',
priority='VERY_HIGH', priority='VERY_HIGH',
batch_size={{_base_.data.imgs_per_gpu}}, batch_size={{_base_.data.samples_per_gpu}},
epoch_queue_starts=15, epoch_queue_starts=15,
crops_for_assign=[0, 1], crops_for_assign=[0, 1],
feat_dim=128, feat_dim=128,

View File

@ -207,7 +207,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, # Batch size of a single GPU, total 32*8=256 samples_per_gpu=32, # Batch size of a single GPU, total 32*8=256
workers_per_gpu=4, # Worker to pre-fetch data for each single GPU workers_per_gpu=4, # Worker to pre-fetch data for each single GPU
drop_last=True, # Whether to drop the last batch of data drop_last=True, # Whether to drop the last batch of data
train=dict( train=dict(
@ -304,7 +304,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, samples_per_gpu=32,
workers_per_gpu=4, workers_per_gpu=4,
drop_last=True, drop_last=True,
train=dict(type=dataset_type, type=data_source, data_prefix=...), train=dict(type=dataset_type, type=data_source, data_prefix=...),

View File

@ -154,14 +154,14 @@ When there is not enough computation resource, the batch size can only be set to
Here is an example: Here is an example:
```python ```python
data = dict(imgs_per_gpu=64) data = dict(samples_per_gpu=64)
optimizer_config = dict(type="DistOptimizerHook", update_interval=4) optimizer_config = dict(type="DistOptimizerHook", update_interval=4)
``` ```
Indicates that during training, back-propagation is performed every 4 iters. And the above is equivalent to: Indicates that during training, back-propagation is performed every 4 iters. And the above is equivalent to:
```python ```python
data = dict(imgs_per_gpu=256) data = dict(samples_per_gpu=256)
optimizer_config = dict(type="OptimizerHook") optimizer_config = dict(type="OptimizerHook")
``` ```

View File

@ -70,7 +70,7 @@ bash tools/benchmarks/classification/slurm_train_linear.sh ${PARTITION} ${JOB_NA
``` ```
Remarks: Remarks:
- The default GPU number is 8. When changing GPUS, please also change imgs_per_gpu in the config file accordingly to ensure the total batch size is 256. - The default GPU number is 8. When changing GPUS, please also change `samples_per_gpu` in the config file accordingly to ensure the total batch size is 256.
- `CONFIG`: Use config files under `configs/benchmarks/classification/`, excluding svm_voc07.py and tsne_imagenet.py and imagenet_*percent folders. - `CONFIG`: Use config files under `configs/benchmarks/classification/`, excluding svm_voc07.py and tsne_imagenet.py and imagenet_*percent folders.
- `PRETRAIN`: the pretrained model file. - `PRETRAIN`: the pretrained model file.

View File

@ -207,7 +207,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, # Batch size of a single GPU, total 32*8=256 samples_per_gpu=32, # Batch size of a single GPU, total 32*8=256
workers_per_gpu=4, # Worker to pre-fetch data for each single GPU workers_per_gpu=4, # Worker to pre-fetch data for each single GPU
drop_last=True, # Whether to drop the last batch of data drop_last=True, # Whether to drop the last batch of data
train=dict( train=dict(
@ -304,7 +304,7 @@ if not prefetch:
# dataset summary # dataset summary
data = dict( data = dict(
imgs_per_gpu=32, samples_per_gpu=32,
workers_per_gpu=4, workers_per_gpu=4,
drop_last=True, drop_last=True,
train=dict(type=dataset_type, type=data_source, data_prefix=...), train=dict(type=dataset_type, type=data_source, data_prefix=...),

View File

@ -154,14 +154,14 @@ optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
用例如下: 用例如下:
```python ```python
data = dict(imgs_per_gpu=64) data = dict(samples_per_gpu=64)
optimizer_config = dict(type="DistOptimizerHook", update_interval=4) optimizer_config = dict(type="DistOptimizerHook", update_interval=4)
``` ```
表示训练时,每 4 个 iter 执行一次反向传播。由于此时单张 GPU 上的批次大小为 64也就等价于单张 GPU 上一次迭代的批次大小为 256也即 表示训练时,每 4 个 iter 执行一次反向传播。由于此时单张 GPU 上的批次大小为 64也就等价于单张 GPU 上一次迭代的批次大小为 256也即
```python ```python
data = dict(imgs_per_gpu=256) data = dict(samples_per_gpu=256)
optimizer_config = dict(type="OptimizerHook") optimizer_config = dict(type="OptimizerHook")
``` ```

View File

@ -70,7 +70,7 @@ bash tools/benchmarks/classification/slurm_train_linear.sh ${PARTITION} ${JOB_NA
``` ```
Remarks: Remarks:
- The default GPU number is 8. When changing GPUS, please also change imgs_per_gpu in the config file accordingly to ensure the total batch size is 256. - The default GPU number is 8. When changing GPUS, please also change `samples_per_gpu` in the config file accordingly to ensure the total batch size is 256.
- `CONFIG`: Use config files under `configs/benchmarks/classification/`, excluding svm_voc07.py and tsne_imagenet.py and imagenet_*percent folders. - `CONFIG`: Use config files under `configs/benchmarks/classification/`, excluding svm_voc07.py and tsne_imagenet.py and imagenet_*percent folders.
- `PRETRAIN`: the pretrained model file. - `PRETRAIN`: the pretrained model file.

View File

@ -75,13 +75,26 @@ def train_model(model,
# prepare data loaders # prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loaders = [ data_loaders = [
build_dataloader( build_dataloader(
ds, ds,
cfg.data.imgs_per_gpu, samples_per_gpu=cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed # `num_gpus` will be ignored if distributed
num_gpus=len(cfg.gpu_ids), num_gpus=len(cfg.gpu_ids),
dist=distributed, dist=distributed,
replace=getattr(cfg.data, 'sampling_replace', False), replace=getattr(cfg.data, 'sampling_replace', False),
@ -161,7 +174,7 @@ def train_model(model,
val_dataset = build_dataset(cfg.data.val) val_dataset = build_dataset(cfg.data.val)
val_dataloader = build_dataloader( val_dataloader = build_dataloader(
val_dataset, val_dataset,
imgs_per_gpu=cfg.data.imgs_per_gpu, samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed, dist=distributed,
shuffle=False, shuffle=False,

View File

@ -7,6 +7,7 @@ from mmcv.utils import print_log
from mmselfsup.utils import Extractor from mmselfsup.utils import Extractor
from mmselfsup.utils import clustering as _clustering from mmselfsup.utils import clustering as _clustering
from mmselfsup.utils import get_root_logger
@HOOKS.register_module() @HOOKS.register_module()
@ -41,6 +42,23 @@ class DeepClusterHook(Hook):
interval=1, interval=1,
dist_mode=True, dist_mode=True,
data_loaders=None): data_loaders=None):
logger = get_root_logger()
if 'imgs_per_gpu' in extractor:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in extractor:
logger.warning(
f'Got "imgs_per_gpu"={extractor["imgs_per_gpu"]} and '
f'"samples_per_gpu"={extractor["samples_per_gpu"]}, '
f'"imgs_per_gpu"={extractor["imgs_per_gpu"]} is used in '
f'this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{extractor["imgs_per_gpu"]} in this experiments')
extractor['samples_per_gpu'] = extractor['imgs_per_gpu']
self.extractor = Extractor(dist_mode=dist_mode, **extractor) self.extractor = Extractor(dist_mode=dist_mode, **extractor)
self.clustering_type = clustering.pop('type') self.clustering_type = clustering.pop('type')
self.clustering_cfg = clustering self.clustering_cfg = clustering

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import platform import platform
import random import random
import warnings
from functools import partial from functools import partial
import numpy as np import numpy as np
@ -45,8 +46,9 @@ def build_dataset(cfg, default_args=None):
def build_dataloader(dataset, def build_dataloader(dataset,
imgs_per_gpu, imgs_per_gpu=None,
workers_per_gpu, samples_per_gpu=None,
workers_per_gpu=1,
num_gpus=1, num_gpus=1,
dist=True, dist=True,
shuffle=True, shuffle=True,
@ -62,10 +64,13 @@ def build_dataloader(dataset,
Args: Args:
dataset (Dataset): A PyTorch dataset. dataset (Dataset): A PyTorch dataset.
imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of imgs_per_gpu (int): (Deprecated, please use samples_per_gpu) Number of
each GPU. images on each GPU, i.e., batch size of each GPU. Defaults to None.
samples_per_gpu (int): Number of images on each GPU, i.e., batch size
of each GPU. Defaults to None.
workers_per_gpu (int): How many subprocesses to use for data loading workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU. for each GPU. `persistent_workers` option needs num_workers > 0.
Defaults to 1.
num_gpus (int): Number of GPUs. Only used in non-distributed training. num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Defaults to True. dist (bool): Distributed training/test or not. Defaults to True.
shuffle (bool): Whether to shuffle the data at every epoch. shuffle (bool): Whether to shuffle the data at every epoch.
@ -85,18 +90,33 @@ def build_dataloader(dataset,
Returns: Returns:
DataLoader: A PyTorch dataloader. DataLoader: A PyTorch dataloader.
""" """
if imgs_per_gpu is None and samples_per_gpu is None:
raise ValueError(
'Please inidcate number of images on each GPU, ',
'"imgs_per_gpu" and "samples_per_gpu" can not be "None" at the ',
'same time. "imgs_per_gpu" is deprecated, please use ',
'"samples_per_gpu".')
if imgs_per_gpu is not None:
warnings.warn(f'Got "imgs_per_gpu"={imgs_per_gpu} and '
f'"samples_per_gpu"={samples_per_gpu}, "imgs_per_gpu"'
f'={imgs_per_gpu} is used in this experiments. '
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{imgs_per_gpu} in this experiments')
samples_per_gpu = imgs_per_gpu
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
if dist: if dist:
sampler = DistributedSampler( sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, replace=replace) dataset, world_size, rank, shuffle=shuffle, replace=replace)
shuffle = False shuffle = False
batch_size = imgs_per_gpu batch_size = samples_per_gpu
num_workers = workers_per_gpu num_workers = workers_per_gpu
else: else:
if replace: if replace:
return NotImplemented return NotImplemented
sampler = None # TODO: set replace sampler = None # TODO: set replace
batch_size = num_gpus * imgs_per_gpu batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu num_workers = num_gpus * workers_per_gpu
init_fn = partial( init_fn = partial(
@ -117,7 +137,7 @@ def build_dataloader(dataset,
batch_size=batch_size, batch_size=batch_size,
sampler=sampler, sampler=sampler,
num_workers=num_workers, num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu), collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory, pin_memory=pin_memory,
shuffle=shuffle, shuffle=shuffle,
worker_init_fn=init_fn, worker_init_fn=init_fn,

View File

@ -11,8 +11,8 @@ class Extractor(object):
Args: Args:
dataset (Dataset | dict): A PyTorch dataset or dict that indicates dataset (Dataset | dict): A PyTorch dataset or dict that indicates
the dataset. the dataset.
imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of samples_per_gpu (int): Number of images on each GPU, i.e., batch size
each GPU. of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU. for each GPU.
dist_mode (bool): Use distributed extraction or not. Defaults to False. dist_mode (bool): Use distributed extraction or not. Defaults to False.
@ -25,7 +25,7 @@ class Extractor(object):
def __init__(self, def __init__(self,
dataset, dataset,
imgs_per_gpu, samples_per_gpu,
workers_per_gpu, workers_per_gpu,
dist_mode=False, dist_mode=False,
persistent_workers=True, persistent_workers=True,
@ -40,8 +40,8 @@ class Extractor(object):
f'not {type(dataset)}') f'not {type(dataset)}')
self.data_loader = datasets.build_dataloader( self.data_loader = datasets.build_dataloader(
self.dataset, self.dataset,
imgs_per_gpu, samples_per_gpu=samples_per_gpu,
workers_per_gpu, workers_per_gpu=workers_per_gpu,
dist=dist_mode, dist=dist_mode,
shuffle=False, shuffle=False,
persistent_workers=persistent_workers, persistent_workers=persistent_workers,

View File

@ -3,7 +3,7 @@ from unittest.mock import ANY
import pytest import pytest
from mmselfsup.datasets import (ConcatDataset, DeepClusterDataset, from mmselfsup.datasets import (ConcatDataset, DeepClusterDataset,
RepeatDataset, build_dataset) RepeatDataset, build_dataloader, build_dataset)
DATASET_CONFIG = dict( DATASET_CONFIG = dict(
type='DeepClusterDataset', type='DeepClusterDataset',
@ -50,3 +50,19 @@ DATASET_CONFIG = dict(
]) ])
def test_build_dataset(cfg, expected_type): def test_build_dataset(cfg, expected_type):
assert isinstance(build_dataset(cfg), expected_type) assert isinstance(build_dataset(cfg), expected_type)
def test_build_dataloader():
dataset = build_dataset(DATASET_CONFIG)
with pytest.raises(ValueError):
data_loader = build_dataloader(dataset)
data_loader = build_dataloader(
dataset,
imgs_per_gpu=1,
samples_per_gpu=None,
dist=False,
)
assert len(data_loader) == 2
assert data_loader.batch_size == 1

View File

@ -156,9 +156,22 @@ def main():
dataset.data_source.data_infos = tmp_infos dataset.data_source.data_infos = tmp_infos
logger.info(f'Apply t-SNE to visualize {len(dataset)} samples.') logger.info(f'Apply t-SNE to visualize {len(dataset)} samples.')
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loader = build_dataloader( data_loader = build_dataloader(
dataset, dataset,
imgs_per_gpu=dataset_cfg.data.imgs_per_gpu, samples_per_gpu=dataset_cfg.data.samples_per_gpu,
workers_per_gpu=dataset_cfg.data.workers_per_gpu, workers_per_gpu=dataset_cfg.data.workers_per_gpu,
dist=distributed, dist=distributed,
shuffle=False) shuffle=False)

View File

@ -6,7 +6,7 @@ set -x
CFG=$1 # use cfgs under "configs/benchmarks/classification/imagenet/*.py" CFG=$1 # use cfgs under "configs/benchmarks/classification/imagenet/*.py"
PRETRAIN=$2 # pretrained model PRETRAIN=$2 # pretrained model
PY_ARGS=${@:3} PY_ARGS=${@:3}
GPUS=${GPUS:-8} # When changing GPUS, please also change imgs_per_gpu in the config file accordingly to ensure the total batch size is 256. GPUS=${GPUS:-8} # When changing GPUS, please also change samples_per_gpu in the config file accordingly to ensure the total batch size is 256.
PORT=${PORT:-29500} PORT=${PORT:-29500}
# set work_dir according to config path and pretrained model to distinguish different models # set work_dir according to config path and pretrained model to distinguish different models

View File

@ -8,7 +8,7 @@ JOB_NAME=$2
CFG=$3 # use cfgs under "configs/benchmarks/classification/imagenet/*.py" CFG=$3 # use cfgs under "configs/benchmarks/classification/imagenet/*.py"
PRETRAIN=$4 # pretrained model PRETRAIN=$4 # pretrained model
PY_ARGS=${@:5} PY_ARGS=${@:5}
GPUS=${GPUS:-8} # When changing GPUS, please also change imgs_per_gpu in the config file accordingly to ensure the total batch size is 256. GPUS=${GPUS:-8} # When changing GPUS, please also change samples_per_gpu in the config file accordingly to ensure the total batch size is 256.
GPUS_PER_NODE=${GPUS_PER_NODE:-8} GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5} CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PORT=${PORT:-29500} PORT=${PORT:-29500}

View File

@ -96,9 +96,23 @@ def main():
# build the dataloader # build the dataloader
dataset_cfg = mmcv.Config.fromfile(args.dataset_config) dataset_cfg = mmcv.Config.fromfile(args.dataset_config)
dataset = build_dataset(dataset_cfg.data.extract) dataset = build_dataset(dataset_cfg.data.extract)
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loader = build_dataloader( data_loader = build_dataloader(
dataset, dataset,
imgs_per_gpu=dataset_cfg.data.imgs_per_gpu, samples_per_gpu=dataset_cfg.data.samples_per_gpu,
workers_per_gpu=dataset_cfg.data.workers_per_gpu, workers_per_gpu=dataset_cfg.data.workers_per_gpu,
dist=distributed, dist=distributed,
shuffle=False) shuffle=False)

View File

@ -96,9 +96,22 @@ def main():
# build the dataloader # build the dataloader
dataset = build_dataset(cfg.data.val) dataset = build_dataset(cfg.data.val)
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loader = build_dataloader( data_loader = build_dataloader(
dataset, dataset,
imgs_per_gpu=cfg.data.imgs_per_gpu, samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed, dist=distributed,
shuffle=False) shuffle=False)