parent
83056d44d5
commit
dd79f81fd7
|
@ -12,8 +12,54 @@
|
|||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import copy
|
||||
import importlib
|
||||
|
||||
import paddle.nn as nn
|
||||
|
||||
from . import backbone
|
||||
|
||||
from .backbone import *
|
||||
from ppcls.arch.loss_metrics.loss import *
|
||||
from .utils import *
|
||||
|
||||
|
||||
def build_model(config):
|
||||
config = copy.deepcopy(config)
|
||||
model_type = config.pop("name")
|
||||
mod = importlib.import_module(__name__)
|
||||
arch = getattr(mod, model_type)(**config)
|
||||
return arch
|
||||
|
||||
|
||||
class RecModel(nn.Layer):
|
||||
def __init__(self, **config):
|
||||
super().__init__()
|
||||
backbone_config = config["Backbone"]
|
||||
backbone_name = backbone_config.pop("name")
|
||||
self.backbone = getattr(backbone_name)(**backbone_config)
|
||||
if "backbone_stop_layer" in config:
|
||||
backbone_stop_layer = config["backbone_stop_layer"]
|
||||
self.backbone.stop_layer(backbone_stop_layer)
|
||||
|
||||
if "Neck" in config:
|
||||
neck_config = config["Neck"]
|
||||
neck_name = neck_config.pop("name")
|
||||
self.neck = getattr(neck_name)(**neck_config)
|
||||
else:
|
||||
self.neck = None
|
||||
|
||||
if "Head" in config:
|
||||
head_config = config["Head"]
|
||||
head_name = head_config.pop("name")
|
||||
self.head = getattr(head_name)(**head_config)
|
||||
else:
|
||||
self.head = None
|
||||
|
||||
def forward(self, x):
|
||||
y = self.backbone(x)
|
||||
if self.neck is not None:
|
||||
y = self.neck(y)
|
||||
if self.head is not None:
|
||||
y = self.head(y)
|
||||
return y
|
||||
|
|
|
@ -17,11 +17,11 @@ from __future__ import absolute_import, division, print_function
|
|||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn import MaxPool2D
|
||||
|
||||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
|
||||
__all__ = ["VGG11", "VGG13", "VGG16", "VGG19"]
|
||||
|
||||
|
@ -149,7 +149,12 @@ class ConvBlock(TheseusLayer):
|
|||
|
||||
|
||||
class VGGNet(TheseusLayer):
|
||||
def __init__(self, config, stop_grad_layers=0, class_num=1000):
|
||||
def __init__(self,
|
||||
config,
|
||||
stop_grad_layers=0,
|
||||
class_num=1000,
|
||||
pretrained=False,
|
||||
**args):
|
||||
super().__init__()
|
||||
|
||||
self.stop_grad_layers = stop_grad_layers
|
||||
|
@ -176,6 +181,9 @@ class VGGNet(TheseusLayer):
|
|||
self._fc2 = Linear(4096, 4096)
|
||||
self._out = Linear(4096, class_num)
|
||||
|
||||
if pretrained is not None:
|
||||
load_dygraph_pretrain(self, pretrained)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv_block_1(inputs)
|
||||
x = self._conv_block_2(x)
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import sys
|
||||
import copy
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
class CELoss(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, name="loss", epsilon=None):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
|
||||
epsilon = None
|
||||
self.epsilon = epsilon
|
||||
|
||||
def _labelsmoothing(self, target, class_num):
|
||||
if target.shape[-1] != class_num:
|
||||
one_hot_target = F.one_hot(target, class_num)
|
||||
else:
|
||||
one_hot_target = target
|
||||
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
|
||||
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
|
||||
return soft_target
|
||||
|
||||
def forward(self, logits, label, mode="train"):
|
||||
loss_dict = {}
|
||||
if self.epsilon is not None:
|
||||
class_num = logits.shape[-1]
|
||||
label = self._labelsmoothing(label, class_num)
|
||||
x = -F.log_softmax(x, axis=-1)
|
||||
loss = paddle.sum(x * label, axis=-1)
|
||||
else:
|
||||
if label.shape[-1] == logits.shape[-1]:
|
||||
label = F.softmax(label, axis=-1)
|
||||
soft_label = True
|
||||
else:
|
||||
soft_label = False
|
||||
loss = F.cross_entropy(logits, label=label, soft_label=soft_label)
|
||||
loss_dict[self.name] = paddle.mean(loss)
|
||||
return loss_dict
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
class Topk(nn.Layer):
|
||||
def __init__(self, topk=[1, 5]):
|
||||
super().__init__()
|
||||
assert isinstance(topk, (int, list))
|
||||
if isinstance(topk, int):
|
||||
topk = [topk]
|
||||
self.topk = topk
|
||||
|
||||
def forward(self, x, label):
|
||||
metric_dict = dict()
|
||||
for k in self.topk:
|
||||
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
|
||||
x, label, k=k)
|
||||
return metric_dict
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
def build_loss(config):
|
||||
loss_func = CELoss()
|
||||
return loss_func
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
def build_metrics(config):
|
||||
metrics_func = Topk()
|
||||
return metrics_func
|
|
@ -0,0 +1,100 @@
|
|||
# global configs
|
||||
Global:
|
||||
pretrained_model: ""
|
||||
output_dir: "./output/"
|
||||
device: "gpu"
|
||||
class_num: 1000
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 90
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
image_shape: [3, 224, 224]
|
||||
infer_imgs:
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "ResNet50"
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Piecewise
|
||||
learning_rate: 0.1
|
||||
decay_epochs: [30, 60, 90]
|
||||
values: [0.1, 0.01, 0.001, 0.0001]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.0001
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
# Dataset:
|
||||
# Sampler:
|
||||
# Loader:
|
||||
batch_size: 256
|
||||
num_workers: 4
|
||||
file_list: "./dataset/ILSVRC2012/train_list.txt"
|
||||
data_dir: "./dataset/ILSVRC2012/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
Eval:
|
||||
# TOTO: modify to the latest trainer
|
||||
# Dataset:
|
||||
# Sampler:
|
||||
# Loader:
|
||||
batch_size: 128
|
||||
num_workers: 4
|
||||
file_list: "./dataset/ILSVRC2012/val_list.txt"
|
||||
data_dir: "./dataset/ILSVRC2012/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- Topk:
|
||||
k: [1, 5]
|
||||
Eval:
|
||||
- Topk:
|
||||
k: [1, 5]
|
||||
|
|
@ -12,4 +12,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .reader import Reader
|
||||
import copy
|
||||
import paddle
|
||||
import os
|
||||
from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader
|
||||
|
||||
from ppcls.utils import logger
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
def build_dataloader(config, mode, device, seed=None):
|
||||
from . import reader
|
||||
from .reader import Reader
|
||||
dataloader = Reader(config, mode=mode, places=device)()
|
||||
return dataloader
|
||||
|
|
|
@ -250,13 +250,14 @@ class Reader:
|
|||
|
||||
def __init__(self, config, mode='train', places=None):
|
||||
try:
|
||||
self.params = config[mode.upper()]
|
||||
self.params = config[mode.capitalize()]
|
||||
except KeyError:
|
||||
raise ModeException(mode=mode)
|
||||
|
||||
use_mix = config.get('use_mix')
|
||||
self.params['mode'] = mode
|
||||
self.shuffle = mode == "train"
|
||||
self.is_train = mode == "train"
|
||||
|
||||
self.collate_fn = None
|
||||
self.batch_ops = []
|
||||
|
@ -298,7 +299,7 @@ class Reader:
|
|||
shuffle=False,
|
||||
num_workers=self.params["num_workers"])
|
||||
else:
|
||||
is_train = self.params['mode'] == "train"
|
||||
is_train = self.is_train
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
|
||||
|
||||
import argparse
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.distributed as dist
|
||||
|
||||
from ppcls.utils import config
|
||||
from ppcls.utils.check import check_gpu
|
||||
|
||||
from ppcls.utils.misc import AverageMeter
|
||||
from ppcls.utils import logger
|
||||
from ppcls.data import build_dataloader
|
||||
from ppcls.arch import build_model
|
||||
from ppcls.arch.loss_metrics import build_loss
|
||||
from ppcls.arch.loss_metrics import build_metrics
|
||||
from ppcls.optimizer import build_optimizer
|
||||
|
||||
from ppcls.utils import save_load
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, mode="train"):
|
||||
args = config.parse_args()
|
||||
self.config = config.get_config(
|
||||
args.config, overrides=args.override, show=True)
|
||||
self.mode = mode
|
||||
self.output_dir = self.config['Global']['output_dir']
|
||||
# set device
|
||||
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
|
||||
self.device = paddle.set_device(self.config["Global"]["device"])
|
||||
# set dist
|
||||
self.config["Global"][
|
||||
"distributed"] = paddle.distributed.get_world_size() != 1
|
||||
if self.config["Global"]["distributed"]:
|
||||
dist.init_parallel_env()
|
||||
self.model = build_model(self.config["Arch"])
|
||||
|
||||
if self.config["Global"]["distributed"]:
|
||||
self.model = paddle.DataParallel(self.model)
|
||||
|
||||
self.vdl_writer = None
|
||||
if self.config['Global']['use_visualdl']:
|
||||
from visualdl import LogWriter
|
||||
vdl_writer_path = os.path.join(self.output_dir, "vdl")
|
||||
if not os.path.exists(vdl_writer_path):
|
||||
os.makedirs(vdl_writer_path)
|
||||
self.vdl_writer = LogWriter(logdir=vdl_writer_path)
|
||||
logger.info('train with paddle {} and device {}'.format(
|
||||
paddle.__version__, self.device))
|
||||
|
||||
def _build_metric_info(self, metric_config, mode="train"):
|
||||
"""
|
||||
_build_metric_info: build metrics according to current mode
|
||||
Return:
|
||||
metric: dict of the metrics info
|
||||
"""
|
||||
metric = None
|
||||
mode = mode.capitalize()
|
||||
if mode in metric_config and metric_config[mode] is not None:
|
||||
metric = build_metrics(metric_config[mode])
|
||||
return metric
|
||||
|
||||
def _build_loss_info(self, loss_config, mode="train"):
|
||||
"""
|
||||
_build_loss_info: build loss according to current mode
|
||||
Return:
|
||||
loss_dict: dict of the loss info
|
||||
"""
|
||||
loss = None
|
||||
mode = mode.capitalize()
|
||||
if mode in loss_config and loss_config[mode] is not None:
|
||||
loss = build_loss(loss_config[mode])
|
||||
return loss
|
||||
|
||||
def train(self):
|
||||
# build train loss and metric info
|
||||
loss_func = self._build_loss_info(self.config["Loss"])
|
||||
|
||||
metric_func = self._build_metric_info(self.config["Metric"])
|
||||
|
||||
train_dataloader = build_dataloader(self.config["DataLoader"], "train",
|
||||
self.device)
|
||||
|
||||
step_each_epoch = len(train_dataloader)
|
||||
|
||||
optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
|
||||
self.config["Global"]["epochs"],
|
||||
step_each_epoch,
|
||||
self.model.parameters())
|
||||
|
||||
print_batch_step = self.config['Global']['print_batch_step']
|
||||
save_interval = self.config["Global"]["save_interval"]
|
||||
|
||||
best_metric = {
|
||||
"metric": 0.0,
|
||||
"epoch": 0,
|
||||
}
|
||||
# key:
|
||||
# val: metrics list word
|
||||
output_info = dict()
|
||||
# global iter counter
|
||||
global_step = 0
|
||||
|
||||
for epoch_id in range(1, self.config["Global"]["epochs"] + 1):
|
||||
self.model.train()
|
||||
for iter_id, batch in enumerate(train_dataloader()):
|
||||
batch_size = batch[0].shape[0]
|
||||
batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64")
|
||||
.reshape([-1, 1]))
|
||||
global_step += 1
|
||||
# image input
|
||||
out = self.model(batch[0])
|
||||
# calc loss
|
||||
loss_dict = loss_func(out, batch[-1])
|
||||
for key in loss_dict:
|
||||
if not key in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
# calc metric
|
||||
if metric_func is not None:
|
||||
metric_dict = metric_func(out, batch[-1])
|
||||
for key in metric_dict:
|
||||
if not key in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(metric_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
|
||||
if iter_id % print_batch_step == 0:
|
||||
lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
|
||||
metric_msg = ", ".join([
|
||||
"{}: {:.5f}".format(key, output_info[key].avg)
|
||||
for key in output_info
|
||||
])
|
||||
logger.info("[Train][Epoch {}][Iter: {}/{}]{}, {}".format(
|
||||
epoch_id, iter_id,
|
||||
len(train_dataloader), lr_msg, metric_msg))
|
||||
|
||||
# step opt and lr
|
||||
loss_dict["loss"].backward()
|
||||
optimizer.step()
|
||||
optimizer.clear_grad()
|
||||
lr_sch.step()
|
||||
|
||||
metric_msg = ", ".join([
|
||||
"{}: {:.5f}".format(key, output_info[key].avg)
|
||||
for key in output_info
|
||||
])
|
||||
logger.info("[Train][Epoch {}][Avg]{}".format(epoch_id,
|
||||
metric_msg))
|
||||
output_info.clear()
|
||||
|
||||
# eval model and save model if possible
|
||||
if self.config["Global"][
|
||||
"eval_during_train"] and epoch_id % self.config["Global"][
|
||||
"eval_during_train"] == 0:
|
||||
acc = self.eval(epoch_id)
|
||||
if acc >= best_metric["metric"]:
|
||||
best_metric["metric"] = acc
|
||||
best_metric["epoch"] = epoch_id
|
||||
save_load.save_model(
|
||||
self.model,
|
||||
optimizer,
|
||||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="best_model")
|
||||
|
||||
# save model
|
||||
if epoch_id % save_interval == 0:
|
||||
save_load.save_model(
|
||||
self.model,
|
||||
optimizer,
|
||||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="ppcls_epoch_{}".format(epoch_id))
|
||||
|
||||
def build_avg_metrics(self, info_dict):
|
||||
return {key: AverageMeter(key, '7.5f') for key in info_dict}
|
||||
|
||||
@paddle.no_grad()
|
||||
def eval(self, epoch_id=0):
|
||||
output_info = dict()
|
||||
|
||||
eval_dataloader = build_dataloader(self.config["DataLoader"], "eval",
|
||||
self.device)
|
||||
|
||||
self.model.eval()
|
||||
print_batch_step = self.config["Global"]["print_batch_step"]
|
||||
|
||||
# build train loss and metric info
|
||||
loss_func = self._build_loss_info(self.config["Loss"], "eval")
|
||||
metric_func = self._build_metric_info(self.config["Metric"], "eval")
|
||||
metric_key = None
|
||||
|
||||
for iter_id, batch in enumerate(eval_dataloader()):
|
||||
batch_size = batch[0].shape[0]
|
||||
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
|
||||
batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1])
|
||||
# image input
|
||||
out = self.model(batch[0])
|
||||
# calc build
|
||||
if loss_func is not None:
|
||||
loss_dict = loss_func(out, batch[-1])
|
||||
for key in loss_dict:
|
||||
if not key in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
# calc metric
|
||||
if metric_func is not None:
|
||||
metric_dict = metric_func(out, batch[-1])
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
for key in metric_dict:
|
||||
paddle.distributed.all_reduce(
|
||||
metric_dict[key],
|
||||
op=paddle.distributed.ReduceOp.SUM)
|
||||
metric_dict[key] = metric_dict[
|
||||
key] / paddle.distributed.get_world_size()
|
||||
for key in metric_dict:
|
||||
if metric_key is None:
|
||||
metric_key = key
|
||||
if not key in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
|
||||
output_info[key].update(metric_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
|
||||
if iter_id % print_batch_step == 0:
|
||||
metric_msg = ", ".join([
|
||||
"{}: {:.5f}".format(key, output_info[key].val)
|
||||
for key in output_info
|
||||
])
|
||||
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}".format(
|
||||
epoch_id, iter_id, len(eval_dataloader), metric_msg))
|
||||
|
||||
metric_msg = ", ".join([
|
||||
"{}: {:.5f}".format(key, output_info[key].avg)
|
||||
for key in output_info
|
||||
])
|
||||
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
|
||||
|
||||
self.model.train()
|
||||
# do not try to save best model
|
||||
if metric_func is None:
|
||||
return -1
|
||||
# return 1st metric in the dict
|
||||
return output_info[metric_key].avg
|
||||
|
||||
|
||||
def main():
|
||||
trainer = Trainer()
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -12,8 +12,55 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import optimizer
|
||||
from . import learning_rate
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from .optimizer import OptimizerBuilder
|
||||
from .learning_rate import LearningRateBuilder
|
||||
import copy
|
||||
import paddle
|
||||
|
||||
from ppcls.utils import logger
|
||||
|
||||
from . import optimizer
|
||||
|
||||
__all__ = ['build_optimizer']
|
||||
|
||||
|
||||
def build_lr_scheduler(lr_config, epochs, step_each_epoch):
|
||||
from . import learning_rate
|
||||
lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
|
||||
if 'name' in lr_config:
|
||||
lr_name = lr_config.pop('name')
|
||||
lr = getattr(learning_rate, lr_name)(**lr_config)()
|
||||
else:
|
||||
lr = lr_config['learning_rate']
|
||||
return lr
|
||||
|
||||
|
||||
def build_optimizer(config, epochs, step_each_epoch, parameters):
|
||||
config = copy.deepcopy(config)
|
||||
# step1 build lr
|
||||
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
|
||||
logger.info("build lr ({}) success..".format(lr))
|
||||
# step2 build regularization
|
||||
if 'regularizer' in config and config['regularizer'] is not None:
|
||||
reg_config = config.pop('regularizer')
|
||||
reg_name = reg_config.pop('name') + 'Decay'
|
||||
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
|
||||
else:
|
||||
reg = None
|
||||
logger.info("build regularizer ({}) success..".format(reg))
|
||||
# step3 build optimizer
|
||||
optim_name = config.pop('name')
|
||||
if 'clip_norm' in config:
|
||||
clip_norm = config.pop('clip_norm')
|
||||
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
|
||||
else:
|
||||
grad_clip = None
|
||||
optim = getattr(optimizer, optim_name)(learning_rate=lr,
|
||||
weight_decay=reg,
|
||||
grad_clip=grad_clip,
|
||||
parameter_list=parameters,
|
||||
**config)()
|
||||
logger.info("build optimizer ({}) success..".format(optim))
|
||||
return optim, lr
|
||||
|
|
|
@ -11,149 +11,173 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import math
|
||||
|
||||
from paddle.optimizer.lr import LinearWarmup
|
||||
from paddle.optimizer.lr import PiecewiseDecay
|
||||
from paddle.optimizer.lr import CosineAnnealingDecay
|
||||
from paddle.optimizer.lr import ExponentialDecay
|
||||
|
||||
__all__ = ['LearningRateBuilder']
|
||||
from __future__ import unicode_literals
|
||||
from paddle.optimizer import lr
|
||||
|
||||
|
||||
class Cosine(CosineAnnealingDecay):
|
||||
class Linear(object):
|
||||
"""
|
||||
Linear learning rate decay
|
||||
Args:
|
||||
lr (float): The initial learning rate. It is a python float number.
|
||||
epochs(int): The decay step size. It determines the decay cycle.
|
||||
end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
|
||||
power(float, optional): Power of polynomial. Default: 1.0.
|
||||
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
epochs,
|
||||
step_each_epoch,
|
||||
end_lr=0.0,
|
||||
power=1.0,
|
||||
warmup_epoch=0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super(Linear, self).__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.epochs = epochs * step_each_epoch
|
||||
self.end_lr = end_lr
|
||||
self.power = power
|
||||
self.last_epoch = last_epoch
|
||||
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
|
||||
|
||||
def __call__(self):
|
||||
learning_rate = lr.PolynomialDecay(
|
||||
learning_rate=self.learning_rate,
|
||||
decay_steps=self.epochs,
|
||||
end_lr=self.end_lr,
|
||||
power=self.power,
|
||||
last_epoch=self.last_epoch)
|
||||
if self.warmup_epoch > 0:
|
||||
learning_rate = lr.LinearWarmup(
|
||||
learning_rate=learning_rate,
|
||||
warmup_steps=self.warmup_epoch,
|
||||
start_lr=0.0,
|
||||
end_lr=self.learning_rate,
|
||||
last_epoch=self.last_epoch)
|
||||
return learning_rate
|
||||
|
||||
|
||||
class Cosine(object):
|
||||
"""
|
||||
Cosine learning rate decay
|
||||
lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
|
||||
|
||||
Args:
|
||||
lr(float): initial learning rate
|
||||
step_each_epoch(int): steps each epoch
|
||||
epochs(int): total training epochs
|
||||
"""
|
||||
|
||||
def __init__(self, lr, step_each_epoch, epochs, **kwargs):
|
||||
super(Cosine, self).__init__(
|
||||
learning_rate=lr,
|
||||
T_max=step_each_epoch * epochs, )
|
||||
|
||||
self.update_specified = False
|
||||
|
||||
|
||||
class Piecewise(PiecewiseDecay):
|
||||
"""
|
||||
Piecewise learning rate decay
|
||||
|
||||
Args:
|
||||
lr(float): initial learning rate
|
||||
step_each_epoch(int): steps each epoch
|
||||
decay_epochs(list): piecewise decay epochs
|
||||
gamma(float): decay factor
|
||||
"""
|
||||
|
||||
def __init__(self, lr, step_each_epoch, decay_epochs, gamma=0.1, **kwargs):
|
||||
boundaries = [step_each_epoch * e for e in decay_epochs]
|
||||
lr_values = [lr * (gamma**i) for i in range(len(boundaries) + 1)]
|
||||
super(Piecewise, self).__init__(
|
||||
boundaries=boundaries, values=lr_values)
|
||||
|
||||
self.update_specified = False
|
||||
|
||||
|
||||
class CosineWarmup(LinearWarmup):
|
||||
"""
|
||||
Cosine learning rate decay with warmup
|
||||
[0, warmup_epoch): linear warmup
|
||||
[warmup_epoch, epochs): cosine decay
|
||||
|
||||
Args:
|
||||
lr(float): initial learning rate
|
||||
step_each_epoch(int): steps each epoch
|
||||
epochs(int): total training epochs
|
||||
warmup_epoch(int): epoch num of warmup
|
||||
"""
|
||||
|
||||
def __init__(self, lr, step_each_epoch, epochs, warmup_epoch=5, **kwargs):
|
||||
assert epochs > warmup_epoch, "total epoch({}) should be larger than warmup_epoch({}) in CosineWarmup.".format(
|
||||
epochs, warmup_epoch)
|
||||
warmup_step = warmup_epoch * step_each_epoch
|
||||
start_lr = 0.0
|
||||
end_lr = lr
|
||||
lr_sch = Cosine(lr, step_each_epoch, epochs - warmup_epoch)
|
||||
|
||||
super(CosineWarmup, self).__init__(
|
||||
learning_rate=lr_sch,
|
||||
warmup_steps=warmup_step,
|
||||
start_lr=start_lr,
|
||||
end_lr=end_lr)
|
||||
|
||||
self.update_specified = False
|
||||
|
||||
|
||||
class ExponentialWarmup(LinearWarmup):
|
||||
"""
|
||||
Exponential learning rate decay with warmup
|
||||
[0, warmup_epoch): linear warmup
|
||||
[warmup_epoch, epochs): Exponential decay
|
||||
|
||||
Args:
|
||||
lr(float): initial learning rate
|
||||
step_each_epoch(int): steps each epoch
|
||||
decay_epochs(float): decay epochs
|
||||
decay_rate(float): decay rate
|
||||
warmup_epoch(int): epoch num of warmup
|
||||
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
lr,
|
||||
learning_rate,
|
||||
step_each_epoch,
|
||||
decay_epochs=2.4,
|
||||
decay_rate=0.97,
|
||||
warmup_epoch=5,
|
||||
epochs,
|
||||
warmup_epoch=0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
warmup_step = warmup_epoch * step_each_epoch
|
||||
start_lr = 0.0
|
||||
end_lr = lr
|
||||
lr_sch = ExponentialDecay(lr, decay_rate)
|
||||
|
||||
super(ExponentialWarmup, self).__init__(
|
||||
learning_rate=lr_sch,
|
||||
warmup_steps=warmup_step,
|
||||
start_lr=start_lr,
|
||||
end_lr=end_lr)
|
||||
|
||||
# NOTE: hac method to update exponential lr scheduler
|
||||
self.update_specified = True
|
||||
self.update_start_step = warmup_step
|
||||
self.update_step_interval = int(decay_epochs * step_each_epoch)
|
||||
self.step_each_epoch = step_each_epoch
|
||||
|
||||
|
||||
class LearningRateBuilder():
|
||||
"""
|
||||
Build learning rate variable
|
||||
https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/layers_cn.html
|
||||
|
||||
Args:
|
||||
function(str): class name of learning rate
|
||||
params(dict): parameters used for init the class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
function='Linear',
|
||||
params={'lr': 0.1,
|
||||
'steps': 100,
|
||||
'end_lr': 0.0}):
|
||||
self.function = function
|
||||
self.params = params
|
||||
super(Cosine, self).__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.T_max = step_each_epoch * epochs
|
||||
self.last_epoch = last_epoch
|
||||
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
|
||||
|
||||
def __call__(self):
|
||||
mod = sys.modules[__name__]
|
||||
lr = getattr(mod, self.function)(**self.params)
|
||||
return lr
|
||||
learning_rate = lr.CosineAnnealingDecay(
|
||||
learning_rate=self.learning_rate,
|
||||
T_max=self.T_max,
|
||||
last_epoch=self.last_epoch)
|
||||
if self.warmup_epoch > 0:
|
||||
learning_rate = lr.LinearWarmup(
|
||||
learning_rate=learning_rate,
|
||||
warmup_steps=self.warmup_epoch,
|
||||
start_lr=0.0,
|
||||
end_lr=self.learning_rate,
|
||||
last_epoch=self.last_epoch)
|
||||
return learning_rate
|
||||
|
||||
|
||||
class Step(object):
|
||||
"""
|
||||
Piecewise learning rate decay
|
||||
Args:
|
||||
step_each_epoch(int): steps each epoch
|
||||
learning_rate (float): The initial learning rate. It is a python float number.
|
||||
step_size (int): the interval to update.
|
||||
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
|
||||
It should be less than 1.0. Default: 0.1.
|
||||
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
step_size,
|
||||
step_each_epoch,
|
||||
gamma,
|
||||
warmup_epoch=0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super(Step, self).__init__()
|
||||
self.step_size = step_each_epoch * step_size
|
||||
self.learning_rate = learning_rate
|
||||
self.gamma = gamma
|
||||
self.last_epoch = last_epoch
|
||||
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
|
||||
|
||||
def __call__(self):
|
||||
learning_rate = lr.StepDecay(
|
||||
learning_rate=self.learning_rate,
|
||||
step_size=self.step_size,
|
||||
gamma=self.gamma,
|
||||
last_epoch=self.last_epoch)
|
||||
if self.warmup_epoch > 0:
|
||||
learning_rate = lr.LinearWarmup(
|
||||
learning_rate=learning_rate,
|
||||
warmup_steps=self.warmup_epoch,
|
||||
start_lr=0.0,
|
||||
end_lr=self.learning_rate,
|
||||
last_epoch=self.last_epoch)
|
||||
return learning_rate
|
||||
|
||||
|
||||
class Piecewise(object):
|
||||
"""
|
||||
Piecewise learning rate decay
|
||||
Args:
|
||||
boundaries(list): A list of steps numbers. The type of element in the list is python int.
|
||||
values(list): A list of learning rate values that will be picked during different epoch boundaries.
|
||||
The type of element in the list is python float.
|
||||
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
step_each_epoch,
|
||||
decay_epochs,
|
||||
values,
|
||||
warmup_epoch=0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super(Piecewise, self).__init__()
|
||||
self.boundaries = [step_each_epoch * e for e in decay_epochs]
|
||||
self.values = values
|
||||
self.last_epoch = last_epoch
|
||||
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
|
||||
|
||||
def __call__(self):
|
||||
learning_rate = lr.PiecewiseDecay(
|
||||
boundaries=self.boundaries,
|
||||
values=self.values,
|
||||
last_epoch=self.last_epoch)
|
||||
if self.warmup_epoch > 0:
|
||||
learning_rate = lr.LinearWarmup(
|
||||
learning_rate=learning_rate,
|
||||
warmup_steps=self.warmup_epoch,
|
||||
start_lr=0.0,
|
||||
end_lr=self.values[0],
|
||||
last_epoch=self.last_epoch)
|
||||
return learning_rate
|
||||
|
|
|
@ -11,13 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import copy
|
||||
import argparse
|
||||
import yaml
|
||||
|
||||
from ppcls.utils import check
|
||||
from ppcls.utils import logger
|
||||
|
||||
from ppcls.utils import check
|
||||
__all__ = ['get_config']
|
||||
|
||||
|
||||
|
@ -31,6 +31,9 @@ class AttrDict(dict):
|
|||
else:
|
||||
self[key] = value
|
||||
|
||||
def __deepcopy__(self, content):
|
||||
return copy.deepcopy(dict(self))
|
||||
|
||||
|
||||
def create_attr_dict(yaml_config):
|
||||
from ast import literal_eval
|
||||
|
@ -76,7 +79,6 @@ def print_dict(d, delimiter=0):
|
|||
logger.info("{}{} : {}".format(delimiter * " ",
|
||||
logger.coloring(k, "HEADER"),
|
||||
logger.coloring(v, "OKGREEN")))
|
||||
|
||||
if k.isupper():
|
||||
logger.info(placeholder)
|
||||
|
||||
|
@ -84,7 +86,6 @@ def print_dict(d, delimiter=0):
|
|||
def print_config(config):
|
||||
"""
|
||||
visualize configs
|
||||
|
||||
Arguments:
|
||||
config: configs
|
||||
"""
|
||||
|
@ -97,21 +98,15 @@ def check_config(config):
|
|||
Check config
|
||||
"""
|
||||
check.check_version()
|
||||
|
||||
use_gpu = config.get('use_gpu', True)
|
||||
if use_gpu:
|
||||
check.check_gpu()
|
||||
|
||||
architecture = config.get('ARCHITECTURE')
|
||||
check.check_architecture(architecture)
|
||||
check.check_model_with_running_mode(architecture)
|
||||
|
||||
#check.check_architecture(architecture)
|
||||
use_mix = config.get('use_mix', False)
|
||||
check.check_mix(architecture, use_mix)
|
||||
|
||||
classes_num = config.get('classes_num')
|
||||
check.check_classes_num(classes_num)
|
||||
|
||||
mode = config.get('mode', 'train')
|
||||
if mode.lower() == 'train':
|
||||
check.check_function_params(config, 'LEARNING_RATE')
|
||||
|
@ -121,7 +116,6 @@ def check_config(config):
|
|||
def override(dl, ks, v):
|
||||
"""
|
||||
Recursively replace dict of list
|
||||
|
||||
Args:
|
||||
dl(dict or list): dict or list to be replaced
|
||||
ks(list): list of keys
|
||||
|
@ -147,19 +141,15 @@ def override(dl, ks, v):
|
|||
if len(ks) == 1:
|
||||
# assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
|
||||
if not ks[0] in dl:
|
||||
logger.warning('A new filed ({}) detected!'.format(ks[0]))
|
||||
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
|
||||
dl[ks[0]] = str2num(v)
|
||||
else:
|
||||
if not ks[0] in dl:
|
||||
logger.warning('A new filed ({}) detected!'.format(ks[0]))
|
||||
dl[ks[0]] = {}
|
||||
override(dl[ks[0]], ks[1:], v)
|
||||
|
||||
|
||||
def override_config(config, options=None):
|
||||
"""
|
||||
Recursively override the config
|
||||
|
||||
Args:
|
||||
config(dict): dict to be replaced
|
||||
options(list): list of pairs(key0.key1.idx.key2=value)
|
||||
|
@ -167,7 +157,6 @@ def override_config(config, options=None):
|
|||
'topk=2',
|
||||
'VALID.transforms.1.ResizeImage.resize_short=300'
|
||||
]
|
||||
|
||||
Returns:
|
||||
config(dict): replaced config
|
||||
"""
|
||||
|
@ -183,7 +172,6 @@ def override_config(config, options=None):
|
|||
key, value = pair
|
||||
keys = key.split('.')
|
||||
override(config, keys, value)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
|
@ -197,5 +185,23 @@ def get_config(fname, overrides=None, show=True):
|
|||
override_config(config, overrides)
|
||||
if show:
|
||||
print_config(config)
|
||||
check_config(config)
|
||||
# check_config(config)
|
||||
return config
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("generic-image-rec train script")
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--config',
|
||||
type=str,
|
||||
default='configs/config.yaml',
|
||||
help='config file path')
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--override',
|
||||
action='append',
|
||||
default=[],
|
||||
help='config options to be overridden')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
|
|
@ -146,13 +146,13 @@ def _save_student_model(net, model_prefix):
|
|||
student_model_prefix))
|
||||
|
||||
|
||||
def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
|
||||
def save_model(net, optimizer, model_path, model_name="", prefix='ppcls'):
|
||||
"""
|
||||
save model to the target path
|
||||
"""
|
||||
if paddle.distributed.get_rank() != 0:
|
||||
return
|
||||
model_path = os.path.join(model_path, str(epoch_id))
|
||||
model_path = os.path.join(model_path, model_name)
|
||||
_mkdir_if_not_exist(model_path)
|
||||
model_prefix = os.path.join(model_path, prefix)
|
||||
|
||||
|
|
Loading…
Reference in New Issue