From 398d229910f02dbddfd4659ee53d5b570ea6a14c Mon Sep 17 00:00:00 2001 From: Desjajja Date: Mon, 7 Aug 2023 15:33:48 +0800 Subject: [PATCH] Add a text translation example (#1283) --- examples/text_translation.py | 120 +++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 examples/text_translation.py diff --git a/examples/text_translation.py b/examples/text_translation.py new file mode 100644 index 00000000..61f43baf --- /dev/null +++ b/examples/text_translation.py @@ -0,0 +1,120 @@ +import numpy as np +import torch +from datasets import load_dataset +from torchtext.data.metrics import bleu_score +from transformers import AutoTokenizer, T5ForConditionalGeneration + +from mmengine.evaluator import BaseMetric +from mmengine.model import BaseModel +from mmengine.runner import Runner + +tokenizer = AutoTokenizer.from_pretrained('t5-small') + + +class MMT5ForTranslation(BaseModel): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, label, input_ids, attention_mask, mode): + if mode == 'loss': + output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=label) + return {'loss': output.loss} + elif mode == 'predict': + output = self.model.generate(input_ids) + return output, label + + +def post_process(preds, labels): + preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + labels = torch.where(labels != -100, labels, tokenizer.pad_token_id) + labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + decoded_preds = [pred.split() for pred in preds] + decoded_labels = [[label.split()] for label in labels] + return decoded_preds, decoded_labels + + +class Accuracy(BaseMetric): + + def process(self, data_batch, data_samples): + outputs, labels = data_samples + decoded_preds, decoded_labels = post_process(outputs, labels) + score = bleu_score(decoded_preds, decoded_labels) + prediction_lens = torch.tensor([ + torch.count_nonzero(pred != tokenizer.pad_token_id) + for pred in outputs + ], + dtype=torch.float64) + + gen_len = torch.mean(prediction_lens).item() + self.results.append({ + 'gen_len': gen_len, + 'bleu': score, + }) + + def compute_metrics(self, results): + return dict( + gen_len=np.mean([item['gen_len'] for item in results]), + bleu_score=np.mean([item['bleu'] for item in results]), + ) + + +def collate_fn(data): + prefix = 'translate English to French: ' + input_sequences = [prefix + item['translation']['en'] for item in data] + target_sequences = [item['translation']['fr'] for item in data] + input_dict = tokenizer( + input_sequences, + padding='longest', + return_tensors='pt', + ) + + label = tokenizer( + target_sequences, + padding='longest', + return_tensors='pt', + ).input_ids + label[label == + tokenizer.pad_token_id] = -100 # ignore contribution to loss + return dict( + label=label, + input_ids=input_dict.input_ids, + attention_mask=input_dict.attention_mask) + + +def main(): + model = T5ForConditionalGeneration.from_pretrained('t5-small') + + books = load_dataset('opus_books', 'en-fr') + books = books['train'].train_test_split(test_size=0.2) + train_set, test_set = books['train'], books['test'] + + train_loader = dict( + batch_size=16, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=collate_fn) + test_loader = dict( + batch_size=32, + dataset=test_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=collate_fn) + runner = Runner( + model=MMT5ForTranslation(model), + train_dataloader=train_loader, + val_dataloader=test_loader, + optim_wrapper=dict(optimizer=dict(type=torch.optim.Adam, lr=2e-5)), + train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), + val_cfg=dict(), + work_dir='t5_work_dir', + val_evaluator=dict(type=Accuracy)) + + runner.train() + + +if __name__ == '__main__': + main()