* add resnet50

* fix bug

* fix bug

* fix bug

* refine

* fix bug

* add choice and mask of units to checkpoint (#397)

* add choice and mask of units to checkpoint

* update

* fix bug

* remove device operation

* fix bug

* fix circle ci error

* fix error in numpy for circle ci

* fix bug in requirements

* restore

* add a note

* a new solution

* save mutable_channel.mask as float for dist training

* refine

* mv meta file test

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: jacky <jacky@xx.com>

* fix bug

* add assert

* fix bug

* change iter to epoch

* bn_imp use abs

Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: liukai <your_email@abc.example>
pull/413/head
LKJacky 2022-12-22 17:41:10 +08:00 committed by GitHub
parent a91e2c7de2
commit 122ee38d69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 183 additions and 59 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
from pathlib import Path
import requests
@ -8,7 +9,7 @@ import yaml
MMRAZOR_ROOT = Path(__file__).absolute().parents[1]
class TestMetafiles:
class TestMetafiles(unittest.TestCase):
def get_metafiles(self, code_path):
"""
@ -51,3 +52,7 @@ class TestMetafiles:
assert model['Name'] == correct_name, \
f'name error in {metafile}, correct name should ' \
f'be {correct_name}'
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,29 @@
_base_ = ['mmcls::resnet/resnet50_8xb32_in1k.py']
data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}
architecture = _base_.model
architecture.update({
'init_cfg': {
'type':
'Pretrained',
'checkpoint':
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa
}
})
model = dict(
_delete_=True,
_scope_='mmrazor',
type='ChexAlgorithm',
architecture=architecture,
mutator_cfg=dict(
type='ChexMutator',
channel_unit_cfg=dict(
type='ChexUnit', default_args=dict(choice_mode='number', )),
channel_ratio=0.7,
),
delta_t=2,
total_steps=60,
init_growth_rate=0.3,
)
custom_hooks = [{'type': 'mmrazor.ChexHook'}]

View File

@ -204,23 +204,24 @@ class ItePruneAlgorithm(BaseAlgorithm):
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'tensor') -> ForwardResults:
"""Forward."""
if not hasattr(self, 'prune_config_manager'):
# self._iters_per_epoch() only available after initiation
self.prune_config_manager = self._init_prune_config_manager()
if self.prune_config_manager.is_prune_time(self._iter):
if self.training:
if not hasattr(self, 'prune_config_manager'):
# self._iters_per_epoch() only available after initiation
self.prune_config_manager = self._init_prune_config_manager()
if self.prune_config_manager.is_prune_time(self._iter):
config = self.prune_config_manager.prune_at(self._iter)
config = self.prune_config_manager.prune_at(self._iter)
self.mutator.set_choices(config)
self.mutator.set_choices(config)
logger = MMLogger.get_current_instance()
if (self.by_epoch):
logger.info(
f'The model is pruned at {self._epoch}th epoch once.')
else:
logger.info(
f'The model is pruned at {self._iter}th iter once.')
logger = MMLogger.get_current_instance()
if (self.by_epoch):
logger.info(
f'The model is pruned at {self._epoch}th epoch once.')
else:
logger.info(
f'The model is pruned at {self._iter}th iter once.')
return super().forward(inputs, data_samples, mode)

View File

@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .chex_algorithm import ChexAlgorithm
from .chex_hook import ChexHook
from .chex_mutator import ChexMutator
from .chex_ops import ChexConv2d, ChexLinear, ChexMixin
from .chex_unit import ChexUnit
__all__ = [
'ChexAlgorithm', 'ChexMutator', 'ChexUnit', 'ChexConv2d', 'ChexLinear',
'ChexMixin'
'ChexMixin', 'ChexHook'
]

View File

@ -1,12 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import math
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
from mmengine import dist
from mmengine.model import BaseModel
from mmengine.model.utils import convert_sync_batchnorm
from mmrazor.models.algorithms import BaseAlgorithm
from mmrazor.registry import MODELS
from mmrazor.utils import print_log
from .chex_mutator import ChexMutator
from .utils import RuntimeInfo
@ -26,6 +31,9 @@ class ChexAlgorithm(BaseAlgorithm):
init_cfg: Optional[Dict] = None):
super().__init__(architecture, data_preprocessor, init_cfg)
if dist.is_distributed():
self.architecture = convert_sync_batchnorm(self.architecture)
self.delta_t = delta_t
self.total_steps = total_steps
self.init_growth_rate = init_growth_rate
@ -35,17 +43,31 @@ class ChexAlgorithm(BaseAlgorithm):
def forward(self, inputs, data_samples=None, mode: str = 'tensor'):
if self.training: #
if RuntimeInfo.iter() % self.delta_t == 0 and \
RuntimeInfo.iter() // self.delta_t < self.total_steps:
self.mutator.prune()
self.mutator.grow(self.growth_ratio)
if RuntimeInfo.epoch() % self.delta_t == 0 and \
RuntimeInfo.epoch() < self.total_steps and \
RuntimeInfo.iter_by_epoch() == 0:
with torch.no_grad():
self.mutator.prune()
print_log(f'prune model with {self.mutator.channel_ratio}')
self.log_choices()
self.mutator.grow(self.growth_ratio)
print_log(f'grow model with {self.growth_ratio}')
self.log_choices()
return super().forward(inputs, data_samples, mode)
@property
def growth_ratio(self):
# return growth ratio in current epoch
def cos():
a = math.pi * RuntimeInfo.epoch() / RuntimeInfo.max_epochs()
a = math.pi * RuntimeInfo.epoch() / self.total_steps
return (math.cos(a) + 1) / 2
return self.init_growth_rate * cos()
def log_choices(self):
if dist.get_rank() == 0:
config = {}
for unit in self.mutator.mutable_units:
config[unit.name] = unit.current_choice
print_log(json.dumps(config, indent=4))

View File

@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from mmrazor.registry import HOOKS
@HOOKS.register_module()
class ChexHook(Hook):
pass
# @classmethod
# def algorithm(cls, runner):
# if dist.is_distributed():
# return runner.model.module
# else:
# return runner.model
# def before_val(self, runner) -> None:
# algorithm = self.algorithm(runner)
# if dist.get_rank() == 0:
# config = {}
# for unit in algorithm.mutator.mutable_units:
# config[unit.name] = unit.current_choice
# print_log(json.dumps(config, indent=4))
# print_log(f'growth_ratio: {algorithm.growth_ratio}')

View File

@ -29,9 +29,10 @@ class ChexMutator(ChannelMutator):
step1: get pruning structure
step2: prune based on ChexMixin.prune_imp
"""
choices = self._get_prune_choices()
for unit in self.mutable_units:
unit.prune(choices[unit.name])
with torch.no_grad():
choices = self._get_prune_choices()
for unit in self.mutable_units:
unit.prune(choices[unit.name])
def grow(self, growth_ratio=0.0):
"""Make the model grow.
@ -60,9 +61,16 @@ class ChexMutator(ChannelMutator):
unit: ChexUnit
bn_imps[unit.name] = unit.bn_imp
bn_imp: torch.Tensor = torch.cat(list(bn_imps.values()), dim=0)
num_remain = int(self.channel_ratio * len(bn_imp))
threshold = bn_imp.topk(num_remain)[0][-1]
num_total_channel = len(bn_imp)
num_min_remained = int(self.channel_ratio * num_total_channel)
threshold = bn_imp.topk(num_min_remained)[0][-1]
num_remained = 0
for unit in self.mutable_units:
num = (bn_imps[unit.name] >= threshold).float().sum().long().item()
choices[unit.name] = num
num = (bn_imps[unit.name] >= threshold).long().sum().item()
choices[unit.name] = max(num, 1)
num_remained += choices[unit.name]
assert num_remained >= num_min_remained, \
f'{num_remained},{num_min_remained}'
return choices

View File

@ -32,22 +32,18 @@ class ChexUnit(L1MutableChannelUnit):
def prune(self, num_remaining):
# prune the channels to num_remaining
def get_prune_imp():
prune_imp: torch.Tensor = torch.zeros([self.num_channels])
prune_imp = 0
for channel in self.chex_channels:
module = channel.module
prune_imp = prune_imp.to(
module.prune_imp(num_remaining).device)
prune_imp = prune_imp + module.prune_imp(
num_remaining)[channel.start:channel.end]
return prune_imp
prune_imp = get_prune_imp()
index = prune_imp.topk(num_remaining)[1]
mask: torch.Tensor = torch.zeros([self.num_channels],
device=prune_imp.device)
mask.scatter_(-1, index, 1.0)
mask = mask.bool()
self.mutable_channel.current_choice.data = mask
with torch.no_grad():
prune_imp = get_prune_imp()
index = prune_imp.topk(num_remaining)[1]
self.mutable_channel.mask.fill_(0.0)
self.mutable_channel.mask.data.scatter_(-1, index, 1.0)
def grow(self, num):
assert num >= 0
@ -55,10 +51,9 @@ class ChexUnit(L1MutableChannelUnit):
return
def get_growth_imp():
growth_imp: torch.Tensor = torch.zeros([self.num_channels])
growth_imp = 0
for channel in self.chex_channels:
module = channel.module
growth_imp = growth_imp.to(module.growth_imp.device)
growth_imp = growth_imp + module.growth_imp[channel.
start:channel.end]
return growth_imp
@ -73,23 +68,22 @@ class ChexUnit(L1MutableChannelUnit):
select_index = index_free[select_index]
else:
select_index = index_free
mask.index_fill_(-1, select_index, 1.0)
self.mutable_channel.current_choice.data = mask
self.mutable_channel.mask.index_fill_(-1, select_index, 1.0)
@property
def bn_imp(self):
imp = torch.zeros([self.num_channels])
num_layers = 0
for channel in self.output_related:
module = channel.module
if isinstance(module, nn.modules.batchnorm._BatchNorm):
imp = imp.to(module.weight.device)
imp = imp + module.weight[channel.start:channel.end]
num_layers += 1
assert num_layers > 0
imp = imp / num_layers
return imp
with torch.no_grad():
imp = 0
num_layers = 0
for channel in self.output_related:
module = channel.module
if isinstance(module, nn.modules.batchnorm._BatchNorm):
imp = imp + module.weight[channel.start:channel.end].abs()
num_layers += 1
assert num_layers > 0
imp = imp / num_layers
return imp
@property
def chex_channels(self):

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from mmengine.logging import MessageHub
@ -23,3 +25,12 @@ class RuntimeInfo():
@classmethod
def iter(cls):
return cls.get_info('iter')
@classmethod
def max_iters(cls):
return cls.get_info('max_iters')
@classmethod
def iter_by_epoch(cls):
iter_per_epoch = math.ceil(cls.max_iters() / cls.max_epochs())
return cls.iter() % iter_per_epoch

View File

@ -27,7 +27,6 @@ class SquentialMutableChannel(SimpleMutableChannel):
super().__init__(num_channels, **kwargs)
assert choice_mode in ['ratio', 'number']
self.choice_mode = choice_mode
self.mask = torch.ones([self.num_channels]).bool()
@property
def is_num_mode(self):
@ -50,14 +49,13 @@ class SquentialMutableChannel(SimpleMutableChannel):
int_choice = self._ratio2num(choice)
else:
int_choice = choice
mask = torch.zeros([self.num_channels], device=self.mask.device)
mask[0:int_choice] = 1
self.mask = mask.bool()
self.mask.fill_(0.0)
self.mask[0:int_choice] = 1.0
@property
def current_mask(self) -> torch.Tensor:
"""Return current mask."""
return self.mask
return self.mask.bool()
# methods for

View File

@ -20,7 +20,10 @@ class SimpleMutableChannel(BaseMutableChannel):
def __init__(self, num_channels: int, **kwargs) -> None:
super().__init__(num_channels, **kwargs)
self.mask = torch.ones(num_channels).bool()
mask = torch.ones([self.num_channels
]) # save bool as float for dist training
self.register_buffer('mask', mask)
self.mask: torch.Tensor
# choice
@ -32,7 +35,7 @@ class SimpleMutableChannel(BaseMutableChannel):
@current_choice.setter
def current_choice(self, choice: torch.Tensor):
"""Set current choice."""
self.mask = choice.to(self.mask.device).bool()
self.mask = choice.to(self.mask.device).float()
@property
def current_mask(self) -> torch.Tensor:

View File

@ -4,6 +4,7 @@ interrogate
isort==4.3.21
nbconvert
nbformat
numpy < 1.24.0 # A temporary solution for tests with mmdet.
pytest
xdoctest >= 0.10.0
yapf

View File

@ -261,3 +261,30 @@ class TestItePruneAlgorithm(unittest.TestCase):
algorithm.forward(
data['inputs'], data['data_samples'], mode='loss')
self.assertEqual(algorithm.step_freq, epoch_step * iter_per_epoch)
def test_resume(self):
algorithm: ItePruneAlgorithm = ItePruneAlgorithm(
MODEL_CFG,
mutator_cfg=MUTATOR_CONFIG_NUM,
target_pruning_ratio=None,
step_freq=1,
prune_times=1,
).to(DEVICE)
algorithm.mutator.set_choices(algorithm.mutator.sample_choices())
state_dict = algorithm.state_dict()
print(state_dict.keys())
algorithm2: ItePruneAlgorithm = ItePruneAlgorithm(
MODEL_CFG,
mutator_cfg=MUTATOR_CONFIG_NUM,
target_pruning_ratio=None,
step_freq=1,
prune_times=1,
).to(DEVICE)
algorithm2.load_state_dict(state_dict)
print(algorithm.mutator.current_choices)
print(algorithm2.mutator.current_choices)
self.assertDictEqual(algorithm.mutator.current_choices,
algorithm2.mutator.current_choices)