mirror of https://github.com/FoundationVision/GLEE
update track loss for video finetune
parent
5ad764cc18
commit
5f1832dd7a
|
@ -618,3 +618,92 @@ class GLEE_Model(nn.Module):
|
|||
|
||||
return ref_feats, ref_masks
|
||||
|
||||
|
||||
|
||||
|
||||
def get_tracking_contrastive_lossv3(self, video_outputs, video_targets, task): # IDOL track loss
|
||||
if task in self.no_mask_tasks:
|
||||
indices_all = self.matcher(video_outputs, video_targets, 'task', cost=["cls", "box"])
|
||||
else:
|
||||
indices_all = self.matcher(video_outputs, video_targets, 'task' )
|
||||
|
||||
video_len = self.video_info['len']
|
||||
track_loss = 0
|
||||
num_inst = 0
|
||||
|
||||
batch_similarity = []
|
||||
batch_label = []
|
||||
for i in range(self.video_info['bz']): # 每个batch 切片操作
|
||||
indices = indices_all[i*video_len:(i+1)*video_len]
|
||||
bz_embedding = video_outputs['pred_track_embed'][i*video_len:(i+1)*video_len]
|
||||
bz_target = video_targets[i*video_len:(i+1)*video_len]
|
||||
zero = torch.tensor(0).to(bz_embedding.device)
|
||||
one = torch.tensor(1).to(bz_embedding.device)
|
||||
video_contras = {}
|
||||
memory = {}
|
||||
for f,(findice,fembed,ftarget) in enumerate(zip(indices,bz_embedding,bz_target)):
|
||||
vf_embed_k = fembed[findice[0]]
|
||||
if len(vf_embed_k.shape) ==1:
|
||||
vf_embed_k.unsqueeze(0)
|
||||
vf_gt_id_k = ftarget['inst_id'][findice[1]]
|
||||
|
||||
|
||||
# neg sample
|
||||
sampled_index = set(random.sample(range(300),20))
|
||||
neg_index = sampled_index - set(findice[0].tolist())
|
||||
neg_index = list(neg_index)
|
||||
vf_embed_neg = fembed[neg_index]
|
||||
vf_embed = torch.cat([vf_embed_k,vf_embed_neg],dim=0)
|
||||
vf_gt_id = torch.cat([vf_gt_id_k,zero.repeat(len(neg_index))-2],dim=0)
|
||||
|
||||
video_contras[f] = (vf_embed,vf_gt_id)
|
||||
|
||||
if f > 0:
|
||||
num_inst = num_inst + len(ftarget['inst_id'])
|
||||
similarity_matric = torch.einsum("ac,bc->ab", video_contras[f-1][0], vf_embed_k) #[num_1, num_gt]
|
||||
|
||||
v0_gt_id_m = video_contras[f-1][1].unsqueeze(-1).repeat(1,len(vf_gt_id_k))
|
||||
v1_gt_id_m = vf_gt_id_k.unsqueeze(0).repeat(len(video_contras[f-1][1]),1)
|
||||
similarity_label = (v0_gt_id_m == v1_gt_id_m).float() # can be treat as one hot label
|
||||
# use focal loss instand of contrastive
|
||||
# aux cosine
|
||||
# aux_contrastive_embed=nn.functional.normalize(video_contras[f-1][0].float(),dim=1)
|
||||
# key_embed_i=nn.functional.normalize(vf_embed_k.float(),dim=1)
|
||||
# cosine = torch.einsum('nc,kc->nk',[aux_contrastive_embed,key_embed_i])
|
||||
|
||||
# batch_similarity_aux.append(cosine.flatten() )
|
||||
batch_similarity.append(similarity_matric.flatten() )
|
||||
batch_label.append(similarity_label.flatten() )
|
||||
if len(batch_similarity)==0 or torch.cat(batch_similarity).shape[0] == 0:
|
||||
track_loss = (video_outputs['pred_track_embed']*0).sum()
|
||||
else:
|
||||
contras_loss = 0
|
||||
aux_loss = 0
|
||||
for pred, label in zip(batch_similarity, batch_label):
|
||||
if len(pred) == 0:
|
||||
continue
|
||||
pred = pred.unsqueeze(0)
|
||||
label = label.unsqueeze(0)
|
||||
# aux_pred = aux_pred.unsqueeze(0)
|
||||
|
||||
pos_inds = (label == 1)
|
||||
neg_inds = (label == 0)
|
||||
pred_pos = pred * pos_inds.float()
|
||||
pred_neg = pred * neg_inds.float()
|
||||
# use -inf to mask out unwanted elements.
|
||||
pred_pos[neg_inds] = pred_pos[neg_inds] + float('inf')
|
||||
pred_neg[pos_inds] = pred_neg[pos_inds] + float('-inf')
|
||||
_pos_expand = torch.repeat_interleave(pred_pos, pred.shape[1], dim=1)
|
||||
_neg_expand = pred_neg.repeat(1, pred.shape[1])
|
||||
# [bz,N], N is all pos and negative samples on reference frame, label indicate it's pos or negative
|
||||
x = torch.nn.functional.pad((_neg_expand - _pos_expand), (0, 1), "constant", 0)
|
||||
contras_loss += torch.logsumexp(x, dim=1)
|
||||
|
||||
|
||||
# track_loss = (contras_loss + 1.5*aux_loss)
|
||||
track_loss = contras_loss/max(num_inst,1)
|
||||
|
||||
track_loss = track_loss # /(self.video_info['bz'])
|
||||
return track_loss
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue