mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
refactor evaluation code
This commit is contained in:
parent
154a06b875
commit
5b88736e1d
@ -112,8 +112,6 @@ def inference_on_dataset(model, data_loader, evaluator):
|
|||||||
|
|
||||||
start_compute_time = time.perf_counter()
|
start_compute_time = time.perf_counter()
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
total_compute_time += time.perf_counter() - start_compute_time
|
total_compute_time += time.perf_counter() - start_compute_time
|
||||||
evaluator.process(inputs, outputs)
|
evaluator.process(inputs, outputs)
|
||||||
|
|
||||||
|
@ -68,8 +68,9 @@ class ReidEvaluator(DatasetEvaluator):
|
|||||||
camids = comm.gather(self.camids)
|
camids = comm.gather(self.camids)
|
||||||
camids = sum(camids, [])
|
camids = sum(camids, [])
|
||||||
|
|
||||||
if not comm.is_main_process():
|
# fmt: off
|
||||||
return {}
|
if not comm.is_main_process(): return {}
|
||||||
|
# fmt: on
|
||||||
else:
|
else:
|
||||||
features = self.features
|
features = self.features
|
||||||
pids = self.pids
|
pids = self.pids
|
||||||
|
@ -11,8 +11,8 @@ from . import optim
|
|||||||
def build_optimizer(cfg, model):
|
def build_optimizer(cfg, model):
|
||||||
params = []
|
params = []
|
||||||
for key, value in model.named_parameters():
|
for key, value in model.named_parameters():
|
||||||
if not value.requires_grad:
|
if not value.requires_grad: continue
|
||||||
continue
|
|
||||||
lr = cfg.SOLVER.BASE_LR
|
lr = cfg.SOLVER.BASE_LR
|
||||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
||||||
if "heads" in key:
|
if "heads" in key:
|
||||||
@ -23,13 +23,10 @@ def build_optimizer(cfg, model):
|
|||||||
params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay, "freeze": False}]
|
params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay, "freeze": False}]
|
||||||
|
|
||||||
solver_opt = cfg.SOLVER.OPT
|
solver_opt = cfg.SOLVER.OPT
|
||||||
if hasattr(optim, solver_opt):
|
# fmt: off
|
||||||
if solver_opt == "SGD":
|
if solver_opt == "SGD": opt_fns = getattr(optim, solver_opt)(params, momentum=cfg.SOLVER.MOMENTUM)
|
||||||
opt_fns = getattr(optim, solver_opt)(params, momentum=cfg.SOLVER.MOMENTUM)
|
else: opt_fns = getattr(optim, solver_opt)(params)
|
||||||
else:
|
# fmt: on
|
||||||
opt_fns = getattr(optim, solver_opt)(params)
|
|
||||||
else:
|
|
||||||
raise NameError("optimizer {} not support".format(cfg.SOLVER.OPT))
|
|
||||||
return opt_fns
|
return opt_fns
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user