Clean depercated tools
parent
c6a2d482fd
commit
cebc1aa810
|
@ -1,12 +0,0 @@
|
|||
#!/usr/bin/env sh
|
||||
|
||||
|
||||
MKL_NUM_THREADS=4
|
||||
OMP_NUM_THREADS=1
|
||||
|
||||
|
||||
|
||||
# bash tools/slurm_train.sh mm_model detnas_train configs/nas/detnas/detnas_supernet_shufflenetv2_coco_1x_2.0_frcnn.py /mnt/lustre/dongpeijie/checkpoints/tests/detnas_pretrain_test
|
||||
|
||||
|
||||
bash tools/slurm_test.sh mm_model angle_test configs/nas/spos/spos_subnet_mobilenet_proxyless_gpu_8xb128_in1k_2.0.py /mnt/lustre/dongpeijie/spos_angelnas_flops_0.49G_acc_75.98_20220307-54f4698f_2.0.pth
|
|
@ -1,56 +0,0 @@
|
|||
#!/usr/bin/env sh
|
||||
|
||||
|
||||
MKL_NUM_THREADS=4
|
||||
OMP_NUM_THREADS=1
|
||||
|
||||
# train
|
||||
# srun --partition=mm_model \
|
||||
# --job-name=spos_train \
|
||||
# --gres=gpu:8 \
|
||||
# --ntasks=8 \
|
||||
# --ntasks-per-node=8 \
|
||||
# --cpus-per-task=8 \
|
||||
# --kill-on-bad-exit=1 \
|
||||
# python tools/train.py configs/nas/spos/spos_supernet_shufflenetv2_8xb128_in1k_2.0_example.py
|
||||
|
||||
# bash tools/slurm_train.sh mm_model spos_train configs/nas/spos/spos_supernet_shufflenetv2_8xb128_in1k_2.0_example.py ./work_dir/spos
|
||||
|
||||
# SPOS test
|
||||
# srun --partition=mm_model \
|
||||
# --job-name=spos_test \
|
||||
# --gres=gpu:1 \
|
||||
# --ntasks=1 \
|
||||
# --ntasks-per-node=1 \
|
||||
# --cpus-per-task=8 \
|
||||
# --kill-on-bad-exit=1 \
|
||||
# python tools/test.py configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k_2.0_example.py "/mnt/lustre/dongpeijie/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_2.0.pth"
|
||||
|
||||
# DetNAS train
|
||||
# srun --partition=mm_model \
|
||||
# --job-name=detnas_train \
|
||||
# --gres=gpu:8 \
|
||||
# --ntasks=8 \
|
||||
# --ntasks-per-node=8 \
|
||||
# --cpus-per-task=8 \
|
||||
# --kill-on-bad-exit=1 \
|
||||
# python tools/train.py configs/nas/detnas/detnas_supernet_shufflenetv2_coco_1x_2.0_frcnn.py
|
||||
|
||||
# bash tools/slurm_train.sh mm_model detnas_train configs/nas/detnas/detnas_supernet_shufflenetv2_coco_1x_2.0_frcnn.py ./work_dir/detnas_pretrain
|
||||
|
||||
# DetNAS test
|
||||
# srun --partition=mm_model \
|
||||
# --job-name=detnas_test \
|
||||
# --gres=gpu:1 \
|
||||
# --ntasks=1 \
|
||||
# --ntasks-per-node=1 \
|
||||
# --cpus-per-task=8 \
|
||||
# --kill-on-bad-exit=1 \
|
||||
# python tools/test.py configs/nas/detnas/detnas_subnet_shufflenetv2_8xb128_in1k_2.0_frcnn.py "/mnt/lustre/dongpeijie/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_2.0.pth"
|
||||
|
||||
|
||||
# CREAM Test
|
||||
# bash tools/slurm_test.sh mm_model cream_test configs/nas/cream/cream_14_subnet_mobilenet.py '/mnt/lustre/dongpeijie/14_2.0.pth'
|
||||
|
||||
# CREAM Train
|
||||
bash tools/slurm_train.sh mm_model cream_train configs/nas/cream/cream_14_subnet_mobilenet.py
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env sh
|
||||
|
||||
|
||||
MKL_NUM_THREADS=4
|
||||
OMP_NUM_THREADS=1
|
||||
|
||||
bash tools/slurm_test.sh mm_model spos_test configs/nas/darts/darts_subnet_1xb96_cifar10_2.0.py '/mnt/lustre/dongpeijie/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921_2.0.pth'
|
|
@ -1,31 +0,0 @@
|
|||
#!/usr/bin/env sh
|
||||
|
||||
|
||||
MKL_NUM_THREADS=4
|
||||
OMP_NUM_THREADS=1
|
||||
|
||||
# DetNAS train
|
||||
# srun --partition=mm_model \
|
||||
# --job-name=detnas_train \
|
||||
# --gres=gpu:8 \
|
||||
# --ntasks=8 \
|
||||
# --ntasks-per-node=8 \
|
||||
# --cpus-per-task=8 \
|
||||
# --kill-on-bad-exit=1 \
|
||||
# python tools/train.py configs/nas/detnas/detnas_supernet_shufflenetv2_coco_1x_2.0_frcnn.py
|
||||
|
||||
# bash tools/slurm_train.sh mm_model detnas_train configs/nas/detnas/detnas_supernet_shufflenetv2_coco_1x_2.0_frcnn.py /mnt/lustre/dongpeijie/checkpoints/tests/detnas_pretrain_test
|
||||
|
||||
|
||||
# bash tools/slurm_test.sh mm_model detnas_test configs/nas/detnas/detnas_supernet_shufflenetv2_coco_1x_2.0_frcnn.py /mnt/lustre/dongpeijie/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_2.0.pth
|
||||
|
||||
# DetNAS test
|
||||
srun --partition=mm_model \
|
||||
--job-name=detnas_test \
|
||||
--gres=gpu:1 \
|
||||
--ntasks=1 \
|
||||
--ntasks-per-node=1 \
|
||||
--cpus-per-task=8 \
|
||||
--kill-on-bad-exit=1 \
|
||||
--quotatype=auto \
|
||||
python tools/test.py configs/nas/detnas/detnas_subnet_shufflenetv2_8xb128_in1k_2.0_frcnn.py "/mnt/lustre/dongpeijie/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_2.0.pth" --launcher=slurm
|
|
@ -1,51 +0,0 @@
|
|||
#!/usr/bin/env sh
|
||||
|
||||
|
||||
MKL_NUM_THREADS=4
|
||||
OMP_NUM_THREADS=1
|
||||
|
||||
# train
|
||||
# srun --partition=mm_model \
|
||||
# --job-name=spos_train \
|
||||
# --gres=gpu:8 \
|
||||
# --ntasks=8 \
|
||||
# --ntasks-per-node=8 \
|
||||
# --cpus-per-task=8 \
|
||||
# --kill-on-bad-exit=1 \
|
||||
# python tools/train.py configs/nas/spos/spos_supernet_shufflenetv2_8xb128_in1k_2.0_example.py
|
||||
|
||||
# bash tools/slurm_train.sh mm_model spos_train configs/nas/spos/spos_supernet_shufflenetv2_8xb128_in1k_2.0_example.py /mnt/lustre/dongpeijie/checkpoints/work_dirs/spos_format_output
|
||||
|
||||
# bash tools/slurm_train.sh mm_model spos_retrain configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k_2.0_example.py /mnt/lustre/dongpeijie/checkpoints/work_dirs/spos_retrain_detnas_with_ceph
|
||||
|
||||
# 55% wrong settings of PolyLR
|
||||
# bash tools/slurm_train.sh mm_model spos_retrain_w_cj configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k_2.0_example.py /mnt/lustre/dongpeijie/checkpoints/work_dirs/spos_retrain_detnas_with_ceph
|
||||
|
||||
# fix setting of PolyLR and rerun with colorjittor
|
||||
# bash tools/slurm_train.sh mm_model spos_retrain_w_cj configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k_2.0_example.py /mnt/lustre/dongpeijie/checkpoints/work_dirs/retrain_detnas_spos_with_colorjittor
|
||||
|
||||
# fix setting of PolyLR and rerun w/o colorjittor
|
||||
# bash tools/slurm_train.sh mm_model spos_retrain_wo_cj configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k_2.0_example_wo_colorjittor.py /mnt/lustre/dongpeijie/checkpoints/work_dirs/retrain_detnas_spos_wo_colorjittor
|
||||
|
||||
# fix setting of optimizer decay[wo cj] (paramwise_cfg)
|
||||
# bash tools/slurm_train.sh mm_model spos_retrain_fix_decay_wo_cj configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k_2.0_example_wo_colorjittor.py /mnt/lustre/dongpeijie/checkpoints/work_dirs/retrain_detnas_spos_retrain_fix_decay_wo_cj
|
||||
|
||||
# fix setting of optimizer decay[with cj] (paramwise_cfg)
|
||||
# bash tools/slurm_train.sh mm_model spos_retrain_fix_decay_w_cj configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k_2.0_example.py /mnt/lustre/dongpeijie/checkpoints/work_dirs/retrain_detnas_spos_retrain_fix_decay_w_cj
|
||||
|
||||
|
||||
|
||||
# SPOS test
|
||||
# srun --partition=mm_model \
|
||||
# --job-name=spos_test \
|
||||
# --gres=gpu:1 \
|
||||
# --ntasks=1 \
|
||||
# --ntasks-per-node=1 \
|
||||
# --cpus-per-task=8 \
|
||||
# --kill-on-bad-exit=1 \
|
||||
# python tools/test.py configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k_2.0_example.py "/mnt/lustre/dongpeijie/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_2.0.pth"
|
||||
|
||||
|
||||
bash tools/slurm_test.sh mm_model spos_test configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k_2.0_example.py '/mnt/lustre/dongpeijie/detnas_subnet_shufflenetv2_8xb128_in1k_acc-74.08_20211223-92e9b66a_2.0.pth'
|
||||
|
||||
# bash tools/slurm_train.sh mm_model spos_retrain configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k_2.0_example.py /mnt/lustre/dongpeijie/checkpoints/work_dirs/spos_retrain_detnas_spos
|
372
convert_keys.py
372
convert_keys.py
|
@ -1,372 +0,0 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from mmengine.config import Config
|
||||
|
||||
from mmrazor.core import * # noqa: F401,F403
|
||||
from mmrazor.models import * # noqa: F401,F403
|
||||
from mmrazor.registry import MODELS
|
||||
from mmrazor.utils import register_all_modules
|
||||
|
||||
|
||||
def convert_spos_key(old_path, new_path):
|
||||
old_dict = torch.load(old_path)
|
||||
new_dict = {'meta': old_dict['meta'], 'state_dict': {}}
|
||||
|
||||
mapping = {
|
||||
'choices': '_candidates',
|
||||
'architecture.': '',
|
||||
'model.': '',
|
||||
}
|
||||
|
||||
for k, v in old_dict['state_dict'].items():
|
||||
new_key = k
|
||||
for _from, _to in mapping.items():
|
||||
new_key = new_key.replace(_from, _to)
|
||||
|
||||
new_key = f'architecture.{new_key}'
|
||||
|
||||
new_dict['state_dict'][new_key] = v
|
||||
|
||||
torch.save(new_dict, new_path)
|
||||
|
||||
|
||||
def convert_detnas_key(old_path, new_path):
|
||||
old_dict = torch.load(old_path)
|
||||
new_dict = {'meta': old_dict['meta'], 'state_dict': {}}
|
||||
|
||||
mapping = {
|
||||
'choices': '_candidates',
|
||||
'model.': '',
|
||||
}
|
||||
|
||||
for k, v in old_dict['state_dict'].items():
|
||||
new_key = k
|
||||
for _from, _to in mapping.items():
|
||||
new_key = new_key.replace(_from, _to)
|
||||
|
||||
new_dict['state_dict'][new_key] = v
|
||||
torch.save(new_dict, new_path)
|
||||
|
||||
|
||||
def convert_anglenas_key(old_path, new_path):
|
||||
old_dict = torch.load(old_path)
|
||||
new_dict = {'state_dict': {}}
|
||||
|
||||
mapping = {
|
||||
'choices': '_candidates',
|
||||
'model.': '',
|
||||
'mbv2': 'mb',
|
||||
}
|
||||
|
||||
for k, v in old_dict.items():
|
||||
new_key = k
|
||||
for _from, _to in mapping.items():
|
||||
new_key = new_key.replace(_from, _to)
|
||||
|
||||
new_dict['state_dict'][new_key] = v
|
||||
torch.save(new_dict, new_path)
|
||||
|
||||
|
||||
def convert_darts_key(old_path, new_path):
|
||||
old_dict = torch.load(old_path)
|
||||
new_dict = {'meta': old_dict['meta'], 'state_dict': {}}
|
||||
cfg = Config.fromfile(
|
||||
'configs/nas/darts/darts_subnet_1xb96_cifar10_2.0.py')
|
||||
# import ipdb; ipdb.set_trace()
|
||||
model = MODELS.build(cfg.model)
|
||||
|
||||
print('============> module name')
|
||||
for name, module in model.state_dict().items():
|
||||
print(name)
|
||||
|
||||
mapping = {
|
||||
'choices': '_candidates',
|
||||
'model.': '',
|
||||
'edges': 'route',
|
||||
}
|
||||
|
||||
for k, v in old_dict['state_dict'].items():
|
||||
new_key = k
|
||||
for _from, _to in mapping.items():
|
||||
new_key = new_key.replace(_from, _to)
|
||||
# cells.0.nodes.0.edges.choices.normal_n2_p1.0.choices.sep_conv_3x3.conv1.2.weight
|
||||
splited_list = new_key.split('.')
|
||||
if len(splited_list) > 10 and splited_list[-6] == '0':
|
||||
del splited_list[-6]
|
||||
new_key = '.'.join(splited_list)
|
||||
elif len(splited_list) > 10 and splited_list[-5] == '0':
|
||||
del splited_list[-5]
|
||||
new_key = '.'.join(splited_list)
|
||||
|
||||
new_dict['state_dict'][new_key] = v
|
||||
|
||||
print('============> new dict')
|
||||
for key, v in new_dict['state_dict'].items():
|
||||
print(key)
|
||||
|
||||
model.load_state_dict(new_dict['state_dict'], strict=True)
|
||||
|
||||
torch.save(new_dict, new_path)
|
||||
|
||||
|
||||
def convert_cream_key(old_path, new_path):
|
||||
|
||||
old_dict = torch.load(old_path, map_location=torch.device('cpu'))
|
||||
new_dict = {'state_dict': {}} # noqa: F841
|
||||
|
||||
ordered_old_dict = OrderedDict(old_dict['state_dict'])
|
||||
|
||||
cfg = Config.fromfile('configs/nas/cream/cream_14_subnet_mobilenet.py')
|
||||
model = MODELS.build(cfg.model)
|
||||
|
||||
model_name_list = []
|
||||
model_module_list = []
|
||||
|
||||
# TODO show structure of model and checkpoint
|
||||
print('=' * 30, 'the key of model')
|
||||
for k, v in model.state_dict().items():
|
||||
print(k)
|
||||
|
||||
print('=' * 30, 'the key of ckpt')
|
||||
for k, v in ordered_old_dict.items():
|
||||
print(k)
|
||||
|
||||
# final mapping dict
|
||||
mapping = {}
|
||||
|
||||
middle_razor2cream = { # noqa: F841
|
||||
# point-wise expansion
|
||||
'expand_conv.conv.weight': 'conv_pw.weight',
|
||||
'expand_conv.bn.weight': 'bn1.weight',
|
||||
'expand_conv.bn.bias': 'bn1.bias',
|
||||
'expand_conv.bn.running_mean': 'bn1.running_mean',
|
||||
'expand_conv.bn.running_var': 'bn1.running_var',
|
||||
'expand_conv.bn.num_batches_tracked': 'bn1.num_batches_tracked',
|
||||
|
||||
# se
|
||||
'se.conv1.conv.weight': 'se.conv_reduce.weight',
|
||||
'se.conv1.conv.bias': 'se.conv_reduce.bias',
|
||||
'se.conv2.conv.weight': 'se.conv_expand.weight',
|
||||
'se.conv2.conv.bias': 'se.conv_expand.bias',
|
||||
|
||||
# depth-wise conv
|
||||
'depthwise_conv.conv.weight': 'conv_dw.weight',
|
||||
'depthwise_conv.bn.weight': 'bn2.weight',
|
||||
'depthwise_conv.bn.bias': 'bn2.bias',
|
||||
'depthwise_conv.bn.running_mean': 'bn2.running_mean',
|
||||
'depthwise_conv.bn.running_var': 'bn2.running_var',
|
||||
'depthwise_conv.bn.num_batches_tracked': 'bn2.num_batches_tracked',
|
||||
|
||||
# point-wise linear projection
|
||||
'linear_conv.conv.weight': 'conv_pwl.weight',
|
||||
'linear_conv.bn.weight': 'bn3.weight',
|
||||
'linear_conv.bn.bias': 'bn3.bias',
|
||||
'linear_conv.bn.running_mean': 'bn3.running_mean',
|
||||
'linear_conv.bn.running_var': 'bn3.running_var',
|
||||
'linear_conv.bn.num_batches_tracked': 'bn3.num_batches_tracked',
|
||||
|
||||
}
|
||||
|
||||
first_razor2cream = {
|
||||
# for first depthsepconv dw
|
||||
'conv_dw.conv.weight': 'conv_dw.weight',
|
||||
'conv_dw.bn.weight': 'bn1.weight',
|
||||
'conv_dw.bn.bias': 'bn1.bias',
|
||||
'conv_dw.bn.running_mean': 'bn1.running_mean',
|
||||
'conv_dw.bn.running_var': 'bn1.running_var',
|
||||
'conv_dw.bn.num_batches_tracked': 'bn1.num_batches_tracked',
|
||||
|
||||
# for first depthsepconv pw
|
||||
'conv_pw.conv.weight': 'conv_pw.weight',
|
||||
'conv_pw.bn.weight': 'bn2.weight',
|
||||
'conv_pw.bn.bias': 'bn2.bias',
|
||||
'conv_pw.bn.running_mean': 'bn2.running_mean',
|
||||
'conv_pw.bn.running_var': 'bn2.running_var',
|
||||
'conv_pw.bn.num_batches_tracked': 'bn2.num_batches_tracked',
|
||||
|
||||
# se
|
||||
'se.conv1.conv.weight': 'se.conv_reduce.weight',
|
||||
'se.conv1.conv.bias': 'se.conv_reduce.bias',
|
||||
'se.conv2.conv.weight': 'se.conv_expand.weight',
|
||||
'se.conv2.conv.bias': 'se.conv_expand.bias',
|
||||
}
|
||||
|
||||
last_razor2cream = {
|
||||
# for last convbnact
|
||||
'conv2.conv.weight': 'conv.weight',
|
||||
'conv2.bn.weight': 'bn1.weight',
|
||||
'conv2.bn.bias': 'bn1.bias',
|
||||
'conv2.bn.running_mean': 'bn1.running_mean',
|
||||
'conv2.bn.running_var': 'bn1.running_var',
|
||||
'conv2.bn.num_batches_tracked': 'bn1.num_batches_tracked',
|
||||
}
|
||||
|
||||
middle_cream2razor = {v: k for k, v in middle_razor2cream.items()}
|
||||
first_cream2razor = {v: k for k, v in first_razor2cream.items()}
|
||||
last_cream2razor = {v: k for k, v in last_razor2cream.items()}
|
||||
|
||||
# 1. group the razor's module names
|
||||
grouped_razor_module_name = {
|
||||
'middle': {},
|
||||
'first': [],
|
||||
'last': [],
|
||||
}
|
||||
|
||||
for name, module in model.state_dict().items():
|
||||
tmp_name: str = name.split(
|
||||
'backbone.')[1] if 'backbone' in name else name
|
||||
model_name_list.append(tmp_name)
|
||||
model_module_list.append(module)
|
||||
|
||||
if 'conv1' in tmp_name and len(tmp_name) <= 35:
|
||||
# belong to stem conv
|
||||
grouped_razor_module_name['first'].append(name)
|
||||
elif 'head' in tmp_name:
|
||||
# belong to last linear
|
||||
grouped_razor_module_name['last'].append(name)
|
||||
else:
|
||||
# middle
|
||||
if tmp_name.startswith('layer'):
|
||||
key_of_middle = tmp_name[5:8]
|
||||
if key_of_middle not in grouped_razor_module_name['middle']:
|
||||
grouped_razor_module_name['middle'][key_of_middle] = [name]
|
||||
else:
|
||||
grouped_razor_module_name['middle'][key_of_middle].append(
|
||||
name)
|
||||
elif tmp_name.startswith('conv2'):
|
||||
key_of_middle = '7.0'
|
||||
if key_of_middle not in grouped_razor_module_name['middle']:
|
||||
grouped_razor_module_name['middle'][key_of_middle] = [name]
|
||||
else:
|
||||
grouped_razor_module_name['middle'][key_of_middle].append(
|
||||
name)
|
||||
|
||||
# 2. group the cream's module names
|
||||
grouped_cream_module_name = {
|
||||
'middle': {},
|
||||
'first': [],
|
||||
'last': [],
|
||||
}
|
||||
|
||||
for k in ordered_old_dict.keys():
|
||||
if 'classifier' in k or 'conv_head' in k:
|
||||
# last conv
|
||||
grouped_cream_module_name['last'].append(k)
|
||||
elif 'blocks' in k:
|
||||
# middle blocks
|
||||
key_of_middle = k[7:10]
|
||||
if key_of_middle not in grouped_cream_module_name['middle']:
|
||||
grouped_cream_module_name['middle'][key_of_middle] = [k]
|
||||
else:
|
||||
grouped_cream_module_name['middle'][key_of_middle].append(k)
|
||||
else:
|
||||
# first blocks
|
||||
grouped_cream_module_name['first'].append(k)
|
||||
|
||||
# 4. process the first modules
|
||||
for cream_item in grouped_cream_module_name['first']:
|
||||
if 'conv_stem' in cream_item:
|
||||
# get corresponding item from razor
|
||||
for razor_item in grouped_razor_module_name['first']:
|
||||
if 'conv.weight' in razor_item:
|
||||
mapping[cream_item] = razor_item
|
||||
grouped_razor_module_name['first'].remove(razor_item)
|
||||
break
|
||||
else:
|
||||
kws = cream_item.split('.')[-1]
|
||||
# get corresponding item from razor
|
||||
for razor_item in grouped_razor_module_name['first']:
|
||||
if kws in razor_item:
|
||||
mapping[cream_item] = razor_item
|
||||
grouped_razor_module_name['first'].remove(razor_item)
|
||||
|
||||
# 5. process the last modules
|
||||
for cream_item in grouped_cream_module_name['last']:
|
||||
if 'classifier' in cream_item:
|
||||
kws = cream_item.split('.')[-1]
|
||||
for razor_item in grouped_razor_module_name['last']:
|
||||
if 'fc' in razor_item:
|
||||
if kws in razor_item:
|
||||
mapping[cream_item] = razor_item
|
||||
grouped_razor_module_name['last'].remove(razor_item)
|
||||
break
|
||||
|
||||
elif 'conv_head' in cream_item:
|
||||
kws = cream_item.split('.')[-1]
|
||||
for razor_item in grouped_razor_module_name['last']:
|
||||
if 'head.conv2' in razor_item:
|
||||
if kws in razor_item:
|
||||
mapping[cream_item] = razor_item
|
||||
grouped_razor_module_name['last'].remove(razor_item)
|
||||
|
||||
# 6. process the middle modules
|
||||
for cream_group_id, cream_items in grouped_cream_module_name[
|
||||
'middle'].items():
|
||||
# get the corresponding group from razor
|
||||
razor_group_id: str = str(float(cream_group_id) + 1)
|
||||
razor_items: list = grouped_razor_module_name['middle'][razor_group_id]
|
||||
|
||||
if int(razor_group_id[0]) == 1:
|
||||
key_cream2razor = first_cream2razor
|
||||
elif int(razor_group_id[0]) == 7:
|
||||
key_cream2razor = last_cream2razor
|
||||
else:
|
||||
key_cream2razor = middle_cream2razor
|
||||
|
||||
# matching razor items and cream items
|
||||
for cream_item in cream_items:
|
||||
# traverse all of key_cream2razor
|
||||
for cream_match, razor_match in key_cream2razor.items():
|
||||
if cream_match in cream_item:
|
||||
# traverse razor_items to get the corresponding razor name
|
||||
for razor_item in razor_items:
|
||||
if razor_match in razor_item:
|
||||
mapping[cream_item] = razor_item
|
||||
break
|
||||
|
||||
print('=' * 100)
|
||||
print('length of mapping: ', len(mapping.keys()))
|
||||
for k, v in mapping.items():
|
||||
print(k, '\t=>\t', v)
|
||||
print('#' * 100)
|
||||
|
||||
# TODO DELETE this print
|
||||
print('**' * 20)
|
||||
for c, cm, r, rm in zip(ordered_old_dict.keys(), ordered_old_dict.values(),
|
||||
model_name_list, model_module_list):
|
||||
print(f'{c}: shape {cm.shape} => {r}: shape {rm.shape}')
|
||||
print('**' * 20)
|
||||
|
||||
for k, v in ordered_old_dict.items():
|
||||
print(f'Mapping from {k} to {mapping[k]}......')
|
||||
new_dict['state_dict'][mapping[k]] = v
|
||||
|
||||
model.load_state_dict(new_dict['state_dict'], strict=True)
|
||||
|
||||
torch.save(new_dict, new_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
register_all_modules(True)
|
||||
# old_path = '/mnt/lustre/dongpeijie/detnas_subnet_shufflenetv2_8xb128_in1k_acc-74.08_20211223-92e9b66a.pth' # noqa: E501
|
||||
# new_path = '/mnt/lustre/dongpeijie/detnas_subnet_shufflenetv2_8xb128_in1k_acc-74.08_20211223-92e9b66a_2.0.pth' # noqa: E501
|
||||
# convert_spos_key(old_path, new_path)
|
||||
|
||||
# old_path = '/mnt/lustre/dongpeijie/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.pth' # noqa: E501
|
||||
# new_path = '/mnt/lustre/dongpeijie/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_2.0.pth' # noqa: E501
|
||||
# convert_detnas_key(old_path, new_path)
|
||||
|
||||
# old_path = './data/14.pth.tar'
|
||||
# new_path = './data/14_2.0.pth'
|
||||
# old_path = '/mnt/lustre/dongpeijie/14.pth.tar'
|
||||
# new_path = '/mnt/lustre/dongpeijie/14_2.0.pth'
|
||||
# convert_cream_key(old_path, new_path)
|
||||
|
||||
# old_path = '/mnt/lustre/dongpeijie/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921.pth' # noqa: E501
|
||||
# new_path = '/mnt/lustre/dongpeijie/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921_2.0.pth' # noqa: E501
|
||||
# convert_darts_key(old_path, new_path)
|
||||
|
||||
old_path = '/mnt/lustre/dongpeijie/spos_angelnas_flops_0.49G_acc_75.98_20220307-54f4698f.pth' # noqa: E501
|
||||
new_path = '/mnt/lustre/dongpeijie/spos_angelnas_flops_0.49G_acc_75.98_20220307-54f4698f_2.0.pth' # noqa: E501
|
||||
convert_anglenas_key(old_path, new_path)
|
Loading…
Reference in New Issue