add d2s train for slanet and v3 (#9341)
* add d2s train for slanet and v3 * fix bugpull/9379/head^2
parent
623424fce0
commit
2e05d54af8
|
@ -17,6 +17,7 @@ Global:
|
|||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
||||
distributed: true
|
||||
d2s_train_image_shape: [3, -1, -1]
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
|
|
|
@ -12,6 +12,7 @@ Global:
|
|||
use_visualdl: False
|
||||
seed: 2022
|
||||
infer_img: ppstructure/docs/kie/input/zh_val_42.jpg
|
||||
d2s_train_image_shape: [3, 224, 224]
|
||||
# if you want to predict using the groundtruth ocr info,
|
||||
# you can use the following config
|
||||
# infer_img: train_data/XFUND/zh_val/val.json
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
|
||||
d2s_train_image_shape: [3, 48, -1]
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
|
|
@ -21,6 +21,7 @@ Global:
|
|||
infer_mode: False
|
||||
use_sync_bn: True
|
||||
save_res_path: 'output/infer'
|
||||
d2s_train_image_shape: [3, -1, -1]
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
|
|
@ -17,6 +17,7 @@ Global:
|
|||
infer_mode: false
|
||||
max_text_length: &max_text_length 500
|
||||
box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy'
|
||||
d2s_train_image_shape: [3, 480, 480]
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
|
|
@ -38,9 +38,9 @@ def build_model(config):
|
|||
def apply_to_static(model, config, logger):
|
||||
if config["Global"].get("to_static", False) is not True:
|
||||
return model
|
||||
assert "image_shape" in config[
|
||||
"Global"], "image_shape must be assigned for static training mode..."
|
||||
supported_list = ["DB", "SVTR_LCNet", "TableMaster"]
|
||||
assert "d2s_train_image_shape" in config[
|
||||
"Global"], "d2s_train_image_shape must be assigned for static training mode..."
|
||||
supported_list = ["DB", "SVTR_LCNet", "TableMaster", "LayoutXLM", "SLANet"]
|
||||
if config["Architecture"]["algorithm"] in ["Distillation"]:
|
||||
algo = list(config["Architecture"]["Models"].values())[0]["algorithm"]
|
||||
else:
|
||||
|
@ -49,7 +49,7 @@ def apply_to_static(model, config, logger):
|
|||
|
||||
specs = [
|
||||
InputSpec(
|
||||
[None] + config["Global"]["image_shape"], dtype='float32')
|
||||
[None] + config["Global"]["d2s_train_image_shape"], dtype='float32')
|
||||
]
|
||||
|
||||
if algo == "SVTR_LCNet":
|
||||
|
@ -62,7 +62,7 @@ def apply_to_static(model, config, logger):
|
|||
[None], dtype='int64'), InputSpec(
|
||||
[None], dtype='float64')
|
||||
])
|
||||
if algo == "TableMaster":
|
||||
elif algo == "TableMaster":
|
||||
specs.append(
|
||||
[
|
||||
InputSpec(
|
||||
|
@ -76,6 +76,34 @@ def apply_to_static(model, config, logger):
|
|||
InputSpec(
|
||||
[None, 6], dtype='float32'),
|
||||
])
|
||||
elif algo == "LayoutXLM":
|
||||
specs = [[
|
||||
InputSpec(
|
||||
shape=[None, 512], dtype="int64"), # input_ids
|
||||
InputSpec(
|
||||
shape=[None, 512, 4], dtype="int64"), # bbox
|
||||
InputSpec(
|
||||
shape=[None, 512], dtype="int64"), # attention_mask
|
||||
InputSpec(
|
||||
shape=[None, 512], dtype="int64"), # token_type_ids
|
||||
InputSpec(
|
||||
shape=[None, 3, 224, 224], dtype="float32"), # image
|
||||
InputSpec(
|
||||
shape=[None, 512], dtype="int64"), # label
|
||||
]]
|
||||
elif algo == "SLANet":
|
||||
specs.append([
|
||||
InputSpec(
|
||||
[None, config["Global"]["max_text_length"] + 2], dtype='int64'),
|
||||
InputSpec(
|
||||
[None, config["Global"]["max_text_length"] + 2, 4],
|
||||
dtype='float32'),
|
||||
InputSpec(
|
||||
[None, config["Global"]["max_text_length"] + 2, 1],
|
||||
dtype='float32'),
|
||||
InputSpec(
|
||||
[None, 6], dtype='float64'),
|
||||
])
|
||||
model = to_static(model, input_spec=specs)
|
||||
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
|
||||
return model
|
||||
|
|
|
@ -20,6 +20,8 @@ from tqdm import tqdm
|
|||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
MODELS_DIR = os.path.expanduser("~/.paddleocr/models/")
|
||||
|
||||
|
||||
def download_with_progressbar(url, save_path):
|
||||
logger = get_logger()
|
||||
|
|
|
@ -17,7 +17,7 @@ norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o
|
|||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
to_static_train:Global.to_static=true
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
|
||||
d2s_train_image_shape: [3, 48, -1]
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
|
|
@ -17,7 +17,7 @@ norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_d
|
|||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
to_static_train:Global.to_static=true
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
|
|
|
@ -21,6 +21,7 @@ Global:
|
|||
infer_mode: False
|
||||
use_sync_bn: True
|
||||
save_res_path: 'output/infer'
|
||||
d2s_train_image_shape: [3, -1, -1]
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
|
|
@ -17,7 +17,7 @@ norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o Global.print
|
|||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
to_static_train:Global.to_static=true
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
|
|
|
@ -16,7 +16,7 @@ Global:
|
|||
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
|
||||
infer_mode: false
|
||||
max_text_length: 500
|
||||
image_shape: [3, 480, 480]
|
||||
d2s_train_image_shape: [3, 480, 480]
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
|
|
@ -17,7 +17,7 @@ norm_train:tools/train.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_z
|
|||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
to_static_train:Global.to_static=true
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
|
|
Loading…
Reference in New Issue