222 lines
7.7 KiB
Python
222 lines
7.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
from unittest import TestCase
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.config import Config
|
|
from mmengine.evaluator import BaseMetric
|
|
from mmengine.model import BaseModel
|
|
from mmengine.runner import Runner
|
|
from torch.utils.data import Dataset
|
|
|
|
from mmrazor.models.subnet import FixSubnet
|
|
from mmrazor.registry import DATASETS, METRICS, MODELS
|
|
from mmrazor.runners import GreedySamplerTrainLoop # noqa: F401
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ToyModel_GreedySamplerTrainLoop(BaseModel):
|
|
|
|
@patch('mmrazor.models.mutators.OneShotModuleMutator')
|
|
def __init__(self, mock_mutator):
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(2, 2)
|
|
self.linear2 = nn.Linear(2, 1)
|
|
self.mutator = mock_mutator
|
|
|
|
def forward(self, batch_inputs, labels, mode='tensor'):
|
|
labels = torch.stack(labels)
|
|
outputs = self.linear1(batch_inputs)
|
|
outputs = self.linear2(outputs)
|
|
|
|
if mode == 'tensor':
|
|
return outputs
|
|
elif mode == 'loss':
|
|
loss = (labels - outputs).sum()
|
|
outputs = dict(loss=loss)
|
|
return outputs
|
|
elif mode == 'predict':
|
|
outputs = dict(log_vars=dict(a=1, b=0.5))
|
|
return outputs
|
|
|
|
def sample_subnet(self):
|
|
return self.mutator.sample_choices()
|
|
|
|
def set_subnet(self, subnet):
|
|
self.mutator.set_choices(subnet)
|
|
|
|
def export_fix_subnet(self):
|
|
pass
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class ToyDataset_GreedySamplerTrainLoop(Dataset):
|
|
METAINFO = dict() # type: ignore
|
|
data = torch.randn(12, 2)
|
|
label = torch.ones(12)
|
|
|
|
@property
|
|
def metainfo(self):
|
|
return self.METAINFO
|
|
|
|
def __len__(self):
|
|
return self.data.size(0)
|
|
|
|
def __getitem__(self, index):
|
|
return dict(inputs=self.data[index], data_sample=self.label[index])
|
|
|
|
|
|
@METRICS.register_module()
|
|
class ToyMetric_GreedySamplerTrainLoop(BaseMetric):
|
|
|
|
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
|
super().__init__(collect_device=collect_device)
|
|
self.dummy_metrics = dummy_metrics
|
|
|
|
def process(self, data_samples, predictions):
|
|
result = {'acc': 1}
|
|
self.results.append(result)
|
|
|
|
def compute_metrics(self, results):
|
|
return dict(acc=1)
|
|
|
|
|
|
class TestGreedySamplerTrainLoop(TestCase):
|
|
|
|
def setUp(self):
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
|
|
val_dataloader = dict(
|
|
dataset=dict(type='ToyDataset_GreedySamplerTrainLoop'),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0)
|
|
val_evaluator = dict(type='ToyMetric_GreedySamplerTrainLoop')
|
|
|
|
iter_based_cfg = dict(
|
|
default_scope='mmrazor',
|
|
model=dict(type='ToyModel_GreedySamplerTrainLoop'),
|
|
work_dir=self.temp_dir,
|
|
train_dataloader=dict(
|
|
dataset=dict(type='ToyDataset_GreedySamplerTrainLoop'),
|
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
val_dataloader=val_dataloader,
|
|
optim_wrapper=dict(
|
|
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
|
|
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
|
val_evaluator=val_evaluator,
|
|
train_cfg=dict(
|
|
type='GreedySamplerTrainLoop',
|
|
dataloader_val=val_dataloader,
|
|
evaluator=val_evaluator,
|
|
max_iters=12,
|
|
val_interval=2,
|
|
score_key='acc',
|
|
flops_range=None,
|
|
num_candidates=4,
|
|
num_samples=2,
|
|
top_k=2,
|
|
prob_schedule='linear',
|
|
schedule_start_iter=4,
|
|
schedule_end_iter=10,
|
|
init_prob=0.,
|
|
max_prob=0.8),
|
|
val_cfg=dict(),
|
|
custom_hooks=[],
|
|
default_hooks=dict(
|
|
runtime_info=dict(type='RuntimeInfoHook'),
|
|
timer=dict(type='IterTimerHook'),
|
|
logger=dict(type='LoggerHook'),
|
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
|
checkpoint=dict(
|
|
type='CheckpointHook', interval=1, by_epoch=False),
|
|
sampler_seed=dict(type='DistSamplerSeedHook')),
|
|
launcher='none',
|
|
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
|
)
|
|
self.iter_based_cfg = Config(iter_based_cfg)
|
|
|
|
def tearDown(self):
|
|
shutil.rmtree(self.temp_dir)
|
|
|
|
def test_init(self):
|
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
|
cfg.experiment_name = 'test_init_GreedySamplerTrainLoop'
|
|
runner = Runner.from_cfg(cfg)
|
|
loop = runner.build_train_loop(cfg.train_cfg)
|
|
self.assertIsInstance(loop, GreedySamplerTrainLoop)
|
|
|
|
def test_update_cur_prob(self):
|
|
# prob_schedule = linear
|
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
|
cfg.experiment_name = 'test_update_cur_prob1'
|
|
runner = Runner.from_cfg(cfg)
|
|
loop = runner.build_train_loop(cfg.train_cfg)
|
|
|
|
loop.update_cur_prob(loop.schedule_end_iter - 1)
|
|
self.assertGreater(loop.max_prob, loop.cur_prob)
|
|
loop.update_cur_prob(loop.schedule_end_iter + 1)
|
|
self.assertEqual(loop.max_prob, loop.cur_prob)
|
|
|
|
# prob_schedule = consine
|
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
|
cfg.experiment_name = 'test_update_cur_prob2'
|
|
cfg.train_cfg.prob_schedule = 'consine'
|
|
runner = Runner.from_cfg(cfg)
|
|
loop = runner.build_train_loop(cfg.train_cfg)
|
|
|
|
loop.update_cur_prob(loop.schedule_end_iter - 1)
|
|
self.assertGreater(loop.max_prob, loop.cur_prob)
|
|
loop.update_cur_prob(loop.schedule_end_iter + 1)
|
|
self.assertEqual(loop.max_prob, loop.cur_prob)
|
|
|
|
def test_sample_subnet(self):
|
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
|
cfg.experiment_name = 'test_sample_subnet'
|
|
runner = Runner.from_cfg(cfg)
|
|
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
|
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
|
loop = runner.build_train_loop(cfg.train_cfg)
|
|
loop.cur_prob = loop.max_prob
|
|
self.assertEqual(len(loop.top_k_candidates), 0)
|
|
|
|
loop._iter = loop.val_interval
|
|
subnet = loop.sample_subnet()
|
|
self.assertEqual(subnet, fake_subnet)
|
|
self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1)
|
|
|
|
@patch('mmrazor.models.subnet.FlopsEstimator.get_model_complexity_info')
|
|
def test_run(self, mock_flops):
|
|
# test run with flops_range=None
|
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
|
cfg.experiment_name = 'test_run1'
|
|
runner = Runner.from_cfg(cfg)
|
|
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
|
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
|
runner.train()
|
|
|
|
self.assertEqual(runner.iter, runner.max_iters)
|
|
assert os.path.exists(os.path.join(self.temp_dir, 'candidates.pkl'))
|
|
|
|
# test run with _check_constraints
|
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
|
cfg.experiment_name = 'test_run2'
|
|
cfg.train_cfg.flops_range = (0, 100)
|
|
runner = Runner.from_cfg(cfg)
|
|
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
|
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
|
mock_flops.return_value = (50., 1)
|
|
fix_subnet = FixSubnet(modules=fake_subnet)
|
|
runner.model.export_fix_subnet = MagicMock(return_value=fix_subnet)
|
|
runner.train()
|
|
|
|
self.assertEqual(runner.iter, runner.max_iters)
|
|
assert os.path.exists(os.path.join(self.temp_dir, 'candidates.pkl'))
|