rec_r45_abinet.yml add max_length and image_size (#10744)
* rec_r45_abinet.yml add max_length and image_shape * image_shape to image_sizepull/10847/head
parent
66461c3325
commit
e3cd343341
|
@ -16,7 +16,7 @@ Global:
|
|||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
max_text_length: 25
|
||||
max_text_length: &max_text_length 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_abinet.txt
|
||||
|
@ -45,6 +45,8 @@ Architecture:
|
|||
name: ABINetHead
|
||||
use_lang: True
|
||||
iter_size: 3
|
||||
max_length: *max_text_length
|
||||
image_size: [ &h 32, &w 128 ] # [ h, w ]
|
||||
|
||||
|
||||
Loss:
|
||||
|
@ -70,7 +72,7 @@ Train:
|
|||
- ABINetLabelEncode: # Class handling label
|
||||
ignore_index: *ignore_index
|
||||
- ABINetRecResizeImg:
|
||||
image_shape: [3, 32, 128]
|
||||
image_shape: [3, *h, *w]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
|
@ -90,7 +92,7 @@ Eval:
|
|||
- ABINetLabelEncode: # Class handling label
|
||||
ignore_index: *ignore_index
|
||||
- ABINetRecResizeImg:
|
||||
image_shape: [3, 32, 128]
|
||||
image_shape: [3, *h, *w]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
|
|
|
@ -182,11 +182,13 @@ class ABINetHead(nn.Layer):
|
|||
dropout=0.1,
|
||||
max_length=25,
|
||||
use_lang=False,
|
||||
iter_size=1):
|
||||
iter_size=1,
|
||||
image_size=(32, 128)):
|
||||
super().__init__()
|
||||
self.max_length = max_length + 1
|
||||
h, w = image_size[0] // 4, image_size[1] // 4
|
||||
self.pos_encoder = PositionalEncoding(
|
||||
dropout=0.1, dim=d_model, max_len=8 * 32)
|
||||
dropout=0.1, dim=d_model, max_len=h * w)
|
||||
self.encoder = nn.LayerList([
|
||||
TransformerBlock(
|
||||
d_model=d_model,
|
||||
|
@ -199,7 +201,7 @@ class ABINetHead(nn.Layer):
|
|||
])
|
||||
self.decoder = PositionAttention(
|
||||
max_length=max_length + 1, # additional stop token
|
||||
mode='nearest', )
|
||||
mode='nearest', h=h, w=w)
|
||||
self.out_channels = out_channels
|
||||
self.cls = nn.Linear(d_model, self.out_channels)
|
||||
self.use_lang = use_lang
|
||||
|
|
Loading…
Reference in New Issue