wrap resume; save optimizer in ckpt

pull/119/head
KaiyangZhou 2019-02-20 21:52:43 +00:00
parent 74a2377ab7
commit cd5dfd62a1
4 changed files with 40 additions and 45 deletions

View File

@ -16,10 +16,11 @@ from args import argument_parser, image_dataset_kwargs, optimizer_kwargs, lr_sch
from torchreid.data_manager import ImageDataManager
from torchreid import models
from torchreid.losses import CrossEntropyLoss, DeepSupervision
from torchreid.utils.iotools import save_checkpoint, check_isfile
from torchreid.utils.iotools import check_isfile
from torchreid.utils.avgmeter import AverageMeter
from torchreid.utils.loggers import Logger, RankLogger
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, load_pretrained_weights
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, \
load_pretrained_weights, save_checkpoint, resume_from_checkpoint
from torchreid.utils.reidtools import visualize_ranked_results
from torchreid.utils.generaltools import set_random_seed
from torchreid.eval_metrics import evaluate
@ -60,19 +61,15 @@ def main():
if args.load_weights and check_isfile(args.load_weights):
load_pretrained_weights(model, args.load_weights)
if args.resume and check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch'] + 1
print('Loaded checkpoint from "{}"'.format(args.resume))
print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch, checkpoint['rank1']))
model = nn.DataParallel(model).cuda() if use_gpu else model
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
optimizer = init_optimizer(model, **optimizer_kwargs(args))
scheduler = init_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))
if args.resume and check_isfile(args.resume):
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
if args.evaluate:
print('Evaluate only')
@ -127,8 +124,10 @@ def main():
save_checkpoint({
'state_dict': model.state_dict(),
'rank1': rank1,
'epoch': epoch,
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
'epoch': epoch + 1,
'arch': args.arch,
'optimizer': optimizer.state_dict(),
}, args.save_dir)
elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))

View File

@ -16,10 +16,11 @@ from args import argument_parser, image_dataset_kwargs, optimizer_kwargs, lr_sch
from torchreid.data_manager import ImageDataManager
from torchreid import models
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
from torchreid.utils.iotools import save_checkpoint, check_isfile
from torchreid.utils.iotools import check_isfile
from torchreid.utils.avgmeter import AverageMeter
from torchreid.utils.loggers import Logger, RankLogger
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, load_pretrained_weights
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, \
load_pretrained_weights, save_checkpoint, resume_from_checkpoint
from torchreid.utils.reidtools import visualize_ranked_results
from torchreid.utils.generaltools import set_random_seed
from torchreid.eval_metrics import evaluate
@ -61,13 +62,6 @@ def main():
if args.load_weights and check_isfile(args.load_weights):
load_pretrained_weights(model, args.load_weights)
if args.resume and check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch'] + 1
print('Loaded checkpoint from "{}"'.format(args.resume))
print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch, checkpoint['rank1']))
model = nn.DataParallel(model).cuda() if use_gpu else model
criterion_xent = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
@ -75,6 +69,9 @@ def main():
optimizer = init_optimizer(model, **optimizer_kwargs(args))
scheduler = init_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))
if args.resume and check_isfile(args.resume):
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
if args.evaluate:
print('Evaluate only')
@ -129,8 +126,10 @@ def main():
save_checkpoint({
'state_dict': model.state_dict(),
'rank1': rank1,
'epoch': epoch,
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
'epoch': epoch + 1,
'arch': args.arch,
'optimizer': optimizer.state_dict(),
}, args.save_dir)
elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))

View File

@ -17,10 +17,11 @@ from args import argument_parser, video_dataset_kwargs, optimizer_kwargs, lr_sch
from torchreid.data_manager import VideoDataManager
from torchreid import models
from torchreid.losses import CrossEntropyLoss
from torchreid.utils.iotools import save_checkpoint, check_isfile
from torchreid.utils.iotools import check_isfile
from torchreid.utils.avgmeter import AverageMeter
from torchreid.utils.loggers import Logger, RankLogger
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, load_pretrained_weights
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, \
load_pretrained_weights, save_checkpoint, resume_from_checkpoint
from torchreid.utils.reidtools import visualize_ranked_results
from torchreid.utils.generaltools import set_random_seed
from torchreid.eval_metrics import evaluate
@ -61,19 +62,15 @@ def main():
if args.load_weights and check_isfile(args.load_weights):
load_pretrained_weights(model, args.load_weights)
if args.resume and check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch'] + 1
print('Loaded checkpoint from "{}"'.format(args.resume))
print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch, checkpoint['rank1']))
model = nn.DataParallel(model).cuda() if use_gpu else model
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
optimizer = init_optimizer(model, **optimizer_kwargs(args))
scheduler = init_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))
if args.resume and check_isfile(args.resume):
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
if args.evaluate:
print('Evaluate only')
@ -128,8 +125,10 @@ def main():
save_checkpoint({
'state_dict': model.state_dict(),
'rank1': rank1,
'epoch': epoch,
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
'epoch': epoch + 1,
'arch': args.arch,
'optimizer': optimizer.state_dict(),
}, args.save_dir)
elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))

View File

@ -17,10 +17,11 @@ from args import argument_parser, video_dataset_kwargs, optimizer_kwargs, lr_sch
from torchreid.data_manager import VideoDataManager
from torchreid import models
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
from torchreid.utils.iotools import save_checkpoint, check_isfile
from torchreid.utils.iotools import check_isfile
from torchreid.utils.avgmeter import AverageMeter
from torchreid.utils.loggers import Logger, RankLogger
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, load_pretrained_weights
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, \
load_pretrained_weights, save_checkpoint, resume_from_checkpoint
from torchreid.utils.reidtools import visualize_ranked_results
from torchreid.utils.generaltools import set_random_seed
from torchreid.eval_metrics import evaluate
@ -62,14 +63,6 @@ def main():
if args.load_weights and check_isfile(args.load_weights):
load_pretrained_weights(model, args.load_weights)
if args.resume and check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch'] + 1
best_rank1 = checkpoint['rank1']
print('Loaded checkpoint from "{}"'.format(args.resume))
print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch, best_rank1))
model = nn.DataParallel(model).cuda() if use_gpu else model
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
@ -77,6 +70,9 @@ def main():
optimizer = init_optimizer(model, **optimizer_kwargs(args))
scheduler = init_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))
if args.resume and check_isfile(args.resume):
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
if args.evaluate:
print('Evaluate only')
@ -131,8 +127,10 @@ def main():
save_checkpoint({
'state_dict': model.state_dict(),
'rank1': rank1,
'epoch': epoch,
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
'epoch': epoch + 1,
'arch': args.arch,
'optimizer': optimizer.state_dict(),
}, args.save_dir)
elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))