[Fix] #1428: remove max_seq_len inconsistency (#1433)

pull/1437/head
Antonio Lanza 2022-10-09 08:05:33 +02:00 committed by GitHub
parent b422dedd8d
commit 5fc920495a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 80 additions and 25 deletions

View File

@ -11,7 +11,7 @@ label_convertor = dict(
with_unknown=False,
with_padding=False,
lower=True,
)
max_seq_len=max_seq_len)
model = dict(
type='ABINet',

View File

@ -1,5 +1,10 @@
max_seq_len = 30
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
type='AttnConvertor',
dict_type='DICT90',
with_unknown=True,
max_seq_len=max_seq_len)
model = dict(
type='MASTER',
@ -58,4 +63,4 @@ model = dict(
feat_size=6 * 40),
loss=dict(type='TFLoss', reduction='mean'),
label_convertor=label_convertor,
max_seq_len=30)
max_seq_len=max_seq_len)

View File

@ -1,5 +1,11 @@
max_seq_len = 40
label_convertor = dict(
type='AttnConvertor', dict_type='DICT36', with_unknown=True, lower=True)
type='AttnConvertor',
dict_type='DICT36',
with_unknown=True,
lower=True,
max_seq_len=max_seq_len)
model = dict(
type='NRTR',
@ -8,4 +14,4 @@ model = dict(
decoder=dict(type='NRTRDecoder'),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=40)
max_seq_len=max_seq_len)

View File

@ -1,5 +1,10 @@
max_seq_len = 30
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
type='AttnConvertor',
dict_type='DICT90',
with_unknown=True,
max_seq_len=max_seq_len)
hybrid_decoder = dict(type='SequenceAttentionDecoder')
@ -21,4 +26,4 @@ model = dict(
position_decoder=position_decoder),
loss=dict(type='SARLoss'),
label_convertor=label_convertor,
max_seq_len=30)
max_seq_len=max_seq_len)

View File

@ -1,5 +1,10 @@
max_seq_len = 30
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
type='AttnConvertor',
dict_type='DICT90',
with_unknown=True,
max_seq_len=max_seq_len)
model = dict(
type='SARNet',
@ -21,4 +26,4 @@ model = dict(
pred_concat=True),
loss=dict(type='SARLoss'),
label_convertor=label_convertor,
max_seq_len=30)
max_seq_len=max_seq_len)

View File

@ -1,5 +1,11 @@
max_seq_len = 40
label_convertor = dict(
type='AttnConvertor', dict_type='DICT36', with_unknown=True, lower=True)
type='AttnConvertor',
dict_type='DICT36',
with_unknown=True,
lower=True,
max_seq_len=max_seq_len)
model = dict(
type='SATRN',
@ -8,4 +14,4 @@ model = dict(
decoder=dict(type='TFDecoder'),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=40)
max_seq_len=max_seq_len)

View File

@ -21,7 +21,7 @@ label_convertor = dict(
with_unknown=False,
with_padding=False,
lower=True,
)
max_seq_len=max_seq_len)
model = dict(
type='ABINet',

View File

@ -12,8 +12,13 @@ test_list = {{_base_.test_list}}
train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}
max_seq_len = 40
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
type='AttnConvertor',
dict_type='DICT90',
with_unknown=True,
max_seq_len=max_seq_len)
model = dict(
type='NRTR',
@ -27,7 +32,7 @@ model = dict(
decoder=dict(type='NRTRDecoder'),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=40)
max_seq_len=max_seq_len)
data = dict(
samples_per_gpu=128,

View File

@ -12,8 +12,13 @@ test_list = {{_base_.test_list}}
train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}
max_seq_len = 40
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
type='AttnConvertor',
dict_type='DICT90',
with_unknown=True,
max_seq_len=max_seq_len)
model = dict(
type='NRTR',
@ -27,7 +32,7 @@ model = dict(
decoder=dict(type='NRTRDecoder'),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=40)
max_seq_len=max_seq_len)
data = dict(
samples_per_gpu=64,

View File

@ -2,10 +2,13 @@ _base_ = [
'../../_base_/default_runtime.py',
'../../_base_/schedules/schedule_adam_step_5e.py'
]
max_seq_len = 30
dict_file = 'data/chineseocr/labels/dict_printed_chinese_english_digits.txt'
label_convertor = dict(
type='AttnConvertor', dict_file=dict_file, with_unknown=True)
type='AttnConvertor',
dict_file=dict_file,
with_unknown=True,
max_seq_len=max_seq_len)
model = dict(
type='SARNet',
@ -27,7 +30,7 @@ model = dict(
pred_concat=True),
loss=dict(type='SARLoss'),
label_convertor=label_convertor,
max_seq_len=30)
max_seq_len=max_seq_len)
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
train_pipeline = [

View File

@ -12,8 +12,13 @@ test_list = {{_base_.test_list}}
train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}
max_seq_len = 30
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
type='AttnConvertor',
dict_type='DICT90',
with_unknown=True,
max_seq_len=max_seq_len)
model = dict(
type='SARNet',
@ -35,7 +40,7 @@ model = dict(
pred_concat=True),
loss=dict(type='SARLoss'),
label_convertor=label_convertor,
max_seq_len=30)
max_seq_len=max_seq_len)
data = dict(
samples_per_gpu=64,

View File

@ -12,8 +12,13 @@ test_list = {{_base_.test_list}}
train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}
max_seq_len = 25
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
type='AttnConvertor',
dict_type='DICT90',
with_unknown=True,
max_seq_len=max_seq_len)
model = dict(
type='SATRN',
@ -39,7 +44,7 @@ model = dict(
d_v=512 // 8),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=25)
max_seq_len=max_seq_len)
# optimizer
optimizer = dict(type='Adam', lr=3e-4)

View File

@ -12,8 +12,13 @@ test_list = {{_base_.test_list}}
train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}
max_seq_len = 25
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
type='AttnConvertor',
dict_type='DICT90',
with_unknown=True,
max_seq_len=max_seq_len)
model = dict(
type='SATRN',
@ -39,7 +44,7 @@ model = dict(
d_v=256 // 8),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=25)
max_seq_len=max_seq_len)
# optimizer
optimizer = dict(type='Adam', lr=3e-4)