feat: support flip testing

pull/372/head
liaoxingyu 2020-12-22 15:50:50 +08:00
parent bb7a00e615
commit 66941cf27a
2 changed files with 10 additions and 5 deletions

View File

@ -331,18 +331,17 @@ class DefaultTrainer(TrainerBase):
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# Do evaluation before checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), 200))
@ -474,7 +473,7 @@ class DefaultTrainer(TrainerBase):
)
results[dataset_name] = {}
continue
results_i = inference_on_dataset(model, data_loader, evaluator)
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP_ENABLED)
results[dataset_name] = results_i
if comm.is_main_process():

View File

@ -78,7 +78,7 @@ class DatasetEvaluator:
# return results
def inference_on_dataset(model, data_loader, evaluator):
def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
"""
Run model on the data_loader and evaluate the metrics with evaluator.
The model will be used in eval mode.
@ -92,6 +92,7 @@ def inference_on_dataset(model, data_loader, evaluator):
evaluator (DatasetEvaluator): the evaluator to run. Use
:class:`DatasetEvaluators([])` if you only want to benchmark, but
don't want to do any evaluation.
flip_test (bool): If get features with flipped images
Returns:
The return value of `evaluator.evaluate()`
"""
@ -112,6 +113,11 @@ def inference_on_dataset(model, data_loader, evaluator):
start_compute_time = time.perf_counter()
outputs = model(inputs)
# Flip test
if flip_test:
inputs["images"] = inputs["images"].flip(dims=[3])
flip_outputs = model(inputs)
outputs = (outputs + flip_outputs) / 2
total_compute_time += time.perf_counter() - start_compute_time
evaluator.process(inputs, outputs)