add visactmap code; upgrade visrank to plot single figure

pull/218/head
KaiyangZhou 2019-08-03 17:35:06 +01:00
parent c809e361b7
commit 27d4f2f093
5 changed files with 184 additions and 47 deletions

View File

@ -164,8 +164,10 @@ def init_parser():
parser.add_argument('--visrank', action='store_true',
help='visualize ranked results, only available in evaluation mode')
parser.add_argument('--visrank-topk', type=int, default=20,
parser.add_argument('--visrank-topk', type=int, default=15,
help='visualize topk ranks')
parser.add_argument('--visactmap', action='store_true',
help='visualize CNN activation maps')
# ************************************************************
# Miscs
@ -273,5 +275,6 @@ def engine_run_kwargs(parsed_args):
'visrank_topk': parsed_args.visrank_topk,
'use_metric_cuhk03': parsed_args.use_metric_cuhk03,
'ranks': parsed_args.ranks,
'rerank': parsed_args.rerank
'rerank': parsed_args.rerank,
'visactmap': parsed_args.visactmap
}

View File

@ -27,6 +27,8 @@ class DataManager(object):
use_cpu=False):
self.sources = sources
self.targets = targets
self.height = height
self.width = width
if self.sources is None:
raise ValueError('sources must not be None')
@ -41,7 +43,7 @@ class DataManager(object):
self.targets = [self.targets]
self.transform_tr, self.transform_te = build_transforms(
height, width, transforms
self.height, self.width, transforms
)
self.use_gpu = (torch.cuda.is_available() and not use_cpu)
@ -108,6 +110,7 @@ class ImageDataManager(DataManager):
batch_size=32
)
"""
data_type = 'image'
def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip',
use_cpu=False, split_id=0, combineall=False,
@ -261,6 +264,7 @@ class VideoDataManager(DataManager):
training, you need to modify the transformation functions for video reid such that each function
applies the same operation to all images in a tracklet to keep consistency.
"""
data_type = 'video'
def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip',
use_cpu=False, split_id=0, combineall=False,

View File

