Update according to issue #12
parent
2b4296a2c5
commit
8cbe846f9d
|
@ -0,0 +1,11 @@
|
||||||
|
# Experiment all tricks with center loss : 256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005
|
||||||
|
# Dataset 1: market1501
|
||||||
|
# imagesize: 256x128
|
||||||
|
# batchsize: 16x4
|
||||||
|
# warmup_step 10
|
||||||
|
# random erase prob 0.5
|
||||||
|
# labelsmooth: on
|
||||||
|
# last stride 1
|
||||||
|
# bnneck on
|
||||||
|
# with center loss
|
||||||
|
python3 tools/train.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('2')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haoluo/data')" MODEL.PRETRAIN_CHOICE "('self')" MODEL.PRETRAIN_PATH "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-pretrain_choice_all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005/resnet50_model_2.pth')" OUTPUT_DIR "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-pretrain_choice_all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005')"
|
|
@ -9,4 +9,4 @@
|
||||||
# bnneck on
|
# bnneck on
|
||||||
# with center loss
|
# with center loss
|
||||||
# without re-ranking
|
# without re-ranking
|
||||||
python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('1')" DATASETS.NAMES "('dukemtmc')" DATASETS.ROOT_DIR "('/home/haoluo/data')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/dukemtmc/Experiment-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005/resnet50_model_120.pth')"
|
python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('1')" DATASETS.NAMES "('dukemtmc')" DATASETS.ROOT_DIR "('/home/haoluo/data')" MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/dukemtmc/Experiment-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005/resnet50_model_120.pth')"
|
|
@ -9,4 +9,4 @@
|
||||||
# bnneck on
|
# bnneck on
|
||||||
# with center loss
|
# with center loss
|
||||||
# without re-ranking
|
# without re-ranking
|
||||||
python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haoluo/data')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005/resnet50_model_120.pth')"
|
python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haoluo/data')" MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005/resnet50_model_120.pth')"
|
|
@ -9,4 +9,4 @@
|
||||||
# bnneck on
|
# bnneck on
|
||||||
# without center loss
|
# without center loss
|
||||||
# without re-ranking
|
# without re-ranking
|
||||||
python3 tools/test.py --config_file='configs/softmax_triplet.yml' MODEL.DEVICE_ID "('1')" DATASETS.NAMES "('dukemtmc')" DATASETS.ROOT_DIR "('/home/haoluo/data')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/dukemtmc/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on/resnet50_model_120.pth')"
|
python3 tools/test.py --config_file='configs/softmax_triplet.yml' MODEL.DEVICE_ID "('1')" DATASETS.NAMES "('dukemtmc')" DATASETS.ROOT_DIR "('/home/haoluo/data')" MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/dukemtmc/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on/resnet50_model_120.pth')"
|
|
@ -9,4 +9,4 @@
|
||||||
# bnneck on
|
# bnneck on
|
||||||
# without center loss
|
# without center loss
|
||||||
# without re-ranking
|
# without re-ranking
|
||||||
python3 tools/test.py --config_file='configs/softmax_triplet.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haoluo/data')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on/resnet50_model_120.pth')"
|
python3 tools/test.py --config_file='configs/softmax_triplet.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haoluo/data')" MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on/resnet50_model_120.pth')"
|
|
@ -9,4 +9,4 @@
|
||||||
# bnneck on
|
# bnneck on
|
||||||
# with center loss
|
# with center loss
|
||||||
# with re-ranking
|
# with re-ranking
|
||||||
python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('1')" DATASETS.NAMES "('dukemtmc')" TEST.RE_RANKING "('yes')" DATASETS.ROOT_DIR "('/home/haoluo/data')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/dukemtmc/Experiment-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005/resnet50_model_120.pth')"
|
python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('1')" DATASETS.NAMES "('dukemtmc')" TEST.RE_RANKING "('yes')" DATASETS.ROOT_DIR "('/home/haoluo/data')" MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/dukemtmc/Experiment-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005/resnet50_model_120.pth')"
|
|
@ -9,4 +9,4 @@
|
||||||
# bnneck on
|
# bnneck on
|
||||||
# with center loss
|
# with center loss
|
||||||
# with re-ranking
|
# with re-ranking
|
||||||
python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haoluo/data')" TEST.RE_RANKING "('yes')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005/resnet50_model_120.pth')"
|
python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haoluo/data')" TEST.RE_RANKING "('yes')" MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005/resnet50_model_120.pth')"
|
|
@ -27,6 +27,9 @@ _C.MODEL.NAME = 'resnet50'
|
||||||
_C.MODEL.LAST_STRIDE = 1
|
_C.MODEL.LAST_STRIDE = 1
|
||||||
# Path to pretrained model of backbone
|
# Path to pretrained model of backbone
|
||||||
_C.MODEL.PRETRAIN_PATH = ''
|
_C.MODEL.PRETRAIN_PATH = ''
|
||||||
|
# Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model
|
||||||
|
# Options: 'imagenet' or 'self'
|
||||||
|
_C.MODEL.PRETRAIN_CHOICE = 'imagenet'
|
||||||
# If train with BNNeck, options: 'bnneck' or 'no'
|
# If train with BNNeck, options: 'bnneck' or 'no'
|
||||||
_C.MODEL.NECK = 'bnneck'
|
_C.MODEL.NECK = 'bnneck'
|
||||||
# If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration
|
# If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
MODEL:
|
MODEL:
|
||||||
|
PRETRAIN_CHOICE: 'imagenet'
|
||||||
PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
|
PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
|
||||||
LAST_STRIDE: 2
|
LAST_STRIDE: 2
|
||||||
NECK: 'no'
|
NECK: 'no'
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
MODEL:
|
MODEL:
|
||||||
|
PRETRAIN_CHOICE: 'imagenet'
|
||||||
PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
|
PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
|
||||||
METRIC_LOSS_TYPE: 'triplet'
|
METRIC_LOSS_TYPE: 'triplet'
|
||||||
IF_LABELSMOOTH: 'on'
|
IF_LABELSMOOTH: 'on'
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
MODEL:
|
MODEL:
|
||||||
|
PRETRAIN_CHOICE: 'imagenet'
|
||||||
PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
|
PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
|
||||||
METRIC_LOSS_TYPE: 'triplet_center'
|
METRIC_LOSS_TYPE: 'triplet_center'
|
||||||
IF_LABELSMOOTH: 'on'
|
IF_LABELSMOOTH: 'on'
|
||||||
|
|
|
@ -130,7 +130,8 @@ def do_train(
|
||||||
optimizer,
|
optimizer,
|
||||||
scheduler,
|
scheduler,
|
||||||
loss_fn,
|
loss_fn,
|
||||||
num_query
|
num_query,
|
||||||
|
start_epoch
|
||||||
):
|
):
|
||||||
log_period = cfg.SOLVER.LOG_PERIOD
|
log_period = cfg.SOLVER.LOG_PERIOD
|
||||||
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
|
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
|
||||||
|
@ -155,6 +156,10 @@ def do_train(
|
||||||
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
|
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
|
||||||
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
|
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
|
||||||
|
|
||||||
|
@trainer.on(Events.STARTED)
|
||||||
|
def start_training(engine):
|
||||||
|
engine.state.epoch = start_epoch
|
||||||
|
|
||||||
@trainer.on(Events.EPOCH_STARTED)
|
@trainer.on(Events.EPOCH_STARTED)
|
||||||
def adjust_learning_rate(engine):
|
def adjust_learning_rate(engine):
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
@ -201,7 +206,8 @@ def do_train_with_center(
|
||||||
optimizer_center,
|
optimizer_center,
|
||||||
scheduler,
|
scheduler,
|
||||||
loss_fn,
|
loss_fn,
|
||||||
num_query
|
num_query,
|
||||||
|
start_epoch
|
||||||
):
|
):
|
||||||
log_period = cfg.SOLVER.LOG_PERIOD
|
log_period = cfg.SOLVER.LOG_PERIOD
|
||||||
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
|
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
|
||||||
|
@ -218,7 +224,9 @@ def do_train_with_center(
|
||||||
timer = Timer(average=True)
|
timer = Timer(average=True)
|
||||||
|
|
||||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
|
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
|
||||||
'optimizer': optimizer.state_dict()})
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'optimizer_center': optimizer_center.state_dict()})
|
||||||
|
|
||||||
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
|
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
|
||||||
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
|
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
|
||||||
|
|
||||||
|
@ -226,6 +234,10 @@ def do_train_with_center(
|
||||||
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
|
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
|
||||||
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
|
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
|
||||||
|
|
||||||
|
@trainer.on(Events.STARTED)
|
||||||
|
def start_training(engine):
|
||||||
|
engine.state.epoch = start_epoch
|
||||||
|
|
||||||
@trainer.on(Events.EPOCH_STARTED)
|
@trainer.on(Events.EPOCH_STARTED)
|
||||||
def adjust_learning_rate(engine):
|
def adjust_learning_rate(engine):
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
|
@ -10,5 +10,5 @@ from .baseline import Baseline
|
||||||
def build_model(cfg, num_classes):
|
def build_model(cfg, num_classes):
|
||||||
# if cfg.MODEL.NAME == 'resnet50':
|
# if cfg.MODEL.NAME == 'resnet50':
|
||||||
# model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT)
|
# model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT)
|
||||||
model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME)
|
model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE)
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -37,7 +37,7 @@ def weights_init_classifier(m):
|
||||||
class Baseline(nn.Module):
|
class Baseline(nn.Module):
|
||||||
in_planes = 2048
|
in_planes = 2048
|
||||||
|
|
||||||
def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name):
|
def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice):
|
||||||
super(Baseline, self).__init__()
|
super(Baseline, self).__init__()
|
||||||
if model_name == 'resnet18':
|
if model_name == 'resnet18':
|
||||||
self.base = ResNet(last_stride=last_stride,
|
self.base = ResNet(last_stride=last_stride,
|
||||||
|
@ -117,14 +117,14 @@ class Baseline(nn.Module):
|
||||||
last_stride=last_stride)
|
last_stride=last_stride)
|
||||||
elif model_name == 'senet154':
|
elif model_name == 'senet154':
|
||||||
self.base = SENet(block=SEBottleneck,
|
self.base = SENet(block=SEBottleneck,
|
||||||
layers=[3, 8, 36, 3],
|
layers=[3, 8, 36, 3],
|
||||||
groups=64,
|
groups=64,
|
||||||
reduction=16,
|
reduction=16,
|
||||||
dropout_p=0.2,
|
dropout_p=0.2,
|
||||||
last_stride=last_stride)
|
last_stride=last_stride)
|
||||||
|
|
||||||
|
if pretrain_choice == 'imagenet':
|
||||||
self.base.load_param(model_path)
|
self.base.load_param(model_path)
|
||||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||||
# self.gap = nn.AdaptiveMaxPool2d(1)
|
# self.gap = nn.AdaptiveMaxPool2d(1)
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
@ -167,6 +167,6 @@ class Baseline(nn.Module):
|
||||||
def load_param(self, trained_path):
|
def load_param(self, trained_path):
|
||||||
param_dict = torch.load(trained_path)
|
param_dict = torch.load(trained_path)
|
||||||
for i in param_dict:
|
for i in param_dict:
|
||||||
if 'classifier' in i:
|
# if 'classifier' in i:
|
||||||
continue
|
# continue
|
||||||
self.state_dict()[i].copy_(param_dict[i])
|
self.state_dict()[i].copy_(param_dict[i])
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import torch
|
||||||
|
|
||||||
from torch.backends import cudnn
|
from torch.backends import cudnn
|
||||||
|
|
||||||
|
@ -31,11 +32,28 @@ def train(cfg):
|
||||||
if cfg.MODEL.IF_WITH_CENTER == 'no':
|
if cfg.MODEL.IF_WITH_CENTER == 'no':
|
||||||
print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
|
print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
|
||||||
optimizer = make_optimizer(cfg, model)
|
optimizer = make_optimizer(cfg, model)
|
||||||
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
# scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
||||||
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
|
# cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
|
||||||
|
|
||||||
loss_func = make_loss(cfg, num_classes) # modified by gu
|
loss_func = make_loss(cfg, num_classes) # modified by gu
|
||||||
|
|
||||||
|
# Add for using self trained model
|
||||||
|
if cfg.MODEL.PRETRAIN_CHOICE == 'self':
|
||||||
|
start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1])
|
||||||
|
print('Start epoch:', start_epoch)
|
||||||
|
path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
|
||||||
|
print('Path to the checkpoint of optimizer:', path_to_optimizer)
|
||||||
|
model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
|
||||||
|
optimizer.load_state_dict(torch.load(path_to_optimizer))
|
||||||
|
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
||||||
|
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)
|
||||||
|
elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
|
||||||
|
start_epoch = 0
|
||||||
|
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
||||||
|
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
|
||||||
|
else:
|
||||||
|
print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE))
|
||||||
|
|
||||||
arguments = {}
|
arguments = {}
|
||||||
|
|
||||||
do_train(
|
do_train(
|
||||||
|
@ -44,19 +62,40 @@ def train(cfg):
|
||||||
train_loader,
|
train_loader,
|
||||||
val_loader,
|
val_loader,
|
||||||
optimizer,
|
optimizer,
|
||||||
scheduler,
|
scheduler, # modify for using self trained model
|
||||||
loss_func,
|
loss_func,
|
||||||
num_query
|
num_query,
|
||||||
|
start_epoch # add for using self trained model
|
||||||
)
|
)
|
||||||
elif cfg.MODEL.IF_WITH_CENTER == 'yes':
|
elif cfg.MODEL.IF_WITH_CENTER == 'yes':
|
||||||
print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
|
print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
|
||||||
loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu
|
loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu
|
||||||
optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion)
|
optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion)
|
||||||
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
# scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
||||||
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
|
# cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
|
||||||
|
|
||||||
arguments = {}
|
arguments = {}
|
||||||
|
|
||||||
|
# Add for using self trained model
|
||||||
|
if cfg.MODEL.PRETRAIN_CHOICE == 'self':
|
||||||
|
start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1])
|
||||||
|
print('Start epoch:', start_epoch)
|
||||||
|
path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
|
||||||
|
print('Path to the checkpoint of optimizer:', path_to_optimizer)
|
||||||
|
path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center')
|
||||||
|
print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center)
|
||||||
|
model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
|
||||||
|
optimizer.load_state_dict(torch.load(path_to_optimizer))
|
||||||
|
optimizer_center.load_state_dict(torch.load(path_to_optimizer_center))
|
||||||
|
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
||||||
|
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)
|
||||||
|
elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
|
||||||
|
start_epoch = 0
|
||||||
|
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
||||||
|
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
|
||||||
|
else:
|
||||||
|
print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE))
|
||||||
|
|
||||||
do_train_with_center(
|
do_train_with_center(
|
||||||
cfg,
|
cfg,
|
||||||
model,
|
model,
|
||||||
|
@ -65,9 +104,10 @@ def train(cfg):
|
||||||
val_loader,
|
val_loader,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_center,
|
optimizer_center,
|
||||||
scheduler,
|
scheduler, # modify for using self trained model
|
||||||
loss_func,
|
loss_func,
|
||||||
num_query
|
num_query,
|
||||||
|
start_epoch # add for using self trained model
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER))
|
print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER))
|
||||||
|
|
Loading…
Reference in New Issue