refine
parent
e4303d110b
commit
f994c6d09e
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from torch.utils.data import DistributedSampler
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
|
@ -106,3 +108,45 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
|
|||
return get_ptb(nsamples, seed, seqlen, model)
|
||||
if 'c4' in name:
|
||||
return get_c4(nsamples, seed, seqlen, model)
|
||||
|
||||
|
||||
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
|
||||
# tokens: 1 N
|
||||
N = tokens.shape[1]
|
||||
num_drop = N % batch_seq_len
|
||||
if num_drop != 0:
|
||||
tokens = tokens[:, :-num_drop]
|
||||
tokens = tokens.reshape([-1, batch_seq_len]) # B N
|
||||
return tokens
|
||||
|
||||
|
||||
class LanguageDataset(TorchDataset):
|
||||
|
||||
def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
|
||||
super().__init__()
|
||||
# seq: 1, N
|
||||
self.seq_len = seq_len
|
||||
|
||||
self.seq = fold_tokens(seq) # B N
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.seq.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.seq[index]
|
||||
|
||||
|
||||
def build_language_loader(testloader, world_size, rank, model, batch_size=128):
|
||||
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
|
||||
distributed_sampler = DistributedSampler(
|
||||
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
|
||||
batch_size = min(len(val_dataset) // world_size, batch_size)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
sampler=distributed_sampler)
|
||||
return val_dataloader
|
||||
|
|
|
@ -5,15 +5,14 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from datautils import get_loaders
|
||||
from datautils import build_language_loader, get_loaders
|
||||
from opt_sparse_gpt import get_model
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.api import ShardingStrategy
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from torch.utils.data import DistributedSampler
|
||||
from utils import init_on_meta
|
||||
|
||||
from mmrazor.implementations.pruning import sparse_gpt
|
||||
from mmrazor.utils import print_log
|
||||
|
@ -25,32 +24,7 @@ def setup(rank, world_size):
|
|||
|
||||
dist.init_process_group('nccl', rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
|
||||
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
|
||||
# tokens: 1 N
|
||||
N = tokens.shape[1]
|
||||
num_drop = N % batch_seq_len
|
||||
if num_drop != 0:
|
||||
tokens = tokens[:, :-num_drop]
|
||||
tokens = tokens.reshape([-1, batch_seq_len]) # B N
|
||||
return tokens
|
||||
|
||||
|
||||
class LanguageDataset(TorchDataset):
|
||||
|
||||
def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
|
||||
super().__init__()
|
||||
# seq: 1, N
|
||||
self.seq_len = seq_len
|
||||
|
||||
self.seq = fold_tokens(seq) # B N
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.seq.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.seq[index]
|
||||
print_log(f'init {rank}/{world_size}')
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -112,37 +86,6 @@ def opt_infer(
|
|||
print_log(f'{(i+1)*B} / {len(dataloader.dataset)}')
|
||||
|
||||
|
||||
def build_language_loader(testloader, world_size, rank, model, batch_size=128):
|
||||
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
|
||||
distributed_sampler = DistributedSampler(
|
||||
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
|
||||
batch_size = min(len(val_dataset) // world_size, batch_size)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
sampler=distributed_sampler)
|
||||
return val_dataloader
|
||||
|
||||
|
||||
class init_on_meta:
|
||||
|
||||
def __init__(self, enable=True) -> None:
|
||||
self.enable = enable
|
||||
self.default_device = torch.ones([]).device
|
||||
|
||||
def __enter__(self):
|
||||
if self.enable:
|
||||
torch.set_default_device('meta')
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self.enable:
|
||||
torch.set_default_device(self.default_device)
|
||||
|
||||
|
||||
def _materialize_meta_module(module: nn.Module, ):
|
||||
# Run default meta device initialization
|
||||
|
||||
|
|
|
@ -78,3 +78,18 @@ def opt_infer(
|
|||
|
||||
if (i + 1) * batch_size >= num_samples:
|
||||
break
|
||||
|
||||
|
||||
class init_on_meta:
|
||||
|
||||
def __init__(self, enable=True) -> None:
|
||||
self.enable = enable
|
||||
self.default_device = torch.ones([]).device
|
||||
|
||||
def __enter__(self):
|
||||
if self.enable:
|
||||
torch.set_default_device('meta')
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self.enable:
|
||||
torch.set_default_device(self.default_device)
|
||||
|
|
Loading…
Reference in New Issue