From 01f1314fbdd1b36eafe796ccd2018fb0ba32efd2 Mon Sep 17 00:00:00 2001 From: cuicheng01 Date: Tue, 29 Jun 2021 14:35:26 +0000 Subject: [PATCH] fix googlenet infer --- ppcls/engine/trainer.py | 2 ++ tools/export_model.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index 80957757c..ed042161e 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -555,6 +555,8 @@ class Trainer(object): if len(batch_data) >= batch_size or idx == len(image_list) - 1: batch_tensor = paddle.to_tensor(batch_data) out = self.model(batch_tensor) + if isinstance(out, list): + out = out[0] result = postprocess_func(out, image_file_list) print(result) batch_data.clear() diff --git a/tools/export_model.py b/tools/export_model.py index 6ce369011..ad0e45b76 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -38,6 +38,7 @@ class ExportModel(nn.Layer): def __init__(self, config): super().__init__() + print (config) self.base_model = build_model(config) # we should choose a final model to export @@ -63,6 +64,8 @@ class ExportModel(nn.Layer): def forward(self, x): x = self.base_model(x) + if isinstance(x, list): + x = x[0] if self.infer_model_name is not None: x = x[self.infer_model_name] if self.infer_output_key is not None: @@ -76,7 +79,6 @@ if __name__ == "__main__": args = config.parse_args() config = config.get_config( args.config, overrides=args.override, show=False) - log_file = os.path.join(config['Global']['output_dir'], config["Arch"]["name"], "export.log") init_logger(name='root', log_file=log_file) @@ -86,7 +88,6 @@ if __name__ == "__main__": assert config["Global"]["device"] in ["cpu", "gpu", "xpu"] device = paddle.set_device(config["Global"]["device"]) model = ExportModel(config["Arch"]) - if config["Global"]["pretrained_model"] is not None: load_dygraph_pretrain(model.base_model, config["Global"]["pretrained_model"])