From f4f2555324672b20d9e109dbdff08b6be645edf1 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Mon, 24 Jul 2023 10:20:21 +0800 Subject: [PATCH] Llama2 example (#1264) --- README.md | 4 +- README_zh-CN.md | 2 + examples/llama2/README.md | 42 ++++++++ examples/llama2/fsdp_finetune.py | 168 +++++++++++++++++++++++++++++++ examples/llama2/generate.py | 36 +++++++ 5 files changed, 250 insertions(+), 2 deletions(-) create mode 100644 examples/llama2/README.md create mode 100644 examples/llama2/fsdp_finetune.py create mode 100644 examples/llama2/generate.py diff --git a/README.md b/README.md index 00e98a6b..cea7aa33 100644 --- a/README.md +++ b/README.md @@ -58,10 +58,10 @@ English | [简体中文](README_zh-CN.md) ## What's New -v0.8.0 was released on 2023-06-30. - Highlights: +- Add an [example](./examples/llama2/) to finetune Llama2. + - Support training with [FSDP](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html?highlight=fsdp) and [DeepSpeed](https://www.deepspeed.ai/). Refer to the [Training Large Models](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html) for more detailed usages. - Introduce the pure Python style configuration file: diff --git a/README_zh-CN.md b/README_zh-CN.md index 7f524047..86a639c9 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -62,6 +62,8 @@ 亮点: +- 新增微调 Llama2 的[示例](./examples/llama2/)。 + - 支持使用 [FSDP](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html?highlight=fsdp) 和 [DeepSpeed](https://www.deepspeed.ai/) 进行训练。可阅读[大模型训练](https://mmengine.readthedocs.io/zh_cn/latest/common_usage/large_model_training.html)了解用法。 - 引入纯 Python 风格的配置文件: diff --git a/examples/llama2/README.md b/examples/llama2/README.md new file mode 100644 index 00000000..288bd4fd --- /dev/null +++ b/examples/llama2/README.md @@ -0,0 +1,42 @@ +# Train Llama2 in MMEngine + +## Setup env + +Note: This example requires PyTorch 2.0+ and MMEngine 0.8.0+. + +- Install MMEngine + + ```bash + git clone https://github.com/open-mmlab/mmengine.git + cd mmengine + pip install -e . -v + ``` + +- Install third-party dependencies + + ```bash + pip install -U transformers accelerate tokenizers + ``` + +## Prepare data + +```bash +mkdir data +wget https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json -O data/alpaca_data.json +``` + +## Prepare model + +Download model weights from https://huggingface.co/meta-llama/Llama-2-7b-hf + +## Train + +```bash +torchrun --nproc-per-node 8 examples/llama2/fsdp_finetune.py data/alpaca_data.json ${model_weights} +``` + +## Inference + +```bash +python examples/llama2/generate.py ${checkpoints} +``` diff --git a/examples/llama2/fsdp_finetune.py b/examples/llama2/fsdp_finetune.py new file mode 100644 index 00000000..0d7e2751 --- /dev/null +++ b/examples/llama2/fsdp_finetune.py @@ -0,0 +1,168 @@ +import argparse +import copy +from functools import partial + +import torch +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.optim import AdamW +from torch.utils.data import DataLoader, Dataset +from transformers import LlamaForCausalLM, LlamaTokenizer +from transformers.data import default_data_collator +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +from mmengine import load +from mmengine._strategy import FSDPStrategy +from mmengine.dataset import DefaultSampler +from mmengine.dist.utils import is_main_process +from mmengine.optim import StepLR +from mmengine.utils import apply_to +from mmengine.visualization import Visualizer, WandbVisBackend + +ORI_BATCH_SIZE = 4 +PROMPT_DICT = { + 'prompt_input': + ('Below is an instruction that describes a task, paired with an input ' + 'that provides further context. ' + 'Write a response that appropriately completes the request.\n\n' + '### Instruction:\n{instruction}\n\n' + '### Input:\n{input}\n\n### Response:'), + 'prompt_no_input': + ('Below is an instruction that describes a task. ' + 'Write a response that appropriately completes the request.\n\n' + '### Instruction:\n{instruction}\n\n### Response:'), +} + + +# Modified from https://github.com/facebookresearch/llama-recipes/blob/main/ft_datasets/alpaca_dataset.py # noqa: E501 +class AlpacaDataset(Dataset): + + def __init__(self, data_path, tokenizer, max_words=224): + self.ann = load(data_path) + self.max_words = max_words + self.tokenizer = tokenizer + + def __len__(self): + return len(self.ann) + + def __getitem__(self, index): + ann = self.ann[index] + if ann.get('input', '') == '': + prompt = PROMPT_DICT['prompt_no_input'].format_map(ann) + else: + prompt = PROMPT_DICT['prompt_input'].format_map(ann) + example = prompt + ann['output'] + prompt = torch.tensor(self.tokenizer.encode(prompt), dtype=torch.int64) + example = self.tokenizer.encode(example) + example.append(self.tokenizer.eos_token_id) + example = torch.tensor(example, dtype=torch.int64) + padding = self.max_words - example.shape[0] + if padding > 0: + example = torch.cat( + (example, torch.zeros(padding, dtype=torch.int64) - 1)) + elif padding < 0: + example = example[:self.max_words] + labels = copy.deepcopy(example) + labels[:len(prompt)] = -1 + example_mask = example.ge(0) + label_mask = labels.ge(0) + example[~example_mask] = 0 + labels[~label_mask] = 0 + example_mask = example_mask.float() + label_mask = label_mask.float() + + return { + 'input_ids': example, + 'labels': labels, + 'attention_mask': example_mask, + } + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train alpaca with llama2') + parser.add_argument('data_root', type=str) + parser.add_argument('checkpoint', type=str) + parser.add_argument('--output-dir', type=str, default='work_dirs') + parser.add_argument('--max-epoch', type=int, default=3) + parser.add_argument('--batch-size', type=int, default=4) + parser.add_argument('--save-interval', type=int, default=500) + args = parser.parse_args() + return args + + +def train(): + args = parse_args() + # Setup distributed related component in Strategy. + strategy = FSDPStrategy( + model_wrapper=dict( + auto_wrap_policy=partial( + transformer_auto_wrap_policy, + transformer_layer_cls={LlamaDecoderLayer})), + state_dict_cfg='full', + env_kwargs=dict(randomness=dict(seed=42))) + visualizer = Visualizer( + name='mmengine', + save_dir=args.output_dir, + vis_backends=[dict(type=WandbVisBackend)]) + + # Prepare model + tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint) + tokenizer.add_special_tokens({'pad_token': ''}) + model = LlamaForCausalLM.from_pretrained(args.checkpoint) + model.to(torch.bfloat16) + model.train() + + # Prepare dataset + train_dataset = AlpacaDataset( + tokenizer=tokenizer, data_path=args.data_root) + train_dataloader = DataLoader( + train_dataset, + batch_size=args.batch_size, + sampler=DefaultSampler(train_dataset, seed=0), + collate_fn=default_data_collator, + drop_last=True) + + # Get the prepared model, scheduler and optimizer from strategy + epoch_length = len(train_dataloader) + max_iters = epoch_length * args.max_epoch + optim_cfg = dict( + optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0), + accumulative_counts=ORI_BATCH_SIZE / args.batch_size) + scheduler_cfgs = [dict(type=StepLR, step_size=1, gamma=0.85)] + model, optimizer, schedulers = strategy.prepare( + model, + optim_wrapper=optim_cfg, + param_scheduler=scheduler_cfgs, + dispatch_kwargs=dict(max_iters=max_iters, max_epochs=args.max_epoch)) + + for epoch in range(args.max_epoch): + for idx, inputs in enumerate(train_dataloader): + # Convert inputs to target device. + inputs = apply_to(inputs, lambda m: isinstance(m, torch.Tensor), + lambda m: m.cuda()) + + loss = model(**inputs).loss + optimizer.update_params(loss) + + max_memory = torch.cuda.max_memory_allocated() + strategy.logger.info(f'Epoch: {epoch+1}/{args.max_epoch}, ' + f'Iter: {idx+1}/{epoch_length}, ' + f'Loss: {loss.item():.3f}, ' + f'Lr: {optimizer.get_lr()["lr"][0]:.6f} ' + f'Memory: {max_memory/1e9:.3f}G') + visualizer.add_scalars({'loss': loss.item()}) + + torch.cuda.reset_peak_memory_stats() + + for scheduler in schedulers: + scheduler.step() + + save_dir = f'{args.output_dir}/epoch_{epoch+1}' + state_dict = model.state_dict() + + if is_main_process(): + model.save_pretrained(save_dir, state_dict=state_dict) + tokenizer.save_pretrained(save_dir) + + +if __name__ == '__main__': + train() diff --git a/examples/llama2/generate.py b/examples/llama2/generate.py new file mode 100644 index 00000000..85635c37 --- /dev/null +++ b/examples/llama2/generate.py @@ -0,0 +1,36 @@ +import argparse + +import torch +from transformers import AutoTokenizer, LlamaForCausalLM + +# flake8: noqa + +prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request. + +### Instruction: +Imagine you are from the 1700s. Try to write a sentence in the language used in that era. + +### Response:""" + + +def parse_args(): + parser = argparse.ArgumentParser(description='llama2 inference') + parser.add_argument('checkpoint', type=str) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint) + model = LlamaForCausalLM.from_pretrained(args.checkpoint).half().cuda() + model.eval() + + inputs = tokenizer(prompt, return_tensors='pt') + with torch.no_grad(): + generate_ids = model.generate(inputs.input_ids.cuda(), max_length=300) + print( + tokenizer.batch_decode( + generate_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False)[0])