Refactor subnet sampler

pull/198/head
humu789 2022-07-11 02:44:18 +00:00 committed by pppppM
parent 1e3f8e9f67
commit 2d5e8bc675
7 changed files with 568 additions and 4 deletions

View File

@ -2,8 +2,9 @@
from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
from .distill_val_loop import SingleTeacherDistillValLoop
from .evolution_search_loop import EvolutionSearchLoop
from .subnet_sampler_loop import GreedySamplerTrainLoop
__all__ = [
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'EvolutionSearchLoop'
'DartsIterBasedTrainLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop'
]

View File

@ -0,0 +1,335 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os
import random
from abc import abstractmethod
from typing import Dict, List, Optional, Sequence, Tuple, Union
import mmcv
import torch
from mmengine.evaluator import Evaluator
from mmengine.runner import IterBasedTrainLoop
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader
from mmrazor.models.subnet import (MULTI_MUTATORS_RANDOM_SUBNET,
SINGLE_MUTATOR_RANDOM_SUBNET, Candidates,
FlopsEstimator)
from mmrazor.registry import LOOPS
random_subnet_type = Union[SINGLE_MUTATOR_RANDOM_SUBNET,
MULTI_MUTATORS_RANDOM_SUBNET]
class BaseSamplerTrainLoop(IterBasedTrainLoop):
"""IterBasedTrainLoop for base sampler.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader for training the model.
max_iters (int): Total training iters.
val_begin (int): The iteration that begins validating.
Defaults to 1.
val_interval (int): Validation interval. Defaults to 1000.
"""
def __init__(self,
runner,
dataloader: Union[Dict, DataLoader],
max_iters: int,
val_begin: int = 1,
val_interval: int = 1000):
super().__init__(runner, dataloader, max_iters, val_begin,
val_interval)
if self.runner.distributed:
self.model = runner.model.module
else:
self.model = runner.model
@abstractmethod
def sample_subnet(self) -> random_subnet_type:
"""Sample a subnet to train the supernet."""
def run_iter(self, data_batch: Sequence[dict]) -> None:
"""Iterate one mini-batch.
Args:
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
# Enable gradient accumulation mode and avoid unnecessary gradient
# synchronization during gradient accumulation process.
# outputs should be a dict of loss.
subnet = self.sample_subnet()
self.model.set_subnet(subnet)
outputs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.message_hub.update_info('train_logs', outputs)
self.runner.call_hook(
'after_train_iter',
batch_idx=self._iter,
data_batch=data_batch,
outputs=outputs)
self._iter += 1
@LOOPS.register_module()
class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
"""IterBasedTrainLoop for greedy sampler.
In GreedySamplerTrainLoop, `Greedy` means that only use some top
sampled candidates to train the supernet. So GreedySamplerTrainLoop mainly
picks the top candidates based on their val socres, then use them to train
the supernet one by one.
Steps:
1. Sample from the supernet and the candidates.
2. Validate these sampled candidates to get each candidate's score.
3. Get top-k candidates based on their scores, then use them to train
the supernet one by one.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader for training the model.
dataloader_val (Dataloader or dict): A dataloader object or a dict to
build a dataloader for evaluating the candidates.
evaluator (Evaluator or dict or list): Used for computing metrics.
max_iters (int): Total training iters.
val_begin (int): The iteration that begins validating.
Defaults to 1.
val_interval (int): Validation interval. Defaults to 1000.
score_key (str): Specify one metric in evaluation results to score
candidates. Defaults to 'accuracy_top-1'.
constraints (dict): Constraints to be used for screening candidates.
num_candidates (int): The number of the candidates consist of samples
from supernet and itself. Defaults to 1000.
num_samples (int): The number of sample in each sampling subnet.
Defaults to 10.
top_k (int): Choose top_k subnet from the candidates used to train
the supernet. Defaults to 5.
prob_schedule (str): The schedule to generate the probablity of
sampling from the candidates. The probablity will increase from
[init_prob, max_prob] during [schedule_start_iter,
schedule_end_iter]. Both of 'linear' schedule and 'consine'
schedule are supported. Defaults to 'linear'.
schedule_start_iter (int): The start iter of the prob_schedule.
Defaults to 10000. 10000 is corresponding to batch_size: 1024.
You should adptive it based on your batch_size.
schedule_end_iter (int): The end iter in of the prob_schedule.
Defaults to 144360. 144360 = 120(epoch) * 1203 (iters/epoch),
batch_size is 1024. You should adptive it based on the batch_size
and the total training epochs.
init_prob (float): The init probablity of the prob_schedule.
Defaults to 0.0.
max_prob (float): The max probablity of the prob_schedule.
Defaults to 0.8.
"""
def __init__(self,
runner,
dataloader: Union[Dict, DataLoader],
dataloader_val: Union[Dict, DataLoader],
evaluator: Union[Evaluator, Dict, List],
max_iters: int,
val_begin: int = 1,
val_interval: int = 1000,
score_key: str = 'accuracy_top-1',
flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6),
num_candidates: int = 1000,
num_samples: int = 10,
top_k: int = 5,
prob_schedule: str = 'linear',
schedule_start_iter: int = 10000,
schedule_end_iter: int = 144360,
init_prob: float = 0.,
max_prob: float = 0.8) -> None:
super().__init__(runner, dataloader, max_iters, val_begin,
val_interval)
if isinstance(dataloader_val, dict):
self.dataloader_val = runner.build_dataloader(
dataloader_val, seed=runner.seed)
else:
self.dataloader_val = dataloader_val
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
self.evaluator = runner.build_evaluator(evaluator)
else:
self.evaluator = evaluator
self.score_key = score_key
self.flops_range = flops_range
self.num_candidates = num_candidates
self.num_samples = num_samples
self.top_k = top_k
assert prob_schedule in ['linear', 'consine']
self.prob_schedule = prob_schedule
self.schedule_start_iter = schedule_start_iter
self.schedule_end_iter = schedule_end_iter
self.init_prob = init_prob
self.max_prob = max_prob
self.cur_prob: float = 0.
self.candidates = Candidates()
self.top_k_candidates = Candidates()
def run(self) -> None:
"""Launch training."""
self.runner.call_hook('before_train')
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
while self._iter < self._max_iters:
self.runner.model.train()
data_batch = next(self.dataloader_iterator)
self.run_iter(data_batch)
if (self.runner.val_loop is not None
and self._iter >= self.runner.val_begin
and self._iter % self.runner.val_interval == 0):
self.runner.val_loop.run()
self._save_candidates()
self.runner.call_hook('after_train_epoch')
self.runner.call_hook('after_train')
def sample_subnet(self) -> random_subnet_type:
"""Sample a subnet from top_k candidates one by one, then to train the
surpernet with the subnet.
Steps:
1. Update and get the `top_k_candidates`.
1.1. Update the prob of sampling from the `candidates` based on
the `prob_schedule` and the current iter.
1.2. Sample `num_samples` candidates from the supernet and the
`candidates` based on the updated prob(step 1.1).
1.3. Val all candidates to get their scores, including the
sampled candidates(step 1.2).
1.4. Update the `top_k_candidates` based on
their scores(step 1.3).
2. Pop from the `top_k_candidates` one by one to train
the supernet.
"""
if len(self.top_k_candidates) == 0:
self.update_cur_prob(cur_iter=self._iter)
sampled_candidates, num_sample_from_supernet = \
self.get_candidates_with_sample(num_samples=self.num_samples)
self.candidates.extend(sampled_candidates)
self.update_candidates_scores()
self.candidates.sort(key=lambda x: x[1], reverse=True)
self.candidates = Candidates(self.candidates[:self.num_candidates])
self.top_k_candidates = Candidates(self.candidates[:self.top_k])
top1_score = self.top_k_candidates.scores[0]
if (self._iter % self.val_interval) < self.top_k:
self.runner.logger.info(
f'GreedySampler: [{self._iter:>6d}] '
f'prob {self.cur_prob:.3f} '
f'num_sample_from_supernet '
f'{num_sample_from_supernet}/{self.num_samples} '
f'top1_score {top1_score:.3f} '
f'cur_num_candidates: {len(self.candidates)}')
return self.top_k_candidates.pop(0)[0]
def update_cur_prob(self, cur_iter: int) -> None:
"""update current probablity of sampling from the candidates, which is
generated based on the probablity strategy and current iter."""
if cur_iter > self.schedule_end_iter:
self.cur_prob = self.max_prob
elif cur_iter < self.schedule_start_iter:
self.cur_prob = self.init_prob
else:
schedule_all_steps = self.schedule_end_iter - \
self.schedule_start_iter
schedule_cur_steps = cur_iter - self.schedule_start_iter
if self.prob_schedule == 'linear':
tmp = self.max_prob - self.init_prob
self.cur_prob = tmp / schedule_all_steps * schedule_cur_steps
elif self.prob_schedule == 'consine':
tmp_1 = (1 - self.init_prob) * 0.5
tmp_2 = math.pi * schedule_cur_steps
tmp_3 = schedule_all_steps
self.cur_prob = tmp_1 * (1 + math.cos(tmp_2 / tmp_3)) \
+ self.init_prob
else:
raise ValueError('`prob_schedule` is eroor, it should be \
one of `linear` and `consine`.')
def get_candidates_with_sample(self,
num_samples: int) -> Tuple[Candidates, int]:
"""Get candidates with sampling from supernet and the candidates based
on the current probablity."""
num_sample_from_supernet = 0
sampled_candidates = Candidates()
for _ in range(num_samples):
if random.random() >= self.cur_prob or len(self.candidates) == 0:
subnet = self._sample_from_supernet()
if self._check_constraints(subnet):
sampled_candidates.append(subnet)
num_sample_from_supernet += 1
else:
sampled_candidates.append(self._sample_from_candidates())
return sampled_candidates, num_sample_from_supernet
def update_candidates_scores(self) -> None:
"""Update candidates' scores, which are validated with the
`dataloader_val`."""
for i, candidate in enumerate(self.candidates.subnets):
self.model.set_subnet(candidate)
metrics = self._val_candidate()
score = metrics[self.score_key] if len(metrics) != 0 else 0.
self.candidates.set_score(i, score)
@torch.no_grad()
def _val_candidate(self) -> Dict:
"""Run validation."""
self.runner.model.eval()
for data_batch in self.dataloader_val:
outputs = self.runner.model.val_step(data_batch)
self.evaluator.process(data_batch, outputs)
metrics = self.evaluator.evaluate(len(self.dataloader_val.dataset))
return metrics
def _sample_from_supernet(self) -> random_subnet_type:
"""Sample from the supernet."""
subnet = self.model.sample_subnet()
return subnet
def _sample_from_candidates(self) -> random_subnet_type:
"""Sample from the candidates."""
assert len(self.candidates) > 0
subnet = random.choice(self.candidates)
return subnet
def _check_constraints(self, random_subnet: random_subnet_type) -> bool:
"""Check whether is beyond constraints.
Returns:
bool: The result of checking.
"""
if self.flops_range is None:
return True
self.model.set_subnet(random_subnet)
fix_subnet = self.model.export_fix_subnet()
flops = FlopsEstimator.get_model_complexity_info(
self.model, fix_subnet=fix_subnet, as_strings=False)[0]
if self.flops_range[0] < flops < self.flops_range[1]:
return True
else:
return False
def _save_candidates(self) -> None:
"""Save the candidates to init the next searching."""
save_path = os.path.join(self.runner.work_dir, 'candidates.pkl')
mmcv.fileio.dump(self.candidates, save_path)
self.runner.logger.info(f'candidates.pkl saved in '
f'{self.runner.work_dir}')

View File

@ -22,6 +22,8 @@ def crossover(random_subnet1: SINGLE_MUTATOR_RANDOM_SUBNET,
Returns:
SINGLE_MUTATOR_RANDOM_SUBNET: The result of crossover.
"""
assert prob >= 0. and prob <= 1., \
'The probability of crossover has to be between 0 and 1'
crossover_subnet = copy.deepcopy(random_subnet1)
for group_id, choice in random_subnet2.items():
if np.random.random_sample() < prob:

