fix bug (#407)
* add resnet50 * fix bug * fix bug * fix bug * refine * fix bug * add choice and mask of units to checkpoint (#397) * add choice and mask of units to checkpoint * update * fix bug * remove device operation * fix bug * fix circle ci error * fix error in numpy for circle ci * fix bug in requirements * restore * add a note * a new solution * save mutable_channel.mask as float for dist training * refine * mv meta file test Co-authored-by: liukai <your_email@abc.example> Co-authored-by: jacky <jacky@xx.com> * fix bug * add assert * fix bug * change iter to epoch * bn_imp use abs Co-authored-by: jacky <jacky@xx.com> Co-authored-by: liukai <your_email@abc.example>pull/413/head
parent
a91e2c7de2
commit
122ee38d69
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
@ -8,7 +9,7 @@ import yaml
|
|||
MMRAZOR_ROOT = Path(__file__).absolute().parents[1]
|
||||
|
||||
|
||||
class TestMetafiles:
|
||||
class TestMetafiles(unittest.TestCase):
|
||||
|
||||
def get_metafiles(self, code_path):
|
||||
"""
|
||||
|
@ -51,3 +52,7 @@ class TestMetafiles:
|
|||
assert model['Name'] == correct_name, \
|
||||
f'name error in {metafile}, correct name should ' \
|
||||
f'be {correct_name}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,29 @@
|
|||
_base_ = ['mmcls::resnet/resnet50_8xb32_in1k.py']
|
||||
|
||||
data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}
|
||||
architecture = _base_.model
|
||||
architecture.update({
|
||||
'init_cfg': {
|
||||
'type':
|
||||
'Pretrained',
|
||||
'checkpoint':
|
||||
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa
|
||||
}
|
||||
})
|
||||
|
||||
model = dict(
|
||||
_delete_=True,
|
||||
_scope_='mmrazor',
|
||||
type='ChexAlgorithm',
|
||||
architecture=architecture,
|
||||
mutator_cfg=dict(
|
||||
type='ChexMutator',
|
||||
channel_unit_cfg=dict(
|
||||
type='ChexUnit', default_args=dict(choice_mode='number', )),
|
||||
channel_ratio=0.7,
|
||||
),
|
||||
delta_t=2,
|
||||
total_steps=60,
|
||||
init_growth_rate=0.3,
|
||||
)
|
||||
custom_hooks = [{'type': 'mmrazor.ChexHook'}]
|
|
@ -204,23 +204,24 @@ class ItePruneAlgorithm(BaseAlgorithm):
|
|||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
mode: str = 'tensor') -> ForwardResults:
|
||||
"""Forward."""
|
||||
if not hasattr(self, 'prune_config_manager'):
|
||||
# self._iters_per_epoch() only available after initiation
|
||||
self.prune_config_manager = self._init_prune_config_manager()
|
||||
|
||||
if self.prune_config_manager.is_prune_time(self._iter):
|
||||
if self.training:
|
||||
if not hasattr(self, 'prune_config_manager'):
|
||||
# self._iters_per_epoch() only available after initiation
|
||||
self.prune_config_manager = self._init_prune_config_manager()
|
||||
if self.prune_config_manager.is_prune_time(self._iter):
|
||||
|
||||
config = self.prune_config_manager.prune_at(self._iter)
|
||||
config = self.prune_config_manager.prune_at(self._iter)
|
||||
|
||||
self.mutator.set_choices(config)
|
||||
self.mutator.set_choices(config)
|
||||
|
||||
logger = MMLogger.get_current_instance()
|
||||
if (self.by_epoch):
|
||||
logger.info(
|
||||
f'The model is pruned at {self._epoch}th epoch once.')
|
||||
else:
|
||||
logger.info(
|
||||
f'The model is pruned at {self._iter}th iter once.')
|
||||
logger = MMLogger.get_current_instance()
|
||||
if (self.by_epoch):
|
||||
logger.info(
|
||||
f'The model is pruned at {self._epoch}th epoch once.')
|
||||
else:
|
||||
logger.info(
|
||||
f'The model is pruned at {self._iter}th iter once.')
|
||||
|
||||
return super().forward(inputs, data_samples, mode)
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .chex_algorithm import ChexAlgorithm
|
||||
from .chex_hook import ChexHook
|
||||
from .chex_mutator import ChexMutator
|
||||
from .chex_ops import ChexConv2d, ChexLinear, ChexMixin
|
||||
from .chex_unit import ChexUnit
|
||||
|
||||
__all__ = [
|
||||
'ChexAlgorithm', 'ChexMutator', 'ChexUnit', 'ChexConv2d', 'ChexLinear',
|
||||
'ChexMixin'
|
||||
'ChexMixin', 'ChexHook'
|
||||
]
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import math
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine import dist
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.model.utils import convert_sync_batchnorm
|
||||
|
||||
from mmrazor.models.algorithms import BaseAlgorithm
|
||||
from mmrazor.registry import MODELS
|
||||
from mmrazor.utils import print_log
|
||||
from .chex_mutator import ChexMutator
|
||||
from .utils import RuntimeInfo
|
||||
|
||||
|
@ -26,6 +31,9 @@ class ChexAlgorithm(BaseAlgorithm):
|
|||
init_cfg: Optional[Dict] = None):
|
||||
super().__init__(architecture, data_preprocessor, init_cfg)
|
||||
|
||||
if dist.is_distributed():
|
||||
self.architecture = convert_sync_batchnorm(self.architecture)
|
||||
|
||||
self.delta_t = delta_t
|
||||
self.total_steps = total_steps
|
||||
self.init_growth_rate = init_growth_rate
|
||||
|
@ -35,17 +43,31 @@ class ChexAlgorithm(BaseAlgorithm):
|
|||
|
||||
def forward(self, inputs, data_samples=None, mode: str = 'tensor'):
|
||||
if self.training: #
|
||||
if RuntimeInfo.iter() % self.delta_t == 0 and \
|
||||
RuntimeInfo.iter() // self.delta_t < self.total_steps:
|
||||
self.mutator.prune()
|
||||
self.mutator.grow(self.growth_ratio)
|
||||
if RuntimeInfo.epoch() % self.delta_t == 0 and \
|
||||
RuntimeInfo.epoch() < self.total_steps and \
|
||||
RuntimeInfo.iter_by_epoch() == 0:
|
||||
with torch.no_grad():
|
||||
self.mutator.prune()
|
||||
print_log(f'prune model with {self.mutator.channel_ratio}')
|
||||
self.log_choices()
|
||||
|
||||
self.mutator.grow(self.growth_ratio)
|
||||
print_log(f'grow model with {self.growth_ratio}')
|
||||
self.log_choices()
|
||||
return super().forward(inputs, data_samples, mode)
|
||||
|
||||
@property
|
||||
def growth_ratio(self):
|
||||
# return growth ratio in current epoch
|
||||
def cos():
|
||||
a = math.pi * RuntimeInfo.epoch() / RuntimeInfo.max_epochs()
|
||||
a = math.pi * RuntimeInfo.epoch() / self.total_steps
|
||||
return (math.cos(a) + 1) / 2
|
||||
|
||||
return self.init_growth_rate * cos()
|
||||
|
||||
def log_choices(self):
|
||||
if dist.get_rank() == 0:
|
||||
config = {}
|
||||
for unit in self.mutator.mutable_units:
|
||||
config[unit.name] = unit.current_choice
|
||||
print_log(json.dumps(config, indent=4))
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.hooks import Hook
|
||||
|
||||
from mmrazor.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ChexHook(Hook):
|
||||
pass
|
||||
# @classmethod
|
||||
# def algorithm(cls, runner):
|
||||
# if dist.is_distributed():
|
||||
# return runner.model.module
|
||||
# else:
|
||||
# return runner.model
|
||||
|
||||
# def before_val(self, runner) -> None:
|
||||
# algorithm = self.algorithm(runner)
|
||||
# if dist.get_rank() == 0:
|
||||
# config = {}
|
||||
# for unit in algorithm.mutator.mutable_units:
|
||||
# config[unit.name] = unit.current_choice
|
||||
# print_log(json.dumps(config, indent=4))
|
||||
# print_log(f'growth_ratio: {algorithm.growth_ratio}')
|
|
@ -29,9 +29,10 @@ class ChexMutator(ChannelMutator):
|
|||
step1: get pruning structure
|
||||
step2: prune based on ChexMixin.prune_imp
|
||||
"""
|
||||
choices = self._get_prune_choices()
|
||||
for unit in self.mutable_units:
|
||||
unit.prune(choices[unit.name])
|
||||
with torch.no_grad():
|
||||
choices = self._get_prune_choices()
|
||||
for unit in self.mutable_units:
|
||||
unit.prune(choices[unit.name])
|
||||
|
||||
def grow(self, growth_ratio=0.0):
|
||||
"""Make the model grow.
|
||||
|
@ -60,9 +61,16 @@ class ChexMutator(ChannelMutator):
|
|||
unit: ChexUnit
|
||||
bn_imps[unit.name] = unit.bn_imp
|
||||
bn_imp: torch.Tensor = torch.cat(list(bn_imps.values()), dim=0)
|
||||
num_remain = int(self.channel_ratio * len(bn_imp))
|
||||
threshold = bn_imp.topk(num_remain)[0][-1]
|
||||
|
||||
num_total_channel = len(bn_imp)
|
||||
num_min_remained = int(self.channel_ratio * num_total_channel)
|
||||
threshold = bn_imp.topk(num_min_remained)[0][-1]
|
||||
|
||||
num_remained = 0
|
||||
for unit in self.mutable_units:
|
||||
num = (bn_imps[unit.name] >= threshold).float().sum().long().item()
|
||||
choices[unit.name] = num
|
||||
num = (bn_imps[unit.name] >= threshold).long().sum().item()
|
||||
choices[unit.name] = max(num, 1)
|
||||
num_remained += choices[unit.name]
|
||||
assert num_remained >= num_min_remained, \
|
||||
f'{num_remained},{num_min_remained}'
|
||||
return choices
|
||||
|
|
|
@ -32,22 +32,18 @@ class ChexUnit(L1MutableChannelUnit):
|
|||
def prune(self, num_remaining):
|
||||
# prune the channels to num_remaining
|
||||
def get_prune_imp():
|
||||
prune_imp: torch.Tensor = torch.zeros([self.num_channels])
|
||||
prune_imp = 0
|
||||
for channel in self.chex_channels:
|
||||
module = channel.module
|
||||
prune_imp = prune_imp.to(
|
||||
module.prune_imp(num_remaining).device)
|
||||
prune_imp = prune_imp + module.prune_imp(
|
||||
num_remaining)[channel.start:channel.end]
|
||||
return prune_imp
|
||||
|
||||
prune_imp = get_prune_imp()
|
||||
index = prune_imp.topk(num_remaining)[1]
|
||||
mask: torch.Tensor = torch.zeros([self.num_channels],
|
||||
device=prune_imp.device)
|
||||
mask.scatter_(-1, index, 1.0)
|
||||
mask = mask.bool()
|
||||
self.mutable_channel.current_choice.data = mask
|
||||
with torch.no_grad():
|
||||
prune_imp = get_prune_imp()
|
||||
index = prune_imp.topk(num_remaining)[1]
|
||||
self.mutable_channel.mask.fill_(0.0)
|
||||
self.mutable_channel.mask.data.scatter_(-1, index, 1.0)
|
||||
|
||||
def grow(self, num):
|
||||
assert num >= 0
|
||||
|
@ -55,10 +51,9 @@ class ChexUnit(L1MutableChannelUnit):
|
|||
return
|
||||
|
||||
def get_growth_imp():
|
||||
growth_imp: torch.Tensor = torch.zeros([self.num_channels])
|
||||
growth_imp = 0
|
||||
for channel in self.chex_channels:
|
||||
module = channel.module
|
||||
growth_imp = growth_imp.to(module.growth_imp.device)
|
||||
growth_imp = growth_imp + module.growth_imp[channel.
|
||||
start:channel.end]
|
||||
return growth_imp
|
||||
|
@ -73,23 +68,22 @@ class ChexUnit(L1MutableChannelUnit):
|
|||
select_index = index_free[select_index]
|
||||
else:
|
||||
select_index = index_free
|
||||
mask.index_fill_(-1, select_index, 1.0)
|
||||
|
||||
self.mutable_channel.current_choice.data = mask
|
||||
self.mutable_channel.mask.index_fill_(-1, select_index, 1.0)
|
||||
|
||||
@property
|
||||
def bn_imp(self):
|
||||
imp = torch.zeros([self.num_channels])
|
||||
num_layers = 0
|
||||
for channel in self.output_related:
|
||||
module = channel.module
|
||||
if isinstance(module, nn.modules.batchnorm._BatchNorm):
|
||||
imp = imp.to(module.weight.device)
|
||||
imp = imp + module.weight[channel.start:channel.end]
|
||||
num_layers += 1
|
||||
assert num_layers > 0
|
||||
imp = imp / num_layers
|
||||
return imp
|
||||
with torch.no_grad():
|
||||
imp = 0
|
||||
num_layers = 0
|
||||
for channel in self.output_related:
|
||||
module = channel.module
|
||||
if isinstance(module, nn.modules.batchnorm._BatchNorm):
|
||||
imp = imp + module.weight[channel.start:channel.end].abs()
|
||||
num_layers += 1
|
||||
assert num_layers > 0
|
||||
imp = imp / num_layers
|
||||
return imp
|
||||
|
||||
@property
|
||||
def chex_channels(self):
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
from mmengine.logging import MessageHub
|
||||
|
||||
|
||||
|
@ -23,3 +25,12 @@ class RuntimeInfo():
|
|||
@classmethod
|
||||
def iter(cls):
|
||||
return cls.get_info('iter')
|
||||
|
||||
@classmethod
|
||||
def max_iters(cls):
|
||||
return cls.get_info('max_iters')
|
||||
|
||||
@classmethod
|
||||
def iter_by_epoch(cls):
|
||||
iter_per_epoch = math.ceil(cls.max_iters() / cls.max_epochs())
|
||||
return cls.iter() % iter_per_epoch
|
||||
|
|
|
@ -27,7 +27,6 @@ class SquentialMutableChannel(SimpleMutableChannel):
|
|||
super().__init__(num_channels, **kwargs)
|
||||
assert choice_mode in ['ratio', 'number']
|
||||
self.choice_mode = choice_mode
|
||||
self.mask = torch.ones([self.num_channels]).bool()
|
||||
|
||||
@property
|
||||
def is_num_mode(self):
|
||||
|
@ -50,14 +49,13 @@ class SquentialMutableChannel(SimpleMutableChannel):
|
|||
int_choice = self._ratio2num(choice)
|
||||
else:
|
||||
int_choice = choice
|
||||
mask = torch.zeros([self.num_channels], device=self.mask.device)
|
||||
mask[0:int_choice] = 1
|
||||
self.mask = mask.bool()
|
||||
self.mask.fill_(0.0)
|
||||
self.mask[0:int_choice] = 1.0
|
||||
|
||||
@property
|
||||
def current_mask(self) -> torch.Tensor:
|
||||
"""Return current mask."""
|
||||
return self.mask
|
||||
return self.mask.bool()
|
||||
|
||||
# methods for
|
||||
|
||||
|
|
|
@ -20,7 +20,10 @@ class SimpleMutableChannel(BaseMutableChannel):
|
|||
|
||||
def __init__(self, num_channels: int, **kwargs) -> None:
|
||||
super().__init__(num_channels, **kwargs)
|
||||
self.mask = torch.ones(num_channels).bool()
|
||||
mask = torch.ones([self.num_channels
|
||||
]) # save bool as float for dist training
|
||||
self.register_buffer('mask', mask)
|
||||
self.mask: torch.Tensor
|
||||
|
||||
# choice
|
||||
|
||||
|
@ -32,7 +35,7 @@ class SimpleMutableChannel(BaseMutableChannel):
|
|||
@current_choice.setter
|
||||
def current_choice(self, choice: torch.Tensor):
|
||||
"""Set current choice."""
|
||||
self.mask = choice.to(self.mask.device).bool()
|
||||
self.mask = choice.to(self.mask.device).float()
|
||||
|
||||
@property
|
||||
def current_mask(self) -> torch.Tensor:
|
||||
|
|
|
@ -4,6 +4,7 @@ interrogate
|
|||
isort==4.3.21
|
||||
nbconvert
|
||||
nbformat
|
||||
numpy < 1.24.0 # A temporary solution for tests with mmdet.
|
||||
pytest
|
||||
xdoctest >= 0.10.0
|
||||
yapf
|
||||
|
|
|
@ -261,3 +261,30 @@ class TestItePruneAlgorithm(unittest.TestCase):
|
|||
algorithm.forward(
|
||||
data['inputs'], data['data_samples'], mode='loss')
|
||||
self.assertEqual(algorithm.step_freq, epoch_step * iter_per_epoch)
|
||||
|
||||
def test_resume(self):
|
||||
algorithm: ItePruneAlgorithm = ItePruneAlgorithm(
|
||||
MODEL_CFG,
|
||||
mutator_cfg=MUTATOR_CONFIG_NUM,
|
||||
target_pruning_ratio=None,
|
||||
step_freq=1,
|
||||
prune_times=1,
|
||||
).to(DEVICE)
|
||||
algorithm.mutator.set_choices(algorithm.mutator.sample_choices())
|
||||
state_dict = algorithm.state_dict()
|
||||
print(state_dict.keys())
|
||||
|
||||
algorithm2: ItePruneAlgorithm = ItePruneAlgorithm(
|
||||
MODEL_CFG,
|
||||
mutator_cfg=MUTATOR_CONFIG_NUM,
|
||||
target_pruning_ratio=None,
|
||||
step_freq=1,
|
||||
prune_times=1,
|
||||
).to(DEVICE)
|
||||
|
||||
algorithm2.load_state_dict(state_dict)
|
||||
|
||||
print(algorithm.mutator.current_choices)
|
||||
print(algorithm2.mutator.current_choices)
|
||||
self.assertDictEqual(algorithm.mutator.current_choices,
|
||||
algorithm2.mutator.current_choices)
|
||||
|
|
Loading…
Reference in New Issue