mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Support grad-based cam and grad-free cam (#234)
* support cam * update * update * done * fix lint * add docstr * add doc * update * update * FIXpull/317/head
parent
1045b41b68
commit
9221499af4
|
@ -0,0 +1,276 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""This script is in the experimental verification stage and cannot be
|
||||
guaranteed to be completely correct. Currently Grad-based CAM and Grad-free CAM
|
||||
are supported.
|
||||
|
||||
The target detection task is different from the classification task. It not
|
||||
only includes the AM map of the category, but also includes information such as
|
||||
bbox and mask, so this script is named bboxam.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os.path
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
from mmengine import Config, DictAction, MessageHub
|
||||
from mmengine.utils import ProgressBar
|
||||
|
||||
from mmyolo.utils import register_all_modules
|
||||
from mmyolo.utils.boxam_utils import (BoxAMDetectorVisualizer,
|
||||
BoxAMDetectorWrapper, DetAblationLayer,
|
||||
DetBoxScoreTarget, GradCAM,
|
||||
GradCAMPlusPlus, reshape_transform)
|
||||
from mmyolo.utils.misc import get_file_list
|
||||
|
||||
try:
|
||||
from pytorch_grad_cam import AblationCAM, EigenCAM
|
||||
except ImportError:
|
||||
raise ImportError('Please run `pip install "grad-cam"` to install '
|
||||
'pytorch_grad_cam package.')
|
||||
|
||||
GRAD_FREE_METHOD_MAP = {
|
||||
'ablationcam': AblationCAM,
|
||||
'eigencam': EigenCAM,
|
||||
# 'scorecam': ScoreCAM, # consumes too much memory
|
||||
}
|
||||
|
||||
GRAD_BASED_METHOD_MAP = {'gradcam': GradCAM, 'gradcam++': GradCAMPlusPlus}
|
||||
|
||||
ALL_SUPPORT_METHODS = list(GRAD_FREE_METHOD_MAP.keys()
|
||||
| GRAD_BASED_METHOD_MAP.keys())
|
||||
|
||||
IGNORE_LOSS_PARAMS = {
|
||||
'yolov5': ['loss_obj'],
|
||||
'yolov6': ['loss_cls'],
|
||||
'yolox': ['loss_obj'],
|
||||
'rtmdet': ['loss_cls'],
|
||||
}
|
||||
|
||||
# This parameter is required in some algorithms
|
||||
# for calculating Loss
|
||||
message_hub = MessageHub.get_current_instance()
|
||||
message_hub.runtime_info['epoch'] = 0
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Visualize Box AM')
|
||||
parser.add_argument(
|
||||
'img', help='Image path, include image file, dir and URL.')
|
||||
parser.add_argument('config', help='Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--method',
|
||||
default='gradcam',
|
||||
choices=ALL_SUPPORT_METHODS,
|
||||
help='Type of method to use, supports '
|
||||
f'{", ".join(ALL_SUPPORT_METHODS)}.')
|
||||
parser.add_argument(
|
||||
'--target-layers',
|
||||
default=['neck.out_layers[2]'],
|
||||
nargs='+',
|
||||
type=str,
|
||||
help='The target layers to get Box AM, if not set, the tool will '
|
||||
'specify the neck.out_layers[2]')
|
||||
parser.add_argument(
|
||||
'--out-dir', default='./output', help='Path to output file')
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='Show the CAM results')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
parser.add_argument(
|
||||
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
|
||||
parser.add_argument(
|
||||
'--topk',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='Select topk predict resutls to show. -1 are mean all.')
|
||||
parser.add_argument(
|
||||
'--max-shape',
|
||||
nargs='+',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='max shapes. Its purpose is to save GPU memory. '
|
||||
'The activation map is scaled and then evaluated. '
|
||||
'If set to -1, it means no scaling.')
|
||||
parser.add_argument(
|
||||
'--preview-model',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='To preview all the model layers')
|
||||
parser.add_argument(
|
||||
'--norm-in-bbox', action='store_true', help='Norm in bbox of am image')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
# Only used by AblationCAM
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='batch of inference of AblationCAM')
|
||||
parser.add_argument(
|
||||
'--ratio-channels-to-ablate',
|
||||
type=int,
|
||||
default=0.5,
|
||||
help='Making it much faster of AblationCAM. '
|
||||
'The parameter controls how many channels should be ablated')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def init_detector_and_visualizer(args, cfg):
|
||||
max_shape = args.max_shape
|
||||
if not isinstance(max_shape, list):
|
||||
max_shape = [args.max_shape]
|
||||
assert len(max_shape) == 1 or len(max_shape) == 2
|
||||
|
||||
model_wrapper = BoxAMDetectorWrapper(
|
||||
cfg, args.checkpoint, args.score_thr, device=args.device)
|
||||
|
||||
if args.preview_model:
|
||||
print(model_wrapper.detector)
|
||||
print('\n Please remove `--preview-model` to get the BoxAM.')
|
||||
return None, None
|
||||
|
||||
target_layers = []
|
||||
for target_layer in args.target_layers:
|
||||
try:
|
||||
target_layers.append(
|
||||
eval(f'model_wrapper.detector.{target_layer}'))
|
||||
except Exception as e:
|
||||
print(model_wrapper.detector)
|
||||
raise RuntimeError('layer does not exist', e)
|
||||
|
||||
ablationcam_extra_params = {
|
||||
'batch_size': args.batch_size,
|
||||
'ablation_layer': DetAblationLayer(),
|
||||
'ratio_channels_to_ablate': args.ratio_channels_to_ablate
|
||||
}
|
||||
|
||||
if args.method in GRAD_BASED_METHOD_MAP:
|
||||
method_class = GRAD_BASED_METHOD_MAP[args.method]
|
||||
is_need_grad = True
|
||||
else:
|
||||
method_class = GRAD_FREE_METHOD_MAP[args.method]
|
||||
is_need_grad = False
|
||||
|
||||
boxam_detector_visualizer = BoxAMDetectorVisualizer(
|
||||
method_class,
|
||||
model_wrapper,
|
||||
target_layers,
|
||||
reshape_transform=partial(
|
||||
reshape_transform, max_shape=max_shape, is_need_grad=is_need_grad),
|
||||
is_need_grad=is_need_grad,
|
||||
extra_params=ablationcam_extra_params)
|
||||
return model_wrapper, boxam_detector_visualizer
|
||||
|
||||
|
||||
def main():
|
||||
register_all_modules()
|
||||
|
||||
args = parse_args()
|
||||
|
||||
# hard code
|
||||
ignore_loss_params = None
|
||||
for param_keys in IGNORE_LOSS_PARAMS:
|
||||
if param_keys in args.config:
|
||||
print(f'The algorithm currently used is {param_keys}')
|
||||
ignore_loss_params = IGNORE_LOSS_PARAMS[param_keys]
|
||||
break
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
if not os.path.exists(args.out_dir) and not args.show:
|
||||
os.mkdir(args.out_dir)
|
||||
|
||||
model_wrapper, boxam_detector_visualizer = init_detector_and_visualizer(
|
||||
args, cfg)
|
||||
|
||||
# get file list
|
||||
image_list, source_type = get_file_list(args.img)
|
||||
|
||||
progress_bar = ProgressBar(len(image_list))
|
||||
|
||||
for image_path in image_list:
|
||||
image = cv2.imread(image_path)
|
||||
model_wrapper.set_input_data(image)
|
||||
|
||||
# forward detection results
|
||||
result = model_wrapper()[0]
|
||||
|
||||
pred_instances = result.pred_instances
|
||||
# Get candidate predict info with score threshold
|
||||
pred_instances = pred_instances[pred_instances.scores > args.score_thr]
|
||||
|
||||
if len(pred_instances) == 0:
|
||||
warnings.warn('empty detection results! skip this')
|
||||
continue
|
||||
|
||||
if args.topk > 0:
|
||||
pred_instances = pred_instances[:args.topk]
|
||||
|
||||
targets = [
|
||||
DetBoxScoreTarget(
|
||||
pred_instances,
|
||||
device=args.device,
|
||||
ignore_loss_params=ignore_loss_params)
|
||||
]
|
||||
|
||||
if args.method in GRAD_BASED_METHOD_MAP:
|
||||
model_wrapper.need_loss(True)
|
||||
model_wrapper.set_input_data(image, pred_instances)
|
||||
boxam_detector_visualizer.switch_activations_and_grads(
|
||||
model_wrapper)
|
||||
|
||||
# get box am image
|
||||
grayscale_boxam = boxam_detector_visualizer(image, targets=targets)
|
||||
|
||||
# draw cam on image
|
||||
pred_instances = pred_instances.numpy()
|
||||
image_with_bounding_boxes = boxam_detector_visualizer.show_am(
|
||||
image,
|
||||
pred_instances,
|
||||
grayscale_boxam,
|
||||
with_norm_in_bboxes=args.norm_in_bbox)
|
||||
|
||||
if source_type['is_dir']:
|
||||
filename = os.path.relpath(image_path, args.img).replace('/', '_')
|
||||
else:
|
||||
filename = os.path.basename(image_path)
|
||||
out_file = None if args.show else os.path.join(args.out_dir, filename)
|
||||
|
||||
if out_file:
|
||||
mmcv.imwrite(image_with_bounding_boxes, out_file)
|
||||
else:
|
||||
cv2.namedWindow(filename, 0)
|
||||
cv2.imshow(filename, image_with_bounding_boxes)
|
||||
cv2.waitKey(0)
|
||||
|
||||
# switch
|
||||
if args.method in GRAD_BASED_METHOD_MAP:
|
||||
model_wrapper.need_loss(False)
|
||||
boxam_detector_visualizer.switch_activations_and_grads(
|
||||
model_wrapper)
|
||||
|
||||
progress_bar.update()
|
||||
|
||||
if not args.show:
|
||||
print(f'All done!'
|
||||
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -13,7 +13,6 @@ from mmyolo.utils import register_all_modules
|
|||
from mmyolo.utils.misc import auto_arrange_images, get_file_list
|
||||
|
||||
|
||||
# TODO: Refine
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Visualize feature map')
|
||||
parser.add_argument(
|
||||
|
@ -190,8 +189,9 @@ def main():
|
|||
if args.show:
|
||||
visualizer.show(shown_imgs)
|
||||
|
||||
print(f'All done!'
|
||||
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
|
||||
if not args.show:
|
||||
print(f'All done!'
|
||||
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
|
||||
|
||||
|
||||
# Please refer to the usage tutorial:
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Visualization
|
||||
|
||||
This article includes feature map visualization and Grad-Based and Grad-Free CAM visualization
|
||||
|
||||
## Feature map visualization
|
||||
|
||||
<div align=center>
|
||||
|
@ -13,7 +15,7 @@ In MMYOLO, you can use the `Visualizer` provided in MMEngine for feature map vis
|
|||
- Support basic drawing interfaces and feature map visualization.
|
||||
- Support selecting different layers in the model to get the feature map. The display methods include `squeeze_mean`, `select_max`, and `topk`. Users can also customize the layout of the feature map display with `arrangement`.
|
||||
|
||||
## Feature map generation
|
||||
### Feature map generation
|
||||
|
||||
You can use `demo/featmap_vis_demo.py` to get a quick view of the visualization results. To better understand all functions, we list all primary parameters and their features here as follows:
|
||||
|
||||
|
@ -51,7 +53,7 @@ You can use `demo/featmap_vis_demo.py` to get a quick view of the visualization
|
|||
|
||||
**Note: When the image and feature map scales are different, the `draw_featmap` function will automatically perform an upsampling alignment. If your image has an operation such as `Pad` in the preprocessing during the inference, the feature map obtained is processed with `Pad`, which may cause misalignment problems if you directly upsample the image.**
|
||||
|
||||
## Usage examples
|
||||
### Usage examples
|
||||
|
||||
Take the pre-trained YOLOv5-s model as an example. Please download the model weight file to the root directory.
|
||||
|
||||
|
@ -88,7 +90,7 @@ The original `test_pipeline` is:
|
|||
test_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args={{_base_.file_client_args}}),
|
||||
file_client_args=_base_.file_client_args),
|
||||
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
|
||||
dict(
|
||||
type='LetterResize',
|
||||
|
@ -166,7 +168,7 @@ python demo/featmap_vis_demo.py demo/dog.jpg \
|
|||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/198522489-8adee6ae-9915-4e9d-bf50-167b8a12c275.png" width="1200" alt="image"/>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/198522489-8adee6ae-9915-4e9d-bf50-167b8a12c275.png" width="800" alt="image"/>
|
||||
</div>
|
||||
|
||||
(5) When the visualization process finishes, you can choose to display the result or store it locally. You only need to add the parameter `--out-file xxx.jpg`:
|
||||
|
@ -179,3 +181,113 @@ python demo/featmap_vis_demo.py demo/dog.jpg \
|
|||
--channel-reduction select_max \
|
||||
--out-file featmap_backbone.jpg
|
||||
```
|
||||
|
||||
## Grad-Based and Grad-Free CAM Visualization
|
||||
|
||||
Object detection CAM visualization is much more complex and different than classification CAM.
|
||||
This article only briefly explains the usage, and a separate document will be opened to describe the implementation principles and precautions in detail later.
|
||||
|
||||
You can call `demo/boxmap_vis_demo.py` to get the AM visualization results at the Box level easily and quickly. Currently, `YOLOv5/YOLOv6/YOLOX/RTMDet` is supported.
|
||||
|
||||
Taking YOLOv5 as an example, as with the feature map visualization, you need to modify the `test_pipeline` first, otherwise there will be a problem of misalignment between the feature map and the original image.
|
||||
|
||||
The original `test_pipeline` is:
|
||||
|
||||
```python
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=_base_.file_client_args),
|
||||
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
|
||||
dict(
|
||||
type='LetterResize',
|
||||
scale=img_scale,
|
||||
allow_scale_up=False,
|
||||
pad_val=dict(img=114)),
|
||||
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor', 'pad_param'))
|
||||
]
|
||||
```
|
||||
|
||||
Change to the following version:
|
||||
|
||||
```python
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=_base_.file_client_args),
|
||||
dict(type='mmdet.Resize', scale=img_scale, keep_ratio=False), # change the LetterResize to mmdet.Resize
|
||||
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor'))
|
||||
]
|
||||
```
|
||||
|
||||
(1) Use the `GradCAM` method to visualize the AM of the last output layer of the neck module
|
||||
|
||||
```shell
|
||||
python demo/boxam_vis_demo.py \
|
||||
demo/dog.jpg \
|
||||
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth
|
||||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/203775584-c4aebf11-4ff8-4530-85fe-7dda897e95a8.jpg" width="800" alt="image"/>
|
||||
</div>
|
||||
|
||||
The corresponding feature AM is as follows:
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/203774801-1555bcfb-a8f9-4688-8ed6-982d6ad38e1d.jpg" width="800" alt="image"/>
|
||||
</div>
|
||||
|
||||
It can be seen that the `GradCAM` effect can highlight the AM information at the box level.
|
||||
|
||||
You can choose to visualize only the top prediction boxes with the highest prediction scores via the `--topk` parameter
|
||||
|
||||
```shell
|
||||
python demo/boxam_vis_demo.py \
|
||||
demo/dog.jpg \
|
||||
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
|
||||
--topk 2
|
||||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/203778700-3165aa72-ecaf-40cc-b470-6911646e6046.jpg" width="800" alt="image"/>
|
||||
</div>
|
||||
|
||||
(2) Use the AblationCAM method to visualize the AM of the last output layer of the neck module
|
||||
|
||||
```shell
|
||||
python demo/boxam_vis_demo.py \
|
||||
demo/dog.jpg \
|
||||
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
|
||||
--method ablationcam
|
||||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/203776978-b5a9b383-93b4-4b35-9e6a-7cac684b372c.jpg" width="800" alt="image"/>
|
||||
</div>
|
||||
|
||||
Since `AblationCAM` is weighted by the contribution of each channel to the score, it is impossible to visualize only the AM information at the box level like `GradCAN`. But you can use `--norm-in-bbox` to only show bbox inside AM
|
||||
|
||||
```shell
|
||||
python demo/boxam_vis_demo.py \
|
||||
demo/dog.jpg \
|
||||
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
|
||||
--method ablationcam \
|
||||
--norm-in-bbox
|
||||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/203777566-7c74e82f-b477-488e-958f-91e1d10833b9.jpg" width="800" alt="image"/>
|
||||
</div>
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# 可视化
|
||||
|
||||
本文包括特征图可视化和 Grad-Based 和 Grad-Free CAM 可视化
|
||||
|
||||
## 特征图可视化
|
||||
|
||||
<div align=center>
|
||||
|
@ -12,7 +14,7 @@ MMYOLO 中,将使用 MMEngine 提供的 `Visualizer` 可视化器进行特征
|
|||
- 支持基础绘图接口以及特征图可视化。
|
||||
- 支持选择模型中的不同层来得到特征图,包含 `squeeze_mean` , `select_max` , `topk` 三种显示方式,用户还可以使用 `arrangement` 自定义特征图显示的布局方式。
|
||||
|
||||
## 特征图绘制
|
||||
### 特征图绘制
|
||||
|
||||
你可以调用 `demo/featmap_vis_demo.py` 来简单快捷地得到可视化结果,为了方便理解,将其主要参数的功能梳理如下:
|
||||
|
||||
|
@ -50,7 +52,7 @@ MMYOLO 中,将使用 MMEngine 提供的 `Visualizer` 可视化器进行特征
|
|||
|
||||
**注意:当图片和特征图尺度不一样时候,`draw_featmap` 函数会自动进行上采样对齐。如果你的图片在推理过程中前处理存在类似 Pad 的操作此时得到的特征图也是 Pad 过的,那么直接上采样就可能会出现不对齐问题。**
|
||||
|
||||
## 用法示例
|
||||
### 用法示例
|
||||
|
||||
以预训练好的 YOLOv5-s 模型为例:
|
||||
|
||||
|
@ -167,7 +169,7 @@ python demo/featmap_vis_demo.py demo/dog.jpg \
|
|||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/198522489-8adee6ae-9915-4e9d-bf50-167b8a12c275.png" width="1200" alt="image"/>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/198522489-8adee6ae-9915-4e9d-bf50-167b8a12c275.png" width="800" alt="image"/>
|
||||
</div>
|
||||
|
||||
(5) 存储绘制后的图片,在绘制完成后,可以选择本地窗口显示,也可以存储到本地,只需要加入参数 `--out-file xxx.jpg`:
|
||||
|
@ -180,3 +182,113 @@ python demo/featmap_vis_demo.py demo/dog.jpg \
|
|||
--channel-reduction select_max \
|
||||
--out-file featmap_backbone.jpg
|
||||
```
|
||||
|
||||
## Grad-Based 和 Grad-Free CAM 可视化
|
||||
|
||||
目标检测 CAM 可视化相比于分类 CAM 复杂很多且差异很大。本文只是简要说明用法,后续会单独开文档详细描述实现原理和注意事项。
|
||||
|
||||
你可以调用 `demo/boxmap_vis_demo.py` 来简单快捷地得到 Box 级别的 AM 可视化结果,目前已经支持 `YOLOv5/YOLOv6/YOLOX/RTMDet`。
|
||||
|
||||
以 YOLOv5 为例,和特征图可视化绘制一样,你需要先修改 `test_pipeline`,否则会出现特征图和原图不对齐问题。
|
||||
|
||||
旧的 `test_pipeline` 为:
|
||||
|
||||
```python
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=_base_.file_client_args),
|
||||
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
|
||||
dict(
|
||||
type='LetterResize',
|
||||
scale=img_scale,
|
||||
allow_scale_up=False,
|
||||
pad_val=dict(img=114)),
|
||||
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor', 'pad_param'))
|
||||
]
|
||||
```
|
||||
|
||||
修改为如下配置:
|
||||
|
||||
```python
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=_base_.file_client_args),
|
||||
dict(type='mmdet.Resize', scale=img_scale, keep_ratio=False), # 这里将 LetterResize 修改成 mmdet.Resize
|
||||
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor'))
|
||||
]
|
||||
```
|
||||
|
||||
(1) 使用 `GradCAM` 方法可视化 neck 模块的最后一个输出层的 AM 图
|
||||
|
||||
```shell
|
||||
python demo/boxam_vis_demo.py \
|
||||
demo/dog.jpg \
|
||||
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth
|
||||
|
||||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/203775584-c4aebf11-4ff8-4530-85fe-7dda897e95a8.jpg" width="800" alt="image"/>
|
||||
</div>
|
||||
|
||||
相对应的特征图 AM 图如下:
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/203774801-1555bcfb-a8f9-4688-8ed6-982d6ad38e1d.jpg" width="800" alt="image"/>
|
||||
</div>
|
||||
|
||||
可以看出 `GradCAM` 效果可以突出 box 级别的 AM 信息。
|
||||
|
||||
你可以通过 `--topk` 参数选择仅仅可视化预测分值最高的前几个预测框
|
||||
|
||||
```shell
|
||||
python demo/boxam_vis_demo.py \
|
||||
demo/dog.jpg \
|
||||
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
|
||||
--topk 2
|
||||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/203778700-3165aa72-ecaf-40cc-b470-6911646e6046.jpg" width="800" alt="image"/>
|
||||
</div>
|
||||
|
||||
(2) 使用 `AblationCAM` 方法可视化 neck 模块的最后一个输出层的 AM 图
|
||||
|
||||
```shell
|
||||
python demo/boxam_vis_demo.py \
|
||||
demo/dog.jpg \
|
||||
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
|
||||
--method ablationcam
|
||||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/203776978-b5a9b383-93b4-4b35-9e6a-7cac684b372c.jpg" width="800" alt="image"/>
|
||||
</div>
|
||||
|
||||
由于 `AblationCAM` 是通过每个通道对分值的贡献程度来加权,因此无法实现类似 `GradCAM` 的仅仅可视化 box 级别的 AM 信息, 但是你可以使用 `--norm-in-bbox` 来仅仅显示 bbox 内部 AM
|
||||
|
||||
```shell
|
||||
python demo/boxam_vis_demo.py \
|
||||
demo/dog.jpg \
|
||||
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
|
||||
--method ablationcam \
|
||||
--norm-in-bbox
|
||||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/203777566-7c74e82f-b477-488e-958f-91e1d10833b9.jpg" width="800" alt="image"/>
|
||||
</div>
|
||||
|
|
|
@ -377,7 +377,7 @@ class YOLOv6Head(YOLOv5Head):
|
|||
loss_cls=loss_cls * world_size, loss_bbox=loss_bbox * world_size)
|
||||
|
||||
@staticmethod
|
||||
def gt_instances_preprocess(batch_gt_instances: Tensor,
|
||||
def gt_instances_preprocess(batch_gt_instances: Union[Tensor, Sequence],
|
||||
batch_size: int) -> Tensor:
|
||||
"""Split batch_gt_instances with batch size, from [all_gt_bboxes, 6]
|
||||
to.
|
||||
|
@ -393,28 +393,51 @@ class YOLOv6Head(YOLOv5Head):
|
|||
Returns:
|
||||
Tensor: batch gt instances data, shape [batch_size, number_gt, 5]
|
||||
"""
|
||||
if isinstance(batch_gt_instances, Sequence):
|
||||
max_gt_bbox_len = max(
|
||||
[len(gt_instances) for gt_instances in batch_gt_instances])
|
||||
# fill [-1., 0., 0., 0., 0.] if some shape of
|
||||
# single batch not equal max_gt_bbox_len
|
||||
batch_instance_list = []
|
||||
for index, gt_instance in enumerate(batch_gt_instances):
|
||||
bboxes = gt_instance.bboxes
|
||||
labels = gt_instance.labels
|
||||
batch_instance_list.append(
|
||||
torch.cat((labels[:, None], bboxes), dim=-1))
|
||||
|
||||
# sqlit batch gt instance [all_gt_bboxes, 6] ->
|
||||
# [batch_size, number_gt_each_batch, 5]
|
||||
batch_instance_list = []
|
||||
max_gt_bbox_len = 0
|
||||
for i in range(batch_size):
|
||||
single_batch_instance = \
|
||||
batch_gt_instances[batch_gt_instances[:, 0] == i, :]
|
||||
single_batch_instance = single_batch_instance[:, 1:]
|
||||
batch_instance_list.append(single_batch_instance)
|
||||
if len(single_batch_instance) > max_gt_bbox_len:
|
||||
max_gt_bbox_len = len(single_batch_instance)
|
||||
if bboxes.shape[0] >= max_gt_bbox_len:
|
||||
continue
|
||||
|
||||
# fill [-1., 0., 0., 0., 0.] if some shape of
|
||||
# single batch not equal max_gt_bbox_len
|
||||
for index, gt_instance in enumerate(batch_instance_list):
|
||||
if gt_instance.shape[0] >= max_gt_bbox_len:
|
||||
continue
|
||||
fill_tensor = batch_gt_instances.new_full(
|
||||
[max_gt_bbox_len - gt_instance.shape[0], 5], 0)
|
||||
fill_tensor[:, 0] = -1.
|
||||
batch_instance_list[index] = torch.cat(
|
||||
(batch_instance_list[index], fill_tensor), dim=0)
|
||||
fill_tensor = bboxes.new_full(
|
||||
[max_gt_bbox_len - bboxes.shape[0], 5], 0)
|
||||
fill_tensor[:, 0] = -1.
|
||||
batch_instance_list[index] = torch.cat(
|
||||
(batch_instance_list[-1], fill_tensor), dim=0)
|
||||
|
||||
return torch.stack(batch_instance_list)
|
||||
return torch.stack(batch_instance_list)
|
||||
else:
|
||||
# faster version
|
||||
# sqlit batch gt instance [all_gt_bboxes, 6] ->
|
||||
# [batch_size, number_gt_each_batch, 5]
|
||||
batch_instance_list = []
|
||||
max_gt_bbox_len = 0
|
||||
for i in range(batch_size):
|
||||
single_batch_instance = \
|
||||
batch_gt_instances[batch_gt_instances[:, 0] == i, :]
|
||||
single_batch_instance = single_batch_instance[:, 1:]
|
||||
batch_instance_list.append(single_batch_instance)
|
||||
if len(single_batch_instance) > max_gt_bbox_len:
|
||||
max_gt_bbox_len = len(single_batch_instance)
|
||||
|
||||
# fill [-1., 0., 0., 0., 0.] if some shape of
|
||||
# single batch not equal max_gt_bbox_len
|
||||
for index, gt_instance in enumerate(batch_instance_list):
|
||||
if gt_instance.shape[0] >= max_gt_bbox_len:
|
||||
continue
|
||||
fill_tensor = batch_gt_instances.new_full(
|
||||
[max_gt_bbox_len - gt_instance.shape[0], 5], 0)
|
||||
fill_tensor[:, 0] = -1.
|
||||
batch_instance_list[index] = torch.cat(
|
||||
(batch_instance_list[index], fill_tensor), dim=0)
|
||||
|
||||
return torch.stack(batch_instance_list)
|
||||
|
|
|
@ -0,0 +1,504 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import bisect
|
||||
import copy
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from mmcv.transforms import Compose
|
||||
from mmdet.evaluation import get_classes
|
||||
from mmdet.models import build_detector
|
||||
from mmdet.utils import ConfigType
|
||||
from mmengine.config import Config
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
try:
|
||||
from pytorch_grad_cam import (AblationCAM, AblationLayer,
|
||||
ActivationsAndGradients)
|
||||
from pytorch_grad_cam import GradCAM as Base_GradCAM
|
||||
from pytorch_grad_cam import GradCAMPlusPlus as Base_GradCAMPlusPlus
|
||||
from pytorch_grad_cam.base_cam import BaseCAM
|
||||
from pytorch_grad_cam.utils.image import scale_cam_image, show_cam_on_image
|
||||
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def init_detector(
|
||||
config: Union[str, Path, Config],
|
||||
checkpoint: Optional[str] = None,
|
||||
palette: str = 'coco',
|
||||
device: str = 'cuda:0',
|
||||
cfg_options: Optional[dict] = None,
|
||||
) -> nn.Module:
|
||||
"""Initialize a detector from config file.
|
||||
|
||||
Args:
|
||||
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
|
||||
:obj:`Path`, or the config object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
palette (str): Color palette used for visualization. If palette
|
||||
is stored in checkpoint, use checkpoint's palette first, otherwise
|
||||
use externally passed palette. Currently, supports 'coco', 'voc',
|
||||
'citys' and 'random'. Defaults to coco.
|
||||
device (str): The device where the anchors will be put on.
|
||||
Defaults to cuda:0.
|
||||
cfg_options (dict, optional): Options to override some settings in
|
||||
the used config.
|
||||
|
||||
Returns:
|
||||
nn.Module: The constructed detector.
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(config)}')
|
||||
if cfg_options is not None:
|
||||
config.merge_from_dict(cfg_options)
|
||||
elif 'init_cfg' in config.model.backbone:
|
||||
config.model.backbone.init_cfg = None
|
||||
|
||||
# only change this
|
||||
# grad based method requires train_cfg
|
||||
# config.model.train_cfg = None
|
||||
|
||||
model = build_detector(config.model)
|
||||
if checkpoint is not None:
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
# Weights converted from elsewhere may not have meta fields.
|
||||
checkpoint_meta = checkpoint.get('meta', {})
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint_meta:
|
||||
# mmdet 3.x
|
||||
model.dataset_meta = checkpoint_meta['dataset_meta']
|
||||
elif 'CLASSES' in checkpoint_meta:
|
||||
# < mmdet 3.x
|
||||
classes = checkpoint_meta['CLASSES']
|
||||
model.dataset_meta = {'CLASSES': classes, 'PALETTE': palette}
|
||||
else:
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn(
|
||||
'dataset_meta or class names are not saved in the '
|
||||
'checkpoint\'s meta data, use COCO classes by default.')
|
||||
model.dataset_meta = {
|
||||
'CLASSES': get_classes('coco'),
|
||||
'PALETTE': palette
|
||||
}
|
||||
|
||||
model.cfg = config # save the config in the model for convenience
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def reshape_transform(feats: Union[Tensor, List[Tensor]],
|
||||
max_shape: Tuple[int, int] = (20, 20),
|
||||
is_need_grad: bool = False):
|
||||
"""Reshape and aggregate feature maps when the input is a multi-layer
|
||||
feature map.
|
||||
|
||||
Takes these tensors with different sizes, resizes them to a common shape,
|
||||
and concatenates them.
|
||||
"""
|
||||
if len(max_shape) == 1:
|
||||
max_shape = max_shape * 2
|
||||
|
||||
if isinstance(feats, torch.Tensor):
|
||||
feats = [feats]
|
||||
else:
|
||||
if is_need_grad:
|
||||
raise NotImplementedError('The `grad_base` method does not '
|
||||
'support output multi-activation layers')
|
||||
|
||||
max_h = max([im.shape[-2] for im in feats])
|
||||
max_w = max([im.shape[-1] for im in feats])
|
||||
if -1 in max_shape:
|
||||
max_shape = (max_h, max_w)
|
||||
else:
|
||||
max_shape = (min(max_h, max_shape[0]), min(max_w, max_shape[1]))
|
||||
|
||||
activations = []
|
||||
for feat in feats:
|
||||
activations.append(
|
||||
torch.nn.functional.interpolate(
|
||||
torch.abs(feat), max_shape, mode='bilinear'))
|
||||
|
||||
activations = torch.cat(activations, axis=1)
|
||||
return activations
|
||||
|
||||
|
||||
class BoxAMDetectorWrapper(nn.Module):
|
||||
"""Wrap the mmdet model class to facilitate handling of non-tensor
|
||||
situations during inference."""
|
||||
|
||||
def __init__(self,
|
||||
cfg: ConfigType,
|
||||
checkpoint: str,
|
||||
score_thr: float,
|
||||
device: str = 'cuda:0'):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.device = device
|
||||
self.score_thr = score_thr
|
||||
self.checkpoint = checkpoint
|
||||
self.detector = init_detector(self.cfg, self.checkpoint, device=device)
|
||||
|
||||
pipeline_cfg = copy.deepcopy(self.cfg.test_dataloader.dataset.pipeline)
|
||||
pipeline_cfg[0].type = 'mmdet.LoadImageFromNDArray'
|
||||
|
||||
new_test_pipeline = []
|
||||
for pipeline in pipeline_cfg:
|
||||
if not pipeline['type'].endswith('LoadAnnotations'):
|
||||
new_test_pipeline.append(pipeline)
|
||||
self.test_pipeline = Compose(new_test_pipeline)
|
||||
|
||||
self.is_need_loss = False
|
||||
self.input_data = None
|
||||
self.image = None
|
||||
|
||||
def need_loss(self, is_need_loss: bool):
|
||||
"""Grad-based methods require loss."""
|
||||
self.is_need_loss = is_need_loss
|
||||
|
||||
def set_input_data(self,
|
||||
image: np.ndarray,
|
||||
pred_instances: Optional[InstanceData] = None):
|
||||
"""Set the input data to be used in the next step."""
|
||||
self.image = image
|
||||
|
||||
if self.is_need_loss:
|
||||
assert pred_instances is not None
|
||||
pred_instances = pred_instances.numpy()
|
||||
data = dict(
|
||||
img=self.image,
|
||||
img_id=0,
|
||||
gt_bboxes=pred_instances.bboxes,
|
||||
gt_bboxes_labels=pred_instances.labels)
|
||||
data = self.test_pipeline(data)
|
||||
else:
|
||||
data = dict(img=self.image, img_id=0)
|
||||
data = self.test_pipeline(data)
|
||||
data['inputs'] = [data['inputs']]
|
||||
data['data_samples'] = [data['data_samples']]
|
||||
self.input_data = data
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert self.input_data is not None
|
||||
if self.is_need_loss:
|
||||
# Maybe this is a direction that can be optimized
|
||||
# self.detector.init_weights()
|
||||
|
||||
if hasattr(self.detector.bbox_head, 'featmap_sizes'):
|
||||
# Prevent the model algorithm error when calculating loss
|
||||
self.detector.bbox_head.featmap_sizes = None
|
||||
|
||||
data_ = {}
|
||||
data_['inputs'] = [self.input_data['inputs']]
|
||||
data_['data_samples'] = [self.input_data['data_samples']]
|
||||
data = self.detector.data_preprocessor(data_, training=False)
|
||||
loss = self.detector._run_forward(data, mode='loss')
|
||||
|
||||
if hasattr(self.detector.bbox_head, 'featmap_sizes'):
|
||||
self.detector.bbox_head.featmap_sizes = None
|
||||
|
||||
return [loss]
|
||||
else:
|
||||
with torch.no_grad():
|
||||
results = self.detector.test_step(self.input_data)
|
||||
return results
|
||||
|
||||
|
||||
class BoxAMDetectorVisualizer:
|
||||
"""Box AM visualization class."""
|
||||
|
||||
def __init__(self,
|
||||
method_class,
|
||||
model: nn.Module,
|
||||
target_layers: List,
|
||||
reshape_transform: Optional[Callable] = None,
|
||||
is_need_grad: bool = False,
|
||||
extra_params: Optional[dict] = None):
|
||||
self.target_layers = target_layers
|
||||
self.reshape_transform = reshape_transform
|
||||
self.is_need_grad = is_need_grad
|
||||
|
||||
if method_class.__name__ == 'AblationCAM':
|
||||
batch_size = extra_params.get('batch_size', 1)
|
||||
ratio_channels_to_ablate = extra_params.get(
|
||||
'ratio_channels_to_ablate', 1.)
|
||||
self.cam = AblationCAM(
|
||||
model,
|
||||
target_layers,
|
||||
use_cuda=True if 'cuda' in model.device else False,
|
||||
reshape_transform=reshape_transform,
|
||||
batch_size=batch_size,
|
||||
ablation_layer=extra_params['ablation_layer'],
|
||||
ratio_channels_to_ablate=ratio_channels_to_ablate)
|
||||
else:
|
||||
self.cam = method_class(
|
||||
model,
|
||||
target_layers,
|
||||
use_cuda=True if 'cuda' in model.device else False,
|
||||
reshape_transform=reshape_transform,
|
||||
)
|
||||
if self.is_need_grad:
|
||||
self.cam.activations_and_grads.release()
|
||||
|
||||
self.classes = model.detector.dataset_meta['CLASSES']
|
||||
self.COLORS = np.random.uniform(0, 255, size=(len(self.classes), 3))
|
||||
|
||||
def switch_activations_and_grads(self, model) -> None:
|
||||
"""In the grad-based method, we need to switch
|
||||
``ActivationsAndGradients`` layer, otherwise an error will occur."""
|
||||
self.cam.model = model
|
||||
|
||||
if self.is_need_grad is True:
|
||||
self.cam.activations_and_grads = ActivationsAndGradients(
|
||||
model, self.target_layers, self.reshape_transform)
|
||||
self.is_need_grad = False
|
||||
else:
|
||||
self.cam.activations_and_grads.release()
|
||||
self.is_need_grad = True
|
||||
|
||||
def __call__(self, img, targets, aug_smooth=False, eigen_smooth=False):
|
||||
img = torch.from_numpy(img)[None].permute(0, 3, 1, 2)
|
||||
return self.cam(img, targets, aug_smooth, eigen_smooth)[0, :]
|
||||
|
||||
def show_am(self,
|
||||
image: np.ndarray,
|
||||
pred_instance: InstanceData,
|
||||
grayscale_am: np.ndarray,
|
||||
with_norm_in_bboxes: bool = False):
|
||||
"""Normalize the AM to be in the range [0, 1] inside every bounding
|
||||
boxes, and zero outside of the bounding boxes."""
|
||||
|
||||
boxes = pred_instance.bboxes
|
||||
labels = pred_instance.labels
|
||||
|
||||
if with_norm_in_bboxes is True:
|
||||
boxes = boxes.astype(np.int32)
|
||||
renormalized_am = np.zeros(grayscale_am.shape, dtype=np.float32)
|
||||
images = []
|
||||
for x1, y1, x2, y2 in boxes:
|
||||
img = renormalized_am * 0
|
||||
img[y1:y2, x1:x2] = scale_cam_image(
|
||||
[grayscale_am[y1:y2, x1:x2].copy()])[0]
|
||||
images.append(img)
|
||||
|
||||
renormalized_am = np.max(np.float32(images), axis=0)
|
||||
renormalized_am = scale_cam_image([renormalized_am])[0]
|
||||
else:
|
||||
renormalized_am = grayscale_am
|
||||
|
||||
am_image_renormalized = show_cam_on_image(
|
||||
image / 255, renormalized_am, use_rgb=False)
|
||||
|
||||
image_with_bounding_boxes = self._draw_boxes(
|
||||
boxes, labels, am_image_renormalized, pred_instance.get('scores'))
|
||||
return image_with_bounding_boxes
|
||||
|
||||
def _draw_boxes(self,
|
||||
boxes: List,
|
||||
labels: List,
|
||||
image: np.ndarray,
|
||||
scores: Optional[List] = None):
|
||||
"""draw boxes on image."""
|
||||
for i, box in enumerate(boxes):
|
||||
label = labels[i]
|
||||
color = self.COLORS[label]
|
||||
cv2.rectangle(image, (int(box[0]), int(box[1])),
|
||||
(int(box[2]), int(box[3])), color, 2)
|
||||
if scores is not None:
|
||||
score = scores[i]
|
||||
text = str(self.classes[label]) + ': ' + str(
|
||||
round(score * 100, 1))
|
||||
else:
|
||||
text = self.classes[label]
|
||||
|
||||
cv2.putText(
|
||||
image,
|
||||
text, (int(box[0]), int(box[1] - 5)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
color,
|
||||
1,
|
||||
lineType=cv2.LINE_AA)
|
||||
return image
|
||||
|
||||
|
||||
class DetAblationLayer(AblationLayer):
|
||||
"""Det AblationLayer."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.activations = None
|
||||
|
||||
def set_next_batch(self, input_batch_index, activations,
|
||||
num_channels_to_ablate):
|
||||
"""Extract the next batch member from activations, and repeat it
|
||||
num_channels_to_ablate times."""
|
||||
if isinstance(activations, torch.Tensor):
|
||||
return super().set_next_batch(input_batch_index, activations,
|
||||
num_channels_to_ablate)
|
||||
|
||||
self.activations = []
|
||||
for activation in activations:
|
||||
activation = activation[
|
||||
input_batch_index, :, :, :].clone().unsqueeze(0)
|
||||
self.activations.append(
|
||||
activation.repeat(num_channels_to_ablate, 1, 1, 1))
|
||||
|
||||
def __call__(self, x):
|
||||
"""Go over the activation indices to be ablated, stored in
|
||||
self.indices."""
|
||||
result = self.activations
|
||||
|
||||
if isinstance(result, torch.Tensor):
|
||||
return super().__call__(x)
|
||||
|
||||
channel_cumsum = np.cumsum([r.shape[1] for r in result])
|
||||
num_channels_to_ablate = result[0].size(0) # batch
|
||||
for i in range(num_channels_to_ablate):
|
||||
pyramid_layer = bisect.bisect_right(channel_cumsum,
|
||||
self.indices[i])
|
||||
if pyramid_layer > 0:
|
||||
index_in_pyramid_layer = self.indices[i] - channel_cumsum[
|
||||
pyramid_layer - 1]
|
||||
else:
|
||||
index_in_pyramid_layer = self.indices[i]
|
||||
result[pyramid_layer][i, index_in_pyramid_layer, :, :] = -1000
|
||||
return result
|
||||
|
||||
|
||||
class DetBoxScoreTarget:
|
||||
"""Det Score calculation class.
|
||||
|
||||
In the case of the grad-free method, the calculation method is that
|
||||
for every original detected bounding box specified in "bboxes",
|
||||
assign a score on how the current bounding boxes match it,
|
||||
|
||||
1. In Bbox IoU
|
||||
2. In the classification score.
|
||||
3. In Mask IoU if ``segms`` exist.
|
||||
|
||||
If there is not a large enough overlap, or the category changed,
|
||||
assign a score of 0. The total score is the sum of all the box scores.
|
||||
|
||||
In the case of the grad-based method, the calculation method is
|
||||
the sum of losses after excluding a specific key.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pred_instance: InstanceData,
|
||||
match_iou_thr: float = 0.5,
|
||||
device: str = 'cuda:0',
|
||||
ignore_loss_params: Optional[List] = None):
|
||||
self.focal_bboxes = pred_instance.bboxes
|
||||
self.focal_labels = pred_instance.labels
|
||||
self.match_iou_thr = match_iou_thr
|
||||
self.device = device
|
||||
self.ignore_loss_params = ignore_loss_params
|
||||
if ignore_loss_params is not None:
|
||||
assert isinstance(self.ignore_loss_params, list)
|
||||
|
||||
def __call__(self, results):
|
||||
output = torch.tensor([0.], device=self.device)
|
||||
|
||||
if 'loss_cls' in results:
|
||||
# grad-based method
|
||||
# results is dict
|
||||
for loss_key, loss_value in results.items():
|
||||
if 'loss' not in loss_key or \
|
||||
loss_key in self.ignore_loss_params:
|
||||
continue
|
||||
if isinstance(loss_value, list):
|
||||
output += sum(loss_value)
|
||||
else:
|
||||
output += loss_value
|
||||
return output
|
||||
else:
|
||||
# grad-free method
|
||||
# results is DetDataSample
|
||||
pred_instances = results.pred_instances
|
||||
if len(pred_instances) == 0:
|
||||
return output
|
||||
|
||||
pred_bboxes = pred_instances.bboxes
|
||||
pred_scores = pred_instances.scores
|
||||
pred_labels = pred_instances.labels
|
||||
|
||||
for focal_box, focal_label in zip(self.focal_bboxes,
|
||||
self.focal_labels):
|
||||
ious = torchvision.ops.box_iou(focal_box[None],
|
||||
pred_bboxes[..., :4])
|
||||
index = ious.argmax()
|
||||
if ious[0, index] > self.match_iou_thr and pred_labels[
|
||||
index] == focal_label:
|
||||
# TODO: Adaptive adjustment of weights based on algorithms
|
||||
score = ious[0, index] + pred_scores[index]
|
||||
output = output + score
|
||||
return output
|
||||
|
||||
|
||||
class SpatialBaseCAM(BaseCAM):
|
||||
"""CAM that maintains spatial information.
|
||||
|
||||
Gradients are often averaged over the spatial dimension in CAM
|
||||
visualization for classification, but this is unreasonable in detection
|
||||
tasks. There is no need to average the gradients in the detection task.
|
||||
"""
|
||||
|
||||
def get_cam_image(self,
|
||||
input_tensor: torch.Tensor,
|
||||
target_layer: torch.nn.Module,
|
||||
targets: List[torch.nn.Module],
|
||||
activations: torch.Tensor,
|
||||
grads: torch.Tensor,
|
||||
eigen_smooth: bool = False) -> np.ndarray:
|
||||
|
||||
weights = self.get_cam_weights(input_tensor, target_layer, targets,
|
||||
activations, grads)
|
||||
weighted_activations = weights * activations
|
||||
if eigen_smooth:
|
||||
cam = get_2d_projection(weighted_activations)
|
||||
else:
|
||||
cam = weighted_activations.sum(axis=1)
|
||||
return cam
|
||||
|
||||
|
||||
class GradCAM(SpatialBaseCAM, Base_GradCAM):
|
||||
"""Gradients are no longer averaged over the spatial dimension."""
|
||||
|
||||
def get_cam_weights(self, input_tensor, target_layer, target_category,
|
||||
activations, grads):
|
||||
return grads
|
||||
|
||||
|
||||
class GradCAMPlusPlus(SpatialBaseCAM, Base_GradCAMPlusPlus):
|
||||
"""Gradients are no longer averaged over the spatial dimension."""
|
||||
|
||||
def get_cam_weights(self, input_tensor, target_layers, target_category,
|
||||
activations, grads):
|
||||
grads_power_2 = grads**2
|
||||
grads_power_3 = grads_power_2 * grads
|
||||
# Equation 19 in https://arxiv.org/abs/1710.11063
|
||||
sum_activations = np.sum(activations, axis=(2, 3))
|
||||
eps = 0.000001
|
||||
aij = grads_power_2 / (
|
||||
2 * grads_power_2 +
|
||||
sum_activations[:, :, None, None] * grads_power_3 + eps)
|
||||
# Now bring back the ReLU from eq.7 in the paper,
|
||||
# And zero out aijs where the activations are 0
|
||||
aij = np.where(grads != 0, aij, 0)
|
||||
|
||||
weights = np.maximum(grads, 0) * aij
|
||||
return weights
|
Loading…
Reference in New Issue