update export model
parent
9e4a1045f0
commit
a445ff1f28
|
@ -29,7 +29,7 @@ from ppcls.arch import build_model
|
|||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
|
||||
|
||||
class ClasModel(nn.Layer):
|
||||
class ExportModel(nn.Layer):
|
||||
"""
|
||||
ClasModel: add softmax onto the model
|
||||
"""
|
||||
|
@ -37,7 +37,11 @@ class ClasModel(nn.Layer):
|
|||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.base_model = build_model(config)
|
||||
self.softmax = nn.Softmax(axis=-1)
|
||||
self.infer_output_key = config.get("infer_output_key")
|
||||
if config.get("infer_add_softmax", False):
|
||||
self.softmax = nn.Softmax(axis=-1)
|
||||
else:
|
||||
self.softmax = None
|
||||
|
||||
def eval(self):
|
||||
self.training = False
|
||||
|
@ -47,7 +51,10 @@ class ClasModel(nn.Layer):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.base_model(x)
|
||||
x = self.softmax(x)
|
||||
if self.infer_output_key is not None:
|
||||
x = x[self.infer_output_key]
|
||||
if self.softmax is not None:
|
||||
x = self.softmax(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -57,8 +64,7 @@ if __name__ == "__main__":
|
|||
# set device
|
||||
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
|
||||
device = paddle.set_device(config["Global"]["device"])
|
||||
|
||||
model = ClasModel(config["Arch"])
|
||||
model = ExportModel(config["Arch"])
|
||||
|
||||
if config["Global"]["pretrained_model"] is not None:
|
||||
load_dygraph_pretrain(model.base_model,
|
||||
|
|
Loading…
Reference in New Issue