mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix googlenet infer
This commit is contained in:
parent
c492e1b2a2
commit
01f1314fbd
@ -555,6 +555,8 @@ class Trainer(object):
|
||||
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
|
||||
batch_tensor = paddle.to_tensor(batch_data)
|
||||
out = self.model(batch_tensor)
|
||||
if isinstance(out, list):
|
||||
out = out[0]
|
||||
result = postprocess_func(out, image_file_list)
|
||||
print(result)
|
||||
batch_data.clear()
|
||||
|
@ -38,6 +38,7 @@ class ExportModel(nn.Layer):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
print (config)
|
||||
self.base_model = build_model(config)
|
||||
|
||||
# we should choose a final model to export
|
||||
@ -63,6 +64,8 @@ class ExportModel(nn.Layer):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base_model(x)
|
||||
if isinstance(x, list):
|
||||
x = x[0]
|
||||
if self.infer_model_name is not None:
|
||||
x = x[self.infer_model_name]
|
||||
if self.infer_output_key is not None:
|
||||
@ -76,7 +79,6 @@ if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
config = config.get_config(
|
||||
args.config, overrides=args.override, show=False)
|
||||
|
||||
log_file = os.path.join(config['Global']['output_dir'],
|
||||
config["Arch"]["name"], "export.log")
|
||||
init_logger(name='root', log_file=log_file)
|
||||
@ -86,7 +88,6 @@ if __name__ == "__main__":
|
||||
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
|
||||
device = paddle.set_device(config["Global"]["device"])
|
||||
model = ExportModel(config["Arch"])
|
||||
|
||||
if config["Global"]["pretrained_model"] is not None:
|
||||
load_dygraph_pretrain(model.base_model,
|
||||
config["Global"]["pretrained_model"])
|
||||
|
Loading…
x
Reference in New Issue
Block a user