add support for train and eval (#752)
* add support for train and eval * rm unsed code * add support for metric save and load ckppull/756/head
parent
decdb51bb0
commit
3881343484
|
@ -1,13 +1,14 @@
|
|||
# global configs
|
||||
Global:
|
||||
pretrained_model: ""
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output/"
|
||||
device: "gpu"
|
||||
class_num: 1000
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 90
|
||||
epochs: 120
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
image_shape: [3, 224, 224]
|
||||
|
|
|
@ -25,9 +25,7 @@ import paddle
|
|||
import paddle.nn as nn
|
||||
import paddle.distributed as dist
|
||||
|
||||
from ppcls.utils import config
|
||||
from ppcls.utils.check import check_gpu
|
||||
|
||||
from ppcls.utils.misc import AverageMeter
|
||||
from ppcls.utils import logger
|
||||
from ppcls.data import build_dataloader
|
||||
|
@ -35,16 +33,15 @@ from ppcls.arch import build_model
|
|||
from ppcls.arch.loss_metrics import build_loss
|
||||
from ppcls.arch.loss_metrics import build_metrics
|
||||
from ppcls.optimizer import build_optimizer
|
||||
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
from ppcls.utils.save_load import init_model
|
||||
from ppcls.utils import save_load
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, mode="train"):
|
||||
args = config.parse_args()
|
||||
self.config = config.get_config(
|
||||
args.config, overrides=args.override, show=True)
|
||||
def __init__(self, config, mode="train"):
|
||||
self.mode = mode
|
||||
self.config = config
|
||||
self.output_dir = self.config['Global']['output_dir']
|
||||
# set device
|
||||
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
|
||||
|
@ -56,6 +53,10 @@ class Trainer(object):
|
|||
dist.init_parallel_env()
|
||||
self.model = build_model(self.config["Arch"])
|
||||
|
||||
if self.config["Global"]["pretrained_model"] is not None:
|
||||
load_dygraph_pretrain(self.model,
|
||||
self.config["Global"]["pretrained_model"])
|
||||
|
||||
if self.config["Global"]["distributed"]:
|
||||
self.model = paddle.DataParallel(self.model)
|
||||
|
||||
|
@ -122,7 +123,15 @@ class Trainer(object):
|
|||
# global iter counter
|
||||
global_step = 0
|
||||
|
||||
for epoch_id in range(1, self.config["Global"]["epochs"] + 1):
|
||||
if self.config["Global"]["checkpoints"] is not None:
|
||||
metric_info = init_model(self.config["Global"], self.model,
|
||||
optimizer)
|
||||
if metric_info is not None:
|
||||
best_metric.update(metric_info)
|
||||
|
||||
for epoch_id in range(best_metric["epoch"] + 1,
|
||||
self.config["Global"]["epochs"] + 1):
|
||||
acc = 0.0
|
||||
self.model.train()
|
||||
for iter_id, batch in enumerate(train_dataloader()):
|
||||
batch_size = batch[0].shape[0]
|
||||
|
@ -176,12 +185,13 @@ class Trainer(object):
|
|||
"eval_during_train"] and epoch_id % self.config["Global"][
|
||||
"eval_during_train"] == 0:
|
||||
acc = self.eval(epoch_id)
|
||||
if acc >= best_metric["metric"]:
|
||||
if acc > best_metric["metric"]:
|
||||
best_metric["metric"] = acc
|
||||
best_metric["epoch"] = epoch_id
|
||||
save_load.save_model(
|
||||
self.model,
|
||||
optimizer,
|
||||
best_metric,
|
||||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="best_model")
|
||||
|
@ -190,7 +200,8 @@ class Trainer(object):
|
|||
if epoch_id % save_interval == 0:
|
||||
save_load.save_model(
|
||||
self.model,
|
||||
optimizer,
|
||||
optimizer, {"metric": acc,
|
||||
"epoch": epoch_id},
|
||||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="ppcls_epoch_{}".format(epoch_id))
|
||||
|
@ -266,12 +277,3 @@ class Trainer(object):
|
|||
return -1
|
||||
# return 1st metric in the dict
|
||||
return output_info[metric_key].avg
|
||||
|
||||
|
||||
def main():
|
||||
trainer = Trainer()
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -71,11 +71,17 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
|
|||
return
|
||||
|
||||
|
||||
def load_dygraph_pretrain_from_url(model, pretrained_url, use_ssld, load_static_weights=False):
|
||||
def load_dygraph_pretrain_from_url(model,
|
||||
pretrained_url,
|
||||
use_ssld,
|
||||
load_static_weights=False):
|
||||
if use_ssld:
|
||||
pretrained_url = pretrained_url.replace("_pretrained", "_ssld_pretrained")
|
||||
local_weight_path = get_weights_path_from_url(pretrained_url).replace(".pdparams", "")
|
||||
load_dygraph_pretrain(model, path=local_weight_path, load_static_weights=load_static_weights)
|
||||
pretrained_url = pretrained_url.replace("_pretrained",
|
||||
"_ssld_pretrained")
|
||||
local_weight_path = get_weights_path_from_url(pretrained_url).replace(
|
||||
".pdparams", "")
|
||||
load_dygraph_pretrain(
|
||||
model, path=local_weight_path, load_static_weights=load_static_weights)
|
||||
return
|
||||
|
||||
|
||||
|
@ -121,10 +127,11 @@ def init_model(config, net, optimizer=None):
|
|||
"Given dir {}.pdopt not exist.".format(checkpoints)
|
||||
para_dict = paddle.load(checkpoints + ".pdparams")
|
||||
opti_dict = paddle.load(checkpoints + ".pdopt")
|
||||
metric_dict = paddle.load(checkpoints + ".pdstates")
|
||||
net.set_dict(para_dict)
|
||||
optimizer.set_state_dict(opti_dict)
|
||||
logger.info("Finish load checkpoints from {}".format(checkpoints))
|
||||
return
|
||||
return metric_dict
|
||||
|
||||
pretrained_model = config.get('pretrained_model')
|
||||
load_static_weights = config.get('load_static_weights', False)
|
||||
|
@ -155,7 +162,12 @@ def _save_student_model(net, model_prefix):
|
|||
student_model_prefix))
|
||||
|
||||
|
||||
def save_model(net, optimizer, model_path, model_name="", prefix='ppcls'):
|
||||
def save_model(net,
|
||||
optimizer,
|
||||
metric_info,
|
||||
model_path,
|
||||
model_name="",
|
||||
prefix='ppcls'):
|
||||
"""
|
||||
save model to the target path
|
||||
"""
|
||||
|
@ -169,4 +181,5 @@ def save_model(net, optimizer, model_path, model_name="", prefix='ppcls'):
|
|||
|
||||
paddle.save(net.state_dict(), model_prefix + ".pdparams")
|
||||
paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
|
||||
paddle.save(metric_info, model_prefix + ".pdstates")
|
||||
logger.info("Already save model in {}".format(model_path))
|
||||
|
|
112
tools/eval.py
112
tools/eval.py
|
@ -1,10 +1,10 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
@ -12,105 +12,19 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
import argparse
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
from ppcls.utils import logger
|
||||
from ppcls.utils.save_load import init_model
|
||||
from ppcls.utils.config import get_config
|
||||
from ppcls.utils import multi_hot_encode
|
||||
from ppcls.utils import accuracy_score
|
||||
from ppcls.utils import mean_average_precision
|
||||
from ppcls.utils import precision_recall_fscore
|
||||
from ppcls.data import Reader
|
||||
import program
|
||||
from ppcls.utils import config
|
||||
from ppcls.engine.trainer import Trainer
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("PaddleClas eval script")
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--config',
|
||||
type=str,
|
||||
default='./configs/eval.yaml',
|
||||
help='config file path')
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--override',
|
||||
action='append',
|
||||
default=[],
|
||||
help='config options to be overridden')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args, return_dict={}):
|
||||
config = get_config(args.config, overrides=args.override, show=True)
|
||||
config.mode = "valid"
|
||||
# assign place
|
||||
use_gpu = config.get("use_gpu", True)
|
||||
place = paddle.set_device('gpu' if use_gpu else 'cpu')
|
||||
multilabel = config.get("multilabel", False)
|
||||
|
||||
trainer_num = paddle.distributed.get_world_size()
|
||||
use_data_parallel = trainer_num != 1
|
||||
config["use_data_parallel"] = use_data_parallel
|
||||
|
||||
if config["use_data_parallel"]:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
net = program.create_model(config.ARCHITECTURE, config.classes_num)
|
||||
|
||||
init_model(config, net, optimizer=None)
|
||||
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||
if len(valid_dataloader) <= 0:
|
||||
logger.error(
|
||||
"valid dataloader is empty, please check your data config again!")
|
||||
sys.exit(-1)
|
||||
net.eval()
|
||||
with paddle.no_grad():
|
||||
if not multilabel:
|
||||
top1_acc = program.run(valid_dataloader, config, net, None, None,
|
||||
0, 'valid')
|
||||
return_dict["top1_acc"] = top1_acc
|
||||
|
||||
return top1_acc
|
||||
else:
|
||||
all_outs = []
|
||||
targets = []
|
||||
for _, batch in enumerate(valid_dataloader()):
|
||||
feeds = program.create_feeds(batch, False, config.classes_num,
|
||||
multilabel)
|
||||
out = net(feeds["image"])
|
||||
out = F.sigmoid(out)
|
||||
|
||||
use_distillation = config.get("use_distillation", False)
|
||||
if use_distillation:
|
||||
out = out[1]
|
||||
|
||||
all_outs.extend(list(out.numpy()))
|
||||
targets.extend(list(feeds["label"].numpy()))
|
||||
all_outs = np.array(all_outs)
|
||||
targets = np.array(targets)
|
||||
|
||||
mAP = mean_average_precision(all_outs, targets)
|
||||
|
||||
return_dict["mean average precision"] = mAP
|
||||
|
||||
return mAP
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
return_dict = {}
|
||||
main(args, return_dict)
|
||||
print(return_dict)
|
||||
if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
config = config.get_config(args.config, overrides=args.override, show=True)
|
||||
trainer = Trainer(config, mode="eval")
|
||||
trainer.eval()
|
||||
|
|
148
tools/train.py
148
tools/train.py
|
@ -1,10 +1,10 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
@ -15,144 +15,16 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
import paddle
|
||||
from ppcls.utils import config
|
||||
from ppcls.engine.trainer import Trainer
|
||||
|
||||
from ppcls.data import Reader
|
||||
from ppcls.utils.config import get_config
|
||||
from ppcls.utils.save_load import init_model, save_model
|
||||
from ppcls.utils import logger
|
||||
import program
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("PaddleClas train script")
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--config',
|
||||
type=str,
|
||||
default='configs/ResNet/ResNet50.yaml',
|
||||
help='config file path')
|
||||
parser.add_argument(
|
||||
'-p',
|
||||
'--profiler_options',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--override',
|
||||
action='append',
|
||||
default=[],
|
||||
help='config options to be overridden')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
paddle.seed(12345)
|
||||
|
||||
config = get_config(args.config, overrides=args.override, show=True)
|
||||
# assign the place
|
||||
use_gpu = config.get("use_gpu", True)
|
||||
use_xpu = config.get("use_xpu", False)
|
||||
assert (
|
||||
use_gpu and use_xpu
|
||||
) is not True, "gpu and xpu can not be true in the same time in static mode!"
|
||||
if use_gpu:
|
||||
place = paddle.set_device('gpu')
|
||||
elif use_xpu:
|
||||
place = paddle.set_device('xpu')
|
||||
else:
|
||||
place = paddle.set_device('cpu')
|
||||
|
||||
trainer_num = paddle.distributed.get_world_size()
|
||||
use_data_parallel = trainer_num != 1
|
||||
config["use_data_parallel"] = use_data_parallel
|
||||
|
||||
if config["use_data_parallel"]:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
net = program.create_model(config.ARCHITECTURE, config.classes_num)
|
||||
optimizer, lr_scheduler = program.create_optimizer(
|
||||
config, parameter_list=net.parameters())
|
||||
|
||||
dp_net = net
|
||||
if config["use_data_parallel"]:
|
||||
find_unused_parameters = config.get("find_unused_parameters", False)
|
||||
dp_net = paddle.DataParallel(
|
||||
net, find_unused_parameters=find_unused_parameters)
|
||||
|
||||
# load model from checkpoint or pretrained model
|
||||
init_model(config, net, optimizer)
|
||||
|
||||
train_dataloader = Reader(config, 'train', places=place)()
|
||||
if len(train_dataloader) <= 0:
|
||||
logger.error(
|
||||
"train dataloader is empty, please check your data config again!")
|
||||
sys.exit(-1)
|
||||
|
||||
if config.validate:
|
||||
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||
if len(valid_dataloader) <= 0:
|
||||
logger.error(
|
||||
"valid dataloader is empty, please check your data config again!"
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
last_epoch_id = config.get("last_epoch", -1)
|
||||
best_top1_acc = 0.0 # best top1 acc record
|
||||
best_top1_epoch = last_epoch_id
|
||||
|
||||
vdl_writer_path = config.get("vdl_dir", None)
|
||||
vdl_writer = None
|
||||
if vdl_writer_path:
|
||||
from visualdl import LogWriter
|
||||
vdl_writer = LogWriter(vdl_writer_path)
|
||||
# Ensure that the vdl log file can be closed normally
|
||||
try:
|
||||
for epoch_id in range(last_epoch_id + 1, config.epochs):
|
||||
net.train()
|
||||
# 1. train with train dataset
|
||||
program.run(train_dataloader, config, dp_net, optimizer,
|
||||
lr_scheduler, epoch_id, 'train', vdl_writer,
|
||||
args.profiler_options)
|
||||
|
||||
# 2. validate with validate dataset
|
||||
if config.validate and epoch_id % config.valid_interval == 0:
|
||||
net.eval()
|
||||
with paddle.no_grad():
|
||||
top1_acc = program.run(valid_dataloader, config, net, None,
|
||||
None, epoch_id, 'valid', vdl_writer)
|
||||
if top1_acc > best_top1_acc:
|
||||
best_top1_acc = top1_acc
|
||||
best_top1_epoch = epoch_id
|
||||
model_path = os.path.join(config.model_save_dir,
|
||||
config.ARCHITECTURE["name"])
|
||||
save_model(net, optimizer, model_path, "best_model")
|
||||
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
||||
best_top1_acc, best_top1_epoch)
|
||||
logger.info(message)
|
||||
|
||||
# 3. save the persistable model
|
||||
if epoch_id % config.save_interval == 0:
|
||||
model_path = os.path.join(config.model_save_dir,
|
||||
config.ARCHITECTURE["name"])
|
||||
save_model(net, optimizer, model_path, epoch_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
finally:
|
||||
vdl_writer.close() if vdl_writer else None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
||||
if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
config = config.get_config(args.config, overrides=args.override, show=True)
|
||||
trainer = Trainer(config, mode="train")
|
||||
trainer.train()
|
||||
|
|
Loading…
Reference in New Issue