polist seed code

pull/4960/head
tink2123 2021-12-17 21:42:53 +08:00
parent 1396186815
commit a5280c0f40
4 changed files with 15 additions and 8 deletions

View File

@ -75,7 +75,7 @@ Train:
channel_first: False
- SEEDLabelEncode: # Class handling label
- RecResizeImg:
character_type: en
character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
@ -96,7 +96,7 @@ Eval:
channel_first: False
- SEEDLabelEncode: # Class handling label
- RecResizeImg:
character_type: en
character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:

View File

@ -344,8 +344,12 @@ class SEEDLabelEncode(BaseRecLabelEncode):
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.padding = "padding"
self.end_str = "eos"
dict_character = dict_character + [self.end_str]
self.unknown = "unknown"
dict_character = dict_character + [
self.end_str, self.padding, self.unknown
]
return dict_character
def __call__(self, data):
@ -356,8 +360,8 @@ class SEEDLabelEncode(BaseRecLabelEncode):
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text)) + 1 # conclude eos
text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
)
text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
self.max_text_len - len(text) - 1)
data['label'] = np.array(text)
return data

View File

@ -47,7 +47,7 @@ class AsterHead(nn.Layer):
self.time_step = time_step
self.embeder = Embedding(self.time_step, in_channels)
self.beam_width = beam_width
self.eos = self.num_classes - 1
self.eos = self.num_classes - 3
def forward(self, x, targets=None, embed=None):
return_dict = {}

View File

@ -287,9 +287,12 @@ class SEEDLabelDecode(BaseRecLabelDecode):
use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.padding_str = "padding"
self.end_str = "eos"
dict_character = dict_character + [self.end_str]
self.unknown = "unknown"
dict_character = dict_character + [
self.end_str, self.padding_str, self.unknown
]
return dict_character
def get_ignored_tokens(self):