@ -7,13 +7,16 @@ import os.path as osp
import time
import datetime
import numpy as np
import cv2
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision
import torchreid
from torchreid.utils import AverageMeter, visualize_ranked_results, save_checkpoint, re_ranking
from torchreid.utils import AverageMeter, visualize_ranked_results, save_checkpoint, re_ranking, mkdir_if_missing
from torchreid.losses import DeepSupervision
from torchreid import metrics
@ -44,7 +47,7 @@ class Engine(object):
def run(self, save_dir='log', max_epoch=0, start_epoch=0, fixbase_epoch=0, open_layers=None,
start_eval=0, eval_freq=-1, test_only=False, print_freq=10,
dist_metric='euclidean', normalize_feature=False, visrank=False, visrank_topk=20,
use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False):
use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False, visactmap=False):
r"""A unified pipeline for training and evaluating a model.
Args:
@ -65,19 +68,22 @@ class Engine(object):
between query and gallery. Default is "euclidean".
normalize_feature (bool, optional): performs L2 normalization on feature vectors before
computing feature distance. Default is False.
visrank (bool, optional): visualizes ranked results. Default is False. Visualization
will be performed every test time, so it is recommended to enable ``visrank`` when
``test_only`` is True. The ranked images will be saved to
"save_dir/ranks-epoch/dataset_name", e.g. "save_dir/ranks-60/market1501".
visrank (bool, optional): visualizes ranked results. Default is False. It is recommended to
enable ``visrank`` when ``test_only`` is True. The ranked images will be saved to
"save_dir/visrank_dataset", e.g. "save_dir/visrank_market1501".
visrank_topk (int, optional): top-k ranked images to be visualized. Default is 20.
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
Default is False. This should be enabled when using cuhk03 classic split.
ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].
rerank (bool, optional): uses person re-ranking (by Zhong et al. CVPR'17).
Default is False. This is only enabled when test_only=True.
visactmap (bool, optional): visualizes activation maps. Default is False.
"""
trainloader, testloader = self.datamanager.return_dataloaders()
if visrank and not test_only:
raise ValueError('visrank=True is valid only if test_only=True')
if test_only:
self.test(
0,
@ -93,6 +99,10 @@ class Engine(object):
)
return
if visactmap:
self.visactmap(testloader, save_dir, self.datamanager.width, self.datamanager.height, print_freq)
return
time_start = time.time()
print('=> Start training')
@ -144,7 +154,7 @@ class Engine(object):
.. note::
This needs to be implemented in subclasses.
This must be implemented in subclasses.
"""
raise NotImplementedError
@ -155,34 +165,14 @@ class Engine(object):
.. note::
This function has been called in ``run()`` when necessary.
This function has been called in ``run()``.
.. note::
The test pipeline implemented in this function suits both image- and
video-reid. In general, a subclass of Engine only needs to re-implement
``_extract_features()`` and ``_parse_data_for_eval()`` when necessary,
``_extract_features()`` and ``_parse_data_for_eval()`` (most of the time),
but not a must. Please refer to the source code for more details.
Args:
epoch (int): current epoch.
testloader (dict): dictionary containing
{dataset_name: 'query': queryloader, 'gallery': galleryloader}.
dist_metric (str, optional): distance metric used to compute distance matrix
between query and gallery. Default is "euclidean".
normalize_feature (bool, optional): performs L2 normalization on feature vectors before
computing feature distance. Default is False.
visrank (bool, optional): visualizes ranked results. Default is False. Visualization
will be performed every test time, so it is recommended to enable ``visrank`` when
``test_only`` is True. The ranked images will be saved to
"save_dir/ranks-epoch/dataset_name", e.g. "save_dir/ranks-60/market1501".
visrank_topk (int, optional): top-k ranked images to be visualized. Default is 20.
save_dir (str): directory to save visualized results if ``visrank`` is True.
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
Default is False. This should be enabled when using cuhk03 classic split.
ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].
rerank (bool, optional): uses person re-ranking (by Zhong et al. CVPR'17).
Default is False.
"""
targets = list(testloader.keys())
@ -291,12 +281,115 @@ class Engine(object):
visualize_ranked_results(
distmat,
self.datamanager.return_testdataset_by_name(dataset_name),
save_dir=osp.join(save_dir, 'visrank-'+str(epoch+1), dataset_name),
self.datamanager.data_type,
width=self.datamanager.width,
height=self.datamanager.height,
save_dir=osp.join(save_dir, 'visrank_'+dataset_name),
topk=visrank_topk
)
return cmc[0]
@torch.no_grad()
def visactmap(self, testloader, save_dir, width, height, print_freq):
"""Visualizes CNN activation maps to see where the CNN focuses on to extract features.
This function takes as input the query images of target datasets
Reference:
- Zagoruyko and Komodakis. Paying more attention to attention: Improving the
performance of convolutional neural networks via attention transfer. ICLR, 2017
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
"""
self.model.eval()
if not hasattr(self.model , 'featuremaps'):
raise AttributeError('Model must have method featuremaps(), which returns the feature maps '
'of shape (b, c, h, w)')
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
tensor2pil = torchvision.transforms.ToPILImage()
for target in list(testloader.keys()):
queryloader = testloader[target]['query']
# original images and activation maps are saved individually
actmap_image_dir = osp.join(save_dir, 'actmap_'+target, 'images')
mkdir_if_missing(actmap_image_dir)
# original image and activation map are saved in one single figure
actmap_fig_dir = osp.join(save_dir, 'actmap_'+target, 'single_figure')
mkdir_if_missing(actmap_fig_dir)
print('Visualizing activation maps for {} ...'.format(target))
for batch_idx, data in enumerate(queryloader):
imgs, paths = data[0], data[3]
if self.use_gpu:
imgs = imgs.cuda()
# forward to get convolutional feature maps
outputs = self.model.featuremaps(imgs)
if outputs.dim() != 4:
raise ValueError('The model output is supposed to have ' \
'shape of (b, c, h, w), i.e. 4 dimensions, but got {} dimensions. '
'Please make sure you set the model output at eval mode '
'to be the last convolutional feature maps'.format(outputs.dim()))
# compute activation maps
outputs = (outputs**2).sum(1)
b, h, w = outputs.size()
outputs = outputs.view(b, h*w)
outputs = F.normalize(outputs, p=2, dim=1)
outputs = outputs.view(b, h, w)
if self.use_gpu:
imgs, outputs = imgs.cpu(), outputs.cpu()
for j in range(outputs.size(0)):
# get image name
path = paths[j]
imname = osp.basename(osp.splitext(path)[0])
# RGB image
img = imgs[j, ...]
for t, m, s in zip(img, imagenet_mean, imagenet_std):
t.mul_(s).add_(m).clamp_(0, 1)
img_pil = tensor2pil(img)
img_pil.save(osp.join(actmap_image_dir, imname+'_image.jpg'))
# activation map
img_np = np.uint8(np.floor(img.numpy() * 255))
img_np = img_np.transpose((1, 2, 0)) # (c, h, w) -> (h, w, c)
out = outputs[j, ...].numpy()
out = cv2.resize(out, (width, height))
out = 255 * (out - np.max(out)) / (np.max(out) - np.min(out) + 1e-12)
out = np.uint8(np.floor(out))
out = cv2.applyColorMap(out, cv2.COLORMAP_JET)
# combined
combined = img_np * 0.5 + out * 0.5
combined[combined>255] = 255
combined = combined.astype(np.uint8)
cv2.imwrite(osp.join(actmap_image_dir, imname+'_combined.jpg'), combined)
cv2.imwrite(osp.join(actmap_image_dir, imname+'_map.jpg'), out)
# save images in a single figure
fig = plt.figure()
fig.add_subplot(1, 3, 1)
plt.axis('off')
plt.title('Original image')
plt.imshow(img_np)
fig.add_subplot(1, 3, 2)
plt.axis('off')
plt.title('Activation map')
plt.imshow(out[:, :, ::-1])
fig.add_subplot(1, 3, 3)
plt.axis('off')
plt.imshow(combined[:, :, ::-1])
fig.savefig(osp.join(actmap_fig_dir, imname+'.pdf'), bbox_inches='tight')
plt.close()
if (batch_idx+1) % print_freq == 0:
print('- done batch {}/{}'.format(batch_idx+1, len(queryloader)))
def _compute_loss(self, criterion, outputs, targets):
if isinstance(outputs, (tuple, list)):
loss = DeepSupervision(criterion, outputs, targets)

