2019-03-20 01:26:08 +08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import print_function
|
|
|
|
from __future__ import division
|
|
|
|
|
|
|
|
import time
|
|
|
|
import datetime
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
import torchreid
|
|
|
|
from torchreid.engine.image import ImageSoftmaxEngine
|
|
|
|
|
|
|
|
|
|
|
|
class VideoSoftmaxEngine(ImageSoftmaxEngine):
|
|
|
|
|
2019-03-21 20:53:21 +08:00
|
|
|
def __init__(self, datamanager, model, optimizer, scheduler=None,
|
2019-03-20 01:26:08 +08:00
|
|
|
use_cpu=False, label_smooth=True, pooling_method='avg'):
|
2019-03-21 20:53:21 +08:00
|
|
|
super(VideoSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler=scheduler,
|
2019-03-20 01:26:08 +08:00
|
|
|
use_cpu=use_cpu, label_smooth=label_smooth)
|
|
|
|
self.pooling_method = pooling_method
|
|
|
|
|
2019-03-21 20:53:21 +08:00
|
|
|
def _parse_data_for_train(self, data):
|
|
|
|
imgs = data[0]
|
|
|
|
pids = data[1]
|
|
|
|
if imgs.dim() == 5:
|
|
|
|
# b: batch size
|
|
|
|
# s: sqeuence length
|
|
|
|
# c: channel depth
|
|
|
|
# h: height
|
|
|
|
# w: width
|
|
|
|
b, s, c, h, w = imgs.size()
|
|
|
|
imgs = imgs.view(b*s, c, h, w)
|
|
|
|
pids = pids.view(b, 1).expand(b, s)
|
|
|
|
pids = pids.contiguous().view(b*s)
|
|
|
|
return imgs, pids
|
|
|
|
|
2019-03-20 01:26:08 +08:00
|
|
|
def _extract_features(self, input):
|
|
|
|
self.model.eval()
|
|
|
|
# b: batch size
|
|
|
|
# s: sqeuence length
|
|
|
|
# c: channel depth
|
|
|
|
# h: height
|
|
|
|
# w: width
|
|
|
|
b, s, c, h, w = input.size()
|
|
|
|
input = input.view(b*s, c, h, w)
|
|
|
|
features = self.model(input)
|
|
|
|
features = features.view(b, s, -1)
|
|
|
|
if self.pooling_method == 'avg':
|
|
|
|
features = torch.mean(features, 1)
|
|
|
|
else:
|
|
|
|
features = torch.max(features, 1)[0]
|
|
|
|
return features
|