# Copyright (c) OpenMMLab. All rights reserved. import argparse import copy import math import re from pathlib import Path import mmcv import numpy as np from mmcv import Config, DictAction from mmcv.utils import to_2tuple from torch.nn import BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm from mmcls.apis import init_model from mmcls.datasets.pipelines import Compose try: from pytorch_grad_cam import (EigenCAM, GradCAM, GradCAMPlusPlus, XGradCAM, EigenGradCAM, LayerCAM) from pytorch_grad_cam.activations_and_gradients import ( ActivationsAndGradients) from pytorch_grad_cam.utils.image import show_cam_on_image except ImportError: raise ImportError('Please run `pip install "grad-cam>=1.3.6"` to install ' '3rd party package pytorch_grad_cam.') # set of transforms, which just change data format, not change the pictures FORMAT_TRANSFORMS_SET = {'ToTensor', 'Normalize', 'ImageToTensor', 'Collect'} # Supported grad-cam type map METHOD_MAP = { 'gradcam': GradCAM, 'gradcam++': GradCAMPlusPlus, 'xgradcam': XGradCAM, 'eigencam': EigenCAM, 'eigengradcam': EigenGradCAM, 'layercam': LayerCAM, } def parse_args(): parser = argparse.ArgumentParser(description='Visualize CAM') parser.add_argument('img', help='Image file') parser.add_argument('config', help='Config file') parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument( '--target-layers', default=[], nargs='+', type=str, help='The target layers to get CAM, if not set, the tool will ' 'specify the norm layer in the last block. Backbones ' 'implemented by users are recommended to manually specify' ' target layers in commmad statement.') parser.add_argument( '--preview-model', default=False, action='store_true', help='To preview all the model layers') parser.add_argument( '--method', default='GradCAM', help='Type of method to use, supports ' f'{", ".join(list(METHOD_MAP.keys()))}.') parser.add_argument( '--target-category', default=None, type=int, help='The target category to get CAM, default to use result ' 'get from given model.') parser.add_argument( '--eigen-smooth', default=False, action='store_true', help='Reduce noise by taking the first principle componenet of ' '``cam_weights*activations``') parser.add_argument( '--aug-smooth', default=False, action='store_true', help='Wether to use test time augmentation, default not to use') parser.add_argument( '--save-path', type=Path, help='The path to save visualize cam image, default not to save.') parser.add_argument('--device', default='cpu', help='Device to use cpu') parser.add_argument( '--vit-like', action='store_true', help='Whether the network is a ViT-like network.') parser.add_argument( '--num-extra-tokens', type=int, help='The number of extra tokens in ViT-like backbones. Defaults to' ' use num_extra_tokens of the backbone.') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') args = parser.parse_args() if args.method.lower() not in METHOD_MAP.keys(): raise ValueError(f'invalid CAM type {args.method},' f' supports {", ".join(list(METHOD_MAP.keys()))}.') return args def build_reshape_transform(model, args): """Build reshape_transform for `cam.activations_and_grads`, which is necessary for ViT-like networks.""" # ViT_based_Transformers have an additional clstoken in features if not args.vit_like: def check_shape(tensor): assert len(tensor.size()) != 3, \ (f"The input feature's shape is {tensor.size()}, and it seems " 'to have been flattened or from a vit-like network. ' "Please use `--vit-like` if it's from a vit-like network.") return tensor return check_shape if args.num_extra_tokens is not None: num_extra_tokens = args.num_extra_tokens elif hasattr(model.backbone, 'num_extra_tokens'): num_extra_tokens = model.backbone.num_extra_tokens else: num_extra_tokens = 1 def _reshape_transform(tensor): """reshape_transform helper.""" assert len(tensor.size()) == 3, \ (f"The input feature's shape is {tensor.size()}, " 'and the feature seems not from a vit-like network?') tensor = tensor[:, num_extra_tokens:, :] # get heat_map_height and heat_map_width, preset input is a square heat_map_area = tensor.size()[1] height, width = to_2tuple(int(math.sqrt(heat_map_area))) assert height * height == heat_map_area, \ (f"The input feature's length ({heat_map_area+num_extra_tokens}) " f'minus num-extra-tokens ({num_extra_tokens}) is {heat_map_area},' ' which is not a perfect square number. Please check if you used ' 'a wrong num-extra-tokens.') result = tensor.reshape(tensor.size(0), height, width, tensor.size(2)) # Bring the channels to the first dimension, like in CNNs. result = result.transpose(2, 3).transpose(1, 2) return result return _reshape_transform def apply_transforms(img_path, pipeline_cfg): """Apply transforms pipeline and get both formatted data and the image without formatting.""" data = dict(img_info=dict(filename=img_path), img_prefix=None) def split_pipeline_cfg(pipeline_cfg): """to split the transfoms into image_transforms and format_transforms.""" image_transforms_cfg, format_transforms_cfg = [], [] if pipeline_cfg[0]['type'] != 'LoadImageFromFile': pipeline_cfg.insert(0, dict(type='LoadImageFromFile')) for transform in pipeline_cfg: if transform['type'] in FORMAT_TRANSFORMS_SET: format_transforms_cfg.append(transform) else: image_transforms_cfg.append(transform) return image_transforms_cfg, format_transforms_cfg image_transforms, format_transforms = split_pipeline_cfg(pipeline_cfg) image_transforms = Compose(image_transforms) format_transforms = Compose(format_transforms) intermediate_data = image_transforms(data) inference_img = copy.deepcopy(intermediate_data['img']) format_data = format_transforms(intermediate_data) return format_data, inference_img class MMActivationsAndGradients(ActivationsAndGradients): """Activations and gradients manager for mmcls models.""" def __call__(self, x): self.gradients = [] self.activations = [] return self.model( x, return_loss=False, softmax=False, post_process=False) def init_cam(method, model, target_layers, use_cuda, reshape_transform): """Construct the CAM object once, In order to be compatible with mmcls, here we modify the ActivationsAndGradients object.""" GradCAM_Class = METHOD_MAP[method.lower()] cam = GradCAM_Class( model=model, target_layers=target_layers, use_cuda=use_cuda) # Release the original hooks in ActivationsAndGradients to use # MMActivationsAndGradients. cam.activations_and_grads.release() cam.activations_and_grads = MMActivationsAndGradients( cam.model, cam.target_layers, reshape_transform) return cam def get_layer(layer_str, model): """get model layer from given str.""" cur_layer = model layer_names = layer_str.strip().split('.') def get_children_by_name(model, name): try: return getattr(model, name) except AttributeError as e: raise AttributeError( e.args[0] + '. Please use `--preview-model` to check keys at first.') def get_children_by_eval(model, name): try: return eval(f'model{name}', {}, {'model': model}) except (AttributeError, IndexError) as e: raise AttributeError( e.args[0] + '. Please use `--preview-model` to check keys at first.') for layer_name in layer_names: match_res = re.match('(?P.+?)(?P(\\[.+\\])+)', layer_name) if match_res: layer_name = match_res.groupdict()['name'] indices = match_res.groupdict()['indices'] cur_layer = get_children_by_name(cur_layer, layer_name) cur_layer = get_children_by_eval(cur_layer, indices) else: cur_layer = get_children_by_name(cur_layer, layer_name) return cur_layer def show_cam_grad(grayscale_cam, src_img, title, out_path=None): """fuse src_img and grayscale_cam and show or save.""" grayscale_cam = grayscale_cam[0, :] src_img = np.float32(src_img) / 255 visualization_img = show_cam_on_image( src_img, grayscale_cam, use_rgb=False) if out_path: mmcv.imwrite(visualization_img, str(out_path)) else: mmcv.imshow(visualization_img, win_name=title) def get_default_traget_layers(model, args): """get default target layers from given model, here choose nrom type layer as default target layer.""" norm_layers = [] for m in model.backbone.modules(): if isinstance(m, (BatchNorm2d, LayerNorm, GroupNorm, BatchNorm1d)): norm_layers.append(m) if len(norm_layers) == 0: raise ValueError( '`--target-layers` is empty. Please use `--preview-model`' ' to check keys at first and then specify `target-layers`.') # if the model is CNN model or Swin model, just use the last norm # layer as the target-layer, if the model is ViT model, the final # classification is done on the class token computed in the last # attention block, the output will not be affected by the 14x14 # channels in the last layer. The gradient of the output with # respect to them, will be 0! here use the last 3rd norm layer. # means the first norm of the last decoder block. if args.vit_like: if args.num_extra_tokens: num_extra_tokens = args.num_extra_tokens elif hasattr(model.backbone, 'num_extra_tokens'): num_extra_tokens = model.backbone.num_extra_tokens else: raise AttributeError('Please set num_extra_tokens in backbone' " or using 'num-extra-tokens'") # if a vit-like backbone's num_extra_tokens bigger than 0, view it # as a VisionTransformer backbone, eg. DeiT, T2T-ViT. if num_extra_tokens >= 1: print('Automatically choose the last norm layer before the ' 'final attention block as target_layer..') return [norm_layers[-3]] print('Automatically choose the last norm layer as target_layer.') target_layers = [norm_layers[-1]] return target_layers def main(): args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # build the model from a config file and a checkpoint file model = init_model(cfg, args.checkpoint, device=args.device) if args.preview_model: print(model) print('\n Please remove `--preview-model` to get the CAM.') return # apply transform and perpare data data, src_img = apply_transforms(args.img, cfg.data.test.pipeline) # build target layers if args.target_layers: target_layers = [ get_layer(layer, model) for layer in args.target_layers ] else: target_layers = get_default_traget_layers(model, args) # init a cam grad calculator use_cuda = ('cuda' in args.device) reshape_transform = build_reshape_transform(model, args) cam = init_cam(args.method, model, target_layers, use_cuda, reshape_transform) # calculate cam grads and show|save the visualization image grayscale_cam = cam( input_tensor=data['img'].unsqueeze(0), target_category=args.target_category, eigen_smooth=args.eigen_smooth, aug_smooth=args.aug_smooth) show_cam_grad( grayscale_cam, src_img, title=args.method, out_path=args.save_path) if __name__ == '__main__': main()