2021-12-08 14:40:03 +08:00
|
|
|
import paddle
|
2021-12-08 19:47:43 +08:00
|
|
|
from ppcls.arch import build_model
|
2021-12-08 19:48:59 +08:00
|
|
|
from ppcls.utils.config import parse_config, parse_args
|
2021-12-08 14:40:03 +08:00
|
|
|
|
|
|
|
|
|
|
|
def build_gallery_feature(feature_extractor):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def save_fuse_model(fuse_model):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class FuseModel(paddle.nn.Layer):
|
|
|
|
def __init__(self, configs):
|
|
|
|
super().__init__()
|
2021-12-08 19:48:59 +08:00
|
|
|
self.feature_extractor = build_model(configs)
|
2021-12-08 14:40:03 +08:00
|
|
|
self.gallery_layer = build_gallery_feature(self.feature_extractor)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.feature_model(x)
|
|
|
|
x = self.gallery_layer(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = parse_args()
|
|
|
|
configs = parse_config(args.config)
|
|
|
|
fuse_model = FuseModel(configs)
|
|
|
|
save_fuse_model(fuse_model)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|