refactor code
parent
d5637367a9
commit
25edd1c0d8
|
@ -16,7 +16,7 @@ def build_gallery_layer(configs, feature_extractor):
|
|||
|
||||
embedding_size = configs["Arch"]["Head"]["embedding_size"]
|
||||
batch_size = configs["IndexProcess"]["batch_size"]
|
||||
image_shape = configs["Global"]["image_shape"]
|
||||
image_shape = configs["Global"]["image_shape"].copy()
|
||||
image_shape.insert(0, batch_size)
|
||||
input_tensor = paddle.zeros(image_shape)
|
||||
|
||||
|
@ -57,25 +57,22 @@ def build_gallery_layer(configs, feature_extractor):
|
|||
return gallery_layer
|
||||
|
||||
|
||||
class FuseModel(paddle.nn.Layer):
|
||||
def __init__(self, configs):
|
||||
class GalleryLayer(paddle.nn.Layer):
|
||||
def __init__(self, configs, feature_extractor):
|
||||
super().__init__()
|
||||
self.feature_extractor = build_model(configs)
|
||||
load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
|
||||
self.feature_extractor.eval()
|
||||
self.feature_extractor.head = IdentityHead()
|
||||
self.gallery_layer = build_gallery_layer(configs, self.feature_extractor)
|
||||
self.gallery_layer = build_gallery_layer(configs, feature_extractor)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.feature_extractor(x)["features"]
|
||||
x = paddle.nn.functional.normalize(x)
|
||||
x = self.gallery_layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def export_fuse_model(configs):
|
||||
fuse_model = FuseModel(configs)
|
||||
fuse_model = build_model(configs)
|
||||
load_dygraph_pretrain(fuse_model, configs["Global"]["pretrained_model"])
|
||||
fuse_model.eval()
|
||||
fuse_model.head = GalleryLayer(configs, fuse_model)
|
||||
save_path = configs["Global"]["save_inference_dir"]
|
||||
fuse_model.quanter.save_quantized_model(
|
||||
fuse_model,
|
||||
|
|
Loading…
Reference in New Issue