PaddleClas/tools/program.py

368 lines
12 KiB
Python
Raw Normal View History

# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2020-04-09 02:16:30 +08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
2020-04-09 02:16:30 +08:00
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2020-04-09 02:16:30 +08:00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from collections import OrderedDict
2020-08-29 17:44:30 +08:00
import paddle
2020-09-13 15:09:43 +08:00
from paddle import to_tensor
import paddle.nn as nn
import paddle.nn.functional as F
2020-04-09 02:16:30 +08:00
from ppcls.optimizer import LearningRateBuilder
from ppcls.optimizer import OptimizerBuilder
from ppcls.modeling import architectures
from ppcls.modeling.loss import CELoss
from ppcls.modeling.loss import MixCELoss
2020-04-17 12:43:42 +08:00
from ppcls.modeling.loss import JSDivLoss
2020-04-09 02:16:30 +08:00
from ppcls.modeling.loss import GoogLeNetLoss
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
2020-06-03 19:52:55 +08:00
def create_model(architecture, classes_num):
2020-04-09 02:16:30 +08:00
"""
Create a model
Args:
architecture(dict): architecture information,
name(such as ResNet50) is needed
2020-04-09 02:16:30 +08:00
image(variable): model input variable
classes_num(int): num of classes
Returns:
out(variable): model output variable
"""
2020-04-16 14:13:48 +08:00
name = architecture["name"]
2020-04-16 16:23:44 +08:00
params = architecture.get("params", {})
2020-06-03 19:52:55 +08:00
return architectures.__dict__[name](class_dim=classes_num, **params)
2020-04-09 02:16:30 +08:00
2020-06-29 14:06:47 +08:00
def create_loss(feeds,
out,
2020-04-09 02:16:30 +08:00
architecture,
classes_num=1000,
epsilon=None,
2020-04-17 12:43:42 +08:00
use_mix=False,
use_distillation=False):
2020-04-09 02:16:30 +08:00
"""
Create a loss for optimization, such as:
1. CrossEnotry loss
2. CrossEnotry loss with label smoothing
3. CrossEnotry loss with mix(mixup, cutmix, fmix)
4. CrossEnotry loss with label smoothing and (mixup, cutmix, fmix)
5. GoogLeNet loss
Args:
out(variable): model output variable
feeds(dict): dict of model input variables
architecture(dict): architecture information,
name(such as ResNet50) is needed
2020-04-09 02:16:30 +08:00
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
2020-04-17 12:43:42 +08:00
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
2020-04-09 02:16:30 +08:00
Returns:
loss(variable): loss variable
"""
2020-04-16 14:13:48 +08:00
if architecture["name"] == "GoogLeNet":
2020-04-09 02:16:30 +08:00
assert len(out) == 3, "GoogLeNet should have 3 outputs"
loss = GoogLeNetLoss(class_dim=classes_num, epsilon=epsilon)
2020-06-29 14:06:47 +08:00
return loss(out[0], out[1], out[2], feeds["label"])
2020-04-09 02:16:30 +08:00
2020-04-17 12:43:42 +08:00
if use_distillation:
assert len(out) == 2, ("distillation output length must be 2, "
"but got {}".format(len(out)))
2020-04-17 12:43:42 +08:00
loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
return loss(out[1], out[0])
if use_mix:
2020-04-09 02:16:30 +08:00
loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
2020-06-29 14:06:47 +08:00
feed_y_a = feeds['y_a']
feed_y_b = feeds['y_b']
feed_lam = feeds['lam']
return loss(out, feed_y_a, feed_y_b, feed_lam)
2020-04-09 02:16:30 +08:00
else:
loss = CELoss(class_dim=classes_num, epsilon=epsilon)
2020-06-29 14:06:47 +08:00
return loss(out, feeds["label"])
2020-04-09 02:16:30 +08:00
2020-04-29 14:22:53 +08:00
def create_metric(out,
2020-06-03 19:52:55 +08:00
label,
2020-04-29 14:22:53 +08:00
architecture,
topk=5,
classes_num=1000,
use_distillation=False,
mode="train"):
2020-04-09 02:16:30 +08:00
"""
Create measures of model accuracy, such as top1 and top5
Args:
out(variable): model output variable
feeds(dict): dict of model input variables(included label)
topk(int): usually top5
classes_num(int): num of classes
use_distillation(bool): whether to use distillation training
mode(str): mode, train/valid
2020-04-09 02:16:30 +08:00
Returns:
fetchs(dict): dict of measures
"""
2020-04-29 14:22:53 +08:00
if architecture["name"] == "GoogLeNet":
assert len(out) == 3, "GoogLeNet should have 3 outputs"
out = out[0]
2020-04-29 14:22:53 +08:00
else:
# just need student label to get metrics
if use_distillation:
out = out[1]
softmax_out = F.softmax(out)
2020-04-29 14:22:53 +08:00
2020-04-09 02:16:30 +08:00
fetchs = OrderedDict()
2020-04-29 14:22:53 +08:00
# set top1 to fetchs
2020-09-13 15:09:43 +08:00
top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
2020-04-29 14:22:53 +08:00
# set topk to fetchs
2020-04-09 02:16:30 +08:00
k = min(topk, classes_num)
2020-09-13 15:09:43 +08:00
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
# multi cards' eval
if mode != "train" and paddle.distributed.get_world_size() > 1:
top1 = paddle.distributed.all_reduce(
top1, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
topk = paddle.distributed.all_reduce(
topk, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
fetchs['top1'] = top1
2020-04-09 02:16:30 +08:00
topk_name = 'top{}'.format(k)
2020-06-03 19:52:55 +08:00
fetchs[topk_name] = topk
2020-04-09 02:16:30 +08:00
return fetchs
2020-06-29 15:19:59 +08:00
def create_fetchs(feeds, net, config, mode="train"):
2020-04-09 02:16:30 +08:00
"""
Create fetchs as model outputs(included loss and measures),
2020-04-17 12:43:42 +08:00
will call create_loss and create_metric(if use_mix).
2020-04-09 02:16:30 +08:00
Args:
out(variable): model output variable
2020-04-29 14:22:53 +08:00
feeds(dict): dict of model input variables.
If use mix_up, it will not include label.
architecture(dict): architecture information,
name(such as ResNet50) is needed
2020-04-09 02:16:30 +08:00
topk(int): usually top5
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
2020-04-17 12:43:42 +08:00
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
2020-04-09 02:16:30 +08:00
Returns:
fetchs(dict): dict of model outputs(included loss and measures)
"""
2020-06-29 15:19:59 +08:00
architecture = config.ARCHITECTURE
topk = config.topk
classes_num = config.classes_num
epsilon = config.get('ls_epsilon')
use_mix = config.get('use_mix') and mode == 'train'
use_distillation = config.get('use_distillation')
out = net(feeds["image"])
2020-04-09 02:16:30 +08:00
fetchs = OrderedDict()
2020-06-29 14:06:47 +08:00
fetchs['loss'] = create_loss(feeds, out, architecture, classes_num,
epsilon, use_mix, use_distillation)
2020-04-17 12:43:42 +08:00
if not use_mix:
metric = create_metric(
out,
feeds["label"],
architecture,
topk,
classes_num,
use_distillation,
mode=mode)
2020-04-09 02:16:30 +08:00
fetchs.update(metric)
return fetchs
2020-06-03 19:52:55 +08:00
def create_optimizer(config, parameter_list=None):
2020-04-09 02:16:30 +08:00
"""
Create an optimizer using config, usually including
learning rate and regularization.
Args:
config(dict): such as
{
'LEARNING_RATE':
{'function': 'Cosine',
'params': {'lr': 0.1}
},
'OPTIMIZER':
{'function': 'Momentum',
'params':{'momentum': 0.9},
'regularizer':
{'function': 'L2', 'factor': 0.0001}
}
}
Returns:
an optimizer instance
"""
# create learning_rate instance
lr_config = config['LEARNING_RATE']
lr_config['params'].update({
'epochs': config['epochs'],
'step_each_epoch':
config['total_images'] // config['TRAIN']['batch_size'],
})
lr = LearningRateBuilder(**lr_config)()
# create optimizer instance
opt_config = config['OPTIMIZER']
opt = OptimizerBuilder(**opt_config)
2020-09-15 17:43:19 +08:00
return opt(lr, parameter_list), lr
2020-04-09 02:16:30 +08:00
2020-06-29 14:06:47 +08:00
def create_feeds(batch, use_mix):
2020-08-29 17:44:30 +08:00
image = batch[0]
2020-06-29 14:06:47 +08:00
if use_mix:
2020-09-13 15:09:43 +08:00
y_a = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
y_b = to_tensor(batch[2].numpy().astype("int64").reshape(-1, 1))
lam = to_tensor(batch[3].numpy().astype("float32").reshape(-1, 1))
2020-06-29 14:06:47 +08:00
feeds = {"image": image, "y_a": y_a, "y_b": y_b, "lam": lam}
else:
2020-09-13 15:09:43 +08:00
label = to_tensor(batch[1].numpy().astype('int64').reshape(-1, 1))
2020-06-29 14:06:47 +08:00
feeds = {"image": image, "label": label}
return feeds
2020-09-15 17:43:19 +08:00
def run(dataloader,
config,
net,
optimizer=None,
lr_scheduler=None,
epoch=0,
mode='train'):
2020-04-09 02:16:30 +08:00
"""
Feed data to the model and fetch the measures and loss
Args:
2020-09-13 15:09:43 +08:00
dataloader(paddle dataloader):
2020-04-09 02:16:30 +08:00
exe():
program():
fetchs(dict): dict of measures and the loss
epoch(int): epoch of training or validation
model(str): log only
Returns:
"""
2020-09-02 23:15:49 +08:00
print_interval = config.get("print_interval", 10)
2020-06-29 14:23:13 +08:00
use_mix = config.get("use_mix", False) and mode == "train"
2020-09-02 23:15:49 +08:00
metric_list = [
("loss", AverageMeter(
'loss', '7.5f', postfix=",")),
2020-09-02 23:15:49 +08:00
("lr", AverageMeter(
'lr', 'f', postfix=",", need_avg=False)),
("batch_time", AverageMeter(
'batch_cost', '.5f', postfix=" s,")),
("reader_time", AverageMeter(
'reader_cost', '.5f', postfix=" s,")),
2020-09-02 23:15:49 +08:00
]
if not use_mix:
2020-06-29 14:23:13 +08:00
topk_name = 'top{}'.format(config.topk)
metric_list.insert(
1, (topk_name, AverageMeter(
topk_name, '.5f', postfix=",")))
metric_list.insert(
1, ("top1", AverageMeter(
"top1", '.5f', postfix=",")))
2020-09-02 23:15:49 +08:00
metric_list = OrderedDict(metric_list)
2020-06-03 19:52:55 +08:00
2020-04-09 02:16:30 +08:00
tic = time.time()
2020-06-29 14:06:47 +08:00
for idx, batch in enumerate(dataloader()):
2020-09-22 15:17:28 +08:00
metric_list['reader_time'].update(time.time() - tic)
2020-06-29 14:53:54 +08:00
batch_size = len(batch[0])
2020-06-29 14:23:13 +08:00
feeds = create_feeds(batch, use_mix)
2020-06-29 15:19:59 +08:00
fetchs = create_fetchs(feeds, net, config, mode)
2020-06-03 19:52:55 +08:00
if mode == 'train':
avg_loss = fetchs['loss']
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
2020-06-03 19:52:55 +08:00
metric_list['lr'].update(
2020-06-29 14:53:54 +08:00
optimizer._global_learning_rate().numpy()[0], batch_size)
2020-06-03 19:52:55 +08:00
2020-09-15 17:43:19 +08:00
if lr_scheduler is not None:
if lr_scheduler.update_specified:
curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx
update = max(
0, curr_global_counter - lr_scheduler.update_start_step
) % lr_scheduler.update_step_interval == 0
if update:
lr_scheduler.step()
else:
lr_scheduler.step()
2020-06-03 19:52:55 +08:00
for name, fetch in fetchs.items():
2020-06-29 14:53:54 +08:00
metric_list[name].update(fetch.numpy()[0], batch_size)
metric_list["batch_time"].update(time.time() - tic)
2020-04-09 02:16:30 +08:00
tic = time.time()
2020-06-03 19:52:55 +08:00
fetchs_str = ' '.join([str(m.value) for m in metric_list.values()])
2020-09-02 23:15:49 +08:00
if idx % print_interval == 0:
ips_info = "ips: {:.5f} images/sec.".format(
batch_size / metric_list["batch_time"].val)
2020-09-02 23:15:49 +08:00
if mode == 'eval':
logger.info("{:s} step:{:<4d}, {:s} {:s}".format(
mode, idx, fetchs_str, ips_info))
2020-09-02 23:15:49 +08:00
else:
epoch_str = "epoch:{:<3d}".format(epoch)
step_str = "{:s} step:{:<4d}".format(mode, idx)
logger.info("{:s}, {:s}, {:s} {:s}".format(
2020-09-02 23:15:49 +08:00
logger.coloring(epoch_str, "HEADER")
if idx == 0 else epoch_str,
logger.coloring(step_str, "PURPLE"),
logger.coloring(fetchs_str, 'OKGREEN'),
logger.coloring(ips_info, 'OKGREEN')))
2020-04-27 05:52:44 +08:00
2020-06-29 14:06:47 +08:00
end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
[metric_list['batch_time'].total])
ips_info = "ips: {:.5f} images/sec.".format(
batch_size * metric_list["batch_time"].count /
metric_list["batch_time"].sum)
2020-04-27 21:27:50 +08:00
if mode == 'eval':
logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
2020-04-27 21:24:41 +08:00
else:
2020-05-06 16:35:11 +08:00
end_epoch_str = "END epoch:{:<3d}".format(epoch)
logger.info("{:s} {:s} {:s} {:s}".format(
2020-05-17 13:05:56 +08:00
logger.coloring(end_epoch_str, "RED"),
logger.coloring(mode, "PURPLE"),
logger.coloring(end_str, "OKGREEN"),
logger.coloring(ips_info, "OKGREEN"), ))
2020-04-27 10:30:09 +08:00
2020-04-27 21:33:25 +08:00
# return top1_acc in order to save the best model
2020-04-27 21:24:41 +08:00
if mode == 'valid':
2020-06-29 14:06:47 +08:00
return metric_list['top1'].avg