refactor: adapt to static graph in deprecating MixCELoss
parent
5025d09fd8
commit
ed459a2a16
|
@ -112,6 +112,7 @@ class Engine(object):
|
|||
}
|
||||
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
|
||||
#TODO(gaotingquan): support rec
|
||||
class_num = config["Arch"].get("class_num", None)
|
||||
self.config["DataLoader"].update({"class_num": class_num})
|
||||
# build dataloader
|
||||
|
|
|
@ -41,13 +41,14 @@ from ppcls.utils.misc import AverageMeter
|
|||
from ppcls.utils import logger, profiler
|
||||
|
||||
|
||||
def create_feeds(image_shape, use_mix=None, dtype="float32"):
|
||||
def create_feeds(image_shape, use_mix=False, class_num=None, dtype="float32"):
|
||||
"""
|
||||
Create feeds as model input
|
||||
|
||||
Args:
|
||||
image_shape(list[int]): model input shape, such as [3, 224, 224]
|
||||
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
class_num(int): the class number of network, required if use_mix
|
||||
|
||||
Returns:
|
||||
feeds(dict): dict of model input variables
|
||||
|
@ -55,13 +56,14 @@ def create_feeds(image_shape, use_mix=None, dtype="float32"):
|
|||
feeds = OrderedDict()
|
||||
feeds['data'] = paddle.static.data(
|
||||
name="data", shape=[None] + image_shape, dtype=dtype)
|
||||
|
||||
if use_mix:
|
||||
feeds['y_a'] = paddle.static.data(
|
||||
name="y_a", shape=[None, 1], dtype="int64")
|
||||
feeds['y_b'] = paddle.static.data(
|
||||
name="y_b", shape=[None, 1], dtype="int64")
|
||||
feeds['lam'] = paddle.static.data(
|
||||
name="lam", shape=[None, 1], dtype=dtype)
|
||||
if class_num is None:
|
||||
msg = "When use MixUp, CutMix and so on, you must set class_num."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
feeds['target'] = paddle.static.data(
|
||||
name="target", shape=[None, class_num], dtype="float32")
|
||||
else:
|
||||
feeds['label'] = paddle.static.data(
|
||||
name="label", shape=[None, 1], dtype="int64")
|
||||
|
@ -74,6 +76,7 @@ def create_fetchs(out,
|
|||
architecture,
|
||||
topk=5,
|
||||
epsilon=None,
|
||||
class_num=None,
|
||||
use_mix=False,
|
||||
config=None,
|
||||
mode="Train"):
|
||||
|
@ -88,6 +91,7 @@ def create_fetchs(out,
|
|||
name(such as ResNet50) is needed
|
||||
topk(int): usually top5
|
||||
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
|
||||
class_num(int): the class number of network, required if use_mix
|
||||
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
config(dict): model config
|
||||
|
||||
|
@ -97,18 +101,16 @@ def create_fetchs(out,
|
|||
fetchs = OrderedDict()
|
||||
# build loss
|
||||
if use_mix:
|
||||
y_a = paddle.reshape(feeds['y_a'], [-1, 1])
|
||||
y_b = paddle.reshape(feeds['y_b'], [-1, 1])
|
||||
lam = paddle.reshape(feeds['lam'], [-1, 1])
|
||||
if class_num is None:
|
||||
msg = "When use MixUp, CutMix and so on, you must set class_num."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
target = paddle.reshape(feeds['target'], [-1, class_num])
|
||||
else:
|
||||
target = paddle.reshape(feeds['label'], [-1, 1])
|
||||
|
||||
loss_func = build_loss(config["Loss"][mode])
|
||||
|
||||
if use_mix:
|
||||
loss_dict = loss_func(out, [y_a, y_b, lam])
|
||||
else:
|
||||
loss_dict = loss_func(out, target)
|
||||
loss_dict = loss_func(out, target)
|
||||
|
||||
loss_out = loss_dict["loss"]
|
||||
fetchs['loss'] = (loss_out, AverageMeter('loss', '7.4f', need_avg=True))
|
||||
|
@ -218,6 +220,7 @@ def mixed_precision_optimizer(config, optimizer):
|
|||
def build(config,
|
||||
main_prog,
|
||||
startup_prog,
|
||||
class_num=None,
|
||||
step_each_epoch=100,
|
||||
is_train=True,
|
||||
is_distributed=True):
|
||||
|
@ -233,6 +236,7 @@ def build(config,
|
|||
config(dict): config
|
||||
main_prog(): main program
|
||||
startup_prog(): startup program
|
||||
class_num(int): the class number of network, required if use_mix
|
||||
is_train(bool): train or eval
|
||||
is_distributed(bool): whether to use distributed training method
|
||||
|
||||
|
@ -245,10 +249,10 @@ def build(config,
|
|||
mode = "Train" if is_train else "Eval"
|
||||
use_mix = "batch_transform_ops" in config["DataLoader"][mode][
|
||||
"dataset"]
|
||||
use_dali = config["Global"].get('use_dali', False)
|
||||
feeds = create_feeds(
|
||||
config["Global"]["image_shape"],
|
||||
use_mix=use_mix,
|
||||
use_mix,
|
||||
class_num=class_num,
|
||||
dtype="float32")
|
||||
|
||||
# build model
|
||||
|
@ -264,6 +268,7 @@ def build(config,
|
|||
feeds,
|
||||
config["Arch"],
|
||||
epsilon=config.get('ls_epsilon'),
|
||||
class_num=class_num,
|
||||
use_mix=use_mix,
|
||||
config=config,
|
||||
mode=mode)
|
||||
|
|
|
@ -115,6 +115,8 @@ def main(args):
|
|||
eval_dataloader = None
|
||||
use_dali = global_config.get('use_dali', False)
|
||||
|
||||
class_num = config["Arch"].get("class_num", None)
|
||||
config["DataLoader"].update({"class_num": class_num})
|
||||
train_dataloader = build_dataloader(
|
||||
config["DataLoader"], "Train", device=device, use_dali=use_dali)
|
||||
if global_config["eval_during_train"]:
|
||||
|
@ -134,6 +136,7 @@ def main(args):
|
|||
config,
|
||||
train_prog,
|
||||
startup_prog,
|
||||
class_num,
|
||||
step_each_epoch=step_each_epoch,
|
||||
is_train=True,
|
||||
is_distributed=global_config.get("is_distributed", True))
|
||||
|
|
Loading…
Reference in New Issue