add dygrapgh load and static load
parent
95ed78e2a6
commit
5a15c16581
|
@ -0,0 +1,70 @@
|
|||
mode: 'train'
|
||||
ARCHITECTURE:
|
||||
name: 'CSPResNet50'
|
||||
pretrained_model: ""
|
||||
model_save_dir: "./output/"
|
||||
classes_num: 1000
|
||||
total_images: 1020
|
||||
save_interval: 1
|
||||
validate: False
|
||||
valid_interval: 1
|
||||
epochs: 1
|
||||
topk: 5
|
||||
image_shape: [3, 224, 224]
|
||||
|
||||
LEARNING_RATE:
|
||||
function: 'Cosine'
|
||||
params:
|
||||
lr: 0.0125
|
||||
|
||||
OPTIMIZER:
|
||||
function: 'Momentum'
|
||||
params:
|
||||
momentum: 0.9
|
||||
regularizer:
|
||||
function: 'L2'
|
||||
factor: 0.00001
|
||||
|
||||
TRAIN:
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
file_list: "./dataset/flowers102/train_list.txt"
|
||||
data_dir: "./dataset/flowers102/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
|
||||
VALID:
|
||||
batch_size: 20
|
||||
num_workers: 4
|
||||
file_list: "./dataset/flowers102/val_list.txt"
|
||||
data_dir: "./dataset/flowers102/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
|
@ -2,6 +2,7 @@ mode: 'train'
|
|||
ARCHITECTURE:
|
||||
name: 'MobileNetV3_large_x1_0'
|
||||
pretrained_model: "./pretrained/MobileNetV3_large_x1_0_pretrained"
|
||||
load_static_weights: True
|
||||
model_save_dir: "./output/"
|
||||
classes_num: 102
|
||||
total_images: 1020
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
mode: 'train'
|
||||
ARCHITECTURE:
|
||||
name: 'MobileNetV3_large_x1_0'
|
||||
params:
|
||||
lr_mult_list: [0.25, 0.25, 0.5, 0.5, 0.75]
|
||||
pretrained_model: "./pretrained/MobileNetV3_large_x1_0_ssld_pretrained"
|
||||
load_static_weights: True
|
||||
model_save_dir: "./output/"
|
||||
classes_num: 102
|
||||
total_images: 1020
|
||||
save_interval: 1
|
||||
validate: True
|
||||
valid_interval: 1
|
||||
epochs: 20
|
||||
topk: 5
|
||||
image_shape: [3, 224, 224]
|
||||
|
||||
LEARNING_RATE:
|
||||
function: 'Cosine'
|
||||
params:
|
||||
lr: 0.00375
|
||||
|
||||
OPTIMIZER:
|
||||
function: 'Momentum'
|
||||
params:
|
||||
momentum: 0.9
|
||||
regularizer:
|
||||
function: 'L2'
|
||||
factor: 0.000001
|
||||
|
||||
TRAIN:
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
file_list: "./dataset/flowers102/train_list.txt"
|
||||
data_dir: "./dataset/flowers102/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
|
||||
VALID:
|
||||
batch_size: 20
|
||||
num_workers: 4
|
||||
file_list: "./dataset/flowers102/val_list.txt"
|
||||
data_dir: "./dataset/flowers102/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
|
@ -5,6 +5,9 @@ ARCHITECTURE:
|
|||
pretrained_model:
|
||||
- "./pretrained/flowers102_R50_vd_final/ppcls"
|
||||
- "./pretrained/MobileNetV3_large_x1_0_pretrained/"
|
||||
load_static_weights:
|
||||
- False
|
||||
- True
|
||||
model_save_dir: "./output/"
|
||||
classes_num: 102
|
||||
total_images: 7169
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
mode: 'train'
|
||||
ARCHITECTURE:
|
||||
name: 'ResNet50_vd'
|
||||
|
||||
checkpoints: ""
|
||||
pretrained_model: ""
|
||||
load_static_weights: True
|
||||
model_save_dir: "./output/"
|
||||
classes_num: 102
|
||||
total_images: 1020
|
||||
|
|
|
@ -4,6 +4,7 @@ ARCHITECTURE:
|
|||
params:
|
||||
lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3]
|
||||
pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained"
|
||||
load_static_weights: True
|
||||
model_save_dir: "./output/"
|
||||
classes_num: 102
|
||||
total_images: 1020
|
||||
|
|
|
@ -4,6 +4,7 @@ ARCHITECTURE:
|
|||
params:
|
||||
lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3]
|
||||
pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained"
|
||||
load_static_weights: True
|
||||
model_save_dir: "./output/"
|
||||
classes_num: 102
|
||||
total_images: 1020
|
||||
|
|
|
@ -28,3 +28,5 @@ from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75
|
|||
from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0
|
||||
from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25
|
||||
from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish
|
||||
|
||||
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0
|
|
@ -1,16 +1,16 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -32,17 +32,20 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
class ResNet50_vd_distill_MobileNetV3_large_x1_0():
|
||||
def net(self, input, class_dim=1000):
|
||||
# student
|
||||
student = MobileNetV3_large_x1_0()
|
||||
out_student = student.net(input, class_dim=class_dim)
|
||||
# teacher
|
||||
teacher = ResNet50_vd()
|
||||
out_teacher = teacher.net(input, class_dim=class_dim)
|
||||
out_teacher.stop_gradient = True
|
||||
class ResNet50_vd_distill_MobileNetV3_large_x1_0(fluid.dygraph.Layer):
|
||||
def __init__(self, class_dim=1000, **args):
|
||||
super(ResNet50_vd_distill_MobileNetV3_large_x1_0, self).__init__()
|
||||
|
||||
return out_teacher, out_student
|
||||
self.teacher = ResNet50_vd(class_dim=class_dim, **args)
|
||||
|
||||
self.student = MobileNetV3_large_x1_0(class_dim=class_dim, **args)
|
||||
|
||||
def forward(self, input):
|
||||
teacher_label = self.teacher(input)
|
||||
|
||||
student_label = self.student(input)
|
||||
|
||||
return teacher_label, student_label
|
||||
|
||||
|
||||
class ResNeXt101_32x16d_wsl_distill_ResNet50_vd():
|
||||
|
|
|
@ -16,9 +16,6 @@ import logging
|
|||
import os
|
||||
import datetime
|
||||
|
||||
from imp import reload
|
||||
reload(logging)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s: %(message)s",
|
||||
|
|
|
@ -26,7 +26,7 @@ import paddle.fluid as fluid
|
|||
|
||||
from ppcls.utils import logger
|
||||
|
||||
__all__ = ['init_model', 'save_model']
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
|
||||
|
||||
|
||||
def _mkdir_if_not_exist(path):
|
||||
|
@ -45,71 +45,35 @@ def _mkdir_if_not_exist(path):
|
|||
raise OSError('Failed to mkdir {}'.format(path))
|
||||
|
||||
|
||||
def _load_state(path):
|
||||
if os.path.exists(path + '.pdopt'):
|
||||
# XXX another hack to ignore the optimizer state
|
||||
tmp = tempfile.mkdtemp()
|
||||
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
|
||||
shutil.copy(path + '.pdparams', dst + '.pdparams')
|
||||
state = fluid.io.load_program_state(dst)
|
||||
shutil.rmtree(tmp)
|
||||
else:
|
||||
state = fluid.io.load_program_state(path)
|
||||
return state
|
||||
|
||||
|
||||
def load_params(exe, prog, path, ignore_params=None):
|
||||
"""
|
||||
Load model from the given path.
|
||||
Args:
|
||||
exe (fluid.Executor): The fluid.Executor object.
|
||||
prog (fluid.Program): load weight to which Program object.
|
||||
path (string): URL string or loca model path.
|
||||
ignore_params (list): ignore variable to load when finetuning.
|
||||
It can be specified by finetune_exclude_pretrained_params
|
||||
and the usage can refer to the document
|
||||
docs/advanced_tutorials/TRANSFER_LEARNING.md
|
||||
"""
|
||||
def load_dygraph_pretrain(
|
||||
model,
|
||||
path=None,
|
||||
load_static_weights=False, ):
|
||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(path))
|
||||
if load_static_weights:
|
||||
pre_state_dict = fluid.load_program_state(path)
|
||||
param_state_dict = {}
|
||||
model_dict = model.state_dict()
|
||||
for key in model_dict.keys():
|
||||
weight_name = model_dict[key].name
|
||||
print("dyg key: {}, weight_name: {}".format(key, weight_name))
|
||||
if weight_name in pre_state_dict.keys():
|
||||
print('Load weight: {}, shape: {}'.format(
|
||||
weight_name, pre_state_dict[weight_name].shape))
|
||||
param_state_dict[key] = pre_state_dict[weight_name]
|
||||
else:
|
||||
param_state_dict[key] = model_dict[key]
|
||||
model.set_dict(param_state_dict)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
logger.coloring('Loading parameters from {}...'.format(path),
|
||||
'HEADER'))
|
||||
|
||||
ignore_set = set()
|
||||
state = _load_state(path)
|
||||
|
||||
# ignore the parameter which mismatch the shape
|
||||
# between the model and pretrain weight.
|
||||
all_var_shape = {}
|
||||
for block in prog.blocks:
|
||||
for param in block.all_parameters():
|
||||
all_var_shape[param.name] = param.shape
|
||||
ignore_set.update([
|
||||
name for name, shape in all_var_shape.items()
|
||||
if name in state and shape != state[name].shape
|
||||
])
|
||||
|
||||
if ignore_params:
|
||||
all_var_names = [var.name for var in prog.list_vars()]
|
||||
ignore_list = filter(
|
||||
lambda var: any([re.match(name, var) for name in ignore_params]),
|
||||
all_var_names)
|
||||
ignore_set.update(list(ignore_list))
|
||||
|
||||
if len(ignore_set) > 0:
|
||||
for k in ignore_set:
|
||||
if k in state:
|
||||
logger.warning(
|
||||
'variable {} is already excluded automatically'.format(k))
|
||||
del state[k]
|
||||
|
||||
fluid.io.set_program_state(prog, state)
|
||||
param_state_dict, optim_state_dict = fluid.load_dygraph(path)
|
||||
model.set_dict(param_state_dict)
|
||||
return
|
||||
|
||||
|
||||
def init_model(config, net, optimizer):
|
||||
def init_model(config, net, optimizer=None):
|
||||
"""
|
||||
load model from checkpoint or pretrained_model
|
||||
"""
|
||||
|
@ -128,16 +92,24 @@ def init_model(config, net, optimizer):
|
|||
return
|
||||
|
||||
pretrained_model = config.get('pretrained_model')
|
||||
load_static_weights = config.get('load_static_weights', False)
|
||||
use_distillation = config.get('use_distillation', False)
|
||||
if pretrained_model:
|
||||
if not isinstance(pretrained_model, list):
|
||||
pretrained_model = [pretrained_model]
|
||||
# TODO: load pretrained_model
|
||||
raise NotImplementedError
|
||||
for pretrain in pretrained_model:
|
||||
load_params(exe, program, pretrain)
|
||||
logger.info(
|
||||
logger.coloring("Finish initing model from {}".format(
|
||||
pretrained_model), "HEADER"))
|
||||
if not isinstance(load_static_weights, list):
|
||||
load_static_weights = [load_static_weights] * len(pretrained_model)
|
||||
for idx, pretrained in enumerate(pretrained_model):
|
||||
load_static = load_static_weights[idx]
|
||||
model = net
|
||||
if use_distillation and not load_static:
|
||||
model = net.teacher
|
||||
load_dygraph_pretrain(
|
||||
model, path=pretrained, load_static_weights=load_static)
|
||||
|
||||
logger.info(
|
||||
logger.coloring("Finish initing model from {}".format(
|
||||
pretrained_model), "HEADER"))
|
||||
|
||||
|
||||
def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
|
||||
|
|
|
@ -18,6 +18,8 @@ import numpy as np
|
|||
|
||||
import paddle.fluid as fluid
|
||||
from ppcls.modeling import architectures
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
|
||||
|
||||
def parse_args():
|
||||
def str2bool(v):
|
||||
|
@ -28,9 +30,11 @@ def parse_args():
|
|||
parser.add_argument("-m", "--model", type=str)
|
||||
parser.add_argument("-p", "--pretrained_model", type=str)
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--load_static_weights", type=str2bool, default=True)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_operators():
|
||||
size = 224
|
||||
img_mean = [0.485, 0.456, 0.406]
|
||||
|
@ -66,32 +70,32 @@ def main():
|
|||
args = parse_args()
|
||||
operators = create_operators()
|
||||
# assign the place
|
||||
gpu_id = fluid.dygraph.parallel.Env().dev_id
|
||||
place = fluid.CUDAPlace(gpu_id)
|
||||
|
||||
pre_weights_dict = fluid.load_program_state(args.pretrained_model)
|
||||
if args.use_gpu:
|
||||
gpu_id = fluid.dygraph.parallel.Env().dev_id
|
||||
place = fluid.CUDAPlace(gpu_id)
|
||||
else:
|
||||
place = fluid.CPUPlace()
|
||||
|
||||
with fluid.dygraph.guard(place):
|
||||
net = architectures.__dict__[args.model]()
|
||||
data = preprocess(args.image_file, operators)
|
||||
data = np.expand_dims(data, axis=0)
|
||||
data = fluid.dygraph.to_variable(data)
|
||||
dy_weights_dict = net.state_dict()
|
||||
pre_weights_dict_new = {}
|
||||
for key in dy_weights_dict:
|
||||
weights_name = dy_weights_dict[key].name
|
||||
pre_weights_dict_new[key] = pre_weights_dict[weights_name]
|
||||
net.set_dict(pre_weights_dict_new)
|
||||
load_dygraph_pretrain(net, args.pretrained_model,
|
||||
args.load_static_weights)
|
||||
net.eval()
|
||||
outputs = net(data)
|
||||
outputs = fluid.layers.softmax(outputs)
|
||||
outputs = outputs.numpy()
|
||||
|
||||
|
||||
probs = postprocess(outputs)
|
||||
rank = 1
|
||||
for idx, prob in probs:
|
||||
print("top{:d}, class id: {:d}, probability: {:.4f}".format(
|
||||
rank, idx, prob))
|
||||
print("top{:d}, class id: {:d}, probability: {:.4f}".format(rank, idx,
|
||||
prob))
|
||||
rank += 1
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -71,6 +71,8 @@ def create_model(architecture, classes_num):
|
|||
"""
|
||||
name = architecture["name"]
|
||||
params = architecture.get("params", {})
|
||||
print(name)
|
||||
print(params)
|
||||
return architectures.__dict__[name](class_dim=classes_num, **params)
|
||||
|
||||
|
||||
|
|
|
@ -102,7 +102,7 @@ def main(args):
|
|||
config.model_save_dir,
|
||||
config.ARCHITECTURE["name"])
|
||||
save_model(net, optimizer, model_path,
|
||||
"best_model_in_epoch_" + str(epoch_id))
|
||||
"best_model")
|
||||
|
||||
# 3. save the persistable model
|
||||
if epoch_id % config.save_interval == 0:
|
||||
|
|
Loading…
Reference in New Issue