Source code for torchreid.engine.video.softmax

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import time
import datetime

import torch

import torchreid
from torchreid.engine.image import ImageSoftmaxEngine


[docs]class VideoSoftmaxEngine(ImageSoftmaxEngine): """Softmax-loss engine for video-reid. Args: datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager`` or ``torchreid.data.VideoDataManager``. model (nn.Module): model instance. optimizer (Optimizer): an Optimizer. scheduler (LRScheduler, optional): if None, no learning rate decay will be performed. use_cpu (bool, optional): use cpu. Default is False. label_smooth (bool, optional): use label smoothing regularizer. Default is True. pooling_method (str, optional): how to pool features for a tracklet. Default is "avg" (average). Choices are ["avg", "max"]. Examples:: import torch import torchreid # Each batch contains batch_size*seq_len images datamanager = torchreid.data.VideoDataManager( root='path/to/reid-data', sources='mars', height=256, width=128, combineall=False, batch_size=8, # number of tracklets seq_len=15 # number of images in each tracklet ) model = torchreid.models.build_model( name='resnet50', num_classes=datamanager.num_train_pids, loss='softmax' ) model = model.cuda() optimizer = torchreid.optim.build_optimizer( model, optim='adam', lr=0.0003 ) scheduler = torchreid.optim.build_lr_scheduler( optimizer, lr_scheduler='single_step', stepsize=20 ) engine = torchreid.engine.VideoSoftmaxEngine( datamanager, model, optimizer, scheduler=scheduler, pooling_method='avg' ) engine.run( max_epoch=60, save_dir='log/resnet50-softmax-mars', print_freq=10 ) """ def __init__(self, datamanager, model, optimizer, scheduler=None, use_cpu=False, label_smooth=True, pooling_method='avg'): super(VideoSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler=scheduler, use_cpu=use_cpu, label_smooth=label_smooth) self.pooling_method = pooling_method def _parse_data_for_train(self, data): imgs = data[0] pids = data[1] if imgs.dim() == 5: # b: batch size # s: sqeuence length # c: channel depth # h: height # w: width b, s, c, h, w = imgs.size() imgs = imgs.view(b*s, c, h, w) pids = pids.view(b, 1).expand(b, s) pids = pids.contiguous().view(b*s) return imgs, pids def _extract_features(self, input): self.model.eval() # b: batch size # s: sqeuence length # c: channel depth # h: height # w: width b, s, c, h, w = input.size() input = input.view(b*s, c, h, w) features = self.model(input) features = features.view(b, s, -1) if self.pooling_method == 'avg': features = torch.mean(features, 1) else: features = torch.max(features, 1)[0] return features