mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix data dtype for amp training
This commit is contained in:
parent
731006f1fc
commit
e7bef51f9e
@ -242,11 +242,14 @@ def build(config,
|
||||
mode = "Train" if is_train else "Eval"
|
||||
use_mix = "batch_transform_ops" in config["DataLoader"][mode][
|
||||
"dataset"]
|
||||
data_dtype = "float32"
|
||||
if 'AMP' in config and config["AMP"]["level"] == 'O2':
|
||||
data_dtype = "float16"
|
||||
feeds = create_feeds(
|
||||
config["Global"]["image_shape"],
|
||||
use_mix,
|
||||
class_num=class_num,
|
||||
dtype="float32")
|
||||
dtype=data_dtype)
|
||||
|
||||
# build model
|
||||
# data_format should be assigned in arch-dict
|
||||
|
Loading…
x
Reference in New Issue
Block a user