minor update

This commit is contained in:
liaoxingyu 2020-07-14 11:58:06 +08:00
parent 7fbdf1fe82
commit 3f35eb449d
10 changed files with 58 additions and 244 deletions

View File

@ -218,7 +218,7 @@ _C.SOLVER.SWA.LR_SCHED = False
_C.SOLVER.CHECKPOINT_PERIOD = 5000
# Number of images per batch
# Number of images per batch across all machines.
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch
_C.SOLVER.IMS_PER_BATCH = 64
@ -228,7 +228,9 @@ _C.SOLVER.IMS_PER_BATCH = 64
_C.TEST = CN()
_C.TEST.EVAL_PERIOD = 50
_C.TEST.IMS_PER_BATCH = 128
# Number of images per batch in one process.
_C.TEST.IMS_PER_BATCH = 64
_C.TEST.METRIC = "cosine"
# Average query expansion

View File

@ -31,17 +31,16 @@ def build_reid_train_loader(cfg):
train_set = CommDataset(train_items, train_transforms, relabel=True)
num_workers = cfg.DATALOADER.NUM_WORKERS
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH
num_instance = cfg.DATALOADER.NUM_INSTANCE
global_batch_size = mini_batch_size * comm.get_world_size()
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
if cfg.DATALOADER.PK_SAMPLER:
if cfg.DATALOADER.NAIVE_WAY:
data_sampler = samplers.NaiveIdentitySampler(train_set.img_items,
global_batch_size, num_instance)
cfg.SOLVER.IMS_PER_BATCH, num_instance)
else:
data_sampler = samplers.BalancedIdentitySampler(train_set.img_items,
global_batch_size, num_instance)
cfg.SOLVER.IMS_PER_BATCH, num_instance)
else:
data_sampler = samplers.TrainingSampler(len(train_set))
batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)

View File

@ -489,7 +489,7 @@ class DefaultTrainer(SimpleTrainer):
frozen = cfg.is_frozen()
cfg.defrost()
iters_per_epoch = len(data_loader.dataset) // (cfg.SOLVER.IMS_PER_BATCH * comm.get_world_size())
iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
cfg.SOLVER.MAX_ITER *= iters_per_epoch
cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch
@ -502,8 +502,8 @@ class DefaultTrainer(SimpleTrainer):
cfg.SOLVER.CHECKPOINT_PERIOD *= iters_per_epoch
# Evaluation period must be divided by 200 for writing into tensorboard.
num_mode = 200 - (cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + num_mode
num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + num_mod
logger = logging.getLogger(__name__)
logger.info(

View File

@ -201,13 +201,13 @@ class SimpleTrainer(TrainerBase):
"""
If your want to do something with the heads, you can wrap the model.
"""
outputs = self.model(data)
outputs, targets = self.model(data)
# Compute loss
if isinstance(self.model, DistributedDataParallel):
loss_dict = self.model.module.losses(outputs)
loss_dict = self.model.module.losses(outputs, targets)
else:
loss_dict = self.model.losses(outputs)
loss_dict = self.model.losses(outputs, targets)
losses = sum(loss_dict.values())
self._detect_anomaly(losses, loss_dict)

View File

@ -1,166 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from collections import defaultdict
import argparse
import json
import os
import sys
import time
import cv2
import numpy as np
import torch
from torch.backends import cudnn
from fastreid.modeling import build_model
from fastreid.utils.checkpoint import Checkpointer
from fastreid.config import get_cfg
cudnn.benchmark = True
class Reid(object):
def __init__(self, config_file):
cfg = get_cfg()
cfg.merge_from_file(config_file)
cfg.defrost()
cfg.MODEL.WEIGHTS = 'projects/bjzProject/logs/bjz/arcface_adam/model_final.pth'
model = build_model(cfg)
Checkpointer(model).resume_or_load(cfg.MODEL.WEIGHTS)
model.cuda()
model.eval()
self.model = model
# self.model = torch.jit.load("reid_model.pt")
# self.model.eval()
# self.model.cuda()
example = torch.rand(1, 3, 256, 128)
example = example.cuda()
traced_script_module = torch.jit.trace_module(model, {'inference': example})
traced_script_module.save("reid_feat_extractor.pt")
@classmethod
def preprocess(cls, img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (128, 256))
img = img / 255.0
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
img = img.transpose((2, 0, 1)).astype(np.float32)
img = img[np.newaxis, :, :, :]
data = torch.from_numpy(img).cuda().float()
return data
@torch.no_grad()
def demo(self, img_path):
data = self.preprocess(img_path)
output = self.model.inference(data)
feat = output.cpu().data.numpy()
return feat
# @torch.no_grad()
# def extract_feat(self, dataloader):
# prefetcher = test_data_prefetcher(dataloader)
# feats = []
# labels = []
# batch = prefetcher.next()
# num_count = 0
# while batch[0] is not None:
# img, pid, camid = batch
# feat = self.model(img)
# feats.append(feat.cpu())
# labels.extend(np.asarray(pid))
#
# # if num_count > 2:
# # break
# batch = prefetcher.next()
# # num_count += 1
#
# feats = torch.cat(feats, dim=0)
# id_feats = defaultdict(list)
# for f, i in zip(feats, labels):
# id_feats[i].append(f)
# all_feats = []
# label_names = []
# for i in id_feats:
# all_feats.append(torch.stack(id_feats[i], dim=0).mean(dim=0))
# label_names.append(i)
#
# label_names = np.asarray(label_names)
# all_feats = torch.stack(all_feats, dim=0) # (n, 2048)
# all_feats = F.normalize(all_feats, p=2, dim=1)
# np.save('feats.npy', all_feats.cpu())
# np.save('labels.npy', label_names)
# cos = torch.mm(all_feats, all_feats.t()).numpy() # (n, n)
# cos -= np.eye(all_feats.shape[0])
# f = open('check_cross_folder_similarity.txt', 'w')
# for i in range(len(label_names)):
# sim_indx = np.argwhere(cos[i] > 0.5)[:, 0]
# sim_name = label_names[sim_indx]
# write_str = label_names[i] + ' '
# # f.write(label_names[i]+'\t')
# for n in sim_name:
# write_str += (n + ' ')
# # f.write(n+'\t')
# f.write(write_str+'\n')
#
#
# def prepare_gt(self, json_file):
# feat = []
# label = []
# with open(json_file, 'r') as f:
# total = json.load(f)
# for index in total:
# label.append(index)
# feat.append(np.array(total[index]))
# time_label = [int(i[0:10]) for i in label]
#
# return np.array(feat), np.array(label), np.array(time_label)
def compute_topk(self, k, feat, feats, label):
# num_gallery = feats.shape[0]
# new_feat = np.tile(feat,[num_gallery,1])
norm_feat = np.sqrt(np.sum(np.square(feat), axis=-1))
norm_feats = np.sqrt(np.sum(np.square(feats), axis=-1))
matrix = np.sum(np.multiply(feat, feats), axis=-1)
dist = matrix / np.multiply(norm_feat, norm_feats)
# print('feat:',feat.shape)
# print('feats:',feats.shape)
# print('label:',label.shape)
# print('dist:',dist.shape)
index = np.argsort(-dist)
# print('index:',index.shape)
result = []
for i in range(min(feats.shape[0], k)):
print(dist[index[i]])
result.append(label[index[i]])
return result
if __name__ == '__main__':
reid_sys = Reid(config_file='../../projects/bjzProject/configs/bjz.yml')
img_path = '/export/home/lxy/beijingStationReID/reid_model/demo_imgs/003740_c5s2_1561733125170.000000.jpg'
feat = reid_sys.demo(img_path)
feat_extractor = torch.jit.load('reid_feat_extractor.pt')
data = reid_sys.preprocess(img_path)
feat2 = feat_extractor.inference(data)
from ipdb import set_trace; set_trace()
# imgs = os.listdir(img_path)
# feats = {}
# for i in range(len(imgs)):
# feat = reid.demo(os.path.join(img_path, imgs[i]))
# feats[imgs[i]] = feat
# feat = reid.demo(os.path.join(img_path, 'crop_img0.jpg'))
# out1 = feats['dog.jpg']
# out2 = feats['kobe2.jpg']
# innerProduct = np.dot(out1, out2.T)
# cosineSimilarity = innerProduct / (np.linalg.norm(out1, ord=2) * np.linalg.norm(out2, ord=2))
# print(f'cosine similarity is {cosineSimilarity[0][0]:.4f}')

View File

@ -43,19 +43,14 @@ class BNneckHead(nn.Module):
if not self.training: return bn_feat
# Training
try:
cls_outputs = self.classifier(bn_feat)
pred_class_logits = cls_outputs.detach()
except TypeError:
cls_outputs = self.classifier(bn_feat, targets)
pred_class_logits = F.linear(F.normalize(bn_feat.detach()), F.normalize(self.classifier.weight.detach()))
# Log prediction accuracy
CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
try: cls_outputs = self.classifier(bn_feat)
except TypeError: cls_outputs = self.classifier(bn_feat, targets)
if self.neck_feat == "before":
feat = global_feat[..., 0, 0]
elif self.neck_feat == "after":
feat = bn_feat
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
elif self.neck_feat == "after": feat = bn_feat
else:
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
return cls_outputs, feat
return cls_outputs, pred_class_logits, feat

View File

@ -38,13 +38,9 @@ class LinearHead(nn.Module):
if not self.training: return global_feat
# Training
try:
cls_outputs = self.classifier(global_feat)
pred_class_logits = cls_outputs.detach()
except TypeError:
cls_outputs = self.classifier(global_feat, targets)
pred_class_logits = F.linear(F.normalize(global_feat.detach()), F.normalize(self.classifier.weight.detach()))
# Log prediction accuracy
CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
try: cls_outputs = self.classifier(global_feat)
except TypeError: cls_outputs = self.classifier(global_feat, targets)
return cls_outputs, global_feat
pred_class_logits = F.linear(global_feat, self.classifier.weight)
return cls_outputs, pred_class_logits, global_feat

View File

@ -55,19 +55,15 @@ class ReductionHead(nn.Module):
if not self.training: return bn_feat
# Training
try:
cls_outputs = self.classifier(bn_feat)
pred_class_logits = cls_outputs.detach()
except TypeError:
cls_outputs = self.classifier(bn_feat, targets)
pred_class_logits = F.linear(F.normalize(bn_feat.detach()), F.normalize(self.classifier.weight.detach()))
# Log prediction accuracy
CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
try: cls_outputs = self.classifier(bn_feat)
except TypeError: cls_outputs = self.classifier(bn_feat, targets)
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
elif self.neck_feat == "after": feat = bn_feat
else:
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
return cls_outputs, feat
return cls_outputs, pred_class_logits, feat

View File

@ -58,11 +58,9 @@ class Baseline(nn.Module):
# throw an error. We just set all the targets to 0 to avoid this problem.
if targets.sum() < 0: targets.zero_()
cls_outputs, features = self.heads(features, targets)
return cls_outputs, features, targets
return self.heads(features, targets), targets
else:
pred_features = self.heads(features)
return pred_features
return self.heads(features)
def preprocess_image(self, batched_inputs):
"""
@ -73,22 +71,22 @@ class Baseline(nn.Module):
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self, outputs):
def losses(self, outputs, gt_labels):
r"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
cls_outputs, pred_features, gt_labels = outputs
cls_outputs, pred_class_logits, pred_features = outputs
loss_dict = {}
loss_names = self._cfg.MODEL.LOSSES.NAME
# Log prediction accuracy
CrossEntropyLoss.log_accuracy(pred_class_logits.detach(), gt_labels)
if "CrossEntropyLoss" in loss_names:
loss_dict['loss_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels)
if "TripletLoss" in loss_names:
loss_dict['loss_triplet'] = TripletLoss(self._cfg)(pred_features, gt_labels)
if "CircleLoss" in loss_names:
loss_dict['loss_circle'] = CircleLoss(self._cfg)(pred_features, gt_labels)
return loss_dict

View File

@ -153,31 +153,21 @@ class MGN(nn.Module):
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"].long().to(self.device)
if targets.sum() < 0:
targets.zero_()
self.b1_head(b1_pool_feat, targets)
self.b2_head(b2_pool_feat, targets)
self.b21_head(b21_pool_feat, targets)
self.b22_head(b22_pool_feat, targets)
self.b3_head(b3_pool_feat, targets)
self.b31_head(b31_pool_feat, targets)
self.b32_head(b32_pool_feat, targets)
self.b33_head(b33_pool_feat, targets)
return
if targets.sum() < 0: targets.zero_()
b1_logits, b1_pool_feat = self.b1_head(b1_pool_feat, targets)
b2_logits, b2_pool_feat = self.b2_head(b2_pool_feat, targets)
b21_logits, b21_pool_feat = self.b21_head(b21_pool_feat, targets)
b22_logits, b22_pool_feat = self.b22_head(b22_pool_feat, targets)
b3_logits, b3_pool_feat = self.b3_head(b3_pool_feat, targets)
b31_logits, b31_pool_feat = self.b31_head(b31_pool_feat, targets)
b32_logits, b32_pool_feat = self.b32_head(b32_pool_feat, targets)
b33_logits, b33_pool_feat = self.b33_head(b33_pool_feat, targets)
return b1_logits, b2_logits, b21_logits, b22_logits, \
b3_logits, b31_logits, b32_logits, b33_logits, \
b1_pool_feat, b2_pool_feat, b3_pool_feat, \
torch.cat((b21_pool_feat, b22_pool_feat), dim=1), \
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1), targets
b1_logits, pred_class_logits, b1_pool_feat = self.b1_head(b1_pool_feat, targets)
b2_logits, _, b2_pool_feat = self.b2_head(b2_pool_feat, targets)
b21_logits, _, b21_pool_feat = self.b21_head(b21_pool_feat, targets)
b22_logits, _, b22_pool_feat = self.b22_head(b22_pool_feat, targets)
b3_logits, _, b3_pool_feat = self.b3_head(b3_pool_feat, targets)
b31_logits, _, b31_pool_feat = self.b31_head(b31_pool_feat, targets)
b32_logits, _, b32_pool_feat = self.b32_head(b32_pool_feat, targets)
b33_logits, _, b33_pool_feat = self.b33_head(b33_pool_feat, targets)
return (b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits,
b1_pool_feat, b2_pool_feat, b3_pool_feat,
torch.cat((b21_pool_feat, b22_pool_feat), dim=1),
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1), pred_class_logits), targets
else:
b1_pool_feat = self.b1_head(b1_pool_feat)
@ -202,12 +192,16 @@ class MGN(nn.Module):
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self, outputs):
def losses(self, outputs, gt_labels):
b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits, \
b1_pool_feat, b2_pool_feat, b3_pool_feat, b22_pool_feat, b33_pool_feat, gt_labels = outputs
b1_pool_feat, b2_pool_feat, b3_pool_feat, b22_pool_feat, b33_pool_feat, pred_class_logits = outputs
loss_dict = {}
loss_names = self._cfg.MODEL.LOSSES.NAME
# Log prediction accuracy
CrossEntropyLoss.log_accuracy(pred_class_logits.detach(), gt_labels)
if "CrossEntropyLoss" in loss_names:
loss_dict['loss_cls_b1'] = CrossEntropyLoss(self._cfg)(b1_logits, gt_labels)
loss_dict['loss_cls_b2'] = CrossEntropyLoss(self._cfg)(b2_logits, gt_labels)