mirror of https://github.com/JDAI-CV/fast-reid.git
comment tensorboard output and output cls_accuracy in train loop
parent
ceb7618f68
commit
ae180917b4
fastreid
engine
utils
|
@ -208,7 +208,7 @@ class DefaultTrainer(TrainerBase):
|
|||
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
|
||||
# for part of the parameters is not updated.
|
||||
model = DistributedDataParallel(
|
||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True
|
||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=False
|
||||
)
|
||||
|
||||
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
||||
|
@ -336,7 +336,7 @@ class DefaultTrainer(TrainerBase):
|
|||
# It may not always print what you want to see, since it prints "common" metrics only.
|
||||
CommonMetricPrinter(self.max_iter),
|
||||
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
|
||||
TensorboardXWriter(self.cfg.OUTPUT_DIR),
|
||||
# TensorboardXWriter(self.cfg.OUTPUT_DIR),
|
||||
]
|
||||
|
||||
def train(self):
|
||||
|
|
|
@ -229,6 +229,11 @@ class CommonMetricPrinter(EventWriter):
|
|||
except KeyError:
|
||||
lr = "N/A"
|
||||
|
||||
try:
|
||||
cls_accuracy = "{:.4f}".format( storage.history("cls_accuracy").latest())
|
||||
except KeyError:
|
||||
accuracy = "N/A"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
|
||||
else:
|
||||
|
@ -236,7 +241,7 @@ class CommonMetricPrinter(EventWriter):
|
|||
|
||||
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
|
||||
self.logger.info(
|
||||
" {eta}epoch/iter: {epoch}/{iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
|
||||
" {eta}epoch/iter: {epoch}/{iter} {losses} cls_accuracy: {cls_accuracy} {time}{data_time}lr: {lr} {memory}".format(
|
||||
eta=f"eta: {eta_string} " if eta_string else "",
|
||||
epoch=epoch,
|
||||
iter=iteration,
|
||||
|
@ -247,6 +252,7 @@ class CommonMetricPrinter(EventWriter):
|
|||
if "loss" in k
|
||||
]
|
||||
),
|
||||
cls_accuracy = cls_accuracy,
|
||||
time="time: {:.4f} ".format(iter_time) if iter_time is not None else "",
|
||||
data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "",
|
||||
lr=lr,
|
||||
|
@ -458,4 +464,4 @@ class EventStorage:
|
|||
Delete all the stored histograms for visualization.
|
||||
This should be called after histograms are written to tensorboard.
|
||||
"""
|
||||
self._histograms = []
|
||||
self._histograms = []
|
||||
|
|
Loading…
Reference in New Issue