125 lines
3.7 KiB
Python
125 lines
3.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from transformers import BertForSequenceClassification, BertTokenizer
|
|
|
|
from mmengine.evaluator import BaseMetric
|
|
from mmengine.model import BaseModel
|
|
from mmengine.runner import Runner
|
|
|
|
|
|
class MMBertForClassify(BaseModel):
|
|
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
def forward(self, label, input_ids, token_type_ids, attention_mask, mode):
|
|
output = self.model(
|
|
input_ids=input_ids,
|
|
token_type_ids=token_type_ids,
|
|
attention_mask=attention_mask,
|
|
labels=label)
|
|
if mode == 'loss':
|
|
return {'loss': output.loss}
|
|
elif mode == 'predict':
|
|
return output.logits, label
|
|
|
|
|
|
class Accuracy(BaseMetric):
|
|
|
|
def process(self, data_batch, data_samples):
|
|
score, gt = data_samples
|
|
self.results.append({
|
|
'batch_size': len(gt),
|
|
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
|
|
})
|
|
|
|
def compute_metrics(self, results):
|
|
total_correct = sum(item['correct'] for item in results)
|
|
total_size = sum(item['batch_size'] for item in results)
|
|
return dict(accuracy=100 * total_correct / total_size)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Distributed Training')
|
|
parser.add_argument(
|
|
'--launcher',
|
|
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
|
default='none',
|
|
help='job launcher')
|
|
parser.add_argument('--local_rank', type=int, default=0)
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def collate_fn(data):
|
|
labels = []
|
|
input_ids = []
|
|
token_type_ids = []
|
|
attention_mask = []
|
|
for item in data:
|
|
labels.append(item['label'])
|
|
input_ids.append(torch.tensor(item['input_ids']))
|
|
token_type_ids.append(torch.tensor(item['token_type_ids']))
|
|
attention_mask.append(torch.tensor(item['attention_mask']))
|
|
|
|
input_ids = torch.stack(input_ids)
|
|
token_type_ids = torch.stack(token_type_ids)
|
|
attention_mask = torch.stack(attention_mask)
|
|
label = torch.tensor(labels)
|
|
return dict(
|
|
label=label,
|
|
input_ids=input_ids,
|
|
token_type_ids=token_type_ids,
|
|
attention_mask=attention_mask)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
model = BertForSequenceClassification.from_pretrained(
|
|
'bert-base-uncased', num_labels=2)
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
|
|
train_set = load_dataset('imdb', split='train')
|
|
test_set = load_dataset('imdb', split='test')
|
|
|
|
train_set = train_set.map(
|
|
lambda x: tokenizer(
|
|
x['text'], truncation=True, padding=True, max_length=128),
|
|
batched=True)
|
|
test_set = test_set.map(
|
|
lambda x: tokenizer(
|
|
x['text'], truncation=True, padding=True, max_length=128),
|
|
batched=True)
|
|
|
|
train_loader = dict(
|
|
batch_size=32,
|
|
dataset=train_set,
|
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
|
collate_fn=collate_fn)
|
|
test_loader = dict(
|
|
batch_size=32,
|
|
dataset=test_set,
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
collate_fn=collate_fn)
|
|
runner = Runner(
|
|
model=MMBertForClassify(model),
|
|
train_dataloader=train_loader,
|
|
val_dataloader=test_loader,
|
|
optim_wrapper=dict(optimizer=dict(type=torch.optim.Adam, lr=2e-5)),
|
|
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
|
val_cfg=dict(),
|
|
work_dir='bert_work_dir',
|
|
val_evaluator=dict(type=Accuracy),
|
|
launcher=args.launcher,
|
|
)
|
|
runner.train()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|