mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
32 lines
908 B
Python
32 lines
908 B
Python
|
from easycv.datasets.registry import PIPELINES
|
||
|
|
||
|
|
||
|
@PIPELINES.register_module()
|
||
|
class TextTokenizer:
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
tokenizer_type='bert-base-chinese',
|
||
|
max_length=50,
|
||
|
padding='max_length',
|
||
|
truncation=True,
|
||
|
):
|
||
|
from transformers import BertTokenizerFast
|
||
|
self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_type)
|
||
|
self.max_length = max_length
|
||
|
self.padding = padding
|
||
|
self.truncation = truncation
|
||
|
|
||
|
def __call__(self, results):
|
||
|
text = results['text']
|
||
|
tokens = self.tokenizer(
|
||
|
text,
|
||
|
max_length=self.max_length,
|
||
|
padding=self.padding,
|
||
|
return_tensors='pt',
|
||
|
truncation=True)
|
||
|
|
||
|
results['text_input_ids'] = tokens.input_ids.reshape([-1])
|
||
|
results['text_input_mask'] = tokens.attention_mask.reshape([-1])
|
||
|
return results
|