deep-person-reid/torchreid/engine/video/triplet.py

24 lines
933 B
Python
Raw Normal View History

2019-03-20 01:26:08 +08:00
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import datetime
import torch
import torchreid
from torchreid.engine.image import ImageTripletEngine
from torchreid.engine.video import VideoSoftmaxEngine
class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine):
2019-03-21 20:53:21 +08:00
def __init__(self, datamanager, model, optimizer, margin=0.3,
2019-03-20 01:26:08 +08:00
weight_t=1, weight_x=1, scheduler=None, use_cpu=False,
label_smooth=True, pooling_method='avg'):
2019-03-21 20:53:21 +08:00
super(VideoTripletEngine, self).__init__(datamanager, model, optimizer, margin=margin,
2019-03-20 01:26:08 +08:00
weight_t=weight_t, weight_x=weight_x,
scheduler=scheduler, use_cpu=use_cpu,
label_smooth=label_smooth)
self.pooling_method = pooling_method