mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
74 lines
2.5 KiB
Python
74 lines
2.5 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
|
||
|
from collections import UserList
|
||
|
from unittest import TestCase
|
||
|
|
||
|
from mmrazor.models.subnet import Candidates
|
||
|
|
||
|
|
||
|
class TestCandidates(TestCase):
|
||
|
|
||
|
def setUp(self) -> None:
|
||
|
self.fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||
|
self.fake_subnet_with_score = (self.fake_subnet, 1.)
|
||
|
|
||
|
def test_init(self):
|
||
|
# initlist is None
|
||
|
candidates = Candidates()
|
||
|
self.assertEqual(len(candidates.data), 0)
|
||
|
# initlist is list
|
||
|
data = [self.fake_subnet] * 2
|
||
|
candidates = Candidates(data)
|
||
|
self.assertEqual(len(candidates.data), 2)
|
||
|
# initlist is UserList
|
||
|
data = UserList([self.fake_subnet] * 2)
|
||
|
self.assertEqual(len(candidates.data), 2)
|
||
|
|
||
|
def test_scores(self):
|
||
|
# test property: scores
|
||
|
data = [self.fake_subnet_with_score] * 2
|
||
|
candidates = Candidates(data)
|
||
|
self.assertEqual(candidates.scores, [1., 1.])
|
||
|
|
||
|
def test_subnets(self):
|
||
|
# test property: subnets
|
||
|
data = [self.fake_subnet_with_score] * 2
|
||
|
candidates = Candidates(data)
|
||
|
self.assertEqual(candidates.subnets, [self.fake_subnet] * 2)
|
||
|
|
||
|
def test_append(self):
|
||
|
# item is dict
|
||
|
candidates = Candidates()
|
||
|
candidates.append(self.fake_subnet)
|
||
|
self.assertEqual(len(candidates), 1)
|
||
|
# item is tuple
|
||
|
candidates = Candidates()
|
||
|
candidates.append(self.fake_subnet_with_score)
|
||
|
self.assertEqual(len(candidates), 1)
|
||
|
|
||
|
def test_insert(self):
|
||
|
# item is dict
|
||
|
candidates = Candidates([self.fake_subnet_with_score])
|
||
|
candidates.insert(1, self.fake_subnet)
|
||
|
self.assertEqual(len(candidates), 2)
|
||
|
# item is tuple
|
||
|
candidates = Candidates([self.fake_subnet_with_score])
|
||
|
candidates.insert(1, self.fake_subnet_with_score)
|
||
|
self.assertEqual(len(candidates), 2)
|
||
|
|
||
|
def test_extend(self):
|
||
|
# other is list
|
||
|
candidates = Candidates([self.fake_subnet_with_score])
|
||
|
candidates.extend([self.fake_subnet])
|
||
|
self.assertEqual(len(candidates), 2)
|
||
|
# other is UserList
|
||
|
candidates = Candidates([self.fake_subnet_with_score])
|
||
|
candidates.extend(UserList([self.fake_subnet_with_score]))
|
||
|
self.assertEqual(len(candidates), 2)
|
||
|
|
||
|
def test_set_score(self):
|
||
|
# test set_score
|
||
|
candidates = Candidates([self.fake_subnet_with_score])
|
||
|
candidates.set_score(0, 0.5)
|
||
|
self.assertEqual(candidates[0][1], 0.5)
|