View File

@ -1 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import sys
import pytest
import torch
@ -9,7 +10,9 @@ from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.models import * # noqa: F401,F403
from mmrazor.models.mutables import * # noqa: F401,F403
from mmrazor.registry import MODELS
from .utils import MockMutable
sys.path.append('tests/test_models/test_architectures/test_backbones')
from utils import MockMutable # noqa: E402
_FIRST_STAGE_MUTABLE = dict(type='MockMutable', choices=['c1'])
_OTHER_STAGE_MUTABLE = dict(

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import sys
import tempfile
import pytest
@ -12,7 +13,9 @@ from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.models import * # noqa: F401,F403
from mmrazor.models.mutables import * # noqa: F401,F403
from mmrazor.registry import MODELS
from .utils import MockMutable
sys.path.append('tests/test_models/test_architectures/test_backbones')
from utils import MockMutable # noqa: E402
STAGE_MUTABLE = dict(type='MockMutable', choices=['c1', 'c2', 'c3', 'c4'])
ARCHSETTING_CFG = [

View File

@ -0,0 +1,221 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import shutil
import tempfile
from unittest import TestCase
from unittest.mock import MagicMock, patch
import torch
import torch.nn as nn
from mmengine.config import Config
from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import Runner
from torch.utils.data import Dataset
from mmrazor.models.subnet import FixSubnet
from mmrazor.registry import DATASETS, METRICS, MODELS
from mmrazor.runners import GreedySamplerTrainLoop # noqa: F401
@MODELS.register_module()
class ToyModel_GreedySamplerTrainLoop(BaseModel):
@patch('mmrazor.models.mutators.OneShotModuleMutator')
def __init__(self, mock_mutator):
super().__init__()
self.linear1 = nn.Linear(2, 2)
self.linear2 = nn.Linear(2, 1)
self.mutator = mock_mutator
def forward(self, batch_inputs, labels, mode='tensor'):
labels = torch.stack(labels)
outputs = self.linear1(batch_inputs)
outputs = self.linear2(outputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
outputs = dict(log_vars=dict(a=1, b=0.5))
return outputs
def sample_subnet(self):
return self.mutator.sample_choices()
def set_subnet(self, subnet):
self.mutator.set_choices(subnet)
def export_fix_subnet(self):
pass
@DATASETS.register_module()
class ToyDataset_GreedySamplerTrainLoop(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
@METRICS.register_module()
class ToyMetric_GreedySamplerTrainLoop(BaseMetric):
def __init__(self, collect_device='cpu', dummy_metrics=None):
super().__init__(collect_device=collect_device)
self.dummy_metrics = dummy_metrics
def process(self, data_samples, predictions):
result = {'acc': 1}
self.results.append(result)
def compute_metrics(self, results):
return dict(acc=1)
class TestGreedySamplerTrainLoop(TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
val_dataloader = dict(
dataset=dict(type='ToyDataset_GreedySamplerTrainLoop'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0)
val_evaluator = dict(type='ToyMetric_GreedySamplerTrainLoop')
iter_based_cfg = dict(
default_scope='mmrazor',
model=dict(type='ToyModel_GreedySamplerTrainLoop'),
work_dir=self.temp_dir,
train_dataloader=dict(
dataset=dict(type='ToyDataset_GreedySamplerTrainLoop'),
sampler=dict(type='InfiniteSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=val_dataloader,
optim_wrapper=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
val_evaluator=val_evaluator,
train_cfg=dict(
type='GreedySamplerTrainLoop',
dataloader_val=val_dataloader,
evaluator=val_evaluator,
max_iters=12,
val_interval=2,
score_key='acc',
flops_range=None,
num_candidates=4,
num_samples=2,
top_k=2,
prob_schedule='linear',
schedule_start_iter=4,
schedule_end_iter=10,
init_prob=0.,
max_prob=0.8),
val_cfg=dict(),
custom_hooks=[],
default_hooks=dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', interval=1, by_epoch=False),
sampler_seed=dict(type='DistSamplerSeedHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
)
self.iter_based_cfg = Config(iter_based_cfg)
def tearDown(self):
shutil.rmtree(self.temp_dir)
def test_init(self):
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_init_GreedySamplerTrainLoop'
runner = Runner.from_cfg(cfg)
loop = runner.build_train_loop(cfg.train_cfg)
self.assertIsInstance(loop, GreedySamplerTrainLoop)
def test_update_cur_prob(self):
# prob_schedule = linear
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_update_cur_prob1'
runner = Runner.from_cfg(cfg)
loop = runner.build_train_loop(cfg.train_cfg)
loop.update_cur_prob(loop.schedule_end_iter - 1)
self.assertGreater(loop.max_prob, loop.cur_prob)
loop.update_cur_prob(loop.schedule_end_iter + 1)
self.assertEqual(loop.max_prob, loop.cur_prob)
# prob_schedule = consine
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_update_cur_prob2'
cfg.train_cfg.prob_schedule = 'consine'
runner = Runner.from_cfg(cfg)
loop = runner.build_train_loop(cfg.train_cfg)
loop.update_cur_prob(loop.schedule_end_iter - 1)
self.assertGreater(loop.max_prob, loop.cur_prob)
loop.update_cur_prob(loop.schedule_end_iter + 1)
self.assertEqual(loop.max_prob, loop.cur_prob)
def test_sample_subnet(self):
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_sample_subnet'
runner = Runner.from_cfg(cfg)
fake_subnet = {'1': 'choice1', '2': 'choice2'}
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
loop = runner.build_train_loop(cfg.train_cfg)
loop.cur_prob = loop.max_prob
self.assertEqual(len(loop.top_k_candidates), 0)
loop._iter = loop.val_interval
subnet = loop.sample_subnet()
self.assertEqual(subnet, fake_subnet)
self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1)
@patch('mmrazor.models.subnet.FlopsEstimator.get_model_complexity_info')
def test_run(self, mock_flops):
# test run with flops_range=None
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_run1'
runner = Runner.from_cfg(cfg)
fake_subnet = {'1': 'choice1', '2': 'choice2'}
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
runner.train()
self.assertEqual(runner.iter, runner.max_iters)
assert os.path.exists(os.path.join(self.temp_dir, 'candidates.pkl'))
# test run with _check_constraints
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_run2'
cfg.train_cfg.flops_range = (0, 100)
runner = Runner.from_cfg(cfg)
fake_subnet = {'1': 'choice1', '2': 'choice2'}
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
mock_flops.return_value = (50., 1)
fix_subnet = FixSubnet(modules=fake_subnet)
runner.model.export_fix_subnet = MagicMock(return_value=fix_subnet)
runner.train()
self.assertEqual(runner.iter, runner.max_iters)
assert os.path.exists(os.path.join(self.temp_dir, 'candidates.pkl'))