diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index ef24094c2..05b9f3913 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -221,7 +221,7 @@ class Engine(object): AMP_RELATED_FLAGS_SETTING.update({ 'FLAGS_cudnn_batchnorm_spatial_persistent': 1 }) - paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + paddle.set_flags(AMP_RELATED_FLAGS_SETTING) self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) self.use_dynamic_loss_scaling = self.config["AMP"].get( diff --git a/ppcls/static/save_load.py b/ppcls/static/save_load.py index 13badfddc..5d124fcf7 100644 --- a/ppcls/static/save_load.py +++ b/ppcls/static/save_load.py @@ -62,8 +62,8 @@ def load_params(exe, prog, path, ignore_params=None): """ Load model from the given path. Args: - exe (fluid.Executor): The fluid.Executor object. - prog (fluid.Program): load weight to which Program object. + exe (paddle.static.Executor): The paddle.static.Executor object. + prog (paddle.static.Program): load weight to which Program object. path (string): URL string or loca model path. ignore_params (list): ignore variable to load when finetuning. It can be specified by finetune_exclude_pretrained_params diff --git a/ppcls/static/train.py b/ppcls/static/train.py index eb803970b..86e832499 100644 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -87,7 +87,7 @@ def main(args): 'FLAGS_max_inplace_grad_add': 8, } os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1' - paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + paddle.set_flags(AMP_RELATED_FLAGS_SETTING) use_xpu = global_config.get("use_xpu", False) use_npu = global_config.get("use_npu", False) diff --git a/ppcls/utils/download.py b/ppcls/utils/download.py index 9c4575048..51d454388 100644 --- a/ppcls/utils/download.py +++ b/ppcls/utils/download.py @@ -112,7 +112,7 @@ def get_path_from_url(url, str: a local path to save downloaded models & weights & datasets. """ - from paddle.fluid.dygraph.parallel import ParallelEnv + from paddle.distributed import ParallelEnv assert is_url(url), "downloading from {} not a url".format(url) # parse path after download to decompress under root_dir