mirror of https://github.com/RE-OWOD/RE-OWOD
24 lines
800 B
Python
24 lines
800 B
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
import unittest
|
|
from torch.utils.data.sampler import SequentialSampler
|
|
|
|
from detectron2.data.samplers import GroupedBatchSampler
|
|
|
|
|
|
class TestGroupedBatchSampler(unittest.TestCase):
|
|
def test_missing_group_id(self):
|
|
sampler = SequentialSampler(list(range(100)))
|
|
group_ids = [1] * 100
|
|
samples = GroupedBatchSampler(sampler, group_ids, 2)
|
|
|
|
for mini_batch in samples:
|
|
self.assertEqual(len(mini_batch), 2)
|
|
|
|
def test_groups(self):
|
|
sampler = SequentialSampler(list(range(100)))
|
|
group_ids = [1, 0] * 50
|
|
samples = GroupedBatchSampler(sampler, group_ids, 2)
|
|
|
|
for mini_batch in samples:
|
|
self.assertEqual((mini_batch[0] + mini_batch[1]) % 2, 0)
|