From 5a15c1658198b801d5df85a07ca0c3c842647af4 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Fri, 28 Aug 2020 09:43:27 +0000 Subject: [PATCH] add dygrapgh load and static load --- configs/quick_start/CSPResNet50.yaml | 70 ++++++++++++ .../MobileNetV3_large_x1_0_finetune.yaml | 1 + .../MobileNetV3_large_x1_0_ssld_finetune.yaml | 73 ++++++++++++ .../R50_vd_distill_MV3_large_x1_0.yaml | 3 + configs/quick_start/ResNet50_vd.yaml | 3 + .../ResNet50_vd_ssld_finetune.yaml | 1 + ...Net50_vd_ssld_random_erasing_finetune.yaml | 1 + ppcls/modeling/architectures/__init__.py | 2 + .../architectures/distillation_models.py | 41 +++---- ppcls/utils/logger.py | 3 - ppcls/utils/save_load.py | 106 +++++++----------- tools/infer/infer.py | 30 ++--- tools/program.py | 2 + tools/train.py | 2 +- 14 files changed, 235 insertions(+), 103 deletions(-) create mode 100644 configs/quick_start/CSPResNet50.yaml create mode 100644 configs/quick_start/MobileNetV3_large_x1_0_ssld_finetune.yaml diff --git a/configs/quick_start/CSPResNet50.yaml b/configs/quick_start/CSPResNet50.yaml new file mode 100644 index 000000000..ead1071c6 --- /dev/null +++ b/configs/quick_start/CSPResNet50.yaml @@ -0,0 +1,70 @@ +mode: 'train' +ARCHITECTURE: + name: 'CSPResNet50' +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1020 +save_interval: 1 +validate: False +valid_interval: 1 +epochs: 1 +topk: 5 +image_shape: [3, 224, 224] + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.0125 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.00001 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/flowers102/train_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 20 + num_workers: 4 + file_list: "./dataset/flowers102/val_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml b/configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml index 827029b7d..dc47e2269 100644 --- a/configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml +++ b/configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml @@ -2,6 +2,7 @@ mode: 'train' ARCHITECTURE: name: 'MobileNetV3_large_x1_0' pretrained_model: "./pretrained/MobileNetV3_large_x1_0_pretrained" +load_static_weights: True model_save_dir: "./output/" classes_num: 102 total_images: 1020 diff --git a/configs/quick_start/MobileNetV3_large_x1_0_ssld_finetune.yaml b/configs/quick_start/MobileNetV3_large_x1_0_ssld_finetune.yaml new file mode 100644 index 000000000..1d73eaba4 --- /dev/null +++ b/configs/quick_start/MobileNetV3_large_x1_0_ssld_finetune.yaml @@ -0,0 +1,73 @@ +mode: 'train' +ARCHITECTURE: + name: 'MobileNetV3_large_x1_0' + params: + lr_mult_list: [0.25, 0.25, 0.5, 0.5, 0.75] +pretrained_model: "./pretrained/MobileNetV3_large_x1_0_ssld_pretrained" +load_static_weights: True +model_save_dir: "./output/" +classes_num: 102 +total_images: 1020 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 20 +topk: 5 +image_shape: [3, 224, 224] + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.00375 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000001 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/flowers102/train_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 20 + num_workers: 4 + file_list: "./dataset/flowers102/val_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml b/configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml index 93839e99c..3e4838212 100644 --- a/configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml +++ b/configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml @@ -5,6 +5,9 @@ ARCHITECTURE: pretrained_model: - "./pretrained/flowers102_R50_vd_final/ppcls" - "./pretrained/MobileNetV3_large_x1_0_pretrained/" +load_static_weights: + - False + - True model_save_dir: "./output/" classes_num: 102 total_images: 7169 diff --git a/configs/quick_start/ResNet50_vd.yaml b/configs/quick_start/ResNet50_vd.yaml index 913090921..76e4d316a 100644 --- a/configs/quick_start/ResNet50_vd.yaml +++ b/configs/quick_start/ResNet50_vd.yaml @@ -1,7 +1,10 @@ mode: 'train' ARCHITECTURE: name: 'ResNet50_vd' + +checkpoints: "" pretrained_model: "" +load_static_weights: True model_save_dir: "./output/" classes_num: 102 total_images: 1020 diff --git a/configs/quick_start/ResNet50_vd_ssld_finetune.yaml b/configs/quick_start/ResNet50_vd_ssld_finetune.yaml index d6bf86341..511dff005 100644 --- a/configs/quick_start/ResNet50_vd_ssld_finetune.yaml +++ b/configs/quick_start/ResNet50_vd_ssld_finetune.yaml @@ -4,6 +4,7 @@ ARCHITECTURE: params: lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3] pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained" +load_static_weights: True model_save_dir: "./output/" classes_num: 102 total_images: 1020 diff --git a/configs/quick_start/ResNet50_vd_ssld_random_erasing_finetune.yaml b/configs/quick_start/ResNet50_vd_ssld_random_erasing_finetune.yaml index 629f050ed..0687ea3b7 100644 --- a/configs/quick_start/ResNet50_vd_ssld_random_erasing_finetune.yaml +++ b/configs/quick_start/ResNet50_vd_ssld_random_erasing_finetune.yaml @@ -4,6 +4,7 @@ ARCHITECTURE: params: lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3] pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained" +load_static_weights: True model_save_dir: "./output/" classes_num: 102 total_images: 1020 diff --git a/ppcls/modeling/architectures/__init__.py b/ppcls/modeling/architectures/__init__.py index ffc085175..82d1b2a2e 100644 --- a/ppcls/modeling/architectures/__init__.py +++ b/ppcls/modeling/architectures/__init__.py @@ -28,3 +28,5 @@ from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75 from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0 from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25 from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish + +from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0 \ No newline at end of file diff --git a/ppcls/modeling/architectures/distillation_models.py b/ppcls/modeling/architectures/distillation_models.py index f5f24b36a..cf2af7eb7 100644 --- a/ppcls/modeling/architectures/distillation_models.py +++ b/ppcls/modeling/architectures/distillation_models.py @@ -1,16 +1,16 @@ -#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import absolute_import from __future__ import division @@ -32,17 +32,20 @@ __all__ = [ ] -class ResNet50_vd_distill_MobileNetV3_large_x1_0(): - def net(self, input, class_dim=1000): - # student - student = MobileNetV3_large_x1_0() - out_student = student.net(input, class_dim=class_dim) - # teacher - teacher = ResNet50_vd() - out_teacher = teacher.net(input, class_dim=class_dim) - out_teacher.stop_gradient = True +class ResNet50_vd_distill_MobileNetV3_large_x1_0(fluid.dygraph.Layer): + def __init__(self, class_dim=1000, **args): + super(ResNet50_vd_distill_MobileNetV3_large_x1_0, self).__init__() - return out_teacher, out_student + self.teacher = ResNet50_vd(class_dim=class_dim, **args) + + self.student = MobileNetV3_large_x1_0(class_dim=class_dim, **args) + + def forward(self, input): + teacher_label = self.teacher(input) + + student_label = self.student(input) + + return teacher_label, student_label class ResNeXt101_32x16d_wsl_distill_ResNet50_vd(): diff --git a/ppcls/utils/logger.py b/ppcls/utils/logger.py index 12789c7c8..c6efa17e0 100644 --- a/ppcls/utils/logger.py +++ b/ppcls/utils/logger.py @@ -16,9 +16,6 @@ import logging import os import datetime -from imp import reload -reload(logging) - logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s", diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 1d1e505b9..7fa7daba9 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -26,7 +26,7 @@ import paddle.fluid as fluid from ppcls.utils import logger -__all__ = ['init_model', 'save_model'] +__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] def _mkdir_if_not_exist(path): @@ -45,71 +45,35 @@ def _mkdir_if_not_exist(path): raise OSError('Failed to mkdir {}'.format(path)) -def _load_state(path): - if os.path.exists(path + '.pdopt'): - # XXX another hack to ignore the optimizer state - tmp = tempfile.mkdtemp() - dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) - shutil.copy(path + '.pdparams', dst + '.pdparams') - state = fluid.io.load_program_state(dst) - shutil.rmtree(tmp) - else: - state = fluid.io.load_program_state(path) - return state - - -def load_params(exe, prog, path, ignore_params=None): - """ - Load model from the given path. - Args: - exe (fluid.Executor): The fluid.Executor object. - prog (fluid.Program): load weight to which Program object. - path (string): URL string or loca model path. - ignore_params (list): ignore variable to load when finetuning. - It can be specified by finetune_exclude_pretrained_params - and the usage can refer to the document - docs/advanced_tutorials/TRANSFER_LEARNING.md - """ +def load_dygraph_pretrain( + model, + path=None, + load_static_weights=False, ): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): raise ValueError("Model pretrain path {} does not " "exists.".format(path)) + if load_static_weights: + pre_state_dict = fluid.load_program_state(path) + param_state_dict = {} + model_dict = model.state_dict() + for key in model_dict.keys(): + weight_name = model_dict[key].name + print("dyg key: {}, weight_name: {}".format(key, weight_name)) + if weight_name in pre_state_dict.keys(): + print('Load weight: {}, shape: {}'.format( + weight_name, pre_state_dict[weight_name].shape)) + param_state_dict[key] = pre_state_dict[weight_name] + else: + param_state_dict[key] = model_dict[key] + model.set_dict(param_state_dict) + return - logger.info( - logger.coloring('Loading parameters from {}...'.format(path), - 'HEADER')) - - ignore_set = set() - state = _load_state(path) - - # ignore the parameter which mismatch the shape - # between the model and pretrain weight. - all_var_shape = {} - for block in prog.blocks: - for param in block.all_parameters(): - all_var_shape[param.name] = param.shape - ignore_set.update([ - name for name, shape in all_var_shape.items() - if name in state and shape != state[name].shape - ]) - - if ignore_params: - all_var_names = [var.name for var in prog.list_vars()] - ignore_list = filter( - lambda var: any([re.match(name, var) for name in ignore_params]), - all_var_names) - ignore_set.update(list(ignore_list)) - - if len(ignore_set) > 0: - for k in ignore_set: - if k in state: - logger.warning( - 'variable {} is already excluded automatically'.format(k)) - del state[k] - - fluid.io.set_program_state(prog, state) + param_state_dict, optim_state_dict = fluid.load_dygraph(path) + model.set_dict(param_state_dict) + return -def init_model(config, net, optimizer): +def init_model(config, net, optimizer=None): """ load model from checkpoint or pretrained_model """ @@ -128,16 +92,24 @@ def init_model(config, net, optimizer): return pretrained_model = config.get('pretrained_model') + load_static_weights = config.get('load_static_weights', False) + use_distillation = config.get('use_distillation', False) if pretrained_model: if not isinstance(pretrained_model, list): pretrained_model = [pretrained_model] - # TODO: load pretrained_model - raise NotImplementedError - for pretrain in pretrained_model: - load_params(exe, program, pretrain) - logger.info( - logger.coloring("Finish initing model from {}".format( - pretrained_model), "HEADER")) + if not isinstance(load_static_weights, list): + load_static_weights = [load_static_weights] * len(pretrained_model) + for idx, pretrained in enumerate(pretrained_model): + load_static = load_static_weights[idx] + model = net + if use_distillation and not load_static: + model = net.teacher + load_dygraph_pretrain( + model, path=pretrained, load_static_weights=load_static) + + logger.info( + logger.coloring("Finish initing model from {}".format( + pretrained_model), "HEADER")) def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'): diff --git a/tools/infer/infer.py b/tools/infer/infer.py index 95aba7f39..410a90476 100644 --- a/tools/infer/infer.py +++ b/tools/infer/infer.py @@ -18,6 +18,8 @@ import numpy as np import paddle.fluid as fluid from ppcls.modeling import architectures +from ppcls.utils.save_load import load_dygraph_pretrain + def parse_args(): def str2bool(v): @@ -28,9 +30,11 @@ def parse_args(): parser.add_argument("-m", "--model", type=str) parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("--use_gpu", type=str2bool, default=True) + parser.add_argument("--load_static_weights", type=str2bool, default=True) return parser.parse_args() + def create_operators(): size = 224 img_mean = [0.485, 0.456, 0.406] @@ -66,32 +70,32 @@ def main(): args = parse_args() operators = create_operators() # assign the place - gpu_id = fluid.dygraph.parallel.Env().dev_id - place = fluid.CUDAPlace(gpu_id) - - pre_weights_dict = fluid.load_program_state(args.pretrained_model) + if args.use_gpu: + gpu_id = fluid.dygraph.parallel.Env().dev_id + place = fluid.CUDAPlace(gpu_id) + else: + place = fluid.CPUPlace() + with fluid.dygraph.guard(place): net = architectures.__dict__[args.model]() data = preprocess(args.image_file, operators) data = np.expand_dims(data, axis=0) data = fluid.dygraph.to_variable(data) - dy_weights_dict = net.state_dict() - pre_weights_dict_new = {} - for key in dy_weights_dict: - weights_name = dy_weights_dict[key].name - pre_weights_dict_new[key] = pre_weights_dict[weights_name] - net.set_dict(pre_weights_dict_new) + load_dygraph_pretrain(net, args.pretrained_model, + args.load_static_weights) net.eval() outputs = net(data) outputs = fluid.layers.softmax(outputs) outputs = outputs.numpy() - + probs = postprocess(outputs) rank = 1 for idx, prob in probs: - print("top{:d}, class id: {:d}, probability: {:.4f}".format( - rank, idx, prob)) + print("top{:d}, class id: {:d}, probability: {:.4f}".format(rank, idx, + prob)) rank += 1 + return + if __name__ == "__main__": main() diff --git a/tools/program.py b/tools/program.py index 55900b985..aaecbe0e3 100644 --- a/tools/program.py +++ b/tools/program.py @@ -71,6 +71,8 @@ def create_model(architecture, classes_num): """ name = architecture["name"] params = architecture.get("params", {}) + print(name) + print(params) return architectures.__dict__[name](class_dim=classes_num, **params) diff --git a/tools/train.py b/tools/train.py index 976136e35..7d919ba6b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -102,7 +102,7 @@ def main(args): config.model_save_dir, config.ARCHITECTURE["name"]) save_model(net, optimizer, model_path, - "best_model_in_epoch_" + str(epoch_id)) + "best_model") # 3. save the persistable model if epoch_id % config.save_interval == 0: