mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix data dtype for amp training
This commit is contained in:
parent
731006f1fc
commit
e7bef51f9e
@ -242,11 +242,14 @@ def build(config,
|
|||||||
mode = "Train" if is_train else "Eval"
|
mode = "Train" if is_train else "Eval"
|
||||||
use_mix = "batch_transform_ops" in config["DataLoader"][mode][
|
use_mix = "batch_transform_ops" in config["DataLoader"][mode][
|
||||||
"dataset"]
|
"dataset"]
|
||||||
|
data_dtype = "float32"
|
||||||
|
if 'AMP' in config and config["AMP"]["level"] == 'O2':
|
||||||
|
data_dtype = "float16"
|
||||||
feeds = create_feeds(
|
feeds = create_feeds(
|
||||||
config["Global"]["image_shape"],
|
config["Global"]["image_shape"],
|
||||||
use_mix,
|
use_mix,
|
||||||
class_num=class_num,
|
class_num=class_num,
|
||||||
dtype="float32")
|
dtype=data_dtype)
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
# data_format should be assigned in arch-dict
|
# data_format should be assigned in arch-dict
|
||||||
|
Loading…
x
Reference in New Issue
Block a user