PaddleClas/ppcls/utils/gallery2fc.py

43 lines
1.0 KiB
Python
Raw Normal View History

2021-12-08 14:40:03 +08:00
import paddle
2021-12-08 19:47:43 +08:00
from ppcls.arch import build_model
2021-12-08 19:48:59 +08:00
from ppcls.utils.config import parse_config, parse_args
2021-12-08 19:56:30 +08:00
from ppcls.utils.save_load import load_dygraph_pretrain
2021-12-08 20:02:13 +08:00
from ppcls.utils.logger import init_logger
2021-12-08 19:56:30 +08:00
def load_feature_extractor(configs):
arch = build_model(configs["Arch"])
load_dygraph_pretrain(arch, configs["Global"]["pretrained_model"])
2021-12-08 14:40:03 +08:00
def build_gallery_feature(feature_extractor):
pass
def save_fuse_model(fuse_model):
pass
class FuseModel(paddle.nn.Layer):
def __init__(self, configs):
super().__init__()
2021-12-08 19:56:30 +08:00
self.feature_extractor = load_feature_extractor(configs)
2021-12-08 14:40:03 +08:00
self.gallery_layer = build_gallery_feature(self.feature_extractor)
def forward(self, x):
x = self.feature_model(x)
x = self.gallery_layer(x)
return x
def main():
args = parse_args()
configs = parse_config(args.config)
2021-12-08 20:02:13 +08:00
init_logger(name='gallery2fc')
2021-12-08 14:40:03 +08:00
fuse_model = FuseModel(configs)
save_fuse_model(fuse_model)
if __name__ == '__main__':
main()