275 lines
10 KiB
Python
275 lines
10 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import copy
|
|
import math
|
|
import pkg_resources
|
|
from functools import partial
|
|
from pathlib import Path
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
from mmcv.transforms import Compose
|
|
from mmengine.config import Config, DictAction
|
|
from mmengine.dataset import default_collate
|
|
from mmengine.utils import to_2tuple
|
|
from mmengine.utils.dl_utils import is_norm
|
|
|
|
from mmpretrain import digit_version
|
|
from mmpretrain.apis import get_model
|
|
from mmpretrain.registry import TRANSFORMS
|
|
|
|
try:
|
|
import pytorch_grad_cam as cam
|
|
from pytorch_grad_cam.activations_and_gradients import \
|
|
ActivationsAndGradients
|
|
from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
except ImportError:
|
|
raise ImportError('Please run `pip install "grad-cam>=1.3.6"` to install '
|
|
'3rd party package pytorch_grad_cam.')
|
|
|
|
# Alias name
|
|
METHOD_MAP = {
|
|
'gradcam++': cam.GradCAMPlusPlus,
|
|
}
|
|
METHOD_MAP.update({
|
|
cam_class.__name__.lower(): cam_class
|
|
for cam_class in cam.base_cam.BaseCAM.__subclasses__()
|
|
})
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Visualize CAM')
|
|
parser.add_argument('img', help='Image file')
|
|
parser.add_argument('config', help='Config file')
|
|
parser.add_argument('checkpoint', help='Checkpoint file')
|
|
parser.add_argument(
|
|
'--target-layers',
|
|
default=[],
|
|
nargs='+',
|
|
type=str,
|
|
help='The target layers to get CAM, if not set, the tool will '
|
|
'specify the norm layer in the last block. Backbones '
|
|
'implemented by users are recommended to manually specify'
|
|
' target layers in commmad statement.')
|
|
parser.add_argument(
|
|
'--preview-model',
|
|
default=False,
|
|
action='store_true',
|
|
help='To preview all the model layers')
|
|
parser.add_argument(
|
|
'--method',
|
|
default='GradCAM',
|
|
help='Type of method to use, supports '
|
|
f'{", ".join(list(METHOD_MAP.keys()))}.')
|
|
parser.add_argument(
|
|
'--target-category',
|
|
default=[],
|
|
nargs='+',
|
|
type=int,
|
|
help='The target category to get CAM, default to use result '
|
|
'get from given model.')
|
|
parser.add_argument(
|
|
'--eigen-smooth',
|
|
default=False,
|
|
action='store_true',
|
|
help='Reduce noise by taking the first principle componenet of '
|
|
'``cam_weights*activations``')
|
|
parser.add_argument(
|
|
'--aug-smooth',
|
|
default=False,
|
|
action='store_true',
|
|
help='Wether to use test time augmentation, default not to use')
|
|
parser.add_argument(
|
|
'--save-path',
|
|
type=Path,
|
|
help='The path to save visualize cam image, default not to save.')
|
|
parser.add_argument('--device', default='cpu', help='Device to use cpu')
|
|
parser.add_argument(
|
|
'--vit-like',
|
|
action='store_true',
|
|
help='Whether the network is a ViT-like network.')
|
|
parser.add_argument(
|
|
'--num-extra-tokens',
|
|
type=int,
|
|
help='The number of extra tokens in ViT-like backbones. Defaults to'
|
|
' use num_extra_tokens of the backbone.')
|
|
parser.add_argument(
|
|
'--cfg-options',
|
|
nargs='+',
|
|
action=DictAction,
|
|
help='override some settings in the used config, the key-value pair '
|
|
'in xxx=yyy format will be merged into config file. If the value to '
|
|
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
|
'Note that the quotation marks are necessary and that no white space '
|
|
'is allowed.')
|
|
args = parser.parse_args()
|
|
if args.method.lower() not in METHOD_MAP.keys():
|
|
raise ValueError(f'invalid CAM type {args.method},'
|
|
f' supports {", ".join(list(METHOD_MAP.keys()))}.')
|
|
|
|
return args
|
|
|
|
|
|
def reshape_transform(tensor, model, args):
|
|
"""Build reshape_transform for `cam.activations_and_grads`, which is
|
|
necessary for ViT-like networks."""
|
|
# ViT_based_Transformers have an additional clstoken in features
|
|
if tensor.ndim == 4:
|
|
# For (B, C, H, W)
|
|
return tensor
|
|
elif tensor.ndim == 3:
|
|
if not args.vit_like:
|
|
raise ValueError(f"The tensor shape is {tensor.shape}, if it's a "
|
|
'vit-like backbone, please specify `--vit-like`.')
|
|
# For (B, L, C)
|
|
num_extra_tokens = args.num_extra_tokens or getattr(
|
|
model.backbone, 'num_extra_tokens', 1)
|
|
|
|
tensor = tensor[:, num_extra_tokens:, :]
|
|
# get heat_map_height and heat_map_width, preset input is a square
|
|
heat_map_area = tensor.size()[1]
|
|
height, width = to_2tuple(int(math.sqrt(heat_map_area)))
|
|
assert height * height == heat_map_area, \
|
|
(f"The input feature's length ({heat_map_area+num_extra_tokens}) "
|
|
f'minus num-extra-tokens ({num_extra_tokens}) is {heat_map_area},'
|
|
' which is not a perfect square number. Please check if you used '
|
|
'a wrong num-extra-tokens.')
|
|
# (B, L, C) -> (B, H, W, C)
|
|
result = tensor.reshape(tensor.size(0), height, width, tensor.size(2))
|
|
# (B, H, W, C) -> (B, C, H, W)
|
|
result = result.permute(0, 3, 1, 2)
|
|
return result
|
|
else:
|
|
raise ValueError(f'Unsupported tensor shape {tensor.shape}.')
|
|
|
|
|
|
def init_cam(method, model, target_layers, use_cuda, reshape_transform):
|
|
"""Construct the CAM object once, In order to be compatible with
|
|
mmpretrain, here we modify the ActivationsAndGradients object."""
|
|
GradCAM_Class = METHOD_MAP[method.lower()]
|
|
cam = GradCAM_Class(
|
|
model=model, target_layers=target_layers, use_cuda=use_cuda)
|
|
# Release the original hooks in ActivationsAndGradients to use
|
|
# ActivationsAndGradients.
|
|
cam.activations_and_grads.release()
|
|
cam.activations_and_grads = ActivationsAndGradients(
|
|
cam.model, cam.target_layers, reshape_transform)
|
|
|
|
return cam
|
|
|
|
|
|
def get_layer(layer_str, model):
|
|
"""get model layer from given str."""
|
|
for name, layer in model.named_modules():
|
|
if name == layer_str:
|
|
return layer
|
|
raise AttributeError(
|
|
f'Cannot get the layer "{layer_str}". Please choose from: \n' +
|
|
'\n'.join(name for name, _ in model.named_modules()))
|
|
|
|
|
|
def show_cam_grad(grayscale_cam, src_img, title, out_path=None):
|
|
"""fuse src_img and grayscale_cam and show or save."""
|
|
grayscale_cam = grayscale_cam[0, :]
|
|
src_img = np.float32(src_img) / 255
|
|
visualization_img = show_cam_on_image(
|
|
src_img, grayscale_cam, use_rgb=False)
|
|
|
|
if out_path:
|
|
mmcv.imwrite(visualization_img, str(out_path))
|
|
else:
|
|
mmcv.imshow(visualization_img, win_name=title)
|
|
|
|
|
|
def get_default_target_layers(model, args):
|
|
"""get default target layers from given model, here choose nrom type layer
|
|
as default target layer."""
|
|
norm_layers = [
|
|
(name, layer)
|
|
for name, layer in model.backbone.named_modules(prefix='backbone')
|
|
if is_norm(layer)
|
|
]
|
|
if args.vit_like:
|
|
# For ViT models, the final classification is done on the class token.
|
|
# And the patch tokens and class tokens won't interact each other after
|
|
# the final attention layer. Therefore, we need to choose the norm
|
|
# layer before the last attention layer.
|
|
num_extra_tokens = args.num_extra_tokens or getattr(
|
|
model.backbone, 'num_extra_tokens', 1)
|
|
|
|
# models like swin have no attr 'out_type', set out_type to avg_featmap
|
|
out_type = getattr(model.backbone, 'out_type', 'avg_featmap')
|
|
if out_type == 'cls_token' or num_extra_tokens > 0:
|
|
# Assume the backbone feature is class token.
|
|
name, layer = norm_layers[-3]
|
|
print('Automatically choose the last norm layer before the '
|
|
f'final attention block "{name}" as the target layer.')
|
|
return [layer]
|
|
|
|
# For CNN models, use the last norm layer as the target-layer
|
|
name, layer = norm_layers[-1]
|
|
print('Automatically choose the last norm layer '
|
|
f'"{name}" as the target layer.')
|
|
return [layer]
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
cfg = Config.fromfile(args.config)
|
|
if args.cfg_options is not None:
|
|
cfg.merge_from_dict(args.cfg_options)
|
|
|
|
# build the model from a config file and a checkpoint file
|
|
model: nn.Module = get_model(cfg, args.checkpoint, device=args.device)
|
|
if args.preview_model:
|
|
print(model)
|
|
print('\n Please remove `--preview-model` to get the CAM.')
|
|
return
|
|
|
|
# apply transform and perpare data
|
|
transforms = Compose(
|
|
[TRANSFORMS.build(t) for t in cfg.test_dataloader.dataset.pipeline])
|
|
data = transforms({'img_path': args.img})
|
|
src_img = copy.deepcopy(data['inputs']).numpy().transpose(1, 2, 0)
|
|
data = model.data_preprocessor(default_collate([data]), False)
|
|
|
|
# build target layers
|
|
if args.target_layers:
|
|
target_layers = [
|
|
get_layer(layer, model) for layer in args.target_layers
|
|
]
|
|
else:
|
|
target_layers = get_default_target_layers(model, args)
|
|
|
|
# init a cam grad calculator
|
|
use_cuda = ('cuda' in args.device)
|
|
cam = init_cam(args.method, model, target_layers, use_cuda,
|
|
partial(reshape_transform, model=model, args=args))
|
|
|
|
# warp the target_category with ClassifierOutputTarget in grad_cam>=1.3.7,
|
|
# to fix the bug in #654.
|
|
targets = None
|
|
if args.target_category:
|
|
grad_cam_v = pkg_resources.get_distribution('grad_cam').version
|
|
if digit_version(grad_cam_v) >= digit_version('1.3.7'):
|
|
from pytorch_grad_cam.utils.model_targets import \
|
|
ClassifierOutputTarget
|
|
targets = [ClassifierOutputTarget(c) for c in args.target_category]
|
|
else:
|
|
targets = args.target_category
|
|
|
|
# calculate cam grads and show|save the visualization image
|
|
grayscale_cam = cam(
|
|
data['inputs'],
|
|
targets,
|
|
eigen_smooth=args.eigen_smooth,
|
|
aug_smooth=args.aug_smooth)
|
|
show_cam_grad(
|
|
grayscale_cam, src_img, title=args.method, out_path=args.save_path)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|