add check_isfile
parent
85bdceb64a
commit
fce7779381
|
@ -20,7 +20,7 @@ from dataset_loader import ImageDataset
|
||||||
import transforms as T
|
import transforms as T
|
||||||
import models
|
import models
|
||||||
from losses import CrossEntropyLabelSmooth, DeepSupervision
|
from losses import CrossEntropyLabelSmooth, DeepSupervision
|
||||||
from utils.iotools import save_checkpoint
|
from utils.iotools import save_checkpoint, check_isfile
|
||||||
from utils.avgmeter import AverageMeter
|
from utils.avgmeter import AverageMeter
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
from utils.torchtools import set_bn_to_eval, count_num_param
|
from utils.torchtools import set_bn_to_eval, count_num_param
|
||||||
|
@ -183,24 +183,23 @@ def main():
|
||||||
|
|
||||||
if args.load_weights:
|
if args.load_weights:
|
||||||
# load pretrained weights but ignore layers that don't match in size
|
# load pretrained weights but ignore layers that don't match in size
|
||||||
print("Loading pretrained weights from '{}'".format(args.load_weights))
|
if check_isfile(args.load_weights):
|
||||||
checkpoint = torch.load(args.load_weights)
|
checkpoint = torch.load(args.load_weights)
|
||||||
pretrain_dict = checkpoint['state_dict']
|
pretrain_dict = checkpoint['state_dict']
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
||||||
model_dict.update(pretrain_dict)
|
model_dict.update(pretrain_dict)
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
|
print("Loaded pretrained weights from '{}'".format(args.load_weights))
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
if osp.isfile(args.resume):
|
if check_isfile(args.resume):
|
||||||
checkpoint = torch.load(args.resume)
|
checkpoint = torch.load(args.resume)
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
args.start_epoch = checkpoint['epoch']
|
args.start_epoch = checkpoint['epoch']
|
||||||
rank1 = checkpoint['rank1']
|
rank1 = checkpoint['rank1']
|
||||||
print("Loaded checkpoint from '{}'".format(args.resume))
|
print("Loaded checkpoint from '{}'".format(args.resume))
|
||||||
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
||||||
else:
|
|
||||||
print("=> No checkpoint found at '{}'".format(args.resume))
|
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
model = nn.DataParallel(model).cuda()
|
model = nn.DataParallel(model).cuda()
|
||||||
|
|
|
@ -20,7 +20,7 @@ from dataset_loader import ImageDataset
|
||||||
import transforms as T
|
import transforms as T
|
||||||
import models
|
import models
|
||||||
from losses import CrossEntropyLabelSmooth, TripletLoss, DeepSupervision
|
from losses import CrossEntropyLabelSmooth, TripletLoss, DeepSupervision
|
||||||
from utils.iotools import save_checkpoint
|
from utils.iotools import save_checkpoint, check_isfile
|
||||||
from utils.avgmeter import AverageMeter
|
from utils.avgmeter import AverageMeter
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
from utils.torchtools import count_num_param
|
from utils.torchtools import count_num_param
|
||||||
|
@ -180,24 +180,23 @@ def main():
|
||||||
|
|
||||||
if args.load_weights:
|
if args.load_weights:
|
||||||
# load pretrained weights but ignore layers that don't match in size
|
# load pretrained weights but ignore layers that don't match in size
|
||||||
print("Loading pretrained weights from '{}'".format(args.load_weights))
|
if check_isfile(args.load_weights):
|
||||||
checkpoint = torch.load(args.load_weights)
|
checkpoint = torch.load(args.load_weights)
|
||||||
pretrain_dict = checkpoint['state_dict']
|
pretrain_dict = checkpoint['state_dict']
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
||||||
model_dict.update(pretrain_dict)
|
model_dict.update(pretrain_dict)
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
|
print("Loaded pretrained weights from '{}'".format(args.load_weights))
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
if osp.isfile(args.resume):
|
if check_isfile(args.resume):
|
||||||
checkpoint = torch.load(args.resume)
|
checkpoint = torch.load(args.resume)
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
args.start_epoch = checkpoint['epoch']
|
args.start_epoch = checkpoint['epoch']
|
||||||
rank1 = checkpoint['rank1']
|
rank1 = checkpoint['rank1']
|
||||||
print("Loaded checkpoint from '{}'".format(args.resume))
|
print("Loaded checkpoint from '{}'".format(args.resume))
|
||||||
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
||||||
else:
|
|
||||||
print("=> No checkpoint found at '{}'".format(args.resume))
|
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
model = nn.DataParallel(model).cuda()
|
model = nn.DataParallel(model).cuda()
|
||||||
|
|
|
@ -20,7 +20,7 @@ from dataset_loader import ImageDataset, VideoDataset
|
||||||
import transforms as T
|
import transforms as T
|
||||||
import models
|
import models
|
||||||
from losses import CrossEntropyLabelSmooth
|
from losses import CrossEntropyLabelSmooth
|
||||||
from utils.iotools import save_checkpoint
|
from utils.iotools import save_checkpoint, check_isfile
|
||||||
from utils.avgmeter import AverageMeter
|
from utils.avgmeter import AverageMeter
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
from utils.torchtools import set_bn_to_eval, count_num_param
|
from utils.torchtools import set_bn_to_eval, count_num_param
|
||||||
|
@ -174,24 +174,23 @@ def main():
|
||||||
|
|
||||||
if args.load_weights:
|
if args.load_weights:
|
||||||
# load pretrained weights but ignore layers that don't match in size
|
# load pretrained weights but ignore layers that don't match in size
|
||||||
print("Loading pretrained weights from '{}'".format(args.load_weights))
|
if check_isfile(args.load_weights):
|
||||||
checkpoint = torch.load(args.load_weights)
|
checkpoint = torch.load(args.load_weights)
|
||||||
pretrain_dict = checkpoint['state_dict']
|
pretrain_dict = checkpoint['state_dict']
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
||||||
model_dict.update(pretrain_dict)
|
model_dict.update(pretrain_dict)
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
|
print("Loaded pretrained weights from '{}'".format(args.load_weights))
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
if osp.isfile(args.resume):
|
if check_isfile(args.resume):
|
||||||
checkpoint = torch.load(args.resume)
|
checkpoint = torch.load(args.resume)
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
args.start_epoch = checkpoint['epoch']
|
args.start_epoch = checkpoint['epoch']
|
||||||
rank1 = checkpoint['rank1']
|
rank1 = checkpoint['rank1']
|
||||||
print("Loaded checkpoint from '{}'".format(args.resume))
|
print("Loaded checkpoint from '{}'".format(args.resume))
|
||||||
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
||||||
else:
|
|
||||||
print("=> No checkpoint found at '{}'".format(args.resume))
|
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
model = nn.DataParallel(model).cuda()
|
model = nn.DataParallel(model).cuda()
|
||||||
|
|
|
@ -20,7 +20,7 @@ from dataset_loader import ImageDataset, VideoDataset
|
||||||
import transforms as T
|
import transforms as T
|
||||||
import models
|
import models
|
||||||
from losses import CrossEntropyLabelSmooth, TripletLoss
|
from losses import CrossEntropyLabelSmooth, TripletLoss
|
||||||
from utils.iotools import save_checkpoint
|
from utils.iotools import save_checkpoint, check_isfile
|
||||||
from utils.avgmeter import AverageMeter
|
from utils.avgmeter import AverageMeter
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
from utils.torchtools import count_num_param
|
from utils.torchtools import count_num_param
|
||||||
|
@ -171,24 +171,23 @@ def main():
|
||||||
|
|
||||||
if args.load_weights:
|
if args.load_weights:
|
||||||
# load pretrained weights but ignore layers that don't match in size
|
# load pretrained weights but ignore layers that don't match in size
|
||||||
print("Loading pretrained weights from '{}'".format(args.load_weights))
|
if check_isfile(args.load_weights):
|
||||||
checkpoint = torch.load(args.load_weights)
|
checkpoint = torch.load(args.load_weights)
|
||||||
pretrain_dict = checkpoint['state_dict']
|
pretrain_dict = checkpoint['state_dict']
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
||||||
model_dict.update(pretrain_dict)
|
model_dict.update(pretrain_dict)
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
|
print("Loaded pretrained weights from '{}'".format(args.load_weights))
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
if osp.isfile(args.resume):
|
if check_isfile(args.resume):
|
||||||
checkpoint = torch.load(args.resume)
|
checkpoint = torch.load(args.resume)
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
args.start_epoch = checkpoint['epoch']
|
args.start_epoch = checkpoint['epoch']
|
||||||
rank1 = checkpoint['rank1']
|
rank1 = checkpoint['rank1']
|
||||||
print("Loaded checkpoint from '{}'".format(args.resume))
|
print("Loaded checkpoint from '{}'".format(args.resume))
|
||||||
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
|
||||||
else:
|
|
||||||
print("=> No checkpoint found at '{}'".format(args.resume))
|
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
model = nn.DataParallel(model).cuda()
|
model = nn.DataParallel(model).cuda()
|
||||||
|
|
Loading…
Reference in New Issue