PaddleClas/ppcls/engine/train/train.py

95 lines
3.4 KiB
Python
Raw Normal View History

2021-08-22 23:10:23 +08:00
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import time
import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info
2021-09-22 16:29:44 +08:00
from ppcls.utils import profiler
2021-08-22 23:10:23 +08:00
2021-09-23 11:22:25 +08:00
def train_epoch(engine, epoch_id, print_batch_step):
2021-08-22 23:10:23 +08:00
tic = time.time()
2021-09-23 11:22:25 +08:00
for iter_id, batch in enumerate(engine.train_dataloader):
if iter_id >= engine.max_iter:
2021-08-22 23:10:23 +08:00
break
2021-10-13 17:30:11 +08:00
profiler.add_profiler_step(engine.config["profiler_options"])
2021-08-22 23:10:23 +08:00
if iter_id == 5:
2021-09-23 11:22:25 +08:00
for key in engine.time_info:
engine.time_info[key].reset()
engine.time_info["reader_cost"].update(time.time() - tic)
if engine.use_dali:
2021-08-22 23:10:23 +08:00
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
batch_size = batch[0].shape[0]
if not engine.config["Global"].get("use_multilabel", False):
2021-09-30 18:16:57 +08:00
batch[1] = batch[1].reshape([batch_size, -1])
2021-09-23 11:22:25 +08:00
engine.global_step += 1
2021-08-22 23:10:23 +08:00
# image input
2021-09-23 11:22:25 +08:00
if engine.amp:
amp_level = engine.config['AMP'].get("level", "O1").upper()
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=amp_level):
2021-09-23 11:22:25 +08:00
out = forward(engine, batch)
2021-12-20 14:36:56 +08:00
loss_dict = engine.train_loss_func(out, batch[1])
2021-08-22 23:10:23 +08:00
else:
2021-09-23 11:22:25 +08:00
out = forward(engine, batch)
2021-12-20 14:36:56 +08:00
loss_dict = engine.train_loss_func(out, batch[1])
2021-08-22 23:10:23 +08:00
2022-04-19 14:26:42 +08:00
# step opt
2021-09-23 11:22:25 +08:00
if engine.amp:
scaled = engine.scaler.scale(loss_dict["loss"])
2021-08-22 23:10:23 +08:00
scaled.backward()
2022-04-19 14:26:42 +08:00
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
2021-08-22 23:10:23 +08:00
else:
loss_dict["loss"].backward()
2022-04-19 14:26:42 +08:00
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
2022-04-21 00:17:54 +08:00
if hasattr(engine.model.neck, 'bn'):
engine.model.neck.bn.bias.grad.set_value(
paddle.zeros_like(engine.model.neck.bn.bias.grad))
2022-04-19 14:26:42 +08:00
# clear grad
for i in range(len(engine.optimizer)):
engine.optimizer[i].clear_grad()
2022-04-21 00:17:54 +08:00
2022-04-19 14:26:42 +08:00
# step lr
for i in range(len(engine.lr_sch)):
engine.lr_sch[i].step()
2021-08-22 23:10:23 +08:00
# below code just for logging
# update metric_for_logger
2021-09-23 11:22:25 +08:00
update_metric(engine, out, batch, batch_size)
2021-08-22 23:10:23 +08:00
# update_loss_for_logger
2021-09-23 11:22:25 +08:00
update_loss(engine, loss_dict, batch_size)
engine.time_info["batch_cost"].update(time.time() - tic)
2021-08-22 23:10:23 +08:00
if iter_id % print_batch_step == 0:
2021-09-23 11:22:25 +08:00
log_info(engine, batch_size, epoch_id, iter_id)
2021-08-22 23:10:23 +08:00
tic = time.time()
2021-08-24 11:02:55 +08:00
2021-08-24 15:07:17 +08:00
2021-09-27 11:00:37 +08:00
def forward(engine, batch):
if not engine.is_rec:
return engine.model(batch[0])
2021-08-24 11:02:55 +08:00
else:
2021-09-27 11:00:37 +08:00
return engine.model(batch[0], batch[1])