add visactmap code; upgrade visrank to plot single figure
parent
c809e361b7
commit
27d4f2f093
|
@ -164,8 +164,10 @@ def init_parser():
|
|||
|
||||
parser.add_argument('--visrank', action='store_true',
|
||||
help='visualize ranked results, only available in evaluation mode')
|
||||
parser.add_argument('--visrank-topk', type=int, default=20,
|
||||
parser.add_argument('--visrank-topk', type=int, default=15,
|
||||
help='visualize topk ranks')
|
||||
parser.add_argument('--visactmap', action='store_true',
|
||||
help='visualize CNN activation maps')
|
||||
|
||||
# ************************************************************
|
||||
# Miscs
|
||||
|
@ -273,5 +275,6 @@ def engine_run_kwargs(parsed_args):
|
|||
'visrank_topk': parsed_args.visrank_topk,
|
||||
'use_metric_cuhk03': parsed_args.use_metric_cuhk03,
|
||||
'ranks': parsed_args.ranks,
|
||||
'rerank': parsed_args.rerank
|
||||
'rerank': parsed_args.rerank,
|
||||
'visactmap': parsed_args.visactmap
|
||||
}
|
||||
|
|
|
@ -27,6 +27,8 @@ class DataManager(object):
|
|||
use_cpu=False):
|
||||
self.sources = sources
|
||||
self.targets = targets
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
if self.sources is None:
|
||||
raise ValueError('sources must not be None')
|
||||
|
@ -41,7 +43,7 @@ class DataManager(object):
|
|||
self.targets = [self.targets]
|
||||
|
||||
self.transform_tr, self.transform_te = build_transforms(
|
||||
height, width, transforms
|
||||
self.height, self.width, transforms
|
||||
)
|
||||
|
||||
self.use_gpu = (torch.cuda.is_available() and not use_cpu)
|
||||
|
@ -108,6 +110,7 @@ class ImageDataManager(DataManager):
|
|||
batch_size=32
|
||||
)
|
||||
"""
|
||||
data_type = 'image'
|
||||
|
||||
def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip',
|
||||
use_cpu=False, split_id=0, combineall=False,
|
||||
|
@ -261,6 +264,7 @@ class VideoDataManager(DataManager):
|
|||
training, you need to modify the transformation functions for video reid such that each function
|
||||
applies the same operation to all images in a tracklet to keep consistency.
|
||||
"""
|
||||
data_type = 'video'
|
||||
|
||||
def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip',
|
||||
use_cpu=False, split_id=0, combineall=False,
|
||||
|
|
|
@ -7,13 +7,16 @@ import os.path as osp
|
|||
import time
|
||||
import datetime
|
||||
import numpy as np
|
||||
import cv2
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
|
||||
import torchreid
|
||||
from torchreid.utils import AverageMeter, visualize_ranked_results, save_checkpoint, re_ranking
|
||||
from torchreid.utils import AverageMeter, visualize_ranked_results, save_checkpoint, re_ranking, mkdir_if_missing
|
||||
from torchreid.losses import DeepSupervision
|
||||
from torchreid import metrics
|
||||
|
||||
|
@ -44,7 +47,7 @@ class Engine(object):
|
|||
def run(self, save_dir='log', max_epoch=0, start_epoch=0, fixbase_epoch=0, open_layers=None,
|
||||
start_eval=0, eval_freq=-1, test_only=False, print_freq=10,
|
||||
dist_metric='euclidean', normalize_feature=False, visrank=False, visrank_topk=20,
|
||||
use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False):
|
||||
use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False, visactmap=False):
|
||||
r"""A unified pipeline for training and evaluating a model.
|
||||
|
||||
Args:
|
||||
|
@ -65,19 +68,22 @@ class Engine(object):
|
|||
between query and gallery. Default is "euclidean".
|
||||
normalize_feature (bool, optional): performs L2 normalization on feature vectors before
|
||||
computing feature distance. Default is False.
|
||||
visrank (bool, optional): visualizes ranked results. Default is False. Visualization
|
||||
will be performed every test time, so it is recommended to enable ``visrank`` when
|
||||
``test_only`` is True. The ranked images will be saved to
|
||||
"save_dir/ranks-epoch/dataset_name", e.g. "save_dir/ranks-60/market1501".
|
||||
visrank (bool, optional): visualizes ranked results. Default is False. It is recommended to
|
||||
enable ``visrank`` when ``test_only`` is True. The ranked images will be saved to
|
||||
"save_dir/visrank_dataset", e.g. "save_dir/visrank_market1501".
|
||||
visrank_topk (int, optional): top-k ranked images to be visualized. Default is 20.
|
||||
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
|
||||
Default is False. This should be enabled when using cuhk03 classic split.
|
||||
ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].
|
||||
rerank (bool, optional): uses person re-ranking (by Zhong et al. CVPR'17).
|
||||
Default is False. This is only enabled when test_only=True.
|
||||
visactmap (bool, optional): visualizes activation maps. Default is False.
|
||||
"""
|
||||
trainloader, testloader = self.datamanager.return_dataloaders()
|
||||
|
||||
if visrank and not test_only:
|
||||
raise ValueError('visrank=True is valid only if test_only=True')
|
||||
|
||||
if test_only:
|
||||
self.test(
|
||||
0,
|
||||
|
@ -93,6 +99,10 @@ class Engine(object):
|
|||
)
|
||||
return
|
||||
|
||||
if visactmap:
|
||||
self.visactmap(testloader, save_dir, self.datamanager.width, self.datamanager.height, print_freq)
|
||||
return
|
||||
|
||||
time_start = time.time()
|
||||
print('=> Start training')
|
||||
|
||||
|
@ -144,7 +154,7 @@ class Engine(object):
|
|||
|
||||
.. note::
|
||||
|
||||
This needs to be implemented in subclasses.
|
||||
This must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -155,34 +165,14 @@ class Engine(object):
|
|||
|
||||
.. note::
|
||||
|
||||
This function has been called in ``run()`` when necessary.
|
||||
This function has been called in ``run()``.
|
||||
|
||||
.. note::
|
||||
|
||||
The test pipeline implemented in this function suits both image- and
|
||||
video-reid. In general, a subclass of Engine only needs to re-implement
|
||||
``_extract_features()`` and ``_parse_data_for_eval()`` when necessary,
|
||||
``_extract_features()`` and ``_parse_data_for_eval()`` (most of the time),
|
||||
but not a must. Please refer to the source code for more details.
|
||||
|
||||
Args:
|
||||
epoch (int): current epoch.
|
||||
testloader (dict): dictionary containing
|
||||
{dataset_name: 'query': queryloader, 'gallery': galleryloader}.
|
||||
dist_metric (str, optional): distance metric used to compute distance matrix
|
||||
between query and gallery. Default is "euclidean".
|
||||
normalize_feature (bool, optional): performs L2 normalization on feature vectors before
|
||||
computing feature distance. Default is False.
|
||||
visrank (bool, optional): visualizes ranked results. Default is False. Visualization
|
||||
will be performed every test time, so it is recommended to enable ``visrank`` when
|
||||
``test_only`` is True. The ranked images will be saved to
|
||||
"save_dir/ranks-epoch/dataset_name", e.g. "save_dir/ranks-60/market1501".
|
||||
visrank_topk (int, optional): top-k ranked images to be visualized. Default is 20.
|
||||
save_dir (str): directory to save visualized results if ``visrank`` is True.
|
||||
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
|
||||
Default is False. This should be enabled when using cuhk03 classic split.
|
||||
ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].
|
||||
rerank (bool, optional): uses person re-ranking (by Zhong et al. CVPR'17).
|
||||
Default is False.
|
||||
"""
|
||||
targets = list(testloader.keys())
|
||||
|
||||
|
@ -291,12 +281,115 @@ class Engine(object):
|
|||
visualize_ranked_results(
|
||||
distmat,
|
||||
self.datamanager.return_testdataset_by_name(dataset_name),
|
||||
save_dir=osp.join(save_dir, 'visrank-'+str(epoch+1), dataset_name),
|
||||
self.datamanager.data_type,
|
||||
width=self.datamanager.width,
|
||||
height=self.datamanager.height,
|
||||
save_dir=osp.join(save_dir, 'visrank_'+dataset_name),
|
||||
topk=visrank_topk
|
||||
)
|
||||
|
||||
return cmc[0]
|
||||
|
||||
@torch.no_grad()
|
||||
def visactmap(self, testloader, save_dir, width, height, print_freq):
|
||||
"""Visualizes CNN activation maps to see where the CNN focuses on to extract features.
|
||||
|
||||
This function takes as input the query images of target datasets
|
||||
|
||||
Reference:
|
||||
- Zagoruyko and Komodakis. Paying more attention to attention: Improving the
|
||||
performance of convolutional neural networks via attention transfer. ICLR, 2017
|
||||
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
|
||||
"""
|
||||
self.model.eval()
|
||||
if not hasattr(self.model , 'featuremaps'):
|
||||
raise AttributeError('Model must have method featuremaps(), which returns the feature maps '
|
||||
'of shape (b, c, h, w)')
|
||||
|
||||
imagenet_mean = [0.485, 0.456, 0.406]
|
||||
imagenet_std = [0.229, 0.224, 0.225]
|
||||
tensor2pil = torchvision.transforms.ToPILImage()
|
||||
|
||||
for target in list(testloader.keys()):
|
||||
queryloader = testloader[target]['query']
|
||||
# original images and activation maps are saved individually
|
||||
actmap_image_dir = osp.join(save_dir, 'actmap_'+target, 'images')
|
||||
mkdir_if_missing(actmap_image_dir)
|
||||
# original image and activation map are saved in one single figure
|
||||
actmap_fig_dir = osp.join(save_dir, 'actmap_'+target, 'single_figure')
|
||||
mkdir_if_missing(actmap_fig_dir)
|
||||
print('Visualizing activation maps for {} ...'.format(target))
|
||||
|
||||
for batch_idx, data in enumerate(queryloader):
|
||||
imgs, paths = data[0], data[3]
|
||||
if self.use_gpu:
|
||||
imgs = imgs.cuda()
|
||||
|
||||
# forward to get convolutional feature maps
|
||||
outputs = self.model.featuremaps(imgs)
|
||||
if outputs.dim() != 4:
|
||||
raise ValueError('The model output is supposed to have ' \
|
||||
'shape of (b, c, h, w), i.e. 4 dimensions, but got {} dimensions. '
|
||||
'Please make sure you set the model output at eval mode '
|
||||
'to be the last convolutional feature maps'.format(outputs.dim()))
|
||||
|
||||
# compute activation maps
|
||||
outputs = (outputs**2).sum(1)
|
||||
b, h, w = outputs.size()
|
||||
outputs = outputs.view(b, h*w)
|
||||
outputs = F.normalize(outputs, p=2, dim=1)
|
||||
outputs = outputs.view(b, h, w)
|
||||
|
||||
if self.use_gpu:
|
||||
imgs, outputs = imgs.cpu(), outputs.cpu()
|
||||
|
||||
for j in range(outputs.size(0)):
|
||||
# get image name
|
||||
path = paths[j]
|
||||
imname = osp.basename(osp.splitext(path)[0])
|
||||
|
||||
# RGB image
|
||||
img = imgs[j, ...]
|
||||
for t, m, s in zip(img, imagenet_mean, imagenet_std):
|
||||
t.mul_(s).add_(m).clamp_(0, 1)
|
||||
img_pil = tensor2pil(img)
|
||||
img_pil.save(osp.join(actmap_image_dir, imname+'_image.jpg'))
|
||||
|
||||
# activation map
|
||||
img_np = np.uint8(np.floor(img.numpy() * 255))
|
||||
img_np = img_np.transpose((1, 2, 0)) # (c, h, w) -> (h, w, c)
|
||||
out = outputs[j, ...].numpy()
|
||||
out = cv2.resize(out, (width, height))
|
||||
out = 255 * (out - np.max(out)) / (np.max(out) - np.min(out) + 1e-12)
|
||||
out = np.uint8(np.floor(out))
|
||||
out = cv2.applyColorMap(out, cv2.COLORMAP_JET)
|
||||
|
||||
# combined
|
||||
combined = img_np * 0.5 + out * 0.5
|
||||
combined[combined>255] = 255
|
||||
combined = combined.astype(np.uint8)
|
||||
cv2.imwrite(osp.join(actmap_image_dir, imname+'_combined.jpg'), combined)
|
||||
cv2.imwrite(osp.join(actmap_image_dir, imname+'_map.jpg'), out)
|
||||
|
||||
# save images in a single figure
|
||||
fig = plt.figure()
|
||||
fig.add_subplot(1, 3, 1)
|
||||
plt.axis('off')
|
||||
plt.title('Original image')
|
||||
plt.imshow(img_np)
|
||||
fig.add_subplot(1, 3, 2)
|
||||
plt.axis('off')
|
||||
plt.title('Activation map')
|
||||
plt.imshow(out[:, :, ::-1])
|
||||
fig.add_subplot(1, 3, 3)
|
||||
plt.axis('off')
|
||||
plt.imshow(combined[:, :, ::-1])
|
||||
fig.savefig(osp.join(actmap_fig_dir, imname+'.pdf'), bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
if (batch_idx+1) % print_freq == 0:
|
||||
print('- done batch {}/{}'.format(batch_idx+1, len(queryloader)))
|
||||
|
||||
def _compute_loss(self, criterion, outputs, targets):
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
loss = DeepSupervision(criterion, outputs, targets)
|
||||
|
|
|
@ -200,8 +200,7 @@ class OSNet(nn.Module):
|
|||
"""Omni-Scale Network.
|
||||
|
||||
Reference:
|
||||
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ArXiv preprint, 2019.
|
||||
https://arxiv.org/abs/1905.00953
|
||||
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, blocks, layers, channels, feature_dim=512, loss='softmax', IN=False, **kwargs):
|
||||
|
|
|
@ -7,27 +7,37 @@ import numpy as np
|
|||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import cv2
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from .tools import mkdir_if_missing
|
||||
|
||||
|
||||
def visualize_ranked_results(distmat, dataset, save_dir='', topk=20):
|
||||
PLOT_FONT_SIZE = 3
|
||||
|
||||
|
||||
def visualize_ranked_results(distmat, dataset, data_type, width=128, height=256, save_dir='', topk=20):
|
||||
"""Visualizes ranked results.
|
||||
|
||||
Supports both image-reid and video-reid.
|
||||
|
||||
For image-reid, ranks will be plotted in a single figure. For video-reid, ranks will be
|
||||
saved in folders each containing a tracklet.
|
||||
|
||||
Args:
|
||||
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
|
||||
dataset (tuple): a 2-tuple containing (query, gallery), each of which contains
|
||||
tuples of (img_path(s), pid, camid).
|
||||
data_type (str): "image" or "video".
|
||||
width (int, optional): resized image width. Default is 128.
|
||||
height (int, optional): resized image height. Default is 256.
|
||||
save_dir (str): directory to save output images.
|
||||
topk (int, optional): denoting top-k images in the rank list to be visualized.
|
||||
"""
|
||||
num_q, num_g = distmat.shape
|
||||
|
||||
print('Visualizing top-{} ranks'.format(topk))
|
||||
print('Visualizing top-{} ranks ...'.format(topk))
|
||||
print('# query: {}\n# gallery {}'.format(num_q, num_g))
|
||||
print('Saving images to "{}"'.format(save_dir))
|
||||
|
||||
query, gallery = dataset
|
||||
assert num_q == len(query)
|
||||
|
@ -36,16 +46,21 @@ def visualize_ranked_results(distmat, dataset, save_dir='', topk=20):
|
|||
indices = np.argsort(distmat, axis=1)
|
||||
mkdir_if_missing(save_dir)
|
||||
|
||||
def _cp_img_to(src, dst, rank, prefix):
|
||||
def _cp_img_to(src, dst, rank, prefix, matched=False):
|
||||
"""
|
||||
Args:
|
||||
src: image path or tuple (for vidreid)
|
||||
dst: target directory
|
||||
rank: int, denoting ranked position, starting from 1
|
||||
prefix: string
|
||||
matched: bool
|
||||
"""
|
||||
if isinstance(src, tuple) or isinstance(src, list):
|
||||
dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3))
|
||||
if prefix == 'gallery':
|
||||
suffix = 'TRUE' if matched else 'FALSE'
|
||||
dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) + '_' + suffix
|
||||
else:
|
||||
dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3))
|
||||
mkdir_if_missing(dst)
|
||||
for img_path in src:
|
||||
shutil.copy(img_path, dst)
|
||||
|
@ -55,21 +70,44 @@ def visualize_ranked_results(distmat, dataset, save_dir='', topk=20):
|
|||
|
||||
for q_idx in range(num_q):
|
||||
qimg_path, qpid, qcamid = query[q_idx]
|
||||
if isinstance(qimg_path, tuple) or isinstance(qimg_path, list):
|
||||
qdir = osp.join(save_dir, osp.basename(qimg_path[0]))
|
||||
|
||||
if data_type == 'image':
|
||||
qimg = cv2.imread(qimg_path)
|
||||
qimg = cv2.resize(qimg, (width, height))
|
||||
fig = plt.figure()
|
||||
fig.add_subplot(1, topk+1, 1) # totally 1 query and topk gallery
|
||||
plt.axis('off')
|
||||
plt.title('Query', fontsize=PLOT_FONT_SIZE)
|
||||
plt.imshow(qimg)
|
||||
else:
|
||||
qdir = osp.join(save_dir, osp.basename(qimg_path))
|
||||
mkdir_if_missing(qdir)
|
||||
_cp_img_to(qimg_path, qdir, rank=0, prefix='query')
|
||||
qdir = osp.join(save_dir, osp.basename(osp.splitext(qimg_path)[0]))
|
||||
mkdir_if_missing(qdir)
|
||||
_cp_img_to(qimg_path, qdir, rank=0, prefix='query')
|
||||
|
||||
rank_idx = 1
|
||||
for g_idx in indices[q_idx,:]:
|
||||
gimg_path, gpid, gcamid = gallery[g_idx]
|
||||
invalid = (qpid == gpid) & (qcamid == gcamid)
|
||||
|
||||
if not invalid:
|
||||
_cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery')
|
||||
if data_type == 'image':
|
||||
gimg = cv2.imread(gimg_path)
|
||||
gimg = cv2.resize(gimg, (width, height))
|
||||
fig.add_subplot(1, topk+1, rank_idx+1)
|
||||
plt.axis('off')
|
||||
title_color = 'green' if gpid == qpid else 'red'
|
||||
plt.title('Rank-'+str(rank_idx), fontsize=PLOT_FONT_SIZE, color=title_color)
|
||||
plt.imshow(gimg)
|
||||
else:
|
||||
_cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery')
|
||||
|
||||
rank_idx += 1
|
||||
if rank_idx > topk:
|
||||
break
|
||||
|
||||
print("Done")
|
||||
if data_type == 'image':
|
||||
imname = osp.basename(osp.splitext(qimg_path)[0])
|
||||
fig.savefig(osp.join(save_dir, imname+'.pdf'), bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print('Done. Images have been saved to "{}" ...'.format(save_dir))
|
||||
|
|
Loading…
Reference in New Issue