diff --git a/configs/detection/_base_/datasets/finetune_based/base_coco.py b/configs/detection/_base_/datasets/finetune_based/base_coco.py index dcc8d7d..3a1d33f 100644 --- a/configs/detection/_base_/datasets/finetune_based/base_coco.py +++ b/configs/detection/_base_/datasets/finetune_based/base_coco.py @@ -31,8 +31,7 @@ test_pipeline = [ dict(type='Collect', keys=['img']) ]) ] -# Predefined ann_cfg, classes and class_splits are defined in -# mmfewshot.detection.datasets.few_shot_data_config +# classes splits are predefined in FewShotCocoDataset data_root = 'data/coco/' data = dict( samples_per_gpu=2, diff --git a/configs/detection/_base_/datasets/finetune_based/base_voc.py b/configs/detection/_base_/datasets/finetune_based/base_voc.py index fe5ef50..3dda3d9 100644 --- a/configs/detection/_base_/datasets/finetune_based/base_voc.py +++ b/configs/detection/_base_/datasets/finetune_based/base_voc.py @@ -32,8 +32,7 @@ test_pipeline = [ dict(type='Collect', keys=['img']) ]) ] -# Predefined ann_cfg, classes and class_splits are defined in -# mmfewshot.detection.datasets.few_shot_data_config +# classes splits are predefined in FewShotVOCDataset data_root = 'data/VOCdevkit/' data = dict( samples_per_gpu=2, diff --git a/configs/detection/_base_/datasets/finetune_based/few_shot_coco.py b/configs/detection/_base_/datasets/finetune_based/few_shot_coco.py index bc57af2..e023659 100644 --- a/configs/detection/_base_/datasets/finetune_based/few_shot_coco.py +++ b/configs/detection/_base_/datasets/finetune_based/few_shot_coco.py @@ -31,8 +31,7 @@ test_pipeline = [ dict(type='Collect', keys=['img']) ]) ] -# Predefined ann_cfg, classes and class_splits are defined in -# mmfewshot.detection.datasets.few_shot_data_config +# classes splits are predefined in FewShotCocoDataset data_root = 'data/coco/' data = dict( samples_per_gpu=2, diff --git a/configs/detection/_base_/datasets/finetune_based/few_shot_voc.py b/configs/detection/_base_/datasets/finetune_based/few_shot_voc.py index b126c58..c80aad2 100644 --- a/configs/detection/_base_/datasets/finetune_based/few_shot_voc.py +++ b/configs/detection/_base_/datasets/finetune_based/few_shot_voc.py @@ -34,7 +34,7 @@ test_pipeline = [ dict(type='Collect', keys=['img']) ]) ] - +# classes splits are predefined in FewShotVOCDataset data_root = 'data/VOCdevkit/' data = dict( samples_per_gpu=2, diff --git a/configs/detection/_base_/datasets/nway_kshot/base_coco.py b/configs/detection/_base_/datasets/nway_kshot/base_coco.py index bf8fcad..0d05c6a 100644 --- a/configs/detection/_base_/datasets/nway_kshot/base_coco.py +++ b/configs/detection/_base_/datasets/nway_kshot/base_coco.py @@ -19,7 +19,7 @@ train_multi_pipelines = dict( dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True), dict(type='Normalize', **img_norm_cfg), - dict(type='ResizeWithMask', target_size=(224, 224)), + dict(type='GenerateMask', target_size=(224, 224)), dict(type='RandomFlip', flip_ratio=0.0), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) @@ -44,7 +44,6 @@ data_root = 'data/coco/' data = dict( samples_per_gpu=4, workers_per_gpu=2, - copy_random_support=True, train=dict( type='NwayKshotDataset', num_support_ways=60, @@ -87,7 +86,8 @@ data = dict( pipeline=test_pipeline, test_mode=True, classes='BASE_CLASSES'), - support_template=dict( + model_init=dict( + copy_from_train_dataset=True, samples_per_gpu=16, workers_per_gpu=1, type='FewShotCocoDataset', @@ -96,5 +96,5 @@ data = dict( pipeline=train_multi_pipelines['support'], instance_wise=True, classes='BASE_CLASSES', - dataset_name='support template')) + dataset_name='model_init')) evaluation = dict(interval=20000, metric='bbox', classwise=True) diff --git a/configs/detection/_base_/datasets/nway_kshot/base_voc.py b/configs/detection/_base_/datasets/nway_kshot/base_voc.py index bd6eb0f..27a2b70 100644 --- a/configs/detection/_base_/datasets/nway_kshot/base_voc.py +++ b/configs/detection/_base_/datasets/nway_kshot/base_voc.py @@ -19,7 +19,7 @@ train_multi_pipelines = dict( dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True), dict(type='Normalize', **img_norm_cfg), - dict(type='ResizeWithMask', target_size=(224, 224)), + dict(type='GenerateMask', target_size=(224, 224)), dict(type='RandomFlip', flip_ratio=0.0), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) @@ -44,7 +44,6 @@ data_root = 'data/VOCdevkit/' data = dict( samples_per_gpu=4, workers_per_gpu=2, - copy_random_support=True, train=dict( type='NwayKshotDataset', num_support_ways=15, @@ -91,7 +90,8 @@ data = dict( pipeline=test_pipeline, test_mode=True, classes=None), - support_template=dict( + model_init=dict( + copy_from_train_dataset=True, samples_per_gpu=16, workers_per_gpu=1, type='FewShotVOCDataset', @@ -101,5 +101,5 @@ data = dict( use_difficult=False, instance_wise=True, classes=None, - dataset_name='support template')) + dataset_name='model_init')) evaluation = dict(interval=5000, metric='mAP') diff --git a/configs/detection/_base_/datasets/nway_kshot/few_shot_coco.py b/configs/detection/_base_/datasets/nway_kshot/few_shot_coco.py index c2133d1..2bbee23 100644 --- a/configs/detection/_base_/datasets/nway_kshot/few_shot_coco.py +++ b/configs/detection/_base_/datasets/nway_kshot/few_shot_coco.py @@ -19,7 +19,7 @@ train_multi_pipelines = dict( dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True), dict(type='Normalize', **img_norm_cfg), - dict(type='ResizeWithMask', target_size=(224, 224)), + dict(type='GenerateMask', target_size=(224, 224)), dict(type='RandomFlip', flip_ratio=0.0), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) @@ -44,7 +44,6 @@ data_root = 'data/coco/' data = dict( samples_per_gpu=4, workers_per_gpu=2, - copy_random_support=True, train=dict( type='NwayKshotDataset', num_support_ways=80, @@ -87,7 +86,8 @@ data = dict( pipeline=test_pipeline, test_mode=True, classes='ALL_CLASSES'), - support_template=dict( + model_init=dict( + copy_from_train_dataset=True, samples_per_gpu=16, workers_per_gpu=1, type='FewShotCocoDataset', @@ -96,7 +96,7 @@ data = dict( pipeline=train_multi_pipelines['support'], instance_wise=True, classes='ALL_CLASSES', - dataset_name='support template')) + dataset_name='model_init')) evaluation = dict( interval=3000, metric='bbox', diff --git a/configs/detection/_base_/datasets/nway_kshot/few_shot_voc.py b/configs/detection/_base_/datasets/nway_kshot/few_shot_voc.py index 0097088..a40479d 100644 --- a/configs/detection/_base_/datasets/nway_kshot/few_shot_voc.py +++ b/configs/detection/_base_/datasets/nway_kshot/few_shot_voc.py @@ -19,7 +19,7 @@ train_multi_pipelines = dict( dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True), dict(type='Normalize', **img_norm_cfg), - dict(type='ResizeWithMask', target_size=(224, 224)), + dict(type='GenerateMask', target_size=(224, 224)), dict(type='RandomFlip', flip_ratio=0.0), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) @@ -44,7 +44,6 @@ data_root = 'data/VOCdevkit/' data = dict( samples_per_gpu=2, workers_per_gpu=2, - copy_random_support=True, train=dict( type='NwayKshotDataset', num_support_ways=20, @@ -91,7 +90,8 @@ data = dict( pipeline=test_pipeline, test_mode=True, classes=None), - support_template=dict( + model_init=dict( + copy_from_train_dataset=True, samples_per_gpu=16, workers_per_gpu=1, type='FewShotVOCDataset', @@ -102,5 +102,5 @@ data = dict( instance_wise=True, num_novel_shots=None, classes=None, - dataset_name='support template')) + dataset_name='model_init')) evaluation = dict(interval=3000, metric='mAP', class_splits=None) diff --git a/configs/detection/_base_/datasets/query_aware/base_coco.py b/configs/detection/_base_/datasets/query_aware/base_coco.py index 5b8b2a6..bff23c9 100644 --- a/configs/detection/_base_/datasets/query_aware/base_coco.py +++ b/configs/detection/_base_/datasets/query_aware/base_coco.py @@ -20,8 +20,8 @@ train_multi_pipelines = dict( dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True), dict( - type='AttentionRPNCropResizeSupport', - context_pixel=16, + type='CropResizeInstance', + num_context_pixels=16, target_size=(320, 320)), dict(type='RandomFlip', flip_ratio=0.0), dict(type='Normalize', **img_norm_cfg), @@ -48,7 +48,6 @@ data_root = 'data/coco/' data = dict( samples_per_gpu=2, workers_per_gpu=2, - copy_random_support=False, train=dict( type='QueryAwareDataset', num_support_ways=2, @@ -90,7 +89,8 @@ data = dict( test_mode=True, classes='BASE_CLASSES'), # random sample 10 shot base instance to evaluate training - support_template=dict( + model_init=dict( + copy_from_train_dataset=False, samples_per_gpu=16, workers_per_gpu=1, type='FewShotCocoDataset', @@ -105,5 +105,5 @@ data = dict( num_base_shots=10, instance_wise=True, min_bbox_area_filter=32 * 32, - dataset_name='support template')) + dataset_name='model_init')) evaluation = dict(interval=20000, metric='bbox', classwise=True) diff --git a/configs/detection/_base_/datasets/query_aware/base_voc.py b/configs/detection/_base_/datasets/query_aware/base_voc.py index c73a331..70cb5d7 100644 --- a/configs/detection/_base_/datasets/query_aware/base_voc.py +++ b/configs/detection/_base_/datasets/query_aware/base_voc.py @@ -20,8 +20,8 @@ train_multi_pipelines = dict( dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True), dict( - type='AttentionRPNCropResizeSupport', - context_pixel=16, + type='CropResizeInstance', + num_context_pixels=16, target_size=(320, 320)), dict(type='RandomFlip', flip_ratio=0.0), dict(type='Normalize', **img_norm_cfg), @@ -48,7 +48,6 @@ data_root = 'data/VOCdevkit/' data = dict( samples_per_gpu=2, workers_per_gpu=2, - copy_random_support=False, train=dict( type='QueryAwareDataset', num_support_ways=2, @@ -97,7 +96,8 @@ data = dict( classes=None, ), # random sample 10 shot base instance to evaluate training - support_template=dict( + model_init=dict( + copy_from_train_dataset=False, samples_per_gpu=16, workers_per_gpu=1, type='FewShotVOCDataset', @@ -116,5 +116,5 @@ data = dict( instance_wise=True, classes=None, min_bbox_area_filter=32 * 32, - dataset_name='support template')) + dataset_name='model_init')) evaluation = dict(interval=20000, metric='mAP') diff --git a/configs/detection/_base_/datasets/query_aware/few_shot_coco.py b/configs/detection/_base_/datasets/query_aware/few_shot_coco.py index 4f7488b..7d2966a 100644 --- a/configs/detection/_base_/datasets/query_aware/few_shot_coco.py +++ b/configs/detection/_base_/datasets/query_aware/few_shot_coco.py @@ -20,8 +20,8 @@ train_multi_pipelines = dict( dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True), dict( - type='AttentionRPNCropResizeSupport', - context_pixel=16, + type='CropResizeInstance', + num_context_pixels=16, target_size=(320, 320)), dict(type='RandomFlip', flip_ratio=0.0), dict(type='Normalize', **img_norm_cfg), @@ -48,7 +48,6 @@ data_root = 'data/coco/' data = dict( samples_per_gpu=1, workers_per_gpu=2, - copy_random_support=True, train=dict( type='QueryAwareDataset', num_support_ways=None, @@ -90,7 +89,8 @@ data = dict( pipeline=test_pipeline, test_mode=True, classes=None), - support_template=dict( + model_init=dict( + copy_from_train_dataset=True, samples_per_gpu=16, workers_per_gpu=1, type='FewShotCocoDataset', diff --git a/configs/detection/_base_/datasets/query_aware/few_shot_voc.py b/configs/detection/_base_/datasets/query_aware/few_shot_voc.py index 3d55a75..61b254c 100644 --- a/configs/detection/_base_/datasets/query_aware/few_shot_voc.py +++ b/configs/detection/_base_/datasets/query_aware/few_shot_voc.py @@ -20,8 +20,8 @@ train_multi_pipelines = dict( dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True), dict( - type='AttentionRPNCropResizeSupport', - context_pixel=16, + type='CropResizeInstance', + num_context_pixels=16, target_size=(320, 320)), dict(type='RandomFlip', flip_ratio=0.0), dict(type='Normalize', **img_norm_cfg), @@ -48,7 +48,6 @@ data_root = 'data/VOCdevkit/' data = dict( samples_per_gpu=1, workers_per_gpu=2, - copy_random_support=True, train=dict( type='QueryAwareDataset', num_support_ways=None, @@ -94,7 +93,8 @@ data = dict( pipeline=test_pipeline, test_mode=True, classes=None), - support_template=dict( + model_init=dict( + copy_from_train_dataset=True, samples_per_gpu=16, workers_per_gpu=1, type='FewShotVOCDataset', @@ -106,5 +106,5 @@ data = dict( num_novel_shots=None, classes=None, min_bbox_area_filter=32 * 32, - dataset_name='support template')) + dataset_name='model_init')) evaluation = dict(interval=3000, metric='mAP', class_splits=None) diff --git a/mmfewshot/builders/dataset_builder.py b/mmfewshot/builders/dataset_builder.py index 9919d5f..1dcd751 100644 --- a/mmfewshot/builders/dataset_builder.py +++ b/mmfewshot/builders/dataset_builder.py @@ -1,4 +1,4 @@ -# this file only for unittests +# this file is only used for testing the model from mmcls.datasets.builder import build_dataloader as build_cls_dataloader from mmcls.datasets.builder import build_dataset as build_cls_dataset diff --git a/mmfewshot/builders/model_builder.py b/mmfewshot/builders/model_builder.py index ecafc2c..d80c637 100644 --- a/mmfewshot/builders/model_builder.py +++ b/mmfewshot/builders/model_builder.py @@ -1,4 +1,4 @@ -# this file only for unittests +# this file is only used for testing the model from mmcls.models.builder import build_classifier as build_cls_model from mmfewshot.detection.models import build_detector as build_det_model diff --git a/mmfewshot/detection/apis/__init__.py b/mmfewshot/detection/apis/__init__.py index c9191aa..ea58d86 100644 --- a/mmfewshot/detection/apis/__init__.py +++ b/mmfewshot/detection/apis/__init__.py @@ -1,8 +1,9 @@ -from .test import (multi_gpu_extract_support_template, - single_gpu_extract_support_template) +from .test import (multi_gpu_model_init, multi_gpu_test, single_gpu_model_init, + single_gpu_test) from .train import get_root_logger, set_random_seed, train_detector __all__ = [ 'get_root_logger', 'set_random_seed', 'train_detector', - 'single_gpu_extract_support_template', 'multi_gpu_extract_support_template' + 'single_gpu_model_init', 'multi_gpu_model_init', 'single_gpu_test', + 'multi_gpu_test' ] diff --git a/mmfewshot/detection/apis/test.py b/mmfewshot/detection/apis/test.py new file mode 100644 index 0000000..9cbf3d1 --- /dev/null +++ b/mmfewshot/detection/apis/test.py @@ -0,0 +1,190 @@ +import os.path as osp +import time + +import mmcv +import torch +from mmcv.image import tensor2imgs +from mmcv.runner import get_dist_info +from mmdet.apis.test import collect_results_cpu, collect_results_gpu +from mmdet.core import encode_mask_results +from mmdet.utils import get_root_logger + + +def single_gpu_test(model, + data_loader, + show=False, + out_dir=None, + show_score_thr=0.3): + """Test model with single gpu for meta-learning based detector. + + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_dir (str or None): The dir to write the image. + Default: None. + show_score_thr (float, optional): Minimum score of bboxes to be shown. + Default: 0.3. + + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(mode='test', rescale=True, **data) + + batch_size = len(result) + if show or out_dir: + # make sure each time only one image to be shown + if batch_size == 1 and isinstance(data['img'][0], torch.Tensor): + img_tensor = data['img'][0] + else: + img_tensor = data['img'][0].data[0] + img_metas = data['img_metas'][0].data[0] + imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) + assert len(imgs) == len(img_metas) + + for i, (img, img_meta) in enumerate(zip(imgs, img_metas)): + h, w, _ = img_meta['img_shape'] + img_show = img[:h, :w, :] + + ori_h, ori_w = img_meta['ori_shape'][:-1] + img_show = mmcv.imresize(img_show, (ori_w, ori_h)) + + if out_dir: + out_file = osp.join(out_dir, img_meta['ori_filename']) + else: + out_file = None + + model.module.show_result( + img_show, + result[i], + show=show, + out_file=out_file, + score_thr=show_score_thr) + + # encode mask results + if isinstance(result[0], tuple): + result = [(bbox_results, encode_mask_results(mask_results)) + for bbox_results, mask_results in result] + results.extend(result) + + for _ in range(batch_size): + prog_bar.update() + return results + + +def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): + """Test model with multiple gpus for meta-learning based detector. + + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' + it encodes results to gpu tensors and use gpu communication for results + collection. On cpu mode it saves the results on different gpus to 'tmpdir' + and collects them by the rank 0 worker. + + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = mmcv.ProgressBar(len(dataset)) + time.sleep(2) # This line can prevent deadlock problem in some cases. + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(mode='test', rescale=True, **data) + # encode mask results + if isinstance(result[0], tuple): + result = [(bbox_results, encode_mask_results(mask_results)) + for bbox_results, mask_results in result] + results.extend(result) + + if rank == 0: + batch_size = len(result) + for _ in range(batch_size * world_size): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) + return results + + +def single_gpu_model_init(model, data_loader): + """Extracting support template features for meta-learning methods in query- + support fashion with single gpu. + + Args: + model (nn.Module): Model used for extract support template features. + data_loader (nn.Dataloader): Pytorch data loader. + + Returns: + list[Tensor]: Extracted support template features. + """ + model.eval() + results = [] + dataset = data_loader.dataset + logger = get_root_logger() + logger.info('starting model initialization...') + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(mode='model_init', **data) + results.append(result) + for _ in range(len(data['img_metas'].data[0])): + prog_bar.update() + model.module.model_init() + logger.info('model initialization done.') + + return results + + +def multi_gpu_model_init(model, data_loader): + """Extracting support template features for meta-learning methods in query- + support fashion with multi gpus. + + Args: + model (nn.Module): Model used for extract support template features. + data_loader (nn.Dataloader): Pytorch data loader. + + Returns: + list[Tensor]: Extracted support template features. + """ + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + logger = get_root_logger() + logger.info('starting model initialization...') + prog_bar = mmcv.ProgressBar(len(dataset)) + time.sleep(2) # This line can prevent deadlock problem in some cases. + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(mode='model_init', **data) + results.append(result) + if rank == 0: + for _ in range(len(data['img_metas'].data[0])): + prog_bar.update() + model.module.model_init() + if rank == 0: + logger.info('model initialization done.') + return results diff --git a/mmfewshot/detection/apis/train.py b/mmfewshot/detection/apis/train.py index 9da0e94..16730a7 100644 --- a/mmfewshot/detection/apis/train.py +++ b/mmfewshot/detection/apis/train.py @@ -12,6 +12,8 @@ from mmdet.core import DistEvalHook, EvalHook from mmdet.datasets import replace_ImageToTensor from mmdet.utils import get_root_logger +from mmfewshot.detection.core import (QuerySupportDistEvalHook, + QuerySupportEvalHook) from mmfewshot.detection.datasets import build_dataloader, build_dataset @@ -145,8 +147,51 @@ def train_detector(model, shuffle=False) eval_cfg = cfg.get('evaluation', {}) eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' - eval_hook = DistEvalHook if distributed else EvalHook - runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) + + # Prepare dataset for model initialization. In most cases, + # the dataset would be used to generate the support templates. + if cfg.data.get('model_init', None) is not None: + + if cfg.data.model_init.pop('copy_from_train_dataset', False): + if cfg.data.model_init.ann_cfg is not None: + warnings.warn( + 'model_init dataset will copy support ' + 'dataset used for training and original ' + 'ann_cfg will be discarded', UserWarning) + # modify dataset type to support copying data_infos operation + if cfg.data.model_init.type == 'FewShotVOCDataset': + cfg.data.model_init.type = 'FewShotVOCCopyDataset' + elif cfg.data.model_init.type == 'FewShotCocoDataset': + cfg.data.model_init.type = 'FewShotCocoCopyDataset' + else: + raise TypeError(f'{cfg.data.model_init.type} ' + f'not support copy data_infos operation.') + if not hasattr(dataset[0], 'get_support_data_infos'): + raise NotImplementedError( + f'`get_support_data_infos` is not implemented ' + f'in {dataset[0].__class__.__name__}.') + cfg.data.model_init.ann_cfg = [ + dict(data_infos=dataset[0].get_support_data_infos()) + ] + samples_per_gpu = cfg.data.model_init.pop('samples_per_gpu', 1) + workers_per_gpu = cfg.data.model_init.pop('workers_per_gpu', 1) + model_init_dataset = build_dataset(cfg.data.model_init) + # disable `dist` to make all gpu get same data + model_init_dataloader = build_dataloader( + model_init_dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + dist=False, + shuffle=False) + + # eval hook for meta-learning based query-support detector + eval_hook = QuerySupportDistEvalHook \ + if distributed else QuerySupportEvalHook + runner.register_hook( + eval_hook(model_init_dataloader, val_dataloader, **eval_cfg)) + else: + eval_hook = DistEvalHook if distributed else EvalHook + runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) # user-defined hooks if cfg.get('custom_hooks', None): diff --git a/mmfewshot/detection/core/evaluation/eval_hooks.py b/mmfewshot/detection/core/evaluation/eval_hooks.py new file mode 100644 index 0000000..a0869be --- /dev/null +++ b/mmfewshot/detection/core/evaluation/eval_hooks.py @@ -0,0 +1,98 @@ +import os.path as osp + +import torch.distributed as dist +from mmcv.runner import DistEvalHook as BaseDistEvalHook +from mmcv.runner import EvalHook as BaseEvalHook +from torch.nn.modules.batchnorm import _BatchNorm + + +class QuerySupportEvalHook(BaseEvalHook): + """Evaluation hook for query support data pipeline, this hook will first + traverse `model_init_dataloader` to extract support features for model + initialization and then evaluate the data from `val_dataloader`. + + Args: + model_init_dataloader (nn.DataLoader): A PyTorch dataloader of + `model_init` dataset. + val_dataloader (nn.DataLoader): A PyTorch dataloader of dataset to be + evaluated. + **eval_kwargs: Evaluation arguments fed into the evaluate function of + the dataset. + """ + + def __init__(self, model_init_dataloader, val_dataloader, **eval_kwargs): + super().__init__(val_dataloader, **eval_kwargs) + self.model_init_dataloader = model_init_dataloader + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + if not self._should_evaluate(runner): + return + # extract support template features + from mmfewshot.detection.apis import \ + single_gpu_model_init, single_gpu_test + single_gpu_model_init(runner.model, self.model_init_dataloader) + results = single_gpu_test(runner.model, self.dataloader, show=False) + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + if self.save_best: + self._save_ckpt(runner, key_score) + + +class QuerySupportDistEvalHook(BaseDistEvalHook): + """Distributed evaluation hook for query support data pipeline, this hook + will first traverse `model_init_dataloader` to extract support features for + model initialization and then evaluate the data from `val_dataloader`. + + Args: + model_init_dataloader (nn.DataLoader): A PyTorch dataloader of + `model_init` dataset. + val_dataloader (nn.DataLoader): A PyTorch dataloader of dataset to be + evaluated. + **eval_kwargs: Evaluation arguments fed into the evaluate function of + the dataset. + """ + + def __init__(self, model_init_dataloader, val_dataloader, **eval_kwargs): + super().__init__(val_dataloader, **eval_kwargs) + self.model_init_dataloader = model_init_dataloader + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, + _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + + if not self._should_evaluate(runner): + return + + tmpdir = self.tmpdir + if tmpdir is None: + tmpdir = osp.join(runner.work_dir, '.eval_hook') + + # extract support template features + from mmfewshot.detection.apis import \ + multi_gpu_model_init, multi_gpu_test + multi_gpu_model_init(runner.model, self.model_init_dataloader) + + results = multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=tmpdir, + gpu_collect=self.gpu_collect) + if runner.rank == 0: + print('\n') + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + + if self.save_best: + self._save_ckpt(runner, key_score) diff --git a/mmfewshot/detection/datasets/__init__.py b/mmfewshot/detection/datasets/__init__.py index 1266e60..d1bdbce 100644 --- a/mmfewshot/detection/datasets/__init__.py +++ b/mmfewshot/detection/datasets/__init__.py @@ -3,14 +3,14 @@ from .coco import FewShotCocoDataset from .dataloader_wrappers import NwayKshotDataloader from .dataset_wrappers import NwayKshotDataset, QueryAwareDataset from .few_shot_custom import FewShotCustomDataset -from .pipelines import AttentionRPNCropResizeSupport, ResizeWithMask -from .utils import NumpyEncoder, Visualizer, query_support_collate_fn +from .pipelines import CropResizeInstance, GenerateMask +from .utils import NumpyEncoder, query_support_collate_fn +from .visualize import Visualizer from .voc import FewShotVOCDataset __all__ = [ 'build_dataloader', 'build_dataset', 'QueryAwareDataset', 'NwayKshotDataset', 'NwayKshotDataloader', 'query_support_collate_fn', 'FewShotCustomDataset', 'FewShotVOCDataset', 'FewShotCocoDataset', - 'AttentionRPNCropResizeSupport', 'ResizeWithMask', 'NumpyEncoder', - 'Visualizer' + 'CropResizeInstance', 'GenerateMask', 'NumpyEncoder', 'Visualizer' ] diff --git a/mmfewshot/detection/datasets/coco.py b/mmfewshot/detection/datasets/coco.py index 93eed9e..fe18854 100644 --- a/mmfewshot/detection/datasets/coco.py +++ b/mmfewshot/detection/datasets/coco.py @@ -706,7 +706,7 @@ class FewShotCocoDataset(FewShotCustomDataset): @DATASETS.register_module() class FewShotCocoCopyDataset(FewShotCocoDataset): - """For some meta learning method, the random sampled sampled support data + """For some meta-learning method, the random sampled sampled support data is required for evaluation. FewShotVOCCopyDataset allow copy diff --git a/mmfewshot/detection/datasets/dataset_wrappers.py b/mmfewshot/detection/datasets/dataset_wrappers.py index b2a4193..635f789 100644 --- a/mmfewshot/detection/datasets/dataset_wrappers.py +++ b/mmfewshot/detection/datasets/dataset_wrappers.py @@ -231,7 +231,7 @@ class QueryAwareDataset(object): def get_support_data_infos(self): """Return data_infos of support dataset.""" - return self.support_dataset.data_infos + return copy.deepcopy(self.support_dataset.data_infos) @DATASETS.register_module() @@ -385,7 +385,7 @@ class NwayKshotDataset(object): dataset_len: Length of pre sample batch indexes. Returns: - List[List[(data_idx, gt_idx)]]: Pre sample batch indexes. + list[list[(data_idx, gt_idx)]]: Pre sample batch indexes. """ total_index = [] for _ in range(dataset_len): @@ -429,11 +429,11 @@ class NwayKshotDataset(object): def get_support_data_infos(self): """Get support data infos from batch index.""" - return [ + return copy.deepcopy([ self._get_shot_data_info(idx, gt_idx) for class_name in self.data_infos_by_class.keys() for (idx, gt_idx) in self.data_infos_by_class[class_name] - ] + ]) def _get_shot_data_info(self, idx, gt_idx): """Get data info by idx and gt idx.""" diff --git a/mmfewshot/detection/datasets/pipelines/__init__.py b/mmfewshot/detection/datasets/pipelines/__init__.py new file mode 100644 index 0000000..b2be64d --- /dev/null +++ b/mmfewshot/detection/datasets/pipelines/__init__.py @@ -0,0 +1,3 @@ +from .transforms import CropResizeInstance, GenerateMask + +__all__ = ['CropResizeInstance', 'GenerateMask'] diff --git a/mmfewshot/detection/datasets/pipelines/transforms.py b/mmfewshot/detection/datasets/pipelines/transforms.py new file mode 100644 index 0000000..cce2689 --- /dev/null +++ b/mmfewshot/detection/datasets/pipelines/transforms.py @@ -0,0 +1,230 @@ +import math + +import mmcv +import numpy as np +from mmdet.datasets import PIPELINES + +# TODO: Simplify pipelines by decoupling operation. + + +@PIPELINES.register_module() +class CropResizeInstance(object): + """Crop and resize instance according to bbox form image. + + Args: + num_context_pixels (int): Padding pixel around instance. Default: 16. + target_size (tuple[int, int]): Resize cropped instance to target size. + Default: (320, 320). + """ + + def __init__(self, num_context_pixels=16, target_size=(320, 320)): + assert isinstance(num_context_pixels, int) + assert len(target_size) == 2, 'target_size' + self.num_context_pixels = num_context_pixels + self.target_size = target_size + + def __call__(self, results): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Cropped and resized instance results. + """ + img = results['img'] + gt_bbox = results['gt_bboxes'] + img_h, img_w = img.shape[:2] # h, w + x1, y1, x2, y2 = list(map(int, gt_bbox.tolist()[0])) + + bbox_w = x2 - x1 + bbox_h = y2 - y1 + t_x1, t_y1, t_x2, t_y2 = 0, 0, bbox_w, bbox_h + + if bbox_w >= bbox_h: + crop_x1 = x1 - self.num_context_pixels + crop_x2 = x2 + self.num_context_pixels + # t_x1 and t_x2 will change when crop context or overflow + t_x1 = t_x1 + self.num_context_pixels + t_x2 = t_x1 + bbox_w + if crop_x1 < 0: + t_x1 = t_x1 + crop_x1 + t_x2 = t_x1 + bbox_w + crop_x1 = 0 + if crop_x2 > img_w: + crop_x2 = img_w + + short_size = bbox_h + long_size = crop_x2 - crop_x1 + y_center = int((y2 + y1) / 2) # math.ceil((y2 + y1) / 2) + crop_y1 = int( + y_center - + (long_size / 2)) # int(y_center - math.ceil(long_size / 2)) + crop_y2 = int( + y_center + + (long_size / 2)) # int(y_center + math.floor(long_size / 2)) + + # t_y1 and t_y2 will change when crop context or overflow + t_y1 = t_y1 + math.ceil((long_size - short_size) / 2) + t_y2 = t_y1 + bbox_h + + if crop_y1 < 0: + t_y1 = t_y1 + crop_y1 + t_y2 = t_y1 + bbox_h + crop_y1 = 0 + if crop_y2 > img_h: + crop_y2 = img_h + + crop_short_size = crop_y2 - crop_y1 + crop_long_size = crop_x2 - crop_x1 + + square = np.zeros((crop_long_size, crop_long_size, 3), + dtype=np.uint8) + delta = int( + (crop_long_size - crop_short_size) / + 2) # int(math.ceil((crop_long_size - crop_short_size) / 2)) + square_y1 = delta + square_y2 = delta + crop_short_size + + t_y1 = t_y1 + delta + t_y2 = t_y2 + delta + + crop_box = img[crop_y1:crop_y2, crop_x1:crop_x2, :] + square[square_y1:square_y2, :, :] = crop_box + else: + crop_y1 = y1 - self.num_context_pixels + crop_y2 = y2 + self.num_context_pixels + + # t_y1 and t_y2 will change when crop context or overflow + t_y1 = t_y1 + self.num_context_pixels + t_y2 = t_y1 + bbox_h + if crop_y1 < 0: + t_y1 = t_y1 + crop_y1 + t_y2 = t_y1 + bbox_h + crop_y1 = 0 + if crop_y2 > img_h: + crop_y2 = img_h + + short_size = bbox_w + long_size = crop_y2 - crop_y1 + x_center = int((x2 + x1) / 2) # math.ceil((x2 + x1) / 2) + crop_x1 = int( + x_center - + (long_size / 2)) # int(x_center - math.ceil(long_size / 2)) + crop_x2 = int( + x_center + + (long_size / 2)) # int(x_center + math.floor(long_size / 2)) + + # t_x1 and t_x2 will change when crop context or overflow + t_x1 = t_x1 + math.ceil((long_size - short_size) / 2) + t_x2 = t_x1 + bbox_w + if crop_x1 < 0: + t_x1 = t_x1 + crop_x1 + t_x2 = t_x1 + bbox_w + crop_x1 = 0 + if crop_x2 > img_w: + crop_x2 = img_w + + crop_short_size = crop_x2 - crop_x1 + crop_long_size = crop_y2 - crop_y1 + square = np.zeros((crop_long_size, crop_long_size, 3), + dtype=np.uint8) + delta = int( + (crop_long_size - crop_short_size) / + 2) # int(math.ceil((crop_long_size - crop_short_size) / 2)) + square_x1 = delta + square_x2 = delta + crop_short_size + + t_x1 = t_x1 + delta + t_x2 = t_x2 + delta + crop_box = img[crop_y1:crop_y2, crop_x1:crop_x2, :] + square[:, square_x1:square_x2, :] = crop_box + + square = square.astype(np.float32, copy=False) + square, square_scale = mmcv.imrescale( + square, self.target_size, return_scale=True, backend='cv2') + + square = square.astype(np.uint8) + + t_x1 = int(t_x1 * square_scale) + t_y1 = int(t_y1 * square_scale) + t_x2 = int(t_x2 * square_scale) + t_y2 = int(t_y2 * square_scale) + results['img'] = square + results['img_shape'] = img.shape + results['gt_bboxes'] = np.array([t_x1, t_y1, t_x2, + t_y2]).astype(np.float32) + + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(num_context_pixels={self.num_context_pixels},' \ + f' target_size={self.target_size})' + + +@PIPELINES.register_module() +class GenerateMask(object): + """Resize support image and generate a mask. + + Args: + target_size (tuple[int, int]): Crop and resize to target size. + Default: (224, 224). + """ + + def __init__(self, target_size=(224, 224)): + self.target_size = target_size + + def _resize_bboxes(self, results): + """Resize bounding boxes with ``results['scale_factor']``.""" + for key in results.get('bbox_fields', []): + bboxes = results[key] * results['scale_factor'] + results[key] = bboxes + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + for key in results.get('img_fields', ['img']): + img, w_scale, h_scale = mmcv.imresize( + results[key], + self.target_size, + return_scale=True, + backend='cv2') + results[key] = img + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + results['img_shape'] = img.shape + # in case that there is no padding + results['pad_shape'] = img.shape + results['scale_factor'] = scale_factor + + def _generate_mask(self, results): + mask = np.zeros(self.target_size, dtype=np.float32) + gt_bboxes = results['gt_bboxes'][0] + mask[int(gt_bboxes[1]):int(gt_bboxes[3]), + int(gt_bboxes[0]):int(gt_bboxes[2])] = 1 + results['img'] = np.concatenate( + [results['img'], np.expand_dims(mask, axis=2)], axis=2) + + return results + + def __call__(self, results): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized images with additional dimension of bbox mask. + """ + self._resize_img(results) + self._resize_bboxes(results) + self._generate_mask(results) + + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(num_context_pixels={self.num_context_pixels},' \ + f' target_size={self.target_size})' diff --git a/mmfewshot/detection/datasets/voc.py b/mmfewshot/detection/datasets/voc.py index 1d54077..e8ad076 100644 --- a/mmfewshot/detection/datasets/voc.py +++ b/mmfewshot/detection/datasets/voc.py @@ -473,7 +473,7 @@ class FewShotVOCDataset(FewShotCustomDataset): @DATASETS.register_module() class FewShotVOCCopyDataset(FewShotVOCDataset): - """For some meta learning method, the random sampled sampled support data + """For some meta-learning method, the random sampled sampled support data is required for evaluation. FewShotVOCCopyDataset allow copy diff --git a/mmfewshot/detection/models/__init__.py b/mmfewshot/detection/models/__init__.py new file mode 100644 index 0000000..5913419 --- /dev/null +++ b/mmfewshot/detection/models/__init__.py @@ -0,0 +1,8 @@ +from mmdet.models.builder import * # noqa: F401,F403 + +from .backbones import * # noqa: F401,F403 +from .dense_heads import * # noqa: F401,F403 +from .detectors import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .roi_heads import * # noqa: F401,F403 diff --git a/mmfewshot/detection/models/detectors/__init__.py b/mmfewshot/detection/models/detectors/__init__.py new file mode 100644 index 0000000..6d73524 --- /dev/null +++ b/mmfewshot/detection/models/detectors/__init__.py @@ -0,0 +1,5 @@ +from .attention_rpn import AttentionRPN +from .base_query_support import BaseQuerySupportDetector +from .fsdetview import FsDetView + +__all__ = ['BaseQuerySupportDetector', 'AttentionRPN', 'FsDetView'] diff --git a/mmfewshot/detection/models/detectors/base_query_support.py b/mmfewshot/detection/models/detectors/base_query_support.py new file mode 100644 index 0000000..391a89f --- /dev/null +++ b/mmfewshot/detection/models/detectors/base_query_support.py @@ -0,0 +1,322 @@ +import copy +from abc import abstractmethod + +from mmcv.runner import auto_fp16 +from mmdet.models.builder import (DETECTORS, build_backbone, build_head, + build_neck) +from mmdet.models.detectors import BaseDetector + + +@DETECTORS.register_module() +class BaseQuerySupportDetector(BaseDetector): + """Base class for two-stage detectors in query-support fashion. Query- + support detectors typically consisting of a region proposal network and a + task-specific regression head. There are two data pipelines (query and + support data) for query-support detectors. + + Args: + backbone (dict): Config of the backbone for query data. + neck (dict | None): Config of the neck for query data and + probably for support data. Default: None. + support_backbone (dict | None): Config of the backbone for + support data only. If None, support and query data will + share same backbone. Default: None. + support_neck (dict | None): Config of the neck for support + data only. Default: None. + rpn_head (dict | None): Config of rpn_head. Default: None. + roi_head (dict | None): Config of roi_head. Default: None. + train_cfg (dict | None): Training config. Useless in CenterNet, + but we keep this variable for SingleStageDetector. Default: None. + test_cfg (dict | None): Testing config of CenterNet. Default: None. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + backbone, + neck=None, + support_backbone=None, + support_neck=None, + rpn_head=None, + roi_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(BaseQuerySupportDetector, self).__init__(init_cfg) + backbone.pretrained = pretrained + self.backbone = build_backbone(backbone) + self.neck = build_neck(neck) if neck is not None else None + # if `support_backbone` is None, then support and query pipeline will + # share same backbone. + self.support_backbone = build_backbone(support_backbone) \ + if support_backbone is not None else self.backbone + # support neck only forward support data. + self.support_neck = build_neck(support_neck) \ + if support_neck is not None else None + assert roi_head is not None, 'missing config of roi_head' + # when rpn with aggregation neck, the input of rpn will consist of + # query and support data. otherwise the input of rpn only + # has query data. + self.with_rpn = False + self.rpn_with_support = False + if rpn_head is not None: + self.with_rpn = True + if rpn_head.get('aggregation_neck', None) is not None: + self.rpn_with_support = True + rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None + rpn_head_ = copy.deepcopy(rpn_head) + rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) + self.rpn_head = build_head(rpn_head_) + + if roi_head is not None: + # update train and test cfg here for now + rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None + roi_head.update(train_cfg=rcnn_train_cfg) + roi_head.update(test_cfg=test_cfg.rcnn) + roi_head.pretrained = pretrained + self.roi_head = build_head(roi_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def extract_query_feat(self, img): + """Extract features of query data. + + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + + Returns: + list[Tensor]: Features of query images. + """ + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def extract_feat(self, img): + """Extract features of query data. + + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + + Returns: + list[Tensor]: Features of query images. + """ + return self.extract_query_feat(img) + + @abstractmethod + def extract_support_feat(self, img, gt_bboxes=None): + """Extract features of support data.""" + raise NotImplementedError + + @auto_fp16(apply_to=('img', )) + def forward(self, + query_data=None, + support_data=None, + img=None, + img_metas=None, + mode='train', + **kwargs): + """Calls one of (:func:`forward_train`, :func:`forward_test` and + :func:`forward_model_init`) depending on which `mode`. Note this + setting will change the expected inputs of the corresponding function. + + - When `mode` is 'train', the input will be query and support data + for training. + + - When `mode` is 'model_init', the input will be support template + data at least including (img, img_metas). + + - When `mode` is 'test', the input will be test data at least + including (img, img_metas). + + Args: + query_data (dict): Used for :func:`forward_train`. Dict of + query data and data info where each dict has: `img`, + `img_metas`, `gt_bboxes`, `gt_labels`, `gt_bboxes_ignore`. + Default: None. + support_data (dict): Used for :func:`forward_train`. Dict of + support data and data info dict where each dict has: `img`, + `img_metas`, `gt_bboxes`, `gt_labels`, `gt_bboxes_ignore`. + Default: None. + img (list[Tensor]): Used for func:`forward_test` or + :func:`forward_model_init`. List of tensors of shape + (1, C, H, W). Typically these should be mean centered + and std scaled. Default: None. + img_metas (list[dict]): Used for func:`forward_test` or + :func:`forward_model_init`. List of image info dict + where each dict has: `img_shape`, `scale_factor`, `flip`, + and may also contain `filename`, `ori_shape`, `pad_shape`, + and `img_norm_cfg`. For details on the values of these keys, + see :class:`mmdet.datasets.pipelines.Collect`. Default: None. + mode (str): Indicate which function to call. Options are 'train', + 'model_init' and 'test'. Default: 'train'. + """ + if mode == 'train': + return self.forward_train(query_data, support_data, **kwargs) + elif mode == 'model_init': + return self.forward_model_init(img, img_metas, **kwargs) + elif mode == 'test': + return self.forward_test(img, img_metas, **kwargs) + else: + raise ValueError(f'invalid forward mode {mode}.') + + def train_step(self, data, optimizer): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. For most of query-support detectors, the + batch size denote the batch size of query data. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + + - ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + - ``log_vars`` contains all the variables to be sent to the + logger. + - ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + # For most of query-support detectors, the batch size denote the + # batch size of query data. + outputs = dict( + loss=loss, + log_vars=log_vars, + num_samples=len(data['query_data']['img_metas'])) + + return outputs + + def val_step(self, data, optimizer=None): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + # For most of query-support detectors, the batch size denote the + # batch size of query data. + outputs = dict( + loss=loss, + log_vars=log_vars, + num_samples=len(data['query_data']['img_metas'])) + + return outputs + + def forward_train(self, + query_data, + support_data, + proposals=None, + **kwargs): + """ + Args: + query_data (dict): In most cases, dict of query data contains: + `img`, `img_metas`, `gt_bboxes`, `gt_labels`, + `gt_bboxes_ignore`. + support_data (dict): In most cases, dict of support data contains: + `img`, `img_metas`, `gt_bboxes`, `gt_labels`, + `gt_bboxes_ignore`. + proposals (list): Override rpn proposals with custom proposals. + Use when `with_rpn` is False. Default: None. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + query_img = query_data['img'] + support_img = support_data['img'] + query_x = self.extract_query_feat(query_img) + support_x = self.extract_support_feat(support_img, + support_data['gt_bboxes']) + + losses = dict() + + # RPN forward and loss + if self.with_rpn: + proposal_cfg = self.train_cfg.get('rpn_proposal', + self.test_cfg.rpn) + if self.rpn_with_support: + x = (query_x, support_x) + else: + x = query_x + rpn_losses, proposal_list = self.rpn_head.forward_train( + x, + copy.deepcopy(query_data['img_metas']), + copy.deepcopy(query_data['gt_bboxes']), + gt_labels=None, + gt_bboxes_ignore=copy.deepcopy( + query_data.get('gt_bboxes_ignore', None)), + proposal_cfg=proposal_cfg) + losses.update(rpn_losses) + else: + proposal_list = proposals + + # ROI head forward and loss + x = (query_x, support_x) + img_metas = (query_data['img_metas'], support_data['img_metas']) + gt_bboxes = (query_data['gt_bboxes'], support_data['gt_bboxes']) + gt_labels = (query_data['gt_labels'], support_data['gt_labels']) + gt_bboxes_ignore = (query_data.get('gt_bboxes_ignore', None), + support_data.get('gt_bboxes_ignore', None)) + + roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, + gt_bboxes, gt_labels, + gt_bboxes_ignore, **kwargs) + losses.update(roi_losses) + + return losses + + def forward_dummy(self, img): + """Used for computing network flops. + + See `mmdetection/tools/analysis_tools/get_flops.py` + """ + raise NotImplementedError + + def simple_test(self, img, img_metas, proposals=None, rescale=False): + """Test without augmentation.""" + raise NotImplementedError + + async def async_simple_test(self, **kwargs): + """Async test without augmentation.""" + raise NotImplementedError + + def aug_test(self, **kwargs): + """Test with augmentation.""" + raise NotImplementedError + + @abstractmethod + def forward_model_init(self, + img, + img_metas, + gt_bboxes=None, + gt_labels=None, + **kwargs): + """extract and save support features for model initialization.""" + raise NotImplementedError + + @abstractmethod + def model_init(self, **kwargs): + """process the saved support features for model initialization.""" + raise NotImplementedError