View File

@ -200,8 +200,7 @@ class OSNet(nn.Module):
"""Omni-Scale Network.
Reference:
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ArXiv preprint, 2019.
https://arxiv.org/abs/1905.00953
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
"""
def __init__(self, num_classes, blocks, layers, channels, feature_dim=512, loss='softmax', IN=False, **kwargs):

View File

@ -7,27 +7,37 @@ import numpy as np
import os
import os.path as osp
import shutil
import cv2
from matplotlib import pyplot as plt
from .tools import mkdir_if_missing
def visualize_ranked_results(distmat, dataset, save_dir='', topk=20):
PLOT_FONT_SIZE = 3
def visualize_ranked_results(distmat, dataset, data_type, width=128, height=256, save_dir='', topk=20):
"""Visualizes ranked results.
Supports both image-reid and video-reid.
For image-reid, ranks will be plotted in a single figure. For video-reid, ranks will be
saved in folders each containing a tracklet.
Args:
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
dataset (tuple): a 2-tuple containing (query, gallery), each of which contains
tuples of (img_path(s), pid, camid).
data_type (str): "image" or "video".
width (int, optional): resized image width. Default is 128.
height (int, optional): resized image height. Default is 256.
save_dir (str): directory to save output images.
topk (int, optional): denoting top-k images in the rank list to be visualized.
"""
num_q, num_g = distmat.shape
print('Visualizing top-{} ranks'.format(topk))
print('Visualizing top-{} ranks ...'.format(topk))
print('# query: {}\n# gallery {}'.format(num_q, num_g))
print('Saving images to "{}"'.format(save_dir))
query, gallery = dataset
assert num_q == len(query)
@ -36,16 +46,21 @@ def visualize_ranked_results(distmat, dataset, save_dir='', topk=20):
indices = np.argsort(distmat, axis=1)
mkdir_if_missing(save_dir)
def _cp_img_to(src, dst, rank, prefix):
def _cp_img_to(src, dst, rank, prefix, matched=False):
"""
Args:
src: image path or tuple (for vidreid)
dst: target directory
rank: int, denoting ranked position, starting from 1
prefix: string
matched: bool
"""
if isinstance(src, tuple) or isinstance(src, list):
dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3))
if prefix == 'gallery':
suffix = 'TRUE' if matched else 'FALSE'
dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) + '_' + suffix
else:
dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3))
mkdir_if_missing(dst)
for img_path in src:
shutil.copy(img_path, dst)
@ -55,21 +70,44 @@ def visualize_ranked_results(distmat, dataset, save_dir='', topk=20):
for q_idx in range(num_q):
qimg_path, qpid, qcamid = query[q_idx]
if isinstance(qimg_path, tuple) or isinstance(qimg_path, list):
qdir = osp.join(save_dir, osp.basename(qimg_path[0]))
if data_type == 'image':
qimg = cv2.imread(qimg_path)
qimg = cv2.resize(qimg, (width, height))
fig = plt.figure()
fig.add_subplot(1, topk+1, 1) # totally 1 query and topk gallery
plt.axis('off')
plt.title('Query', fontsize=PLOT_FONT_SIZE)
plt.imshow(qimg)
else:
qdir = osp.join(save_dir, osp.basename(qimg_path))
mkdir_if_missing(qdir)
_cp_img_to(qimg_path, qdir, rank=0, prefix='query')
qdir = osp.join(save_dir, osp.basename(osp.splitext(qimg_path)[0]))
mkdir_if_missing(qdir)
_cp_img_to(qimg_path, qdir, rank=0, prefix='query')
rank_idx = 1
for g_idx in indices[q_idx,:]:
gimg_path, gpid, gcamid = gallery[g_idx]
invalid = (qpid == gpid) & (qcamid == gcamid)
if not invalid:
_cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery')
if data_type == 'image':
gimg = cv2.imread(gimg_path)
gimg = cv2.resize(gimg, (width, height))
fig.add_subplot(1, topk+1, rank_idx+1)
plt.axis('off')
title_color = 'green' if gpid == qpid else 'red'
plt.title('Rank-'+str(rank_idx), fontsize=PLOT_FONT_SIZE, color=title_color)
plt.imshow(gimg)
else:
_cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery')
rank_idx += 1
if rank_idx > topk:
break
print("Done")
if data_type == 'image':
imname = osp.basename(osp.splitext(qimg_path)[0])
fig.savefig(osp.join(save_dir, imname+'.pdf'), bbox_inches='tight')
plt.close()
print('Done. Images have been saved to "{}" ...'.format(save_dir))