PaddleOCR/ppstructure/vqa/train_ser.py

314 lines
11 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import random
import copy
import logging
import argparse
import paddle
import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from xfun import XFUNDataset
from utils import parse_args
from utils import get_bio_label_maps
logger = logging.getLogger(__name__)
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
paddle.seed(args.seed)
def train(args):
os.makedirs(args.output_dir, exist_ok=True)
logging.basicConfig(
filename=os.path.join(args.output_dir, "train.log")
if paddle.distributed.get_rank() == 0 else None,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO
if paddle.distributed.get_rank() == 0 else logging.WARN, )
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
logger.addHandler(ch)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
# dist mode
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
base_model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
model = LayoutXLMForTokenClassification(
base_model, num_classes=len(label2id_map), dropout=None)
# dist mode
if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model)
train_dataset = XFUNDataset(
tokenizer,
data_dir=args.train_data_dir,
label_path=args.train_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
pad_token_label_id=pad_token_label_id,
contains_re=False,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
train_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
args.train_batch_size = args.per_gpu_train_batch_size * max(
1, paddle.distributed.get_world_size())
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=0,
use_shared_memory=True,
collate_fn=None, )
t_total = len(train_dataloader) * args.num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
learning_rate=args.learning_rate,
decay_steps=t_total,
end_lr=0.0,
power=1.0)
if args.warmup_steps > 0:
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
lr_scheduler,
args.warmup_steps,
start_lr=0,
end_lr=args.learning_rate, )
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
epsilon=args.adam_epsilon,
weight_decay=args.weight_decay)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d",
args.per_gpu_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed) = %d",
args.train_batch_size * paddle.distributed.get_world_size(), )
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
tr_loss = 0.0
set_seed(args)
best_metrics = None
for epoch_id in range(args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
model.train()
outputs = model(**batch)
# model outputs are always tuple in ppnlp (see doc)
loss = outputs[0]
loss = loss.mean()
logger.info(
"[epoch {}/{}][iter: {}/{}] lr: {:.5f}, train loss: {:.5f}, ".
format(epoch_id, args.num_train_epochs, step,
len(train_dataloader),
lr_scheduler.get_lr(), loss.numpy()[0]))
loss.backward()
tr_loss += loss.item()
optimizer.step()
lr_scheduler.step() # Update learning rate schedule
optimizer.clear_grad()
global_step += 1
if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and
global_step % args.eval_steps == 0):
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
if paddle.distributed.get_rank(
) == 0 and args.evaluate_during_training:
results, _ = evaluate(
args,
model,
tokenizer,
label2id_map,
id2label_map,
pad_token_label_id, )
if best_metrics is None or results["f1"] >= best_metrics[
"f1"]:
best_metrics = copy.deepcopy(results)
output_dir = os.path.join(args.output_dir, "best_model")
os.makedirs(output_dir, exist_ok=True)
if paddle.distributed.get_rank() == 0:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(
args,
os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s",
output_dir)
logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
epoch_id, args.num_train_epochs, step,
len(train_dataloader), results))
if best_metrics is not None:
logger.info("best metrics: {}".format(best_metrics))
if paddle.distributed.get_rank(
) == 0 and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir,
"checkpoint-{}".format(global_step))
os.makedirs(output_dir, exist_ok=True)
if paddle.distributed.get_rank() == 0:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir)
return global_step, tr_loss / global_step
def evaluate(args,
model,
tokenizer,
label2id_map,
id2label_map,
pad_token_label_id,
prefix=""):
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
pad_token_label_id=pad_token_label_id,
contains_re=False,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
args.eval_batch_size = args.per_gpu_eval_batch_size * max(
1, paddle.distributed.get_world_size())
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.eval_batch_size,
num_workers=0,
use_shared_memory=True,
collate_fn=None, )
# Eval!
logger.info("***** Running evaluation %s *****", prefix)
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0.0
nb_eval_steps = 0
preds = None
out_label_ids = None
model.eval()
for idx, batch in enumerate(eval_dataloader):
with paddle.no_grad():
outputs = model(**batch)
tmp_eval_loss, logits = outputs[:2]
tmp_eval_loss = tmp_eval_loss.mean()
if paddle.distributed.get_rank() == 0:
logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
idx, len(eval_dataloader), tmp_eval_loss.numpy()[0]))
eval_loss += tmp_eval_loss.item()
nb_eval_steps += 1
if preds is None:
preds = logits.numpy()
out_label_ids = batch["labels"].numpy()
else:
preds = np.append(preds, logits.numpy(), axis=0)
out_label_ids = np.append(
out_label_ids, batch["labels"].numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = np.argmax(preds, axis=2)
# label_map = {i: label.upper() for i, label in enumerate(labels)}
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
preds_list = [[] for _ in range(out_label_ids.shape[0])]
for i in range(out_label_ids.shape[0]):
for j in range(out_label_ids.shape[1]):
if out_label_ids[i, j] != pad_token_label_id:
out_label_list[i].append(id2label_map[out_label_ids[i][j]])
preds_list[i].append(id2label_map[preds[i][j]])
results = {
"loss": eval_loss,
"precision": precision_score(out_label_list, preds_list),
"recall": recall_score(out_label_list, preds_list),
"f1": f1_score(out_label_list, preds_list),
}
with open(os.path.join(args.output_dir, "test_gt.txt"), "w") as fout:
for lbl in out_label_list:
for l in lbl:
fout.write(l + "\t")
fout.write("\n")
with open(os.path.join(args.output_dir, "test_pred.txt"), "w") as fout:
for lbl in preds_list:
for l in lbl:
fout.write(l + "\t")
fout.write("\n")
report = classification_report(out_label_list, preds_list)
logger.info("\n" + report)
logger.info("***** Eval results %s *****", prefix)
for key in sorted(results.keys()):
logger.info(" %s = %s", key, str(results[key]))
return results, preds_list
def print_arguments(args):
"""print arguments"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == "__main__":
args = parse_args()
print_arguments(args)
train(args)