finalize fixbase functions

This commit is contained in:
KaiyangZhou 2018-11-09 00:02:46 +00:00
parent 54b92ee617
commit b85e1b59d6
5 changed files with 51 additions and 30 deletions

View File

@ -86,7 +86,7 @@ def argument_parser():
parser.add_argument('--test-batch-size', default=100, type=int,
help="test batch size")
parser.add_argument('--fixbase-epoch', type=int, default=0,
parser.add_argument('--fixbase-epoch', type=int, default=0,
help="how many epochs to fix base network (only train randomly initialized classifier)")
parser.add_argument('--open-layers', type=str, nargs='+', default=['classifier'],
help="open specified layers for training while keeping others frozen")

View File

@ -61,13 +61,6 @@ def main():
optimizer = init_optimizer(model.parameters(), **optimizer_kwargs(args))
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma)
"""if args.fixbase_epoch > 0:
if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
optimizer_tmp = init_optimizer(model.classifier.parameters(), **optimizer_kwargs(args))
else:
print("Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0")
args.fixbase_epoch = 0"""
if args.load_weights and check_isfile(args.load_weights):
# load pretrained weights but ignore layers that don't match in size
checkpoint = torch.load(args.load_weights)
@ -119,7 +112,7 @@ def main():
train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=True)
train_time += round(time.time() - start_train_time)
print("Now open all layers for training")
print("Done. All layers are open to train for {} epochs".format(args.max_epoch))
optimizer.load_state_dict(initial_optim_state)
for epoch in range(args.start_epoch, args.max_epoch):

View File

@ -20,7 +20,7 @@ from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
from torchreid.utils.iotools import save_checkpoint, check_isfile
from torchreid.utils.avgmeter import AverageMeter
from torchreid.utils.loggers import Logger, RankLogger
from torchreid.utils.torchtools import count_num_param
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers
from torchreid.utils.reidtools import visualize_ranked_results
from torchreid.eval_metrics import evaluate
from torchreid.samplers import RandomIdentitySampler
@ -106,6 +106,18 @@ def main():
train_time = 0
print("==> Start training")
if args.fixbase_epoch > 0:
print("Train {} for {} epochs while keeping other layers frozen".format(args.open_layers, args.fixbase_epoch))
initial_optim_state = optimizer.state_dict()
for epoch in range(args.fixbase_epoch):
start_train_time = time.time()
train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=True)
train_time += round(time.time() - start_train_time)
print("Done. All layers are open to train for {} epochs".format(args.max_epoch))
optimizer.load_state_dict(initial_optim_state)
for epoch in range(args.start_epoch, args.max_epoch):
start_train_time = time.time()
train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu)
@ -141,13 +153,18 @@ def main():
ranklogger.show_summary()
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
losses = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
model.train()
if fixbase:
open_specified_layers(model, args.open_layers)
else:
open_all_layers(model)
end = time.time()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
data_time.update(time.time() - end)

View File

@ -21,7 +21,7 @@ from torchreid.losses import CrossEntropyLoss
from torchreid.utils.iotools import save_checkpoint, check_isfile
from torchreid.utils.avgmeter import AverageMeter
from torchreid.utils.loggers import Logger, RankLogger
from torchreid.utils.torchtools import set_bn_to_eval, count_num_param
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers
from torchreid.utils.reidtools import visualize_ranked_results
from torchreid.eval_metrics import evaluate
from torchreid.optimizers import init_optimizer
@ -62,14 +62,6 @@ def main():
optimizer = init_optimizer(model.parameters(), **optimizer_kwargs(args))
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma)
if args.fixbase_epoch > 0:
"""if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
optimizer_tmp = init_optimizer(model.classifier.parameters(), **optimizer_kwargs(args))
else:
print("Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0")
args.fixbase_epoch = 0"""
raise NotImplementedError
if args.load_weights and check_isfile(args.load_weights):
# load pretrained weights but ignore layers that don't match in size
checkpoint = torch.load(args.load_weights)
@ -113,16 +105,16 @@ def main():
print("==> Start training")
if args.fixbase_epoch > 0:
"""print("Train classifier for {} epochs while keeping base network frozen".format(args.fixbase_epoch))
print("Train {} for {} epochs while keeping other layers frozen".format(args.open_layers, args.fixbase_epoch))
initial_optim_state = optimizer.state_dict()
for epoch in range(args.fixbase_epoch):
start_train_time = time.time()
train(epoch, model, criterion, optimizer_tmp, trainloader, use_gpu, freeze_bn=True)
train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=True)
train_time += round(time.time() - start_train_time)
del optimizer_tmp
print("Now open all layers for training")"""
raise NotImplementedError
print("Done. All layers are open to train for {} epochs".format(args.max_epoch))
optimizer.load_state_dict(initial_optim_state)
for epoch in range(args.start_epoch, args.max_epoch):
start_train_time = time.time()
@ -159,15 +151,17 @@ def main():
ranklogger.show_summary()
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=False):
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=False):
losses = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
model.train()
if freeze_bn or args.freeze_bn:
model.apply(set_bn_to_eval)
if fixbase:
open_specified_layers(model, args.open_layers)
else:
open_all_layers(model)
end = time.time()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):

View File

@ -21,7 +21,7 @@ from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
from torchreid.utils.iotools import save_checkpoint, check_isfile
from torchreid.utils.avgmeter import AverageMeter
from torchreid.utils.loggers import Logger, RankLogger
from torchreid.utils.torchtools import count_num_param
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers
from torchreid.utils.reidtools import visualize_ranked_results
from torchreid.eval_metrics import evaluate
from torchreid.samplers import RandomIdentitySampler
@ -108,6 +108,18 @@ def main():
train_time = 0
print("==> Start training")
if args.fixbase_epoch > 0:
print("Train {} for {} epochs while keeping other layers frozen".format(args.open_layers, args.fixbase_epoch))
initial_optim_state = optimizer.state_dict()
for epoch in range(args.fixbase_epoch):
start_train_time = time.time()
train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=True)
train_time += round(time.time() - start_train_time)
print("Done. All layers are open to train for {} epochs".format(args.max_epoch))
optimizer.load_state_dict(initial_optim_state)
for epoch in range(args.start_epoch, args.max_epoch):
start_train_time = time.time()
train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu)
@ -143,13 +155,18 @@ def main():
ranklogger.show_summary()
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
losses = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
model.train()
if fixbase:
open_specified_layers(model, args.open_layers)
else:
open_all_layers(model)
end = time.time()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
data_time.update(time.time() - end)