mmengine/examples/llama2/fsdp_finetune.py

169 lines
6.1 KiB
Python
Raw Normal View History

2023-07-24 10:20:21 +08:00
import argparse
import copy
from functools import partial
import torch
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers.data import default_data_collator
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from mmengine import load
from mmengine._strategy import FSDPStrategy
from mmengine.dataset import DefaultSampler
from mmengine.dist.utils import is_main_process
from mmengine.optim import StepLR
from mmengine.utils import apply_to
from mmengine.visualization import Visualizer, WandbVisBackend
ORI_BATCH_SIZE = 4
PROMPT_DICT = {
'prompt_input':
('Below is an instruction that describes a task, paired with an input '
'that provides further context. '
'Write a response that appropriately completes the request.\n\n'
'### Instruction:\n{instruction}\n\n'
'### Input:\n{input}\n\n### Response:'),
'prompt_no_input':
('Below is an instruction that describes a task. '
'Write a response that appropriately completes the request.\n\n'
'### Instruction:\n{instruction}\n\n### Response:'),
}
# Modified from https://github.com/facebookresearch/llama-recipes/blob/main/ft_datasets/alpaca_dataset.py # noqa: E501
class AlpacaDataset(Dataset):
def __init__(self, data_path, tokenizer, max_words=224):
self.ann = load(data_path)
self.max_words = max_words
self.tokenizer = tokenizer
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
ann = self.ann[index]
if ann.get('input', '') == '':
prompt = PROMPT_DICT['prompt_no_input'].format_map(ann)
else:
prompt = PROMPT_DICT['prompt_input'].format_map(ann)
example = prompt + ann['output']
prompt = torch.tensor(self.tokenizer.encode(prompt), dtype=torch.int64)
example = self.tokenizer.encode(example)
example.append(self.tokenizer.eos_token_id)
example = torch.tensor(example, dtype=torch.int64)
padding = self.max_words - example.shape[0]
if padding > 0:
example = torch.cat(
(example, torch.zeros(padding, dtype=torch.int64) - 1))
elif padding < 0:
example = example[:self.max_words]
labels = copy.deepcopy(example)
labels[:len(prompt)] = -1
example_mask = example.ge(0)
label_mask = labels.ge(0)
example[~example_mask] = 0
labels[~label_mask] = 0
example_mask = example_mask.float()
label_mask = label_mask.float()
return {
'input_ids': example,
'labels': labels,
'attention_mask': example_mask,
}
def parse_args():
parser = argparse.ArgumentParser(description='Train alpaca with llama2')
parser.add_argument('data_root', type=str)
parser.add_argument('checkpoint', type=str)
parser.add_argument('--output-dir', type=str, default='work_dirs')
parser.add_argument('--max-epoch', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--save-interval', type=int, default=500)
args = parser.parse_args()
return args
def train():
args = parse_args()
# Setup distributed related component in Strategy.
strategy = FSDPStrategy(
model_wrapper=dict(
auto_wrap_policy=partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer})),
state_dict_cfg='full',
env_kwargs=dict(randomness=dict(seed=42)))
visualizer = Visualizer(
name='mmengine',
save_dir=args.output_dir,
vis_backends=[dict(type=WandbVisBackend)])
# Prepare model
tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)
tokenizer.add_special_tokens({'pad_token': '<PAD>'})
model = LlamaForCausalLM.from_pretrained(args.checkpoint)
model.to(torch.bfloat16)
model.train()
# Prepare dataset
train_dataset = AlpacaDataset(
tokenizer=tokenizer, data_path=args.data_root)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
sampler=DefaultSampler(train_dataset, seed=0),
collate_fn=default_data_collator,
drop_last=True)
# Get the prepared model, scheduler and optimizer from strategy
epoch_length = len(train_dataloader)
max_iters = epoch_length * args.max_epoch
optim_cfg = dict(
optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0),
accumulative_counts=ORI_BATCH_SIZE / args.batch_size)
scheduler_cfgs = [dict(type=StepLR, step_size=1, gamma=0.85)]
model, optimizer, schedulers = strategy.prepare(
model,
optim_wrapper=optim_cfg,
param_scheduler=scheduler_cfgs,
dispatch_kwargs=dict(max_iters=max_iters, max_epochs=args.max_epoch))
for epoch in range(args.max_epoch):
for idx, inputs in enumerate(train_dataloader):
# Convert inputs to target device.
inputs = apply_to(inputs, lambda m: isinstance(m, torch.Tensor),
lambda m: m.cuda())
loss = model(**inputs).loss
optimizer.update_params(loss)
max_memory = torch.cuda.max_memory_allocated()
strategy.logger.info(f'Epoch: {epoch+1}/{args.max_epoch}, '
f'Iter: {idx+1}/{epoch_length}, '
f'Loss: {loss.item():.3f}, '
f'Lr: {optimizer.get_lr()["lr"][0]:.6f} '
f'Memory: {max_memory/1e9:.3f}G')
visualizer.add_scalars({'loss': loss.item()})
torch.cuda.reset_peak_memory_stats()
for scheduler in schedulers:
scheduler.step()
save_dir = f'{args.output_dir}/epoch_{epoch+1}'
state_dict = model.state_dict()
if is_main_process():
model.save_pretrained(save_dir, state_dict=state_dict)
tokenizer.save_pretrained(save_dir)
if __name__ == '__main__':
train()