mmpretrain/tools/convert_models/repvgg_to_mmcls.py

60 lines
1.8 KiB
Python
Raw Normal View History

[Feature] Add RepVGG backbone and checkpoints. (#414) * Add RepVGG code. * Add se_module as plugin. * Add the repvggA0 primitive config * Change repvggA0.py to fit mmcls * Add RepVGG configs * Add repvgg_to_mmcls * Add tools/deployment/convert_repvggblock_param_to_deploy.py * Change configs/repvgg/README.md * Streamlining the number of configuration files. * Fix lints * Delete plugins * Delete code about plugin. * Modify the code for using se module. * Modify config to fit repvgg with se. * Change se_cfg to allow loading of pre-training parameters. * Reduce the complexity of the configuration file. * Finsh unitest for repvgg. * Fix bug about se in repvgg_to_mmcls. * Rename convert_repvggblock_param_to_deploy.py to reparameterize_repvgg.py, and delete setting about device. * test commit * test commit * test commit command * Modify repvgg.py to make the code more readable. * Add value=0 in F.pad() * Add se_cfg to arch_settings. * Fix bug. * modeify some attr name and Update unit tests * rename stage_0 to stem and branch_identity to branch_norm * update unit tests * add m.eval in unit tests * [Enhance] Enhence SE layer to support custom squeeze channels. (#417) * add enhenced SE * Update * rm basechannel * fix docstring * Update se_layer.py fix docstring * [Docs] Add algorithm readme and update meta yml (#418) * Add README.md for models without checkpoints. * Update model-index.yml * Update metafile.yml of seresnet * [Enhance] Add `hparams` argument in `AutoAugment` and `RandAugment` and some other improvement. (#398) * Add hparams argument in `AutoAugment` and `RandAugment`. And `pad_val` supports sequence instead of tuple only. * Add unit tests for `AutoAugment` and `hparams` in `RandAugment`. * Use smaller test image to speed up uni tests. * Use hparams to simplify RandAugment config in swin-transformer. * Rename augment config name from `pipeline` to `pipelines`. * Add some commnet ad docstring. * [Feature] Support classwise weight in losses (#388) * Add classwise weight in losses:CE,BCE,softBCE * Update unit test * rm some extra code * rm some extra code * fix broadcast * fix broadcast * update unit tests * use new_tensor * fix lint * [Enhance] Better result visualization (#419) * Imporve result visualization to support wait time and change the backend to matplotlib. * Add unit test for visualization * Add adaptive dpi function * Rename `imshow_cls_result` to `imshow_infos`. * Support str in `imshow_infos` * Improve docstring. * Bump version to v0.15.0 (#426) * [CI] Add PyTorch 1.9 and Python 3.9 build workflow, and remove some CI. (#422) * Add PyTorch 1.9 build workflow, and remove some CI. * Add Python 3.9 CI * Show Python 3.9 support. * [Enhance] Rename the option `--options` in some tools to `--cfg-options`. (#425) * [Docs] Fix sphinx version (#429) * [Docs] Add `CITATION.cff` (#428) * Add CITATION.cff * Fix typo in setup.py * Change author in setup.py * modeify some attr name and Update unit tests * rename stage_0 to stem and branch_identity to branch_norm * update unit tests * add m.eval in unit tests * Update unit tests * refactor * refactor * Alignment inference accuracy * Update configs, readme and metafile * Update readme * return tuple and fix metafile * fix unit test * rm regnet and classifiers changes * update auto_aug * update metafile & readme * use delattr * rename cfgs * Update checkpoint url * Update readme * Rename config files. * Update readme and metafile * add comment * Update mmcls/models/backbones/repvgg.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Update docstring * Improve docstring. * Update unittest_testblock Co-authored-by: Ezra-Yu <1105212286@qq.com> Co-authored-by: Ma Zerun <mzr1996@163.com>
2021-09-29 11:06:23 +08:00
import argparse
from collections import OrderedDict
from pathlib import Path
import torch
def convert(src, dst):
print('Converting...')
blobs = torch.load(src, map_location='cpu')
converted_state_dict = OrderedDict()
for key in blobs:
splited_key = key.split('.')
splited_key = ['norm' if i == 'bn' else i for i in splited_key]
splited_key = [
'branch_norm' if i == 'rbr_identity' else i for i in splited_key
]
splited_key = [
'branch_1x1' if i == 'rbr_1x1' else i for i in splited_key
]
splited_key = [
'branch_3x3' if i == 'rbr_dense' else i for i in splited_key
]
splited_key = [
'backbone.stem' if i[:6] == 'stage0' else i for i in splited_key
]
splited_key = [
'backbone.stage_' + i[5] if i[:5] == 'stage' else i
for i in splited_key
]
splited_key = ['se_layer' if i == 'se' else i for i in splited_key]
splited_key = ['conv1.conv' if i == 'down' else i for i in splited_key]
splited_key = ['conv2.conv' if i == 'up' else i for i in splited_key]
splited_key = ['head.fc' if i == 'linear' else i for i in splited_key]
new_key = '.'.join(splited_key)
converted_state_dict[new_key] = blobs[key]
torch.save(converted_state_dict, dst)
print('Done!')
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit()
dst.parent.mkdir(parents=True, exist_ok=True)
convert(args.src, args.dst)
if __name__ == '__main__':
main()