250 lines
9.3 KiB
Python
250 lines
9.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import shutil
|
|
import tempfile
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.config import Config
|
|
from mmengine.hooks import Hook
|
|
from mmengine.model import BaseDataPreprocessor, BaseModel
|
|
from mmengine.runner import Runner
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
from mmrazor.engine import DartsEpochBasedTrainLoop # noqa: F401
|
|
from mmrazor.engine import DartsIterBasedTrainLoop # noqa: F401
|
|
from mmrazor.registry import DATASETS, HOOKS, MODELS
|
|
|
|
|
|
class ToyDataPreprocessor(BaseDataPreprocessor):
|
|
|
|
def collate_data(self, data):
|
|
data = [_data[0] for _data in data]
|
|
inputs = [_data['inputs'].to(self._device) for _data in data]
|
|
batch_data_samples = []
|
|
# Model can get predictions without any data samples.
|
|
for _data in data:
|
|
if 'data_sample' in _data:
|
|
batch_data_samples.append(_data['data_sample'])
|
|
# Move data from CPU to corresponding device.
|
|
batch_data_samples = [
|
|
data_sample.to(self._device) for data_sample in batch_data_samples
|
|
]
|
|
|
|
if not batch_data_samples:
|
|
batch_data_samples = None # type: ignore
|
|
|
|
return inputs, batch_data_samples
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ToyModel_DartsLoop(BaseModel):
|
|
|
|
def __init__(self, data_preprocessor=ToyDataPreprocessor()):
|
|
super().__init__(data_preprocessor=data_preprocessor)
|
|
self.linear1 = nn.Linear(2, 2)
|
|
self.linear2 = nn.Linear(2, 1)
|
|
|
|
def forward(self, batch_inputs, labels, mode='tensor'):
|
|
labels = torch.stack(labels)
|
|
outputs = self.linear1(batch_inputs)
|
|
outputs = self.linear2(outputs)
|
|
|
|
if mode == 'tensor':
|
|
return outputs
|
|
elif mode == 'loss':
|
|
loss = (labels - outputs).sum()
|
|
outputs = dict(loss=loss)
|
|
return outputs
|
|
elif mode == 'predict':
|
|
outputs = dict(log_vars=dict(a=1, b=0.5))
|
|
return outputs
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class ToyDataset_DartsLoop(Dataset):
|
|
METAINFO = dict() # type: ignore
|
|
data = torch.randn(12, 2)
|
|
label = torch.ones(12)
|
|
|
|
@property
|
|
def metainfo(self):
|
|
return self.METAINFO
|
|
|
|
def __len__(self):
|
|
return self.data.size(0)
|
|
|
|
def __getitem__(self, index):
|
|
return dict(inputs=self.data[index], data_sample=self.label[index])
|
|
|
|
|
|
class TestDartsLoop(TestCase):
|
|
|
|
def setUp(self):
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
epoch_based_cfg = dict(
|
|
default_scope='mmrazor',
|
|
model=dict(type='ToyModel_DartsLoop'),
|
|
work_dir=self.temp_dir,
|
|
train_dataloader=dict(
|
|
dataset=dict(type='ToyDataset_DartsLoop'),
|
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
optim_wrapper=dict(
|
|
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
|
|
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
|
train_cfg=dict(
|
|
type='DartsEpochBasedTrainLoop',
|
|
max_epochs=3,
|
|
val_interval=1,
|
|
val_begin=2),
|
|
custom_hooks=[],
|
|
default_hooks=dict(
|
|
runtime_info=dict(type='RuntimeInfoHook'),
|
|
timer=dict(type='IterTimerHook'),
|
|
logger=dict(type='LoggerHook'),
|
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
|
checkpoint=dict(
|
|
type='CheckpointHook', interval=1, by_epoch=True),
|
|
sampler_seed=dict(type='DistSamplerSeedHook')),
|
|
launcher='none',
|
|
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
|
)
|
|
self.epoch_based_cfg = Config(epoch_based_cfg)
|
|
self.epoch_based_cfg.train_cfg['mutator_dataloader'] = \
|
|
self.epoch_based_cfg.train_dataloader
|
|
self.iter_based_cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
self.iter_based_cfg.train_dataloader = dict(
|
|
dataset=dict(type='ToyDataset_DartsLoop'),
|
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
|
batch_size=3,
|
|
num_workers=0)
|
|
self.iter_based_cfg.train_cfg = dict(
|
|
type='DartsIterBasedTrainLoop',
|
|
mutator_dataloader=self.iter_based_cfg.train_dataloader,
|
|
max_iters=12,
|
|
val_interval=4,
|
|
val_begin=4)
|
|
self.iter_based_cfg.default_hooks = dict(
|
|
runtime_info=dict(type='RuntimeInfoHook'),
|
|
timer=dict(type='IterTimerHook'),
|
|
logger=dict(type='LoggerHook'),
|
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
|
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
|
|
sampler_seed=dict(type='DistSamplerSeedHook'))
|
|
|
|
def tearDown(self):
|
|
shutil.rmtree(self.temp_dir)
|
|
|
|
def test_init(self):
|
|
# 1. DartsEpochBasedTrainLoop
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
cfg.experiment_name = 'test_init1'
|
|
runner = Runner.from_cfg(cfg)
|
|
loop = runner.build_train_loop(cfg.train_cfg)
|
|
|
|
self.assertIsInstance(loop, DartsEpochBasedTrainLoop)
|
|
self.assertIsInstance(loop.runner, Runner)
|
|
self.assertEqual(loop.max_epochs, 3)
|
|
self.assertEqual(loop.max_iters, 12)
|
|
self.assertIsInstance(loop.mutator_dataloader, DataLoader)
|
|
|
|
# 2. DartsIterBasedTrainLoop
|
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
|
cfg.experiment_name = 'test_init2'
|
|
runner = Runner.from_cfg(cfg)
|
|
loop = runner.build_train_loop(cfg.train_cfg)
|
|
|
|
self.assertIsInstance(loop, DartsIterBasedTrainLoop)
|
|
self.assertIsInstance(loop.runner, Runner)
|
|
self.assertEqual(loop.max_iters, 12)
|
|
self.assertIsInstance(loop.mutator_dataloader, DataLoader)
|
|
|
|
def test_run(self):
|
|
# 1. test DartsEpochBasedTrainLoop
|
|
epoch_results = []
|
|
epoch_targets = [i for i in range(3)]
|
|
iter_results = []
|
|
iter_targets = [i for i in range(4 * 3)]
|
|
batch_idx_results = []
|
|
batch_idx_targets = [i for i in range(4)] * 3 # train and val
|
|
val_epoch_results = []
|
|
val_epoch_targets = [i for i in range(2, 4)]
|
|
|
|
@HOOKS.register_module()
|
|
class TestEpochHook(Hook):
|
|
|
|
def before_train_epoch(self, runner):
|
|
epoch_results.append(runner.epoch)
|
|
|
|
def before_train_iter(self, runner, batch_idx, data_batch=None):
|
|
iter_results.append(runner.iter)
|
|
batch_idx_results.append(batch_idx)
|
|
|
|
def before_val_epoch(self, runner):
|
|
val_epoch_results.append(runner.epoch)
|
|
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
cfg.experiment_name = 'test_train1'
|
|
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
|
|
runner = Runner.from_cfg(cfg)
|
|
runner.train()
|
|
|
|
assert isinstance(runner.train_loop, DartsEpochBasedTrainLoop)
|
|
for result, target, in zip(epoch_results, epoch_targets):
|
|
self.assertEqual(result, target)
|
|
for result, target, in zip(iter_results, iter_targets):
|
|
self.assertEqual(result, target)
|
|
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
|
self.assertEqual(result, target)
|
|
for result, target, in zip(val_epoch_results, val_epoch_targets):
|
|
self.assertEqual(result, target)
|
|
|
|
# 2. test DartsIterBasedTrainLoop
|
|
epoch_results = []
|
|
iter_results = []
|
|
batch_idx_results = []
|
|
val_iter_results = []
|
|
val_batch_idx_results = []
|
|
iter_targets = [i for i in range(12)]
|
|
batch_idx_targets = [i for i in range(12)]
|
|
val_iter_targets = [i for i in range(4, 12)]
|
|
val_batch_idx_targets = [i for i in range(4)] * 2
|
|
|
|
@HOOKS.register_module()
|
|
class TestIterHook(Hook):
|
|
|
|
def before_train_epoch(self, runner):
|
|
epoch_results.append(runner.epoch)
|
|
|
|
def before_train_iter(self, runner, batch_idx, data_batch=None):
|
|
iter_results.append(runner.iter)
|
|
batch_idx_results.append(batch_idx)
|
|
|
|
def before_val_iter(self, runner, batch_idx, data_batch=None):
|
|
val_epoch_results.append(runner.iter)
|
|
val_batch_idx_results.append(batch_idx)
|
|
|
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
|
cfg.experiment_name = 'test_train2'
|
|
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
|
|
runner = Runner.from_cfg(cfg)
|
|
runner.train()
|
|
|
|
assert isinstance(runner.train_loop, DartsIterBasedTrainLoop)
|
|
self.assertEqual(len(epoch_results), 1)
|
|
self.assertEqual(epoch_results[0], 0)
|
|
self.assertEqual(runner.val_interval, 4)
|
|
self.assertEqual(runner.val_begin, 4)
|
|
for result, target, in zip(iter_results, iter_targets):
|
|
self.assertEqual(result, target)
|
|
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
|
self.assertEqual(result, target)
|
|
for result, target, in zip(val_iter_results, val_iter_targets):
|
|
self.assertEqual(result, target)
|
|
for result, target, in zip(val_batch_idx_results,
|
|
val_batch_idx_targets):
|
|
self.assertEqual(result, target)
|