mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
update export model
This commit is contained in:
parent
9e4a1045f0
commit
a445ff1f28
@ -29,7 +29,7 @@ from ppcls.arch import build_model
|
|||||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||||
|
|
||||||
|
|
||||||
class ClasModel(nn.Layer):
|
class ExportModel(nn.Layer):
|
||||||
"""
|
"""
|
||||||
ClasModel: add softmax onto the model
|
ClasModel: add softmax onto the model
|
||||||
"""
|
"""
|
||||||
@ -37,7 +37,11 @@ class ClasModel(nn.Layer):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.base_model = build_model(config)
|
self.base_model = build_model(config)
|
||||||
self.softmax = nn.Softmax(axis=-1)
|
self.infer_output_key = config.get("infer_output_key")
|
||||||
|
if config.get("infer_add_softmax", False):
|
||||||
|
self.softmax = nn.Softmax(axis=-1)
|
||||||
|
else:
|
||||||
|
self.softmax = None
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
self.training = False
|
self.training = False
|
||||||
@ -47,7 +51,10 @@ class ClasModel(nn.Layer):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.base_model(x)
|
x = self.base_model(x)
|
||||||
x = self.softmax(x)
|
if self.infer_output_key is not None:
|
||||||
|
x = x[self.infer_output_key]
|
||||||
|
if self.softmax is not None:
|
||||||
|
x = self.softmax(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -57,8 +64,7 @@ if __name__ == "__main__":
|
|||||||
# set device
|
# set device
|
||||||
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
|
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
|
||||||
device = paddle.set_device(config["Global"]["device"])
|
device = paddle.set_device(config["Global"]["device"])
|
||||||
|
model = ExportModel(config["Arch"])
|
||||||
model = ClasModel(config["Arch"])
|
|
||||||
|
|
||||||
if config["Global"]["pretrained_model"] is not None:
|
if config["Global"]["pretrained_model"] is not None:
|
||||||
load_dygraph_pretrain(model.base_model,
|
load_dygraph_pretrain(model.base_model,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user