156 lines
5.5 KiB
Python
156 lines
5.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import os
|
|
|
|
import mmcv
|
|
import torch
|
|
from mmengine.config import Config, DictAction
|
|
from mmengine.registry import VISUALIZERS
|
|
from mmengine.utils import import_modules_from_strings
|
|
|
|
from mmrazor.models.task_modules import RecorderManager
|
|
from mmrazor.utils import register_all_modules
|
|
from mmrazor.visualization.local_visualizer import modify
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Feature map visualization')
|
|
parser.add_argument('img', help='Image file')
|
|
parser.add_argument('config', help='train config file path')
|
|
parser.add_argument('vis_config', help='visualization config file path')
|
|
parser.add_argument('checkpoint', help='Checkpoint file')
|
|
parser.add_argument('--out-file', default=None, help='Path to output file')
|
|
parser.add_argument(
|
|
'--device', default='cpu', help='Device used for inference')
|
|
parser.add_argument('--repo', help='the corresponding repo name')
|
|
parser.add_argument(
|
|
'--use-norm',
|
|
action='store_true',
|
|
help='normalize the featmap before visualization')
|
|
parser.add_argument(
|
|
'--overlaid', action='store_true', help='overlaid image')
|
|
parser.add_argument(
|
|
'--channel-reduction',
|
|
help='Reduce multiple channels to a single channel. The optional value'
|
|
' is \'squeeze_mean\', \'select_max\' or \'pixel_wise_max\'.',
|
|
default=None)
|
|
parser.add_argument(
|
|
'--topk',
|
|
type=int,
|
|
help='If channel_reduction is not None and topk > 0, it will select '
|
|
'topk channel to show by the sum of each channel. If topk <= 0, '
|
|
'tensor_chw is assert to be one or three.',
|
|
default=20)
|
|
parser.add_argument(
|
|
'--arrangement',
|
|
nargs='+',
|
|
type=int,
|
|
help='the arrangement of featmap when channel_reduction is not None '
|
|
'and topk > 0.',
|
|
default=[4, 5])
|
|
parser.add_argument(
|
|
'--resize-shape',
|
|
nargs='+',
|
|
type=int,
|
|
help='the shape to scale the feature map',
|
|
default=None)
|
|
parser.add_argument(
|
|
'--alpha', help='the transparency of featmap', default=0.5)
|
|
parser.add_argument(
|
|
'--cfg-options',
|
|
nargs='+',
|
|
action=DictAction,
|
|
help='override some settings in the used config, the key-value pair '
|
|
'in xxx=yyy format will be merged into config file. If the value to '
|
|
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
|
'Note that the quotation marks are necessary and that no white space '
|
|
'is allowed.',
|
|
default={})
|
|
|
|
parser.add_argument('--local_rank', type=int, default=0)
|
|
|
|
args = parser.parse_args()
|
|
if 'LOCAL_RANK' not in os.environ:
|
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
|
|
|
return args
|
|
|
|
|
|
def norm(feat):
|
|
N, C, H, W = feat.shape
|
|
feat = feat.permute(1, 0, 2, 3).reshape(C, -1)
|
|
mean = feat.mean(dim=-1, keepdim=True)
|
|
std = feat.std(dim=-1, keepdim=True)
|
|
centered = (feat - mean) / (std + 1e-6)
|
|
centered = centered.reshape(C, N, H, W).permute(1, 0, 2, 3)
|
|
return centered
|
|
|
|
|
|
def main(args):
|
|
register_all_modules(False)
|
|
mod = import_modules_from_strings(f'{args.repo}.utils')
|
|
mod.register_all_modules()
|
|
|
|
apis = import_modules_from_strings(f'{args.repo}.apis')
|
|
inference_model, init_model = None, None
|
|
for attr_name in dir(apis):
|
|
if 'inference_' in attr_name:
|
|
inference_model = getattr(apis, attr_name)
|
|
if 'init_' in attr_name:
|
|
init_model = getattr(apis, attr_name)
|
|
assert inference_model and init_model
|
|
|
|
model = init_model(args.config, args.checkpoint, device=args.device)
|
|
# init visualizer
|
|
visualizer = VISUALIZERS.build(model.cfg.visualizer)
|
|
visualizer.draw_featmap = modify
|
|
|
|
visualization_cfg = Config.fromfile(args.vis_config)
|
|
recorder_cfg = visualization_cfg.recorders
|
|
mappings = visualization_cfg.mappings
|
|
recorder_manager = RecorderManager(recorder_cfg)
|
|
recorder_manager.initialize(model)
|
|
|
|
with recorder_manager:
|
|
# test a single image
|
|
result = inference_model(model, args.img)
|
|
|
|
overlaid_image = mmcv.imread(
|
|
args.img, channel_order='rgb') if args.overlaid else None
|
|
|
|
for name, record in mappings.items():
|
|
recorder = recorder_manager.get_recorder(record.recorder)
|
|
record_idx = getattr(record, 'record_idx', 0)
|
|
data_idx = getattr(record, 'data_idx')
|
|
feats = recorder.get_record_data(record_idx, data_idx)
|
|
if isinstance(feats, torch.Tensor):
|
|
feats = (feats, )
|
|
|
|
for i, feat in enumerate(feats):
|
|
if args.use_norm:
|
|
feat = norm(feat)
|
|
drawn_img = visualizer.draw_featmap(
|
|
feat[0],
|
|
overlaid_image,
|
|
args.channel_reduction,
|
|
topk=args.topk,
|
|
arrangement=tuple(args.arrangement),
|
|
resize_shape=tuple(args.resize_shape)
|
|
if args.resize_shape else None,
|
|
alpha=args.alpha)
|
|
visualizer.add_datasample(
|
|
f'{name}_{i}',
|
|
drawn_img,
|
|
data_sample=result,
|
|
draw_gt=False,
|
|
show=args.out_file is None,
|
|
wait_time=0.1,
|
|
out_file=args.out_file,
|
|
**args.cfg_options)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
main(args)
|