refactor: simplify engine

pull/2698/head
gaotingquan 2023-02-21 09:26:57 +00:00 committed by Wei Shengyu
parent 4cd9dc0e05
commit 376d83d46e
6 changed files with 339 additions and 288 deletions

View File

@ -15,6 +15,8 @@
import inspect
import copy
import random
import platform
import paddle
import numpy as np
import paddle.distributed as dist
@ -86,7 +88,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed)
def build_dataloader(config, mode, device, use_dali=False, seed=None):
def build(config, mode, device, use_dali=False, seed=None):
assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
@ -187,3 +189,79 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader
def build_dataloader(engine):
if "class_num" in engine.config["Global"]:
global_class_num = engine.config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
engine.config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg)
class_num = engine.config["Arch"].get("class_num", None)
engine.config["DataLoader"].update({"class_num": class_num})
engine.config["DataLoader"].update({
"epochs": engine.config["Global"]["epochs"]
})
use_dali = engine.config['Global'].get("use_dali", False)
dataloader_dict = {
"Train": None,
"UnLabelTrain": None,
"Eval": None,
"Query": None,
"Gallery": None,
"GalleryQuery": None
}
if engine.mode == 'train':
train_dataloader = build(
engine.config["DataLoader"],
"Train",
engine.device,
use_dali,
seed=None)
iter_per_epoch = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
if engine.config["Global"].get("iter_per_epoch", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
iter_per_epoch = engine.config["Global"].get("iter_per_epoch")
iter_per_epoch = iter_per_epoch // engine.update_freq * engine.update_freq
engine.iter_per_epoch = iter_per_epoch
train_dataloader.iter_per_epoch = iter_per_epoch
dataloader_dict["Train"] = train_dataloader
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build(
engine.config["DataLoader"],
"UnLabelTrain",
engine.device,
use_dali,
seed=None)
if engine.mode == "eval" or (engine.mode == "train" and
engine.config["Global"]["eval_during_train"]):
if engine.eval_mode in ["classification", "adaface"]:
dataloader_dict["Eval"] = build(
engine.config["DataLoader"],
"Eval",
engine.device,
use_dali,
seed=None)
elif engine.eval_mode == "retrieval":
if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
key = list(engine.config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], key, engine.device,
use_dali)
else:
dataloader_dict["Gallery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Gallery",
engine.device, use_dali)
dataloader_dict["Query"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Query",
engine.device, use_dali)
return dataloader_dict

View File

