mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Add a text translation example (#1283)
This commit is contained in:
parent
d9fee4fbb1
commit
398d229910
120
examples/text_translation.py
Normal file
120
examples/text_translation.py
Normal file
@ -0,0 +1,120 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torchtext.data.metrics import bleu_score
|
||||
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
||||
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.runner import Runner
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('t5-small')
|
||||
|
||||
|
||||
class MMT5ForTranslation(BaseModel):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, label, input_ids, attention_mask, mode):
|
||||
if mode == 'loss':
|
||||
output = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=label)
|
||||
return {'loss': output.loss}
|
||||
elif mode == 'predict':
|
||||
output = self.model.generate(input_ids)
|
||||
return output, label
|
||||
|
||||
|
||||
def post_process(preds, labels):
|
||||
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
labels = torch.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
decoded_preds = [pred.split() for pred in preds]
|
||||
decoded_labels = [[label.split()] for label in labels]
|
||||
return decoded_preds, decoded_labels
|
||||
|
||||
|
||||
class Accuracy(BaseMetric):
|
||||
|
||||
def process(self, data_batch, data_samples):
|
||||
outputs, labels = data_samples
|
||||
decoded_preds, decoded_labels = post_process(outputs, labels)
|
||||
score = bleu_score(decoded_preds, decoded_labels)
|
||||
prediction_lens = torch.tensor([
|
||||
torch.count_nonzero(pred != tokenizer.pad_token_id)
|
||||
for pred in outputs
|
||||
],
|
||||
dtype=torch.float64)
|
||||
|
||||
gen_len = torch.mean(prediction_lens).item()
|
||||
self.results.append({
|
||||
'gen_len': gen_len,
|
||||
'bleu': score,
|
||||
})
|
||||
|
||||
def compute_metrics(self, results):
|
||||
return dict(
|
||||
gen_len=np.mean([item['gen_len'] for item in results]),
|
||||
bleu_score=np.mean([item['bleu'] for item in results]),
|
||||
)
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
prefix = 'translate English to French: '
|
||||
input_sequences = [prefix + item['translation']['en'] for item in data]
|
||||
target_sequences = [item['translation']['fr'] for item in data]
|
||||
input_dict = tokenizer(
|
||||
input_sequences,
|
||||
padding='longest',
|
||||
return_tensors='pt',
|
||||
)
|
||||
|
||||
label = tokenizer(
|
||||
target_sequences,
|
||||
padding='longest',
|
||||
return_tensors='pt',
|
||||
).input_ids
|
||||
label[label ==
|
||||
tokenizer.pad_token_id] = -100 # ignore contribution to loss
|
||||
return dict(
|
||||
label=label,
|
||||
input_ids=input_dict.input_ids,
|
||||
attention_mask=input_dict.attention_mask)
|
||||
|
||||
|
||||
def main():
|
||||
model = T5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
|
||||
books = load_dataset('opus_books', 'en-fr')
|
||||
books = books['train'].train_test_split(test_size=0.2)
|
||||
train_set, test_set = books['train'], books['test']
|
||||
|
||||
train_loader = dict(
|
||||
batch_size=16,
|
||||
dataset=train_set,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=collate_fn)
|
||||
test_loader = dict(
|
||||
batch_size=32,
|
||||
dataset=test_set,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
collate_fn=collate_fn)
|
||||
runner = Runner(
|
||||
model=MMT5ForTranslation(model),
|
||||
train_dataloader=train_loader,
|
||||
val_dataloader=test_loader,
|
||||
optim_wrapper=dict(optimizer=dict(type=torch.optim.Adam, lr=2e-5)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
||||
val_cfg=dict(),
|
||||
work_dir='t5_work_dir',
|
||||
val_evaluator=dict(type=Accuracy))
|
||||
|
||||
runner.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user