[Feature] Support batch augmentation through BatchAugSampler (#1757)

* [Fix] RepeatAugSampler -> BatchAugSampler

* update docs
This commit is contained in:
Tong Gao 2023-03-07 11:29:53 +08:00 committed by GitHub
parent 82f81ff67c
commit 47f7fc06ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 53 additions and 45 deletions

View File

@ -9,6 +9,18 @@ mmocr.datasets
:local: :local:
:backlinks: top :backlinks: top
.. currentmodule:: mmocr.datasets.samplers
Samplers
---------------------------------------------
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
BatchAugSampler
.. currentmodule:: mmocr.datasets .. currentmodule:: mmocr.datasets
Datasets Datasets

View File

@ -9,6 +9,18 @@ mmocr.datasets
:local: :local:
:backlinks: top :backlinks: top
.. currentmodule:: mmocr.datasets.samplers
Samplers
---------------------------------------------
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
BatchAugSampler
.. currentmodule:: mmocr.datasets .. currentmodule:: mmocr.datasets
Datasets Datasets

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .repeat_aug import RepeatAugSampler from .batch_aug import BatchAugSampler
__all__ = ['RepeatAugSampler'] __all__ = ['BatchAugSampler']

View File

@ -2,20 +2,22 @@ import math
from typing import Iterator, Optional, Sized from typing import Iterator, Optional, Sized
import torch import torch
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler from torch.utils.data import Sampler
from mmocr.registry import DATA_SAMPLERS from mmocr.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module() @DATA_SAMPLERS.register_module()
class RepeatAugSampler(Sampler): class BatchAugSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset for """Sampler that repeats the same data elements for num_repeats times. The
distributed, with repeated augmentation. It ensures that different each batch size should be divisible by num_repeats.
It ensures that different each
augmented version of a sample will be visible to a different process (GPU). augmented version of a sample will be visible to a different process (GPU).
Heavily based on torch.utils.data.DistributedSampler. Heavily based on torch.utils.data.DistributedSampler.
This sampler was taken from This sampler was modified from
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
Used in Used in
Copyright (c) 2015-present, Facebook, Inc. Copyright (c) 2015-present, Facebook, Inc.
@ -40,11 +42,6 @@ class RepeatAugSampler(Sampler):
self.dataset = dataset self.dataset = dataset
self.shuffle = shuffle self.shuffle = shuffle
if not self.shuffle and is_main_process():
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.warning('The RepeatAugSampler always picks a '
'fixed part of data if `shuffle=False`.')
if seed is None: if seed is None:
seed = sync_random_seed() seed = sync_random_seed()
@ -82,7 +79,7 @@ class RepeatAugSampler(Sampler):
assert len(indices) == self.num_samples assert len(indices) == self.num_samples
# return up to num selected samples # return up to num selected samples
return iter(indices[:self.num_selected_samples]) return iter(indices)
def __len__(self) -> int: def __len__(self) -> int:
"""The number of samples in this rank.""" """The number of samples in this rank."""

View File

@ -1,15 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math
from unittest import TestCase from unittest import TestCase
from unittest.mock import patch from unittest.mock import patch
import torch import torch
from mmengine.logging import MMLogger from mmengine.logging import MMLogger
from mmocr.datasets import RepeatAugSampler from mmocr.datasets import BatchAugSampler
file = 'mmocr.datasets.samplers.repeat_aug.' file = 'mmocr.datasets.samplers.batch_aug.'
class MockDist: class MockDist:
@ -28,7 +27,7 @@ class MockDist:
return self.dist_info[0] == 0 return self.dist_info[0] == 0
class TestRepeatAugSampler(TestCase): class TestBatchAugSampler(TestCase):
def setUp(self): def setUp(self):
self.data_length = 100 self.data_length = 100
@ -36,63 +35,51 @@ class TestRepeatAugSampler(TestCase):
@patch(file + 'get_dist_info', return_value=(0, 1)) @patch(file + 'get_dist_info', return_value=(0, 1))
def test_non_dist(self, mock): def test_non_dist(self, mock):
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False) sampler = BatchAugSampler(self.dataset, num_repeats=3, shuffle=False)
self.assertEqual(sampler.world_size, 1) self.assertEqual(sampler.world_size, 1)
self.assertEqual(sampler.rank, 0) self.assertEqual(sampler.rank, 0)
self.assertEqual(sampler.total_size, self.data_length * 3) self.assertEqual(sampler.total_size, self.data_length * 3)
self.assertEqual(sampler.num_samples, self.data_length * 3) self.assertEqual(sampler.num_samples, self.data_length * 3)
self.assertEqual(sampler.num_selected_samples, self.data_length)
self.assertEqual(len(sampler), sampler.num_selected_samples)
indices = [x for x in range(self.data_length) for _ in range(3)] indices = [x for x in range(self.data_length) for _ in range(3)]
self.assertEqual(list(sampler), indices[:self.data_length]) self.assertEqual(list(sampler), indices)
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
sampler = RepeatAugSampler(self.dataset, shuffle=False)
self.assertIn('always picks a fixed part', log.output[0])
@patch(file + 'get_dist_info', return_value=(2, 3)) @patch(file + 'get_dist_info', return_value=(2, 3))
@patch(file + 'is_main_process', return_value=False) def test_dist(self, mock):
def test_dist(self, mock1, mock2): sampler = BatchAugSampler(self.dataset, num_repeats=3, shuffle=False)
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
self.assertEqual(sampler.world_size, 3) self.assertEqual(sampler.world_size, 3)
self.assertEqual(sampler.rank, 2) self.assertEqual(sampler.rank, 2)
self.assertEqual(sampler.num_samples, self.data_length) self.assertEqual(sampler.num_samples, self.data_length)
self.assertEqual(sampler.total_size, self.data_length * 3) self.assertEqual(sampler.total_size, self.data_length * 3)
self.assertEqual(sampler.num_selected_samples,
math.ceil(self.data_length / 3))
self.assertEqual(len(sampler), sampler.num_selected_samples)
indices = [x for x in range(self.data_length) for _ in range(3)]
self.assertEqual(
list(sampler), indices[2::3][:sampler.num_selected_samples])
logger = MMLogger.get_current_instance() logger = MMLogger.get_current_instance()
with patch.object(logger, 'warning') as mock_log: with patch.object(logger, 'warning') as mock_log:
sampler = RepeatAugSampler(self.dataset, shuffle=False) sampler = BatchAugSampler(self.dataset, shuffle=False)
mock_log.assert_not_called() mock_log.assert_not_called()
@patch(file + 'get_dist_info', return_value=(0, 1)) @patch(file + 'get_dist_info', return_value=(0, 1))
@patch(file + 'sync_random_seed', return_value=7) @patch(file + 'sync_random_seed', return_value=7)
def test_shuffle(self, mock1, mock2): def test_shuffle(self, mock1, mock2):
# test seed=None # test seed=None
sampler = RepeatAugSampler(self.dataset, seed=None) sampler = BatchAugSampler(self.dataset, seed=None)
self.assertEqual(sampler.seed, 7) self.assertEqual(sampler.seed, 7)
# test random seed # test random seed
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0) sampler = BatchAugSampler(self.dataset, shuffle=True, seed=0)
sampler.set_epoch(10) sampler.set_epoch(10)
g = torch.Generator() g = torch.Generator()
g.manual_seed(10) g.manual_seed(10)
indices = torch.randperm(len(self.dataset), generator=g).tolist() indices = [
indices = [x for x in indices x for x in torch.randperm(len(self.dataset), generator=g)
for _ in range(3)][:sampler.num_selected_samples] for _ in range(3)
]
self.assertEqual(list(sampler), indices) self.assertEqual(list(sampler), indices)
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42) sampler = BatchAugSampler(self.dataset, shuffle=True, seed=42)
sampler.set_epoch(10) sampler.set_epoch(10)
g = torch.Generator() g = torch.Generator()
g.manual_seed(42 + 10) g.manual_seed(42 + 10)
indices = torch.randperm(len(self.dataset), generator=g).tolist() indices = [
indices = [x for x in indices x for x in torch.randperm(len(self.dataset), generator=g)
for _ in range(3)][:sampler.num_selected_samples] for _ in range(3)
]
self.assertEqual(list(sampler), indices) self.assertEqual(list(sampler), indices)