mirror of https://github.com/open-mmlab/mmocr.git
[Feature] Support batch augmentation through BatchAugSampler (#1757)
* [Fix] RepeatAugSampler -> BatchAugSampler * update docspull/1760/head
parent
82f81ff67c
commit
47f7fc06ed
|
@ -9,6 +9,18 @@ mmocr.datasets
|
|||
:local:
|
||||
:backlinks: top
|
||||
|
||||
.. currentmodule:: mmocr.datasets.samplers
|
||||
|
||||
Samplers
|
||||
---------------------------------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
BatchAugSampler
|
||||
|
||||
.. currentmodule:: mmocr.datasets
|
||||
|
||||
Datasets
|
||||
|
|
|
@ -9,6 +9,18 @@ mmocr.datasets
|
|||
:local:
|
||||
:backlinks: top
|
||||
|
||||
.. currentmodule:: mmocr.datasets.samplers
|
||||
|
||||
Samplers
|
||||
---------------------------------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
BatchAugSampler
|
||||
|
||||
.. currentmodule:: mmocr.datasets
|
||||
|
||||
Datasets
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .repeat_aug import RepeatAugSampler
|
||||
from .batch_aug import BatchAugSampler
|
||||
|
||||
__all__ = ['RepeatAugSampler']
|
||||
__all__ = ['BatchAugSampler']
|
||||
|
|
|
@ -2,20 +2,22 @@ import math
|
|||
from typing import Iterator, Optional, Sized
|
||||
|
||||
import torch
|
||||
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
|
||||
from mmengine.dist import get_dist_info, sync_random_seed
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from mmocr.registry import DATA_SAMPLERS
|
||||
|
||||
|
||||
@DATA_SAMPLERS.register_module()
|
||||
class RepeatAugSampler(Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset for
|
||||
distributed, with repeated augmentation. It ensures that different each
|
||||
class BatchAugSampler(Sampler):
|
||||
"""Sampler that repeats the same data elements for num_repeats times. The
|
||||
batch size should be divisible by num_repeats.
|
||||
|
||||
It ensures that different each
|
||||
augmented version of a sample will be visible to a different process (GPU).
|
||||
Heavily based on torch.utils.data.DistributedSampler.
|
||||
|
||||
This sampler was taken from
|
||||
This sampler was modified from
|
||||
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
|
||||
Used in
|
||||
Copyright (c) 2015-present, Facebook, Inc.
|
||||
|
@ -40,11 +42,6 @@ class RepeatAugSampler(Sampler):
|
|||
|
||||
self.dataset = dataset
|
||||
self.shuffle = shuffle
|
||||
if not self.shuffle and is_main_process():
|
||||
from mmengine.logging import MMLogger
|
||||
logger = MMLogger.get_current_instance()
|
||||
logger.warning('The RepeatAugSampler always picks a '
|
||||
'fixed part of data if `shuffle=False`.')
|
||||
|
||||
if seed is None:
|
||||
seed = sync_random_seed()
|
||||
|
@ -82,7 +79,7 @@ class RepeatAugSampler(Sampler):
|
|||
assert len(indices) == self.num_samples
|
||||
|
||||
# return up to num selected samples
|
||||
return iter(indices[:self.num_selected_samples])
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The number of samples in this rank."""
|
|
@ -1,15 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import math
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from mmengine.logging import MMLogger
|
||||
|
||||
from mmocr.datasets import RepeatAugSampler
|
||||
from mmocr.datasets import BatchAugSampler
|
||||
|
||||
file = 'mmocr.datasets.samplers.repeat_aug.'
|
||||
file = 'mmocr.datasets.samplers.batch_aug.'
|
||||
|
||||
|
||||
class MockDist:
|
||||
|
@ -28,7 +27,7 @@ class MockDist:
|
|||
return self.dist_info[0] == 0
|
||||
|
||||
|
||||
class TestRepeatAugSampler(TestCase):
|
||||
class TestBatchAugSampler(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.data_length = 100
|
||||
|
@ -36,63 +35,51 @@ class TestRepeatAugSampler(TestCase):
|
|||
|
||||
@patch(file + 'get_dist_info', return_value=(0, 1))
|
||||
def test_non_dist(self, mock):
|
||||
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
||||
sampler = BatchAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
||||
self.assertEqual(sampler.world_size, 1)
|
||||
self.assertEqual(sampler.rank, 0)
|
||||
self.assertEqual(sampler.total_size, self.data_length * 3)
|
||||
self.assertEqual(sampler.num_samples, self.data_length * 3)
|
||||
self.assertEqual(sampler.num_selected_samples, self.data_length)
|
||||
self.assertEqual(len(sampler), sampler.num_selected_samples)
|
||||
indices = [x for x in range(self.data_length) for _ in range(3)]
|
||||
self.assertEqual(list(sampler), indices[:self.data_length])
|
||||
|
||||
logger = MMLogger.get_current_instance()
|
||||
with self.assertLogs(logger, 'WARN') as log:
|
||||
sampler = RepeatAugSampler(self.dataset, shuffle=False)
|
||||
self.assertIn('always picks a fixed part', log.output[0])
|
||||
self.assertEqual(list(sampler), indices)
|
||||
|
||||
@patch(file + 'get_dist_info', return_value=(2, 3))
|
||||
@patch(file + 'is_main_process', return_value=False)
|
||||
def test_dist(self, mock1, mock2):
|
||||
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
||||
def test_dist(self, mock):
|
||||
sampler = BatchAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
||||
self.assertEqual(sampler.world_size, 3)
|
||||
self.assertEqual(sampler.rank, 2)
|
||||
self.assertEqual(sampler.num_samples, self.data_length)
|
||||
self.assertEqual(sampler.total_size, self.data_length * 3)
|
||||
self.assertEqual(sampler.num_selected_samples,
|
||||
math.ceil(self.data_length / 3))
|
||||
self.assertEqual(len(sampler), sampler.num_selected_samples)
|
||||
indices = [x for x in range(self.data_length) for _ in range(3)]
|
||||
self.assertEqual(
|
||||
list(sampler), indices[2::3][:sampler.num_selected_samples])
|
||||
|
||||
logger = MMLogger.get_current_instance()
|
||||
with patch.object(logger, 'warning') as mock_log:
|
||||
sampler = RepeatAugSampler(self.dataset, shuffle=False)
|
||||
sampler = BatchAugSampler(self.dataset, shuffle=False)
|
||||
mock_log.assert_not_called()
|
||||
|
||||
@patch(file + 'get_dist_info', return_value=(0, 1))
|
||||
@patch(file + 'sync_random_seed', return_value=7)
|
||||
def test_shuffle(self, mock1, mock2):
|
||||
# test seed=None
|
||||
sampler = RepeatAugSampler(self.dataset, seed=None)
|
||||
sampler = BatchAugSampler(self.dataset, seed=None)
|
||||
self.assertEqual(sampler.seed, 7)
|
||||
|
||||
# test random seed
|
||||
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0)
|
||||
sampler = BatchAugSampler(self.dataset, shuffle=True, seed=0)
|
||||
sampler.set_epoch(10)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(10)
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
indices = [x for x in indices
|
||||
for _ in range(3)][:sampler.num_selected_samples]
|
||||
indices = [
|
||||
x for x in torch.randperm(len(self.dataset), generator=g)
|
||||
for _ in range(3)
|
||||
]
|
||||
self.assertEqual(list(sampler), indices)
|
||||
|
||||
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42)
|
||||
sampler = BatchAugSampler(self.dataset, shuffle=True, seed=42)
|
||||
sampler.set_epoch(10)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(42 + 10)
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
indices = [x for x in indices
|
||||
for _ in range(3)][:sampler.num_selected_samples]
|
||||
indices = [
|
||||
x for x in torch.randperm(len(self.dataset), generator=g)
|
||||
for _ in range(3)
|
||||
]
|
||||
self.assertEqual(list(sampler), indices)
|
Loading…
Reference in New Issue