comment tensorboard output and output cls_accuracy in train loop

pull/608/head
zuchen.wang 2021-11-21 15:20:07 +08:00
parent ceb7618f68
commit ae180917b4
2 changed files with 10 additions and 4 deletions

View File

@ -208,7 +208,7 @@ class DefaultTrainer(TrainerBase):
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True` # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
# for part of the parameters is not updated. # for part of the parameters is not updated.
model = DistributedDataParallel( model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=False
) )
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
@ -336,7 +336,7 @@ class DefaultTrainer(TrainerBase):
# It may not always print what you want to see, since it prints "common" metrics only. # It may not always print what you want to see, since it prints "common" metrics only.
CommonMetricPrinter(self.max_iter), CommonMetricPrinter(self.max_iter),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")), JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR), # TensorboardXWriter(self.cfg.OUTPUT_DIR),
] ]
def train(self): def train(self):

View File

@ -229,6 +229,11 @@ class CommonMetricPrinter(EventWriter):
except KeyError: except KeyError:
lr = "N/A" lr = "N/A"
try:
cls_accuracy = "{:.4f}".format( storage.history("cls_accuracy").latest())
except KeyError:
accuracy = "N/A"
if torch.cuda.is_available(): if torch.cuda.is_available():
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
else: else:
@ -236,7 +241,7 @@ class CommonMetricPrinter(EventWriter):
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh" # NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info( self.logger.info(
" {eta}epoch/iter: {epoch}/{iter} {losses} {time}{data_time}lr: {lr} {memory}".format( " {eta}epoch/iter: {epoch}/{iter} {losses} cls_accuracy: {cls_accuracy} {time}{data_time}lr: {lr} {memory}".format(
eta=f"eta: {eta_string} " if eta_string else "", eta=f"eta: {eta_string} " if eta_string else "",
epoch=epoch, epoch=epoch,
iter=iteration, iter=iteration,
@ -247,6 +252,7 @@ class CommonMetricPrinter(EventWriter):
if "loss" in k if "loss" in k
] ]
), ),
cls_accuracy = cls_accuracy,
time="time: {:.4f} ".format(iter_time) if iter_time is not None else "", time="time: {:.4f} ".format(iter_time) if iter_time is not None else "",
data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "", data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "",
lr=lr, lr=lr,
@ -458,4 +464,4 @@ class EventStorage:
Delete all the stored histograms for visualization. Delete all the stored histograms for visualization.
This should be called after histograms are written to tensorboard. This should be called after histograms are written to tensorboard.
""" """
self._histograms = [] self._histograms = []