bug fix in VideoTripletEngine
parent
069356a6e4
commit
5656aaba85
|
@ -1,10 +1,10 @@
|
|||
from __future__ import division, print_function, absolute_import
|
||||
import torch
|
||||
|
||||
from torchreid.engine.image import ImageTripletEngine
|
||||
from torchreid.engine.video import VideoSoftmaxEngine
|
||||
|
||||
|
||||
class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine):
|
||||
class VideoTripletEngine(ImageTripletEngine):
|
||||
"""Triplet-loss engine for video-reid.
|
||||
|
||||
Args:
|
||||
|
@ -89,3 +89,34 @@ class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine):
|
|||
label_smooth=label_smooth
|
||||
)
|
||||
self.pooling_method = pooling_method
|
||||
|
||||
def parse_data_for_train(self, data):
|
||||
imgs = data['img']
|
||||
pids = data['pid']
|
||||
if imgs.dim() == 5:
|
||||
# b: batch size
|
||||
# s: sqeuence length
|
||||
# c: channel depth
|
||||
# h: height
|
||||
# w: width
|
||||
b, s, c, h, w = imgs.size()
|
||||
imgs = imgs.view(b * s, c, h, w)
|
||||
pids = pids.view(b, 1).expand(b, s)
|
||||
pids = pids.contiguous().view(b * s)
|
||||
return imgs, pids
|
||||
|
||||
def extract_features(self, input):
|
||||
# b: batch size
|
||||
# s: sqeuence length
|
||||
# c: channel depth
|
||||
# h: height
|
||||
# w: width
|
||||
b, s, c, h, w = input.size()
|
||||
input = input.view(b * s, c, h, w)
|
||||
features = self.model(input)
|
||||
features = features.view(b, s, -1)
|
||||
if self.pooling_method == 'avg':
|
||||
features = torch.mean(features, 1)
|
||||
else:
|
||||
features = torch.max(features, 1)[0]
|
||||
return features
|
||||
|
|
Loading…
Reference in New Issue