2020-10-01 18:10:02 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
|
|
|
@author: xingyu liao
|
|
|
|
@contact: sherlockliao01@gmail.com
|
|
|
|
"""
|
2020-10-01 18:12:51 +08:00
|
|
|
|
2020-10-01 18:10:02 +08:00
|
|
|
import torch
|
|
|
|
from ray import tune
|
|
|
|
|
|
|
|
from fastreid.engine.hooks import EvalHook, flatten_results_dict
|
2020-10-01 18:12:51 +08:00
|
|
|
from fastreid.utils.checkpoint import Checkpointer
|
2020-10-01 18:10:02 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TuneReportHook(EvalHook):
|
2020-10-01 18:12:51 +08:00
|
|
|
def __init__(self, eval_period, eval_function):
|
|
|
|
super().__init__(eval_period, eval_function)
|
|
|
|
self.step = 0
|
|
|
|
|
2020-10-01 18:10:02 +08:00
|
|
|
def _do_eval(self):
|
|
|
|
results = self._func()
|
|
|
|
|
|
|
|
if results:
|
|
|
|
assert isinstance(
|
|
|
|
results, dict
|
|
|
|
), "Eval function must return a dict. Got {} instead.".format(results)
|
|
|
|
|
|
|
|
flattened_results = flatten_results_dict(results)
|
|
|
|
for k, v in flattened_results.items():
|
|
|
|
try:
|
|
|
|
v = float(v)
|
|
|
|
except Exception:
|
|
|
|
raise ValueError(
|
|
|
|
"[EvalHook] eval_function should return a nested dict of float. "
|
|
|
|
"Got '{}: {}' instead.".format(k, v)
|
|
|
|
)
|
|
|
|
|
|
|
|
# Remove extra memory cache of main process due to evaluation
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
2020-10-01 18:12:51 +08:00
|
|
|
self.step += 1
|
2020-10-10 17:35:25 +08:00
|
|
|
|
2020-10-01 18:12:51 +08:00
|
|
|
# Here we save a checkpoint. It is automatically registered with
|
2020-10-10 17:35:25 +08:00
|
|
|
# Ray Tune and will potentially be passed as the `checkpoint_dir`
|
2020-10-01 18:12:51 +08:00
|
|
|
# parameter in future iterations.
|
|
|
|
with tune.checkpoint_dir(step=self.step) as checkpoint_dir:
|
2020-10-10 17:35:25 +08:00
|
|
|
additional_state = {"iteration": int(self.trainer.iter)}
|
2020-10-01 18:12:51 +08:00
|
|
|
Checkpointer(
|
|
|
|
# Assume you want to save checkpoints together with logs/statistics
|
|
|
|
self.trainer.model,
|
|
|
|
checkpoint_dir,
|
|
|
|
save_to_disk=True,
|
|
|
|
optimizer=self.trainer.optimizer,
|
|
|
|
scheduler=self.trainer.scheduler,
|
2020-10-10 17:35:25 +08:00
|
|
|
).save(name="checkpoint", **additional_state)
|
2020-10-01 18:12:51 +08:00
|
|
|
|
2020-10-10 17:35:25 +08:00
|
|
|
metrics = dict(r1=results['Rank-1'], map=results['mAP'], score=(results['Rank-1'] + results['mAP']) / 2)
|
|
|
|
tune.report(**metrics)
|