polist seed code
parent
1396186815
commit
a5280c0f40
|
@ -75,7 +75,7 @@ Train:
|
|||
channel_first: False
|
||||
- SEEDLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
character_type: en
|
||||
character_dict_path:
|
||||
image_shape: [3, 64, 256]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
|
@ -96,7 +96,7 @@ Eval:
|
|||
channel_first: False
|
||||
- SEEDLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
character_type: en
|
||||
character_dict_path:
|
||||
image_shape: [3, 64, 256]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
|
|
|
@ -344,8 +344,12 @@ class SEEDLabelEncode(BaseRecLabelEncode):
|
|||
max_text_length, character_dict_path, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.padding = "padding"
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character + [self.end_str]
|
||||
self.unknown = "unknown"
|
||||
dict_character = dict_character + [
|
||||
self.end_str, self.padding, self.unknown
|
||||
]
|
||||
return dict_character
|
||||
|
||||
def __call__(self, data):
|
||||
|
@ -356,8 +360,8 @@ class SEEDLabelEncode(BaseRecLabelEncode):
|
|||
if len(text) >= self.max_text_len:
|
||||
return None
|
||||
data['length'] = np.array(len(text)) + 1 # conclude eos
|
||||
text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
|
||||
)
|
||||
text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
|
||||
self.max_text_len - len(text) - 1)
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ class AsterHead(nn.Layer):
|
|||
self.time_step = time_step
|
||||
self.embeder = Embedding(self.time_step, in_channels)
|
||||
self.beam_width = beam_width
|
||||
self.eos = self.num_classes - 1
|
||||
self.eos = self.num_classes - 3
|
||||
|
||||
def forward(self, x, targets=None, embed=None):
|
||||
return_dict = {}
|
||||
|
|
|
@ -287,9 +287,12 @@ class SEEDLabelDecode(BaseRecLabelDecode):
|
|||
use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.padding_str = "padding"
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character + [self.end_str]
|
||||
self.unknown = "unknown"
|
||||
dict_character = dict_character + [
|
||||
self.end_str, self.padding_str, self.unknown
|
||||
]
|
||||
return dict_character
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
|
|
Loading…
Reference in New Issue