add check_isfile
parent
85bdceb64a
commit
fce7779381
|
@ -20,7 +20,7 @@ from dataset_loader import ImageDataset
|
|||
import transforms as T
|
||||
import models
|
||||
from losses import CrossEntropyLabelSmooth, DeepSupervision
|
||||
from utils.iotools import save_checkpoint
|
||||
from utils.iotools import save_checkpoint, check_isfile
|
||||
from utils.avgmeter import AverageMeter
|
||||
from utils.logger import Logger
|
||||
from utils.torchtools import set_bn_to_eval, count_num_param
|
||||
|
@ -183,24 +183,23 @@ def main():
|
|||
|
||||
if args.load_weights:
|
||||
# load pretrained weights but ignore layers that don't match in size
|
||||
print("Loading pretrained weights from '{}'".format(args.load_weights))
|
||||
if check_isfile(args.load_weights):
|
||||
checkpoint = torch.load(args.load_weights)
|
||||
pretrain_dict = checkpoint['state_dict']
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
print("Loaded pretrained weights from '{}'".format(args.load_weights))
|
||||
|
||||
if args.resume:
|
||||
if osp.isfile(args.resume):
|
||||
if check_isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
rank1 = checkpoint['rank1']
|
||||
print("Loaded checkpoint from '{}'".format(args.resume))
|
||||
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
||||
else:
|
||||
print("=> No checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
if use_gpu:
|
||||
model = nn.DataParallel(model).cuda()
|
||||
|
|
|
@ -20,7 +20,7 @@ from dataset_loader import ImageDataset
|
|||
import transforms as T
|
||||
import models
|
||||
from losses import CrossEntropyLabelSmooth, TripletLoss, DeepSupervision
|
||||
from utils.iotools import save_checkpoint
|
||||
from utils.iotools import save_checkpoint, check_isfile
|
||||
from utils.avgmeter import AverageMeter
|
||||
from utils.logger import Logger
|
||||
from utils.torchtools import count_num_param
|
||||
|
@ -180,24 +180,23 @@ def main():
|
|||
|
||||
if args.load_weights:
|
||||
# load pretrained weights but ignore layers that don't match in size
|
||||
print("Loading pretrained weights from '{}'".format(args.load_weights))
|
||||
if check_isfile(args.load_weights):
|
||||
checkpoint = torch.load(args.load_weights)
|
||||
pretrain_dict = checkpoint['state_dict']
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
print("Loaded pretrained weights from '{}'".format(args.load_weights))
|
||||
|
||||
if args.resume:
|
||||
if osp.isfile(args.resume):
|
||||
if check_isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
rank1 = checkpoint['rank1']
|
||||
print("Loaded checkpoint from '{}'".format(args.resume))
|
||||
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
||||
else:
|
||||
print("=> No checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
if use_gpu:
|
||||
model = nn.DataParallel(model).cuda()
|
||||
|
|
|
@ -20,7 +20,7 @@ from dataset_loader import ImageDataset, VideoDataset
|
|||
import transforms as T
|
||||
import models
|
||||
from losses import CrossEntropyLabelSmooth
|
||||
from utils.iotools import save_checkpoint
|
||||
from utils.iotools import save_checkpoint, check_isfile
|
||||
from utils.avgmeter import AverageMeter
|
||||
from utils.logger import Logger
|
||||
from utils.torchtools import set_bn_to_eval, count_num_param
|
||||
|
@ -174,24 +174,23 @@ def main():
|
|||
|
||||
if args.load_weights:
|
||||
# load pretrained weights but ignore layers that don't match in size
|
||||
print("Loading pretrained weights from '{}'".format(args.load_weights))
|
||||
if check_isfile(args.load_weights):
|
||||
checkpoint = torch.load(args.load_weights)
|
||||
pretrain_dict = checkpoint['state_dict']
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
print("Loaded pretrained weights from '{}'".format(args.load_weights))
|
||||
|
||||
if args.resume:
|
||||
if osp.isfile(args.resume):
|
||||
if check_isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
rank1 = checkpoint['rank1']
|
||||
print("Loaded checkpoint from '{}'".format(args.resume))
|
||||
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
||||
else:
|
||||
print("=> No checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
if use_gpu:
|
||||
model = nn.DataParallel(model).cuda()
|
||||
|
|
|
@ -20,7 +20,7 @@ from dataset_loader import ImageDataset, VideoDataset
|
|||
import transforms as T
|
||||
import models
|
||||
from losses import CrossEntropyLabelSmooth, TripletLoss
|
||||
from utils.iotools import save_checkpoint
|
||||
from utils.iotools import save_checkpoint, check_isfile
|
||||
from utils.avgmeter import AverageMeter
|
||||
from utils.logger import Logger
|
||||
from utils.torchtools import count_num_param
|
||||
|
@ -171,24 +171,23 @@ def main():
|
|||
|
||||
if args.load_weights:
|
||||
# load pretrained weights but ignore layers that don't match in size
|
||||
print("Loading pretrained weights from '{}'".format(args.load_weights))
|
||||
if check_isfile(args.load_weights):
|
||||
checkpoint = torch.load(args.load_weights)
|
||||
pretrain_dict = checkpoint['state_dict']
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
print("Loaded pretrained weights from '{}'".format(args.load_weights))
|
||||
|
||||
if args.resume:
|
||||
if osp.isfile(args.resume):
|
||||
if check_isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
rank1 = checkpoint['rank1']
|
||||
print("Loaded checkpoint from '{}'".format(args.resume))
|
||||
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
||||
else:
|
||||
print("=> No checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
if use_gpu:
|
||||
model = nn.DataParallel(model).cuda()
|
||||
|
|
Loading…
Reference in New Issue