comment tensorboard output and output cls_accuracy in train loop

pull/608/head
zuchen.wang 2021-11-21 15:20:07 +08:00
parent ceb7618f68
commit ae180917b4
2 changed files with 10 additions and 4 deletions
fastreid

View File

@ -208,7 +208,7 @@ class DefaultTrainer(TrainerBase):
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
# for part of the parameters is not updated.
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=False
)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
@ -336,7 +336,7 @@ class DefaultTrainer(TrainerBase):
# It may not always print what you want to see, since it prints "common" metrics only.
CommonMetricPrinter(self.max_iter),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR),
# TensorboardXWriter(self.cfg.OUTPUT_DIR),
]
def train(self):

View File

@ -229,6 +229,11 @@ class CommonMetricPrinter(EventWriter):
except KeyError:
lr = "N/A"
try:
cls_accuracy = "{:.4f}".format( storage.history("cls_accuracy").latest())
except KeyError:
accuracy = "N/A"
if torch.cuda.is_available():
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
else:
@ -236,7 +241,7 @@ class CommonMetricPrinter(EventWriter):
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info(
" {eta}epoch/iter: {epoch}/{iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
" {eta}epoch/iter: {epoch}/{iter} {losses} cls_accuracy: {cls_accuracy} {time}{data_time}lr: {lr} {memory}".format(
eta=f"eta: {eta_string} " if eta_string else "",
epoch=epoch,
iter=iteration,
@ -247,6 +252,7 @@ class CommonMetricPrinter(EventWriter):
if "loss" in k
]
),
cls_accuracy = cls_accuracy,
time="time: {:.4f} ".format(iter_time) if iter_time is not None else "",
data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "",
lr=lr,
@ -458,4 +464,4 @@ class EventStorage:
Delete all the stored histograms for visualization.
This should be called after histograms are written to tensorboard.
"""
self._histograms = []
self._histograms = []