wrap resume; save optimizer in ckpt
parent
74a2377ab7
commit
cd5dfd62a1
|
@ -16,10 +16,11 @@ from args import argument_parser, image_dataset_kwargs, optimizer_kwargs, lr_sch
|
|||
from torchreid.data_manager import ImageDataManager
|
||||
from torchreid import models
|
||||
from torchreid.losses import CrossEntropyLoss, DeepSupervision
|
||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||
from torchreid.utils.iotools import check_isfile
|
||||
from torchreid.utils.avgmeter import AverageMeter
|
||||
from torchreid.utils.loggers import Logger, RankLogger
|
||||
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, load_pretrained_weights
|
||||
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, \
|
||||
load_pretrained_weights, save_checkpoint, resume_from_checkpoint
|
||||
from torchreid.utils.reidtools import visualize_ranked_results
|
||||
from torchreid.utils.generaltools import set_random_seed
|
||||
from torchreid.eval_metrics import evaluate
|
||||
|
@ -60,19 +61,15 @@ def main():
|
|||
if args.load_weights and check_isfile(args.load_weights):
|
||||
load_pretrained_weights(model, args.load_weights)
|
||||
|
||||
if args.resume and check_isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
args.start_epoch = checkpoint['epoch'] + 1
|
||||
print('Loaded checkpoint from "{}"'.format(args.resume))
|
||||
print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch, checkpoint['rank1']))
|
||||
|
||||
model = nn.DataParallel(model).cuda() if use_gpu else model
|
||||
|
||||
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
|
||||
optimizer = init_optimizer(model, **optimizer_kwargs(args))
|
||||
scheduler = init_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))
|
||||
|
||||
if args.resume and check_isfile(args.resume):
|
||||
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
|
||||
|
||||
if args.evaluate:
|
||||
print('Evaluate only')
|
||||
|
||||
|
@ -127,8 +124,10 @@ def main():
|
|||
save_checkpoint({
|
||||
'state_dict': model.state_dict(),
|
||||
'rank1': rank1,
|
||||
'epoch': epoch,
|
||||
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}, args.save_dir)
|
||||
|
||||
elapsed = round(time.time() - start_time)
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
|
|
@ -16,10 +16,11 @@ from args import argument_parser, image_dataset_kwargs, optimizer_kwargs, lr_sch
|
|||
from torchreid.data_manager import ImageDataManager
|
||||
from torchreid import models
|
||||
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||
from torchreid.utils.iotools import check_isfile
|
||||
from torchreid.utils.avgmeter import AverageMeter
|
||||
from torchreid.utils.loggers import Logger, RankLogger
|
||||
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, load_pretrained_weights
|
||||
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, \
|
||||
load_pretrained_weights, save_checkpoint, resume_from_checkpoint
|
||||
from torchreid.utils.reidtools import visualize_ranked_results
|
||||
from torchreid.utils.generaltools import set_random_seed
|
||||
from torchreid.eval_metrics import evaluate
|
||||
|
@ -61,13 +62,6 @@ def main():
|
|||
if args.load_weights and check_isfile(args.load_weights):
|
||||
load_pretrained_weights(model, args.load_weights)
|
||||
|
||||
if args.resume and check_isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
args.start_epoch = checkpoint['epoch'] + 1
|
||||
print('Loaded checkpoint from "{}"'.format(args.resume))
|
||||
print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch, checkpoint['rank1']))
|
||||
|
||||
model = nn.DataParallel(model).cuda() if use_gpu else model
|
||||
|
||||
criterion_xent = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
|
||||
|
@ -75,6 +69,9 @@ def main():
|
|||
optimizer = init_optimizer(model, **optimizer_kwargs(args))
|
||||
scheduler = init_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))
|
||||
|
||||
if args.resume and check_isfile(args.resume):
|
||||
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
|
||||
|
||||
if args.evaluate:
|
||||
print('Evaluate only')
|
||||
|
||||
|
@ -129,8 +126,10 @@ def main():
|
|||
save_checkpoint({
|
||||
'state_dict': model.state_dict(),
|
||||
'rank1': rank1,
|
||||
'epoch': epoch,
|
||||
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}, args.save_dir)
|
||||
|
||||
elapsed = round(time.time() - start_time)
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
|
|
@ -17,10 +17,11 @@ from args import argument_parser, video_dataset_kwargs, optimizer_kwargs, lr_sch
|
|||
from torchreid.data_manager import VideoDataManager
|
||||
from torchreid import models
|
||||
from torchreid.losses import CrossEntropyLoss
|
||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||
from torchreid.utils.iotools import check_isfile
|
||||
from torchreid.utils.avgmeter import AverageMeter
|
||||
from torchreid.utils.loggers import Logger, RankLogger
|
||||
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, load_pretrained_weights
|
||||
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, \
|
||||
load_pretrained_weights, save_checkpoint, resume_from_checkpoint
|
||||
from torchreid.utils.reidtools import visualize_ranked_results
|
||||
from torchreid.utils.generaltools import set_random_seed
|
||||
from torchreid.eval_metrics import evaluate
|
||||
|
@ -61,19 +62,15 @@ def main():
|
|||
if args.load_weights and check_isfile(args.load_weights):
|
||||
load_pretrained_weights(model, args.load_weights)
|
||||
|
||||
if args.resume and check_isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
args.start_epoch = checkpoint['epoch'] + 1
|
||||
print('Loaded checkpoint from "{}"'.format(args.resume))
|
||||
print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch, checkpoint['rank1']))
|
||||
|
||||
model = nn.DataParallel(model).cuda() if use_gpu else model
|
||||
|
||||
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
|
||||
optimizer = init_optimizer(model, **optimizer_kwargs(args))
|
||||
scheduler = init_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))
|
||||
|
||||
if args.resume and check_isfile(args.resume):
|
||||
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
|
||||
|
||||
if args.evaluate:
|
||||
print('Evaluate only')
|
||||
|
||||
|
@ -128,8 +125,10 @@ def main():
|
|||
save_checkpoint({
|
||||
'state_dict': model.state_dict(),
|
||||
'rank1': rank1,
|
||||
'epoch': epoch,
|
||||
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}, args.save_dir)
|
||||
|
||||
elapsed = round(time.time() - start_time)
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
|
|
@ -17,10 +17,11 @@ from args import argument_parser, video_dataset_kwargs, optimizer_kwargs, lr_sch
|
|||
from torchreid.data_manager import VideoDataManager
|
||||
from torchreid import models
|
||||
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||
from torchreid.utils.iotools import check_isfile
|
||||
from torchreid.utils.avgmeter import AverageMeter
|
||||
from torchreid.utils.loggers import Logger, RankLogger
|
||||
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, load_pretrained_weights
|
||||
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, \
|
||||
load_pretrained_weights, save_checkpoint, resume_from_checkpoint
|
||||
from torchreid.utils.reidtools import visualize_ranked_results
|
||||
from torchreid.utils.generaltools import set_random_seed
|
||||
from torchreid.eval_metrics import evaluate
|
||||
|
@ -62,14 +63,6 @@ def main():
|
|||
if args.load_weights and check_isfile(args.load_weights):
|
||||
load_pretrained_weights(model, args.load_weights)
|
||||
|
||||
if args.resume and check_isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
args.start_epoch = checkpoint['epoch'] + 1
|
||||
best_rank1 = checkpoint['rank1']
|
||||
print('Loaded checkpoint from "{}"'.format(args.resume))
|
||||
print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch, best_rank1))
|
||||
|
||||
model = nn.DataParallel(model).cuda() if use_gpu else model
|
||||
|
||||
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
|
||||
|
@ -77,6 +70,9 @@ def main():
|
|||
optimizer = init_optimizer(model, **optimizer_kwargs(args))
|
||||
scheduler = init_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))
|
||||
|
||||
if args.resume and check_isfile(args.resume):
|
||||
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
|
||||
|
||||
if args.evaluate:
|
||||
print('Evaluate only')
|
||||
|
||||
|
@ -131,8 +127,10 @@ def main():
|
|||
save_checkpoint({
|
||||
'state_dict': model.state_dict(),
|
||||
'rank1': rank1,
|
||||
'epoch': epoch,
|
||||
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}, args.save_dir)
|
||||
|
||||
elapsed = round(time.time() - start_time)
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
|
Loading…
Reference in New Issue