From fce777938160f7aa5e8ed3702ce93c5a1c3760bc Mon Sep 17 00:00:00 2001 From: KaiyangZhou Date: Tue, 7 Aug 2018 17:14:44 +0100 Subject: [PATCH] add check_isfile --- train_imgreid_xent.py | 21 ++++++++++----------- train_imgreid_xent_htri.py | 21 ++++++++++----------- train_vidreid_xent.py | 21 ++++++++++----------- train_vidreid_xent_htri.py | 21 ++++++++++----------- 4 files changed, 40 insertions(+), 44 deletions(-) diff --git a/train_imgreid_xent.py b/train_imgreid_xent.py index d9cd838..6ac8d94 100755 --- a/train_imgreid_xent.py +++ b/train_imgreid_xent.py @@ -20,7 +20,7 @@ from dataset_loader import ImageDataset import transforms as T import models from losses import CrossEntropyLabelSmooth, DeepSupervision -from utils.iotools import save_checkpoint +from utils.iotools import save_checkpoint, check_isfile from utils.avgmeter import AverageMeter from utils.logger import Logger from utils.torchtools import set_bn_to_eval, count_num_param @@ -183,24 +183,23 @@ def main(): if args.load_weights: # load pretrained weights but ignore layers that don't match in size - print("Loading pretrained weights from '{}'".format(args.load_weights)) - checkpoint = torch.load(args.load_weights) - pretrain_dict = checkpoint['state_dict'] - model_dict = model.state_dict() - pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} - model_dict.update(pretrain_dict) - model.load_state_dict(model_dict) + if check_isfile(args.load_weights): + checkpoint = torch.load(args.load_weights) + pretrain_dict = checkpoint['state_dict'] + model_dict = model.state_dict() + pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + print("Loaded pretrained weights from '{}'".format(args.load_weights)) if args.resume: - if osp.isfile(args.resume): + if check_isfile(args.resume): checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] rank1 = checkpoint['rank1'] print("Loaded checkpoint from '{}'".format(args.resume)) print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1)) - else: - print("=> No checkpoint found at '{}'".format(args.resume)) if use_gpu: model = nn.DataParallel(model).cuda() diff --git a/train_imgreid_xent_htri.py b/train_imgreid_xent_htri.py index c2367b1..e50bc4a 100755 --- a/train_imgreid_xent_htri.py +++ b/train_imgreid_xent_htri.py @@ -20,7 +20,7 @@ from dataset_loader import ImageDataset import transforms as T import models from losses import CrossEntropyLabelSmooth, TripletLoss, DeepSupervision -from utils.iotools import save_checkpoint +from utils.iotools import save_checkpoint, check_isfile from utils.avgmeter import AverageMeter from utils.logger import Logger from utils.torchtools import count_num_param @@ -180,24 +180,23 @@ def main(): if args.load_weights: # load pretrained weights but ignore layers that don't match in size - print("Loading pretrained weights from '{}'".format(args.load_weights)) - checkpoint = torch.load(args.load_weights) - pretrain_dict = checkpoint['state_dict'] - model_dict = model.state_dict() - pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} - model_dict.update(pretrain_dict) - model.load_state_dict(model_dict) + if check_isfile(args.load_weights): + checkpoint = torch.load(args.load_weights) + pretrain_dict = checkpoint['state_dict'] + model_dict = model.state_dict() + pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + print("Loaded pretrained weights from '{}'".format(args.load_weights)) if args.resume: - if osp.isfile(args.resume): + if check_isfile(args.resume): checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] rank1 = checkpoint['rank1'] print("Loaded checkpoint from '{}'".format(args.resume)) print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1)) - else: - print("=> No checkpoint found at '{}'".format(args.resume)) if use_gpu: model = nn.DataParallel(model).cuda() diff --git a/train_vidreid_xent.py b/train_vidreid_xent.py index 2852465..7192554 100755 --- a/train_vidreid_xent.py +++ b/train_vidreid_xent.py @@ -20,7 +20,7 @@ from dataset_loader import ImageDataset, VideoDataset import transforms as T import models from losses import CrossEntropyLabelSmooth -from utils.iotools import save_checkpoint +from utils.iotools import save_checkpoint, check_isfile from utils.avgmeter import AverageMeter from utils.logger import Logger from utils.torchtools import set_bn_to_eval, count_num_param @@ -174,24 +174,23 @@ def main(): if args.load_weights: # load pretrained weights but ignore layers that don't match in size - print("Loading pretrained weights from '{}'".format(args.load_weights)) - checkpoint = torch.load(args.load_weights) - pretrain_dict = checkpoint['state_dict'] - model_dict = model.state_dict() - pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} - model_dict.update(pretrain_dict) - model.load_state_dict(model_dict) + if check_isfile(args.load_weights): + checkpoint = torch.load(args.load_weights) + pretrain_dict = checkpoint['state_dict'] + model_dict = model.state_dict() + pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + print("Loaded pretrained weights from '{}'".format(args.load_weights)) if args.resume: - if osp.isfile(args.resume): + if check_isfile(args.resume): checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] rank1 = checkpoint['rank1'] print("Loaded checkpoint from '{}'".format(args.resume)) print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1)) - else: - print("=> No checkpoint found at '{}'".format(args.resume)) if use_gpu: model = nn.DataParallel(model).cuda() diff --git a/train_vidreid_xent_htri.py b/train_vidreid_xent_htri.py index c01d685..a1120a7 100755 --- a/train_vidreid_xent_htri.py +++ b/train_vidreid_xent_htri.py @@ -20,7 +20,7 @@ from dataset_loader import ImageDataset, VideoDataset import transforms as T import models from losses import CrossEntropyLabelSmooth, TripletLoss -from utils.iotools import save_checkpoint +from utils.iotools import save_checkpoint, check_isfile from utils.avgmeter import AverageMeter from utils.logger import Logger from utils.torchtools import count_num_param @@ -171,24 +171,23 @@ def main(): if args.load_weights: # load pretrained weights but ignore layers that don't match in size - print("Loading pretrained weights from '{}'".format(args.load_weights)) - checkpoint = torch.load(args.load_weights) - pretrain_dict = checkpoint['state_dict'] - model_dict = model.state_dict() - pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} - model_dict.update(pretrain_dict) - model.load_state_dict(model_dict) + if check_isfile(args.load_weights): + checkpoint = torch.load(args.load_weights) + pretrain_dict = checkpoint['state_dict'] + model_dict = model.state_dict() + pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + print("Loaded pretrained weights from '{}'".format(args.load_weights)) if args.resume: - if osp.isfile(args.resume): + if check_isfile(args.resume): checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] rank1 = checkpoint['rank1'] print("Loaded checkpoint from '{}'".format(args.resume)) print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1)) - else: - print("=> No checkpoint found at '{}'".format(args.resume)) if use_gpu: model = nn.DataParallel(model).cuda()