Refactor subnet sampler
parent
1e3f8e9f67
commit
2d5e8bc675
|
@ -2,8 +2,9 @@
|
|||
from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
|
||||
from .distill_val_loop import SingleTeacherDistillValLoop
|
||||
from .evolution_search_loop import EvolutionSearchLoop
|
||||
from .subnet_sampler_loop import GreedySamplerTrainLoop
|
||||
|
||||
__all__ = [
|
||||
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
|
||||
'DartsIterBasedTrainLoop', 'EvolutionSearchLoop'
|
||||
'DartsIterBasedTrainLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,335 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.runner import IterBasedTrainLoop
|
||||
from mmengine.utils import is_list_of
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmrazor.models.subnet import (MULTI_MUTATORS_RANDOM_SUBNET,
|
||||
SINGLE_MUTATOR_RANDOM_SUBNET, Candidates,
|
||||
FlopsEstimator)
|
||||
from mmrazor.registry import LOOPS
|
||||
|
||||
random_subnet_type = Union[SINGLE_MUTATOR_RANDOM_SUBNET,
|
||||
MULTI_MUTATORS_RANDOM_SUBNET]
|
||||
|
||||
|
||||
class BaseSamplerTrainLoop(IterBasedTrainLoop):
|
||||
"""IterBasedTrainLoop for base sampler.
|
||||
|
||||
Args:
|
||||
runner (Runner): A reference of runner.
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader for training the model.
|
||||
max_iters (int): Total training iters.
|
||||
val_begin (int): The iteration that begins validating.
|
||||
Defaults to 1.
|
||||
val_interval (int): Validation interval. Defaults to 1000.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[Dict, DataLoader],
|
||||
max_iters: int,
|
||||
val_begin: int = 1,
|
||||
val_interval: int = 1000):
|
||||
super().__init__(runner, dataloader, max_iters, val_begin,
|
||||
val_interval)
|
||||
if self.runner.distributed:
|
||||
self.model = runner.model.module
|
||||
else:
|
||||
self.model = runner.model
|
||||
|
||||
@abstractmethod
|
||||
def sample_subnet(self) -> random_subnet_type:
|
||||
"""Sample a subnet to train the supernet."""
|
||||
|
||||
def run_iter(self, data_batch: Sequence[dict]) -> None:
|
||||
"""Iterate one mini-batch.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): Batch of data from dataloader.
|
||||
"""
|
||||
self.runner.call_hook(
|
||||
'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
|
||||
# Enable gradient accumulation mode and avoid unnecessary gradient
|
||||
# synchronization during gradient accumulation process.
|
||||
# outputs should be a dict of loss.
|
||||
subnet = self.sample_subnet()
|
||||
self.model.set_subnet(subnet)
|
||||
outputs = self.runner.model.train_step(
|
||||
data_batch, optim_wrapper=self.runner.optim_wrapper)
|
||||
self.runner.message_hub.update_info('train_logs', outputs)
|
||||
|
||||
self.runner.call_hook(
|
||||
'after_train_iter',
|
||||
batch_idx=self._iter,
|
||||
data_batch=data_batch,
|
||||
outputs=outputs)
|
||||
self._iter += 1
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
||||
"""IterBasedTrainLoop for greedy sampler.
|
||||
|
||||
In GreedySamplerTrainLoop, `Greedy` means that only use some top
|
||||
sampled candidates to train the supernet. So GreedySamplerTrainLoop mainly
|
||||
picks the top candidates based on their val socres, then use them to train
|
||||
the supernet one by one.
|
||||
|
||||
Steps:
|
||||
1. Sample from the supernet and the candidates.
|
||||
2. Validate these sampled candidates to get each candidate's score.
|
||||
3. Get top-k candidates based on their scores, then use them to train
|
||||
the supernet one by one.
|
||||
|
||||
Args:
|
||||
runner (Runner): A reference of runner.
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader for training the model.
|
||||
dataloader_val (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader for evaluating the candidates.
|
||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||
max_iters (int): Total training iters.
|
||||
val_begin (int): The iteration that begins validating.
|
||||
Defaults to 1.
|
||||
val_interval (int): Validation interval. Defaults to 1000.
|
||||
score_key (str): Specify one metric in evaluation results to score
|
||||
candidates. Defaults to 'accuracy_top-1'.
|
||||
constraints (dict): Constraints to be used for screening candidates.
|
||||
num_candidates (int): The number of the candidates consist of samples
|
||||
from supernet and itself. Defaults to 1000.
|
||||
num_samples (int): The number of sample in each sampling subnet.
|
||||
Defaults to 10.
|
||||
top_k (int): Choose top_k subnet from the candidates used to train
|
||||
the supernet. Defaults to 5.
|
||||
prob_schedule (str): The schedule to generate the probablity of
|
||||
sampling from the candidates. The probablity will increase from
|
||||
[init_prob, max_prob] during [schedule_start_iter,
|
||||
schedule_end_iter]. Both of 'linear' schedule and 'consine'
|
||||
schedule are supported. Defaults to 'linear'.
|
||||
schedule_start_iter (int): The start iter of the prob_schedule.
|
||||
Defaults to 10000. 10000 is corresponding to batch_size: 1024.
|
||||
You should adptive it based on your batch_size.
|
||||
schedule_end_iter (int): The end iter in of the prob_schedule.
|
||||
Defaults to 144360. 144360 = 120(epoch) * 1203 (iters/epoch),
|
||||
batch_size is 1024. You should adptive it based on the batch_size
|
||||
and the total training epochs.
|
||||
init_prob (float): The init probablity of the prob_schedule.
|
||||
Defaults to 0.0.
|
||||
max_prob (float): The max probablity of the prob_schedule.
|
||||
Defaults to 0.8.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[Dict, DataLoader],
|
||||
dataloader_val: Union[Dict, DataLoader],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
max_iters: int,
|
||||
val_begin: int = 1,
|
||||
val_interval: int = 1000,
|
||||
score_key: str = 'accuracy_top-1',
|
||||
flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6),
|
||||
num_candidates: int = 1000,
|
||||
num_samples: int = 10,
|
||||
top_k: int = 5,
|
||||
prob_schedule: str = 'linear',
|
||||
schedule_start_iter: int = 10000,
|
||||
schedule_end_iter: int = 144360,
|
||||
init_prob: float = 0.,
|
||||
max_prob: float = 0.8) -> None:
|
||||
super().__init__(runner, dataloader, max_iters, val_begin,
|
||||
val_interval)
|
||||
if isinstance(dataloader_val, dict):
|
||||
self.dataloader_val = runner.build_dataloader(
|
||||
dataloader_val, seed=runner.seed)
|
||||
else:
|
||||
self.dataloader_val = dataloader_val
|
||||
|
||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
self.evaluator = runner.build_evaluator(evaluator)
|
||||
else:
|
||||
self.evaluator = evaluator
|
||||
|
||||
self.score_key = score_key
|
||||
self.flops_range = flops_range
|
||||
self.num_candidates = num_candidates
|
||||
self.num_samples = num_samples
|
||||
self.top_k = top_k
|
||||
assert prob_schedule in ['linear', 'consine']
|
||||
self.prob_schedule = prob_schedule
|
||||
self.schedule_start_iter = schedule_start_iter
|
||||
self.schedule_end_iter = schedule_end_iter
|
||||
self.init_prob = init_prob
|
||||
self.max_prob = max_prob
|
||||
self.cur_prob: float = 0.
|
||||
|
||||
self.candidates = Candidates()
|
||||
self.top_k_candidates = Candidates()
|
||||
|
||||
def run(self) -> None:
|
||||
"""Launch training."""
|
||||
self.runner.call_hook('before_train')
|
||||
# In iteration-based training loop, we treat the whole training process
|
||||
# as a big epoch and execute the corresponding hook.
|
||||
self.runner.call_hook('before_train_epoch')
|
||||
while self._iter < self._max_iters:
|
||||
self.runner.model.train()
|
||||
|
||||
data_batch = next(self.dataloader_iterator)
|
||||
self.run_iter(data_batch)
|
||||
|
||||
if (self.runner.val_loop is not None
|
||||
and self._iter >= self.runner.val_begin
|
||||
and self._iter % self.runner.val_interval == 0):
|
||||
self.runner.val_loop.run()
|
||||
self._save_candidates()
|
||||
|
||||
self.runner.call_hook('after_train_epoch')
|
||||
self.runner.call_hook('after_train')
|
||||
|
||||
def sample_subnet(self) -> random_subnet_type:
|
||||
"""Sample a subnet from top_k candidates one by one, then to train the
|
||||
surpernet with the subnet.
|
||||
|
||||
Steps:
|
||||
1. Update and get the `top_k_candidates`.
|
||||
1.1. Update the prob of sampling from the `candidates` based on
|
||||
the `prob_schedule` and the current iter.
|
||||
1.2. Sample `num_samples` candidates from the supernet and the
|
||||
`candidates` based on the updated prob(step 1.1).
|
||||
1.3. Val all candidates to get their scores, including the
|
||||
sampled candidates(step 1.2).
|
||||
1.4. Update the `top_k_candidates` based on
|
||||
their scores(step 1.3).
|
||||
2. Pop from the `top_k_candidates` one by one to train
|
||||
the supernet.
|
||||
"""
|
||||
if len(self.top_k_candidates) == 0:
|
||||
self.update_cur_prob(cur_iter=self._iter)
|
||||
|
||||
sampled_candidates, num_sample_from_supernet = \
|
||||
self.get_candidates_with_sample(num_samples=self.num_samples)
|
||||
|
||||
self.candidates.extend(sampled_candidates)
|
||||
|
||||
self.update_candidates_scores()
|
||||
|
||||
self.candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
self.candidates = Candidates(self.candidates[:self.num_candidates])
|
||||
self.top_k_candidates = Candidates(self.candidates[:self.top_k])
|
||||
|
||||
top1_score = self.top_k_candidates.scores[0]
|
||||
if (self._iter % self.val_interval) < self.top_k:
|
||||
self.runner.logger.info(
|
||||
f'GreedySampler: [{self._iter:>6d}] '
|
||||
f'prob {self.cur_prob:.3f} '
|
||||
f'num_sample_from_supernet '
|
||||
f'{num_sample_from_supernet}/{self.num_samples} '
|
||||
f'top1_score {top1_score:.3f} '
|
||||
f'cur_num_candidates: {len(self.candidates)}')
|
||||
return self.top_k_candidates.pop(0)[0]
|
||||
|
||||
def update_cur_prob(self, cur_iter: int) -> None:
|
||||
"""update current probablity of sampling from the candidates, which is
|
||||
generated based on the probablity strategy and current iter."""
|
||||
if cur_iter > self.schedule_end_iter:
|
||||
self.cur_prob = self.max_prob
|
||||
elif cur_iter < self.schedule_start_iter:
|
||||
self.cur_prob = self.init_prob
|
||||
else:
|
||||
schedule_all_steps = self.schedule_end_iter - \
|
||||
self.schedule_start_iter
|
||||
schedule_cur_steps = cur_iter - self.schedule_start_iter
|
||||
if self.prob_schedule == 'linear':
|
||||
tmp = self.max_prob - self.init_prob
|
||||
self.cur_prob = tmp / schedule_all_steps * schedule_cur_steps
|
||||
elif self.prob_schedule == 'consine':
|
||||
tmp_1 = (1 - self.init_prob) * 0.5
|
||||
tmp_2 = math.pi * schedule_cur_steps
|
||||
tmp_3 = schedule_all_steps
|
||||
self.cur_prob = tmp_1 * (1 + math.cos(tmp_2 / tmp_3)) \
|
||||
+ self.init_prob
|
||||
else:
|
||||
raise ValueError('`prob_schedule` is eroor, it should be \
|
||||
one of `linear` and `consine`.')
|
||||
|
||||
def get_candidates_with_sample(self,
|
||||
num_samples: int) -> Tuple[Candidates, int]:
|
||||
"""Get candidates with sampling from supernet and the candidates based
|
||||
on the current probablity."""
|
||||
num_sample_from_supernet = 0
|
||||
sampled_candidates = Candidates()
|
||||
for _ in range(num_samples):
|
||||
if random.random() >= self.cur_prob or len(self.candidates) == 0:
|
||||
subnet = self._sample_from_supernet()
|
||||
if self._check_constraints(subnet):
|
||||
sampled_candidates.append(subnet)
|
||||
num_sample_from_supernet += 1
|
||||
else:
|
||||
sampled_candidates.append(self._sample_from_candidates())
|
||||
return sampled_candidates, num_sample_from_supernet
|
||||
|
||||
def update_candidates_scores(self) -> None:
|
||||
"""Update candidates' scores, which are validated with the
|
||||
`dataloader_val`."""
|
||||
for i, candidate in enumerate(self.candidates.subnets):
|
||||
self.model.set_subnet(candidate)
|
||||
metrics = self._val_candidate()
|
||||
score = metrics[self.score_key] if len(metrics) != 0 else 0.
|
||||
self.candidates.set_score(i, score)
|
||||
|
||||
@torch.no_grad()
|
||||
def _val_candidate(self) -> Dict:
|
||||
"""Run validation."""
|
||||
self.runner.model.eval()
|
||||
for data_batch in self.dataloader_val:
|
||||
outputs = self.runner.model.val_step(data_batch)
|
||||
self.evaluator.process(data_batch, outputs)
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader_val.dataset))
|
||||
return metrics
|
||||
|
||||
def _sample_from_supernet(self) -> random_subnet_type:
|
||||
"""Sample from the supernet."""
|
||||
subnet = self.model.sample_subnet()
|
||||
return subnet
|
||||
|
||||
def _sample_from_candidates(self) -> random_subnet_type:
|
||||
"""Sample from the candidates."""
|
||||
assert len(self.candidates) > 0
|
||||
subnet = random.choice(self.candidates)
|
||||
return subnet
|
||||
|
||||
def _check_constraints(self, random_subnet: random_subnet_type) -> bool:
|
||||
"""Check whether is beyond constraints.
|
||||
|
||||
Returns:
|
||||
bool: The result of checking.
|
||||
"""
|
||||
if self.flops_range is None:
|
||||
return True
|
||||
|
||||
self.model.set_subnet(random_subnet)
|
||||
fix_subnet = self.model.export_fix_subnet()
|
||||
flops = FlopsEstimator.get_model_complexity_info(
|
||||
self.model, fix_subnet=fix_subnet, as_strings=False)[0]
|
||||
if self.flops_range[0] < flops < self.flops_range[1]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _save_candidates(self) -> None:
|
||||
"""Save the candidates to init the next searching."""
|
||||
save_path = os.path.join(self.runner.work_dir, 'candidates.pkl')
|
||||
mmcv.fileio.dump(self.candidates, save_path)
|
||||
self.runner.logger.info(f'candidates.pkl saved in '
|
||||
f'{self.runner.work_dir}')
|
|
@ -22,6 +22,8 @@ def crossover(random_subnet1: SINGLE_MUTATOR_RANDOM_SUBNET,
|
|||
Returns:
|
||||
SINGLE_MUTATOR_RANDOM_SUBNET: The result of crossover.
|
||||
"""
|
||||
assert prob >= 0. and prob <= 1., \
|
||||
'The probability of crossover has to be between 0 and 1'
|
||||
crossover_subnet = copy.deepcopy(random_subnet1)
|
||||
for group_id, choice in random_subnet2.items():
|
||||
if np.random.random_sample() < prob:
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -9,7 +10,9 @@ from torch.nn.modules.batchnorm import _BatchNorm
|
|||
from mmrazor.models import * # noqa: F401,F403
|
||||
from mmrazor.models.mutables import * # noqa: F401,F403
|
||||
from mmrazor.registry import MODELS
|
||||
from .utils import MockMutable
|
||||
|
||||
sys.path.append('tests/test_models/test_architectures/test_backbones')
|
||||
from utils import MockMutable # noqa: E402
|
||||
|
||||
_FIRST_STAGE_MUTABLE = dict(type='MockMutable', choices=['c1'])
|
||||
_OTHER_STAGE_MUTABLE = dict(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
@ -12,7 +13,9 @@ from torch.nn.modules.batchnorm import _BatchNorm
|
|||
from mmrazor.models import * # noqa: F401,F403
|
||||
from mmrazor.models.mutables import * # noqa: F401,F403
|
||||
from mmrazor.registry import MODELS
|
||||
from .utils import MockMutable
|
||||
|
||||
sys.path.append('tests/test_models/test_architectures/test_backbones')
|
||||
from utils import MockMutable # noqa: E402
|
||||
|
||||
STAGE_MUTABLE = dict(type='MockMutable', choices=['c1', 'c2', 'c3', 'c4'])
|
||||
ARCHSETTING_CFG = [
|
||||
|
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.config import Config
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.runner import Runner
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmrazor.models.subnet import FixSubnet
|
||||
from mmrazor.registry import DATASETS, METRICS, MODELS
|
||||
from mmrazor.runners import GreedySamplerTrainLoop # noqa: F401
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyModel_GreedySamplerTrainLoop(BaseModel):
|
||||
|
||||
@patch('mmrazor.models.mutators.OneShotModuleMutator')
|
||||
def __init__(self, mock_mutator):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(2, 2)
|
||||
self.linear2 = nn.Linear(2, 1)
|
||||
self.mutator = mock_mutator
|
||||
|
||||
def forward(self, batch_inputs, labels, mode='tensor'):
|
||||
labels = torch.stack(labels)
|
||||
outputs = self.linear1(batch_inputs)
|
||||
outputs = self.linear2(outputs)
|
||||
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
loss = (labels - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
elif mode == 'predict':
|
||||
outputs = dict(log_vars=dict(a=1, b=0.5))
|
||||
return outputs
|
||||
|
||||
def sample_subnet(self):
|
||||
return self.mutator.sample_choices()
|
||||
|
||||
def set_subnet(self, subnet):
|
||||
self.mutator.set_choices(subnet)
|
||||
|
||||
def export_fix_subnet(self):
|
||||
pass
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ToyDataset_GreedySamplerTrainLoop(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class ToyMetric_GreedySamplerTrainLoop(BaseMetric):
|
||||
|
||||
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
||||
super().__init__(collect_device=collect_device)
|
||||
self.dummy_metrics = dummy_metrics
|
||||
|
||||
def process(self, data_samples, predictions):
|
||||
result = {'acc': 1}
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results):
|
||||
return dict(acc=1)
|
||||
|
||||
|
||||
class TestGreedySamplerTrainLoop(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
val_dataloader = dict(
|
||||
dataset=dict(type='ToyDataset_GreedySamplerTrainLoop'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0)
|
||||
val_evaluator = dict(type='ToyMetric_GreedySamplerTrainLoop')
|
||||
|
||||
iter_based_cfg = dict(
|
||||
default_scope='mmrazor',
|
||||
model=dict(type='ToyModel_GreedySamplerTrainLoop'),
|
||||
work_dir=self.temp_dir,
|
||||
train_dataloader=dict(
|
||||
dataset=dict(type='ToyDataset_GreedySamplerTrainLoop'),
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_dataloader=val_dataloader,
|
||||
optim_wrapper=dict(
|
||||
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
|
||||
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
||||
val_evaluator=val_evaluator,
|
||||
train_cfg=dict(
|
||||
type='GreedySamplerTrainLoop',
|
||||
dataloader_val=val_dataloader,
|
||||
evaluator=val_evaluator,
|
||||
max_iters=12,
|
||||
val_interval=2,
|
||||
score_key='acc',
|
||||
flops_range=None,
|
||||
num_candidates=4,
|
||||
num_samples=2,
|
||||
top_k=2,
|
||||
prob_schedule='linear',
|
||||
schedule_start_iter=4,
|
||||
schedule_end_iter=10,
|
||||
init_prob=0.,
|
||||
max_prob=0.8),
|
||||
val_cfg=dict(),
|
||||
custom_hooks=[],
|
||||
default_hooks=dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook', interval=1, by_epoch=False),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook')),
|
||||
launcher='none',
|
||||
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||
)
|
||||
self.iter_based_cfg = Config(iter_based_cfg)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def test_init(self):
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_init_GreedySamplerTrainLoop'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
loop = runner.build_train_loop(cfg.train_cfg)
|
||||
self.assertIsInstance(loop, GreedySamplerTrainLoop)
|
||||
|
||||
def test_update_cur_prob(self):
|
||||
# prob_schedule = linear
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_update_cur_prob1'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
loop = runner.build_train_loop(cfg.train_cfg)
|
||||
|
||||
loop.update_cur_prob(loop.schedule_end_iter - 1)
|
||||
self.assertGreater(loop.max_prob, loop.cur_prob)
|
||||
loop.update_cur_prob(loop.schedule_end_iter + 1)
|
||||
self.assertEqual(loop.max_prob, loop.cur_prob)
|
||||
|
||||
# prob_schedule = consine
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_update_cur_prob2'
|
||||
cfg.train_cfg.prob_schedule = 'consine'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
loop = runner.build_train_loop(cfg.train_cfg)
|
||||
|
||||
loop.update_cur_prob(loop.schedule_end_iter - 1)
|
||||
self.assertGreater(loop.max_prob, loop.cur_prob)
|
||||
loop.update_cur_prob(loop.schedule_end_iter + 1)
|
||||
self.assertEqual(loop.max_prob, loop.cur_prob)
|
||||
|
||||
def test_sample_subnet(self):
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_sample_subnet'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
loop = runner.build_train_loop(cfg.train_cfg)
|
||||
loop.cur_prob = loop.max_prob
|
||||
self.assertEqual(len(loop.top_k_candidates), 0)
|
||||
|
||||
loop._iter = loop.val_interval
|
||||
subnet = loop.sample_subnet()
|
||||
self.assertEqual(subnet, fake_subnet)
|
||||
self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1)
|
||||
|
||||
@patch('mmrazor.models.subnet.FlopsEstimator.get_model_complexity_info')
|
||||
def test_run(self, mock_flops):
|
||||
# test run with flops_range=None
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_run1'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
runner.train()
|
||||
|
||||
self.assertEqual(runner.iter, runner.max_iters)
|
||||
assert os.path.exists(os.path.join(self.temp_dir, 'candidates.pkl'))
|
||||
|
||||
# test run with _check_constraints
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_run2'
|
||||
cfg.train_cfg.flops_range = (0, 100)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
mock_flops.return_value = (50., 1)
|
||||
fix_subnet = FixSubnet(modules=fake_subnet)
|
||||
runner.model.export_fix_subnet = MagicMock(return_value=fix_subnet)
|
||||
runner.train()
|
||||
|
||||
self.assertEqual(runner.iter, runner.max_iters)
|
||||
assert os.path.exists(os.path.join(self.temp_dir, 'candidates.pkl'))
|
Loading…
Reference in New Issue