mirror of https://github.com/JDAI-CV/fast-reid.git
feat: support flip testing
parent
bb7a00e615
commit
66941cf27a
|
@ -331,18 +331,17 @@ class DefaultTrainer(TrainerBase):
|
|||
# be saved by checkpointer.
|
||||
# This is not always the best: if checkpointing has a different frequency,
|
||||
# some checkpoints may have more precise statistics than others.
|
||||
if comm.is_main_process():
|
||||
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
|
||||
|
||||
def test_and_save_results():
|
||||
self._last_eval_results = self.test(self.cfg, self.model)
|
||||
return self._last_eval_results
|
||||
|
||||
# Do evaluation after checkpointer, because then if it fails,
|
||||
# Do evaluation before checkpointer, because then if it fails,
|
||||
# we can use the saved checkpoint to debug.
|
||||
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
|
||||
|
||||
if comm.is_main_process():
|
||||
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
|
||||
# run writers in the end, so that evaluation metrics are written
|
||||
ret.append(hooks.PeriodicWriter(self.build_writers(), 200))
|
||||
|
||||
|
@ -474,7 +473,7 @@ class DefaultTrainer(TrainerBase):
|
|||
)
|
||||
results[dataset_name] = {}
|
||||
continue
|
||||
results_i = inference_on_dataset(model, data_loader, evaluator)
|
||||
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP_ENABLED)
|
||||
results[dataset_name] = results_i
|
||||
|
||||
if comm.is_main_process():
|
||||
|
|
|
@ -78,7 +78,7 @@ class DatasetEvaluator:
|
|||
# return results
|
||||
|
||||
|
||||
def inference_on_dataset(model, data_loader, evaluator):
|
||||
def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
|
||||
"""
|
||||
Run model on the data_loader and evaluate the metrics with evaluator.
|
||||
The model will be used in eval mode.
|
||||
|
@ -92,6 +92,7 @@ def inference_on_dataset(model, data_loader, evaluator):
|
|||
evaluator (DatasetEvaluator): the evaluator to run. Use
|
||||
:class:`DatasetEvaluators([])` if you only want to benchmark, but
|
||||
don't want to do any evaluation.
|
||||
flip_test (bool): If get features with flipped images
|
||||
Returns:
|
||||
The return value of `evaluator.evaluate()`
|
||||
"""
|
||||
|
@ -112,6 +113,11 @@ def inference_on_dataset(model, data_loader, evaluator):
|
|||
|
||||
start_compute_time = time.perf_counter()
|
||||
outputs = model(inputs)
|
||||
# Flip test
|
||||
if flip_test:
|
||||
inputs["images"] = inputs["images"].flip(dims=[3])
|
||||
flip_outputs = model(inputs)
|
||||
outputs = (outputs + flip_outputs) / 2
|
||||
total_compute_time += time.perf_counter() - start_compute_time
|
||||
evaluator.process(inputs, outputs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue