pull/520/head
Your Name 2023-04-21 14:52:20 +08:00 committed by liukai
parent e4303d110b
commit f994c6d09e
3 changed files with 63 additions and 61 deletions

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DistributedSampler
def set_seed(seed):
@ -106,3 +108,45 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
return get_ptb(nsamples, seed, seqlen, model)
if 'c4' in name:
return get_c4(nsamples, seed, seqlen, model)
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens
class LanguageDataset(TorchDataset):
def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
super().__init__()
# seq: 1, N
self.seq_len = seq_len
self.seq = fold_tokens(seq) # B N
def __len__(self) -> int:
return self.seq.shape[0]
def __getitem__(self, index):
return self.seq[index]
def build_language_loader(testloader, world_size, rank, model, batch_size=128):
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
distributed_sampler = DistributedSampler(
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
batch_size = min(len(val_dataset) // world_size, batch_size)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=True,
sampler=distributed_sampler)
return val_dataloader

View File

@ -5,15 +5,14 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from datautils import get_loaders
from datautils import build_language_loader, get_loaders
from opt_sparse_gpt import get_model
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DistributedSampler
from utils import init_on_meta
from mmrazor.implementations.pruning import sparse_gpt
from mmrazor.utils import print_log
@ -25,32 +24,7 @@ def setup(rank, world_size):
dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens
class LanguageDataset(TorchDataset):
def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
super().__init__()
# seq: 1, N
self.seq_len = seq_len
self.seq = fold_tokens(seq) # B N
def __len__(self) -> int:
return self.seq.shape[0]
def __getitem__(self, index):
return self.seq[index]
print_log(f'init {rank}/{world_size}')
@torch.no_grad()
@ -112,37 +86,6 @@ def opt_infer(
print_log(f'{(i+1)*B} / {len(dataloader.dataset)}')
def build_language_loader(testloader, world_size, rank, model, batch_size=128):
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
distributed_sampler = DistributedSampler(
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
batch_size = min(len(val_dataset) // world_size, batch_size)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=True,
sampler=distributed_sampler)
return val_dataloader
class init_on_meta:
def __init__(self, enable=True) -> None:
self.enable = enable
self.default_device = torch.ones([]).device
def __enter__(self):
if self.enable:
torch.set_default_device('meta')
def __exit__(self, exc_type, exc_value, traceback):
if self.enable:
torch.set_default_device(self.default_device)
def _materialize_meta_module(module: nn.Module, ):
# Run default meta device initialization

View File

@ -78,3 +78,18 @@ def opt_infer(
if (i + 1) * batch_size >= num_samples:
break
class init_on_meta:
def __init__(self, enable=True) -> None:
self.enable = enable
self.default_device = torch.ones([]).device
def __enter__(self):
if self.enable:
torch.set_default_device('meta')
def __exit__(self, exc_type, exc_value, traceback):
if self.enable:
torch.set_default_device(self.default_device)