comment tensorboard output and output cls_accuracy in train loop

pull/608/head
zuchen.wang 2021-11-21 15:20:07 +08:00
parent ceb7618f68
commit ae180917b4
2 changed files with 10 additions and 4 deletions

View File

@ -208,7 +208,7 @@ class DefaultTrainer(TrainerBase):
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True` # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
# for part of the parameters is not updated. # for part of the parameters is not updated.
model = DistributedDataParallel( model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=False
) )
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
@ -336,7 +336,7 @@ class DefaultTrainer(TrainerBase):
# It may not always print what you want to see, since it prints "common" metrics only. # It may not always print what you want to see, since it prints "common" metrics only.
CommonMetricPrinter(self.max_iter), CommonMetricPrinter(self.max_iter),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")), JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR), # TensorboardXWriter(self.cfg.OUTPUT_DIR),
] ]
def train(self): def train(self):

View File

@ -229,6 +229,11 @@ class CommonMetricPrinter(EventWriter):
except KeyError: except KeyError:
lr = "N/A" lr = "N/A"
try:
cls_accuracy = "{:.4f}".format( storage.history("cls_accuracy").latest())
except KeyError:
accuracy = "N/A"
if torch.cuda.is_available(): if torch.cuda.is_available():
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
else: else:
@ -236,7 +241,7 @@ class CommonMetricPrinter(EventWriter):
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh" # NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info( self.logger.info(
" {eta}epoch/iter: {epoch}/{iter} {losses} {time}{data_time}lr: {lr} {memory}".format( " {eta}epoch/iter: {epoch}/{iter} {losses} cls_accuracy: {cls_accuracy} {time}{data_time}lr: {lr} {memory}".format(
eta=f"eta: {eta_string} " if eta_string else "", eta=f"eta: {eta_string} " if eta_string else "",
epoch=epoch, epoch=epoch,
iter=iteration, iter=iteration,
@ -247,6 +252,7 @@ class CommonMetricPrinter(EventWriter):
if "loss" in k if "loss" in k
] ]
), ),
cls_accuracy = cls_accuracy,
time="time: {:.4f} ".format(iter_time) if iter_time is not None else "", time="time: {:.4f} ".format(iter_time) if iter_time is not None else "",
data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "", data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "",
lr=lr, lr=lr,