implementation of TFA (#13)

* update script

* update script

* fix comments

* create pr

* fix comments

Co-authored-by: zhangshilong <2392587229zsl@gmail.com>
Co-authored-by: Shilong Zhang <61961338+jshilong@users.noreply.github.com>
pull/1/head
Linyiqi 2021-07-30 16:21:43 +08:00 committed by GitHub
parent 87b36f102a
commit a643f7ee9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1225 additions and 3 deletions

View File

@ -0,0 +1,36 @@
_base_ = [
'../../_base_/datasets/finetune_based/few_shot_coco.py',
'../../_base_/schedules/schedule.py',
'../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotCocoDataset
# FewShotCocoDefaultDataset predefine ann_cfg for model reproducibility
data = dict(
train=dict(
dataset=dict(
type='FewShotCocoDefaultDataset',
ann_cfg=[dict(method='TFA', setting='10SHOT')],
num_novel_shots=10,
num_base_shots=10)))
evaluation = dict(interval=80000)
checkpoint_config = dict(interval=80000)
optimizer = dict(lr=0.001)
lr_config = dict(warmup_iters=10, step=[144000])
runner = dict(max_iters=160000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'faster_rcnn_r101_fpn_coco_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=80,
scale=20)))

View File

@ -0,0 +1,39 @@
_base_ = [
'../../_base_/datasets/finetune_based/few_shot_coco.py',
'../../_base_/schedules/schedule.py',
'../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotCocoDataset
# FewShotCocoDefaultDataset predefine ann_cfg for model reproducibility
data = dict(
train=dict(
dataset=dict(
type='FewShotCocoDefaultDataset',
ann_cfg=[dict(method='TFA', setting='30SHOT')],
num_novel_shots=30,
num_base_shots=30)))
evaluation = dict(interval=120000)
checkpoint_config = dict(interval=120000)
optimizer = dict(lr=0.001)
lr_config = dict(
warmup_iters=10, step=[
216000,
])
runner = dict(max_iters=240000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'faster_rcnn_r101_fpn_coco_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=80,
scale=20)))

View File

@ -0,0 +1,13 @@
_base_ = [
'../../_base_/datasets/finetune_based/base_coco.py',
'../../_base_/schedules/schedule.py',
'../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../_base_/default_runtime.py'
]
lr_config = dict(warmup_iters=1000, step=[85000, 100000])
runner = dict(max_iters=110000)
# model settings
model = dict(
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101),
roi_head=dict(bbox_head=dict(num_classes=60), ))

View File

