170 lines
5.9 KiB
Python
170 lines
5.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import os
|
|
|
|
import mmcv
|
|
import torch
|
|
from mmengine.config import Config
|
|
from mmengine.registry import VISUALIZERS
|
|
from mmengine.utils import import_modules_from_strings
|
|
|
|
from mmrazor.models.task_modules import RecorderManager
|
|
from mmrazor.utils import register_all_modules
|
|
from mmrazor.visualization.local_visualizer import modify
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Feature map visualization')
|
|
parser.add_argument('img', help='Image file')
|
|
parser.add_argument(
|
|
'config1', help='train config file path for the first model')
|
|
parser.add_argument(
|
|
'config2', help='train config file path for the second model')
|
|
parser.add_argument('vis_config', help='visualization config file path')
|
|
parser.add_argument(
|
|
'checkpoint1', help='Checkpoint file for the first model')
|
|
parser.add_argument(
|
|
'checkpoint2', help='Checkpoint file for the second model')
|
|
parser.add_argument('--out-file', default=None, help='Path to output file')
|
|
parser.add_argument(
|
|
'--device', default='cpu', help='Device used for inference')
|
|
parser.add_argument('--repo', help='the corresponding repo name')
|
|
parser.add_argument(
|
|
'--use-norm',
|
|
action='store_true',
|
|
help='normalize the featmap before visualization')
|
|
parser.add_argument(
|
|
'--overlaid', action='store_true', help='overlaid image')
|
|
parser.add_argument(
|
|
'--channel-reduction',
|
|
help='Reduce multiple channels to a single channel. The optional value'
|
|
' is \'squeeze_mean\', \'select_max\' or \'pixel_wise_max\'.',
|
|
default=None)
|
|
parser.add_argument(
|
|
'--topk',
|
|
help='If channel_reduction is not None and topk > 0, it will select '
|
|
'topk channel to show by the sum of each channel. If topk <= 0, '
|
|
'tensor_chw is assert to be one or three.',
|
|
type=int,
|
|
default=20)
|
|
parser.add_argument(
|
|
'--arrangement',
|
|
nargs='+',
|
|
type=int,
|
|
help='the arrangement of featmap when channel_reduction is not None '
|
|
'and topk > 0.',
|
|
default=[4, 5])
|
|
parser.add_argument(
|
|
'--resize-shape',
|
|
nargs='+',
|
|
type=int,
|
|
help='the shape to scale the feature map',
|
|
default=None)
|
|
parser.add_argument(
|
|
'--alpha', help='the transparency of featmap', default=0.5)
|
|
|
|
parser.add_argument('--local_rank', type=int, default=0)
|
|
|
|
args = parser.parse_args()
|
|
if 'LOCAL_RANK' not in os.environ:
|
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
|
|
|
return args
|
|
|
|
|
|
def norm(feat):
|
|
N, C, H, W = feat.shape
|
|
feat = feat.permute(1, 0, 2, 3).reshape(C, -1)
|
|
mean = feat.mean(dim=-1, keepdim=True)
|
|
std = feat.std(dim=-1, keepdim=True)
|
|
centered = (feat - mean) / (std + 1e-6)
|
|
centered = centered.reshape(C, N, H, W).permute(1, 0, 2, 3)
|
|
return centered
|
|
|
|
|
|
def main(args):
|
|
register_all_modules(False)
|
|
mod = import_modules_from_strings(f'{args.repo}.utils')
|
|
mod.register_all_modules()
|
|
|
|
apis = import_modules_from_strings(f'{args.repo}.apis')
|
|
inference_model, init_model = None, None
|
|
for attr_name in dir(apis):
|
|
if 'inference_' in attr_name:
|
|
inference_model = getattr(apis, attr_name)
|
|
if 'init_' in attr_name:
|
|
init_model = getattr(apis, attr_name)
|
|
assert inference_model and init_model
|
|
|
|
model1 = init_model(args.config1, args.checkpoint1, device=args.device)
|
|
# init visualizer
|
|
visualizer = VISUALIZERS.build(model1.cfg.visualizer)
|
|
visualizer.draw_featmap = modify
|
|
|
|
model2 = init_model(args.config2, args.checkpoint2, device=args.device)
|
|
|
|
visualization_cfg = Config.fromfile(args.vis_config)
|
|
recorder_cfg1 = visualization_cfg.recorders1
|
|
mappings1 = visualization_cfg.mappings1
|
|
recorder_cfg2 = visualization_cfg.recorders2
|
|
mappings2 = visualization_cfg.mappings2
|
|
|
|
recorder_manager1 = RecorderManager(recorder_cfg1)
|
|
recorder_manager1.initialize(model1)
|
|
|
|
recorder_manager2 = RecorderManager(recorder_cfg2)
|
|
recorder_manager2.initialize(model2)
|
|
|
|
with recorder_manager1:
|
|
# test a single image
|
|
_ = inference_model(model1, args.img)
|
|
|
|
with recorder_manager2:
|
|
# test a single image
|
|
_ = inference_model(model2, args.img)
|
|
|
|
overlaid_image = mmcv.imread(
|
|
args.img, channel_order='rgb') if args.overlaid else None
|
|
|
|
for name1, name2 in zip(mappings1.keys(), mappings2.keys()):
|
|
record1 = mappings1[name1]
|
|
recorder1 = recorder_manager1.get_recorder(record1.recorder)
|
|
record_idx = getattr(record1, 'record_idx', 0)
|
|
data_idx = getattr(record1, 'data_idx')
|
|
feats1 = recorder1.get_record_data(record_idx, data_idx)
|
|
if isinstance(feats1, torch.Tensor):
|
|
feats1 = (feats1, )
|
|
|
|
record2 = mappings2[name2]
|
|
recorder2 = recorder_manager2.get_recorder(record2.recorder)
|
|
record_idx = getattr(record2, 'record_idx', 0)
|
|
data_idx = getattr(record2, 'data_idx')
|
|
feats2 = recorder2.get_record_data(record_idx, data_idx)
|
|
if isinstance(feats2, torch.Tensor):
|
|
feats2 = (feats2, )
|
|
|
|
for i, (feat1, feat2) in enumerate(zip(feats1, feats2)):
|
|
diff = torch.abs(feat1 - feat2)
|
|
if args.use_norm:
|
|
diff = norm(diff)
|
|
drawn_img = visualizer.draw_featmap(
|
|
diff[0],
|
|
overlaid_image,
|
|
args.channel_reduction,
|
|
topk=args.topk,
|
|
arrangement=tuple(args.arrangement),
|
|
resize_shape=tuple(args.resize_shape)
|
|
if args.resize_shape else None,
|
|
alpha=args.alpha)
|
|
visualizer.add_datasample(
|
|
f'model1_{name1}_model2_{name2}_{i}',
|
|
drawn_img,
|
|
show=args.out_file is None,
|
|
wait_time=0.1,
|
|
out_file=args.out_file)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
main(args)
|