From 5656aaba85b3815d69624ecfb33efb1bd888acd6 Mon Sep 17 00:00:00 2001 From: KaiyangZhou Date: Mon, 3 Aug 2020 16:17:49 +0100 Subject: [PATCH] bug fix in VideoTripletEngine --- torchreid/engine/video/triplet.py | 35 +++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/torchreid/engine/video/triplet.py b/torchreid/engine/video/triplet.py index 5d43b8b..b2778db 100644 --- a/torchreid/engine/video/triplet.py +++ b/torchreid/engine/video/triplet.py @@ -1,10 +1,10 @@ from __future__ import division, print_function, absolute_import +import torch from torchreid.engine.image import ImageTripletEngine -from torchreid.engine.video import VideoSoftmaxEngine -class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine): +class VideoTripletEngine(ImageTripletEngine): """Triplet-loss engine for video-reid. Args: @@ -89,3 +89,34 @@ class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine): label_smooth=label_smooth ) self.pooling_method = pooling_method + + def parse_data_for_train(self, data): + imgs = data['img'] + pids = data['pid'] + if imgs.dim() == 5: + # b: batch size + # s: sqeuence length + # c: channel depth + # h: height + # w: width + b, s, c, h, w = imgs.size() + imgs = imgs.view(b * s, c, h, w) + pids = pids.view(b, 1).expand(b, s) + pids = pids.contiguous().view(b * s) + return imgs, pids + + def extract_features(self, input): + # b: batch size + # s: sqeuence length + # c: channel depth + # h: height + # w: width + b, s, c, h, w = input.size() + input = input.view(b * s, c, h, w) + features = self.model(input) + features = features.view(b, s, -1) + if self.pooling_method == 'avg': + features = torch.mean(features, 1) + else: + features = torch.max(features, 1)[0] + return features