2021-08-22 23:10:23 +08:00
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
|
|
|
|
import datetime
|
2023-03-14 16:16:40 +08:00
|
|
|
from ppcls.utils import logger
|
2021-08-22 23:10:23 +08:00
|
|
|
from ppcls.utils.misc import AverageMeter
|
|
|
|
|
|
|
|
|
|
|
|
def update_metric(trainer, out, batch, batch_size):
|
|
|
|
# calc metric
|
|
|
|
if trainer.train_metric_func is not None:
|
|
|
|
metric_dict = trainer.train_metric_func(out, batch[-1])
|
|
|
|
for key in metric_dict:
|
|
|
|
if key not in trainer.output_info:
|
|
|
|
trainer.output_info[key] = AverageMeter(key, '7.5f')
|
2022-11-01 12:01:01 +08:00
|
|
|
trainer.output_info[key].update(
|
|
|
|
float(metric_dict[key]), batch_size)
|
2021-08-22 23:10:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
def update_loss(trainer, loss_dict, batch_size):
|
|
|
|
# update_output_info
|
|
|
|
for key in loss_dict:
|
|
|
|
if key not in trainer.output_info:
|
|
|
|
trainer.output_info[key] = AverageMeter(key, '7.5f')
|
2022-11-01 12:01:01 +08:00
|
|
|
trainer.output_info[key].update(float(loss_dict[key]), batch_size)
|
2021-08-22 23:10:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
def log_info(trainer, batch_size, epoch_id, iter_id):
|
2022-04-19 14:26:42 +08:00
|
|
|
lr_msg = ", ".join([
|
2022-11-01 12:01:01 +08:00
|
|
|
"lr({}): {:.8f}".format(type_name(lr), lr.get_lr())
|
2022-04-19 14:26:42 +08:00
|
|
|
for i, lr in enumerate(trainer.lr_sch)
|
|
|
|
])
|
2021-08-22 23:10:23 +08:00
|
|
|
metric_msg = ", ".join([
|
|
|
|
"{}: {:.5f}".format(key, trainer.output_info[key].avg)
|
|
|
|
for key in trainer.output_info
|
|
|
|
])
|
|
|
|
time_msg = "s, ".join([
|
|
|
|
"{}: {:.5f}".format(key, trainer.time_info[key].avg)
|
|
|
|
for key in trainer.time_info
|
|
|
|
])
|
|
|
|
|
2022-02-10 16:25:52 +08:00
|
|
|
ips_msg = "ips: {:.5f} samples/s".format(
|
2021-08-22 23:10:23 +08:00
|
|
|
batch_size / trainer.time_info["batch_cost"].avg)
|
2022-10-25 11:28:43 +08:00
|
|
|
|
2023-03-14 16:16:40 +08:00
|
|
|
eta_sec = (
|
|
|
|
(trainer.config["Global"]["epochs"] - epoch_id + 1) *
|
|
|
|
trainer.iter_per_epoch - iter_id) * trainer.time_info["batch_cost"].avg
|
2021-08-22 23:10:23 +08:00
|
|
|
eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
|
|
|
|
logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format(
|
2023-03-14 16:16:40 +08:00
|
|
|
epoch_id, trainer.config["Global"]["epochs"], iter_id, trainer.
|
|
|
|
iter_per_epoch, lr_msg, metric_msg, time_msg, ips_msg, eta_msg))
|
2021-08-22 23:10:23 +08:00
|
|
|
|
2022-04-19 14:26:42 +08:00
|
|
|
for i, lr in enumerate(trainer.lr_sch):
|
|
|
|
logger.scaler(
|
2022-11-01 12:01:01 +08:00
|
|
|
name="lr({})".format(type_name(lr)),
|
2022-04-19 14:26:42 +08:00
|
|
|
value=lr.get_lr(),
|
|
|
|
step=trainer.global_step,
|
|
|
|
writer=trainer.vdl_writer)
|
2021-08-22 23:10:23 +08:00
|
|
|
for key in trainer.output_info:
|
|
|
|
logger.scaler(
|
|
|
|
name="train_{}".format(key),
|
|
|
|
value=trainer.output_info[key].avg,
|
|
|
|
step=trainer.global_step,
|
|
|
|
writer=trainer.vdl_writer)
|
2023-03-14 16:16:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
def type_name(object: object) -> str:
|
|
|
|
"""get class name of an object"""
|
|
|
|
return object.__class__.__name__
|