mmengine/tests/test_dataset/test_sampler.py

142 lines
5.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import patch
import numpy as np
import torch
from mmengine.dataset import DefaultSampler, InfiniteSampler
class TestDefaultSampler(TestCase):
def setUp(self):
self.data_length = 100
self.dataset = list(range(self.data_length))
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
def test_non_dist(self, mock):
sampler = DefaultSampler(self.dataset)
self.assertEqual(sampler.world_size, 1)
self.assertEqual(sampler.rank, 0)
# test round_up=True
sampler = DefaultSampler(self.dataset, round_up=True, shuffle=False)
self.assertEqual(sampler.total_size, self.data_length)
self.assertEqual(sampler.num_samples, self.data_length)
self.assertEqual(list(sampler), list(range(self.data_length)))
# test round_up=False
sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False)
self.assertEqual(sampler.total_size, self.data_length)
self.assertEqual(sampler.num_samples, self.data_length)
self.assertEqual(list(sampler), list(range(self.data_length)))
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(2, 3))
def test_dist(self, mock):
sampler = DefaultSampler(self.dataset)
self.assertEqual(sampler.world_size, 3)
self.assertEqual(sampler.rank, 2)
# test round_up=True
sampler = DefaultSampler(self.dataset, round_up=True, shuffle=False)
self.assertEqual(sampler.num_samples, np.ceil(self.data_length / 3))
self.assertEqual(sampler.total_size, sampler.num_samples * 3)
self.assertEqual(len(sampler), sampler.num_samples)
self.assertEqual(
list(sampler),
list(range(self.data_length))[2::3] + [1])
# test round_up=False
sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False)
self.assertEqual(sampler.num_samples,
np.ceil((self.data_length - 2) / 3))
self.assertEqual(sampler.total_size, self.data_length)
self.assertEqual(len(sampler), sampler.num_samples)
self.assertEqual(list(sampler), list(range(self.data_length))[2::3])
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
@patch('mmengine.dataset.sampler.sync_random_seed', return_value=7)
def test_shuffle(self, mock1, mock2):
# test seed=None
sampler = DefaultSampler(self.dataset, seed=None)
self.assertEqual(sampler.seed, 7)
# test random seed
sampler = DefaultSampler(self.dataset, shuffle=True, seed=0)
sampler.set_epoch(10)
g = torch.Generator()
g.manual_seed(10)
self.assertEqual(
list(sampler),
torch.randperm(len(self.dataset), generator=g).tolist())
sampler = DefaultSampler(self.dataset, shuffle=True, seed=42)
sampler.set_epoch(10)
g = torch.Generator()
g.manual_seed(42 + 10)
self.assertEqual(
list(sampler),
torch.randperm(len(self.dataset), generator=g).tolist())
class TestInfiniteSampler(TestCase):
def setUp(self):
self.data_length = 100
self.dataset = list(range(self.data_length))
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
def test_non_dist(self, mock):
sampler = InfiniteSampler(self.dataset)
self.assertEqual(sampler.world_size, 1)
self.assertEqual(sampler.rank, 0)
# test iteration
sampler = InfiniteSampler(self.dataset, shuffle=False)
self.assertEqual(len(sampler), self.data_length)
self.assertEqual(sampler.size, self.data_length)
sampler_iter = iter(sampler)
items = [next(sampler_iter) for _ in range(self.data_length * 2)]
self.assertEqual(items, list(range(self.data_length)) * 2)
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(2, 3))
def test_dist(self, mock):
sampler = InfiniteSampler(self.dataset)
self.assertEqual(sampler.world_size, 3)
self.assertEqual(sampler.rank, 2)
# test iteration
sampler = InfiniteSampler(self.dataset, shuffle=False)
self.assertEqual(len(sampler), self.data_length)
self.assertEqual(sampler.size, self.data_length)
targets = (list(range(self.data_length)) * 2)[2::3]
sampler_iter = iter(sampler)
samples = [next(sampler_iter) for _ in range(len(targets))]
print(samples)
self.assertEqual(samples, targets)
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
@patch('mmengine.dataset.sampler.sync_random_seed', return_value=7)
def test_shuffle(self, mock1, mock2):
# test seed=None
sampler = InfiniteSampler(self.dataset, seed=None)
self.assertEqual(sampler.seed, 7)
# test the random seed
sampler = InfiniteSampler(self.dataset, shuffle=True, seed=42)
sampler_iter = iter(sampler)
samples = [next(sampler_iter) for _ in range(self.data_length)]
g = torch.Generator()
g.manual_seed(42)
self.assertEqual(
samples,
torch.randperm(self.data_length, generator=g).tolist())
def test_set_epoch(self):
sampler = InfiniteSampler(self.dataset)
sampler.set_epoch(10)