2019-03-20 01:26:08 +08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import print_function
|
|
|
|
from __future__ import division
|
|
|
|
|
|
|
|
import time
|
|
|
|
import datetime
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
import torchreid
|
|
|
|
from torchreid.engine.image import ImageTripletEngine
|
|
|
|
from torchreid.engine.video import VideoSoftmaxEngine
|
|
|
|
|
|
|
|
|
|
|
|
class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine):
|
|
|
|
|
2019-03-21 20:53:21 +08:00
|
|
|
def __init__(self, datamanager, model, optimizer, margin=0.3,
|
2019-03-20 01:26:08 +08:00
|
|
|
weight_t=1, weight_x=1, scheduler=None, use_cpu=False,
|
|
|
|
label_smooth=True, pooling_method='avg'):
|
2019-03-21 20:53:21 +08:00
|
|
|
super(VideoTripletEngine, self).__init__(datamanager, model, optimizer, margin=margin,
|
2019-03-20 01:26:08 +08:00
|
|
|
weight_t=weight_t, weight_x=weight_x,
|
|
|
|
scheduler=scheduler, use_cpu=use_cpu,
|
|
|
|
label_smooth=label_smooth)
|
|
|
|
self.pooling_method = pooling_method
|