modify attr export model
parent
9cf1abaee6
commit
a4e1da6610
|
@ -287,7 +287,8 @@ class ResNet(TheseusLayer):
|
|||
data_format="NCHW",
|
||||
input_image_channel=3,
|
||||
return_patterns=None,
|
||||
return_stages=None):
|
||||
return_stages=None,
|
||||
**kargs):
|
||||
super().__init__()
|
||||
|
||||
self.cfg = config
|
||||
|
|
|
@ -20,6 +20,7 @@ Arch:
|
|||
name: "ResNet50"
|
||||
pretrained: True
|
||||
class_num: 26
|
||||
infer_add_softmax: False
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
|
@ -110,5 +111,3 @@ DataLoader:
|
|||
Metric:
|
||||
Eval:
|
||||
- ATTRMetric:
|
||||
|
||||
|
||||
|
|
|
@ -457,7 +457,9 @@ class Engine(object):
|
|||
|
||||
def export(self):
|
||||
assert self.mode == "export"
|
||||
use_multilabel = self.config["Global"].get("use_multilabel", False)
|
||||
use_multilabel = self.config["Global"].get(
|
||||
"use_multilabel",
|
||||
False) and not "ATTRMetric" in self.config["Metric"]["Eval"][0]
|
||||
model = ExportModel(self.config["Arch"], self.model, use_multilabel)
|
||||
if self.config["Global"]["pretrained_model"] is not None:
|
||||
load_dygraph_pretrain(model.base_model,
|
||||
|
|
Loading…
Reference in New Issue