mirror of https://github.com/JDAI-CV/fast-reid.git
comment tensorboard output and output cls_accuracy in train loop
parent
ceb7618f68
commit
ae180917b4
|
@ -208,7 +208,7 @@ class DefaultTrainer(TrainerBase):
|
||||||
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
|
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
|
||||||
# for part of the parameters is not updated.
|
# for part of the parameters is not updated.
|
||||||
model = DistributedDataParallel(
|
model = DistributedDataParallel(
|
||||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True
|
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
||||||
|
@ -336,7 +336,7 @@ class DefaultTrainer(TrainerBase):
|
||||||
# It may not always print what you want to see, since it prints "common" metrics only.
|
# It may not always print what you want to see, since it prints "common" metrics only.
|
||||||
CommonMetricPrinter(self.max_iter),
|
CommonMetricPrinter(self.max_iter),
|
||||||
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
|
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
|
||||||
TensorboardXWriter(self.cfg.OUTPUT_DIR),
|
# TensorboardXWriter(self.cfg.OUTPUT_DIR),
|
||||||
]
|
]
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
|
|
@ -229,6 +229,11 @@ class CommonMetricPrinter(EventWriter):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
lr = "N/A"
|
lr = "N/A"
|
||||||
|
|
||||||
|
try:
|
||||||
|
cls_accuracy = "{:.4f}".format( storage.history("cls_accuracy").latest())
|
||||||
|
except KeyError:
|
||||||
|
accuracy = "N/A"
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
|
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
|
||||||
else:
|
else:
|
||||||
|
@ -236,7 +241,7 @@ class CommonMetricPrinter(EventWriter):
|
||||||
|
|
||||||
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
|
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
" {eta}epoch/iter: {epoch}/{iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
|
" {eta}epoch/iter: {epoch}/{iter} {losses} cls_accuracy: {cls_accuracy} {time}{data_time}lr: {lr} {memory}".format(
|
||||||
eta=f"eta: {eta_string} " if eta_string else "",
|
eta=f"eta: {eta_string} " if eta_string else "",
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
iter=iteration,
|
iter=iteration,
|
||||||
|
@ -247,6 +252,7 @@ class CommonMetricPrinter(EventWriter):
|
||||||
if "loss" in k
|
if "loss" in k
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
cls_accuracy = cls_accuracy,
|
||||||
time="time: {:.4f} ".format(iter_time) if iter_time is not None else "",
|
time="time: {:.4f} ".format(iter_time) if iter_time is not None else "",
|
||||||
data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "",
|
data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "",
|
||||||
lr=lr,
|
lr=lr,
|
||||||
|
|
Loading…
Reference in New Issue