mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix googlenet infer
This commit is contained in:
parent
c492e1b2a2
commit
01f1314fbd
@ -555,6 +555,8 @@ class Trainer(object):
|
|||||||
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
|
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
|
||||||
batch_tensor = paddle.to_tensor(batch_data)
|
batch_tensor = paddle.to_tensor(batch_data)
|
||||||
out = self.model(batch_tensor)
|
out = self.model(batch_tensor)
|
||||||
|
if isinstance(out, list):
|
||||||
|
out = out[0]
|
||||||
result = postprocess_func(out, image_file_list)
|
result = postprocess_func(out, image_file_list)
|
||||||
print(result)
|
print(result)
|
||||||
batch_data.clear()
|
batch_data.clear()
|
||||||
|
@ -38,6 +38,7 @@ class ExportModel(nn.Layer):
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
print (config)
|
||||||
self.base_model = build_model(config)
|
self.base_model = build_model(config)
|
||||||
|
|
||||||
# we should choose a final model to export
|
# we should choose a final model to export
|
||||||
@ -63,6 +64,8 @@ class ExportModel(nn.Layer):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.base_model(x)
|
x = self.base_model(x)
|
||||||
|
if isinstance(x, list):
|
||||||
|
x = x[0]
|
||||||
if self.infer_model_name is not None:
|
if self.infer_model_name is not None:
|
||||||
x = x[self.infer_model_name]
|
x = x[self.infer_model_name]
|
||||||
if self.infer_output_key is not None:
|
if self.infer_output_key is not None:
|
||||||
@ -76,7 +79,6 @@ if __name__ == "__main__":
|
|||||||
args = config.parse_args()
|
args = config.parse_args()
|
||||||
config = config.get_config(
|
config = config.get_config(
|
||||||
args.config, overrides=args.override, show=False)
|
args.config, overrides=args.override, show=False)
|
||||||
|
|
||||||
log_file = os.path.join(config['Global']['output_dir'],
|
log_file = os.path.join(config['Global']['output_dir'],
|
||||||
config["Arch"]["name"], "export.log")
|
config["Arch"]["name"], "export.log")
|
||||||
init_logger(name='root', log_file=log_file)
|
init_logger(name='root', log_file=log_file)
|
||||||
@ -86,7 +88,6 @@ if __name__ == "__main__":
|
|||||||
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
|
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
|
||||||
device = paddle.set_device(config["Global"]["device"])
|
device = paddle.set_device(config["Global"]["device"])
|
||||||
model = ExportModel(config["Arch"])
|
model = ExportModel(config["Arch"])
|
||||||
|
|
||||||
if config["Global"]["pretrained_model"] is not None:
|
if config["Global"]["pretrained_model"] is not None:
|
||||||
load_dygraph_pretrain(model.base_model,
|
load_dygraph_pretrain(model.base_model,
|
||||||
config["Global"]["pretrained_model"])
|
config["Global"]["pretrained_model"])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user