@ -0,0 +1,45 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT1_10SHOT')],
num_novel_shots=10,
num_base_shots=10,
classes='ALL_CLASSES_SPLIT1')),
val=dict(classes='ALL_CLASSES_SPLIT1'),
test=dict(classes='ALL_CLASSES_SPLIT1'))
evaluation = dict(
interval=40000,
class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
checkpoint_config = dict(interval=40000)
optimizer = dict(lr=0.001)
lr_config = dict(
warmup_iters=10, step=[
36000,
])
runner = dict(max_iters=40000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split1_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,45 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT1_1SHOT')],
num_novel_shots=1,
num_base_shots=1,
classes='ALL_CLASSES_SPLIT1')),
val=dict(classes='ALL_CLASSES_SPLIT1'),
test=dict(classes='ALL_CLASSES_SPLIT1'))
evaluation = dict(
interval=4000,
class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
checkpoint_config = dict(interval=4000)
optimizer = dict(lr=0.001)
lr_config = dict(
warmup=None, step=[
3000,
])
runner = dict(max_iters=4000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split1_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,45 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT1_2SHOT')],
num_novel_shots=2,
num_base_shots=2,
classes='ALL_CLASSES_SPLIT1')),
val=dict(classes='ALL_CLASSES_SPLIT1'),
test=dict(classes='ALL_CLASSES_SPLIT1'))
evaluation = dict(
interval=8000,
class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
checkpoint_config = dict(interval=8000)
optimizer = dict(lr=0.001)
lr_config = dict(
warmup=None, step=[
7000,
])
runner = dict(max_iters=8000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split1_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,45 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT1_3SHOT')],
num_novel_shots=3,
num_base_shots=3,
classes='ALL_CLASSES_SPLIT1')),
val=dict(classes='ALL_CLASSES_SPLIT1'),
test=dict(classes='ALL_CLASSES_SPLIT1'))
evaluation = dict(
interval=12000,
class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
checkpoint_config = dict(interval=12000)
optimizer = dict(lr=0.001)
lr_config = dict(
warmup_iters=10, step=[
11000,
])
runner = dict(max_iters=12000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split1_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,45 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT1_5SHOT')],
num_novel_shots=5,
num_base_shots=5,
classes='ALL_CLASSES_SPLIT1')),
val=dict(classes='ALL_CLASSES_SPLIT1'),
test=dict(classes='ALL_CLASSES_SPLIT1'))
evaluation = dict(
interval=20000,
class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
checkpoint_config = dict(interval=20000)
optimizer = dict(lr=0.001)
lr_config = dict(
warmup_iters=10, step=[
18000,
])
runner = dict(max_iters=20000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split1_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,18 @@
_base_ = [
'../../../_base_/datasets/finetune_based/base_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
data = dict(
train=dict(dataset=dict(classes='BASE_CLASSES_SPLIT1')),
val=dict(classes='BASE_CLASSES_SPLIT1'),
test=dict(classes='BASE_CLASSES_SPLIT1'))
lr_config = dict(warmup_iters=100, step=[12000, 16000])
runner = dict(max_iters=18000)
# model settings
model = dict(
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101),
roi_head=dict(bbox_head=dict(num_classes=15)))

View File

@ -0,0 +1,46 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT2_10SHOT')],
num_novel_shots=10,
num_base_shots=10,
classes='ALL_CLASSES_SPLIT2',
)),
val=dict(classes='ALL_CLASSES_SPLIT2'),
test=dict(classes='ALL_CLASSES_SPLIT2'))
evaluation = dict(
interval=40000,
class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
checkpoint_config = dict(interval=40000)
optimizer = dict(lr=0.005)
lr_config = dict(
warmup_iters=10, step=[
36000,
])
runner = dict(max_iters=40000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split2_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,46 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT2_1SHOT')],
num_novel_shots=1,
num_base_shots=1,
classes='ALL_CLASSES_SPLIT2',
)),
val=dict(classes='ALL_CLASSES_SPLIT2'),
test=dict(classes='ALL_CLASSES_SPLIT2'))
evaluation = dict(
interval=4000,
class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
checkpoint_config = dict(interval=4000)
optimizer = dict(lr=0.005)
lr_config = dict(
warmup=None, step=[
3000,
])
runner = dict(max_iters=4000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split2_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,46 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT2_2SHOT')],
num_novel_shots=2,
num_base_shots=2,
classes='ALL_CLASSES_SPLIT2',
)),
val=dict(classes='ALL_CLASSES_SPLIT2'),
test=dict(classes='ALL_CLASSES_SPLIT2'))
evaluation = dict(
interval=8000,
class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
checkpoint_config = dict(interval=8000)
optimizer = dict(lr=0.005)
lr_config = dict(
warmup=None, step=[
7000,
])
runner = dict(max_iters=8000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split2_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,46 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT2_3SHOT')],
num_novel_shots=3,
num_base_shots=3,
classes='ALL_CLASSES_SPLIT2',
)),
val=dict(classes='ALL_CLASSES_SPLIT2'),
test=dict(classes='ALL_CLASSES_SPLIT2'))
evaluation = dict(
interval=12000,
class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
checkpoint_config = dict(interval=12000)
optimizer = dict(lr=0.005)
lr_config = dict(
warmup_iters=10, step=[
11000,
])
runner = dict(max_iters=12000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split2_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,46 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT2_5SHOT')],
num_novel_shots=5,
num_base_shots=5,
classes='ALL_CLASSES_SPLIT2',
)),
val=dict(classes='ALL_CLASSES_SPLIT2'),
test=dict(classes='ALL_CLASSES_SPLIT2'))
evaluation = dict(
interval=20000,
class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
checkpoint_config = dict(interval=20000)
optimizer = dict(lr=0.005)
lr_config = dict(
warmup_iters=10, step=[
18000,
])
runner = dict(max_iters=20000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split2_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,18 @@
_base_ = [
'../../../_base_/datasets/finetune_based/base_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
data = dict(
train=dict(dataset=dict(classes='BASE_CLASSES_SPLIT2')),
val=dict(classes='BASE_CLASSES_SPLIT2'),
test=dict(classes='BASE_CLASSES_SPLIT2'))
lr_config = dict(warmup_iters=100, step=[12000, 16000])
runner = dict(max_iters=18000)
# model settings
model = dict(
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, ),
roi_head=dict(bbox_head=dict(num_classes=15), ))

View File

@ -0,0 +1,46 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT3_10SHOT')],
num_novel_shots=10,
num_base_shots=10,
classes='ALL_CLASSES_SPLIT3',
)),
val=dict(classes='ALL_CLASSES_SPLIT3'),
test=dict(classes='ALL_CLASSES_SPLIT3'))
evaluation = dict(
interval=40000,
class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
checkpoint_config = dict(interval=40000)
optimizer = dict(lr=0.001)
lr_config = dict(
warmup_iters=10, step=[
36000,
])
runner = dict(max_iters=40000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split3_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,46 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT3_1SHOT')],
num_novel_shots=1,
num_base_shots=1,
classes='ALL_CLASSES_SPLIT3',
)),
val=dict(classes='ALL_CLASSES_SPLIT3'),
test=dict(classes='ALL_CLASSES_SPLIT3'))
evaluation = dict(
interval=4000,
class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
checkpoint_config = dict(interval=4000)
optimizer = dict(lr=0.001)
lr_config = dict(
warmup=None, step=[
3000,
])
runner = dict(max_iters=4000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split3_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,46 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT3_2SHOT')],
num_novel_shots=2,
num_base_shots=2,
classes='ALL_CLASSES_SPLIT3',
)),
val=dict(classes='ALL_CLASSES_SPLIT3'),
test=dict(classes='ALL_CLASSES_SPLIT3'))
evaluation = dict(
interval=8000,
class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
checkpoint_config = dict(interval=8000)
optimizer = dict(lr=0.001)
lr_config = dict(
warmup=None, step=[
7000,
])
runner = dict(max_iters=8000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split3_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,46 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT3_3SHOT')],
num_novel_shots=3,
num_base_shots=3,
classes='ALL_CLASSES_SPLIT3',
)),
val=dict(classes='ALL_CLASSES_SPLIT3'),
test=dict(classes='ALL_CLASSES_SPLIT3'))
evaluation = dict(
interval=12000,
class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
checkpoint_config = dict(interval=12000)
optimizer = dict(lr=0.001)
lr_config = dict(
warmup_iters=10, step=[
11000,
])
runner = dict(max_iters=12000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split3_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,46 @@
_base_ = [
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
data = dict(
train=dict(
dataset=dict(
type='FewShotVOCDefaultDataset',
ann_cfg=[dict(method='TFA', setting='SPLIT3_5SHOT')],
num_novel_shots=5,
num_base_shots=5,
classes='ALL_CLASSES_SPLIT3',
)),
val=dict(classes='ALL_CLASSES_SPLIT3'),
test=dict(classes='ALL_CLASSES_SPLIT3'))
evaluation = dict(
interval=20000,
class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
checkpoint_config = dict(interval=20000)
optimizer = dict(lr=0.001)
lr_config = dict(
warmup_iters=10, step=[
18000,
])
runner = dict(max_iters=20000)
# load_from = 'path of base training model'
load_from = \
'work_dirs/' \
'tfa_faster_rcnn_r101_fpn_voc_split3_base_training/' \
'model_reset_surgery.pth'
model = dict(
frozen_parameters=[
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
],
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, frozen_stages=4),
roi_head=dict(
bbox_head=dict(
type='CosineSimBBoxHead',
num_shared_fcs=2,
num_classes=20,
scale=20)))

View File

@ -0,0 +1,18 @@
_base_ = [
'../../../_base_/datasets/finetune_based/base_voc.py',
'../../../_base_/schedules/schedule.py',
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
'../../../_base_/default_runtime.py'
]
# classes splits are predefined in FewShotVOCDataset
data = dict(
train=dict(dataset=dict(classes='BASE_CLASSES_SPLIT3')),
val=dict(classes='BASE_CLASSES_SPLIT3'),
test=dict(classes='BASE_CLASSES_SPLIT3'))
lr_config = dict(warmup_iters=100, step=[12000, 16000])
runner = dict(max_iters=18000)
# model settings
model = dict(
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(depth=101, ),
roi_head=dict(bbox_head=dict(num_classes=15), ))

View File

@ -1,5 +1,5 @@
from .attention_rpn import AttentionRPN
from .attention_rpn_detector import AttentionRPNDetector
from .base_query_support import BaseQuerySupportDetector
from .fsdetview import FsDetView
from .fsdetview import FSDetView
__all__ = ['BaseQuerySupportDetector', 'AttentionRPN', 'FsDetView']
__all__ = ['BaseQuerySupportDetector', 'AttentionRPNDetector', 'FSDetView']

View File

@ -0,0 +1,13 @@
from .attention_rpn_roi_head import AttentionRPNRoIHead
from .bbox_heads import (ContrastiveBBoxHead, CosineSimBBoxHead,
MultiRelationBBoxHead)
from .contrastive_roi_head import ContrastiveRoIHead
from .fsdetview_roi_head import FSDetViewRoIHead
from .meta_rcnn_roi_head import MetaRCNNRoIHead
from .shared_heads import MetaRCNNResLayer
__all__ = [
'CosineSimBBoxHead', 'ContrastiveBBoxHead', 'MultiRelationBBoxHead',
'ContrastiveRoIHead', 'AttentionRPNRoIHead', 'FSDetViewRoIHead',
'MetaRCNNRoIHead', 'MetaRCNNResLayer'
]

View File

@ -0,0 +1,9 @@
from .contrastive_bbox_head import ContrastiveBBoxHead
from .cosine_sim_bbox_head import CosineSimBBoxHead
from .meta_bbox_head import MetaBBoxHead
from .multi_relation_bbox_head import MultiRelationBBoxHead
__all__ = [
'CosineSimBBoxHead', 'ContrastiveBBoxHead', 'MultiRelationBBoxHead',
'MetaBBoxHead'
]

View File

@ -0,0 +1,99 @@
import torch
import torch.nn as nn
from mmdet.models.builder import HEADS
from mmdet.models.roi_heads import ConvFCBBoxHead
EPS = 1e-5
@HEADS.register_module()
class CosineSimBBoxHead(ConvFCBBoxHead):
"""BBOxHead for `TFA <https://arxiv.org/abs/2003.06957>`_.
The code is modified from the official implementation
https://github.com/ucbdrive/few-shot-object-detection/
Args:
scale (int): Scaling factor of `cls_score`. Default: 20.
learnable_scale (bool): Learnable global scaling factor.
Default: False.
"""
def __init__(self, scale=20, learnable_scale=False, *args, **kwargs):
super(CosineSimBBoxHead, self).__init__(*args, **kwargs)
# override the fc_cls in :obj:`ConvFCBBoxHead`
if self.with_cls:
self.fc_cls = nn.Linear(
self.cls_last_dim, self.num_classes + 1, bias=False)
# learnable global scaling factor
if learnable_scale:
self.scale = nn.Parameter(torch.ones(1) * scale)
else:
self.scale = scale
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Shape of (N, C, H, W).
Returns:
tuple(Tensor, Tensor): Box scores with shape of
(N, num_classes, H, W) and Box energies /
deltas with shape of (N, 4, H, W).
"""
# shared part
if self.num_shared_convs > 0:
for conv in self.shared_convs:
x = conv(x)
if self.num_shared_fcs > 0:
if self.with_avg_pool:
x = self.avg_pool(x)
x = x.flatten(1)
for fc in self.shared_fcs:
x = self.relu(fc(x))
# separate branches
x_cls = x
x_reg = x
for conv in self.cls_convs:
x_cls = conv(x_cls)
if x_cls.dim() > 2:
if self.with_avg_pool:
x_cls = self.avg_pool(x_cls)
x_cls = x_cls.flatten(1)
for fc in self.cls_fcs:
x_cls = self.relu(fc(x_cls))
for conv in self.reg_convs:
x_reg = conv(x_reg)
if x_reg.dim() > 2:
if self.with_avg_pool:
x_reg = self.avg_pool(x_reg)
x_reg = x_reg.flatten(1)
for fc in self.reg_fcs:
x_reg = self.relu(fc(x_reg))
bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
if x_cls.dim() > 2:
x_cls = torch.flatten(x_cls, start_dim=1)
# normalize the input x along the `input_size` dimension
x_norm = torch.norm(x_cls, p=2, dim=1).unsqueeze(1).expand_as(x)
x_cls_normalized = x_cls.div(x_norm + EPS)
# normalize weight
with torch.no_grad():
temp_norm = torch.norm(
self.fc_cls.weight, p=2,
dim=1).unsqueeze(1).expand_as(self.fc_cls.weight)
self.fc_cls.weight.div_(temp_norm + EPS)
# calculate and scale cls_score
cls_score = self.scale * self.fc_cls(x_cls_normalized) \
if self.with_cls else None
return cls_score, bbox_pred

View File

@ -0,0 +1,274 @@
"""Modified the classifier of base model for novel class fine-tuning.
Initialize the classifier with the checkpoint in base training for
novel class fine-tuning. For more details, It would initialize a
classifier head with total (num_base_classes + num_novel_classes)
classes, for classes that inherit from the base training,
the weight would be load from the corresponding base training
checkpoint. For novel classes, the weight would be randomly initialized.
Temporally, we only use this script in FSCE and TFA with --method randinit.
This part of code is modified from
https://github.com/ucbdrive/few-shot-object-detection/.
Example:
# VOC base model
python3 -m tools.models.checkpoint_surgery \
--src1 work_dirs/voc_split1_base_training/latest.pth \
--method randinit \
--save-dir work_dirs/voc_split1_base_training
# COCO base model
python3 -m tools.models.checkpoint_surgery \
--src1 work_dirs/coco_base_training/latest.pth \
--method randinit \
--coco \
--save-dir work_dirs/coco_base_training
"""
import argparse
import os
import torch
# COCO config
COCO_NOVEL_CLASSES = [
1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72
]
COCO_BASE_CLASSES = [
8, 10, 11, 13, 14, 15, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37,
38, 39, 40, 41, 42, 43, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58,
59, 60, 61, 65, 70, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87,
88, 89, 90
]
COCO_ALL_CLASSES = sorted(COCO_BASE_CLASSES + COCO_NOVEL_CLASSES)
COCO_IDMAP = {v: i for i, v in enumerate(COCO_ALL_CLASSES)}
COCO_TAR_SIZE = 80
# LVIS config
LVIS_NOVEL_CLASSES = [
0, 6, 9, 13, 14, 15, 20, 21, 30, 37, 38, 39, 41, 45, 48, 50, 51, 63, 64,
69, 71, 73, 82, 85, 93, 99, 100, 104, 105, 106, 112, 115, 116, 119, 121,
124, 126, 129, 130, 135, 139, 141, 142, 143, 146, 149, 154, 158, 160, 162,
163, 166, 168, 172, 180, 181, 183, 195, 198, 202, 204, 205, 208, 212, 213,
216, 217, 218, 225, 226, 230, 235, 237, 238, 240, 241, 242, 244, 245, 248,
249, 250, 251, 252, 254, 257, 258, 264, 265, 269, 270, 272, 279, 283, 286,
290, 292, 294, 295, 297, 299, 302, 303, 305, 306, 309, 310, 312, 315, 316,
317, 319, 320, 321, 323, 325, 327, 328, 329, 334, 335, 341, 343, 349, 350,
353, 355, 356, 357, 358, 359, 360, 365, 367, 368, 369, 371, 377, 378, 384,
385, 387, 388, 392, 393, 401, 402, 403, 405, 407, 410, 412, 413, 416, 419,
420, 422, 426, 429, 432, 433, 434, 437, 438, 440, 441, 445, 453, 454, 455,
461, 463, 468, 472, 475, 476, 477, 482, 484, 485, 487, 488, 492, 494, 495,
497, 508, 509, 511, 513, 514, 515, 517, 520, 523, 524, 525, 526, 529, 533,
540, 541, 542, 544, 547, 550, 551, 552, 554, 555, 561, 563, 568, 571, 572,
580, 581, 583, 584, 585, 586, 589, 591, 592, 593, 595, 596, 599, 601, 604,
608, 609, 611, 612, 615, 616, 625, 626, 628, 629, 630, 633, 635, 642, 644,
645, 649, 655, 657, 658, 662, 663, 664, 670, 673, 675, 676, 682, 683, 685,
689, 695, 697, 699, 702, 711, 712, 715, 721, 722, 723, 724, 726, 729, 731,
733, 734, 738, 740, 741, 744, 748, 754, 758, 764, 766, 767, 768, 771, 772,
774, 776, 777, 781, 782, 784, 789, 790, 794, 795, 796, 798, 799, 803, 805,
806, 807, 808, 815, 817, 820, 821, 822, 824, 825, 827, 832, 833, 835, 836,
840, 842, 844, 846, 856, 862, 863, 864, 865, 866, 868, 869, 870, 871, 872,
875, 877, 882, 886, 892, 893, 897, 898, 900, 901, 904, 905, 907, 915, 918,
919, 920, 921, 922, 926, 927, 930, 931, 933, 939, 940, 944, 945, 946, 948,
950, 951, 953, 954, 955, 956, 958, 959, 961, 962, 963, 969, 974, 975, 988,
990, 991, 998, 999, 1001, 1003, 1005, 1008, 1009, 1010, 1012, 1015, 1020,
1022, 1025, 1026, 1028, 1029, 1032, 1033, 1046, 1047, 1048, 1049, 1050,
1055, 1066, 1067, 1068, 1072, 1073, 1076, 1077, 1086, 1094, 1099, 1103,
1111, 1132, 1135, 1137, 1138, 1139, 1140, 1144, 1146, 1148, 1150, 1152,
1153, 1156, 1158, 1165, 1166, 1167, 1168, 1169, 1171, 1178, 1179, 1180,
1186, 1187, 1188, 1189, 1203, 1204, 1205, 1213, 1215, 1218, 1224, 1225,
1227
]
LVIS_BASE_CLASSES = [c for c in range(1230) if c not in LVIS_NOVEL_CLASSES]
LVIS_ALL_CLASSES = sorted(LVIS_BASE_CLASSES + LVIS_NOVEL_CLASSES)
LVIS_IDMAP = {v: i for i, v in enumerate(LVIS_ALL_CLASSES)}
LVIS_TAR_SIZE = 1230
# VOC config
VOC_TAR_SIZE = 20
def parse_args():
parser = argparse.ArgumentParser()
# Paths
parser.add_argument('--src1', type=str, help='Path to the main checkpoint')
parser.add_argument(
'--src2',
type=str,
default=None,
help='Path to the secondary checkpoint. Only used when combining '
'fc layers of two checkpoints')
parser.add_argument(
'--save-dir', type=str, default=None, help='Save directory')
parser.add_argument(
'--method',
choices=['combine', 'remove', 'randinit'],
required=True,
help='Surgery method. combine = '
'combine checkpoints. remove = for fine-tuning on '
'novel dataset, remove the final layer of the '
'base detector. randinit = randomly initialize '
'novel weights.')
parser.add_argument(
'--param-name',
type=str,
nargs='+',
default=['roi_head.bbox_head.fc_cls', 'roi_head.bbox_head.fc_reg'],
help='Target parameter names')
parser.add_argument(
'--tar-name',
type=str,
default='model_reset',
help='Name of the new checkpoint')
# Dataset
parser.add_argument('--coco', action='store_true', help='For COCO models')
parser.add_argument('--lvis', action='store_true', help='For LVIS models')
return parser.parse_args()
def random_init_checkpoint(param_name, is_weight, tar_size, checkpoint, args):
"""Either remove the final layer weights for fine-tuning on novel dataset
or append randomly initialized weights for the novel classes.
Note: The base detector for LVIS contains weights for all classes, but only
the weights corresponding to base classes are updated during base training
(this design choice has no particular reason). Thus, the random
initialization step is not really necessary.
"""
weight_name = param_name + ('.weight' if is_weight else '.bias')
pretrained_weight = checkpoint['state_dict'][weight_name]
prev_cls = pretrained_weight.size(0)
if 'fc_cls' in param_name:
prev_cls -= 1
if is_weight:
feat_size = pretrained_weight.size(1)
new_weight = torch.rand((tar_size, feat_size))
torch.nn.init.normal_(new_weight, 0, 0.01)
else:
new_weight = torch.zeros(tar_size)
if args.coco or args.lvis:
BASE_CLASSES = COCO_BASE_CLASSES if args.coco else LVIS_BASE_CLASSES
IDMAP = COCO_IDMAP if args.coco else LVIS_IDMAP
for i, c in enumerate(BASE_CLASSES):
idx = i if args.coco else c
if 'fc_cls' in param_name:
new_weight[IDMAP[c]] = pretrained_weight[idx]
else:
new_weight[IDMAP[c] * 4:(IDMAP[c] + 1) * 4] = \
pretrained_weight[idx * 4:(idx + 1) * 4]
else:
new_weight[:prev_cls] = pretrained_weight[:prev_cls]
if 'fc_cls' in param_name:
new_weight[-1] = pretrained_weight[-1] # bg class
checkpoint['state_dict'][weight_name] = new_weight
def combine_checkpoints(param_name, is_weight, tar_size, checkpoint,
checkpoint2, args):
"""Combine base detector with novel detector.
Feature extractor weights are from the base detector. Only the final layer
weights are combined.
"""
if not is_weight and param_name + '.bias' not in checkpoint['state_dict']:
return
if not is_weight and param_name + '.bias' not in checkpoint2['state_dict']:
return
weight_name = param_name + ('.weight' if is_weight else '.bias')
pretrained_weight = checkpoint['state_dict'][weight_name]
prev_cls = pretrained_weight.size(0)
if 'fc_cls' in param_name:
prev_cls -= 1
if is_weight:
feat_size = pretrained_weight.size(1)
new_weight = torch.rand((tar_size, feat_size))
else:
new_weight = torch.zeros(tar_size)
if args.coco or args.lvis:
BASE_CLASSES = COCO_BASE_CLASSES if args.coco else LVIS_BASE_CLASSES
IDMAP = COCO_IDMAP if args.coco else LVIS_IDMAP
for i, c in enumerate(BASE_CLASSES):
idx = i if args.coco else c
if 'fc_cls' in param_name:
new_weight[IDMAP[c]] = pretrained_weight[idx]
else:
new_weight[IDMAP[c] * 4:(IDMAP[c] + 1) * 4] = \
pretrained_weight[idx * 4:(idx + 1) * 4]
else:
new_weight[:prev_cls] = pretrained_weight[:prev_cls]
checkpoint2_weight = checkpoint2['state_dict'][weight_name]
if args.coco or args.lvis:
NOVEL_CLASSES = COCO_NOVEL_CLASSES if args.coco else LVIS_NOVEL_CLASSES
IDMAP = COCO_IDMAP if args.coco else LVIS_IDMAP
for i, c in enumerate(NOVEL_CLASSES):
if 'fc_cls' in param_name:
new_weight[IDMAP[c]] = checkpoint2_weight[i]
else:
new_weight[IDMAP[c] * 4:(IDMAP[c] + 1) * 4] = \
checkpoint2_weight[i * 4:(i + 1) * 4]
if 'fc_cls' in param_name:
new_weight[-1] = pretrained_weight[-1]
else:
if 'fc_cls' in param_name:
new_weight[prev_cls:-1] = checkpoint2_weight[:-1]
new_weight[-1] = pretrained_weight[-1]
else:
new_weight[prev_cls:] = checkpoint2_weight
checkpoint['state_dict'][weight_name] = new_weight
return checkpoint
def reset_checkpoint(checkpoint):
if 'scheduler' in checkpoint:
del checkpoint['scheduler']
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
if 'iteration' in checkpoint:
checkpoint['iteration'] = 0
def main():
args = parse_args()
checkpoint = torch.load(args.src1)
save_name = args.tar_name + f'_{args.method}.pth'
save_dir = args.save_dir \
if args.save_dir != '' else os.path.dirname(args.src1)
save_path = os.path.join(save_dir, save_name)
os.makedirs(save_dir, exist_ok=True)
reset_checkpoint(checkpoint)
if args.coco:
TAR_SIZE = COCO_TAR_SIZE
elif args.lvis:
TAR_SIZE = LVIS_TAR_SIZE
else:
TAR_SIZE = VOC_TAR_SIZE
if args.method == 'remove':
# Remove parameters
for param_name in args.param_name:
del checkpoint['state_dict'][param_name + '.weight']
if param_name + '.bias' in checkpoint['state_dict']:
del checkpoint['state_dict'][param_name + '.bias']
elif args.method == 'combine':
checkpoint2 = torch.load(args.src2)
tar_sizes = [TAR_SIZE + 1, TAR_SIZE * 4]
for idx, (param_name,
tar_size) in enumerate(zip(args.param_name, tar_sizes)):
combine_checkpoints(param_name, True, tar_size, checkpoint,
checkpoint2)
combine_checkpoints(param_name, False, tar_size, checkpoint,
checkpoint2)
elif args.method == 'randinit':
tar_sizes = [TAR_SIZE + 1, TAR_SIZE * 4]
for idx, (param_name,
tar_size) in enumerate(zip(args.param_name, tar_sizes)):
random_init_checkpoint(param_name, True, tar_size, checkpoint)
random_init_checkpoint(param_name, False, tar_size, checkpoint)
else:
raise ValueError(f'not support method: {args.method}')
torch.save(checkpoint, save_path)
print('save changed checkpoint to {}'.format(save_path))
if __name__ == '__main__':
main()