mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
add ranklogger; update variable names
This commit is contained in:
parent
01fe76cf27
commit
57682b768e
@ -19,7 +19,7 @@ from torchreid import models
|
|||||||
from torchreid.losses import CrossEntropyLoss, DeepSupervision
|
from torchreid.losses import CrossEntropyLoss, DeepSupervision
|
||||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||||
from torchreid.utils.avgmeter import AverageMeter
|
from torchreid.utils.avgmeter import AverageMeter
|
||||||
from torchreid.utils.logger import Logger
|
from torchreid.utils.loggers import Logger, RankLogger
|
||||||
from torchreid.utils.torchtools import set_bn_to_eval, count_num_param
|
from torchreid.utils.torchtools import set_bn_to_eval, count_num_param
|
||||||
from torchreid.utils.reidtools import visualize_ranked_results
|
from torchreid.utils.reidtools import visualize_ranked_results
|
||||||
from torchreid.eval_metrics import evaluate
|
from torchreid.eval_metrics import evaluate
|
||||||
@ -91,7 +91,7 @@ def main():
|
|||||||
if args.evaluate:
|
if args.evaluate:
|
||||||
print("Evaluate only")
|
print("Evaluate only")
|
||||||
|
|
||||||
for name in args.target:
|
for name in args.target_names:
|
||||||
print("Evaluating {} ...".format(name))
|
print("Evaluating {} ...".format(name))
|
||||||
queryloader = testloader_dict[name]['query']
|
queryloader = testloader_dict[name]['query']
|
||||||
galleryloader = testloader_dict[name]['gallery']
|
galleryloader = testloader_dict[name]['gallery']
|
||||||
@ -106,6 +106,7 @@ def main():
|
|||||||
return
|
return
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
ranklogger = RankLogger(args.source_names, args.target_names)
|
||||||
train_time = 0
|
train_time = 0
|
||||||
print("==> Start training")
|
print("==> Start training")
|
||||||
|
|
||||||
@ -130,11 +131,12 @@ def main():
|
|||||||
if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
||||||
print("==> Test")
|
print("==> Test")
|
||||||
|
|
||||||
for name in args.target:
|
for name in args.target_names:
|
||||||
print("Evaluating {} ...".format(name))
|
print("Evaluating {} ...".format(name))
|
||||||
queryloader = testloader_dict[name]['query']
|
queryloader = testloader_dict[name]['query']
|
||||||
galleryloader = testloader_dict[name]['gallery']
|
galleryloader = testloader_dict[name]['gallery']
|
||||||
rank1 = test(model, queryloader, galleryloader, use_gpu)
|
rank1 = test(model, queryloader, galleryloader, use_gpu)
|
||||||
|
ranklogger.write(name, epoch + 1, rank1)
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
state_dict = model.module.state_dict()
|
state_dict = model.module.state_dict()
|
||||||
@ -151,6 +153,7 @@ def main():
|
|||||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||||
train_time = str(datetime.timedelta(seconds=train_time))
|
train_time = str(datetime.timedelta(seconds=train_time))
|
||||||
print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
|
print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
|
||||||
|
ranklogger.show_summary()
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=False):
|
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=False):
|
||||||
|
@ -19,7 +19,7 @@ from torchreid import models
|
|||||||
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
||||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||||
from torchreid.utils.avgmeter import AverageMeter
|
from torchreid.utils.avgmeter import AverageMeter
|
||||||
from torchreid.utils.logger import Logger
|
from torchreid.utils.loggers import Logger, RankLogger
|
||||||
from torchreid.utils.torchtools import count_num_param
|
from torchreid.utils.torchtools import count_num_param
|
||||||
from torchreid.utils.reidtools import visualize_ranked_results
|
from torchreid.utils.reidtools import visualize_ranked_results
|
||||||
from torchreid.eval_metrics import evaluate
|
from torchreid.eval_metrics import evaluate
|
||||||
@ -87,7 +87,7 @@ def main():
|
|||||||
if args.evaluate:
|
if args.evaluate:
|
||||||
print("Evaluate only")
|
print("Evaluate only")
|
||||||
|
|
||||||
for name in args.target:
|
for name in args.target_names:
|
||||||
print("Evaluating {} ...".format(name))
|
print("Evaluating {} ...".format(name))
|
||||||
queryloader = testloader_dict[name]['query']
|
queryloader = testloader_dict[name]['query']
|
||||||
galleryloader = testloader_dict[name]['gallery']
|
galleryloader = testloader_dict[name]['gallery']
|
||||||
@ -102,6 +102,7 @@ def main():
|
|||||||
return
|
return
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
ranklogger = RankLogger(args.source_names, args.target_names)
|
||||||
train_time = 0
|
train_time = 0
|
||||||
print("==> Start training")
|
print("==> Start training")
|
||||||
|
|
||||||
@ -115,11 +116,12 @@ def main():
|
|||||||
if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
||||||
print("==> Test")
|
print("==> Test")
|
||||||
|
|
||||||
for name in args.target:
|
for name in args.target_names:
|
||||||
print("Evaluating {} ...".format(name))
|
print("Evaluating {} ...".format(name))
|
||||||
queryloader = testloader_dict[name]['query']
|
queryloader = testloader_dict[name]['query']
|
||||||
galleryloader = testloader_dict[name]['gallery']
|
galleryloader = testloader_dict[name]['gallery']
|
||||||
rank1 = test(model, queryloader, galleryloader, use_gpu)
|
rank1 = test(model, queryloader, galleryloader, use_gpu)
|
||||||
|
ranklogger.write(name, epoch + 1, rank1)
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
state_dict = model.module.state_dict()
|
state_dict = model.module.state_dict()
|
||||||
@ -136,6 +138,7 @@ def main():
|
|||||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||||
train_time = str(datetime.timedelta(seconds=train_time))
|
train_time = str(datetime.timedelta(seconds=train_time))
|
||||||
print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
|
print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
|
||||||
|
ranklogger.show_summary()
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
|
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
|
||||||
|
@ -20,7 +20,7 @@ from torchreid import models
|
|||||||
from torchreid.losses import CrossEntropyLoss
|
from torchreid.losses import CrossEntropyLoss
|
||||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||||
from torchreid.utils.avgmeter import AverageMeter
|
from torchreid.utils.avgmeter import AverageMeter
|
||||||
from torchreid.utils.logger import Logger
|
from torchreid.utils.loggers import Logger, RankLogger
|
||||||
from torchreid.utils.torchtools import set_bn_to_eval, count_num_param
|
from torchreid.utils.torchtools import set_bn_to_eval, count_num_param
|
||||||
from torchreid.utils.reidtools import visualize_ranked_results
|
from torchreid.utils.reidtools import visualize_ranked_results
|
||||||
from torchreid.eval_metrics import evaluate
|
from torchreid.eval_metrics import evaluate
|
||||||
@ -92,7 +92,7 @@ def main():
|
|||||||
if args.evaluate:
|
if args.evaluate:
|
||||||
print("Evaluate only")
|
print("Evaluate only")
|
||||||
|
|
||||||
for name in args.target:
|
for name in args.target_names:
|
||||||
print("Evaluating {} ...".format(name))
|
print("Evaluating {} ...".format(name))
|
||||||
queryloader = testloader_dict[name]['query']
|
queryloader = testloader_dict[name]['query']
|
||||||
galleryloader = testloader_dict[name]['gallery']
|
galleryloader = testloader_dict[name]['gallery']
|
||||||
@ -107,6 +107,7 @@ def main():
|
|||||||
return
|
return
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
ranklogger = RankLogger(args.source_names, args.target_names)
|
||||||
train_time = 0
|
train_time = 0
|
||||||
print("==> Start training")
|
print("==> Start training")
|
||||||
|
|
||||||
@ -131,11 +132,12 @@ def main():
|
|||||||
if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
||||||
print("==> Test")
|
print("==> Test")
|
||||||
|
|
||||||
for name in args.target:
|
for name in args.target_names:
|
||||||
print("Evaluating {} ...".format(name))
|
print("Evaluating {} ...".format(name))
|
||||||
queryloader = testloader_dict[name]['query']
|
queryloader = testloader_dict[name]['query']
|
||||||
galleryloader = testloader_dict[name]['gallery']
|
galleryloader = testloader_dict[name]['gallery']
|
||||||
rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu)
|
rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu)
|
||||||
|
ranklogger.write(name, epoch + 1, rank1)
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
state_dict = model.module.state_dict()
|
state_dict = model.module.state_dict()
|
||||||
@ -152,6 +154,7 @@ def main():
|
|||||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||||
train_time = str(datetime.timedelta(seconds=train_time))
|
train_time = str(datetime.timedelta(seconds=train_time))
|
||||||
print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
|
print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
|
||||||
|
ranklogger.show_summary()
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=False):
|
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=False):
|
||||||
|
@ -20,7 +20,7 @@ from torchreid import models
|
|||||||
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
||||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||||
from torchreid.utils.avgmeter import AverageMeter
|
from torchreid.utils.avgmeter import AverageMeter
|
||||||
from torchreid.utils.logger import Logger
|
from torchreid.utils.loggers import Logger, RankLogger
|
||||||
from torchreid.utils.torchtools import count_num_param
|
from torchreid.utils.torchtools import count_num_param
|
||||||
from torchreid.utils.reidtools import visualize_ranked_results
|
from torchreid.utils.reidtools import visualize_ranked_results
|
||||||
from torchreid.eval_metrics import evaluate
|
from torchreid.eval_metrics import evaluate
|
||||||
@ -89,7 +89,7 @@ def main():
|
|||||||
if args.evaluate:
|
if args.evaluate:
|
||||||
print("Evaluate only")
|
print("Evaluate only")
|
||||||
|
|
||||||
for name in args.target:
|
for name in args.target_names:
|
||||||
print("Evaluating {} ...".format(name))
|
print("Evaluating {} ...".format(name))
|
||||||
queryloader = testloader_dict[name]['query']
|
queryloader = testloader_dict[name]['query']
|
||||||
galleryloader = testloader_dict[name]['gallery']
|
galleryloader = testloader_dict[name]['gallery']
|
||||||
@ -104,6 +104,7 @@ def main():
|
|||||||
return
|
return
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
ranklogger = RankLogger(args.source_names, args.target_names)
|
||||||
train_time = 0
|
train_time = 0
|
||||||
print("==> Start training")
|
print("==> Start training")
|
||||||
|
|
||||||
@ -117,11 +118,12 @@ def main():
|
|||||||
if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
||||||
print("==> Test")
|
print("==> Test")
|
||||||
|
|
||||||
for name in args.target:
|
for name in args.target_names:
|
||||||
print("Evaluating {} ...".format(name))
|
print("Evaluating {} ...".format(name))
|
||||||
queryloader = testloader_dict[name]['query']
|
queryloader = testloader_dict[name]['query']
|
||||||
galleryloader = testloader_dict[name]['gallery']
|
galleryloader = testloader_dict[name]['gallery']
|
||||||
rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu)
|
rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu)
|
||||||
|
ranklogger.write(name, epoch + 1, rank1)
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
state_dict = model.module.state_dict()
|
state_dict = model.module.state_dict()
|
||||||
@ -138,6 +140,7 @@ def main():
|
|||||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||||
train_time = str(datetime.timedelta(seconds=train_time))
|
train_time = str(datetime.timedelta(seconds=train_time))
|
||||||
print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
|
print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
|
||||||
|
ranklogger.show_summary()
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
|
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user