bug fix in VideoTripletEngine

pull/405/head
KaiyangZhou 2020-08-03 16:17:49 +01:00
parent 069356a6e4
commit 5656aaba85
1 changed files with 33 additions and 2 deletions

View File

@ -1,10 +1,10 @@
from __future__ import division, print_function, absolute_import
import torch
from torchreid.engine.image import ImageTripletEngine
from torchreid.engine.video import VideoSoftmaxEngine
class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine):
class VideoTripletEngine(ImageTripletEngine):
"""Triplet-loss engine for video-reid.
Args:
@ -89,3 +89,34 @@ class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine):
label_smooth=label_smooth
)
self.pooling_method = pooling_method
def parse_data_for_train(self, data):
imgs = data['img']
pids = data['pid']
if imgs.dim() == 5:
# b: batch size
# s: sqeuence length
# c: channel depth
# h: height
# w: width
b, s, c, h, w = imgs.size()
imgs = imgs.view(b * s, c, h, w)
pids = pids.view(b, 1).expand(b, s)
pids = pids.contiguous().view(b * s)
return imgs, pids
def extract_features(self, input):
# b: batch size
# s: sqeuence length
# c: channel depth
# h: height
# w: width
b, s, c, h, w = input.size()
input = input.view(b * s, c, h, w)
features = self.model(input)
features = features.view(b, s, -1)
if self.pooling_method == 'avg':
features = torch.mean(features, 1)
else:
features = torch.max(features, 1)[0]
return features