mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
minor update
This commit is contained in:
parent
7fbdf1fe82
commit
3f35eb449d
@ -218,7 +218,7 @@ _C.SOLVER.SWA.LR_SCHED = False
|
|||||||
|
|
||||||
_C.SOLVER.CHECKPOINT_PERIOD = 5000
|
_C.SOLVER.CHECKPOINT_PERIOD = 5000
|
||||||
|
|
||||||
# Number of images per batch
|
# Number of images per batch across all machines.
|
||||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||||
# see 2 images per batch
|
# see 2 images per batch
|
||||||
_C.SOLVER.IMS_PER_BATCH = 64
|
_C.SOLVER.IMS_PER_BATCH = 64
|
||||||
@ -228,7 +228,9 @@ _C.SOLVER.IMS_PER_BATCH = 64
|
|||||||
_C.TEST = CN()
|
_C.TEST = CN()
|
||||||
|
|
||||||
_C.TEST.EVAL_PERIOD = 50
|
_C.TEST.EVAL_PERIOD = 50
|
||||||
_C.TEST.IMS_PER_BATCH = 128
|
|
||||||
|
# Number of images per batch in one process.
|
||||||
|
_C.TEST.IMS_PER_BATCH = 64
|
||||||
_C.TEST.METRIC = "cosine"
|
_C.TEST.METRIC = "cosine"
|
||||||
|
|
||||||
# Average query expansion
|
# Average query expansion
|
||||||
|
@ -31,17 +31,16 @@ def build_reid_train_loader(cfg):
|
|||||||
train_set = CommDataset(train_items, train_transforms, relabel=True)
|
train_set = CommDataset(train_items, train_transforms, relabel=True)
|
||||||
|
|
||||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||||
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH
|
|
||||||
num_instance = cfg.DATALOADER.NUM_INSTANCE
|
num_instance = cfg.DATALOADER.NUM_INSTANCE
|
||||||
global_batch_size = mini_batch_size * comm.get_world_size()
|
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
|
||||||
|
|
||||||
if cfg.DATALOADER.PK_SAMPLER:
|
if cfg.DATALOADER.PK_SAMPLER:
|
||||||
if cfg.DATALOADER.NAIVE_WAY:
|
if cfg.DATALOADER.NAIVE_WAY:
|
||||||
data_sampler = samplers.NaiveIdentitySampler(train_set.img_items,
|
data_sampler = samplers.NaiveIdentitySampler(train_set.img_items,
|
||||||
global_batch_size, num_instance)
|
cfg.SOLVER.IMS_PER_BATCH, num_instance)
|
||||||
else:
|
else:
|
||||||
data_sampler = samplers.BalancedIdentitySampler(train_set.img_items,
|
data_sampler = samplers.BalancedIdentitySampler(train_set.img_items,
|
||||||
global_batch_size, num_instance)
|
cfg.SOLVER.IMS_PER_BATCH, num_instance)
|
||||||
else:
|
else:
|
||||||
data_sampler = samplers.TrainingSampler(len(train_set))
|
data_sampler = samplers.TrainingSampler(len(train_set))
|
||||||
batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
|
batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
|
||||||
|
@ -489,7 +489,7 @@ class DefaultTrainer(SimpleTrainer):
|
|||||||
frozen = cfg.is_frozen()
|
frozen = cfg.is_frozen()
|
||||||
cfg.defrost()
|
cfg.defrost()
|
||||||
|
|
||||||
iters_per_epoch = len(data_loader.dataset) // (cfg.SOLVER.IMS_PER_BATCH * comm.get_world_size())
|
iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
|
||||||
cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
|
cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
|
||||||
cfg.SOLVER.MAX_ITER *= iters_per_epoch
|
cfg.SOLVER.MAX_ITER *= iters_per_epoch
|
||||||
cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch
|
cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch
|
||||||
@ -502,8 +502,8 @@ class DefaultTrainer(SimpleTrainer):
|
|||||||
cfg.SOLVER.CHECKPOINT_PERIOD *= iters_per_epoch
|
cfg.SOLVER.CHECKPOINT_PERIOD *= iters_per_epoch
|
||||||
|
|
||||||
# Evaluation period must be divided by 200 for writing into tensorboard.
|
# Evaluation period must be divided by 200 for writing into tensorboard.
|
||||||
num_mode = 200 - (cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
|
num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
|
||||||
cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + num_mode
|
cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + num_mod
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -201,13 +201,13 @@ class SimpleTrainer(TrainerBase):
|
|||||||
"""
|
"""
|
||||||
If your want to do something with the heads, you can wrap the model.
|
If your want to do something with the heads, you can wrap the model.
|
||||||
"""
|
"""
|
||||||
outputs = self.model(data)
|
outputs, targets = self.model(data)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
if isinstance(self.model, DistributedDataParallel):
|
if isinstance(self.model, DistributedDataParallel):
|
||||||
loss_dict = self.model.module.losses(outputs)
|
loss_dict = self.model.module.losses(outputs, targets)
|
||||||
else:
|
else:
|
||||||
loss_dict = self.model.losses(outputs)
|
loss_dict = self.model.losses(outputs, targets)
|
||||||
|
|
||||||
losses = sum(loss_dict.values())
|
losses = sum(loss_dict.values())
|
||||||
self._detect_anomaly(losses, loss_dict)
|
self._detect_anomaly(losses, loss_dict)
|
||||||
|
@ -1,166 +0,0 @@
|
|||||||
# encoding: utf-8
|
|
||||||
"""
|
|
||||||
@author: liaoxingyu
|
|
||||||
@contact: sherlockliao01@gmail.com
|
|
||||||
"""
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from collections import defaultdict
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.backends import cudnn
|
|
||||||
from fastreid.modeling import build_model
|
|
||||||
from fastreid.utils.checkpoint import Checkpointer
|
|
||||||
from fastreid.config import get_cfg
|
|
||||||
|
|
||||||
cudnn.benchmark = True
|
|
||||||
|
|
||||||
|
|
||||||
class Reid(object):
|
|
||||||
|
|
||||||
def __init__(self, config_file):
|
|
||||||
cfg = get_cfg()
|
|
||||||
cfg.merge_from_file(config_file)
|
|
||||||
cfg.defrost()
|
|
||||||
cfg.MODEL.WEIGHTS = 'projects/bjzProject/logs/bjz/arcface_adam/model_final.pth'
|
|
||||||
model = build_model(cfg)
|
|
||||||
Checkpointer(model).resume_or_load(cfg.MODEL.WEIGHTS)
|
|
||||||
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
|
||||||
self.model = model
|
|
||||||
# self.model = torch.jit.load("reid_model.pt")
|
|
||||||
# self.model.eval()
|
|
||||||
# self.model.cuda()
|
|
||||||
|
|
||||||
example = torch.rand(1, 3, 256, 128)
|
|
||||||
example = example.cuda()
|
|
||||||
traced_script_module = torch.jit.trace_module(model, {'inference': example})
|
|
||||||
traced_script_module.save("reid_feat_extractor.pt")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def preprocess(cls, img_path):
|
|
||||||
img = cv2.imread(img_path)
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
||||||
img = cv2.resize(img, (128, 256))
|
|
||||||
img = img / 255.0
|
|
||||||
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
|
|
||||||
img = img.transpose((2, 0, 1)).astype(np.float32)
|
|
||||||
img = img[np.newaxis, :, :, :]
|
|
||||||
data = torch.from_numpy(img).cuda().float()
|
|
||||||
return data
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def demo(self, img_path):
|
|
||||||
data = self.preprocess(img_path)
|
|
||||||
output = self.model.inference(data)
|
|
||||||
feat = output.cpu().data.numpy()
|
|
||||||
return feat
|
|
||||||
|
|
||||||
# @torch.no_grad()
|
|
||||||
# def extract_feat(self, dataloader):
|
|
||||||
# prefetcher = test_data_prefetcher(dataloader)
|
|
||||||
# feats = []
|
|
||||||
# labels = []
|
|
||||||
# batch = prefetcher.next()
|
|
||||||
# num_count = 0
|
|
||||||
# while batch[0] is not None:
|
|
||||||
# img, pid, camid = batch
|
|
||||||
# feat = self.model(img)
|
|
||||||
# feats.append(feat.cpu())
|
|
||||||
# labels.extend(np.asarray(pid))
|
|
||||||
#
|
|
||||||
# # if num_count > 2:
|
|
||||||
# # break
|
|
||||||
# batch = prefetcher.next()
|
|
||||||
# # num_count += 1
|
|
||||||
#
|
|
||||||
# feats = torch.cat(feats, dim=0)
|
|
||||||
# id_feats = defaultdict(list)
|
|
||||||
# for f, i in zip(feats, labels):
|
|
||||||
# id_feats[i].append(f)
|
|
||||||
# all_feats = []
|
|
||||||
# label_names = []
|
|
||||||
# for i in id_feats:
|
|
||||||
# all_feats.append(torch.stack(id_feats[i], dim=0).mean(dim=0))
|
|
||||||
# label_names.append(i)
|
|
||||||
#
|
|
||||||
# label_names = np.asarray(label_names)
|
|
||||||
# all_feats = torch.stack(all_feats, dim=0) # (n, 2048)
|
|
||||||
# all_feats = F.normalize(all_feats, p=2, dim=1)
|
|
||||||
# np.save('feats.npy', all_feats.cpu())
|
|
||||||
# np.save('labels.npy', label_names)
|
|
||||||
# cos = torch.mm(all_feats, all_feats.t()).numpy() # (n, n)
|
|
||||||
# cos -= np.eye(all_feats.shape[0])
|
|
||||||
# f = open('check_cross_folder_similarity.txt', 'w')
|
|
||||||
# for i in range(len(label_names)):
|
|
||||||
# sim_indx = np.argwhere(cos[i] > 0.5)[:, 0]
|
|
||||||
# sim_name = label_names[sim_indx]
|
|
||||||
# write_str = label_names[i] + ' '
|
|
||||||
# # f.write(label_names[i]+'\t')
|
|
||||||
# for n in sim_name:
|
|
||||||
# write_str += (n + ' ')
|
|
||||||
# # f.write(n+'\t')
|
|
||||||
# f.write(write_str+'\n')
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# def prepare_gt(self, json_file):
|
|
||||||
# feat = []
|
|
||||||
# label = []
|
|
||||||
# with open(json_file, 'r') as f:
|
|
||||||
# total = json.load(f)
|
|
||||||
# for index in total:
|
|
||||||
# label.append(index)
|
|
||||||
# feat.append(np.array(total[index]))
|
|
||||||
# time_label = [int(i[0:10]) for i in label]
|
|
||||||
#
|
|
||||||
# return np.array(feat), np.array(label), np.array(time_label)
|
|
||||||
|
|
||||||
def compute_topk(self, k, feat, feats, label):
|
|
||||||
|
|
||||||
# num_gallery = feats.shape[0]
|
|
||||||
# new_feat = np.tile(feat,[num_gallery,1])
|
|
||||||
norm_feat = np.sqrt(np.sum(np.square(feat), axis=-1))
|
|
||||||
norm_feats = np.sqrt(np.sum(np.square(feats), axis=-1))
|
|
||||||
matrix = np.sum(np.multiply(feat, feats), axis=-1)
|
|
||||||
dist = matrix / np.multiply(norm_feat, norm_feats)
|
|
||||||
# print('feat:',feat.shape)
|
|
||||||
# print('feats:',feats.shape)
|
|
||||||
# print('label:',label.shape)
|
|
||||||
# print('dist:',dist.shape)
|
|
||||||
|
|
||||||
index = np.argsort(-dist)
|
|
||||||
|
|
||||||
# print('index:',index.shape)
|
|
||||||
result = []
|
|
||||||
for i in range(min(feats.shape[0], k)):
|
|
||||||
print(dist[index[i]])
|
|
||||||
result.append(label[index[i]])
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
reid_sys = Reid(config_file='../../projects/bjzProject/configs/bjz.yml')
|
|
||||||
img_path = '/export/home/lxy/beijingStationReID/reid_model/demo_imgs/003740_c5s2_1561733125170.000000.jpg'
|
|
||||||
feat = reid_sys.demo(img_path)
|
|
||||||
feat_extractor = torch.jit.load('reid_feat_extractor.pt')
|
|
||||||
data = reid_sys.preprocess(img_path)
|
|
||||||
feat2 = feat_extractor.inference(data)
|
|
||||||
from ipdb import set_trace; set_trace()
|
|
||||||
# imgs = os.listdir(img_path)
|
|
||||||
# feats = {}
|
|
||||||
# for i in range(len(imgs)):
|
|
||||||
# feat = reid.demo(os.path.join(img_path, imgs[i]))
|
|
||||||
# feats[imgs[i]] = feat
|
|
||||||
# feat = reid.demo(os.path.join(img_path, 'crop_img0.jpg'))
|
|
||||||
# out1 = feats['dog.jpg']
|
|
||||||
# out2 = feats['kobe2.jpg']
|
|
||||||
# innerProduct = np.dot(out1, out2.T)
|
|
||||||
# cosineSimilarity = innerProduct / (np.linalg.norm(out1, ord=2) * np.linalg.norm(out2, ord=2))
|
|
||||||
# print(f'cosine similarity is {cosineSimilarity[0][0]:.4f}')
|
|
@ -43,19 +43,14 @@ class BNneckHead(nn.Module):
|
|||||||
if not self.training: return bn_feat
|
if not self.training: return bn_feat
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
try:
|
try: cls_outputs = self.classifier(bn_feat)
|
||||||
cls_outputs = self.classifier(bn_feat)
|
except TypeError: cls_outputs = self.classifier(bn_feat, targets)
|
||||||
pred_class_logits = cls_outputs.detach()
|
|
||||||
except TypeError:
|
|
||||||
cls_outputs = self.classifier(bn_feat, targets)
|
|
||||||
pred_class_logits = F.linear(F.normalize(bn_feat.detach()), F.normalize(self.classifier.weight.detach()))
|
|
||||||
# Log prediction accuracy
|
|
||||||
CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
|
|
||||||
|
|
||||||
if self.neck_feat == "before":
|
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
|
||||||
feat = global_feat[..., 0, 0]
|
|
||||||
elif self.neck_feat == "after":
|
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
|
||||||
feat = bn_feat
|
elif self.neck_feat == "after": feat = bn_feat
|
||||||
else:
|
else:
|
||||||
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
|
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
|
||||||
return cls_outputs, feat
|
|
||||||
|
return cls_outputs, pred_class_logits, feat
|
||||||
|
@ -38,13 +38,9 @@ class LinearHead(nn.Module):
|
|||||||
if not self.training: return global_feat
|
if not self.training: return global_feat
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
try:
|
try: cls_outputs = self.classifier(global_feat)
|
||||||
cls_outputs = self.classifier(global_feat)
|
except TypeError: cls_outputs = self.classifier(global_feat, targets)
|
||||||
pred_class_logits = cls_outputs.detach()
|
|
||||||
except TypeError:
|
|
||||||
cls_outputs = self.classifier(global_feat, targets)
|
|
||||||
pred_class_logits = F.linear(F.normalize(global_feat.detach()), F.normalize(self.classifier.weight.detach()))
|
|
||||||
# Log prediction accuracy
|
|
||||||
CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
|
|
||||||
|
|
||||||
return cls_outputs, global_feat
|
pred_class_logits = F.linear(global_feat, self.classifier.weight)
|
||||||
|
|
||||||
|
return cls_outputs, pred_class_logits, global_feat
|
||||||
|
@ -55,19 +55,15 @@ class ReductionHead(nn.Module):
|
|||||||
if not self.training: return bn_feat
|
if not self.training: return bn_feat
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
try:
|
try: cls_outputs = self.classifier(bn_feat)
|
||||||
cls_outputs = self.classifier(bn_feat)
|
except TypeError: cls_outputs = self.classifier(bn_feat, targets)
|
||||||
pred_class_logits = cls_outputs.detach()
|
|
||||||
except TypeError:
|
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
|
||||||
cls_outputs = self.classifier(bn_feat, targets)
|
|
||||||
pred_class_logits = F.linear(F.normalize(bn_feat.detach()), F.normalize(self.classifier.weight.detach()))
|
|
||||||
# Log prediction accuracy
|
|
||||||
CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
|
|
||||||
|
|
||||||
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
|
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
|
||||||
elif self.neck_feat == "after": feat = bn_feat
|
elif self.neck_feat == "after": feat = bn_feat
|
||||||
else:
|
else:
|
||||||
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
|
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
|
||||||
|
|
||||||
return cls_outputs, feat
|
return cls_outputs, pred_class_logits, feat
|
||||||
|
|
||||||
|
@ -58,11 +58,9 @@ class Baseline(nn.Module):
|
|||||||
# throw an error. We just set all the targets to 0 to avoid this problem.
|
# throw an error. We just set all the targets to 0 to avoid this problem.
|
||||||
if targets.sum() < 0: targets.zero_()
|
if targets.sum() < 0: targets.zero_()
|
||||||
|
|
||||||
cls_outputs, features = self.heads(features, targets)
|
return self.heads(features, targets), targets
|
||||||
return cls_outputs, features, targets
|
|
||||||
else:
|
else:
|
||||||
pred_features = self.heads(features)
|
return self.heads(features)
|
||||||
return pred_features
|
|
||||||
|
|
||||||
def preprocess_image(self, batched_inputs):
|
def preprocess_image(self, batched_inputs):
|
||||||
"""
|
"""
|
||||||
@ -73,22 +71,22 @@ class Baseline(nn.Module):
|
|||||||
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def losses(self, outputs):
|
def losses(self, outputs, gt_labels):
|
||||||
r"""
|
r"""
|
||||||
Compute loss from modeling's outputs, the loss function input arguments
|
Compute loss from modeling's outputs, the loss function input arguments
|
||||||
must be the same as the outputs of the model forwarding.
|
must be the same as the outputs of the model forwarding.
|
||||||
"""
|
"""
|
||||||
cls_outputs, pred_features, gt_labels = outputs
|
cls_outputs, pred_class_logits, pred_features = outputs
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
loss_names = self._cfg.MODEL.LOSSES.NAME
|
loss_names = self._cfg.MODEL.LOSSES.NAME
|
||||||
|
|
||||||
|
# Log prediction accuracy
|
||||||
|
CrossEntropyLoss.log_accuracy(pred_class_logits.detach(), gt_labels)
|
||||||
|
|
||||||
if "CrossEntropyLoss" in loss_names:
|
if "CrossEntropyLoss" in loss_names:
|
||||||
loss_dict['loss_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels)
|
loss_dict['loss_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels)
|
||||||
|
|
||||||
if "TripletLoss" in loss_names:
|
if "TripletLoss" in loss_names:
|
||||||
loss_dict['loss_triplet'] = TripletLoss(self._cfg)(pred_features, gt_labels)
|
loss_dict['loss_triplet'] = TripletLoss(self._cfg)(pred_features, gt_labels)
|
||||||
|
|
||||||
if "CircleLoss" in loss_names:
|
|
||||||
loss_dict['loss_circle'] = CircleLoss(self._cfg)(pred_features, gt_labels)
|
|
||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
@ -153,31 +153,21 @@ class MGN(nn.Module):
|
|||||||
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
||||||
targets = batched_inputs["targets"].long().to(self.device)
|
targets = batched_inputs["targets"].long().to(self.device)
|
||||||
|
|
||||||
if targets.sum() < 0:
|
if targets.sum() < 0: targets.zero_()
|
||||||
targets.zero_()
|
|
||||||
self.b1_head(b1_pool_feat, targets)
|
|
||||||
self.b2_head(b2_pool_feat, targets)
|
|
||||||
self.b21_head(b21_pool_feat, targets)
|
|
||||||
self.b22_head(b22_pool_feat, targets)
|
|
||||||
self.b3_head(b3_pool_feat, targets)
|
|
||||||
self.b31_head(b31_pool_feat, targets)
|
|
||||||
self.b32_head(b32_pool_feat, targets)
|
|
||||||
self.b33_head(b33_pool_feat, targets)
|
|
||||||
return
|
|
||||||
|
|
||||||
b1_logits, b1_pool_feat = self.b1_head(b1_pool_feat, targets)
|
b1_logits, pred_class_logits, b1_pool_feat = self.b1_head(b1_pool_feat, targets)
|
||||||
b2_logits, b2_pool_feat = self.b2_head(b2_pool_feat, targets)
|
b2_logits, _, b2_pool_feat = self.b2_head(b2_pool_feat, targets)
|
||||||
b21_logits, b21_pool_feat = self.b21_head(b21_pool_feat, targets)
|
b21_logits, _, b21_pool_feat = self.b21_head(b21_pool_feat, targets)
|
||||||
b22_logits, b22_pool_feat = self.b22_head(b22_pool_feat, targets)
|
b22_logits, _, b22_pool_feat = self.b22_head(b22_pool_feat, targets)
|
||||||
b3_logits, b3_pool_feat = self.b3_head(b3_pool_feat, targets)
|
b3_logits, _, b3_pool_feat = self.b3_head(b3_pool_feat, targets)
|
||||||
b31_logits, b31_pool_feat = self.b31_head(b31_pool_feat, targets)
|
b31_logits, _, b31_pool_feat = self.b31_head(b31_pool_feat, targets)
|
||||||
b32_logits, b32_pool_feat = self.b32_head(b32_pool_feat, targets)
|
b32_logits, _, b32_pool_feat = self.b32_head(b32_pool_feat, targets)
|
||||||
b33_logits, b33_pool_feat = self.b33_head(b33_pool_feat, targets)
|
b33_logits, _, b33_pool_feat = self.b33_head(b33_pool_feat, targets)
|
||||||
return b1_logits, b2_logits, b21_logits, b22_logits, \
|
|
||||||
b3_logits, b31_logits, b32_logits, b33_logits, \
|
return (b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits,
|
||||||
b1_pool_feat, b2_pool_feat, b3_pool_feat, \
|
b1_pool_feat, b2_pool_feat, b3_pool_feat,
|
||||||
torch.cat((b21_pool_feat, b22_pool_feat), dim=1), \
|
torch.cat((b21_pool_feat, b22_pool_feat), dim=1),
|
||||||
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1), targets
|
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1), pred_class_logits), targets
|
||||||
|
|
||||||
else:
|
else:
|
||||||
b1_pool_feat = self.b1_head(b1_pool_feat)
|
b1_pool_feat = self.b1_head(b1_pool_feat)
|
||||||
@ -202,12 +192,16 @@ class MGN(nn.Module):
|
|||||||
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def losses(self, outputs):
|
def losses(self, outputs, gt_labels):
|
||||||
b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits, \
|
b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits, \
|
||||||
b1_pool_feat, b2_pool_feat, b3_pool_feat, b22_pool_feat, b33_pool_feat, gt_labels = outputs
|
b1_pool_feat, b2_pool_feat, b3_pool_feat, b22_pool_feat, b33_pool_feat, pred_class_logits = outputs
|
||||||
|
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
loss_names = self._cfg.MODEL.LOSSES.NAME
|
loss_names = self._cfg.MODEL.LOSSES.NAME
|
||||||
|
|
||||||
|
# Log prediction accuracy
|
||||||
|
CrossEntropyLoss.log_accuracy(pred_class_logits.detach(), gt_labels)
|
||||||
|
|
||||||
if "CrossEntropyLoss" in loss_names:
|
if "CrossEntropyLoss" in loss_names:
|
||||||
loss_dict['loss_cls_b1'] = CrossEntropyLoss(self._cfg)(b1_logits, gt_labels)
|
loss_dict['loss_cls_b1'] = CrossEntropyLoss(self._cfg)(b1_logits, gt_labels)
|
||||||
loss_dict['loss_cls_b2'] = CrossEntropyLoss(self._cfg)(b2_logits, gt_labels)
|
loss_dict['loss_cls_b2'] = CrossEntropyLoss(self._cfg)(b2_logits, gt_labels)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user