Llama2 example (#1264)
parent
86387da4a5
commit
f4f2555324
|
@ -58,10 +58,10 @@ English | [简体中文](README_zh-CN.md)
|
|||
|
||||
## What's New
|
||||
|
||||
v0.8.0 was released on 2023-06-30.
|
||||
|
||||
Highlights:
|
||||
|
||||
- Add an [example](./examples/llama2/) to finetune Llama2.
|
||||
|
||||
- Support training with [FSDP](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html?highlight=fsdp) and [DeepSpeed](https://www.deepspeed.ai/). Refer to the [Training Large Models](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html) for more detailed usages.
|
||||
|
||||
- Introduce the pure Python style configuration file:
|
||||
|
|
|
@ -62,6 +62,8 @@
|
|||
|
||||
亮点:
|
||||
|
||||
- 新增微调 Llama2 的[示例](./examples/llama2/)。
|
||||
|
||||
- 支持使用 [FSDP](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html?highlight=fsdp) 和 [DeepSpeed](https://www.deepspeed.ai/) 进行训练。可阅读[大模型训练](https://mmengine.readthedocs.io/zh_cn/latest/common_usage/large_model_training.html)了解用法。
|
||||
|
||||
- 引入纯 Python 风格的配置文件:
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Train Llama2 in MMEngine
|
||||
|
||||
## Setup env
|
||||
|
||||
Note: This example requires PyTorch 2.0+ and MMEngine 0.8.0+.
|
||||
|
||||
- Install MMEngine
|
||||
|
||||
```bash
|
||||
git clone https://github.com/open-mmlab/mmengine.git
|
||||
cd mmengine
|
||||
pip install -e . -v
|
||||
```
|
||||
|
||||
- Install third-party dependencies
|
||||
|
||||
```bash
|
||||
pip install -U transformers accelerate tokenizers
|
||||
```
|
||||
|
||||
## Prepare data
|
||||
|
||||
```bash
|
||||
mkdir data
|
||||
wget https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json -O data/alpaca_data.json
|
||||
```
|
||||
|
||||
## Prepare model
|
||||
|
||||
Download model weights from https://huggingface.co/meta-llama/Llama-2-7b-hf
|
||||
|
||||
## Train
|
||||
|
||||
```bash
|
||||
torchrun --nproc-per-node 8 examples/llama2/fsdp_finetune.py data/alpaca_data.json ${model_weights}
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
```bash
|
||||
python examples/llama2/generate.py ${checkpoints}
|
||||
```
|
|
@ -0,0 +1,168 @@
|
|||
import argparse
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
from transformers.data import default_data_collator
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||
|
||||
from mmengine import load
|
||||
from mmengine._strategy import FSDPStrategy
|
||||
from mmengine.dataset import DefaultSampler
|
||||
from mmengine.dist.utils import is_main_process
|
||||
from mmengine.optim import StepLR
|
||||
from mmengine.utils import apply_to
|
||||
from mmengine.visualization import Visualizer, WandbVisBackend
|
||||
|
||||
ORI_BATCH_SIZE = 4
|
||||
PROMPT_DICT = {
|
||||
'prompt_input':
|
||||
('Below is an instruction that describes a task, paired with an input '
|
||||
'that provides further context. '
|
||||
'Write a response that appropriately completes the request.\n\n'
|
||||
'### Instruction:\n{instruction}\n\n'
|
||||
'### Input:\n{input}\n\n### Response:'),
|
||||
'prompt_no_input':
|
||||
('Below is an instruction that describes a task. '
|
||||
'Write a response that appropriately completes the request.\n\n'
|
||||
'### Instruction:\n{instruction}\n\n### Response:'),
|
||||
}
|
||||
|
||||
|
||||
# Modified from https://github.com/facebookresearch/llama-recipes/blob/main/ft_datasets/alpaca_dataset.py # noqa: E501
|
||||
class AlpacaDataset(Dataset):
|
||||
|
||||
def __init__(self, data_path, tokenizer, max_words=224):
|
||||
self.ann = load(data_path)
|
||||
self.max_words = max_words
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ann)
|
||||
|
||||
def __getitem__(self, index):
|
||||
ann = self.ann[index]
|
||||
if ann.get('input', '') == '':
|
||||
prompt = PROMPT_DICT['prompt_no_input'].format_map(ann)
|
||||
else:
|
||||
prompt = PROMPT_DICT['prompt_input'].format_map(ann)
|
||||
example = prompt + ann['output']
|
||||
prompt = torch.tensor(self.tokenizer.encode(prompt), dtype=torch.int64)
|
||||
example = self.tokenizer.encode(example)
|
||||
example.append(self.tokenizer.eos_token_id)
|
||||
example = torch.tensor(example, dtype=torch.int64)
|
||||
padding = self.max_words - example.shape[0]
|
||||
if padding > 0:
|
||||
example = torch.cat(
|
||||
(example, torch.zeros(padding, dtype=torch.int64) - 1))
|
||||
elif padding < 0:
|
||||
example = example[:self.max_words]
|
||||
labels = copy.deepcopy(example)
|
||||
labels[:len(prompt)] = -1
|
||||
example_mask = example.ge(0)
|
||||
label_mask = labels.ge(0)
|
||||
example[~example_mask] = 0
|
||||
labels[~label_mask] = 0
|
||||
example_mask = example_mask.float()
|
||||
label_mask = label_mask.float()
|
||||
|
||||
return {
|
||||
'input_ids': example,
|
||||
'labels': labels,
|
||||
'attention_mask': example_mask,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train alpaca with llama2')
|
||||
parser.add_argument('data_root', type=str)
|
||||
parser.add_argument('checkpoint', type=str)
|
||||
parser.add_argument('--output-dir', type=str, default='work_dirs')
|
||||
parser.add_argument('--max-epoch', type=int, default=3)
|
||||
parser.add_argument('--batch-size', type=int, default=4)
|
||||
parser.add_argument('--save-interval', type=int, default=500)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def train():
|
||||
args = parse_args()
|
||||
# Setup distributed related component in Strategy.
|
||||
strategy = FSDPStrategy(
|
||||
model_wrapper=dict(
|
||||
auto_wrap_policy=partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={LlamaDecoderLayer})),
|
||||
state_dict_cfg='full',
|
||||
env_kwargs=dict(randomness=dict(seed=42)))
|
||||
visualizer = Visualizer(
|
||||
name='mmengine',
|
||||
save_dir=args.output_dir,
|
||||
vis_backends=[dict(type=WandbVisBackend)])
|
||||
|
||||
# Prepare model
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)
|
||||
tokenizer.add_special_tokens({'pad_token': '<PAD>'})
|
||||
model = LlamaForCausalLM.from_pretrained(args.checkpoint)
|
||||
model.to(torch.bfloat16)
|
||||
model.train()
|
||||
|
||||
# Prepare dataset
|
||||
train_dataset = AlpacaDataset(
|
||||
tokenizer=tokenizer, data_path=args.data_root)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
sampler=DefaultSampler(train_dataset, seed=0),
|
||||
collate_fn=default_data_collator,
|
||||
drop_last=True)
|
||||
|
||||
# Get the prepared model, scheduler and optimizer from strategy
|
||||
epoch_length = len(train_dataloader)
|
||||
max_iters = epoch_length * args.max_epoch
|
||||
optim_cfg = dict(
|
||||
optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0),
|
||||
accumulative_counts=ORI_BATCH_SIZE / args.batch_size)
|
||||
scheduler_cfgs = [dict(type=StepLR, step_size=1, gamma=0.85)]
|
||||
model, optimizer, schedulers = strategy.prepare(
|
||||
model,
|
||||
optim_wrapper=optim_cfg,
|
||||
param_scheduler=scheduler_cfgs,
|
||||
dispatch_kwargs=dict(max_iters=max_iters, max_epochs=args.max_epoch))
|
||||
|
||||
for epoch in range(args.max_epoch):
|
||||
for idx, inputs in enumerate(train_dataloader):
|
||||
# Convert inputs to target device.
|
||||
inputs = apply_to(inputs, lambda m: isinstance(m, torch.Tensor),
|
||||
lambda m: m.cuda())
|
||||
|
||||
loss = model(**inputs).loss
|
||||
optimizer.update_params(loss)
|
||||
|
||||
max_memory = torch.cuda.max_memory_allocated()
|
||||
strategy.logger.info(f'Epoch: {epoch+1}/{args.max_epoch}, '
|
||||
f'Iter: {idx+1}/{epoch_length}, '
|
||||
f'Loss: {loss.item():.3f}, '
|
||||
f'Lr: {optimizer.get_lr()["lr"][0]:.6f} '
|
||||
f'Memory: {max_memory/1e9:.3f}G')
|
||||
visualizer.add_scalars({'loss': loss.item()})
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
for scheduler in schedulers:
|
||||
scheduler.step()
|
||||
|
||||
save_dir = f'{args.output_dir}/epoch_{epoch+1}'
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if is_main_process():
|
||||
model.save_pretrained(save_dir, state_dict=state_dict)
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
|
@ -0,0 +1,36 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
Imagine you are from the 1700s. Try to write a sentence in the language used in that era.
|
||||
|
||||
### Response:"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='llama2 inference')
|
||||
parser.add_argument('checkpoint', type=str)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
|
||||
model = LlamaForCausalLM.from_pretrained(args.checkpoint).half().cuda()
|
||||
model.eval()
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors='pt')
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(inputs.input_ids.cuda(), max_length=300)
|
||||
print(
|
||||
tokenizer.batch_decode(
|
||||
generate_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)[0])
|
Loading…
Reference in New Issue