commit
2f1bd9ab39
|
@ -30,6 +30,6 @@ class FC(nn.Layer):
|
|||
self.fc = paddle.nn.Linear(
|
||||
self.embedding_size, self.class_num, weight_attr=weight_attr)
|
||||
|
||||
def forward(self, input, label):
|
||||
def forward(self, input):
|
||||
out = self.fc(input)
|
||||
return out
|
||||
|
|
|
@ -31,7 +31,7 @@ from ppcls.utils import logger
|
|||
from ppcls.data import build_dataloader
|
||||
from ppcls.arch import build_model
|
||||
from ppcls.loss import build_loss
|
||||
from ppcls.arch.loss_metrics import build_metrics
|
||||
from ppcls.metric import build_metrics
|
||||
from ppcls.optimizer import build_optimizer
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
from ppcls.utils.save_load import init_model
|
||||
|
@ -81,43 +81,35 @@ class Trainer(object):
|
|||
self.vdl_writer = LogWriter(logdir=vdl_writer_path)
|
||||
logger.info('train with paddle {} and device {}'.format(
|
||||
paddle.__version__, self.device))
|
||||
|
||||
def _build_metric_info(self, metric_config, mode="train"):
|
||||
"""
|
||||
_build_metric_info: build metrics according to current mode
|
||||
Return:
|
||||
metric: dict of the metrics info
|
||||
"""
|
||||
metric = None
|
||||
mode = mode.capitalize()
|
||||
if mode in metric_config and metric_config[mode] is not None:
|
||||
metric = build_metrics(metric_config[mode])
|
||||
return metric
|
||||
|
||||
def _build_loss_info(self, loss_config, mode="train"):
|
||||
"""
|
||||
_build_loss_info: build loss according to current mode
|
||||
Return:
|
||||
loss_dict: dict of the loss info
|
||||
"""
|
||||
loss = None
|
||||
mode = mode.capitalize()
|
||||
if mode in loss_config and loss_config[mode] is not None:
|
||||
loss = build_loss(loss_config[mode])
|
||||
return loss
|
||||
# init members
|
||||
self.train_dataloader = None
|
||||
self.eval_dataloader = None
|
||||
self.gallery_dataloader = None
|
||||
self.query_dataloader = None
|
||||
self.eval_mode = self.config["Global"].get("eval_mode",
|
||||
"classification")
|
||||
self.train_loss_func = None
|
||||
self.eval_loss_func = None
|
||||
self.train_metric_func = None
|
||||
self.eval_metric_func = None
|
||||
|
||||
def train(self):
|
||||
# build train loss and metric info
|
||||
loss_func = self._build_loss_info(self.config["Loss"])
|
||||
if "Metric" in self.config:
|
||||
metric_func = self._build_metric_info(self.config["Metric"])
|
||||
else:
|
||||
metric_func = None
|
||||
if self.train_loss_func is None:
|
||||
loss_info = self.config["Loss"]["Train"]
|
||||
self.train_loss_func = build_loss(loss_info)
|
||||
if self.train_metric_func is None:
|
||||
metric_config = self.config.get("Metric")
|
||||
if metric_config is not None:
|
||||
metric_config = metric_config.get("Train")
|
||||
if metric_config is not None:
|
||||
self.train_metric_func = build_metrics(metric_config)
|
||||
|
||||
train_dataloader = build_dataloader(self.config["DataLoader"], "Train",
|
||||
self.device)
|
||||
if self.train_dataloader is None:
|
||||
self.train_dataloader = build_dataloader(self.config["DataLoader"],
|
||||
"Train", self.device)
|
||||
|
||||
step_each_epoch = len(train_dataloader)
|
||||
step_each_epoch = len(self.train_dataloader)
|
||||
|
||||
optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
|
||||
self.config["Global"]["epochs"],
|
||||
|
@ -146,8 +138,7 @@ class Trainer(object):
|
|||
for epoch_id in range(best_metric["epoch"] + 1,
|
||||
self.config["Global"]["epochs"] + 1):
|
||||
acc = 0.0
|
||||
self.model.train()
|
||||
for iter_id, batch in enumerate(train_dataloader()):
|
||||
for iter_id, batch in enumerate(self.train_dataloader()):
|
||||
batch_size = batch[0].shape[0]
|
||||
batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64")
|
||||
.reshape([-1, 1]))
|
||||
|
@ -158,15 +149,15 @@ class Trainer(object):
|
|||
else:
|
||||
out = self.model(batch[0], batch[1])
|
||||
# calc loss
|
||||
loss_dict = loss_func(out, batch[1])
|
||||
loss_dict = self.train_loss_func(out, batch[1])
|
||||
for key in loss_dict:
|
||||
if not key in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
# calc metric
|
||||
if metric_func is not None:
|
||||
metric_dict = metric_func(out, batch[-1])
|
||||
if self.train_metric_func is not None:
|
||||
metric_dict = self.train_metric_func(out, batch[-1])
|
||||
for key in metric_dict:
|
||||
if not key in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
|
@ -181,7 +172,7 @@ class Trainer(object):
|
|||
])
|
||||
logger.info("[Train][Epoch {}][Iter: {}/{}]{}, {}".format(
|
||||
epoch_id, iter_id,
|
||||
len(train_dataloader), lr_msg, metric_msg))
|
||||
len(self.train_dataloader), lr_msg, metric_msg))
|
||||
|
||||
# step opt and lr
|
||||
loss_dict["loss"].backward()
|
||||
|
@ -212,6 +203,7 @@ class Trainer(object):
|
|||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="best_model")
|
||||
self.model.train()
|
||||
|
||||
# save model
|
||||
if epoch_id % save_interval == 0:
|
||||
|
@ -228,20 +220,56 @@ class Trainer(object):
|
|||
|
||||
@paddle.no_grad()
|
||||
def eval(self, epoch_id=0):
|
||||
output_info = dict()
|
||||
|
||||
eval_dataloader = build_dataloader(self.config["DataLoader"], "Eval",
|
||||
self.device)
|
||||
|
||||
self.model.eval()
|
||||
if self.eval_loss_func is None:
|
||||
loss_config = self.config.get("Loss", None)
|
||||
if loss_config is not None:
|
||||
loss_config = loss_config.get("Eval")
|
||||
if loss_config is not None:
|
||||
self.eval_loss_func = build_loss(loss_config)
|
||||
if self.eval_mode == "classification":
|
||||
if self.eval_dataloader is None:
|
||||
self.eval_dataloader = build_dataloader(
|
||||
self.config["DataLoader"], "Eval", self.device)
|
||||
|
||||
if self.eval_metric_func is None:
|
||||
metric_config = self.config.get("Metric")
|
||||
if metric_config is not None:
|
||||
metric_config = metric_config.get("Eval")
|
||||
if metric_config is not None:
|
||||
self.eval_metric_func = build_metrics(metric_config)
|
||||
|
||||
eval_result = self.eval_cls(epoch_id)
|
||||
|
||||
elif self.eval_mode == "retrieval":
|
||||
if self.gallery_dataloader is None:
|
||||
self.gallery_dataloader = build_dataloader(
|
||||
self.config["DataLoader"]["Eval"], "Gallery", self.device)
|
||||
|
||||
if self.query_dataloader is None:
|
||||
self.query_dataloader = build_dataloader(
|
||||
self.config["DataLoader"]["Eval"], "Query", self.device)
|
||||
# build metric info
|
||||
if self.eval_metric_func is None:
|
||||
metric_config = self.config.get("Metric", None)
|
||||
if metric_config is None:
|
||||
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
|
||||
else:
|
||||
metric_config = metric_config["Eval"]
|
||||
self.eval_metric_func = build_metrics(metric_config)
|
||||
eval_result = self.eval_retrieval(epoch_id)
|
||||
else:
|
||||
logger.warning("Invalid eval mode: {}".format(self.eval_mode))
|
||||
eval_result = None
|
||||
self.model.train()
|
||||
return eval_result
|
||||
|
||||
def eval_cls(self, epoch_id=0):
|
||||
output_info = dict()
|
||||
print_batch_step = self.config["Global"]["print_batch_step"]
|
||||
|
||||
# build train loss and metric info
|
||||
loss_func = self._build_loss_info(self.config["Loss"], "eval")
|
||||
metric_func = self._build_metric_info(self.config["Metric"], "eval")
|
||||
metric_key = None
|
||||
|
||||
for iter_id, batch in enumerate(eval_dataloader()):
|
||||
for iter_id, batch in enumerate(self.eval_dataloader()):
|
||||
batch_size = batch[0].shape[0]
|
||||
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
|
||||
batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1])
|
||||
|
@ -250,32 +278,32 @@ class Trainer(object):
|
|||
out = self.model(batch[0], batch[1])
|
||||
else:
|
||||
out = self.model(batch[0])
|
||||
# calc build
|
||||
if loss_func is not None:
|
||||
loss_dict = loss_func(out, batch[-1])
|
||||
# calc loss
|
||||
if self.eval_loss_func is not None:
|
||||
loss_dict = self.eval_loss_func(out, batch[-1])
|
||||
for key in loss_dict:
|
||||
if not key in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
# calc metric
|
||||
if metric_func is not None:
|
||||
metric_dict = metric_func(out, batch[-1])
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
for key in metric_dict:
|
||||
paddle.distributed.all_reduce(
|
||||
metric_dict[key],
|
||||
op=paddle.distributed.ReduceOp.SUM)
|
||||
metric_dict[key] = metric_dict[
|
||||
key] / paddle.distributed.get_world_size()
|
||||
# calc metric
|
||||
if self.eval_metric_func is not None:
|
||||
metric_dict = self.eval_metric_func(out, batch[-1])
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
for key in metric_dict:
|
||||
if metric_key is None:
|
||||
metric_key = key
|
||||
if not key in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
paddle.distributed.all_reduce(
|
||||
metric_dict[key],
|
||||
op=paddle.distributed.ReduceOp.SUM)
|
||||
metric_dict[key] = metric_dict[
|
||||
key] / paddle.distributed.get_world_size()
|
||||
for key in metric_dict:
|
||||
if metric_key is None:
|
||||
metric_key = key
|
||||
if not key in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
|
||||
output_info[key].update(metric_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
output_info[key].update(metric_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
|
||||
if iter_id % print_batch_step == 0:
|
||||
metric_msg = ", ".join([
|
||||
|
@ -283,7 +311,7 @@ class Trainer(object):
|
|||
for key in output_info
|
||||
])
|
||||
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}".format(
|
||||
epoch_id, iter_id, len(eval_dataloader), metric_msg))
|
||||
epoch_id, iter_id, len(self.eval_dataloader), metric_msg))
|
||||
|
||||
metric_msg = ", ".join([
|
||||
"{}: {:.5f}".format(key, output_info[key].avg)
|
||||
|
@ -291,13 +319,128 @@ class Trainer(object):
|
|||
])
|
||||
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
|
||||
|
||||
self.model.train()
|
||||
# do not try to save best model
|
||||
if metric_func is None:
|
||||
if self.eval_metric_func is None:
|
||||
return -1
|
||||
# return 1st metric in the dict
|
||||
return output_info[metric_key].avg
|
||||
|
||||
def eval_retrieval(self, epoch_id=0):
|
||||
self.model.eval()
|
||||
cum_similarity_matrix = None
|
||||
# step1. build gallery
|
||||
gallery_feas, gallery_img_id, gallery_camera_id = self._cal_feature(
|
||||
name='gallery')
|
||||
query_feas, query_img_id, query_camera_id = self._cal_feature(
|
||||
name='query')
|
||||
gallery_img_id = gallery_img_id
|
||||
# if gallery_camera_id is not None:
|
||||
# gallery_camera_id = gallery_camera_id
|
||||
# step2. do evaluation
|
||||
sim_block_size = self.config["Global"].get("sim_block_size", 64)
|
||||
sections = [sim_block_size] * (len(query_feas) // sim_block_size)
|
||||
if len(query_feas) % sim_block_size:
|
||||
sections.append(len(query_feas) % sim_block_size)
|
||||
fea_blocks = paddle.split(query_feas, num_or_sections=sections)
|
||||
if query_camera_id is not None:
|
||||
camera_id_blocks = paddle.split(
|
||||
query_camera_id, num_or_sections=sections)
|
||||
image_id_blocks = paddle.split(
|
||||
query_img_id, num_or_sections=sections)
|
||||
metric_key = None
|
||||
|
||||
for block_idx, block_fea in enumerate(fea_blocks):
|
||||
similarity_matrix = paddle.matmul(
|
||||
block_fea, gallery_feas, transpose_y=True)
|
||||
if query_camera_id is not None:
|
||||
camera_id_block = camera_id_blocks[block_idx]
|
||||
camera_id_mask = (camera_id_block != gallery_camera_id.t())
|
||||
|
||||
image_id_block = image_id_blocks[block_idx]
|
||||
image_id_mask = (image_id_block != gallery_img_id.t())
|
||||
|
||||
keep_mask = paddle.logical_or(camera_id_mask, image_id_mask)
|
||||
similarity_matrix = similarity_matrix * keep_mask.astype(
|
||||
"float32")
|
||||
if cum_similarity_matrix is None:
|
||||
cum_similarity_matrix = similarity_matrix
|
||||
else:
|
||||
cum_similarity_matrix = paddle.concat(
|
||||
[cum_similarity_matrix, similarity_matrix], axis=0)
|
||||
|
||||
# calc metric
|
||||
if self.eval_metric_func is not None:
|
||||
metric_dict = self.eval_metric_func(cum_similarity_matrix,
|
||||
query_img_id, gallery_img_id)
|
||||
else:
|
||||
metric_dict = {metric_key: 0.}
|
||||
metric_info_list = []
|
||||
|
||||
for key in metric_dict:
|
||||
if metric_key is None:
|
||||
metric_key = key
|
||||
metric_info_list.append("{}: {:.5f}".format(key, metric_dict[key]))
|
||||
metric_msg = ", ".join(metric_info_list)
|
||||
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
|
||||
|
||||
return metric_dict[metric_key]
|
||||
|
||||
def _cal_feature(self, name='gallery'):
|
||||
all_feas = None
|
||||
all_image_id = None
|
||||
all_camera_id = None
|
||||
if name == 'gallery':
|
||||
dataloader = self.gallery_dataloader
|
||||
elif name == 'query':
|
||||
dataloader = self.query_dataloader
|
||||
else:
|
||||
raise RuntimeError("Only support gallery or query dataset")
|
||||
|
||||
has_cam_id = False
|
||||
for idx, batch in enumerate(dataloader(
|
||||
)): # load is very time-consuming
|
||||
batch = [paddle.to_tensor(x) for x in batch]
|
||||
batch[1] = batch[1].reshape([-1, 1])
|
||||
if len(batch) == 3:
|
||||
has_cam_id = True
|
||||
batch[2] = batch[2].reshape([-1, 1])
|
||||
out = self.model(batch[0], batch[1])
|
||||
batch_feas = out["features"]
|
||||
|
||||
# do norm
|
||||
if self.config["Global"].get("feature_normalize", True):
|
||||
feas_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(batch_feas), axis=1,
|
||||
keepdim=True))
|
||||
batch_feas = paddle.divide(batch_feas, feas_norm)
|
||||
|
||||
if all_feas is None:
|
||||
all_feas = batch_feas
|
||||
if has_cam_id:
|
||||
all_camera_id = batch[2]
|
||||
all_image_id = batch[1]
|
||||
else:
|
||||
all_feas = paddle.concat([all_feas, batch_feas])
|
||||
all_image_id = paddle.concat([all_image_id, batch[1]])
|
||||
if has_cam_id:
|
||||
all_camera_id = paddle.concat([all_camera_id, batch[2]])
|
||||
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
feat_list = []
|
||||
img_id_list = []
|
||||
cam_id_list = []
|
||||
paddle.distributed.all_gather(feat_list, all_feas)
|
||||
paddle.distributed.all_gather(img_id_list, all_image_id)
|
||||
all_feas = paddle.concat(feat_list, axis=0)
|
||||
all_image_id = paddle.concat(img_id_list, axis=0)
|
||||
if has_cam_id:
|
||||
paddle.distributed.all_gather(cam_id_list, all_camera_id)
|
||||
all_camera_id = paddle.concat(cam_id_list, axis=0)
|
||||
|
||||
logger.info("Build {} done, all feat shape: {}, begin to eval..".
|
||||
format(name, all_feas.shape))
|
||||
return all_feas, all_image_id, all_camera_id
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self, ):
|
||||
total_trainer = paddle.distributed.get_world_size()
|
||||
|
|
|
@ -1,208 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from .trainer import Trainer
|
||||
from ppcls.utils import logger
|
||||
from ppcls.data import build_dataloader
|
||||
|
||||
|
||||
class TrainerReID(Trainer):
|
||||
def __init__(self, config, mode="train"):
|
||||
super().__init__(config, mode)
|
||||
|
||||
self.gallery_dataloader = build_dataloader(self.config["DataLoader"],
|
||||
"Gallery", self.device)
|
||||
|
||||
self.query_dataloader = build_dataloader(self.config["DataLoader"],
|
||||
"Query", self.device)
|
||||
|
||||
@paddle.no_grad()
|
||||
def eval(self, epoch_id=0):
|
||||
output_info = dict()
|
||||
self.model.eval()
|
||||
print_batch_step = self.config["Global"]["print_batch_step"]
|
||||
|
||||
# step1. build gallery
|
||||
gallery_feas, gallery_img_id, gallery_camera_id = self._cal_feature(
|
||||
name='gallery')
|
||||
query_feas, query_img_id, query_camera_id = self._cal_feature(
|
||||
name='query')
|
||||
|
||||
# step2. do evaluation
|
||||
if "num_split" in self.config["Global"]:
|
||||
num_split = self.config["Global"]["num_split"]
|
||||
else:
|
||||
num_split = 1
|
||||
fea_blocks = paddle.split(query_feas, num_or_sections=1)
|
||||
|
||||
total_similarities_matrix = None
|
||||
|
||||
for block_fea in fea_blocks:
|
||||
similarities_matrix = paddle.matmul(
|
||||
block_fea, gallery_feas, transpose_y=True)
|
||||
if total_similarities_matrix is None:
|
||||
total_similarities_matrix = similarities_matrix
|
||||
else:
|
||||
total_similarities_matrix = paddle.concat(
|
||||
[total_similarities_matrix, similarities_matrix])
|
||||
|
||||
# distmat = (1 - total_similarities_matrix).numpy()
|
||||
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
|
||||
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
|
||||
if query_camera_id is not None and gallery_camera_id is not None:
|
||||
q_camids = query_camera_id.numpy().reshape(
|
||||
(query_camera_id.shape[0]))
|
||||
g_camids = gallery_camera_id.numpy().reshape(
|
||||
(gallery_camera_id.shape[0]))
|
||||
max_rank = 50
|
||||
|
||||
num_q, num_g = total_similarities_matrix.shape
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.
|
||||
format(num_g))
|
||||
|
||||
# indices = np.argsort(distmat, axis=1)
|
||||
indices = paddle.argsort(
|
||||
total_similarities_matrix, axis=1, descending=True).numpy()
|
||||
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# compute cmc curve for each query
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
all_INP = []
|
||||
num_valid_q = 0. # number of valid query
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
if query_camera_id is not None and gallery_camera_id is not None:
|
||||
remove = (g_pids[order] == q_pid) & (
|
||||
g_camids[order] == q_camid)
|
||||
else:
|
||||
remove = g_pids[order] == q_pid
|
||||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
raw_cmc = matches[q_idx][
|
||||
keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(raw_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
cmc = raw_cmc.cumsum()
|
||||
|
||||
pos_idx = np.where(raw_cmc == 1)
|
||||
max_pos_idx = np.max(pos_idx)
|
||||
inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
|
||||
all_INP.append(inp)
|
||||
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
num_valid_q += 1.
|
||||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = raw_cmc.sum()
|
||||
tmp_cmc = raw_cmc.cumsum()
|
||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
all_AP.append(AP)
|
||||
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
|
||||
mAP = np.mean(all_AP)
|
||||
mINP = np.mean(all_INP)
|
||||
logger.info(
|
||||
"[Eval][Epoch {}]: mAP: {:.5f}, mINP: {:.5f},rank_1: {:.5f}, rank_5: {:.5f}"
|
||||
.format(epoch_id, mAP, mINP, all_cmc[0], all_cmc[4]))
|
||||
return mAP
|
||||
|
||||
def _cal_feature(self, name='gallery'):
|
||||
all_feas = None
|
||||
all_image_id = None
|
||||
all_camera_id = None
|
||||
if name == 'gallery':
|
||||
dataloader = self.gallery_dataloader
|
||||
elif name == 'query':
|
||||
dataloader = self.query_dataloader
|
||||
else:
|
||||
raise RuntimeError("Only support gallery or query dataset")
|
||||
|
||||
has_cam_id = False
|
||||
for idx, batch in enumerate(dataloader(
|
||||
)): # load is very time-consuming
|
||||
batch = [paddle.to_tensor(x) for x in batch]
|
||||
batch[1] = batch[1].reshape([-1, 1])
|
||||
if len(batch) == 3:
|
||||
has_cam_id = True
|
||||
batch[2] = batch[2].reshape([-1, 1])
|
||||
out = self.model(batch[0], batch[1])
|
||||
batch_feas = out["features"]
|
||||
|
||||
# do norm
|
||||
if self.config["Global"].get("feature_normalize", True):
|
||||
feas_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(batch_feas), axis=1,
|
||||
keepdim=True))
|
||||
batch_feas = paddle.divide(batch_feas, feas_norm)
|
||||
|
||||
batch_feas = batch_feas
|
||||
batch_image_labels = batch[1]
|
||||
if has_cam_id:
|
||||
batch_camera_labels = batch[2]
|
||||
|
||||
if all_feas is None:
|
||||
all_feas = batch_feas
|
||||
if has_cam_id:
|
||||
all_camera_id = batch[2]
|
||||
all_image_id = batch[1]
|
||||
else:
|
||||
all_feas = paddle.concat([all_feas, batch_feas])
|
||||
all_image_id = paddle.concat([all_image_id, batch[1]])
|
||||
if has_cam_id:
|
||||
all_camera_id = paddle.concat([all_camera_id, batch[2]])
|
||||
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
feat_list = []
|
||||
img_id_list = []
|
||||
cam_id_list = []
|
||||
paddle.distributed.all_gather(feat_list, all_feas)
|
||||
paddle.distributed.all_gather(img_id_list, all_image_id)
|
||||
all_feas = paddle.concat(feat_list, axis=0)
|
||||
all_image_id = paddle.concat(img_id_list, axis=0)
|
||||
if has_cam_id:
|
||||
paddle.distributed.all_gather(cam_id_list, all_camera_id)
|
||||
all_camera_id = paddle.concat(cam_id_list, axis=0)
|
||||
|
||||
logger.info("Build {} done, all feat shape: {}, begin to eval..".
|
||||
format(name, all_feas.shape))
|
||||
return all_feas, all_image_id, all_camera_id
|
|
@ -0,0 +1,45 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from paddle import nn
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
|
||||
from .metrics import TopkAcc, mAP, mINP, Recallk
|
||||
|
||||
|
||||
class CombinedMetrics(nn.Layer):
|
||||
def __init__(self, config_list):
|
||||
super().__init__()
|
||||
self.metric_func_list = []
|
||||
assert isinstance(config_list, list), (
|
||||
'operator config should be a list')
|
||||
for config in config_list:
|
||||
assert isinstance(config,
|
||||
dict) and len(config) == 1, "yaml format error"
|
||||
metric_name = list(config)[0]
|
||||
metric_params = config[metric_name]
|
||||
self.metric_func_list.append(eval(metric_name)(**metric_params))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
metric_dict = OrderedDict()
|
||||
for idx, metric_func in enumerate(self.metric_func_list):
|
||||
metric_dict.update(metric_func(*args, **kwargs))
|
||||
|
||||
return metric_dict
|
||||
|
||||
|
||||
def build_metrics(config):
|
||||
metrics_list = CombinedMetrics(copy.deepcopy(config))
|
||||
return metrics_list
|
|
@ -0,0 +1,147 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
class TopkAcc(nn.Layer):
|
||||
def __init__(self, topk=(1, 5)):
|
||||
super().__init__()
|
||||
assert isinstance(topk, (int, list, tuple))
|
||||
if isinstance(topk, int):
|
||||
topk = [topk]
|
||||
self.topk = topk
|
||||
|
||||
def forward(self, x, label):
|
||||
if isinstance(x, dict):
|
||||
x = x["logits"]
|
||||
|
||||
metric_dict = dict()
|
||||
for k in self.topk:
|
||||
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
|
||||
x, label, k=k)
|
||||
return metric_dict
|
||||
|
||||
|
||||
class mAP(nn.Layer):
|
||||
def __init__(self, max_rank=50):
|
||||
super().__init__()
|
||||
self.max_rank = max_rank
|
||||
|
||||
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
|
||||
metric_dict = dict()
|
||||
num_q, num_g = similarities_matrix.shape
|
||||
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
|
||||
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
|
||||
if num_g < self.max_rank:
|
||||
self.max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.
|
||||
format(num_g))
|
||||
indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=True).numpy()
|
||||
_, all_AP, _ = get_metrics(indices, num_q, num_g, q_pids, g_pids,
|
||||
self.max_rank)
|
||||
|
||||
mAP = np.mean(all_AP)
|
||||
metric_dict["mAP"] = mAP
|
||||
return metric_dict
|
||||
|
||||
|
||||
class mINP(nn.Layer):
|
||||
def __init__(self, max_rank=50):
|
||||
super().__init__()
|
||||
self.max_rank = max_rank
|
||||
|
||||
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
|
||||
metric_dict = dict()
|
||||
num_q, num_g = similarities_matrix.shape
|
||||
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
|
||||
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
|
||||
if num_g < self.max_rank:
|
||||
max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.
|
||||
format(num_g))
|
||||
indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=True).numpy()
|
||||
_, _, all_INP = get_metrics(indices, num_q, num_g, q_pids, g_pids,
|
||||
self.max_rank)
|
||||
|
||||
mINP = np.mean(all_INP)
|
||||
metric_dict["mINP"] = mINP
|
||||
return metric_dict
|
||||
|
||||
|
||||
class Recallk(nn.Layer):
|
||||
def __init__(self, max_rank=50, topk=(1, 5)):
|
||||
super().__init__()
|
||||
self.max_rank = max_rank
|
||||
assert isinstance(topk, (int, list))
|
||||
if isinstance(topk, int):
|
||||
topk = [topk]
|
||||
self.topk = topk
|
||||
|
||||
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
|
||||
metric_dict = dict()
|
||||
num_q, num_g = similarities_matrix.shape
|
||||
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
|
||||
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
|
||||
if num_g < self.max_rank:
|
||||
max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.
|
||||
format(num_g))
|
||||
indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=True).numpy()
|
||||
all_cmc, _, _ = get_metrics(indices, num_q, num_g, q_pids, g_pids,
|
||||
self.max_rank)
|
||||
|
||||
for k in self.topk:
|
||||
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
|
||||
return metric_dict
|
||||
|
||||
|
||||
def get_metrics(indices, num_q, num_g, q_pids, g_pids, max_rank=50):
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
all_INP = []
|
||||
num_valid_q = 0
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
for q_idx in range(num_q):
|
||||
raw_cmc = matches[q_idx]
|
||||
if not np.any(raw_cmc):
|
||||
continue
|
||||
cmc = raw_cmc.cumsum()
|
||||
pos_idx = np.where(raw_cmc == 1)
|
||||
max_pos_idx = np.max(pos_idx)
|
||||
inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
|
||||
all_INP.append(inp)
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
num_valid_q += 1.
|
||||
|
||||
num_rel = raw_cmc.sum()
|
||||
tmp_cmc = raw_cmc.cumsum()
|
||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
all_AP.append(AP)
|
||||
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
|
||||
return all_cmc, all_AP, all_INP
|
|
@ -22,13 +22,9 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
|||
|
||||
from ppcls.utils import config
|
||||
from ppcls.engine.trainer import Trainer
|
||||
from ppcls.engine.trainer_reid import TrainerReID
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
config = config.get_config(args.config, overrides=args.override, show=True)
|
||||
if "Trainer" in config:
|
||||
trainer = eval(config["Trainer"]["name"])(config, mode="train")
|
||||
else:
|
||||
trainer = Trainer(config, mode="train")
|
||||
trainer = Trainer(config, mode="train")
|
||||
trainer.train()
|
||||
|
|
Loading…
Reference in New Issue