From 9221499af4439cebac868952fd89b2366fd1b637 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?=
 <huanghaian@pjlab.org.cn>
Date: Fri, 25 Nov 2022 16:45:40 +0800
Subject: [PATCH] [Feature] Support grad-based cam and grad-free cam  (#234)

* support cam

* update

* update

* done

* fix lint

* add docstr

* add doc

* update

* update

* FIX
---
 demo/boxam_vis_demo.py                   | 276 +++++++++++++
 demo/featmap_vis_demo.py                 |   6 +-
 docs/en/user_guides/visualization.md     | 120 +++++-
 docs/zh_cn/user_guides/visualization.md  | 118 +++++-
 mmyolo/models/dense_heads/yolov6_head.py |  69 ++--
 mmyolo/utils/boxam_utils.py              | 504 +++++++++++++++++++++++
 6 files changed, 1060 insertions(+), 33 deletions(-)
 create mode 100644 demo/boxam_vis_demo.py
 create mode 100644 mmyolo/utils/boxam_utils.py

diff --git a/demo/boxam_vis_demo.py b/demo/boxam_vis_demo.py
new file mode 100644
index 00000000..3672b727
--- /dev/null
+++ b/demo/boxam_vis_demo.py
@@ -0,0 +1,276 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""This script is in the experimental verification stage and cannot be
+guaranteed to be completely correct. Currently Grad-based CAM and Grad-free CAM
+are supported.
+
+The target detection task is different from the classification task. It not
+only includes the AM map of the category, but also includes information such as
+bbox and mask, so this script is named bboxam.
+"""
+
+import argparse
+import os.path
+import warnings
+from functools import partial
+
+import cv2
+import mmcv
+from mmengine import Config, DictAction, MessageHub
+from mmengine.utils import ProgressBar
+
+from mmyolo.utils import register_all_modules
+from mmyolo.utils.boxam_utils import (BoxAMDetectorVisualizer,
+                                      BoxAMDetectorWrapper, DetAblationLayer,
+                                      DetBoxScoreTarget, GradCAM,
+                                      GradCAMPlusPlus, reshape_transform)
+from mmyolo.utils.misc import get_file_list
+
+try:
+    from pytorch_grad_cam import AblationCAM, EigenCAM
+except ImportError:
+    raise ImportError('Please run `pip install "grad-cam"` to install '
+                      'pytorch_grad_cam package.')
+
+GRAD_FREE_METHOD_MAP = {
+    'ablationcam': AblationCAM,
+    'eigencam': EigenCAM,
+    # 'scorecam': ScoreCAM, # consumes too much memory
+}
+
+GRAD_BASED_METHOD_MAP = {'gradcam': GradCAM, 'gradcam++': GradCAMPlusPlus}
+
+ALL_SUPPORT_METHODS = list(GRAD_FREE_METHOD_MAP.keys()
+                           | GRAD_BASED_METHOD_MAP.keys())
+
+IGNORE_LOSS_PARAMS = {
+    'yolov5': ['loss_obj'],
+    'yolov6': ['loss_cls'],
+    'yolox': ['loss_obj'],
+    'rtmdet': ['loss_cls'],
+}
+
+# This parameter is required in some algorithms
+# for calculating Loss
+message_hub = MessageHub.get_current_instance()
+message_hub.runtime_info['epoch'] = 0
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Visualize Box AM')
+    parser.add_argument(
+        'img', help='Image path, include image file, dir and URL.')
+    parser.add_argument('config', help='Config file')
+    parser.add_argument('checkpoint', help='Checkpoint file')
+    parser.add_argument(
+        '--method',
+        default='gradcam',
+        choices=ALL_SUPPORT_METHODS,
+        help='Type of method to use, supports '
+        f'{", ".join(ALL_SUPPORT_METHODS)}.')
+    parser.add_argument(
+        '--target-layers',
+        default=['neck.out_layers[2]'],
+        nargs='+',
+        type=str,
+        help='The target layers to get Box AM, if not set, the tool will '
+        'specify the neck.out_layers[2]')
+    parser.add_argument(
+        '--out-dir', default='./output', help='Path to output file')
+    parser.add_argument(
+        '--show', action='store_true', help='Show the CAM results')
+    parser.add_argument(
+        '--device', default='cuda:0', help='Device used for inference')
+    parser.add_argument(
+        '--score-thr', type=float, default=0.3, help='Bbox score threshold')
+    parser.add_argument(
+        '--topk',
+        type=int,
+        default=-1,
+        help='Select topk predict resutls to show. -1 are mean all.')
+    parser.add_argument(
+        '--max-shape',
+        nargs='+',
+        type=int,
+        default=-1,
+        help='max shapes. Its purpose is to save GPU memory. '
+        'The activation map is scaled and then evaluated. '
+        'If set to -1, it means no scaling.')
+    parser.add_argument(
+        '--preview-model',
+        default=False,
+        action='store_true',
+        help='To preview all the model layers')
+    parser.add_argument(
+        '--norm-in-bbox', action='store_true', help='Norm in bbox of am image')
+    parser.add_argument(
+        '--cfg-options',
+        nargs='+',
+        action=DictAction,
+        help='override some settings in the used config, the key-value pair '
+        'in xxx=yyy format will be merged into config file. If the value to '
+        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+        'Note that the quotation marks are necessary and that no white space '
+        'is allowed.')
+    # Only used by AblationCAM
+    parser.add_argument(
+        '--batch-size',
+        type=int,
+        default=1,
+        help='batch of inference of AblationCAM')
+    parser.add_argument(
+        '--ratio-channels-to-ablate',
+        type=int,
+        default=0.5,
+        help='Making it much faster of AblationCAM. '
+        'The parameter controls how many channels should be ablated')
+
+    args = parser.parse_args()
+    return args
+
+
+def init_detector_and_visualizer(args, cfg):
+    max_shape = args.max_shape
+    if not isinstance(max_shape, list):
+        max_shape = [args.max_shape]
+    assert len(max_shape) == 1 or len(max_shape) == 2
+
+    model_wrapper = BoxAMDetectorWrapper(
+        cfg, args.checkpoint, args.score_thr, device=args.device)
+
+    if args.preview_model:
+        print(model_wrapper.detector)
+        print('\n Please remove `--preview-model` to get the BoxAM.')
+        return None, None
+
+    target_layers = []
+    for target_layer in args.target_layers:
+        try:
+            target_layers.append(
+                eval(f'model_wrapper.detector.{target_layer}'))
+        except Exception as e:
+            print(model_wrapper.detector)
+            raise RuntimeError('layer does not exist', e)
+
+    ablationcam_extra_params = {
+        'batch_size': args.batch_size,
+        'ablation_layer': DetAblationLayer(),
+        'ratio_channels_to_ablate': args.ratio_channels_to_ablate
+    }
+
+    if args.method in GRAD_BASED_METHOD_MAP:
+        method_class = GRAD_BASED_METHOD_MAP[args.method]
+        is_need_grad = True
+    else:
+        method_class = GRAD_FREE_METHOD_MAP[args.method]
+        is_need_grad = False
+
+    boxam_detector_visualizer = BoxAMDetectorVisualizer(
+        method_class,
+        model_wrapper,
+        target_layers,
+        reshape_transform=partial(
+            reshape_transform, max_shape=max_shape, is_need_grad=is_need_grad),
+        is_need_grad=is_need_grad,
+        extra_params=ablationcam_extra_params)
+    return model_wrapper, boxam_detector_visualizer
+
+
+def main():
+    register_all_modules()
+
+    args = parse_args()
+
+    # hard code
+    ignore_loss_params = None
+    for param_keys in IGNORE_LOSS_PARAMS:
+        if param_keys in args.config:
+            print(f'The algorithm currently used is {param_keys}')
+            ignore_loss_params = IGNORE_LOSS_PARAMS[param_keys]
+            break
+
+    cfg = Config.fromfile(args.config)
+    if args.cfg_options is not None:
+        cfg.merge_from_dict(args.cfg_options)
+
+    if not os.path.exists(args.out_dir) and not args.show:
+        os.mkdir(args.out_dir)
+
+    model_wrapper, boxam_detector_visualizer = init_detector_and_visualizer(
+        args, cfg)
+
+    # get file list
+    image_list, source_type = get_file_list(args.img)
+
+    progress_bar = ProgressBar(len(image_list))
+
+    for image_path in image_list:
+        image = cv2.imread(image_path)
+        model_wrapper.set_input_data(image)
+
+        # forward detection results
+        result = model_wrapper()[0]
+
+        pred_instances = result.pred_instances
+        # Get candidate predict info with score threshold
+        pred_instances = pred_instances[pred_instances.scores > args.score_thr]
+
+        if len(pred_instances) == 0:
+            warnings.warn('empty detection results! skip this')
+            continue
+
+        if args.topk > 0:
+            pred_instances = pred_instances[:args.topk]
+
+        targets = [
+            DetBoxScoreTarget(
+                pred_instances,
+                device=args.device,
+                ignore_loss_params=ignore_loss_params)
+        ]
+
+        if args.method in GRAD_BASED_METHOD_MAP:
+            model_wrapper.need_loss(True)
+            model_wrapper.set_input_data(image, pred_instances)
+            boxam_detector_visualizer.switch_activations_and_grads(
+                model_wrapper)
+
+        # get box am image
+        grayscale_boxam = boxam_detector_visualizer(image, targets=targets)
+
+        # draw cam on image
+        pred_instances = pred_instances.numpy()
+        image_with_bounding_boxes = boxam_detector_visualizer.show_am(
+            image,
+            pred_instances,
+            grayscale_boxam,
+            with_norm_in_bboxes=args.norm_in_bbox)
+
+        if source_type['is_dir']:
+            filename = os.path.relpath(image_path, args.img).replace('/', '_')
+        else:
+            filename = os.path.basename(image_path)
+        out_file = None if args.show else os.path.join(args.out_dir, filename)
+
+        if out_file:
+            mmcv.imwrite(image_with_bounding_boxes, out_file)
+        else:
+            cv2.namedWindow(filename, 0)
+            cv2.imshow(filename, image_with_bounding_boxes)
+            cv2.waitKey(0)
+
+        # switch
+        if args.method in GRAD_BASED_METHOD_MAP:
+            model_wrapper.need_loss(False)
+            boxam_detector_visualizer.switch_activations_and_grads(
+                model_wrapper)
+
+        progress_bar.update()
+
+    if not args.show:
+        print(f'All done!'
+              f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
+
+
+if __name__ == '__main__':
+    main()
diff --git a/demo/featmap_vis_demo.py b/demo/featmap_vis_demo.py
index c9fb8eda..2006c7af 100644
--- a/demo/featmap_vis_demo.py
+++ b/demo/featmap_vis_demo.py
@@ -13,7 +13,6 @@ from mmyolo.utils import register_all_modules
 from mmyolo.utils.misc import auto_arrange_images, get_file_list
 
 
-# TODO: Refine
 def parse_args():
     parser = argparse.ArgumentParser(description='Visualize feature map')
     parser.add_argument(
@@ -190,8 +189,9 @@ def main():
         if args.show:
             visualizer.show(shown_imgs)
 
-    print(f'All done!'
-          f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
+    if not args.show:
+        print(f'All done!'
+              f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
 
 
 # Please refer to the usage tutorial:
diff --git a/docs/en/user_guides/visualization.md b/docs/en/user_guides/visualization.md
index eb5a530c..7835a434 100644
--- a/docs/en/user_guides/visualization.md
+++ b/docs/en/user_guides/visualization.md
@@ -1,5 +1,7 @@
 # Visualization
 
+This article includes feature map visualization and Grad-Based and Grad-Free CAM visualization
+
 ## Feature map visualization
 
 <div align=center>
@@ -13,7 +15,7 @@ In MMYOLO, you can use the `Visualizer` provided in MMEngine for feature map vis
 - Support basic drawing interfaces and feature map visualization.
 - Support selecting different layers in the model to get the feature map. The display methods include `squeeze_mean`, `select_max`, and `topk`. Users can also customize the layout of the feature map display with `arrangement`.
 
-## Feature map generation
+### Feature map generation
 
 You can use `demo/featmap_vis_demo.py` to get a quick view of the visualization results. To better understand all functions, we list all primary parameters and their features here as follows:
 
@@ -51,7 +53,7 @@ You can use `demo/featmap_vis_demo.py` to get a quick view of the visualization
 
 **Note: When the image and feature map scales are different, the `draw_featmap` function will automatically perform an upsampling alignment. If your image has an operation such as `Pad` in the preprocessing during the inference, the feature map obtained is processed with `Pad`, which may cause misalignment problems if you directly upsample the image.**
 
-## Usage examples
+### Usage examples
 
 Take the pre-trained YOLOv5-s model as an example. Please download the model weight file to the root directory.
 
@@ -88,7 +90,7 @@ The original `test_pipeline` is:
 test_pipeline = [
     dict(
         type='LoadImageFromFile',
-        file_client_args={{_base_.file_client_args}}),
+        file_client_args=_base_.file_client_args),
     dict(type='YOLOv5KeepRatioResize', scale=img_scale),
     dict(
         type='LetterResize',
@@ -166,7 +168,7 @@ python demo/featmap_vis_demo.py demo/dog.jpg \
 ```
 
 <div align=center>
-<img src="https://user-images.githubusercontent.com/17425982/198522489-8adee6ae-9915-4e9d-bf50-167b8a12c275.png" width="1200" alt="image"/>
+<img src="https://user-images.githubusercontent.com/17425982/198522489-8adee6ae-9915-4e9d-bf50-167b8a12c275.png" width="800" alt="image"/>
 </div>
 
 (5) When the visualization process finishes, you can choose to display the result or store it locally. You only need to add the parameter `--out-file xxx.jpg`:
@@ -179,3 +181,113 @@ python demo/featmap_vis_demo.py demo/dog.jpg \
                                 --channel-reduction select_max \
                                 --out-file featmap_backbone.jpg
 ```
+
+## Grad-Based and Grad-Free CAM Visualization
+
+Object detection CAM visualization is much more complex and different than classification CAM.
+This article only briefly explains the usage, and a separate document will be opened to describe the implementation principles and precautions in detail later.
+
+You can call `demo/boxmap_vis_demo.py` to get the AM visualization results at the Box level easily and quickly. Currently, `YOLOv5/YOLOv6/YOLOX/RTMDet` is supported.
+
+Taking YOLOv5 as an example, as with the feature map visualization, you need to modify the `test_pipeline` first, otherwise there will be a problem of misalignment between the feature map and the original image.
+
+The original `test_pipeline` is:
+
+```python
+test_pipeline = [
+    dict(
+        type='LoadImageFromFile',
+        file_client_args=_base_.file_client_args),
+    dict(type='YOLOv5KeepRatioResize', scale=img_scale),
+    dict(
+        type='LetterResize',
+        scale=img_scale,
+        allow_scale_up=False,
+        pad_val=dict(img=114)),
+    dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
+    dict(
+        type='mmdet.PackDetInputs',
+        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+                   'scale_factor', 'pad_param'))
+]
+```
+
+Change to the following version:
+
+```python
+test_pipeline = [
+    dict(
+        type='LoadImageFromFile',
+        file_client_args=_base_.file_client_args),
+    dict(type='mmdet.Resize', scale=img_scale, keep_ratio=False), # change the  LetterResize to mmdet.Resize
+    dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
+    dict(
+        type='mmdet.PackDetInputs',
+        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+                   'scale_factor'))
+]
+```
+
+(1) Use the `GradCAM` method to visualize the AM of the last output layer of the neck module
+
+```shell
+python demo/boxam_vis_demo.py \
+        demo/dog.jpg \
+        configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
+        yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth
+```
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/17425982/203775584-c4aebf11-4ff8-4530-85fe-7dda897e95a8.jpg" width="800" alt="image"/>
+</div>
+
+The corresponding feature AM is as follows:
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/17425982/203774801-1555bcfb-a8f9-4688-8ed6-982d6ad38e1d.jpg" width="800" alt="image"/>
+</div>
+
+It can be seen that the `GradCAM` effect can highlight the AM information at the box level.
+
+You can choose to visualize only the top prediction boxes with the highest prediction scores via the `--topk` parameter
+
+```shell
+python demo/boxam_vis_demo.py \
+        demo/dog.jpg \
+        configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
+        yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
+        --topk 2
+```
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/17425982/203778700-3165aa72-ecaf-40cc-b470-6911646e6046.jpg" width="800" alt="image"/>
+</div>
+
+(2) Use the AblationCAM method to visualize the AM of the last output layer of the neck module
+
+```shell
+python demo/boxam_vis_demo.py \
+        demo/dog.jpg \
+        configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
+        yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
+        --method ablationcam
+```
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/17425982/203776978-b5a9b383-93b4-4b35-9e6a-7cac684b372c.jpg" width="800" alt="image"/>
+</div>
+
+Since `AblationCAM` is weighted by the contribution of each channel to the score, it is impossible to visualize only the AM information at the box level like `GradCAN`. But you can use `--norm-in-bbox` to only show bbox inside AM
+
+```shell
+python demo/boxam_vis_demo.py \
+        demo/dog.jpg \
+        configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
+        yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
+        --method ablationcam \
+        --norm-in-bbox
+```
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/17425982/203777566-7c74e82f-b477-488e-958f-91e1d10833b9.jpg" width="800" alt="image"/>
+</div>
diff --git a/docs/zh_cn/user_guides/visualization.md b/docs/zh_cn/user_guides/visualization.md
index d8bd051b..e5975eed 100644
--- a/docs/zh_cn/user_guides/visualization.md
+++ b/docs/zh_cn/user_guides/visualization.md
@@ -1,5 +1,7 @@
 # 可视化
 
+本文包括特征图可视化和 Grad-Based 和 Grad-Free CAM 可视化
+
 ## 特征图可视化
 
 <div align=center>
@@ -12,7 +14,7 @@ MMYOLO 中,将使用 MMEngine 提供的 `Visualizer` 可视化器进行特征
 - 支持基础绘图接口以及特征图可视化。
 - 支持选择模型中的不同层来得到特征图,包含 `squeeze_mean` , `select_max` , `topk` 三种显示方式,用户还可以使用 `arrangement` 自定义特征图显示的布局方式。
 
-## 特征图绘制
+### 特征图绘制
 
 你可以调用 `demo/featmap_vis_demo.py` 来简单快捷地得到可视化结果,为了方便理解,将其主要参数的功能梳理如下:
 
@@ -50,7 +52,7 @@ MMYOLO 中,将使用 MMEngine 提供的 `Visualizer` 可视化器进行特征
 
 **注意:当图片和特征图尺度不一样时候,`draw_featmap` 函数会自动进行上采样对齐。如果你的图片在推理过程中前处理存在类似 Pad 的操作此时得到的特征图也是 Pad 过的,那么直接上采样就可能会出现不对齐问题。**
 
-## 用法示例
+### 用法示例
 
 以预训练好的 YOLOv5-s 模型为例:
 
@@ -167,7 +169,7 @@ python demo/featmap_vis_demo.py demo/dog.jpg \
 ```
 
 <div align=center>
-<img src="https://user-images.githubusercontent.com/17425982/198522489-8adee6ae-9915-4e9d-bf50-167b8a12c275.png" width="1200" alt="image"/>
+<img src="https://user-images.githubusercontent.com/17425982/198522489-8adee6ae-9915-4e9d-bf50-167b8a12c275.png" width="800" alt="image"/>
 </div>
 
 (5) 存储绘制后的图片,在绘制完成后,可以选择本地窗口显示,也可以存储到本地,只需要加入参数 `--out-file xxx.jpg`:
@@ -180,3 +182,113 @@ python demo/featmap_vis_demo.py demo/dog.jpg \
                                 --channel-reduction select_max \
                                 --out-file featmap_backbone.jpg
 ```
+
+## Grad-Based 和 Grad-Free CAM 可视化
+
+目标检测 CAM 可视化相比于分类 CAM 复杂很多且差异很大。本文只是简要说明用法,后续会单独开文档详细描述实现原理和注意事项。
+
+你可以调用 `demo/boxmap_vis_demo.py` 来简单快捷地得到 Box 级别的 AM 可视化结果,目前已经支持 `YOLOv5/YOLOv6/YOLOX/RTMDet`。
+
+以 YOLOv5 为例,和特征图可视化绘制一样,你需要先修改 `test_pipeline`,否则会出现特征图和原图不对齐问题。
+
+旧的 `test_pipeline` 为:
+
+```python
+test_pipeline = [
+    dict(
+        type='LoadImageFromFile',
+        file_client_args=_base_.file_client_args),
+    dict(type='YOLOv5KeepRatioResize', scale=img_scale),
+    dict(
+        type='LetterResize',
+        scale=img_scale,
+        allow_scale_up=False,
+        pad_val=dict(img=114)),
+    dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
+    dict(
+        type='mmdet.PackDetInputs',
+        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+                   'scale_factor', 'pad_param'))
+]
+```
+
+修改为如下配置:
+
+```python
+test_pipeline = [
+    dict(
+        type='LoadImageFromFile',
+        file_client_args=_base_.file_client_args),
+    dict(type='mmdet.Resize', scale=img_scale, keep_ratio=False), # 这里将 LetterResize 修改成 mmdet.Resize
+    dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
+    dict(
+        type='mmdet.PackDetInputs',
+        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+                   'scale_factor'))
+]
+```
+
+(1) 使用 `GradCAM` 方法可视化 neck 模块的最后一个输出层的 AM 图
+
+```shell
+python demo/boxam_vis_demo.py \
+        demo/dog.jpg \
+        configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
+        yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth
+
+```
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/17425982/203775584-c4aebf11-4ff8-4530-85fe-7dda897e95a8.jpg" width="800" alt="image"/>
+</div>
+
+相对应的特征图 AM 图如下:
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/17425982/203774801-1555bcfb-a8f9-4688-8ed6-982d6ad38e1d.jpg" width="800" alt="image"/>
+</div>
+
+可以看出 `GradCAM` 效果可以突出 box 级别的 AM 信息。
+
+你可以通过 `--topk` 参数选择仅仅可视化预测分值最高的前几个预测框
+
+```shell
+python demo/boxam_vis_demo.py \
+        demo/dog.jpg \
+        configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
+        yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
+        --topk 2
+```
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/17425982/203778700-3165aa72-ecaf-40cc-b470-6911646e6046.jpg" width="800" alt="image"/>
+</div>
+
+(2) 使用 `AblationCAM` 方法可视化 neck 模块的最后一个输出层的 AM 图
+
+```shell
+python demo/boxam_vis_demo.py \
+        demo/dog.jpg \
+        configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
+        yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
+        --method ablationcam
+```
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/17425982/203776978-b5a9b383-93b4-4b35-9e6a-7cac684b372c.jpg" width="800" alt="image"/>
+</div>
+
+由于 `AblationCAM` 是通过每个通道对分值的贡献程度来加权,因此无法实现类似 `GradCAM` 的仅仅可视化 box 级别的 AM 信息, 但是你可以使用 `--norm-in-bbox` 来仅仅显示 bbox 内部 AM
+
+```shell
+python demo/boxam_vis_demo.py \
+        demo/dog.jpg \
+        configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
+        yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
+        --method ablationcam \
+        --norm-in-bbox
+```
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/17425982/203777566-7c74e82f-b477-488e-958f-91e1d10833b9.jpg" width="800" alt="image"/>
+</div>
diff --git a/mmyolo/models/dense_heads/yolov6_head.py b/mmyolo/models/dense_heads/yolov6_head.py
index 624db985..b2581ef5 100644
--- a/mmyolo/models/dense_heads/yolov6_head.py
+++ b/mmyolo/models/dense_heads/yolov6_head.py
@@ -377,7 +377,7 @@ class YOLOv6Head(YOLOv5Head):
             loss_cls=loss_cls * world_size, loss_bbox=loss_bbox * world_size)
 
     @staticmethod
-    def gt_instances_preprocess(batch_gt_instances: Tensor,
+    def gt_instances_preprocess(batch_gt_instances: Union[Tensor, Sequence],
                                 batch_size: int) -> Tensor:
         """Split batch_gt_instances with batch size, from [all_gt_bboxes, 6]
         to.
@@ -393,28 +393,51 @@ class YOLOv6Head(YOLOv5Head):
         Returns:
             Tensor: batch gt instances data, shape [batch_size, number_gt, 5]
         """
+        if isinstance(batch_gt_instances, Sequence):
+            max_gt_bbox_len = max(
+                [len(gt_instances) for gt_instances in batch_gt_instances])
+            # fill [-1., 0., 0., 0., 0.] if some shape of
+            # single batch not equal max_gt_bbox_len
+            batch_instance_list = []
+            for index, gt_instance in enumerate(batch_gt_instances):
+                bboxes = gt_instance.bboxes
+                labels = gt_instance.labels
+                batch_instance_list.append(
+                    torch.cat((labels[:, None], bboxes), dim=-1))
 
-        # sqlit batch gt instance [all_gt_bboxes, 6] ->
-        # [batch_size, number_gt_each_batch, 5]
-        batch_instance_list = []
-        max_gt_bbox_len = 0
-        for i in range(batch_size):
-            single_batch_instance = \
-                batch_gt_instances[batch_gt_instances[:, 0] == i, :]
-            single_batch_instance = single_batch_instance[:, 1:]
-            batch_instance_list.append(single_batch_instance)
-            if len(single_batch_instance) > max_gt_bbox_len:
-                max_gt_bbox_len = len(single_batch_instance)
+                if bboxes.shape[0] >= max_gt_bbox_len:
+                    continue
 
-        # fill [-1., 0., 0., 0., 0.] if some shape of
-        # single batch not equal max_gt_bbox_len
-        for index, gt_instance in enumerate(batch_instance_list):
-            if gt_instance.shape[0] >= max_gt_bbox_len:
-                continue
-            fill_tensor = batch_gt_instances.new_full(
-                [max_gt_bbox_len - gt_instance.shape[0], 5], 0)
-            fill_tensor[:, 0] = -1.
-            batch_instance_list[index] = torch.cat(
-                (batch_instance_list[index], fill_tensor), dim=0)
+                fill_tensor = bboxes.new_full(
+                    [max_gt_bbox_len - bboxes.shape[0], 5], 0)
+                fill_tensor[:, 0] = -1.
+                batch_instance_list[index] = torch.cat(
+                    (batch_instance_list[-1], fill_tensor), dim=0)
 
-        return torch.stack(batch_instance_list)
+            return torch.stack(batch_instance_list)
+        else:
+            # faster version
+            # sqlit batch gt instance [all_gt_bboxes, 6] ->
+            # [batch_size, number_gt_each_batch, 5]
+            batch_instance_list = []
+            max_gt_bbox_len = 0
+            for i in range(batch_size):
+                single_batch_instance = \
+                    batch_gt_instances[batch_gt_instances[:, 0] == i, :]
+                single_batch_instance = single_batch_instance[:, 1:]
+                batch_instance_list.append(single_batch_instance)
+                if len(single_batch_instance) > max_gt_bbox_len:
+                    max_gt_bbox_len = len(single_batch_instance)
+
+            # fill [-1., 0., 0., 0., 0.] if some shape of
+            # single batch not equal max_gt_bbox_len
+            for index, gt_instance in enumerate(batch_instance_list):
+                if gt_instance.shape[0] >= max_gt_bbox_len:
+                    continue
+                fill_tensor = batch_gt_instances.new_full(
+                    [max_gt_bbox_len - gt_instance.shape[0], 5], 0)
+                fill_tensor[:, 0] = -1.
+                batch_instance_list[index] = torch.cat(
+                    (batch_instance_list[index], fill_tensor), dim=0)
+
+            return torch.stack(batch_instance_list)
diff --git a/mmyolo/utils/boxam_utils.py b/mmyolo/utils/boxam_utils.py
new file mode 100644
index 00000000..5e1ec913
--- /dev/null
+++ b/mmyolo/utils/boxam_utils.py
@@ -0,0 +1,504 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import bisect
+import copy
+import warnings
+from pathlib import Path
+from typing import Callable, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torchvision
+from mmcv.transforms import Compose
+from mmdet.evaluation import get_classes
+from mmdet.models import build_detector
+from mmdet.utils import ConfigType
+from mmengine.config import Config
+from mmengine.runner import load_checkpoint
+from mmengine.structures import InstanceData
+from torch import Tensor
+
+try:
+    from pytorch_grad_cam import (AblationCAM, AblationLayer,
+                                  ActivationsAndGradients)
+    from pytorch_grad_cam import GradCAM as Base_GradCAM
+    from pytorch_grad_cam import GradCAMPlusPlus as Base_GradCAMPlusPlus
+    from pytorch_grad_cam.base_cam import BaseCAM
+    from pytorch_grad_cam.utils.image import scale_cam_image, show_cam_on_image
+    from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
+except ImportError:
+    pass
+
+
+def init_detector(
+    config: Union[str, Path, Config],
+    checkpoint: Optional[str] = None,
+    palette: str = 'coco',
+    device: str = 'cuda:0',
+    cfg_options: Optional[dict] = None,
+) -> nn.Module:
+    """Initialize a detector from config file.
+
+    Args:
+        config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
+            :obj:`Path`, or the config object.
+        checkpoint (str, optional): Checkpoint path. If left as None, the model
+            will not load any weights.
+        palette (str): Color palette used for visualization. If palette
+            is stored in checkpoint, use checkpoint's palette first, otherwise
+            use externally passed palette. Currently, supports 'coco', 'voc',
+            'citys' and 'random'. Defaults to coco.
+        device (str): The device where the anchors will be put on.
+            Defaults to cuda:0.
+        cfg_options (dict, optional): Options to override some settings in
+            the used config.
+
+    Returns:
+        nn.Module: The constructed detector.
+    """
+    if isinstance(config, (str, Path)):
+        config = Config.fromfile(config)
+    elif not isinstance(config, Config):
+        raise TypeError('config must be a filename or Config object, '
+                        f'but got {type(config)}')
+    if cfg_options is not None:
+        config.merge_from_dict(cfg_options)
+    elif 'init_cfg' in config.model.backbone:
+        config.model.backbone.init_cfg = None
+
+    # only change this
+    # grad based method requires train_cfg
+    # config.model.train_cfg = None
+
+    model = build_detector(config.model)
+    if checkpoint is not None:
+        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
+        # Weights converted from elsewhere may not have meta fields.
+        checkpoint_meta = checkpoint.get('meta', {})
+        # save the dataset_meta in the model for convenience
+        if 'dataset_meta' in checkpoint_meta:
+            # mmdet 3.x
+            model.dataset_meta = checkpoint_meta['dataset_meta']
+        elif 'CLASSES' in checkpoint_meta:
+            # < mmdet 3.x
+            classes = checkpoint_meta['CLASSES']
+            model.dataset_meta = {'CLASSES': classes, 'PALETTE': palette}
+        else:
+            warnings.simplefilter('once')
+            warnings.warn(
+                'dataset_meta or class names are not saved in the '
+                'checkpoint\'s meta data, use COCO classes by default.')
+            model.dataset_meta = {
+                'CLASSES': get_classes('coco'),
+                'PALETTE': palette
+            }
+
+    model.cfg = config  # save the config in the model for convenience
+    model.to(device)
+    model.eval()
+    return model
+
+
+def reshape_transform(feats: Union[Tensor, List[Tensor]],
+                      max_shape: Tuple[int, int] = (20, 20),
+                      is_need_grad: bool = False):
+    """Reshape and aggregate feature maps when the input is a multi-layer
+    feature map.
+
+    Takes these tensors with different sizes, resizes them to a common shape,
+    and concatenates them.
+    """
+    if len(max_shape) == 1:
+        max_shape = max_shape * 2
+
+    if isinstance(feats, torch.Tensor):
+        feats = [feats]
+    else:
+        if is_need_grad:
+            raise NotImplementedError('The `grad_base` method does not '
+                                      'support output multi-activation layers')
+
+    max_h = max([im.shape[-2] for im in feats])
+    max_w = max([im.shape[-1] for im in feats])
+    if -1 in max_shape:
+        max_shape = (max_h, max_w)
+    else:
+        max_shape = (min(max_h, max_shape[0]), min(max_w, max_shape[1]))
+
+    activations = []
+    for feat in feats:
+        activations.append(
+            torch.nn.functional.interpolate(
+                torch.abs(feat), max_shape, mode='bilinear'))
+
+    activations = torch.cat(activations, axis=1)
+    return activations
+
+
+class BoxAMDetectorWrapper(nn.Module):
+    """Wrap the mmdet model class to facilitate handling of non-tensor
+    situations during inference."""
+
+    def __init__(self,
+                 cfg: ConfigType,
+                 checkpoint: str,
+                 score_thr: float,
+                 device: str = 'cuda:0'):
+        super().__init__()
+        self.cfg = cfg
+        self.device = device
+        self.score_thr = score_thr
+        self.checkpoint = checkpoint
+        self.detector = init_detector(self.cfg, self.checkpoint, device=device)
+
+        pipeline_cfg = copy.deepcopy(self.cfg.test_dataloader.dataset.pipeline)
+        pipeline_cfg[0].type = 'mmdet.LoadImageFromNDArray'
+
+        new_test_pipeline = []
+        for pipeline in pipeline_cfg:
+            if not pipeline['type'].endswith('LoadAnnotations'):
+                new_test_pipeline.append(pipeline)
+        self.test_pipeline = Compose(new_test_pipeline)
+
+        self.is_need_loss = False
+        self.input_data = None
+        self.image = None
+
+    def need_loss(self, is_need_loss: bool):
+        """Grad-based methods require loss."""
+        self.is_need_loss = is_need_loss
+
+    def set_input_data(self,
+                       image: np.ndarray,
+                       pred_instances: Optional[InstanceData] = None):
+        """Set the input data to be used in the next step."""
+        self.image = image
+
+        if self.is_need_loss:
+            assert pred_instances is not None
+            pred_instances = pred_instances.numpy()
+            data = dict(
+                img=self.image,
+                img_id=0,
+                gt_bboxes=pred_instances.bboxes,
+                gt_bboxes_labels=pred_instances.labels)
+            data = self.test_pipeline(data)
+        else:
+            data = dict(img=self.image, img_id=0)
+            data = self.test_pipeline(data)
+            data['inputs'] = [data['inputs']]
+            data['data_samples'] = [data['data_samples']]
+        self.input_data = data
+
+    def __call__(self, *args, **kwargs):
+        assert self.input_data is not None
+        if self.is_need_loss:
+            # Maybe this is a direction that can be optimized
+            # self.detector.init_weights()
+
+            if hasattr(self.detector.bbox_head, 'featmap_sizes'):
+                # Prevent the model algorithm error when calculating loss
+                self.detector.bbox_head.featmap_sizes = None
+
+            data_ = {}
+            data_['inputs'] = [self.input_data['inputs']]
+            data_['data_samples'] = [self.input_data['data_samples']]
+            data = self.detector.data_preprocessor(data_, training=False)
+            loss = self.detector._run_forward(data, mode='loss')
+
+            if hasattr(self.detector.bbox_head, 'featmap_sizes'):
+                self.detector.bbox_head.featmap_sizes = None
+
+            return [loss]
+        else:
+            with torch.no_grad():
+                results = self.detector.test_step(self.input_data)
+                return results
+
+
+class BoxAMDetectorVisualizer:
+    """Box AM visualization class."""
+
+    def __init__(self,
+                 method_class,
+                 model: nn.Module,
+                 target_layers: List,
+                 reshape_transform: Optional[Callable] = None,
+                 is_need_grad: bool = False,
+                 extra_params: Optional[dict] = None):
+        self.target_layers = target_layers
+        self.reshape_transform = reshape_transform
+        self.is_need_grad = is_need_grad
+
+        if method_class.__name__ == 'AblationCAM':
+            batch_size = extra_params.get('batch_size', 1)
+            ratio_channels_to_ablate = extra_params.get(
+                'ratio_channels_to_ablate', 1.)
+            self.cam = AblationCAM(
+                model,
+                target_layers,
+                use_cuda=True if 'cuda' in model.device else False,
+                reshape_transform=reshape_transform,
+                batch_size=batch_size,
+                ablation_layer=extra_params['ablation_layer'],
+                ratio_channels_to_ablate=ratio_channels_to_ablate)
+        else:
+            self.cam = method_class(
+                model,
+                target_layers,
+                use_cuda=True if 'cuda' in model.device else False,
+                reshape_transform=reshape_transform,
+            )
+            if self.is_need_grad:
+                self.cam.activations_and_grads.release()
+
+        self.classes = model.detector.dataset_meta['CLASSES']
+        self.COLORS = np.random.uniform(0, 255, size=(len(self.classes), 3))
+
+    def switch_activations_and_grads(self, model) -> None:
+        """In the grad-based method, we need to switch
+        ``ActivationsAndGradients`` layer, otherwise an error will occur."""
+        self.cam.model = model
+
+        if self.is_need_grad is True:
+            self.cam.activations_and_grads = ActivationsAndGradients(
+                model, self.target_layers, self.reshape_transform)
+            self.is_need_grad = False
+        else:
+            self.cam.activations_and_grads.release()
+            self.is_need_grad = True
+
+    def __call__(self, img, targets, aug_smooth=False, eigen_smooth=False):
+        img = torch.from_numpy(img)[None].permute(0, 3, 1, 2)
+        return self.cam(img, targets, aug_smooth, eigen_smooth)[0, :]
+
+    def show_am(self,
+                image: np.ndarray,
+                pred_instance: InstanceData,
+                grayscale_am: np.ndarray,
+                with_norm_in_bboxes: bool = False):
+        """Normalize the AM to be in the range [0, 1] inside every bounding
+        boxes, and zero outside of the bounding boxes."""
+
+        boxes = pred_instance.bboxes
+        labels = pred_instance.labels
+
+        if with_norm_in_bboxes is True:
+            boxes = boxes.astype(np.int32)
+            renormalized_am = np.zeros(grayscale_am.shape, dtype=np.float32)
+            images = []
+            for x1, y1, x2, y2 in boxes:
+                img = renormalized_am * 0
+                img[y1:y2, x1:x2] = scale_cam_image(
+                    [grayscale_am[y1:y2, x1:x2].copy()])[0]
+                images.append(img)
+
+            renormalized_am = np.max(np.float32(images), axis=0)
+            renormalized_am = scale_cam_image([renormalized_am])[0]
+        else:
+            renormalized_am = grayscale_am
+
+        am_image_renormalized = show_cam_on_image(
+            image / 255, renormalized_am, use_rgb=False)
+
+        image_with_bounding_boxes = self._draw_boxes(
+            boxes, labels, am_image_renormalized, pred_instance.get('scores'))
+        return image_with_bounding_boxes
+
+    def _draw_boxes(self,
+                    boxes: List,
+                    labels: List,
+                    image: np.ndarray,
+                    scores: Optional[List] = None):
+        """draw boxes on image."""
+        for i, box in enumerate(boxes):
+            label = labels[i]
+            color = self.COLORS[label]
+            cv2.rectangle(image, (int(box[0]), int(box[1])),
+                          (int(box[2]), int(box[3])), color, 2)
+            if scores is not None:
+                score = scores[i]
+                text = str(self.classes[label]) + ': ' + str(
+                    round(score * 100, 1))
+            else:
+                text = self.classes[label]
+
+            cv2.putText(
+                image,
+                text, (int(box[0]), int(box[1] - 5)),
+                cv2.FONT_HERSHEY_SIMPLEX,
+                0.5,
+                color,
+                1,
+                lineType=cv2.LINE_AA)
+        return image
+
+
+class DetAblationLayer(AblationLayer):
+    """Det AblationLayer."""
+
+    def __init__(self):
+        super().__init__()
+        self.activations = None
+
+    def set_next_batch(self, input_batch_index, activations,
+                       num_channels_to_ablate):
+        """Extract the next batch member from activations, and repeat it
+        num_channels_to_ablate times."""
+        if isinstance(activations, torch.Tensor):
+            return super().set_next_batch(input_batch_index, activations,
+                                          num_channels_to_ablate)
+
+        self.activations = []
+        for activation in activations:
+            activation = activation[
+                input_batch_index, :, :, :].clone().unsqueeze(0)
+            self.activations.append(
+                activation.repeat(num_channels_to_ablate, 1, 1, 1))
+
+    def __call__(self, x):
+        """Go over the activation indices to be ablated, stored in
+        self.indices."""
+        result = self.activations
+
+        if isinstance(result, torch.Tensor):
+            return super().__call__(x)
+
+        channel_cumsum = np.cumsum([r.shape[1] for r in result])
+        num_channels_to_ablate = result[0].size(0)  # batch
+        for i in range(num_channels_to_ablate):
+            pyramid_layer = bisect.bisect_right(channel_cumsum,
+                                                self.indices[i])
+            if pyramid_layer > 0:
+                index_in_pyramid_layer = self.indices[i] - channel_cumsum[
+                    pyramid_layer - 1]
+            else:
+                index_in_pyramid_layer = self.indices[i]
+            result[pyramid_layer][i, index_in_pyramid_layer, :, :] = -1000
+        return result
+
+
+class DetBoxScoreTarget:
+    """Det Score calculation class.
+
+    In the case of the grad-free method, the calculation method is that
+    for every original detected bounding box specified in "bboxes",
+    assign a score on how the current bounding boxes match it,
+
+        1. In Bbox IoU
+        2. In the classification score.
+        3. In Mask IoU if ``segms`` exist.
+
+    If there is not a large enough overlap, or the category changed,
+    assign a score of 0. The total score is the sum of all the box scores.
+
+    In the case of the grad-based method, the calculation method is
+    the sum of losses after excluding a specific key.
+    """
+
+    def __init__(self,
+                 pred_instance: InstanceData,
+                 match_iou_thr: float = 0.5,
+                 device: str = 'cuda:0',
+                 ignore_loss_params: Optional[List] = None):
+        self.focal_bboxes = pred_instance.bboxes
+        self.focal_labels = pred_instance.labels
+        self.match_iou_thr = match_iou_thr
+        self.device = device
+        self.ignore_loss_params = ignore_loss_params
+        if ignore_loss_params is not None:
+            assert isinstance(self.ignore_loss_params, list)
+
+    def __call__(self, results):
+        output = torch.tensor([0.], device=self.device)
+
+        if 'loss_cls' in results:
+            # grad-based method
+            # results is dict
+            for loss_key, loss_value in results.items():
+                if 'loss' not in loss_key or \
+                        loss_key in self.ignore_loss_params:
+                    continue
+                if isinstance(loss_value, list):
+                    output += sum(loss_value)
+                else:
+                    output += loss_value
+            return output
+        else:
+            # grad-free method
+            # results is DetDataSample
+            pred_instances = results.pred_instances
+            if len(pred_instances) == 0:
+                return output
+
+            pred_bboxes = pred_instances.bboxes
+            pred_scores = pred_instances.scores
+            pred_labels = pred_instances.labels
+
+            for focal_box, focal_label in zip(self.focal_bboxes,
+                                              self.focal_labels):
+                ious = torchvision.ops.box_iou(focal_box[None],
+                                               pred_bboxes[..., :4])
+                index = ious.argmax()
+                if ious[0, index] > self.match_iou_thr and pred_labels[
+                        index] == focal_label:
+                    # TODO: Adaptive adjustment of weights based on algorithms
+                    score = ious[0, index] + pred_scores[index]
+                    output = output + score
+            return output
+
+
+class SpatialBaseCAM(BaseCAM):
+    """CAM that maintains spatial information.
+
+    Gradients are often averaged over the spatial dimension in CAM
+    visualization for classification, but this is unreasonable in detection
+    tasks. There is no need to average the gradients in the detection task.
+    """
+
+    def get_cam_image(self,
+                      input_tensor: torch.Tensor,
+                      target_layer: torch.nn.Module,
+                      targets: List[torch.nn.Module],
+                      activations: torch.Tensor,
+                      grads: torch.Tensor,
+                      eigen_smooth: bool = False) -> np.ndarray:
+
+        weights = self.get_cam_weights(input_tensor, target_layer, targets,
+                                       activations, grads)
+        weighted_activations = weights * activations
+        if eigen_smooth:
+            cam = get_2d_projection(weighted_activations)
+        else:
+            cam = weighted_activations.sum(axis=1)
+        return cam
+
+
+class GradCAM(SpatialBaseCAM, Base_GradCAM):
+    """Gradients are no longer averaged over the spatial dimension."""
+
+    def get_cam_weights(self, input_tensor, target_layer, target_category,
+                        activations, grads):
+        return grads
+
+
+class GradCAMPlusPlus(SpatialBaseCAM, Base_GradCAMPlusPlus):
+    """Gradients are no longer averaged over the spatial dimension."""
+
+    def get_cam_weights(self, input_tensor, target_layers, target_category,
+                        activations, grads):
+        grads_power_2 = grads**2
+        grads_power_3 = grads_power_2 * grads
+        # Equation 19 in https://arxiv.org/abs/1710.11063
+        sum_activations = np.sum(activations, axis=(2, 3))
+        eps = 0.000001
+        aij = grads_power_2 / (
+            2 * grads_power_2 +
+            sum_activations[:, :, None, None] * grads_power_3 + eps)
+        # Now bring back the ReLU from eq.7 in the paper,
+        # And zero out aijs where the activations are 0
+        aij = np.where(grads != 0, aij, 0)
+
+        weights = np.maximum(grads, 0) * aij
+        return weights