Llama2 example (#1264)

pull/1276/head
Mashiro 2023-07-24 10:20:21 +08:00 committed by GitHub
parent 86387da4a5
commit f4f2555324
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 250 additions and 2 deletions

View File

@ -58,10 +58,10 @@ English | [简体中文](README_zh-CN.md)
## What's New
v0.8.0 was released on 2023-06-30.
Highlights:
- Add an [example](./examples/llama2/) to finetune Llama2.
- Support training with [FSDP](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html?highlight=fsdp) and [DeepSpeed](https://www.deepspeed.ai/). Refer to the [Training Large Models](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html) for more detailed usages.
- Introduce the pure Python style configuration file:

View File

@ -62,6 +62,8 @@
亮点:
- 新增微调 Llama2 的[示例](./examples/llama2/)。
- 支持使用 [FSDP](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html?highlight=fsdp) 和 [DeepSpeed](https://www.deepspeed.ai/) 进行训练。可阅读[大模型训练](https://mmengine.readthedocs.io/zh_cn/latest/common_usage/large_model_training.html)了解用法。
- 引入纯 Python 风格的配置文件:

View File

@ -0,0 +1,42 @@
# Train Llama2 in MMEngine
## Setup env
Note: This example requires PyTorch 2.0+ and MMEngine 0.8.0+.
- Install MMEngine
```bash
git clone https://github.com/open-mmlab/mmengine.git
cd mmengine
pip install -e . -v
```
- Install third-party dependencies
```bash
pip install -U transformers accelerate tokenizers
```
## Prepare data
```bash
mkdir data
wget https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json -O data/alpaca_data.json
```
## Prepare model
Download model weights from https://huggingface.co/meta-llama/Llama-2-7b-hf
## Train
```bash
torchrun --nproc-per-node 8 examples/llama2/fsdp_finetune.py data/alpaca_data.json ${model_weights}
```
## Inference
```bash
python examples/llama2/generate.py ${checkpoints}
```

View File

@ -0,0 +1,168 @@
import argparse
import copy
from functools import partial
import torch
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers.data import default_data_collator
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from mmengine import load
from mmengine._strategy import FSDPStrategy
from mmengine.dataset import DefaultSampler
from mmengine.dist.utils import is_main_process
from mmengine.optim import StepLR
from mmengine.utils import apply_to
from mmengine.visualization import Visualizer, WandbVisBackend
ORI_BATCH_SIZE = 4
PROMPT_DICT = {
'prompt_input':
('Below is an instruction that describes a task, paired with an input '
'that provides further context. '
'Write a response that appropriately completes the request.\n\n'
'### Instruction:\n{instruction}\n\n'
'### Input:\n{input}\n\n### Response:'),
'prompt_no_input':
('Below is an instruction that describes a task. '
'Write a response that appropriately completes the request.\n\n'
'### Instruction:\n{instruction}\n\n### Response:'),
}
# Modified from https://github.com/facebookresearch/llama-recipes/blob/main/ft_datasets/alpaca_dataset.py # noqa: E501
class AlpacaDataset(Dataset):
def __init__(self, data_path, tokenizer, max_words=224):
self.ann = load(data_path)
self.max_words = max_words
self.tokenizer = tokenizer
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
ann = self.ann[index]
if ann.get('input', '') == '':
prompt = PROMPT_DICT['prompt_no_input'].format_map(ann)
else:
prompt = PROMPT_DICT['prompt_input'].format_map(ann)
example = prompt + ann['output']
prompt = torch.tensor(self.tokenizer.encode(prompt), dtype=torch.int64)
example = self.tokenizer.encode(example)
example.append(self.tokenizer.eos_token_id)
example = torch.tensor(example, dtype=torch.int64)
padding = self.max_words - example.shape[0]
if padding > 0:
example = torch.cat(
(example, torch.zeros(padding, dtype=torch.int64) - 1))
elif padding < 0:
example = example[:self.max_words]
labels = copy.deepcopy(example)
labels[:len(prompt)] = -1
example_mask = example.ge(0)
label_mask = labels.ge(0)
example[~example_mask] = 0
labels[~label_mask] = 0
example_mask = example_mask.float()
label_mask = label_mask.float()
return {
'input_ids': example,
'labels': labels,
'attention_mask': example_mask,
}
def parse_args():
parser = argparse.ArgumentParser(description='Train alpaca with llama2')
parser.add_argument('data_root', type=str)
parser.add_argument('checkpoint', type=str)
parser.add_argument('--output-dir', type=str, default='work_dirs')
parser.add_argument('--max-epoch', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--save-interval', type=int, default=500)
args = parser.parse_args()
return args
def train():
args = parse_args()
# Setup distributed related component in Strategy.
strategy = FSDPStrategy(
model_wrapper=dict(
auto_wrap_policy=partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer})),
state_dict_cfg='full',
env_kwargs=dict(randomness=dict(seed=42)))
visualizer = Visualizer(
name='mmengine',
save_dir=args.output_dir,
vis_backends=[dict(type=WandbVisBackend)])
# Prepare model
tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)
tokenizer.add_special_tokens({'pad_token': '<PAD>'})
model = LlamaForCausalLM.from_pretrained(args.checkpoint)
model.to(torch.bfloat16)
model.train()
# Prepare dataset
train_dataset = AlpacaDataset(
tokenizer=tokenizer, data_path=args.data_root)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
sampler=DefaultSampler(train_dataset, seed=0),
collate_fn=default_data_collator,
drop_last=True)
# Get the prepared model, scheduler and optimizer from strategy
epoch_length = len(train_dataloader)
max_iters = epoch_length * args.max_epoch
optim_cfg = dict(
optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0),
accumulative_counts=ORI_BATCH_SIZE / args.batch_size)
scheduler_cfgs = [dict(type=StepLR, step_size=1, gamma=0.85)]
model, optimizer, schedulers = strategy.prepare(
model,
optim_wrapper=optim_cfg,
param_scheduler=scheduler_cfgs,
dispatch_kwargs=dict(max_iters=max_iters, max_epochs=args.max_epoch))
for epoch in range(args.max_epoch):
for idx, inputs in enumerate(train_dataloader):
# Convert inputs to target device.
inputs = apply_to(inputs, lambda m: isinstance(m, torch.Tensor),
lambda m: m.cuda())
loss = model(**inputs).loss
optimizer.update_params(loss)
max_memory = torch.cuda.max_memory_allocated()
strategy.logger.info(f'Epoch: {epoch+1}/{args.max_epoch}, '
f'Iter: {idx+1}/{epoch_length}, '
f'Loss: {loss.item():.3f}, '
f'Lr: {optimizer.get_lr()["lr"][0]:.6f} '
f'Memory: {max_memory/1e9:.3f}G')
visualizer.add_scalars({'loss': loss.item()})
torch.cuda.reset_peak_memory_stats()
for scheduler in schedulers:
scheduler.step()
save_dir = f'{args.output_dir}/epoch_{epoch+1}'
state_dict = model.state_dict()
if is_main_process():
model.save_pretrained(save_dir, state_dict=state_dict)
tokenizer.save_pretrained(save_dir)
if __name__ == '__main__':
train()

View File

@ -0,0 +1,36 @@
import argparse
import torch
from transformers import AutoTokenizer, LlamaForCausalLM
# flake8: noqa
prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Imagine you are from the 1700s. Try to write a sentence in the language used in that era.
### Response:"""
def parse_args():
parser = argparse.ArgumentParser(description='llama2 inference')
parser.add_argument('checkpoint', type=str)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
model = LlamaForCausalLM.from_pretrained(args.checkpoint).half().cuda()
model.eval()
inputs = tokenizer(prompt, return_tensors='pt')
with torch.no_grad():
generate_ids = model.generate(inputs.input_ids.cuda(), max_length=300)
print(
tokenizer.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0])