# Copyright (c) OpenMMLab. All rights reserved. import math from unittest import TestCase from unittest.mock import patch import torch from mmengine.logging import MMLogger from mmpretrain.datasets import RepeatAugSampler file = 'mmpretrain.datasets.samplers.repeat_aug.' class MockDist: def __init__(self, dist_info=(0, 1), seed=7): self.dist_info = dist_info self.seed = seed def get_dist_info(self): return self.dist_info def sync_random_seed(self): return self.seed def is_main_process(self): return self.dist_info[0] == 0 class TestRepeatAugSampler(TestCase): def setUp(self): self.data_length = 100 self.dataset = list(range(self.data_length)) @patch(file + 'get_dist_info', return_value=(0, 1)) def test_non_dist(self, mock): sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False) self.assertEqual(sampler.world_size, 1) self.assertEqual(sampler.rank, 0) self.assertEqual(sampler.total_size, self.data_length * 3) self.assertEqual(sampler.num_samples, self.data_length * 3) self.assertEqual(sampler.num_selected_samples, self.data_length) self.assertEqual(len(sampler), sampler.num_selected_samples) indices = [x for x in range(self.data_length) for _ in range(3)] self.assertEqual(list(sampler), indices[:self.data_length]) logger = MMLogger.get_current_instance() with self.assertLogs(logger, 'WARN') as log: sampler = RepeatAugSampler(self.dataset, shuffle=False) self.assertIn('always picks a fixed part', log.output[0]) @patch(file + 'get_dist_info', return_value=(2, 3)) @patch(file + 'is_main_process', return_value=False) def test_dist(self, mock1, mock2): sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False) self.assertEqual(sampler.world_size, 3) self.assertEqual(sampler.rank, 2) self.assertEqual(sampler.num_samples, self.data_length) self.assertEqual(sampler.total_size, self.data_length * 3) self.assertEqual(sampler.num_selected_samples, math.ceil(self.data_length / 3)) self.assertEqual(len(sampler), sampler.num_selected_samples) indices = [x for x in range(self.data_length) for _ in range(3)] self.assertEqual( list(sampler), indices[2::3][:sampler.num_selected_samples]) logger = MMLogger.get_current_instance() with patch.object(logger, 'warning') as mock_log: sampler = RepeatAugSampler(self.dataset, shuffle=False) mock_log.assert_not_called() @patch(file + 'get_dist_info', return_value=(0, 1)) @patch(file + 'sync_random_seed', return_value=7) def test_shuffle(self, mock1, mock2): # test seed=None sampler = RepeatAugSampler(self.dataset, seed=None) self.assertEqual(sampler.seed, 7) # test random seed sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0) sampler.set_epoch(10) g = torch.Generator() g.manual_seed(10) indices = torch.randperm(len(self.dataset), generator=g).tolist() indices = [x for x in indices for _ in range(3)][:sampler.num_selected_samples] self.assertEqual(list(sampler), indices) sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42) sampler.set_epoch(10) g = torch.Generator() g.manual_seed(42 + 10) indices = torch.randperm(len(self.dataset), generator=g).tolist() indices = [x for x in indices for _ in range(3)][:sampler.num_selected_samples] self.assertEqual(list(sampler), indices)