[Feature] Support batch augmentation through BatchAugSampler (#1757)

* [Fix] RepeatAugSampler -> BatchAugSampler

* update docs
pull/1760/head
Tong Gao 2023-03-07 11:29:53 +08:00 committed by GitHub
parent 82f81ff67c
commit 47f7fc06ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 53 additions and 45 deletions

View File

@ -9,6 +9,18 @@ mmocr.datasets
:local:
:backlinks: top
.. currentmodule:: mmocr.datasets.samplers
Samplers
---------------------------------------------
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
BatchAugSampler
.. currentmodule:: mmocr.datasets
Datasets

View File

@ -9,6 +9,18 @@ mmocr.datasets
:local:
:backlinks: top
.. currentmodule:: mmocr.datasets.samplers
Samplers
---------------------------------------------
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
BatchAugSampler
.. currentmodule:: mmocr.datasets
Datasets

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .repeat_aug import RepeatAugSampler
from .batch_aug import BatchAugSampler
__all__ = ['RepeatAugSampler']
__all__ = ['BatchAugSampler']

View File

@ -2,20 +2,22 @@ import math
from typing import Iterator, Optional, Sized
import torch
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmocr.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class RepeatAugSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset for
distributed, with repeated augmentation. It ensures that different each
class BatchAugSampler(Sampler):
"""Sampler that repeats the same data elements for num_repeats times. The
batch size should be divisible by num_repeats.
It ensures that different each
augmented version of a sample will be visible to a different process (GPU).
Heavily based on torch.utils.data.DistributedSampler.
This sampler was taken from
This sampler was modified from
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
Used in
Copyright (c) 2015-present, Facebook, Inc.
@ -40,11 +42,6 @@ class RepeatAugSampler(Sampler):
self.dataset = dataset
self.shuffle = shuffle
if not self.shuffle and is_main_process():
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.warning('The RepeatAugSampler always picks a '
'fixed part of data if `shuffle=False`.')
if seed is None:
seed = sync_random_seed()
@ -82,7 +79,7 @@ class RepeatAugSampler(Sampler):
assert len(indices) == self.num_samples
# return up to num selected samples
return iter(indices[:self.num_selected_samples])
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""

View File

@ -1,15 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from unittest import TestCase
from unittest.mock import patch
import torch
from mmengine.logging import MMLogger
from mmocr.datasets import RepeatAugSampler
from mmocr.datasets import BatchAugSampler
file = 'mmocr.datasets.samplers.repeat_aug.'
file = 'mmocr.datasets.samplers.batch_aug.'
class MockDist:
@ -28,7 +27,7 @@ class MockDist:
return self.dist_info[0] == 0
class TestRepeatAugSampler(TestCase):
class TestBatchAugSampler(TestCase):
def setUp(self):
self.data_length = 100
@ -36,63 +35,51 @@ class TestRepeatAugSampler(TestCase):
@patch(file + 'get_dist_info', return_value=(0, 1))
def test_non_dist(self, mock):
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
sampler = BatchAugSampler(self.dataset, num_repeats=3, shuffle=False)
self.assertEqual(sampler.world_size, 1)
self.assertEqual(sampler.rank, 0)
self.assertEqual(sampler.total_size, self.data_length * 3)
self.assertEqual(sampler.num_samples, self.data_length * 3)
self.assertEqual(sampler.num_selected_samples, self.data_length)
self.assertEqual(len(sampler), sampler.num_selected_samples)
indices = [x for x in range(self.data_length) for _ in range(3)]
self.assertEqual(list(sampler), indices[:self.data_length])
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
sampler = RepeatAugSampler(self.dataset, shuffle=False)
self.assertIn('always picks a fixed part', log.output[0])
self.assertEqual(list(sampler), indices)
@patch(file + 'get_dist_info', return_value=(2, 3))
@patch(file + 'is_main_process', return_value=False)
def test_dist(self, mock1, mock2):
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
def test_dist(self, mock):
sampler = BatchAugSampler(self.dataset, num_repeats=3, shuffle=False)
self.assertEqual(sampler.world_size, 3)
self.assertEqual(sampler.rank, 2)
self.assertEqual(sampler.num_samples, self.data_length)
self.assertEqual(sampler.total_size, self.data_length * 3)
self.assertEqual(sampler.num_selected_samples,
math.ceil(self.data_length / 3))
self.assertEqual(len(sampler), sampler.num_selected_samples)
indices = [x for x in range(self.data_length) for _ in range(3)]
self.assertEqual(
list(sampler), indices[2::3][:sampler.num_selected_samples])
logger = MMLogger.get_current_instance()
with patch.object(logger, 'warning') as mock_log:
sampler = RepeatAugSampler(self.dataset, shuffle=False)
sampler = BatchAugSampler(self.dataset, shuffle=False)
mock_log.assert_not_called()
@patch(file + 'get_dist_info', return_value=(0, 1))
@patch(file + 'sync_random_seed', return_value=7)
def test_shuffle(self, mock1, mock2):
# test seed=None
sampler = RepeatAugSampler(self.dataset, seed=None)
sampler = BatchAugSampler(self.dataset, seed=None)
self.assertEqual(sampler.seed, 7)
# test random seed
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0)
sampler = BatchAugSampler(self.dataset, shuffle=True, seed=0)
sampler.set_epoch(10)
g = torch.Generator()
g.manual_seed(10)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
indices = [x for x in indices
for _ in range(3)][:sampler.num_selected_samples]
indices = [
x for x in torch.randperm(len(self.dataset), generator=g)
for _ in range(3)
]
self.assertEqual(list(sampler), indices)
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42)
sampler = BatchAugSampler(self.dataset, shuffle=True, seed=42)
sampler.set_epoch(10)
g = torch.Generator()
g.manual_seed(42 + 10)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
indices = [x for x in indices
for _ in range(3)][:sampler.num_selected_samples]
indices = [
x for x in torch.randperm(len(self.dataset), generator=g)
for _ in range(3)
]
self.assertEqual(list(sampler), indices)