bug fix in VideoTripletEngine
parent
069356a6e4
commit
5656aaba85
|
@ -1,10 +1,10 @@
|
||||||
from __future__ import division, print_function, absolute_import
|
from __future__ import division, print_function, absolute_import
|
||||||
|
import torch
|
||||||
|
|
||||||
from torchreid.engine.image import ImageTripletEngine
|
from torchreid.engine.image import ImageTripletEngine
|
||||||
from torchreid.engine.video import VideoSoftmaxEngine
|
|
||||||
|
|
||||||
|
|
||||||
class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine):
|
class VideoTripletEngine(ImageTripletEngine):
|
||||||
"""Triplet-loss engine for video-reid.
|
"""Triplet-loss engine for video-reid.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -89,3 +89,34 @@ class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine):
|
||||||
label_smooth=label_smooth
|
label_smooth=label_smooth
|
||||||
)
|
)
|
||||||
self.pooling_method = pooling_method
|
self.pooling_method = pooling_method
|
||||||
|
|
||||||
|
def parse_data_for_train(self, data):
|
||||||
|
imgs = data['img']
|
||||||
|
pids = data['pid']
|
||||||
|
if imgs.dim() == 5:
|
||||||
|
# b: batch size
|
||||||
|
# s: sqeuence length
|
||||||
|
# c: channel depth
|
||||||
|
# h: height
|
||||||
|
# w: width
|
||||||
|
b, s, c, h, w = imgs.size()
|
||||||
|
imgs = imgs.view(b * s, c, h, w)
|
||||||
|
pids = pids.view(b, 1).expand(b, s)
|
||||||
|
pids = pids.contiguous().view(b * s)
|
||||||
|
return imgs, pids
|
||||||
|
|
||||||
|
def extract_features(self, input):
|
||||||
|
# b: batch size
|
||||||
|
# s: sqeuence length
|
||||||
|
# c: channel depth
|
||||||
|
# h: height
|
||||||
|
# w: width
|
||||||
|
b, s, c, h, w = input.size()
|
||||||
|
input = input.view(b * s, c, h, w)
|
||||||
|
features = self.model(input)
|
||||||
|
features = features.view(b, s, -1)
|
||||||
|
if self.pooling_method == 'avg':
|
||||||
|
features = torch.mean(features, 1)
|
||||||
|
else:
|
||||||
|
features = torch.max(features, 1)[0]
|
||||||
|
return features
|
||||||
|
|
Loading…
Reference in New Issue