99 lines
3.7 KiB
Python
99 lines
3.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
import math
|
|
from unittest import TestCase
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from mmengine.logging import MMLogger
|
|
|
|
from mmcls.datasets import RepeatAugSampler
|
|
|
|
file = 'mmcls.datasets.samplers.repeat_aug.'
|
|
|
|
|
|
class MockDist:
|
|
|
|
def __init__(self, dist_info=(0, 1), seed=7):
|
|
self.dist_info = dist_info
|
|
self.seed = seed
|
|
|
|
def get_dist_info(self):
|
|
return self.dist_info
|
|
|
|
def sync_random_seed(self):
|
|
return self.seed
|
|
|
|
def is_main_process(self):
|
|
return self.dist_info[0] == 0
|
|
|
|
|
|
class TestRepeatAugSampler(TestCase):
|
|
|
|
def setUp(self):
|
|
self.data_length = 100
|
|
self.dataset = list(range(self.data_length))
|
|
|
|
@patch(file + 'get_dist_info', return_value=(0, 1))
|
|
def test_non_dist(self, mock):
|
|
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
|
self.assertEqual(sampler.world_size, 1)
|
|
self.assertEqual(sampler.rank, 0)
|
|
self.assertEqual(sampler.total_size, self.data_length * 3)
|
|
self.assertEqual(sampler.num_samples, self.data_length * 3)
|
|
self.assertEqual(sampler.num_selected_samples, self.data_length)
|
|
self.assertEqual(len(sampler), sampler.num_selected_samples)
|
|
indices = [x for x in range(self.data_length) for _ in range(3)]
|
|
self.assertEqual(list(sampler), indices[:self.data_length])
|
|
|
|
logger = MMLogger.get_current_instance()
|
|
with self.assertLogs(logger, 'WARN') as log:
|
|
sampler = RepeatAugSampler(self.dataset, shuffle=False)
|
|
self.assertIn('always picks a fixed part', log.output[0])
|
|
|
|
@patch(file + 'get_dist_info', return_value=(2, 3))
|
|
@patch(file + 'is_main_process', return_value=False)
|
|
def test_dist(self, mock1, mock2):
|
|
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
|
self.assertEqual(sampler.world_size, 3)
|
|
self.assertEqual(sampler.rank, 2)
|
|
self.assertEqual(sampler.num_samples, self.data_length)
|
|
self.assertEqual(sampler.total_size, self.data_length * 3)
|
|
self.assertEqual(sampler.num_selected_samples,
|
|
math.ceil(self.data_length / 3))
|
|
self.assertEqual(len(sampler), sampler.num_selected_samples)
|
|
indices = [x for x in range(self.data_length) for _ in range(3)]
|
|
self.assertEqual(
|
|
list(sampler), indices[2::3][:sampler.num_selected_samples])
|
|
|
|
logger = MMLogger.get_current_instance()
|
|
with patch.object(logger, 'warning') as mock_log:
|
|
sampler = RepeatAugSampler(self.dataset, shuffle=False)
|
|
mock_log.assert_not_called()
|
|
|
|
@patch(file + 'get_dist_info', return_value=(0, 1))
|
|
@patch(file + 'sync_random_seed', return_value=7)
|
|
def test_shuffle(self, mock1, mock2):
|
|
# test seed=None
|
|
sampler = RepeatAugSampler(self.dataset, seed=None)
|
|
self.assertEqual(sampler.seed, 7)
|
|
|
|
# test random seed
|
|
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0)
|
|
sampler.set_epoch(10)
|
|
g = torch.Generator()
|
|
g.manual_seed(10)
|
|
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
indices = [x for x in indices
|
|
for _ in range(3)][:sampler.num_selected_samples]
|
|
self.assertEqual(list(sampler), indices)
|
|
|
|
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42)
|
|
sampler.set_epoch(10)
|
|
g = torch.Generator()
|
|
g.manual_seed(42 + 10)
|
|
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
indices = [x for x in indices
|
|
for _ in range(3)][:sampler.num_selected_samples]
|
|
self.assertEqual(list(sampler), indices)
|