@ -15,7 +15,6 @@ from __future__ import division
from __future__ import print_function
import os
import platform
import paddle
import paddle.distributed as dist
from visualdl import LogWriter
@ -52,168 +51,60 @@ class Engine(object):
assert mode in ["train", "eval", "infer", "export"]
self.mode = mode
self.config = config
self.eval_mode = self.config["Global"].get("eval_mode",
"classification")
self.train_mode = self.config["Global"].get("train_mode", None)
# set seed
self._init_seed()
# init logger
init_logger(self.config, mode=mode)
print_config(config)
# for visualdl
self.vdl_writer = self._init_vdl()
# is_rec
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
False):
self.is_rec = True
else:
self.is_rec = False
# set seed
seed = self.config["Global"].get("seed", False)
if seed or seed == 0:
assert isinstance(seed, int), "The 'seed' must be a integer!"
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
# init logger
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
f"{mode}.log")
init_logger(log_file=log_file)
print_config(config)
# init train_func and eval_func
assert self.eval_mode in [
"classification", "retrieval", "adaface"
], logger.error("Invalid eval mode: {}".format(self.eval_mode))
self.train_mode = self.config["Global"].get("train_mode", None)
if self.train_mode is None:
self.train_epoch_func = train_method.train_epoch
else:
self.train_epoch_func = getattr(train_method,
"train_epoch_" + self.train_mode)
self.eval_mode = self.config["Global"].get("eval_mode",
"classification")
assert self.eval_mode in [
"classification", "retrieval", "adaface"
], logger.error("Invalid eval mode: {}".format(self.eval_mode))
self.eval_func = getattr(evaluation, self.eval_mode + "_eval")
self.use_dali = self.config['Global'].get("use_dali", False)
# for visualdl
self.vdl_writer = None
if self.config['Global'][
'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path)
self.vdl_writer = LogWriter(logdir=vdl_writer_path)
# set device
assert self.config["Global"][
"device"] in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
self.device = paddle.set_device(self.config["Global"]["device"])
logger.info('train with paddle {} and device {}'.format(
paddle.__version__, self.device))
self.device = self._init_device()
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg)
#TODO(gaotingquan): support rec
class_num = config["Arch"].get("class_num", None)
self.config["DataLoader"].update({"class_num": class_num})
self.config["DataLoader"].update({
"epochs": self.config["Global"]["epochs"]
})
# build dataloader
if self.mode == 'train':
self.train_dataloader = build_dataloader(
self.config["DataLoader"], "Train", self.device, self.use_dali)
if self.config["DataLoader"].get('UnLabelTrain', None) is not None:
self.unlabel_train_dataloader = build_dataloader(
self.config["DataLoader"], "UnLabelTrain", self.device,
self.use_dali)
else:
self.unlabel_train_dataloader = None
self.iter_per_epoch = len(
self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader)
if self.config["Global"].get("iter_per_epoch", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
self.iter_per_epoch = self.config["Global"].get(
"iter_per_epoch")
self.iter_per_epoch = self.iter_per_epoch // self.update_freq * self.update_freq
if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
if self.eval_mode in ["classification", "adaface"]:
self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device,
self.use_dali)
elif self.eval_mode == "retrieval":
self.gallery_query_dataloader = None
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
key = list(self.config["DataLoader"]["Eval"].keys())[0]
self.gallery_query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], key, self.device,
self.use_dali)
else:
self.gallery_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Gallery",
self.device, self.use_dali)
self.query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Query",
self.device, self.use_dali)
self.dataloader_dict = build_dataloader(self)
self.train_dataloader, self.unlabel_train_dataloader, self.eval_dataloader = self.dataloader_dict[
"Train"], self.dataloader_dict[
"UnLabelTrain"], self.dataloader_dict["Eval"]
self.gallery_query_dataloader, self.gallery_dataloader, self.query_dataloader = self.dataloader_dict[
"GalleryQuery"], self.dataloader_dict[
"Gallery"], self.dataloader_dict["Query"]
# build loss
if self.mode == "train":
label_loss_info = self.config["Loss"]["Train"]
self.train_loss_func = build_loss(label_loss_info)
unlabel_loss_info = self.config.get("UnLabelLoss", {}).get("Train",
None)
self.unlabel_train_loss_func = build_loss(unlabel_loss_info)
if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
loss_config = self.config.get("Loss", None)
if loss_config is not None:
loss_config = loss_config.get("Eval")
if loss_config is not None:
self.eval_loss_func = build_loss(loss_config)
else:
self.eval_loss_func = None
else:
self.eval_loss_func = None
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
self.config, self.mode)
# build metric
if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[
"Metric"] and self.config["Metric"]["Train"]:
metric_config = self.config["Metric"]["Train"]
if hasattr(self.train_dataloader, "collate_fn"
) and self.train_dataloader.collate_fn is not None:
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg)
metric_config.pop(m_idx)
self.train_metric_func = build_metrics(metric_config)
else:
self.train_metric_func = None
if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
if self.eval_mode == "classification":
if "Metric" in self.config and "Eval" in self.config["Metric"]:
self.eval_metric_func = build_metrics(self.config["Metric"]
["Eval"])
else:
self.eval_metric_func = None
elif self.eval_mode == "retrieval":
if "Metric" in self.config and "Eval" in self.config["Metric"]:
metric_config = self.config["Metric"]["Eval"]
else:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
self.eval_metric_func = build_metrics(metric_config)
else:
self.eval_metric_func = None
self.train_metric_func, self.eval_metric_func = build_metrics(self)
# build model
self.model = build_model(self.config, self.mode)
@ -221,139 +112,18 @@ class Engine(object):
apply_to_static(self.config, self.model)
# load_pretrain
if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"):
load_dygraph_pretrain_from_url(
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
else:
load_dygraph_pretrain(
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
self._init_pretrained()
# build optimizer
if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"],
self.iter_per_epoch // self.update_freq,
[self.model, self.train_loss_func])
self.optimizer, self.lr_sch = build_optimizer(
self.config, self.train_dataloader,
[self.model, self.train_loss_func])
# AMP training and evaluating
self.amp = "AMP" in self.config and self.config["AMP"] is not None
self.amp_eval = False
# for amp
if self.amp:
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
})
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get(
"use_dynamic_loss_scaling", False)
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
self.amp_level = self.config['AMP'].get("level", "O1")
if self.amp_level not in ["O1", "O2"]:
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
logger.warning(msg)
self.config['AMP']["level"] = "O1"
self.amp_level = "O1"
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
if self.mode == "train" and self.config["Global"].get(
"eval_during_train",
True) and self.amp_level == "O2" and self.amp_eval == False:
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
logger.warning(msg)
self.config["AMP"]["use_fp16_test"] = True
self.amp_eval = True
# TODO(gaotingquan): to compatible with different versions of Paddle
paddle_version = paddle.__version__[:3]
# paddle version < 2.3.0 and not develop
if paddle_version not in ["2.3", "0.0"]:
if self.mode == "train":
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level,
save_dtype='float32')
elif self.amp_eval:
if self.amp_level == "O2":
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
logger.warning(msg)
self.amp_eval = False
else:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
# paddle version >= 2.3.0 or develop
else:
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.amp.decorate(
models=self.train_loss_func,
level=self.amp_level,
save_dtype='float32')
# build EMA model
self.ema = "EMA" in self.config and self.mode == "train"
if self.ema:
self.model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
# check the gpu num
world_size = dist.get_world_size()
self.config["Global"]["distributed"] = world_size != 1
if self.mode == "train":
std_gpu_num = 8 if isinstance(
self.config["Optimizer"],
dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
if world_size != std_gpu_num:
msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
logger.warning(msg)
self._init_amp()
# for distributed
if self.config["Global"]["distributed"]:
dist.init_parallel_env()
self.model = paddle.DataParallel(self.model)
if self.mode == 'train' and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.DataParallel(
self.train_loss_func)
# set different seed in different GPU manually in distributed environment
if seed is None:
logger.warning(
"The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default"
)
self.config["Global"]["seed"] = seed = 42
logger.info(
f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
)
paddle.seed(int(seed) + dist.get_rank())
np.random.seed(int(seed) + dist.get_rank())
random.seed(int(seed) + dist.get_rank())
# build postprocess for infer
if self.mode == 'infer':
self.preprocess_func = create_operators(self.config["Infer"][
"transforms"])
self.postprocess_func = build_postprocess(self.config["Infer"][
"PostProcess"])
self._init_dist()
def train(self):
assert self.mode == "train"
@ -363,10 +133,17 @@ class Engine(object):
"metric": -1.0,
"epoch": 0,
}
ema_module = None
# build EMA model
self.ema = "EMA" in self.config and self.mode == "train"
if self.ema:
self.model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
best_metric_ema = 0.0
ema_module = self.model_ema.module
else:
ema_module = None
# key:
# val: metrics list word
self.output_info = dict()
@ -392,8 +169,6 @@ class Engine(object):
# for one epoch train
self.train_epoch_func(self, epoch_id, print_batch_step)
if self.use_dali:
self.train_dataloader.reset()
metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
@ -499,6 +274,12 @@ class Engine(object):
@paddle.no_grad()
def infer(self):
assert self.mode == "infer" and self.eval_mode == "classification"
self.preprocess_func = create_operators(self.config["Infer"][
"transforms"])
self.postprocess_func = build_postprocess(self.config["Infer"][
"PostProcess"])
total_trainer = dist.get_world_size()
local_rank = dist.get_rank()
image_list = get_image_list(self.config["Infer"]["infer_imgs"])
@ -586,6 +367,148 @@ class Engine(object):
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
)
def _init_vdl(self):
if self.config['Global'][
'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path)
return LogWriter(logdir=vdl_writer_path)
return None
def _init_seed(self):
seed = self.config["Global"].get("seed", False)
if dist.get_world_size() != 1:
# if self.config["Global"]["distributed"]:
# set different seed in different GPU manually in distributed environment
if not seed:
logger.warning(
"The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default"
)
self.config["Global"]["seed"] = seed = 42
logger.info(
f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
)
dist_seed = int(seed) + dist.get_rank()
paddle.seed(dist_seed)
np.random.seed(dist_seed)
random.seed(dist_seed)
elif seed or seed == 0:
assert isinstance(seed, int), "The 'seed' must be a integer!"
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
def _init_device(self):
device = self.config["Global"]["device"]
assert device in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
logger.info('train with paddle {} and device {}'.format(
paddle.__version__, device))
return paddle.set_device(device)
def _init_pretrained(self):
if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"):
load_dygraph_pretrain_from_url(
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
else:
load_dygraph_pretrain(
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
def _init_amp(self):
self.amp = "AMP" in self.config and self.config["AMP"] is not None
self.amp_eval = False
# for amp
if self.amp:
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
})
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get(
"use_dynamic_loss_scaling", False)
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
self.amp_level = self.config['AMP'].get("level", "O1")
if self.amp_level not in ["O1", "O2"]:
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
logger.warning(msg)
self.config['AMP']["level"] = "O1"
self.amp_level = "O1"
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
if self.mode == "train" and self.config["Global"].get(
"eval_during_train",
True) and self.amp_level == "O2" and self.amp_eval == False:
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
logger.warning(msg)
self.config["AMP"]["use_fp16_test"] = True
self.amp_eval = True
# TODO(gaotingquan): to compatible with different versions of Paddle
paddle_version = paddle.__version__[:3]
# paddle version < 2.3.0 and not develop
if paddle_version not in ["2.3", "0.0"]:
if self.mode == "train":
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level,
save_dtype='float32')
elif self.amp_eval:
if self.amp_level == "O2":
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
logger.warning(msg)
self.amp_eval = False
else:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
# paddle version >= 2.3.0 or develop
else:
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.amp.decorate(
models=self.train_loss_func,
level=self.amp_level,
save_dtype='float32')
def _init_dist(self):
# check the gpu num
world_size = dist.get_world_size()
self.config["Global"]["distributed"] = world_size != 1
# TODO(gaotingquan):
if self.mode == "train":
std_gpu_num = 8 if isinstance(
self.config["Optimizer"],
dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
if world_size != std_gpu_num:
msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
logger.warning(msg)
if self.config["Global"]["distributed"]:
dist.init_parallel_env()
self.model = paddle.DataParallel(self.model)
if self.mode == 'train' and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.DataParallel(
self.train_loss_func)
class ExportModel(TheseusLayer):
"""

View File

@ -51,7 +51,7 @@ from .metabinloss import IntraDomainScatterLoss
class CombinedLoss(nn.Layer):
def __init__(self, config_list):
super().__init__()
self.loss_func = []
loss_func = []
self.loss_weight = []
assert isinstance(config_list, list), (
'operator config should be a list')
@ -63,8 +63,9 @@ class CombinedLoss(nn.Layer):
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))
self.loss_func = nn.LayerList(self.loss_func)
loss_func.append(eval(name)(**param))
self.loss_func = nn.LayerList(loss_func)
logger.debug("build loss {} success.".format(loss_func))
def __call__(self, input, batch):
loss_dict = {}
@ -83,9 +84,22 @@ class CombinedLoss(nn.Layer):
return loss_dict
def build_loss(config):
if config is None:
return None
module_class = CombinedLoss(copy.deepcopy(config))
logger.debug("build loss {} success.".format(module_class))
return module_class
def build_loss(config, mode="train"):
train_loss_func, unlabel_train_loss_func, eval_loss_func = None, None, None
if mode == "train":
label_loss_info = config["Loss"]["Train"]
if label_loss_info:
train_loss_func = CombinedLoss(copy.deepcopy(label_loss_info))
unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None)
if unlabel_loss_info:
unlabel_train_loss_func = CombinedLoss(
copy.deepcopy(unlabel_loss_info))
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
loss_config = config.get("Loss", None)
if loss_config is not None:
loss_config = loss_config.get("Eval")
if loss_config is not None:
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
return train_loss_func, unlabel_train_loss_func, eval_loss_func

View File

@ -65,6 +65,38 @@ class CombinedMetrics(AvgMetrics):
metric.reset()
def build_metrics(config):
metrics_list = CombinedMetrics(copy.deepcopy(config))
return metrics_list
def build_metrics(engine):
config, mode = engine.config, engine.mode
if mode == 'train' and "Metric" in config and "Train" in config[
"Metric"] and config["Metric"]["Train"]:
metric_config = config["Metric"]["Train"]
if hasattr(engine.train_dataloader, "collate_fn"
) and engine.train_dataloader.collate_fn is not None:
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg)
metric_config.pop(m_idx)
train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
else:
train_metric_func = None
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
eval_mode = config["Global"].get("eval_mode", "classification")
if eval_mode == "classification":
if "Metric" in config and "Eval" in config["Metric"]:
eval_metric_func = CombinedMetrics(
copy.deepcopy(config["Metric"]["Eval"]))
else:
eval_metric_func = None
elif eval_mode == "retrieval":
if "Metric" in config and "Eval" in config["Metric"]:
metric_config = config["Metric"]["Eval"]
else:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
else:
eval_metric_func = None
return train_metric_func, eval_metric_func

View File

@ -45,8 +45,11 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
optim_config = copy.deepcopy(config)
def build_optimizer(config, dataloader, model_list=None):
optim_config = copy.deepcopy(config["Optimizer"])
epochs = config["Global"]["epochs"]
update_freq = config["Global"].get("update_freq", 1)
step_each_epoch = dataloader.iter_per_epoch // update_freq
if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name = optim_config.pop("name")

View File

@ -22,16 +22,15 @@ import paddle.distributed as dist
_logger = None
def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
added.
Args:
config(dict): Training config.
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
@ -63,6 +62,8 @@ def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
if init_flag:
_logger.addHandler(stream_handler)
log_file = os.path.join(config['Global']['output_dir'],
config["Arch"]["name"], f"{mode}.log")
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)