mirror of https://github.com/JDAI-CV/fast-reid.git
feat: add save best model checkpoint
parent
5469e8ce76
commit
bb7a00e615
|
@ -15,6 +15,7 @@ from termcolor import colored
|
|||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from fastreid.utils.events import get_event_storage
|
||||
|
||||
|
||||
class _IncompatibleKeys(
|
||||
|
@ -307,6 +308,7 @@ class PeriodicCheckpointer:
|
|||
self.checkpointer = checkpointer
|
||||
self.period = int(period)
|
||||
self.max_epoch = max_epoch
|
||||
self.best_metric = -1
|
||||
|
||||
def step(self, epoch: int, **kwargs: Any):
|
||||
"""
|
||||
|
@ -323,8 +325,17 @@ class PeriodicCheckpointer:
|
|||
self.checkpointer.save(
|
||||
"model_{:04d}".format(epoch), **additional_state
|
||||
)
|
||||
if additional_state["metric"] > self.best_metric:
|
||||
self.checkpointer.save(
|
||||
"model_best", **additional_state
|
||||
)
|
||||
self.best_metric = additional_state["metric"]
|
||||
if epoch >= self.max_epoch - 1:
|
||||
self.checkpointer.save("model_final", **additional_state)
|
||||
if additional_state["metric"] > self.best_metric:
|
||||
self.checkpointer.save(
|
||||
"model_best", **additional_state
|
||||
)
|
||||
|
||||
def save(self, name: str, **kwargs: Any):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue