refactor evaluation code

This commit is contained in:
liaoxingyu 2020-09-23 19:32:40 +08:00
parent 154a06b875
commit 5b88736e1d
3 changed files with 9 additions and 13 deletions

View File

@ -112,8 +112,6 @@ def inference_on_dataset(model, data_loader, evaluator):
start_compute_time = time.perf_counter()
outputs = model(inputs)
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
evaluator.process(inputs, outputs)

View File

@ -68,8 +68,9 @@ class ReidEvaluator(DatasetEvaluator):
camids = comm.gather(self.camids)
camids = sum(camids, [])
if not comm.is_main_process():
return {}
# fmt: off
if not comm.is_main_process(): return {}
# fmt: on
else:
features = self.features
pids = self.pids

View File

@ -11,8 +11,8 @@ from . import optim
def build_optimizer(cfg, model):
params = []
for key, value in model.named_parameters():
if not value.requires_grad:
continue
if not value.requires_grad: continue
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if "heads" in key:
@ -23,13 +23,10 @@ def build_optimizer(cfg, model):
params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay, "freeze": False}]
solver_opt = cfg.SOLVER.OPT
if hasattr(optim, solver_opt):
if solver_opt == "SGD":
opt_fns = getattr(optim, solver_opt)(params, momentum=cfg.SOLVER.MOMENTUM)
else:
opt_fns = getattr(optim, solver_opt)(params)
else:
raise NameError("optimizer {} not support".format(cfg.SOLVER.OPT))
# fmt: off
if solver_opt == "SGD": opt_fns = getattr(optim, solver_opt)(params, momentum=cfg.SOLVER.MOMENTUM)
else: opt_fns = getattr(optim, solver_opt)(params)
# fmt: on
return opt_fns