mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
finalize fixbase functions
This commit is contained in:
parent
54b92ee617
commit
b85e1b59d6
2
args.py
2
args.py
@ -86,7 +86,7 @@ def argument_parser():
|
||||
parser.add_argument('--test-batch-size', default=100, type=int,
|
||||
help="test batch size")
|
||||
|
||||
parser.add_argument('--fixbase-epoch', type=int, default=0,
|
||||
parser.add_argument('--fixbase-epoch', type=int, default=0,
|
||||
help="how many epochs to fix base network (only train randomly initialized classifier)")
|
||||
parser.add_argument('--open-layers', type=str, nargs='+', default=['classifier'],
|
||||
help="open specified layers for training while keeping others frozen")
|
||||
|
@ -61,13 +61,6 @@ def main():
|
||||
optimizer = init_optimizer(model.parameters(), **optimizer_kwargs(args))
|
||||
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma)
|
||||
|
||||
"""if args.fixbase_epoch > 0:
|
||||
if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
|
||||
optimizer_tmp = init_optimizer(model.classifier.parameters(), **optimizer_kwargs(args))
|
||||
else:
|
||||
print("Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0")
|
||||
args.fixbase_epoch = 0"""
|
||||
|
||||
if args.load_weights and check_isfile(args.load_weights):
|
||||
# load pretrained weights but ignore layers that don't match in size
|
||||
checkpoint = torch.load(args.load_weights)
|
||||
@ -119,7 +112,7 @@ def main():
|
||||
train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=True)
|
||||
train_time += round(time.time() - start_train_time)
|
||||
|
||||
print("Now open all layers for training")
|
||||
print("Done. All layers are open to train for {} epochs".format(args.max_epoch))
|
||||
optimizer.load_state_dict(initial_optim_state)
|
||||
|
||||
for epoch in range(args.start_epoch, args.max_epoch):
|
||||
|
@ -20,7 +20,7 @@ from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||
from torchreid.utils.avgmeter import AverageMeter
|
||||
from torchreid.utils.loggers import Logger, RankLogger
|
||||
from torchreid.utils.torchtools import count_num_param
|
||||
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers
|
||||
from torchreid.utils.reidtools import visualize_ranked_results
|
||||
from torchreid.eval_metrics import evaluate
|
||||
from torchreid.samplers import RandomIdentitySampler
|
||||
@ -106,6 +106,18 @@ def main():
|
||||
train_time = 0
|
||||
print("==> Start training")
|
||||
|
||||
if args.fixbase_epoch > 0:
|
||||
print("Train {} for {} epochs while keeping other layers frozen".format(args.open_layers, args.fixbase_epoch))
|
||||
initial_optim_state = optimizer.state_dict()
|
||||
|
||||
for epoch in range(args.fixbase_epoch):
|
||||
start_train_time = time.time()
|
||||
train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=True)
|
||||
train_time += round(time.time() - start_train_time)
|
||||
|
||||
print("Done. All layers are open to train for {} epochs".format(args.max_epoch))
|
||||
optimizer.load_state_dict(initial_optim_state)
|
||||
|
||||
for epoch in range(args.start_epoch, args.max_epoch):
|
||||
start_train_time = time.time()
|
||||
train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu)
|
||||
@ -141,13 +153,18 @@ def main():
|
||||
ranklogger.show_summary()
|
||||
|
||||
|
||||
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
|
||||
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
|
||||
losses = AverageMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
|
||||
model.train()
|
||||
|
||||
if fixbase:
|
||||
open_specified_layers(model, args.open_layers)
|
||||
else:
|
||||
open_all_layers(model)
|
||||
|
||||
end = time.time()
|
||||
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
|
||||
data_time.update(time.time() - end)
|
||||
|
@ -21,7 +21,7 @@ from torchreid.losses import CrossEntropyLoss
|
||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||
from torchreid.utils.avgmeter import AverageMeter
|
||||
from torchreid.utils.loggers import Logger, RankLogger
|
||||
from torchreid.utils.torchtools import set_bn_to_eval, count_num_param
|
||||
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers
|
||||
from torchreid.utils.reidtools import visualize_ranked_results
|
||||
from torchreid.eval_metrics import evaluate
|
||||
from torchreid.optimizers import init_optimizer
|
||||
@ -62,14 +62,6 @@ def main():
|
||||
optimizer = init_optimizer(model.parameters(), **optimizer_kwargs(args))
|
||||
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma)
|
||||
|
||||
if args.fixbase_epoch > 0:
|
||||
"""if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
|
||||
optimizer_tmp = init_optimizer(model.classifier.parameters(), **optimizer_kwargs(args))
|
||||
else:
|
||||
print("Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0")
|
||||
args.fixbase_epoch = 0"""
|
||||
raise NotImplementedError
|
||||
|
||||
if args.load_weights and check_isfile(args.load_weights):
|
||||
# load pretrained weights but ignore layers that don't match in size
|
||||
checkpoint = torch.load(args.load_weights)
|
||||
@ -113,16 +105,16 @@ def main():
|
||||
print("==> Start training")
|
||||
|
||||
if args.fixbase_epoch > 0:
|
||||
"""print("Train classifier for {} epochs while keeping base network frozen".format(args.fixbase_epoch))
|
||||
print("Train {} for {} epochs while keeping other layers frozen".format(args.open_layers, args.fixbase_epoch))
|
||||
initial_optim_state = optimizer.state_dict()
|
||||
|
||||
for epoch in range(args.fixbase_epoch):
|
||||
start_train_time = time.time()
|
||||
train(epoch, model, criterion, optimizer_tmp, trainloader, use_gpu, freeze_bn=True)
|
||||
train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=True)
|
||||
train_time += round(time.time() - start_train_time)
|
||||
|
||||
del optimizer_tmp
|
||||
print("Now open all layers for training")"""
|
||||
raise NotImplementedError
|
||||
print("Done. All layers are open to train for {} epochs".format(args.max_epoch))
|
||||
optimizer.load_state_dict(initial_optim_state)
|
||||
|
||||
for epoch in range(args.start_epoch, args.max_epoch):
|
||||
start_train_time = time.time()
|
||||
@ -159,15 +151,17 @@ def main():
|
||||
ranklogger.show_summary()
|
||||
|
||||
|
||||
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=False):
|
||||
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=False):
|
||||
losses = AverageMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
|
||||
model.train()
|
||||
|
||||
if freeze_bn or args.freeze_bn:
|
||||
model.apply(set_bn_to_eval)
|
||||
if fixbase:
|
||||
open_specified_layers(model, args.open_layers)
|
||||
else:
|
||||
open_all_layers(model)
|
||||
|
||||
end = time.time()
|
||||
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
|
||||
|
@ -21,7 +21,7 @@ from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||
from torchreid.utils.avgmeter import AverageMeter
|
||||
from torchreid.utils.loggers import Logger, RankLogger
|
||||
from torchreid.utils.torchtools import count_num_param
|
||||
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers
|
||||
from torchreid.utils.reidtools import visualize_ranked_results
|
||||
from torchreid.eval_metrics import evaluate
|
||||
from torchreid.samplers import RandomIdentitySampler
|
||||
@ -108,6 +108,18 @@ def main():
|
||||
train_time = 0
|
||||
print("==> Start training")
|
||||
|
||||
if args.fixbase_epoch > 0:
|
||||
print("Train {} for {} epochs while keeping other layers frozen".format(args.open_layers, args.fixbase_epoch))
|
||||
initial_optim_state = optimizer.state_dict()
|
||||
|
||||
for epoch in range(args.fixbase_epoch):
|
||||
start_train_time = time.time()
|
||||
train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=True)
|
||||
train_time += round(time.time() - start_train_time)
|
||||
|
||||
print("Done. All layers are open to train for {} epochs".format(args.max_epoch))
|
||||
optimizer.load_state_dict(initial_optim_state)
|
||||
|
||||
for epoch in range(args.start_epoch, args.max_epoch):
|
||||
start_train_time = time.time()
|
||||
train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu)
|
||||
@ -143,13 +155,18 @@ def main():
|
||||
ranklogger.show_summary()
|
||||
|
||||
|
||||
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
|
||||
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
|
||||
losses = AverageMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
|
||||
model.train()
|
||||
|
||||
if fixbase:
|
||||
open_specified_layers(model, args.open_layers)
|
||||
else:
|
||||
open_all_layers(model)
|
||||
|
||||
end = time.time()
|
||||
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
|
||||
data_time.update(time.time() - end)
|
||||
|
Loading…
x
Reference in New Issue
Block a user