mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Feature] Support batch augmentation through BatchAugSampler (#1757)
* [Fix] RepeatAugSampler -> BatchAugSampler * update docs
This commit is contained in:
parent
82f81ff67c
commit
47f7fc06ed
@ -9,6 +9,18 @@ mmocr.datasets
|
|||||||
:local:
|
:local:
|
||||||
:backlinks: top
|
:backlinks: top
|
||||||
|
|
||||||
|
.. currentmodule:: mmocr.datasets.samplers
|
||||||
|
|
||||||
|
Samplers
|
||||||
|
---------------------------------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
BatchAugSampler
|
||||||
|
|
||||||
.. currentmodule:: mmocr.datasets
|
.. currentmodule:: mmocr.datasets
|
||||||
|
|
||||||
Datasets
|
Datasets
|
||||||
|
@ -9,6 +9,18 @@ mmocr.datasets
|
|||||||
:local:
|
:local:
|
||||||
:backlinks: top
|
:backlinks: top
|
||||||
|
|
||||||
|
.. currentmodule:: mmocr.datasets.samplers
|
||||||
|
|
||||||
|
Samplers
|
||||||
|
---------------------------------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
BatchAugSampler
|
||||||
|
|
||||||
.. currentmodule:: mmocr.datasets
|
.. currentmodule:: mmocr.datasets
|
||||||
|
|
||||||
Datasets
|
Datasets
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .repeat_aug import RepeatAugSampler
|
from .batch_aug import BatchAugSampler
|
||||||
|
|
||||||
__all__ = ['RepeatAugSampler']
|
__all__ = ['BatchAugSampler']
|
||||||
|
@ -2,20 +2,22 @@ import math
|
|||||||
from typing import Iterator, Optional, Sized
|
from typing import Iterator, Optional, Sized
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
|
from mmengine.dist import get_dist_info, sync_random_seed
|
||||||
from torch.utils.data import Sampler
|
from torch.utils.data import Sampler
|
||||||
|
|
||||||
from mmocr.registry import DATA_SAMPLERS
|
from mmocr.registry import DATA_SAMPLERS
|
||||||
|
|
||||||
|
|
||||||
@DATA_SAMPLERS.register_module()
|
@DATA_SAMPLERS.register_module()
|
||||||
class RepeatAugSampler(Sampler):
|
class BatchAugSampler(Sampler):
|
||||||
"""Sampler that restricts data loading to a subset of the dataset for
|
"""Sampler that repeats the same data elements for num_repeats times. The
|
||||||
distributed, with repeated augmentation. It ensures that different each
|
batch size should be divisible by num_repeats.
|
||||||
|
|
||||||
|
It ensures that different each
|
||||||
augmented version of a sample will be visible to a different process (GPU).
|
augmented version of a sample will be visible to a different process (GPU).
|
||||||
Heavily based on torch.utils.data.DistributedSampler.
|
Heavily based on torch.utils.data.DistributedSampler.
|
||||||
|
|
||||||
This sampler was taken from
|
This sampler was modified from
|
||||||
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
|
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
|
||||||
Used in
|
Used in
|
||||||
Copyright (c) 2015-present, Facebook, Inc.
|
Copyright (c) 2015-present, Facebook, Inc.
|
||||||
@ -40,11 +42,6 @@ class RepeatAugSampler(Sampler):
|
|||||||
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
if not self.shuffle and is_main_process():
|
|
||||||
from mmengine.logging import MMLogger
|
|
||||||
logger = MMLogger.get_current_instance()
|
|
||||||
logger.warning('The RepeatAugSampler always picks a '
|
|
||||||
'fixed part of data if `shuffle=False`.')
|
|
||||||
|
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = sync_random_seed()
|
seed = sync_random_seed()
|
||||||
@ -82,7 +79,7 @@ class RepeatAugSampler(Sampler):
|
|||||||
assert len(indices) == self.num_samples
|
assert len(indices) == self.num_samples
|
||||||
|
|
||||||
# return up to num selected samples
|
# return up to num selected samples
|
||||||
return iter(indices[:self.num_selected_samples])
|
return iter(indices)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""The number of samples in this rank."""
|
"""The number of samples in this rank."""
|
@ -1,15 +1,14 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
import math
|
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmengine.logging import MMLogger
|
from mmengine.logging import MMLogger
|
||||||
|
|
||||||
from mmocr.datasets import RepeatAugSampler
|
from mmocr.datasets import BatchAugSampler
|
||||||
|
|
||||||
file = 'mmocr.datasets.samplers.repeat_aug.'
|
file = 'mmocr.datasets.samplers.batch_aug.'
|
||||||
|
|
||||||
|
|
||||||
class MockDist:
|
class MockDist:
|
||||||
@ -28,7 +27,7 @@ class MockDist:
|
|||||||
return self.dist_info[0] == 0
|
return self.dist_info[0] == 0
|
||||||
|
|
||||||
|
|
||||||
class TestRepeatAugSampler(TestCase):
|
class TestBatchAugSampler(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.data_length = 100
|
self.data_length = 100
|
||||||
@ -36,63 +35,51 @@ class TestRepeatAugSampler(TestCase):
|
|||||||
|
|
||||||
@patch(file + 'get_dist_info', return_value=(0, 1))
|
@patch(file + 'get_dist_info', return_value=(0, 1))
|
||||||
def test_non_dist(self, mock):
|
def test_non_dist(self, mock):
|
||||||
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
sampler = BatchAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
||||||
self.assertEqual(sampler.world_size, 1)
|
self.assertEqual(sampler.world_size, 1)
|
||||||
self.assertEqual(sampler.rank, 0)
|
self.assertEqual(sampler.rank, 0)
|
||||||
self.assertEqual(sampler.total_size, self.data_length * 3)
|
self.assertEqual(sampler.total_size, self.data_length * 3)
|
||||||
self.assertEqual(sampler.num_samples, self.data_length * 3)
|
self.assertEqual(sampler.num_samples, self.data_length * 3)
|
||||||
self.assertEqual(sampler.num_selected_samples, self.data_length)
|
|
||||||
self.assertEqual(len(sampler), sampler.num_selected_samples)
|
|
||||||
indices = [x for x in range(self.data_length) for _ in range(3)]
|
indices = [x for x in range(self.data_length) for _ in range(3)]
|
||||||
self.assertEqual(list(sampler), indices[:self.data_length])
|
self.assertEqual(list(sampler), indices)
|
||||||
|
|
||||||
logger = MMLogger.get_current_instance()
|
|
||||||
with self.assertLogs(logger, 'WARN') as log:
|
|
||||||
sampler = RepeatAugSampler(self.dataset, shuffle=False)
|
|
||||||
self.assertIn('always picks a fixed part', log.output[0])
|
|
||||||
|
|
||||||
@patch(file + 'get_dist_info', return_value=(2, 3))
|
@patch(file + 'get_dist_info', return_value=(2, 3))
|
||||||
@patch(file + 'is_main_process', return_value=False)
|
def test_dist(self, mock):
|
||||||
def test_dist(self, mock1, mock2):
|
sampler = BatchAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
||||||
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
|
||||||
self.assertEqual(sampler.world_size, 3)
|
self.assertEqual(sampler.world_size, 3)
|
||||||
self.assertEqual(sampler.rank, 2)
|
self.assertEqual(sampler.rank, 2)
|
||||||
self.assertEqual(sampler.num_samples, self.data_length)
|
self.assertEqual(sampler.num_samples, self.data_length)
|
||||||
self.assertEqual(sampler.total_size, self.data_length * 3)
|
self.assertEqual(sampler.total_size, self.data_length * 3)
|
||||||
self.assertEqual(sampler.num_selected_samples,
|
|
||||||
math.ceil(self.data_length / 3))
|
|
||||||
self.assertEqual(len(sampler), sampler.num_selected_samples)
|
|
||||||
indices = [x for x in range(self.data_length) for _ in range(3)]
|
|
||||||
self.assertEqual(
|
|
||||||
list(sampler), indices[2::3][:sampler.num_selected_samples])
|
|
||||||
|
|
||||||
logger = MMLogger.get_current_instance()
|
logger = MMLogger.get_current_instance()
|
||||||
with patch.object(logger, 'warning') as mock_log:
|
with patch.object(logger, 'warning') as mock_log:
|
||||||
sampler = RepeatAugSampler(self.dataset, shuffle=False)
|
sampler = BatchAugSampler(self.dataset, shuffle=False)
|
||||||
mock_log.assert_not_called()
|
mock_log.assert_not_called()
|
||||||
|
|
||||||
@patch(file + 'get_dist_info', return_value=(0, 1))
|
@patch(file + 'get_dist_info', return_value=(0, 1))
|
||||||
@patch(file + 'sync_random_seed', return_value=7)
|
@patch(file + 'sync_random_seed', return_value=7)
|
||||||
def test_shuffle(self, mock1, mock2):
|
def test_shuffle(self, mock1, mock2):
|
||||||
# test seed=None
|
# test seed=None
|
||||||
sampler = RepeatAugSampler(self.dataset, seed=None)
|
sampler = BatchAugSampler(self.dataset, seed=None)
|
||||||
self.assertEqual(sampler.seed, 7)
|
self.assertEqual(sampler.seed, 7)
|
||||||
|
|
||||||
# test random seed
|
# test random seed
|
||||||
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0)
|
sampler = BatchAugSampler(self.dataset, shuffle=True, seed=0)
|
||||||
sampler.set_epoch(10)
|
sampler.set_epoch(10)
|
||||||
g = torch.Generator()
|
g = torch.Generator()
|
||||||
g.manual_seed(10)
|
g.manual_seed(10)
|
||||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
indices = [
|
||||||
indices = [x for x in indices
|
x for x in torch.randperm(len(self.dataset), generator=g)
|
||||||
for _ in range(3)][:sampler.num_selected_samples]
|
for _ in range(3)
|
||||||
|
]
|
||||||
self.assertEqual(list(sampler), indices)
|
self.assertEqual(list(sampler), indices)
|
||||||
|
|
||||||
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42)
|
sampler = BatchAugSampler(self.dataset, shuffle=True, seed=42)
|
||||||
sampler.set_epoch(10)
|
sampler.set_epoch(10)
|
||||||
g = torch.Generator()
|
g = torch.Generator()
|
||||||
g.manual_seed(42 + 10)
|
g.manual_seed(42 + 10)
|
||||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
indices = [
|
||||||
indices = [x for x in indices
|
x for x in torch.randperm(len(self.dataset), generator=g)
|
||||||
for _ in range(3)][:sampler.num_selected_samples]
|
for _ in range(3)
|
||||||
|
]
|
||||||
self.assertEqual(list(sampler), indices)
|
self.assertEqual(list(sampler), indices)
|
Loading…
x
Reference in New Issue
Block a user