mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
32 lines
908 B
Python
32 lines
908 B
Python
from easycv.datasets.registry import PIPELINES
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class TextTokenizer:
|
|
|
|
def __init__(
|
|
self,
|
|
tokenizer_type='bert-base-chinese',
|
|
max_length=50,
|
|
padding='max_length',
|
|
truncation=True,
|
|
):
|
|
from transformers import BertTokenizerFast
|
|
self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_type)
|
|
self.max_length = max_length
|
|
self.padding = padding
|
|
self.truncation = truncation
|
|
|
|
def __call__(self, results):
|
|
text = results['text']
|
|
tokens = self.tokenizer(
|
|
text,
|
|
max_length=self.max_length,
|
|
padding=self.padding,
|
|
return_tensors='pt',
|
|
truncation=True)
|
|
|
|
results['text_input_ids'] = tokens.input_ids.reshape([-1])
|
|
results['text_input_mask'] = tokens.attention_mask.reshape([-1])
|
|
return results
|