refactor: simplify engine
parent
4cd9dc0e05
commit
376d83d46e
|
@ -15,6 +15,8 @@
|
|||
import inspect
|
||||
import copy
|
||||
import random
|
||||
import platform
|
||||
|
||||
import paddle
|
||||
import numpy as np
|
||||
import paddle.distributed as dist
|
||||
|
@ -86,7 +88,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
|
|||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def build_dataloader(config, mode, device, use_dali=False, seed=None):
|
||||
def build(config, mode, device, use_dali=False, seed=None):
|
||||
assert mode in [
|
||||
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
|
||||
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
|
||||
|
@ -187,3 +189,79 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
|
|||
|
||||
logger.debug("build data_loader({}) success...".format(data_loader))
|
||||
return data_loader
|
||||
|
||||
|
||||
def build_dataloader(engine):
|
||||
if "class_num" in engine.config["Global"]:
|
||||
global_class_num = engine.config["Global"]["class_num"]
|
||||
if "class_num" not in config["Arch"]:
|
||||
engine.config["Arch"]["class_num"] = global_class_num
|
||||
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
|
||||
else:
|
||||
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
|
||||
logger.warning(msg)
|
||||
|
||||
class_num = engine.config["Arch"].get("class_num", None)
|
||||
engine.config["DataLoader"].update({"class_num": class_num})
|
||||
engine.config["DataLoader"].update({
|
||||
"epochs": engine.config["Global"]["epochs"]
|
||||
})
|
||||
|
||||
use_dali = engine.config['Global'].get("use_dali", False)
|
||||
dataloader_dict = {
|
||||
"Train": None,
|
||||
"UnLabelTrain": None,
|
||||
"Eval": None,
|
||||
"Query": None,
|
||||
"Gallery": None,
|
||||
"GalleryQuery": None
|
||||
}
|
||||
if engine.mode == 'train':
|
||||
train_dataloader = build(
|
||||
engine.config["DataLoader"],
|
||||
"Train",
|
||||
engine.device,
|
||||
use_dali,
|
||||
seed=None)
|
||||
iter_per_epoch = len(train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(train_dataloader)
|
||||
if engine.config["Global"].get("iter_per_epoch", None):
|
||||
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
|
||||
iter_per_epoch = engine.config["Global"].get("iter_per_epoch")
|
||||
iter_per_epoch = iter_per_epoch // engine.update_freq * engine.update_freq
|
||||
engine.iter_per_epoch = iter_per_epoch
|
||||
train_dataloader.iter_per_epoch = iter_per_epoch
|
||||
dataloader_dict["Train"] = train_dataloader
|
||||
|
||||
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
|
||||
dataloader_dict["UnLabelTrain"] = build(
|
||||
engine.config["DataLoader"],
|
||||
"UnLabelTrain",
|
||||
engine.device,
|
||||
use_dali,
|
||||
seed=None)
|
||||
|
||||
if engine.mode == "eval" or (engine.mode == "train" and
|
||||
engine.config["Global"]["eval_during_train"]):
|
||||
if engine.eval_mode in ["classification", "adaface"]:
|
||||
dataloader_dict["Eval"] = build(
|
||||
engine.config["DataLoader"],
|
||||
"Eval",
|
||||
engine.device,
|
||||
use_dali,
|
||||
seed=None)
|
||||
elif engine.eval_mode == "retrieval":
|
||||
if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
|
||||
key = list(engine.config["DataLoader"]["Eval"].keys())[0]
|
||||
dataloader_dict["GalleryQuery"] = build_dataloader(
|
||||
engine.config["DataLoader"]["Eval"], key, engine.device,
|
||||
use_dali)
|
||||
else:
|
||||
dataloader_dict["Gallery"] = build_dataloader(
|
||||
engine.config["DataLoader"]["Eval"], "Gallery",
|
||||
engine.device, use_dali)
|
||||
dataloader_dict["Query"] = build_dataloader(
|
||||
engine.config["DataLoader"]["Eval"], "Query",
|
||||
engine.device, use_dali)
|
||||
|
||||
return dataloader_dict
|
||||
|
|
|
@ -15,7 +15,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import platform
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from visualdl import LogWriter
|
||||
|
@ -52,168 +51,60 @@ class Engine(object):
|
|||
assert mode in ["train", "eval", "infer", "export"]
|
||||
self.mode = mode
|
||||
self.config = config
|
||||
self.eval_mode = self.config["Global"].get("eval_mode",
|
||||
"classification")
|
||||
self.train_mode = self.config["Global"].get("train_mode", None)
|
||||
|
||||
# set seed
|
||||
self._init_seed()
|
||||
|
||||
# init logger
|
||||
init_logger(self.config, mode=mode)
|
||||
print_config(config)
|
||||
|
||||
# for visualdl
|
||||
self.vdl_writer = self._init_vdl()
|
||||
|
||||
# is_rec
|
||||
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
|
||||
False):
|
||||
self.is_rec = True
|
||||
else:
|
||||
self.is_rec = False
|
||||
|
||||
# set seed
|
||||
seed = self.config["Global"].get("seed", False)
|
||||
if seed or seed == 0:
|
||||
assert isinstance(seed, int), "The 'seed' must be a integer!"
|
||||
paddle.seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# init logger
|
||||
self.output_dir = self.config['Global']['output_dir']
|
||||
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
|
||||
f"{mode}.log")
|
||||
init_logger(log_file=log_file)
|
||||
print_config(config)
|
||||
|
||||
# init train_func and eval_func
|
||||
assert self.eval_mode in [
|
||||
"classification", "retrieval", "adaface"
|
||||
], logger.error("Invalid eval mode: {}".format(self.eval_mode))
|
||||
self.train_mode = self.config["Global"].get("train_mode", None)
|
||||
if self.train_mode is None:
|
||||
self.train_epoch_func = train_method.train_epoch
|
||||
else:
|
||||
self.train_epoch_func = getattr(train_method,
|
||||
"train_epoch_" + self.train_mode)
|
||||
|
||||
self.eval_mode = self.config["Global"].get("eval_mode",
|
||||
"classification")
|
||||
assert self.eval_mode in [
|
||||
"classification", "retrieval", "adaface"
|
||||
], logger.error("Invalid eval mode: {}".format(self.eval_mode))
|
||||
self.eval_func = getattr(evaluation, self.eval_mode + "_eval")
|
||||
|
||||
self.use_dali = self.config['Global'].get("use_dali", False)
|
||||
|
||||
# for visualdl
|
||||
self.vdl_writer = None
|
||||
if self.config['Global'][
|
||||
'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
|
||||
vdl_writer_path = os.path.join(self.output_dir, "vdl")
|
||||
if not os.path.exists(vdl_writer_path):
|
||||
os.makedirs(vdl_writer_path)
|
||||
self.vdl_writer = LogWriter(logdir=vdl_writer_path)
|
||||
|
||||
# set device
|
||||
assert self.config["Global"][
|
||||
"device"] in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
|
||||
self.device = paddle.set_device(self.config["Global"]["device"])
|
||||
logger.info('train with paddle {} and device {}'.format(
|
||||
paddle.__version__, self.device))
|
||||
self.device = self._init_device()
|
||||
|
||||
# gradient accumulation
|
||||
self.update_freq = self.config["Global"].get("update_freq", 1)
|
||||
|
||||
if "class_num" in config["Global"]:
|
||||
global_class_num = config["Global"]["class_num"]
|
||||
if "class_num" not in config["Arch"]:
|
||||
config["Arch"]["class_num"] = global_class_num
|
||||
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
|
||||
else:
|
||||
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
|
||||
logger.warning(msg)
|
||||
#TODO(gaotingquan): support rec
|
||||
class_num = config["Arch"].get("class_num", None)
|
||||
self.config["DataLoader"].update({"class_num": class_num})
|
||||
self.config["DataLoader"].update({
|
||||
"epochs": self.config["Global"]["epochs"]
|
||||
})
|
||||
|
||||
# build dataloader
|
||||
if self.mode == 'train':
|
||||
self.train_dataloader = build_dataloader(
|
||||
self.config["DataLoader"], "Train", self.device, self.use_dali)
|
||||
if self.config["DataLoader"].get('UnLabelTrain', None) is not None:
|
||||
self.unlabel_train_dataloader = build_dataloader(
|
||||
self.config["DataLoader"], "UnLabelTrain", self.device,
|
||||
self.use_dali)
|
||||
else:
|
||||
self.unlabel_train_dataloader = None
|
||||
|
||||
self.iter_per_epoch = len(
|
||||
self.train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(self.train_dataloader)
|
||||
if self.config["Global"].get("iter_per_epoch", None):
|
||||
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
|
||||
self.iter_per_epoch = self.config["Global"].get(
|
||||
"iter_per_epoch")
|
||||
self.iter_per_epoch = self.iter_per_epoch // self.update_freq * self.update_freq
|
||||
|
||||
if self.mode == "eval" or (self.mode == "train" and
|
||||
self.config["Global"]["eval_during_train"]):
|
||||
if self.eval_mode in ["classification", "adaface"]:
|
||||
self.eval_dataloader = build_dataloader(
|
||||
self.config["DataLoader"], "Eval", self.device,
|
||||
self.use_dali)
|
||||
elif self.eval_mode == "retrieval":
|
||||
self.gallery_query_dataloader = None
|
||||
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
|
||||
key = list(self.config["DataLoader"]["Eval"].keys())[0]
|
||||
self.gallery_query_dataloader = build_dataloader(
|
||||
self.config["DataLoader"]["Eval"], key, self.device,
|
||||
self.use_dali)
|
||||
else:
|
||||
self.gallery_dataloader = build_dataloader(
|
||||
self.config["DataLoader"]["Eval"], "Gallery",
|
||||
self.device, self.use_dali)
|
||||
self.query_dataloader = build_dataloader(
|
||||
self.config["DataLoader"]["Eval"], "Query",
|
||||
self.device, self.use_dali)
|
||||
self.dataloader_dict = build_dataloader(self)
|
||||
self.train_dataloader, self.unlabel_train_dataloader, self.eval_dataloader = self.dataloader_dict[
|
||||
"Train"], self.dataloader_dict[
|
||||
"UnLabelTrain"], self.dataloader_dict["Eval"]
|
||||
self.gallery_query_dataloader, self.gallery_dataloader, self.query_dataloader = self.dataloader_dict[
|
||||
"GalleryQuery"], self.dataloader_dict[
|
||||
"Gallery"], self.dataloader_dict["Query"]
|
||||
|
||||
# build loss
|
||||
if self.mode == "train":
|
||||
label_loss_info = self.config["Loss"]["Train"]
|
||||
self.train_loss_func = build_loss(label_loss_info)
|
||||
unlabel_loss_info = self.config.get("UnLabelLoss", {}).get("Train",
|
||||
None)
|
||||
self.unlabel_train_loss_func = build_loss(unlabel_loss_info)
|
||||
if self.mode == "eval" or (self.mode == "train" and
|
||||
self.config["Global"]["eval_during_train"]):
|
||||
loss_config = self.config.get("Loss", None)
|
||||
if loss_config is not None:
|
||||
loss_config = loss_config.get("Eval")
|
||||
if loss_config is not None:
|
||||
self.eval_loss_func = build_loss(loss_config)
|
||||
else:
|
||||
self.eval_loss_func = None
|
||||
else:
|
||||
self.eval_loss_func = None
|
||||
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
|
||||
self.config, self.mode)
|
||||
|
||||
# build metric
|
||||
if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[
|
||||
"Metric"] and self.config["Metric"]["Train"]:
|
||||
metric_config = self.config["Metric"]["Train"]
|
||||
if hasattr(self.train_dataloader, "collate_fn"
|
||||
) and self.train_dataloader.collate_fn is not None:
|
||||
for m_idx, m in enumerate(metric_config):
|
||||
if "TopkAcc" in m:
|
||||
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
|
||||
logger.warning(msg)
|
||||
metric_config.pop(m_idx)
|
||||
self.train_metric_func = build_metrics(metric_config)
|
||||
else:
|
||||
self.train_metric_func = None
|
||||
|
||||
if self.mode == "eval" or (self.mode == "train" and
|
||||
self.config["Global"]["eval_during_train"]):
|
||||
if self.eval_mode == "classification":
|
||||
if "Metric" in self.config and "Eval" in self.config["Metric"]:
|
||||
self.eval_metric_func = build_metrics(self.config["Metric"]
|
||||
["Eval"])
|
||||
else:
|
||||
self.eval_metric_func = None
|
||||
elif self.eval_mode == "retrieval":
|
||||
if "Metric" in self.config and "Eval" in self.config["Metric"]:
|
||||
metric_config = self.config["Metric"]["Eval"]
|
||||
else:
|
||||
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
|
||||
self.eval_metric_func = build_metrics(metric_config)
|
||||
else:
|
||||
self.eval_metric_func = None
|
||||
self.train_metric_func, self.eval_metric_func = build_metrics(self)
|
||||
|
||||
# build model
|
||||
self.model = build_model(self.config, self.mode)
|
||||
|
@ -221,139 +112,18 @@ class Engine(object):
|
|||
apply_to_static(self.config, self.model)
|
||||
|
||||
# load_pretrain
|
||||
if self.config["Global"]["pretrained_model"] is not None:
|
||||
if self.config["Global"]["pretrained_model"].startswith("http"):
|
||||
load_dygraph_pretrain_from_url(
|
||||
[self.model, getattr(self, 'train_loss_func', None)],
|
||||
self.config["Global"]["pretrained_model"])
|
||||
else:
|
||||
load_dygraph_pretrain(
|
||||
[self.model, getattr(self, 'train_loss_func', None)],
|
||||
self.config["Global"]["pretrained_model"])
|
||||
self._init_pretrained()
|
||||
|
||||
# build optimizer
|
||||
if self.mode == 'train':
|
||||
self.optimizer, self.lr_sch = build_optimizer(
|
||||
self.config["Optimizer"], self.config["Global"]["epochs"],
|
||||
self.iter_per_epoch // self.update_freq,
|
||||
[self.model, self.train_loss_func])
|
||||
self.optimizer, self.lr_sch = build_optimizer(
|
||||
self.config, self.train_dataloader,
|
||||
[self.model, self.train_loss_func])
|
||||
|
||||
# AMP training and evaluating
|
||||
self.amp = "AMP" in self.config and self.config["AMP"] is not None
|
||||
self.amp_eval = False
|
||||
# for amp
|
||||
if self.amp:
|
||||
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
|
||||
if paddle.is_compiled_with_cuda():
|
||||
AMP_RELATED_FLAGS_SETTING.update({
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
|
||||
})
|
||||
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
|
||||
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
|
||||
self.use_dynamic_loss_scaling = self.config["AMP"].get(
|
||||
"use_dynamic_loss_scaling", False)
|
||||
self.scaler = paddle.amp.GradScaler(
|
||||
init_loss_scaling=self.scale_loss,
|
||||
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
|
||||
|
||||
self.amp_level = self.config['AMP'].get("level", "O1")
|
||||
if self.amp_level not in ["O1", "O2"]:
|
||||
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
|
||||
logger.warning(msg)
|
||||
self.config['AMP']["level"] = "O1"
|
||||
self.amp_level = "O1"
|
||||
|
||||
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
|
||||
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
|
||||
if self.mode == "train" and self.config["Global"].get(
|
||||
"eval_during_train",
|
||||
True) and self.amp_level == "O2" and self.amp_eval == False:
|
||||
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
|
||||
logger.warning(msg)
|
||||
self.config["AMP"]["use_fp16_test"] = True
|
||||
self.amp_eval = True
|
||||
|
||||
# TODO(gaotingquan): to compatible with different versions of Paddle
|
||||
paddle_version = paddle.__version__[:3]
|
||||
# paddle version < 2.3.0 and not develop
|
||||
if paddle_version not in ["2.3", "0.0"]:
|
||||
if self.mode == "train":
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
optimizers=self.optimizer,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
elif self.amp_eval:
|
||||
if self.amp_level == "O2":
|
||||
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
|
||||
logger.warning(msg)
|
||||
self.amp_eval = False
|
||||
else:
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
# paddle version >= 2.3.0 or develop
|
||||
else:
|
||||
if self.mode == "train" or self.amp_eval:
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
if self.mode == "train" and len(self.train_loss_func.parameters(
|
||||
)) > 0:
|
||||
self.train_loss_func = paddle.amp.decorate(
|
||||
models=self.train_loss_func,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
# build EMA model
|
||||
self.ema = "EMA" in self.config and self.mode == "train"
|
||||
if self.ema:
|
||||
self.model_ema = ExponentialMovingAverage(
|
||||
self.model, self.config['EMA'].get("decay", 0.9999))
|
||||
|
||||
# check the gpu num
|
||||
world_size = dist.get_world_size()
|
||||
self.config["Global"]["distributed"] = world_size != 1
|
||||
if self.mode == "train":
|
||||
std_gpu_num = 8 if isinstance(
|
||||
self.config["Optimizer"],
|
||||
dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
|
||||
if world_size != std_gpu_num:
|
||||
msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
|
||||
logger.warning(msg)
|
||||
self._init_amp()
|
||||
|
||||
# for distributed
|
||||
if self.config["Global"]["distributed"]:
|
||||
dist.init_parallel_env()
|
||||
self.model = paddle.DataParallel(self.model)
|
||||
if self.mode == 'train' and len(self.train_loss_func.parameters(
|
||||
)) > 0:
|
||||
self.train_loss_func = paddle.DataParallel(
|
||||
self.train_loss_func)
|
||||
|
||||
# set different seed in different GPU manually in distributed environment
|
||||
if seed is None:
|
||||
logger.warning(
|
||||
"The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default"
|
||||
)
|
||||
self.config["Global"]["seed"] = seed = 42
|
||||
logger.info(
|
||||
f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
|
||||
)
|
||||
paddle.seed(int(seed) + dist.get_rank())
|
||||
np.random.seed(int(seed) + dist.get_rank())
|
||||
random.seed(int(seed) + dist.get_rank())
|
||||
|
||||
# build postprocess for infer
|
||||
if self.mode == 'infer':
|
||||
self.preprocess_func = create_operators(self.config["Infer"][
|
||||
"transforms"])
|
||||
self.postprocess_func = build_postprocess(self.config["Infer"][
|
||||
"PostProcess"])
|
||||
self._init_dist()
|
||||
|
||||
def train(self):
|
||||
assert self.mode == "train"
|
||||
|
@ -363,10 +133,17 @@ class Engine(object):
|
|||
"metric": -1.0,
|
||||
"epoch": 0,
|
||||
}
|
||||
ema_module = None
|
||||
|
||||
# build EMA model
|
||||
self.ema = "EMA" in self.config and self.mode == "train"
|
||||
if self.ema:
|
||||
self.model_ema = ExponentialMovingAverage(
|
||||
self.model, self.config['EMA'].get("decay", 0.9999))
|
||||
best_metric_ema = 0.0
|
||||
ema_module = self.model_ema.module
|
||||
else:
|
||||
ema_module = None
|
||||
|
||||
# key:
|
||||
# val: metrics list word
|
||||
self.output_info = dict()
|
||||
|
@ -392,8 +169,6 @@ class Engine(object):
|
|||
# for one epoch train
|
||||
self.train_epoch_func(self, epoch_id, print_batch_step)
|
||||
|
||||
if self.use_dali:
|
||||
self.train_dataloader.reset()
|
||||
metric_msg = ", ".join(
|
||||
[self.output_info[key].avg_info for key in self.output_info])
|
||||
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
|
||||
|
@ -499,6 +274,12 @@ class Engine(object):
|
|||
@paddle.no_grad()
|
||||
def infer(self):
|
||||
assert self.mode == "infer" and self.eval_mode == "classification"
|
||||
|
||||
self.preprocess_func = create_operators(self.config["Infer"][
|
||||
"transforms"])
|
||||
self.postprocess_func = build_postprocess(self.config["Infer"][
|
||||
"PostProcess"])
|
||||
|
||||
total_trainer = dist.get_world_size()
|
||||
local_rank = dist.get_rank()
|
||||
image_list = get_image_list(self.config["Infer"]["infer_imgs"])
|
||||
|
@ -586,6 +367,148 @@ class Engine(object):
|
|||
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
|
||||
)
|
||||
|
||||
def _init_vdl(self):
|
||||
if self.config['Global'][
|
||||
'use_visualdl'] and mode == "train" and dist.get_rank() == 0:
|
||||
vdl_writer_path = os.path.join(self.output_dir, "vdl")
|
||||
if not os.path.exists(vdl_writer_path):
|
||||
os.makedirs(vdl_writer_path)
|
||||
return LogWriter(logdir=vdl_writer_path)
|
||||
return None
|
||||
|
||||
def _init_seed(self):
|
||||
seed = self.config["Global"].get("seed", False)
|
||||
if dist.get_world_size() != 1:
|
||||
# if self.config["Global"]["distributed"]:
|
||||
# set different seed in different GPU manually in distributed environment
|
||||
if not seed:
|
||||
logger.warning(
|
||||
"The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default"
|
||||
)
|
||||
self.config["Global"]["seed"] = seed = 42
|
||||
logger.info(
|
||||
f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
|
||||
)
|
||||
dist_seed = int(seed) + dist.get_rank()
|
||||
paddle.seed(dist_seed)
|
||||
np.random.seed(dist_seed)
|
||||
random.seed(dist_seed)
|
||||
elif seed or seed == 0:
|
||||
assert isinstance(seed, int), "The 'seed' must be a integer!"
|
||||
paddle.seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
def _init_device(self):
|
||||
device = self.config["Global"]["device"]
|
||||
assert device in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
|
||||
logger.info('train with paddle {} and device {}'.format(
|
||||
paddle.__version__, device))
|
||||
return paddle.set_device(device)
|
||||
|
||||
def _init_pretrained(self):
|
||||
if self.config["Global"]["pretrained_model"] is not None:
|
||||
if self.config["Global"]["pretrained_model"].startswith("http"):
|
||||
load_dygraph_pretrain_from_url(
|
||||
[self.model, getattr(self, 'train_loss_func', None)],
|
||||
self.config["Global"]["pretrained_model"])
|
||||
else:
|
||||
load_dygraph_pretrain(
|
||||
[self.model, getattr(self, 'train_loss_func', None)],
|
||||
self.config["Global"]["pretrained_model"])
|
||||
|
||||
def _init_amp(self):
|
||||
self.amp = "AMP" in self.config and self.config["AMP"] is not None
|
||||
self.amp_eval = False
|
||||
# for amp
|
||||
if self.amp:
|
||||
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
|
||||
if paddle.is_compiled_with_cuda():
|
||||
AMP_RELATED_FLAGS_SETTING.update({
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
|
||||
})
|
||||
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
|
||||
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
|
||||
self.use_dynamic_loss_scaling = self.config["AMP"].get(
|
||||
"use_dynamic_loss_scaling", False)
|
||||
self.scaler = paddle.amp.GradScaler(
|
||||
init_loss_scaling=self.scale_loss,
|
||||
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
|
||||
|
||||
self.amp_level = self.config['AMP'].get("level", "O1")
|
||||
if self.amp_level not in ["O1", "O2"]:
|
||||
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
|
||||
logger.warning(msg)
|
||||
self.config['AMP']["level"] = "O1"
|
||||
self.amp_level = "O1"
|
||||
|
||||
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
|
||||
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
|
||||
if self.mode == "train" and self.config["Global"].get(
|
||||
"eval_during_train",
|
||||
True) and self.amp_level == "O2" and self.amp_eval == False:
|
||||
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
|
||||
logger.warning(msg)
|
||||
self.config["AMP"]["use_fp16_test"] = True
|
||||
self.amp_eval = True
|
||||
|
||||
# TODO(gaotingquan): to compatible with different versions of Paddle
|
||||
paddle_version = paddle.__version__[:3]
|
||||
# paddle version < 2.3.0 and not develop
|
||||
if paddle_version not in ["2.3", "0.0"]:
|
||||
if self.mode == "train":
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
optimizers=self.optimizer,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
elif self.amp_eval:
|
||||
if self.amp_level == "O2":
|
||||
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
|
||||
logger.warning(msg)
|
||||
self.amp_eval = False
|
||||
else:
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
# paddle version >= 2.3.0 or develop
|
||||
else:
|
||||
if self.mode == "train" or self.amp_eval:
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
if self.mode == "train" and len(self.train_loss_func.parameters(
|
||||
)) > 0:
|
||||
self.train_loss_func = paddle.amp.decorate(
|
||||
models=self.train_loss_func,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
def _init_dist(self):
|
||||
# check the gpu num
|
||||
world_size = dist.get_world_size()
|
||||
self.config["Global"]["distributed"] = world_size != 1
|
||||
# TODO(gaotingquan):
|
||||
if self.mode == "train":
|
||||
std_gpu_num = 8 if isinstance(
|
||||
self.config["Optimizer"],
|
||||
dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
|
||||
if world_size != std_gpu_num:
|
||||
msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
|
||||
logger.warning(msg)
|
||||
|
||||
if self.config["Global"]["distributed"]:
|
||||
dist.init_parallel_env()
|
||||
self.model = paddle.DataParallel(self.model)
|
||||
if self.mode == 'train' and len(self.train_loss_func.parameters(
|
||||
)) > 0:
|
||||
self.train_loss_func = paddle.DataParallel(
|
||||
self.train_loss_func)
|
||||
|
||||
|
||||
class ExportModel(TheseusLayer):
|
||||
"""
|
||||
|
|
|
@ -51,7 +51,7 @@ from .metabinloss import IntraDomainScatterLoss
|
|||
class CombinedLoss(nn.Layer):
|
||||
def __init__(self, config_list):
|
||||
super().__init__()
|
||||
self.loss_func = []
|
||||
loss_func = []
|
||||
self.loss_weight = []
|
||||
assert isinstance(config_list, list), (
|
||||
'operator config should be a list')
|
||||
|
@ -63,8 +63,9 @@ class CombinedLoss(nn.Layer):
|
|||
assert "weight" in param, "weight must be in param, but param just contains {}".format(
|
||||
param.keys())
|
||||
self.loss_weight.append(param.pop("weight"))
|
||||
self.loss_func.append(eval(name)(**param))
|
||||
self.loss_func = nn.LayerList(self.loss_func)
|
||||
loss_func.append(eval(name)(**param))
|
||||
self.loss_func = nn.LayerList(loss_func)
|
||||
logger.debug("build loss {} success.".format(loss_func))
|
||||
|
||||
def __call__(self, input, batch):
|
||||
loss_dict = {}
|
||||
|
@ -83,9 +84,22 @@ class CombinedLoss(nn.Layer):
|
|||
return loss_dict
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
if config is None:
|
||||
return None
|
||||
module_class = CombinedLoss(copy.deepcopy(config))
|
||||
logger.debug("build loss {} success.".format(module_class))
|
||||
return module_class
|
||||
def build_loss(config, mode="train"):
|
||||
train_loss_func, unlabel_train_loss_func, eval_loss_func = None, None, None
|
||||
if mode == "train":
|
||||
label_loss_info = config["Loss"]["Train"]
|
||||
if label_loss_info:
|
||||
train_loss_func = CombinedLoss(copy.deepcopy(label_loss_info))
|
||||
unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None)
|
||||
if unlabel_loss_info:
|
||||
unlabel_train_loss_func = CombinedLoss(
|
||||
copy.deepcopy(unlabel_loss_info))
|
||||
if mode == "eval" or (mode == "train" and
|
||||
config["Global"]["eval_during_train"]):
|
||||
loss_config = config.get("Loss", None)
|
||||
if loss_config is not None:
|
||||
loss_config = loss_config.get("Eval")
|
||||
if loss_config is not None:
|
||||
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
|
||||
|
||||
return train_loss_func, unlabel_train_loss_func, eval_loss_func
|
||||
|
|
|
@ -65,6 +65,38 @@ class CombinedMetrics(AvgMetrics):
|
|||
metric.reset()
|
||||
|
||||
|
||||
def build_metrics(config):
|
||||
metrics_list = CombinedMetrics(copy.deepcopy(config))
|
||||
return metrics_list
|
||||
def build_metrics(engine):
|
||||
config, mode = engine.config, engine.mode
|
||||
if mode == 'train' and "Metric" in config and "Train" in config[
|
||||
"Metric"] and config["Metric"]["Train"]:
|
||||
metric_config = config["Metric"]["Train"]
|
||||
if hasattr(engine.train_dataloader, "collate_fn"
|
||||
) and engine.train_dataloader.collate_fn is not None:
|
||||
for m_idx, m in enumerate(metric_config):
|
||||
if "TopkAcc" in m:
|
||||
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
|
||||
logger.warning(msg)
|
||||
metric_config.pop(m_idx)
|
||||
train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
|
||||
else:
|
||||
train_metric_func = None
|
||||
|
||||
if mode == "eval" or (mode == "train" and
|
||||
config["Global"]["eval_during_train"]):
|
||||
eval_mode = config["Global"].get("eval_mode", "classification")
|
||||
if eval_mode == "classification":
|
||||
if "Metric" in config and "Eval" in config["Metric"]:
|
||||
eval_metric_func = CombinedMetrics(
|
||||
copy.deepcopy(config["Metric"]["Eval"]))
|
||||
else:
|
||||
eval_metric_func = None
|
||||
elif eval_mode == "retrieval":
|
||||
if "Metric" in config and "Eval" in config["Metric"]:
|
||||
metric_config = config["Metric"]["Eval"]
|
||||
else:
|
||||
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
|
||||
eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
|
||||
else:
|
||||
eval_metric_func = None
|
||||
|
||||
return train_metric_func, eval_metric_func
|
||||
|
|
|
@ -45,8 +45,11 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
|
|||
|
||||
|
||||
# model_list is None in static graph
|
||||
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
||||
optim_config = copy.deepcopy(config)
|
||||
def build_optimizer(config, dataloader, model_list=None):
|
||||
optim_config = copy.deepcopy(config["Optimizer"])
|
||||
epochs = config["Global"]["epochs"]
|
||||
update_freq = config["Global"].get("update_freq", 1)
|
||||
step_each_epoch = dataloader.iter_per_epoch // update_freq
|
||||
if isinstance(optim_config, dict):
|
||||
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
|
||||
optim_name = optim_config.pop("name")
|
||||
|
|
|
@ -22,16 +22,15 @@ import paddle.distributed as dist
|
|||
_logger = None
|
||||
|
||||
|
||||
def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
|
||||
def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO):
|
||||
"""Initialize and get a logger by name.
|
||||
If the logger has not been initialized, this method will initialize the
|
||||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
be directly returned. During initialization, a StreamHandler will always be
|
||||
added. If `log_file` is specified a FileHandler will also be added.
|
||||
added.
|
||||
Args:
|
||||
config(dict): Training config.
|
||||
name (str): Logger name.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the logger.
|
||||
log_level (int): The logger level. Note that only the process of
|
||||
rank 0 is affected, and other processes will set the level to
|
||||
"Error" thus be silent most of the time.
|
||||
|
@ -63,6 +62,8 @@ def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
|
|||
if init_flag:
|
||||
_logger.addHandler(stream_handler)
|
||||
|
||||
log_file = os.path.join(config['Global']['output_dir'],
|
||||
config["Arch"]["name"], f"{mode}.log")
|
||||
if log_file is not None and dist.get_rank() == 0:
|
||||
log_file_folder = os.path.split(log_file)[0]
|
||||
os.makedirs(log_file_folder, exist_ok=True)
|
||||
|
|
Loading…
Reference in New Issue