add rotnet code (#6065)
* add rotnet code * add config * fix infer for ssl * rm unused codepull/6064/head^2
parent
3c9200c671
commit
4cddec7307
|
@ -0,0 +1,99 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 100
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_ppocr_v3_rotnet
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: true
|
||||
pretrained_model: null
|
||||
checkpoints: null
|
||||
save_inference_dir: null
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||
max_text_length: 25
|
||||
infer_mode: false
|
||||
use_space_char: true
|
||||
save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 1.0e-05
|
||||
Architecture:
|
||||
model_type: cls
|
||||
algorithm: CLS
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
last_conv_stride: [1, 2]
|
||||
last_pool_type: avg
|
||||
Neck:
|
||||
Head:
|
||||
name: ClsHead
|
||||
class_dim: 4
|
||||
|
||||
Loss:
|
||||
name: ClsLoss
|
||||
main_indicator: acc
|
||||
|
||||
PostProcess:
|
||||
name: ClsPostProcess
|
||||
|
||||
Metric:
|
||||
name: ClsMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
use_tia: False
|
||||
- RandAugment:
|
||||
- SSLRotateResize:
|
||||
image_shape: [3, 48, 320]
|
||||
- KeepKeys:
|
||||
keep_keys: ["image", "label"]
|
||||
loader:
|
||||
collate_fn: "SSLRotateCollate"
|
||||
shuffle: true
|
||||
batch_size_per_card: 32
|
||||
drop_last: true
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- SSLRotateResize:
|
||||
image_shape: [3, 48, 320]
|
||||
- KeepKeys:
|
||||
keep_keys: ["image", "label"]
|
||||
loader:
|
||||
collate_fn: "SSLRotateCollate"
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 64
|
||||
num_workers: 8
|
||||
profiler_options: null
|
|
@ -35,17 +35,7 @@ from ppocr.metrics import build_metric
|
|||
import tools.program as program
|
||||
from paddleslim.dygraph.quant import QAT
|
||||
from ppocr.data import build_dataloader
|
||||
|
||||
|
||||
def export_single_model(quanter, model, infer_shape, save_path, logger):
|
||||
quanter.save_quantized_model(
|
||||
model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
])
|
||||
logger.info('inference QAT model is saved to {}'.format(save_path))
|
||||
from tools.export_model import export_single_model
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -84,17 +74,54 @@ def main():
|
|||
config['Global'])
|
||||
|
||||
# build model
|
||||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
if config['Architecture']['Models'][key]['Head'][
|
||||
'name'] == 'MultiHead': # for multi head
|
||||
if config['PostProcess'][
|
||||
'name'] == 'DistillationSARLabelDecode':
|
||||
char_num = char_num - 2
|
||||
# update SARLoss params
|
||||
assert list(config['Loss']['loss_config_list'][-1].keys())[
|
||||
0] == 'DistillationSARLoss'
|
||||
config['Loss']['loss_config_list'][-1][
|
||||
'DistillationSARLoss']['ignore_index'] = char_num + 1
|
||||
out_channels_list = {}
|
||||
out_channels_list['CTCLabelDecode'] = char_num
|
||||
out_channels_list['SARLabelDecode'] = char_num + 2
|
||||
config['Architecture']['Models'][key]['Head'][
|
||||
'out_channels_list'] = out_channels_list
|
||||
else:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
elif config['Architecture']['Head'][
|
||||
'name'] == 'MultiHead': # for multi head
|
||||
if config['PostProcess']['name'] == 'SARLabelDecode':
|
||||
char_num = char_num - 2
|
||||
# update SARLoss params
|
||||
assert list(config['Loss']['loss_config_list'][1].keys())[
|
||||
0] == 'SARLoss'
|
||||
if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
|
||||
config['Loss']['loss_config_list'][1]['SARLoss'] = {
|
||||
'ignore_index': char_num + 1
|
||||
}
|
||||
else:
|
||||
config['Loss']['loss_config_list'][1]['SARLoss'][
|
||||
'ignore_index'] = char_num + 1
|
||||
out_channels_list = {}
|
||||
out_channels_list['CTCLabelDecode'] = char_num
|
||||
out_channels_list['SARLabelDecode'] = char_num + 2
|
||||
config['Architecture']['Head'][
|
||||
'out_channels_list'] = out_channels_list
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
|
||||
config['Loss']['ignore_index'] = char_num - 1
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
# get QAT model
|
||||
|
@ -120,21 +147,22 @@ def main():
|
|||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
||||
infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640]
|
||||
|
||||
save_path = config["Global"]["save_inference_dir"]
|
||||
|
||||
arch_config = config["Architecture"]
|
||||
|
||||
arch_config = config["Architecture"]
|
||||
|
||||
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
|
||||
archs = list(arch_config["Models"].values())
|
||||
for idx, name in enumerate(model.model_name_list):
|
||||
model.model_list[idx].eval()
|
||||
sub_model_save_path = os.path.join(save_path, name, "inference")
|
||||
export_single_model(quanter, model.model_list[idx], infer_shape,
|
||||
sub_model_save_path, logger)
|
||||
export_single_model(model.model_list[idx], archs[idx],
|
||||
sub_model_save_path, logger, quanter)
|
||||
else:
|
||||
save_path = os.path.join(save_path, "inference")
|
||||
model.eval()
|
||||
export_single_model(quanter, model, infer_shape, save_path, logger)
|
||||
export_single_model(model, arch_config, save_path, logger, quanter)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -112,10 +112,48 @@ def main(config, device, logger, vdl_writer):
|
|||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
if config['Architecture']['Models'][key]['Head'][
|
||||
'name'] == 'MultiHead': # for multi head
|
||||
if config['PostProcess'][
|
||||
'name'] == 'DistillationSARLabelDecode':
|
||||
char_num = char_num - 2
|
||||
# update SARLoss params
|
||||
assert list(config['Loss']['loss_config_list'][-1].keys())[
|
||||
0] == 'DistillationSARLoss'
|
||||
config['Loss']['loss_config_list'][-1][
|
||||
'DistillationSARLoss']['ignore_index'] = char_num + 1
|
||||
out_channels_list = {}
|
||||
out_channels_list['CTCLabelDecode'] = char_num
|
||||
out_channels_list['SARLabelDecode'] = char_num + 2
|
||||
config['Architecture']['Models'][key]['Head'][
|
||||
'out_channels_list'] = out_channels_list
|
||||
else:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
elif config['Architecture']['Head'][
|
||||
'name'] == 'MultiHead': # for multi head
|
||||
if config['PostProcess']['name'] == 'SARLabelDecode':
|
||||
char_num = char_num - 2
|
||||
# update SARLoss params
|
||||
assert list(config['Loss']['loss_config_list'][1].keys())[
|
||||
0] == 'SARLoss'
|
||||
if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
|
||||
config['Loss']['loss_config_list'][1]['SARLoss'] = {
|
||||
'ignore_index': char_num + 1
|
||||
}
|
||||
else:
|
||||
config['Loss']['loss_config_list'][1]['SARLoss'][
|
||||
'ignore_index'] = char_num + 1
|
||||
out_channels_list = {}
|
||||
out_channels_list['CTCLabelDecode'] = char_num
|
||||
out_channels_list['SARLabelDecode'] = char_num + 2
|
||||
config['Architecture']['Head'][
|
||||
'out_channels_list'] = out_channels_list
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
|
||||
config['Loss']['ignore_index'] = char_num - 1
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
pre_best_model_dict = dict()
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
|
||||
#### 2、CDLA数据集
|
||||
- **数据来源**:https://github.com/buptlihang/CDLA
|
||||
- **数据简介**:publaynet数据集的训练集合中包含5000张图像,验证集合中包含1000张图像。总共包含10个类别,分别是: `Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation`。部分图像以及标注框可视化如下所示。
|
||||
- **数据简介**:CDLA据集的训练集合中包含5000张图像,验证集合中包含1000张图像。总共包含10个类别,分别是: `Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation`。部分图像以及标注框可视化如下所示。
|
||||
|
||||
<div align="center">
|
||||
<img src="../datasets/CDLA_demo/val_0633.jpg" width="500">
|
||||
|
|
|
@ -72,6 +72,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
|
|||
use_shared_memory = loader_config['use_shared_memory']
|
||||
else:
|
||||
use_shared_memory = True
|
||||
|
||||
if mode == "Train":
|
||||
# Distribute data to multiple cards
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
|
|
|
@ -56,3 +56,17 @@ class ListCollator(object):
|
|||
for idx in to_tensor_idxs:
|
||||
data_dict[idx] = paddle.to_tensor(data_dict[idx])
|
||||
return list(data_dict.values())
|
||||
|
||||
|
||||
class SSLRotateCollate(object):
|
||||
"""
|
||||
bach: [
|
||||
[(4*3xH*W), (4,)]
|
||||
[(4*3xH*W), (4,)]
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
def __call__(self, batch):
|
||||
output = [np.concatenate(d, axis=0) for d in zip(*batch)]
|
||||
return output
|
||||
|
|
|
@ -24,6 +24,7 @@ from .make_pse_gt import MakePseGt
|
|||
|
||||
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg
|
||||
from .ssl_img_aug import SSLRotateResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .ColorJitter import ColorJitter
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
import random
|
||||
from PIL import Image
|
||||
|
||||
from .rec_img_aug import resize_norm_img
|
||||
|
||||
|
||||
class SSLRotateResize(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
padding=False,
|
||||
select_all=True,
|
||||
mode="train",
|
||||
**kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.padding = padding
|
||||
self.select_all = select_all
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, data):
|
||||
img = data["image"]
|
||||
|
||||
data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
||||
data["image_r180"] = cv2.rotate(data["image_r90"],
|
||||
cv2.ROTATE_90_CLOCKWISE)
|
||||
data["image_r270"] = cv2.rotate(data["image_r180"],
|
||||
cv2.ROTATE_90_CLOCKWISE)
|
||||
|
||||
images = []
|
||||
for key in ["image", "image_r90", "image_r180", "image_r270"]:
|
||||
images.append(
|
||||
resize_norm_img(
|
||||
data.pop(key),
|
||||
image_shape=self.image_shape,
|
||||
padding=self.padding)[0])
|
||||
data["image"] = np.stack(images, axis=0)
|
||||
data["label"] = np.array(list(range(4)))
|
||||
if not self.select_all:
|
||||
data["image"] = data["image"][0::2] # just choose 0 and 180
|
||||
data["label"] = data["label"][0:2] # label needs to be continuous
|
||||
if self.mode == "test":
|
||||
data["image"] = data["image"][0]
|
||||
data["label"] = data["label"][0]
|
||||
return data
|
|
@ -17,17 +17,26 @@ import paddle
|
|||
class ClsPostProcess(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, label_list, **kwargs):
|
||||
def __init__(self, label_list=None, key=None, **kwargs):
|
||||
super(ClsPostProcess, self).__init__()
|
||||
self.label_list = label_list
|
||||
self.key = key
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if self.key is not None:
|
||||
preds = preds[self.key]
|
||||
|
||||
label_list = self.label_list
|
||||
if label_list is None:
|
||||
label_list = {idx: idx for idx in range(preds.shape[-1])}
|
||||
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
|
||||
pred_idxs = preds.argmax(axis=1)
|
||||
decode_out = [(self.label_list[idx], preds[i, idx])
|
||||
decode_out = [(label_list[idx], preds[i, idx])
|
||||
for i, idx in enumerate(pred_idxs)]
|
||||
if label is None:
|
||||
return decode_out
|
||||
label = [(self.label_list[idx], 1.0) for idx in label]
|
||||
label = [(label_list[idx], 1.0) for idx in label]
|
||||
return decode_out, label
|
||||
|
|
|
@ -31,7 +31,7 @@ from ppocr.utils.logging import get_logger
|
|||
from tools.program import load_config, merge_config, ArgsParser
|
||||
|
||||
|
||||
def export_single_model(model, arch_config, save_path, logger):
|
||||
def export_single_model(model, arch_config, save_path, logger, quanter=None):
|
||||
if arch_config["algorithm"] == "SRN":
|
||||
max_text_length = arch_config["Head"]["max_text_length"]
|
||||
other_shape = [
|
||||
|
@ -95,7 +95,10 @@ def export_single_model(model, arch_config, save_path, logger):
|
|||
shape=[None] + infer_shape, dtype="float32")
|
||||
])
|
||||
|
||||
paddle.jit.save(model, save_path)
|
||||
if quanter is None:
|
||||
paddle.jit.save(model, save_path)
|
||||
else:
|
||||
quanter.save_quantized_model(model, save_path)
|
||||
logger.info("inference model is saved to {}".format(save_path))
|
||||
return
|
||||
|
||||
|
@ -125,7 +128,6 @@ def main():
|
|||
char_num = char_num - 2
|
||||
out_channels_list['CTCLabelDecode'] = char_num
|
||||
out_channels_list['SARLabelDecode'] = char_num + 2
|
||||
loss_list = config['Loss']['loss_config_list']
|
||||
config['Architecture']['Models'][key]['Head'][
|
||||
'out_channels_list'] = out_channels_list
|
||||
else:
|
||||
|
|
|
@ -57,6 +57,8 @@ def main():
|
|||
continue
|
||||
elif op_name == 'KeepKeys':
|
||||
op[op_name]['keep_keys'] = ['image']
|
||||
elif op_name == "SSLRotateResize":
|
||||
op[op_name]["mode"] = "test"
|
||||
transforms.append(op)
|
||||
global_config['infer_mode'] = True
|
||||
ops = create_operators(transforms, global_config)
|
||||
|
|
Loading…
Reference in New Issue