refactor code

pull/1599/head
weishengyu 2021-12-23 20:51:15 +08:00
parent d5637367a9
commit 25edd1c0d8
1 changed files with 7 additions and 10 deletions

View File

@ -16,7 +16,7 @@ def build_gallery_layer(configs, feature_extractor):
embedding_size = configs["Arch"]["Head"]["embedding_size"]
batch_size = configs["IndexProcess"]["batch_size"]
image_shape = configs["Global"]["image_shape"]
image_shape = configs["Global"]["image_shape"].copy()
image_shape.insert(0, batch_size)
input_tensor = paddle.zeros(image_shape)
@ -57,25 +57,22 @@ def build_gallery_layer(configs, feature_extractor):
return gallery_layer
class FuseModel(paddle.nn.Layer):
def __init__(self, configs):
class GalleryLayer(paddle.nn.Layer):
def __init__(self, configs, feature_extractor):
super().__init__()
self.feature_extractor = build_model(configs)
load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
self.feature_extractor.eval()
self.feature_extractor.head = IdentityHead()
self.gallery_layer = build_gallery_layer(configs, self.feature_extractor)
self.gallery_layer = build_gallery_layer(configs, feature_extractor)
def forward(self, x):
x = self.feature_extractor(x)["features"]
x = paddle.nn.functional.normalize(x)
x = self.gallery_layer(x)
return x
def export_fuse_model(configs):
fuse_model = FuseModel(configs)
fuse_model = build_model(configs)
load_dygraph_pretrain(fuse_model, configs["Global"]["pretrained_model"])
fuse_model.eval()
fuse_model.head = GalleryLayer(configs, fuse_model)
save_path = configs["Global"]["save_inference_dir"]
fuse_model.quanter.save_quantized_model(
fuse_model,