mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
feat: add save best model mechanism
This commit is contained in:
parent
04fe9fb2d8
commit
5469e8ce76
@ -17,7 +17,7 @@ from fastreid.evaluation.testing import flatten_results_dict
|
||||
from fastreid.solver import optim
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
|
||||
from fastreid.utils.events import EventStorage, EventWriter
|
||||
from fastreid.utils.events import EventStorage, EventWriter, get_event_storage
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from fastreid.utils.precision_bn import update_bn_stats, get_bn_modules
|
||||
from fastreid.utils.timer import Timer
|
||||
@ -198,10 +198,18 @@ class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
|
||||
|
||||
def before_train(self):
|
||||
self.max_epoch = self.trainer.max_epoch
|
||||
if len(self.trainer.cfg.DATASETS.TESTS) == 1:
|
||||
self.metric_name = "metric"
|
||||
else:
|
||||
self.metric_name = self.trainer.cfg.DATASETS.TESTS[0] + "/metric"
|
||||
|
||||
def after_epoch(self):
|
||||
# No way to use **kwargs
|
||||
self.step(self.trainer.epoch)
|
||||
storage = get_event_storage()
|
||||
metric_dict = dict(
|
||||
metric=storage.latest()[self.metric_name][0] if self.metric_name in storage.latest() else -1
|
||||
)
|
||||
self.step(self.trainer.epoch, **metric_dict)
|
||||
|
||||
|
||||
class LRScheduler(HookBase):
|
||||
@ -473,7 +481,7 @@ class LayerFreeze(HookBase):
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.is_frozen = True
|
||||
freeze_layers = ",".join(self.freeze_layers)
|
||||
freeze_layers = ", ".join(self.freeze_layers)
|
||||
self._logger.info(f'Freeze layer group "{freeze_layers}" training for {self.freeze_iters:d} iterations')
|
||||
|
||||
def open_all_layer(self):
|
||||
|
@ -106,6 +106,7 @@ class ReidEvaluator(DatasetEvaluator):
|
||||
self._results['Rank-{}'.format(r)] = cmc[r - 1]
|
||||
self._results['mAP'] = mAP
|
||||
self._results['mINP'] = mINP
|
||||
self._results["metric"] = (mAP + cmc[0]) / 2
|
||||
|
||||
if self.cfg.TEST.ROC_ENABLED:
|
||||
scores, labels = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
|
Loading…
x
Reference in New Issue
Block a user