modify attr export model

pull/1960/head
zhiboniu 2022-05-25 08:13:38 +00:00
parent 9cf1abaee6
commit a4e1da6610
3 changed files with 6 additions and 4 deletions

View File

@ -287,7 +287,8 @@ class ResNet(TheseusLayer):
data_format="NCHW",
input_image_channel=3,
return_patterns=None,
return_stages=None):
return_stages=None,
**kargs):
super().__init__()
self.cfg = config

View File

@ -20,6 +20,7 @@ Arch:
name: "ResNet50"
pretrained: True
class_num: 26
infer_add_softmax: False
# loss function config for traing/eval process
Loss:
@ -110,5 +111,3 @@ DataLoader:
Metric:
Eval:
- ATTRMetric:

View File

@ -457,7 +457,9 @@ class Engine(object):
def export(self):
assert self.mode == "export"
use_multilabel = self.config["Global"].get("use_multilabel", False)
use_multilabel = self.config["Global"].get(
"use_multilabel",
False) and not "ATTRMetric" in self.config["Metric"]["Eval"][0]
model = ExportModel(self.config["Arch"], self.model, use_multilabel)
if self.config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(model.base_model,