feat: add save best model checkpoint

pull/372/head
liaoxingyu 2020-12-22 15:50:23 +08:00
parent 5469e8ce76
commit bb7a00e615
1 changed files with 11 additions and 0 deletions

View File

@ -15,6 +15,7 @@ from termcolor import colored
from torch.nn.parallel import DataParallel, DistributedDataParallel
from fastreid.utils.file_io import PathManager
from fastreid.utils.events import get_event_storage
class _IncompatibleKeys(
@ -307,6 +308,7 @@ class PeriodicCheckpointer:
self.checkpointer = checkpointer
self.period = int(period)
self.max_epoch = max_epoch
self.best_metric = -1
def step(self, epoch: int, **kwargs: Any):
"""
@ -323,8 +325,17 @@ class PeriodicCheckpointer:
self.checkpointer.save(
"model_{:04d}".format(epoch), **additional_state
)
if additional_state["metric"] > self.best_metric:
self.checkpointer.save(
"model_best", **additional_state
)
self.best_metric = additional_state["metric"]
if epoch >= self.max_epoch - 1:
self.checkpointer.save("model_final", **additional_state)
if additional_state["metric"] > self.best_metric:
self.checkpointer.save(
"model_best", **additional_state
)
def save(self, name: str, **kwargs: Any):
"""