# Copyright (c) OpenMMLab. All rights reserved. from functools import partial from unittest import TestCase from unittest.mock import patch import numpy as np import torch from mmengine.data import DefaultSampler, InfiniteSampler class TestDefaultSampler(TestCase): def setUp(self): self.data_length = 100 self.dataset = list(range(self.data_length)) @patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1)) def test_non_dist(self, mock): sampler = DefaultSampler(self.dataset) self.assertEqual(sampler.world_size, 1) self.assertEqual(sampler.rank, 0) # test round_up=True sampler = DefaultSampler(self.dataset, round_up=True, shuffle=False) self.assertEqual(sampler.total_size, self.data_length) self.assertEqual(sampler.num_samples, self.data_length) self.assertEqual(list(sampler), list(range(self.data_length))) # test round_up=False sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False) self.assertEqual(sampler.total_size, self.data_length) self.assertEqual(sampler.num_samples, self.data_length) self.assertEqual(list(sampler), list(range(self.data_length))) @patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3)) def test_dist(self, mock): sampler = DefaultSampler(self.dataset) self.assertEqual(sampler.world_size, 3) self.assertEqual(sampler.rank, 2) # test round_up=True sampler = DefaultSampler(self.dataset, round_up=True, shuffle=False) self.assertEqual(sampler.num_samples, np.ceil(self.data_length / 3)) self.assertEqual(sampler.total_size, sampler.num_samples * 3) self.assertEqual(len(sampler), sampler.num_samples) self.assertEqual( list(sampler), list(range(self.data_length))[2::3] + [1]) # test round_up=False sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False) self.assertEqual(sampler.num_samples, np.ceil((self.data_length - 2) / 3)) self.assertEqual(sampler.total_size, self.data_length) self.assertEqual(len(sampler), sampler.num_samples) self.assertEqual(list(sampler), list(range(self.data_length))[2::3]) @patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1)) @patch('mmengine.data.sampler.sync_random_seed', return_value=7) def test_shuffle(self, mock1, mock2): # test seed=None sampler = DefaultSampler(self.dataset, seed=None) self.assertEqual(sampler.seed, 7) # test random seed sampler = DefaultSampler(self.dataset, shuffle=True, seed=0) sampler.set_epoch(10) g = torch.Generator() g.manual_seed(10) self.assertEqual( list(sampler), torch.randperm(len(self.dataset), generator=g).tolist()) sampler = DefaultSampler(self.dataset, shuffle=True, seed=42) sampler.set_epoch(10) g = torch.Generator() g.manual_seed(42 + 10) self.assertEqual( list(sampler), torch.randperm(len(self.dataset), generator=g).tolist()) class TestInfiniteSampler(TestCase): def setUp(self): self.data_length = 100 self.dataset = list(range(self.data_length)) @patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1)) def test_non_dist(self, mock): sampler = InfiniteSampler(self.dataset) self.assertEqual(sampler.world_size, 1) self.assertEqual(sampler.rank, 0) # test iteration sampler = InfiniteSampler(self.dataset, shuffle=False) self.assertEqual(len(sampler), self.data_length) self.assertEqual(sampler.size, self.data_length) items = [next(iter(sampler)) for _ in range(self.data_length * 2)] self.assertEqual(items, list(range(self.data_length)) * 2) @patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3)) def test_dist(self, mock): sampler = InfiniteSampler(self.dataset) self.assertEqual(sampler.world_size, 3) self.assertEqual(sampler.rank, 2) # test iteration sampler = InfiniteSampler(self.dataset, shuffle=False) self.assertEqual(len(sampler), self.data_length) self.assertEqual(sampler.size, self.data_length) targets = (list(range(self.data_length)) * 2)[2::3] samples = [next(iter(sampler)) for _ in range(len(targets))] self.assertEqual(samples, targets) @patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1)) @patch('mmengine.data.sampler.sync_random_seed', return_value=7) def test_shuffle(self, mock1, mock2): # test seed=None sampler = InfiniteSampler(self.dataset, seed=None) self.assertEqual(sampler.seed, 7) # test the random seed sampler = InfiniteSampler(self.dataset, shuffle=True, seed=42) sampler_iter = iter(sampler) samples = [next(sampler_iter) for _ in range(self.data_length)] g = torch.Generator() g.manual_seed(42) self.assertEqual( samples, torch.randperm(self.data_length, generator=g).tolist()) def test_set_epoch(self): sampler = InfiniteSampler(self.dataset) self.assertRaises(NotImplementedError, partial(sampler.set_epoch, 10))