fix data dtype for amp training

This commit is contained in:
zhangting2020 2023-04-26 07:37:04 +00:00 committed by Tingquan Gao
parent 731006f1fc
commit e7bef51f9e

View File

@ -242,11 +242,14 @@ def build(config,
mode = "Train" if is_train else "Eval"
use_mix = "batch_transform_ops" in config["DataLoader"][mode][
"dataset"]
data_dtype = "float32"
if 'AMP' in config and config["AMP"]["level"] == 'O2':
data_dtype = "float16"
feeds = create_feeds(
config["Global"]["image_shape"],
use_mix,
class_num=class_num,
dtype="float32")
dtype=data_dtype)
# build model
# data_format should be assigned in arch-dict