mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
polist seed code
This commit is contained in:
parent
1396186815
commit
a5280c0f40
@ -75,7 +75,7 @@ Train:
|
|||||||
channel_first: False
|
channel_first: False
|
||||||
- SEEDLabelEncode: # Class handling label
|
- SEEDLabelEncode: # Class handling label
|
||||||
- RecResizeImg:
|
- RecResizeImg:
|
||||||
character_type: en
|
character_dict_path:
|
||||||
image_shape: [3, 64, 256]
|
image_shape: [3, 64, 256]
|
||||||
padding: False
|
padding: False
|
||||||
- KeepKeys:
|
- KeepKeys:
|
||||||
@ -96,7 +96,7 @@ Eval:
|
|||||||
channel_first: False
|
channel_first: False
|
||||||
- SEEDLabelEncode: # Class handling label
|
- SEEDLabelEncode: # Class handling label
|
||||||
- RecResizeImg:
|
- RecResizeImg:
|
||||||
character_type: en
|
character_dict_path:
|
||||||
image_shape: [3, 64, 256]
|
image_shape: [3, 64, 256]
|
||||||
padding: False
|
padding: False
|
||||||
- KeepKeys:
|
- KeepKeys:
|
||||||
|
@ -344,8 +344,12 @@ class SEEDLabelEncode(BaseRecLabelEncode):
|
|||||||
max_text_length, character_dict_path, use_space_char)
|
max_text_length, character_dict_path, use_space_char)
|
||||||
|
|
||||||
def add_special_char(self, dict_character):
|
def add_special_char(self, dict_character):
|
||||||
|
self.padding = "padding"
|
||||||
self.end_str = "eos"
|
self.end_str = "eos"
|
||||||
dict_character = dict_character + [self.end_str]
|
self.unknown = "unknown"
|
||||||
|
dict_character = dict_character + [
|
||||||
|
self.end_str, self.padding, self.unknown
|
||||||
|
]
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
@ -356,8 +360,8 @@ class SEEDLabelEncode(BaseRecLabelEncode):
|
|||||||
if len(text) >= self.max_text_len:
|
if len(text) >= self.max_text_len:
|
||||||
return None
|
return None
|
||||||
data['length'] = np.array(len(text)) + 1 # conclude eos
|
data['length'] = np.array(len(text)) + 1 # conclude eos
|
||||||
text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
|
text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
|
||||||
)
|
self.max_text_len - len(text) - 1)
|
||||||
data['label'] = np.array(text)
|
data['label'] = np.array(text)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ class AsterHead(nn.Layer):
|
|||||||
self.time_step = time_step
|
self.time_step = time_step
|
||||||
self.embeder = Embedding(self.time_step, in_channels)
|
self.embeder = Embedding(self.time_step, in_channels)
|
||||||
self.beam_width = beam_width
|
self.beam_width = beam_width
|
||||||
self.eos = self.num_classes - 1
|
self.eos = self.num_classes - 3
|
||||||
|
|
||||||
def forward(self, x, targets=None, embed=None):
|
def forward(self, x, targets=None, embed=None):
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
|
@ -287,9 +287,12 @@ class SEEDLabelDecode(BaseRecLabelDecode):
|
|||||||
use_space_char)
|
use_space_char)
|
||||||
|
|
||||||
def add_special_char(self, dict_character):
|
def add_special_char(self, dict_character):
|
||||||
self.beg_str = "sos"
|
self.padding_str = "padding"
|
||||||
self.end_str = "eos"
|
self.end_str = "eos"
|
||||||
dict_character = dict_character + [self.end_str]
|
self.unknown = "unknown"
|
||||||
|
dict_character = dict_character + [
|
||||||
|
self.end_str, self.padding_str, self.unknown
|
||||||
|
]
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
def get_ignored_tokens(self):
|
def get_ignored_tokens(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user