add check_isfile

pull/62/head
KaiyangZhou 2018-08-07 17:14:44 +01:00
parent 85bdceb64a
commit fce7779381
4 changed files with 40 additions and 44 deletions

View File

@ -20,7 +20,7 @@ from dataset_loader import ImageDataset
import transforms as T
import models
from losses import CrossEntropyLabelSmooth, DeepSupervision
from utils.iotools import save_checkpoint
from utils.iotools import save_checkpoint, check_isfile
from utils.avgmeter import AverageMeter
from utils.logger import Logger
from utils.torchtools import set_bn_to_eval, count_num_param
@ -183,24 +183,23 @@ def main():
if args.load_weights:
# load pretrained weights but ignore layers that don't match in size
print("Loading pretrained weights from '{}'".format(args.load_weights))
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
if check_isfile(args.load_weights):
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print("Loaded pretrained weights from '{}'".format(args.load_weights))
if args.resume:
if osp.isfile(args.resume):
if check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch']
rank1 = checkpoint['rank1']
print("Loaded checkpoint from '{}'".format(args.resume))
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
else:
print("=> No checkpoint found at '{}'".format(args.resume))
if use_gpu:
model = nn.DataParallel(model).cuda()

View File

@ -20,7 +20,7 @@ from dataset_loader import ImageDataset
import transforms as T
import models
from losses import CrossEntropyLabelSmooth, TripletLoss, DeepSupervision
from utils.iotools import save_checkpoint
from utils.iotools import save_checkpoint, check_isfile
from utils.avgmeter import AverageMeter
from utils.logger import Logger
from utils.torchtools import count_num_param
@ -180,24 +180,23 @@ def main():
if args.load_weights:
# load pretrained weights but ignore layers that don't match in size
print("Loading pretrained weights from '{}'".format(args.load_weights))
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
if check_isfile(args.load_weights):
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print("Loaded pretrained weights from '{}'".format(args.load_weights))
if args.resume:
if osp.isfile(args.resume):
if check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch']
rank1 = checkpoint['rank1']
print("Loaded checkpoint from '{}'".format(args.resume))
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
else:
print("=> No checkpoint found at '{}'".format(args.resume))
if use_gpu:
model = nn.DataParallel(model).cuda()

View File

@ -20,7 +20,7 @@ from dataset_loader import ImageDataset, VideoDataset
import transforms as T
import models
from losses import CrossEntropyLabelSmooth
from utils.iotools import save_checkpoint
from utils.iotools import save_checkpoint, check_isfile
from utils.avgmeter import AverageMeter
from utils.logger import Logger
from utils.torchtools import set_bn_to_eval, count_num_param
@ -174,24 +174,23 @@ def main():
if args.load_weights:
# load pretrained weights but ignore layers that don't match in size
print("Loading pretrained weights from '{}'".format(args.load_weights))
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
if check_isfile(args.load_weights):
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print("Loaded pretrained weights from '{}'".format(args.load_weights))
if args.resume:
if osp.isfile(args.resume):
if check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch']
rank1 = checkpoint['rank1']
print("Loaded checkpoint from '{}'".format(args.resume))
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
else:
print("=> No checkpoint found at '{}'".format(args.resume))
if use_gpu:
model = nn.DataParallel(model).cuda()

View File

@ -20,7 +20,7 @@ from dataset_loader import ImageDataset, VideoDataset
import transforms as T
import models
from losses import CrossEntropyLabelSmooth, TripletLoss
from utils.iotools import save_checkpoint
from utils.iotools import save_checkpoint, check_isfile
from utils.avgmeter import AverageMeter
from utils.logger import Logger
from utils.torchtools import count_num_param
@ -171,24 +171,23 @@ def main():
if args.load_weights:
# load pretrained weights but ignore layers that don't match in size
print("Loading pretrained weights from '{}'".format(args.load_weights))
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
if check_isfile(args.load_weights):
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print("Loaded pretrained weights from '{}'".format(args.load_weights))
if args.resume:
if osp.isfile(args.resume):
if check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch']
rank1 = checkpoint['rank1']
print("Loaded checkpoint from '{}'".format(args.resume))
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
else:
print("=> No checkpoint found at '{}'".format(args.resume))
if use_gpu:
model = nn.DataParallel(model).cuda()