PaddleClas/ppcls/engine/train/utils.py

95 lines
3.7 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import paddle
import datetime
from ppcls.utils import logger
from ppcls.utils.misc import AverageMeter
def update_metric(trainer, out, batch, batch_size):
# calc metric
if trainer.train_metric_func is not None:
metric_dict = trainer.train_metric_func(out, batch[-1])
for key in metric_dict:
if key not in trainer.output_info:
trainer.output_info[key] = AverageMeter(key, '7.5f')
trainer.output_info[key].update(
float(metric_dict[key]), batch_size)
def update_loss(trainer, loss_dict, batch_size):
# update_output_info
for key in loss_dict:
if key not in trainer.output_info:
trainer.output_info[key] = AverageMeter(key, '7.5f')
trainer.output_info[key].update(float(loss_dict[key]), batch_size)
def log_info(trainer, batch_size, epoch_id, iter_id):
lr_msg = ", ".join([
"lr({}): {:.8f}".format(type_name(lr), lr.get_lr())
for i, lr in enumerate(trainer.lr_sch)
])
metric_msg = ", ".join([
"{}: {:.5f}".format(key, trainer.output_info[key].avg)
for key in trainer.output_info
])
time_msg = "s, ".join([
"{}: {:.5f}".format(key, trainer.time_info[key].avg)
for key in trainer.time_info
])
ips_msg = "ips: {:.5f} samples/s".format(
batch_size / trainer.time_info["batch_cost"].avg)
global_epochs = trainer.config["Global"]["epochs"]
eta_sec = (
(trainer.config["Global"]["epochs"] - epoch_id + 1) *
trainer.iter_per_epoch - iter_id) * trainer.time_info["batch_cost"].avg
eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
max_mem_reserved_msg = ""
max_mem_allocated_msg = ""
max_mem_msg = ""
print_mem_info = trainer.config["Global"].get("print_mem_info", False)
if print_mem_info:
if paddle.device.is_compiled_with_cuda():
max_mem_reserved_msg = f"max_mem_reserved: {format(paddle.device.cuda.max_memory_reserved() / (1024 ** 2), '.2f')} MB"
max_mem_allocated_msg = f"max_mem_allocated: {format(paddle.device.cuda.max_memory_allocated() / (1024 ** 2), '.2f')} MB"
max_mem_msg = f", {max_mem_reserved_msg}, {max_mem_allocated_msg}"
logger.info(
f"[Train][Epoch {epoch_id}/{global_epochs}][Iter: {iter_id}/{trainer.iter_per_epoch}]{lr_msg}, {metric_msg}, {time_msg}, {ips_msg}, {eta_msg}{max_mem_msg}"
)
for key in trainer.time_info:
trainer.time_info[key].reset()
for i, lr in enumerate(trainer.lr_sch):
logger.scaler(
name="lr({})".format(type_name(lr)),
value=lr.get_lr(),
step=trainer.global_step,
writer=trainer.vdl_writer)
for key in trainer.output_info:
logger.scaler(
name="train_{}".format(key),
value=trainer.output_info[key].avg,
step=trainer.global_step,
writer=trainer.vdl_writer)
def type_name(object: object) -> str:
"""get class name of an object"""
return object.__class__.__name__