mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
refactor evaluation code
This commit is contained in:
parent
154a06b875
commit
5b88736e1d
@ -112,8 +112,6 @@ def inference_on_dataset(model, data_loader, evaluator):
|
||||
|
||||
start_compute_time = time.perf_counter()
|
||||
outputs = model(inputs)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
total_compute_time += time.perf_counter() - start_compute_time
|
||||
evaluator.process(inputs, outputs)
|
||||
|
||||
|
@ -68,8 +68,9 @@ class ReidEvaluator(DatasetEvaluator):
|
||||
camids = comm.gather(self.camids)
|
||||
camids = sum(camids, [])
|
||||
|
||||
if not comm.is_main_process():
|
||||
return {}
|
||||
# fmt: off
|
||||
if not comm.is_main_process(): return {}
|
||||
# fmt: on
|
||||
else:
|
||||
features = self.features
|
||||
pids = self.pids
|
||||
|
@ -11,8 +11,8 @@ from . import optim
|
||||
def build_optimizer(cfg, model):
|
||||
params = []
|
||||
for key, value in model.named_parameters():
|
||||
if not value.requires_grad:
|
||||
continue
|
||||
if not value.requires_grad: continue
|
||||
|
||||
lr = cfg.SOLVER.BASE_LR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
||||
if "heads" in key:
|
||||
@ -23,13 +23,10 @@ def build_optimizer(cfg, model):
|
||||
params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay, "freeze": False}]
|
||||
|
||||
solver_opt = cfg.SOLVER.OPT
|
||||
if hasattr(optim, solver_opt):
|
||||
if solver_opt == "SGD":
|
||||
opt_fns = getattr(optim, solver_opt)(params, momentum=cfg.SOLVER.MOMENTUM)
|
||||
else:
|
||||
opt_fns = getattr(optim, solver_opt)(params)
|
||||
else:
|
||||
raise NameError("optimizer {} not support".format(cfg.SOLVER.OPT))
|
||||
# fmt: off
|
||||
if solver_opt == "SGD": opt_fns = getattr(optim, solver_opt)(params, momentum=cfg.SOLVER.MOMENTUM)
|
||||
else: opt_fns = getattr(optim, solver_opt)(params)
|
||||
# fmt: on
|
||||
return opt_fns
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user