rec_r45_abinet.yml add max_length and image_size (#10744)

* rec_r45_abinet.yml add max_length and image_shape

* image_shape to image_size
pull/10847/head
xlg-go 2023-08-31 14:23:47 +08:00 committed by GitHub
parent 66461c3325
commit e3cd343341
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 8 deletions

View File

@ -16,7 +16,7 @@ Global:
# for data or label process
character_dict_path:
character_type: en
max_text_length: 25
max_text_length: &max_text_length 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_abinet.txt
@ -45,6 +45,8 @@ Architecture:
name: ABINetHead
use_lang: True
iter_size: 3
max_length: *max_text_length
image_size: [ &h 32, &w 128 ] # [ h, w ]
Loss:
@ -70,7 +72,7 @@ Train:
- ABINetLabelEncode: # Class handling label
ignore_index: *ignore_index
- ABINetRecResizeImg:
image_shape: [3, 32, 128]
image_shape: [3, *h, *w]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
@ -90,7 +92,7 @@ Eval:
- ABINetLabelEncode: # Class handling label
ignore_index: *ignore_index
- ABINetRecResizeImg:
image_shape: [3, 32, 128]
image_shape: [3, *h, *w]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:

View File

@ -182,11 +182,13 @@ class ABINetHead(nn.Layer):
dropout=0.1,
max_length=25,
use_lang=False,
iter_size=1):
iter_size=1,
image_size=(32, 128)):
super().__init__()
self.max_length = max_length + 1
h, w = image_size[0] // 4, image_size[1] // 4
self.pos_encoder = PositionalEncoding(
dropout=0.1, dim=d_model, max_len=8 * 32)
dropout=0.1, dim=d_model, max_len=h * w)
self.encoder = nn.LayerList([
TransformerBlock(
d_model=d_model,
@ -199,7 +201,7 @@ class ABINetHead(nn.Layer):
])
self.decoder = PositionAttention(
max_length=max_length + 1, # additional stop token
mode='nearest', )
mode='nearest', h=h, w=w)
self.out_channels = out_channels
self.cls = nn.Linear(d_model, self.out_channels)
self.use_lang = use_lang