update export model

This commit is contained in:
weishengyu 2021-06-05 17:36:24 +08:00
parent 9e4a1045f0
commit a445ff1f28

View File

@ -29,7 +29,7 @@ from ppcls.arch import build_model
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
class ClasModel(nn.Layer): class ExportModel(nn.Layer):
""" """
ClasModel: add softmax onto the model ClasModel: add softmax onto the model
""" """
@ -37,7 +37,11 @@ class ClasModel(nn.Layer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.base_model = build_model(config) self.base_model = build_model(config)
self.softmax = nn.Softmax(axis=-1) self.infer_output_key = config.get("infer_output_key")
if config.get("infer_add_softmax", False):
self.softmax = nn.Softmax(axis=-1)
else:
self.softmax = None
def eval(self): def eval(self):
self.training = False self.training = False
@ -47,7 +51,10 @@ class ClasModel(nn.Layer):
def forward(self, x): def forward(self, x):
x = self.base_model(x) x = self.base_model(x)
x = self.softmax(x) if self.infer_output_key is not None:
x = x[self.infer_output_key]
if self.softmax is not None:
x = self.softmax(x)
return x return x
@ -57,8 +64,7 @@ if __name__ == "__main__":
# set device # set device
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"] assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
device = paddle.set_device(config["Global"]["device"]) device = paddle.set_device(config["Global"]["device"])
model = ExportModel(config["Arch"])
model = ClasModel(config["Arch"])
if config["Global"]["pretrained_model"] is not None: if config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(model.base_model, load_dygraph_pretrain(model.base_model,