implementation of TFA (#13)
* update script * update script * fix comments * create pr * fix comments Co-authored-by: zhangshilong <2392587229zsl@gmail.com> Co-authored-by: Shilong Zhang <61961338+jshilong@users.noreply.github.com>pull/1/head
parent
87b36f102a
commit
a643f7ee9b
|
@ -0,0 +1,36 @@
|
|||
_base_ = [
|
||||
'../../_base_/datasets/finetune_based/few_shot_coco.py',
|
||||
'../../_base_/schedules/schedule.py',
|
||||
'../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotCocoDataset
|
||||
# FewShotCocoDefaultDataset predefine ann_cfg for model reproducibility
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotCocoDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='10SHOT')],
|
||||
num_novel_shots=10,
|
||||
num_base_shots=10)))
|
||||
evaluation = dict(interval=80000)
|
||||
checkpoint_config = dict(interval=80000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(warmup_iters=10, step=[144000])
|
||||
runner = dict(max_iters=160000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'faster_rcnn_r101_fpn_coco_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=80,
|
||||
scale=20)))
|
|
@ -0,0 +1,39 @@
|
|||
_base_ = [
|
||||
'../../_base_/datasets/finetune_based/few_shot_coco.py',
|
||||
'../../_base_/schedules/schedule.py',
|
||||
'../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotCocoDataset
|
||||
# FewShotCocoDefaultDataset predefine ann_cfg for model reproducibility
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotCocoDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='30SHOT')],
|
||||
num_novel_shots=30,
|
||||
num_base_shots=30)))
|
||||
evaluation = dict(interval=120000)
|
||||
checkpoint_config = dict(interval=120000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(
|
||||
warmup_iters=10, step=[
|
||||
216000,
|
||||
])
|
||||
runner = dict(max_iters=240000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'faster_rcnn_r101_fpn_coco_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=80,
|
||||
scale=20)))
|
|
@ -0,0 +1,13 @@
|
|||
_base_ = [
|
||||
'../../_base_/datasets/finetune_based/base_coco.py',
|
||||
'../../_base_/schedules/schedule.py',
|
||||
'../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../_base_/default_runtime.py'
|
||||
]
|
||||
lr_config = dict(warmup_iters=1000, step=[85000, 100000])
|
||||
runner = dict(max_iters=110000)
|
||||
# model settings
|
||||
model = dict(
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101),
|
||||
roi_head=dict(bbox_head=dict(num_classes=60), ))
|
|
@ -0,0 +1,45 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT1_10SHOT')],
|
||||
num_novel_shots=10,
|
||||
num_base_shots=10,
|
||||
classes='ALL_CLASSES_SPLIT1')),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT1'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT1'))
|
||||
evaluation = dict(
|
||||
interval=40000,
|
||||
class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
|
||||
checkpoint_config = dict(interval=40000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(
|
||||
warmup_iters=10, step=[
|
||||
36000,
|
||||
])
|
||||
runner = dict(max_iters=40000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split1_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,45 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT1_1SHOT')],
|
||||
num_novel_shots=1,
|
||||
num_base_shots=1,
|
||||
classes='ALL_CLASSES_SPLIT1')),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT1'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT1'))
|
||||
evaluation = dict(
|
||||
interval=4000,
|
||||
class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
|
||||
checkpoint_config = dict(interval=4000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(
|
||||
warmup=None, step=[
|
||||
3000,
|
||||
])
|
||||
runner = dict(max_iters=4000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split1_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,45 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT1_2SHOT')],
|
||||
num_novel_shots=2,
|
||||
num_base_shots=2,
|
||||
classes='ALL_CLASSES_SPLIT1')),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT1'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT1'))
|
||||
evaluation = dict(
|
||||
interval=8000,
|
||||
class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
|
||||
checkpoint_config = dict(interval=8000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(
|
||||
warmup=None, step=[
|
||||
7000,
|
||||
])
|
||||
runner = dict(max_iters=8000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split1_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,45 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT1_3SHOT')],
|
||||
num_novel_shots=3,
|
||||
num_base_shots=3,
|
||||
classes='ALL_CLASSES_SPLIT1')),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT1'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT1'))
|
||||
evaluation = dict(
|
||||
interval=12000,
|
||||
class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
|
||||
checkpoint_config = dict(interval=12000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(
|
||||
warmup_iters=10, step=[
|
||||
11000,
|
||||
])
|
||||
runner = dict(max_iters=12000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split1_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,45 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT1_5SHOT')],
|
||||
num_novel_shots=5,
|
||||
num_base_shots=5,
|
||||
classes='ALL_CLASSES_SPLIT1')),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT1'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT1'))
|
||||
evaluation = dict(
|
||||
interval=20000,
|
||||
class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
|
||||
checkpoint_config = dict(interval=20000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(
|
||||
warmup_iters=10, step=[
|
||||
18000,
|
||||
])
|
||||
runner = dict(max_iters=20000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split1_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,18 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/base_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
data = dict(
|
||||
train=dict(dataset=dict(classes='BASE_CLASSES_SPLIT1')),
|
||||
val=dict(classes='BASE_CLASSES_SPLIT1'),
|
||||
test=dict(classes='BASE_CLASSES_SPLIT1'))
|
||||
lr_config = dict(warmup_iters=100, step=[12000, 16000])
|
||||
runner = dict(max_iters=18000)
|
||||
# model settings
|
||||
model = dict(
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101),
|
||||
roi_head=dict(bbox_head=dict(num_classes=15)))
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT2_10SHOT')],
|
||||
num_novel_shots=10,
|
||||
num_base_shots=10,
|
||||
classes='ALL_CLASSES_SPLIT2',
|
||||
)),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT2'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT2'))
|
||||
evaluation = dict(
|
||||
interval=40000,
|
||||
class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
|
||||
checkpoint_config = dict(interval=40000)
|
||||
optimizer = dict(lr=0.005)
|
||||
lr_config = dict(
|
||||
warmup_iters=10, step=[
|
||||
36000,
|
||||
])
|
||||
runner = dict(max_iters=40000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split2_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT2_1SHOT')],
|
||||
num_novel_shots=1,
|
||||
num_base_shots=1,
|
||||
classes='ALL_CLASSES_SPLIT2',
|
||||
)),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT2'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT2'))
|
||||
evaluation = dict(
|
||||
interval=4000,
|
||||
class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
|
||||
checkpoint_config = dict(interval=4000)
|
||||
optimizer = dict(lr=0.005)
|
||||
lr_config = dict(
|
||||
warmup=None, step=[
|
||||
3000,
|
||||
])
|
||||
runner = dict(max_iters=4000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split2_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT2_2SHOT')],
|
||||
num_novel_shots=2,
|
||||
num_base_shots=2,
|
||||
classes='ALL_CLASSES_SPLIT2',
|
||||
)),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT2'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT2'))
|
||||
evaluation = dict(
|
||||
interval=8000,
|
||||
class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
|
||||
checkpoint_config = dict(interval=8000)
|
||||
optimizer = dict(lr=0.005)
|
||||
lr_config = dict(
|
||||
warmup=None, step=[
|
||||
7000,
|
||||
])
|
||||
runner = dict(max_iters=8000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split2_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT2_3SHOT')],
|
||||
num_novel_shots=3,
|
||||
num_base_shots=3,
|
||||
classes='ALL_CLASSES_SPLIT2',
|
||||
)),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT2'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT2'))
|
||||
evaluation = dict(
|
||||
interval=12000,
|
||||
class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
|
||||
checkpoint_config = dict(interval=12000)
|
||||
optimizer = dict(lr=0.005)
|
||||
lr_config = dict(
|
||||
warmup_iters=10, step=[
|
||||
11000,
|
||||
])
|
||||
runner = dict(max_iters=12000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split2_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT2_5SHOT')],
|
||||
num_novel_shots=5,
|
||||
num_base_shots=5,
|
||||
classes='ALL_CLASSES_SPLIT2',
|
||||
)),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT2'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT2'))
|
||||
evaluation = dict(
|
||||
interval=20000,
|
||||
class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
|
||||
checkpoint_config = dict(interval=20000)
|
||||
optimizer = dict(lr=0.005)
|
||||
lr_config = dict(
|
||||
warmup_iters=10, step=[
|
||||
18000,
|
||||
])
|
||||
runner = dict(max_iters=20000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split2_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,18 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/base_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
data = dict(
|
||||
train=dict(dataset=dict(classes='BASE_CLASSES_SPLIT2')),
|
||||
val=dict(classes='BASE_CLASSES_SPLIT2'),
|
||||
test=dict(classes='BASE_CLASSES_SPLIT2'))
|
||||
lr_config = dict(warmup_iters=100, step=[12000, 16000])
|
||||
runner = dict(max_iters=18000)
|
||||
# model settings
|
||||
model = dict(
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, ),
|
||||
roi_head=dict(bbox_head=dict(num_classes=15), ))
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT3_10SHOT')],
|
||||
num_novel_shots=10,
|
||||
num_base_shots=10,
|
||||
classes='ALL_CLASSES_SPLIT3',
|
||||
)),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT3'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT3'))
|
||||
evaluation = dict(
|
||||
interval=40000,
|
||||
class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
|
||||
checkpoint_config = dict(interval=40000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(
|
||||
warmup_iters=10, step=[
|
||||
36000,
|
||||
])
|
||||
runner = dict(max_iters=40000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split3_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT3_1SHOT')],
|
||||
num_novel_shots=1,
|
||||
num_base_shots=1,
|
||||
classes='ALL_CLASSES_SPLIT3',
|
||||
)),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT3'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT3'))
|
||||
evaluation = dict(
|
||||
interval=4000,
|
||||
class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
|
||||
checkpoint_config = dict(interval=4000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(
|
||||
warmup=None, step=[
|
||||
3000,
|
||||
])
|
||||
runner = dict(max_iters=4000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split3_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT3_2SHOT')],
|
||||
num_novel_shots=2,
|
||||
num_base_shots=2,
|
||||
classes='ALL_CLASSES_SPLIT3',
|
||||
)),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT3'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT3'))
|
||||
evaluation = dict(
|
||||
interval=8000,
|
||||
class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
|
||||
checkpoint_config = dict(interval=8000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(
|
||||
warmup=None, step=[
|
||||
7000,
|
||||
])
|
||||
runner = dict(max_iters=8000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split3_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT3_3SHOT')],
|
||||
num_novel_shots=3,
|
||||
num_base_shots=3,
|
||||
classes='ALL_CLASSES_SPLIT3',
|
||||
)),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT3'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT3'))
|
||||
evaluation = dict(
|
||||
interval=12000,
|
||||
class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
|
||||
checkpoint_config = dict(interval=12000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(
|
||||
warmup_iters=10, step=[
|
||||
11000,
|
||||
])
|
||||
runner = dict(max_iters=12000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split3_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/few_shot_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
# FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
|
||||
data = dict(
|
||||
train=dict(
|
||||
dataset=dict(
|
||||
type='FewShotVOCDefaultDataset',
|
||||
ann_cfg=[dict(method='TFA', setting='SPLIT3_5SHOT')],
|
||||
num_novel_shots=5,
|
||||
num_base_shots=5,
|
||||
classes='ALL_CLASSES_SPLIT3',
|
||||
)),
|
||||
val=dict(classes='ALL_CLASSES_SPLIT3'),
|
||||
test=dict(classes='ALL_CLASSES_SPLIT3'))
|
||||
evaluation = dict(
|
||||
interval=20000,
|
||||
class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
|
||||
checkpoint_config = dict(interval=20000)
|
||||
optimizer = dict(lr=0.001)
|
||||
lr_config = dict(
|
||||
warmup_iters=10, step=[
|
||||
18000,
|
||||
])
|
||||
runner = dict(max_iters=20000)
|
||||
# load_from = 'path of base training model'
|
||||
load_from = \
|
||||
'work_dirs/' \
|
||||
'tfa_faster_rcnn_r101_fpn_voc_split3_base_training/' \
|
||||
'model_reset_surgery.pth'
|
||||
model = dict(
|
||||
frozen_parameters=[
|
||||
'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
|
||||
],
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, frozen_stages=4),
|
||||
roi_head=dict(
|
||||
bbox_head=dict(
|
||||
type='CosineSimBBoxHead',
|
||||
num_shared_fcs=2,
|
||||
num_classes=20,
|
||||
scale=20)))
|
|
@ -0,0 +1,18 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/finetune_based/base_voc.py',
|
||||
'../../../_base_/schedules/schedule.py',
|
||||
'../../../_base_/models/faster_rcnn_r50_caffe_fpn.py',
|
||||
'../../../_base_/default_runtime.py'
|
||||
]
|
||||
# classes splits are predefined in FewShotVOCDataset
|
||||
data = dict(
|
||||
train=dict(dataset=dict(classes='BASE_CLASSES_SPLIT3')),
|
||||
val=dict(classes='BASE_CLASSES_SPLIT3'),
|
||||
test=dict(classes='BASE_CLASSES_SPLIT3'))
|
||||
lr_config = dict(warmup_iters=100, step=[12000, 16000])
|
||||
runner = dict(max_iters=18000)
|
||||
# model settings
|
||||
model = dict(
|
||||
pretrained='open-mmlab://detectron2/resnet101_caffe',
|
||||
backbone=dict(depth=101, ),
|
||||
roi_head=dict(bbox_head=dict(num_classes=15), ))
|
|
@ -1,5 +1,5 @@
|
|||
from .attention_rpn import AttentionRPN
|
||||
from .attention_rpn_detector import AttentionRPNDetector
|
||||
from .base_query_support import BaseQuerySupportDetector
|
||||
from .fsdetview import FsDetView
|
||||
from .fsdetview import FSDetView
|
||||
|
||||
__all__ = ['BaseQuerySupportDetector', 'AttentionRPN', 'FsDetView']
|
||||
__all__ = ['BaseQuerySupportDetector', 'AttentionRPNDetector', 'FSDetView']
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
from .attention_rpn_roi_head import AttentionRPNRoIHead
|
||||
from .bbox_heads import (ContrastiveBBoxHead, CosineSimBBoxHead,
|
||||
MultiRelationBBoxHead)
|
||||
from .contrastive_roi_head import ContrastiveRoIHead
|
||||
from .fsdetview_roi_head import FSDetViewRoIHead
|
||||
from .meta_rcnn_roi_head import MetaRCNNRoIHead
|
||||
from .shared_heads import MetaRCNNResLayer
|
||||
|
||||
__all__ = [
|
||||
'CosineSimBBoxHead', 'ContrastiveBBoxHead', 'MultiRelationBBoxHead',
|
||||
'ContrastiveRoIHead', 'AttentionRPNRoIHead', 'FSDetViewRoIHead',
|
||||
'MetaRCNNRoIHead', 'MetaRCNNResLayer'
|
||||
]
|
|
@ -0,0 +1,9 @@
|
|||
from .contrastive_bbox_head import ContrastiveBBoxHead
|
||||
from .cosine_sim_bbox_head import CosineSimBBoxHead
|
||||
from .meta_bbox_head import MetaBBoxHead
|
||||
from .multi_relation_bbox_head import MultiRelationBBoxHead
|
||||
|
||||
__all__ = [
|
||||
'CosineSimBBoxHead', 'ContrastiveBBoxHead', 'MultiRelationBBoxHead',
|
||||
'MetaBBoxHead'
|
||||
]
|
|
@ -0,0 +1,99 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmdet.models.builder import HEADS
|
||||
from mmdet.models.roi_heads import ConvFCBBoxHead
|
||||
|
||||
EPS = 1e-5
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class CosineSimBBoxHead(ConvFCBBoxHead):
|
||||
"""BBOxHead for `TFA <https://arxiv.org/abs/2003.06957>`_.
|
||||
|
||||
The code is modified from the official implementation
|
||||
https://github.com/ucbdrive/few-shot-object-detection/
|
||||
|
||||
Args:
|
||||
scale (int): Scaling factor of `cls_score`. Default: 20.
|
||||
learnable_scale (bool): Learnable global scaling factor.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, scale=20, learnable_scale=False, *args, **kwargs):
|
||||
super(CosineSimBBoxHead, self).__init__(*args, **kwargs)
|
||||
# override the fc_cls in :obj:`ConvFCBBoxHead`
|
||||
if self.with_cls:
|
||||
self.fc_cls = nn.Linear(
|
||||
self.cls_last_dim, self.num_classes + 1, bias=False)
|
||||
|
||||
# learnable global scaling factor
|
||||
if learnable_scale:
|
||||
self.scale = nn.Parameter(torch.ones(1) * scale)
|
||||
else:
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Shape of (N, C, H, W).
|
||||
|
||||
Returns:
|
||||
tuple(Tensor, Tensor): Box scores with shape of
|
||||
(N, num_classes, H, W) and Box energies /
|
||||
deltas with shape of (N, 4, H, W).
|
||||
"""
|
||||
# shared part
|
||||
if self.num_shared_convs > 0:
|
||||
for conv in self.shared_convs:
|
||||
x = conv(x)
|
||||
|
||||
if self.num_shared_fcs > 0:
|
||||
if self.with_avg_pool:
|
||||
x = self.avg_pool(x)
|
||||
|
||||
x = x.flatten(1)
|
||||
|
||||
for fc in self.shared_fcs:
|
||||
x = self.relu(fc(x))
|
||||
# separate branches
|
||||
x_cls = x
|
||||
x_reg = x
|
||||
|
||||
for conv in self.cls_convs:
|
||||
x_cls = conv(x_cls)
|
||||
if x_cls.dim() > 2:
|
||||
if self.with_avg_pool:
|
||||
x_cls = self.avg_pool(x_cls)
|
||||
x_cls = x_cls.flatten(1)
|
||||
for fc in self.cls_fcs:
|
||||
x_cls = self.relu(fc(x_cls))
|
||||
|
||||
for conv in self.reg_convs:
|
||||
x_reg = conv(x_reg)
|
||||
if x_reg.dim() > 2:
|
||||
if self.with_avg_pool:
|
||||
x_reg = self.avg_pool(x_reg)
|
||||
x_reg = x_reg.flatten(1)
|
||||
for fc in self.reg_fcs:
|
||||
x_reg = self.relu(fc(x_reg))
|
||||
|
||||
bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
|
||||
|
||||
if x_cls.dim() > 2:
|
||||
x_cls = torch.flatten(x_cls, start_dim=1)
|
||||
|
||||
# normalize the input x along the `input_size` dimension
|
||||
x_norm = torch.norm(x_cls, p=2, dim=1).unsqueeze(1).expand_as(x)
|
||||
x_cls_normalized = x_cls.div(x_norm + EPS)
|
||||
# normalize weight
|
||||
with torch.no_grad():
|
||||
temp_norm = torch.norm(
|
||||
self.fc_cls.weight, p=2,
|
||||
dim=1).unsqueeze(1).expand_as(self.fc_cls.weight)
|
||||
self.fc_cls.weight.div_(temp_norm + EPS)
|
||||
# calculate and scale cls_score
|
||||
cls_score = self.scale * self.fc_cls(x_cls_normalized) \
|
||||
if self.with_cls else None
|
||||
|
||||
return cls_score, bbox_pred
|
|
@ -0,0 +1,274 @@
|
|||
"""Modified the classifier of base model for novel class fine-tuning.
|
||||
|
||||
Initialize the classifier with the checkpoint in base training for
|
||||
novel class fine-tuning. For more details, It would initialize a
|
||||
classifier head with total (num_base_classes + num_novel_classes)
|
||||
classes, for classes that inherit from the base training,
|
||||
the weight would be load from the corresponding base training
|
||||
checkpoint. For novel classes, the weight would be randomly initialized.
|
||||
Temporally, we only use this script in FSCE and TFA with --method randinit.
|
||||
This part of code is modified from
|
||||
https://github.com/ucbdrive/few-shot-object-detection/.
|
||||
|
||||
Example:
|
||||
# VOC base model
|
||||
python3 -m tools.models.checkpoint_surgery \
|
||||
--src1 work_dirs/voc_split1_base_training/latest.pth \
|
||||
--method randinit \
|
||||
--save-dir work_dirs/voc_split1_base_training
|
||||
# COCO base model
|
||||
python3 -m tools.models.checkpoint_surgery \
|
||||
--src1 work_dirs/coco_base_training/latest.pth \
|
||||
--method randinit \
|
||||
--coco \
|
||||
--save-dir work_dirs/coco_base_training
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
# COCO config
|
||||
COCO_NOVEL_CLASSES = [
|
||||
1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72
|
||||
]
|
||||
COCO_BASE_CLASSES = [
|
||||
8, 10, 11, 13, 14, 15, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37,
|
||||
38, 39, 40, 41, 42, 43, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58,
|
||||
59, 60, 61, 65, 70, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87,
|
||||
88, 89, 90
|
||||
]
|
||||
COCO_ALL_CLASSES = sorted(COCO_BASE_CLASSES + COCO_NOVEL_CLASSES)
|
||||
COCO_IDMAP = {v: i for i, v in enumerate(COCO_ALL_CLASSES)}
|
||||
COCO_TAR_SIZE = 80
|
||||
# LVIS config
|
||||
LVIS_NOVEL_CLASSES = [
|
||||
0, 6, 9, 13, 14, 15, 20, 21, 30, 37, 38, 39, 41, 45, 48, 50, 51, 63, 64,
|
||||
69, 71, 73, 82, 85, 93, 99, 100, 104, 105, 106, 112, 115, 116, 119, 121,
|
||||
124, 126, 129, 130, 135, 139, 141, 142, 143, 146, 149, 154, 158, 160, 162,
|
||||
163, 166, 168, 172, 180, 181, 183, 195, 198, 202, 204, 205, 208, 212, 213,
|
||||
216, 217, 218, 225, 226, 230, 235, 237, 238, 240, 241, 242, 244, 245, 248,
|
||||
249, 250, 251, 252, 254, 257, 258, 264, 265, 269, 270, 272, 279, 283, 286,
|
||||
290, 292, 294, 295, 297, 299, 302, 303, 305, 306, 309, 310, 312, 315, 316,
|
||||
317, 319, 320, 321, 323, 325, 327, 328, 329, 334, 335, 341, 343, 349, 350,
|
||||
353, 355, 356, 357, 358, 359, 360, 365, 367, 368, 369, 371, 377, 378, 384,
|
||||
385, 387, 388, 392, 393, 401, 402, 403, 405, 407, 410, 412, 413, 416, 419,
|
||||
420, 422, 426, 429, 432, 433, 434, 437, 438, 440, 441, 445, 453, 454, 455,
|
||||
461, 463, 468, 472, 475, 476, 477, 482, 484, 485, 487, 488, 492, 494, 495,
|
||||
497, 508, 509, 511, 513, 514, 515, 517, 520, 523, 524, 525, 526, 529, 533,
|
||||
540, 541, 542, 544, 547, 550, 551, 552, 554, 555, 561, 563, 568, 571, 572,
|
||||
580, 581, 583, 584, 585, 586, 589, 591, 592, 593, 595, 596, 599, 601, 604,
|
||||
608, 609, 611, 612, 615, 616, 625, 626, 628, 629, 630, 633, 635, 642, 644,
|
||||
645, 649, 655, 657, 658, 662, 663, 664, 670, 673, 675, 676, 682, 683, 685,
|
||||
689, 695, 697, 699, 702, 711, 712, 715, 721, 722, 723, 724, 726, 729, 731,
|
||||
733, 734, 738, 740, 741, 744, 748, 754, 758, 764, 766, 767, 768, 771, 772,
|
||||
774, 776, 777, 781, 782, 784, 789, 790, 794, 795, 796, 798, 799, 803, 805,
|
||||
806, 807, 808, 815, 817, 820, 821, 822, 824, 825, 827, 832, 833, 835, 836,
|
||||
840, 842, 844, 846, 856, 862, 863, 864, 865, 866, 868, 869, 870, 871, 872,
|
||||
875, 877, 882, 886, 892, 893, 897, 898, 900, 901, 904, 905, 907, 915, 918,
|
||||
919, 920, 921, 922, 926, 927, 930, 931, 933, 939, 940, 944, 945, 946, 948,
|
||||
950, 951, 953, 954, 955, 956, 958, 959, 961, 962, 963, 969, 974, 975, 988,
|
||||
990, 991, 998, 999, 1001, 1003, 1005, 1008, 1009, 1010, 1012, 1015, 1020,
|
||||
1022, 1025, 1026, 1028, 1029, 1032, 1033, 1046, 1047, 1048, 1049, 1050,
|
||||
1055, 1066, 1067, 1068, 1072, 1073, 1076, 1077, 1086, 1094, 1099, 1103,
|
||||
1111, 1132, 1135, 1137, 1138, 1139, 1140, 1144, 1146, 1148, 1150, 1152,
|
||||
1153, 1156, 1158, 1165, 1166, 1167, 1168, 1169, 1171, 1178, 1179, 1180,
|
||||
1186, 1187, 1188, 1189, 1203, 1204, 1205, 1213, 1215, 1218, 1224, 1225,
|
||||
1227
|
||||
]
|
||||
LVIS_BASE_CLASSES = [c for c in range(1230) if c not in LVIS_NOVEL_CLASSES]
|
||||
LVIS_ALL_CLASSES = sorted(LVIS_BASE_CLASSES + LVIS_NOVEL_CLASSES)
|
||||
LVIS_IDMAP = {v: i for i, v in enumerate(LVIS_ALL_CLASSES)}
|
||||
LVIS_TAR_SIZE = 1230
|
||||
# VOC config
|
||||
VOC_TAR_SIZE = 20
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Paths
|
||||
parser.add_argument('--src1', type=str, help='Path to the main checkpoint')
|
||||
parser.add_argument(
|
||||
'--src2',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to the secondary checkpoint. Only used when combining '
|
||||
'fc layers of two checkpoints')
|
||||
parser.add_argument(
|
||||
'--save-dir', type=str, default=None, help='Save directory')
|
||||
parser.add_argument(
|
||||
'--method',
|
||||
choices=['combine', 'remove', 'randinit'],
|
||||
required=True,
|
||||
help='Surgery method. combine = '
|
||||
'combine checkpoints. remove = for fine-tuning on '
|
||||
'novel dataset, remove the final layer of the '
|
||||
'base detector. randinit = randomly initialize '
|
||||
'novel weights.')
|
||||
parser.add_argument(
|
||||
'--param-name',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=['roi_head.bbox_head.fc_cls', 'roi_head.bbox_head.fc_reg'],
|
||||
help='Target parameter names')
|
||||
parser.add_argument(
|
||||
'--tar-name',
|
||||
type=str,
|
||||
default='model_reset',
|
||||
help='Name of the new checkpoint')
|
||||
# Dataset
|
||||
parser.add_argument('--coco', action='store_true', help='For COCO models')
|
||||
parser.add_argument('--lvis', action='store_true', help='For LVIS models')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def random_init_checkpoint(param_name, is_weight, tar_size, checkpoint, args):
|
||||
"""Either remove the final layer weights for fine-tuning on novel dataset
|
||||
or append randomly initialized weights for the novel classes.
|
||||
|
||||
Note: The base detector for LVIS contains weights for all classes, but only
|
||||
the weights corresponding to base classes are updated during base training
|
||||
(this design choice has no particular reason). Thus, the random
|
||||
initialization step is not really necessary.
|
||||
"""
|
||||
weight_name = param_name + ('.weight' if is_weight else '.bias')
|
||||
pretrained_weight = checkpoint['state_dict'][weight_name]
|
||||
prev_cls = pretrained_weight.size(0)
|
||||
if 'fc_cls' in param_name:
|
||||
prev_cls -= 1
|
||||
if is_weight:
|
||||
feat_size = pretrained_weight.size(1)
|
||||
new_weight = torch.rand((tar_size, feat_size))
|
||||
torch.nn.init.normal_(new_weight, 0, 0.01)
|
||||
else:
|
||||
new_weight = torch.zeros(tar_size)
|
||||
if args.coco or args.lvis:
|
||||
BASE_CLASSES = COCO_BASE_CLASSES if args.coco else LVIS_BASE_CLASSES
|
||||
IDMAP = COCO_IDMAP if args.coco else LVIS_IDMAP
|
||||
for i, c in enumerate(BASE_CLASSES):
|
||||
idx = i if args.coco else c
|
||||
if 'fc_cls' in param_name:
|
||||
new_weight[IDMAP[c]] = pretrained_weight[idx]
|
||||
else:
|
||||
new_weight[IDMAP[c] * 4:(IDMAP[c] + 1) * 4] = \
|
||||
pretrained_weight[idx * 4:(idx + 1) * 4]
|
||||
else:
|
||||
new_weight[:prev_cls] = pretrained_weight[:prev_cls]
|
||||
if 'fc_cls' in param_name:
|
||||
new_weight[-1] = pretrained_weight[-1] # bg class
|
||||
checkpoint['state_dict'][weight_name] = new_weight
|
||||
|
||||
|
||||
def combine_checkpoints(param_name, is_weight, tar_size, checkpoint,
|
||||
checkpoint2, args):
|
||||
"""Combine base detector with novel detector.
|
||||
|
||||
Feature extractor weights are from the base detector. Only the final layer
|
||||
weights are combined.
|
||||
"""
|
||||
if not is_weight and param_name + '.bias' not in checkpoint['state_dict']:
|
||||
return
|
||||
if not is_weight and param_name + '.bias' not in checkpoint2['state_dict']:
|
||||
return
|
||||
weight_name = param_name + ('.weight' if is_weight else '.bias')
|
||||
pretrained_weight = checkpoint['state_dict'][weight_name]
|
||||
prev_cls = pretrained_weight.size(0)
|
||||
if 'fc_cls' in param_name:
|
||||
prev_cls -= 1
|
||||
if is_weight:
|
||||
feat_size = pretrained_weight.size(1)
|
||||
new_weight = torch.rand((tar_size, feat_size))
|
||||
else:
|
||||
new_weight = torch.zeros(tar_size)
|
||||
if args.coco or args.lvis:
|
||||
BASE_CLASSES = COCO_BASE_CLASSES if args.coco else LVIS_BASE_CLASSES
|
||||
IDMAP = COCO_IDMAP if args.coco else LVIS_IDMAP
|
||||
for i, c in enumerate(BASE_CLASSES):
|
||||
idx = i if args.coco else c
|
||||
if 'fc_cls' in param_name:
|
||||
new_weight[IDMAP[c]] = pretrained_weight[idx]
|
||||
else:
|
||||
new_weight[IDMAP[c] * 4:(IDMAP[c] + 1) * 4] = \
|
||||
pretrained_weight[idx * 4:(idx + 1) * 4]
|
||||
else:
|
||||
new_weight[:prev_cls] = pretrained_weight[:prev_cls]
|
||||
|
||||
checkpoint2_weight = checkpoint2['state_dict'][weight_name]
|
||||
if args.coco or args.lvis:
|
||||
NOVEL_CLASSES = COCO_NOVEL_CLASSES if args.coco else LVIS_NOVEL_CLASSES
|
||||
IDMAP = COCO_IDMAP if args.coco else LVIS_IDMAP
|
||||
for i, c in enumerate(NOVEL_CLASSES):
|
||||
if 'fc_cls' in param_name:
|
||||
new_weight[IDMAP[c]] = checkpoint2_weight[i]
|
||||
else:
|
||||
new_weight[IDMAP[c] * 4:(IDMAP[c] + 1) * 4] = \
|
||||
checkpoint2_weight[i * 4:(i + 1) * 4]
|
||||
if 'fc_cls' in param_name:
|
||||
new_weight[-1] = pretrained_weight[-1]
|
||||
else:
|
||||
if 'fc_cls' in param_name:
|
||||
new_weight[prev_cls:-1] = checkpoint2_weight[:-1]
|
||||
new_weight[-1] = pretrained_weight[-1]
|
||||
else:
|
||||
new_weight[prev_cls:] = checkpoint2_weight
|
||||
checkpoint['state_dict'][weight_name] = new_weight
|
||||
return checkpoint
|
||||
|
||||
|
||||
def reset_checkpoint(checkpoint):
|
||||
if 'scheduler' in checkpoint:
|
||||
del checkpoint['scheduler']
|
||||
if 'optimizer' in checkpoint:
|
||||
del checkpoint['optimizer']
|
||||
if 'iteration' in checkpoint:
|
||||
checkpoint['iteration'] = 0
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
checkpoint = torch.load(args.src1)
|
||||
save_name = args.tar_name + f'_{args.method}.pth'
|
||||
save_dir = args.save_dir \
|
||||
if args.save_dir != '' else os.path.dirname(args.src1)
|
||||
save_path = os.path.join(save_dir, save_name)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
reset_checkpoint(checkpoint)
|
||||
|
||||
if args.coco:
|
||||
TAR_SIZE = COCO_TAR_SIZE
|
||||
elif args.lvis:
|
||||
TAR_SIZE = LVIS_TAR_SIZE
|
||||
else:
|
||||
TAR_SIZE = VOC_TAR_SIZE
|
||||
|
||||
if args.method == 'remove':
|
||||
# Remove parameters
|
||||
for param_name in args.param_name:
|
||||
del checkpoint['state_dict'][param_name + '.weight']
|
||||
if param_name + '.bias' in checkpoint['state_dict']:
|
||||
del checkpoint['state_dict'][param_name + '.bias']
|
||||
elif args.method == 'combine':
|
||||
checkpoint2 = torch.load(args.src2)
|
||||
tar_sizes = [TAR_SIZE + 1, TAR_SIZE * 4]
|
||||
for idx, (param_name,
|
||||
tar_size) in enumerate(zip(args.param_name, tar_sizes)):
|
||||
combine_checkpoints(param_name, True, tar_size, checkpoint,
|
||||
checkpoint2)
|
||||
combine_checkpoints(param_name, False, tar_size, checkpoint,
|
||||
checkpoint2)
|
||||
elif args.method == 'randinit':
|
||||
tar_sizes = [TAR_SIZE + 1, TAR_SIZE * 4]
|
||||
for idx, (param_name,
|
||||
tar_size) in enumerate(zip(args.param_name, tar_sizes)):
|
||||
random_init_checkpoint(param_name, True, tar_size, checkpoint)
|
||||
random_init_checkpoint(param_name, False, tar_size, checkpoint)
|
||||
else:
|
||||
raise ValueError(f'not support method: {args.method}')
|
||||
|
||||
torch.save(checkpoint, save_path)
|
||||
print('save changed checkpoint to {}'.format(save_path))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue