deep-person-reid/torchreid/engine/video/triplet.py

92 lines
3.1 KiB
Python
Raw Normal View History

2019-12-01 10:35:44 +08:00
from __future__ import division, print_function, absolute_import
2019-03-20 01:26:08 +08:00
from torchreid.engine.image import ImageTripletEngine
from torchreid.engine.video import VideoSoftmaxEngine
class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine):
2019-03-24 07:09:39 +08:00
"""Triplet-loss engine for video-reid.
Args:
datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager``
or ``torchreid.data.VideoDataManager``.
model (nn.Module): model instance.
optimizer (Optimizer): an Optimizer.
margin (float, optional): margin for triplet loss. Default is 0.3.
weight_t (float, optional): weight for triplet loss. Default is 1.
weight_x (float, optional): weight for softmax loss. Default is 1.
scheduler (LRScheduler, optional): if None, no learning rate decay will be performed.
2019-08-26 17:34:31 +08:00
use_gpu (bool, optional): use gpu. Default is True.
2019-03-24 07:09:39 +08:00
label_smooth (bool, optional): use label smoothing regularizer. Default is True.
pooling_method (str, optional): how to pool features for a tracklet.
Default is "avg" (average). Choices are ["avg", "max"].
Examples::
2020-05-05 22:58:00 +08:00
2019-03-24 07:09:39 +08:00
import torch
import torchreid
# Each batch contains batch_size*seq_len images
2020-05-05 22:58:00 +08:00
# Each identity is sampled with num_instances tracklets
2019-03-24 07:09:39 +08:00
datamanager = torchreid.data.VideoDataManager(
root='path/to/reid-data',
sources='mars',
height=256,
width=128,
combineall=False,
num_instances=4,
train_sampler='RandomIdentitySampler'
batch_size=8, # number of tracklets
seq_len=15 # number of images in each tracklet
)
model = torchreid.models.build_model(
name='resnet50',
num_classes=datamanager.num_train_pids,
loss='triplet'
)
model = model.cuda()
optimizer = torchreid.optim.build_optimizer(
model, optim='adam', lr=0.0003
)
scheduler = torchreid.optim.build_lr_scheduler(
optimizer,
lr_scheduler='single_step',
stepsize=20
)
engine = torchreid.engine.VideoTripletEngine(
datamanager, model, optimizer, margin=0.3,
weight_t=0.7, weight_x=1, scheduler=scheduler,
pooling_method='avg'
)
engine.run(
max_epoch=60,
save_dir='log/resnet50-triplet-mars',
print_freq=10
)
2019-05-06 17:48:03 +08:00
"""
2019-03-20 01:26:08 +08:00
2019-12-01 10:35:44 +08:00
def __init__(
self,
datamanager,
model,
optimizer,
margin=0.3,
weight_t=1,
weight_x=1,
scheduler=None,
use_gpu=True,
label_smooth=True,
pooling_method='avg'
):
super(VideoTripletEngine, self).__init__(
datamanager,
model,
optimizer,
margin=margin,
weight_t=weight_t,
weight_x=weight_x,
scheduler=scheduler,
use_gpu=use_gpu,
label_smooth=label_smooth
)
2019-08-26 17:34:31 +08:00
self.pooling_method = pooling_method