pull/520/head
Your Name 2023-04-21 14:52:20 +08:00 committed by liukai
parent e4303d110b
commit f994c6d09e
3 changed files with 63 additions and 61 deletions

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DistributedSampler
def set_seed(seed): def set_seed(seed):
@ -106,3 +108,45 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
return get_ptb(nsamples, seed, seqlen, model) return get_ptb(nsamples, seed, seqlen, model)
if 'c4' in name: if 'c4' in name:
return get_c4(nsamples, seed, seqlen, model) return get_c4(nsamples, seed, seqlen, model)
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens
class LanguageDataset(TorchDataset):
def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
super().__init__()
# seq: 1, N
self.seq_len = seq_len
self.seq = fold_tokens(seq) # B N
def __len__(self) -> int:
return self.seq.shape[0]
def __getitem__(self, index):
return self.seq[index]
def build_language_loader(testloader, world_size, rank, model, batch_size=128):
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
distributed_sampler = DistributedSampler(
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
batch_size = min(len(val_dataset) // world_size, batch_size)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=True,
sampler=distributed_sampler)
return val_dataloader

View File

@ -5,15 +5,14 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from datautils import get_loaders from datautils import build_language_loader, get_loaders
from opt_sparse_gpt import get_model from opt_sparse_gpt import get_model
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset from utils import init_on_meta
from torch.utils.data import DistributedSampler
from mmrazor.implementations.pruning import sparse_gpt from mmrazor.implementations.pruning import sparse_gpt
from mmrazor.utils import print_log from mmrazor.utils import print_log
@ -25,32 +24,7 @@ def setup(rank, world_size):
dist.init_process_group('nccl', rank=rank, world_size=world_size) dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
print_log(f'init {rank}/{world_size}')
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens
class LanguageDataset(TorchDataset):
def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
super().__init__()
# seq: 1, N
self.seq_len = seq_len
self.seq = fold_tokens(seq) # B N
def __len__(self) -> int:
return self.seq.shape[0]
def __getitem__(self, index):
return self.seq[index]
@torch.no_grad() @torch.no_grad()
@ -112,37 +86,6 @@ def opt_infer(
print_log(f'{(i+1)*B} / {len(dataloader.dataset)}') print_log(f'{(i+1)*B} / {len(dataloader.dataset)}')
def build_language_loader(testloader, world_size, rank, model, batch_size=128):
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
distributed_sampler = DistributedSampler(
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
batch_size = min(len(val_dataset) // world_size, batch_size)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=True,
sampler=distributed_sampler)
return val_dataloader
class init_on_meta:
def __init__(self, enable=True) -> None:
self.enable = enable
self.default_device = torch.ones([]).device
def __enter__(self):
if self.enable:
torch.set_default_device('meta')
def __exit__(self, exc_type, exc_value, traceback):
if self.enable:
torch.set_default_device(self.default_device)
def _materialize_meta_module(module: nn.Module, ): def _materialize_meta_module(module: nn.Module, ):
# Run default meta device initialization # Run default meta device initialization

View File

@ -78,3 +78,18 @@ def opt_infer(
if (i + 1) * batch_size >= num_samples: if (i + 1) * batch_size >= num_samples:
break break
class init_on_meta:
def __init__(self, enable=True) -> None:
self.enable = enable
self.default_device = torch.ones([]).device
def __enter__(self):
if self.enable:
torch.set_default_device('meta')
def __exit__(self, exc_type, exc_value, traceback):
if self.enable:
torch.set_default_device(self.default_device)