refine
parent
e4303d110b
commit
f994c6d09e
|
@ -1,6 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data import Dataset as TorchDataset
|
||||||
|
from torch.utils.data import DistributedSampler
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
def set_seed(seed):
|
||||||
|
@ -106,3 +108,45 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
|
||||||
return get_ptb(nsamples, seed, seqlen, model)
|
return get_ptb(nsamples, seed, seqlen, model)
|
||||||
if 'c4' in name:
|
if 'c4' in name:
|
||||||
return get_c4(nsamples, seed, seqlen, model)
|
return get_c4(nsamples, seed, seqlen, model)
|
||||||
|
|
||||||
|
|
||||||
|
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
|
||||||
|
# tokens: 1 N
|
||||||
|
N = tokens.shape[1]
|
||||||
|
num_drop = N % batch_seq_len
|
||||||
|
if num_drop != 0:
|
||||||
|
tokens = tokens[:, :-num_drop]
|
||||||
|
tokens = tokens.reshape([-1, batch_seq_len]) # B N
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class LanguageDataset(TorchDataset):
|
||||||
|
|
||||||
|
def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# seq: 1, N
|
||||||
|
self.seq_len = seq_len
|
||||||
|
|
||||||
|
self.seq = fold_tokens(seq) # B N
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.seq.shape[0]
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.seq[index]
|
||||||
|
|
||||||
|
|
||||||
|
def build_language_loader(testloader, world_size, rank, model, batch_size=128):
|
||||||
|
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
|
||||||
|
distributed_sampler = DistributedSampler(
|
||||||
|
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
|
||||||
|
batch_size = min(len(val_dataset) // world_size, batch_size)
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=0,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
sampler=distributed_sampler)
|
||||||
|
return val_dataloader
|
||||||
|
|
|
@ -5,15 +5,14 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from datautils import get_loaders
|
from datautils import build_language_loader, get_loaders
|
||||||
from opt_sparse_gpt import get_model
|
from opt_sparse_gpt import get_model
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp.api import ShardingStrategy
|
from torch.distributed.fsdp.api import ShardingStrategy
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
||||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data import Dataset as TorchDataset
|
from utils import init_on_meta
|
||||||
from torch.utils.data import DistributedSampler
|
|
||||||
|
|
||||||
from mmrazor.implementations.pruning import sparse_gpt
|
from mmrazor.implementations.pruning import sparse_gpt
|
||||||
from mmrazor.utils import print_log
|
from mmrazor.utils import print_log
|
||||||
|
@ -25,32 +24,7 @@ def setup(rank, world_size):
|
||||||
|
|
||||||
dist.init_process_group('nccl', rank=rank, world_size=world_size)
|
dist.init_process_group('nccl', rank=rank, world_size=world_size)
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
|
print_log(f'init {rank}/{world_size}')
|
||||||
|
|
||||||
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
|
|
||||||
# tokens: 1 N
|
|
||||||
N = tokens.shape[1]
|
|
||||||
num_drop = N % batch_seq_len
|
|
||||||
if num_drop != 0:
|
|
||||||
tokens = tokens[:, :-num_drop]
|
|
||||||
tokens = tokens.reshape([-1, batch_seq_len]) # B N
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
class LanguageDataset(TorchDataset):
|
|
||||||
|
|
||||||
def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
|
|
||||||
super().__init__()
|
|
||||||
# seq: 1, N
|
|
||||||
self.seq_len = seq_len
|
|
||||||
|
|
||||||
self.seq = fold_tokens(seq) # B N
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.seq.shape[0]
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
return self.seq[index]
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -112,37 +86,6 @@ def opt_infer(
|
||||||
print_log(f'{(i+1)*B} / {len(dataloader.dataset)}')
|
print_log(f'{(i+1)*B} / {len(dataloader.dataset)}')
|
||||||
|
|
||||||
|
|
||||||
def build_language_loader(testloader, world_size, rank, model, batch_size=128):
|
|
||||||
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
|
|
||||||
distributed_sampler = DistributedSampler(
|
|
||||||
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
|
|
||||||
batch_size = min(len(val_dataset) // world_size, batch_size)
|
|
||||||
val_dataloader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=0,
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=True,
|
|
||||||
sampler=distributed_sampler)
|
|
||||||
return val_dataloader
|
|
||||||
|
|
||||||
|
|
||||||
class init_on_meta:
|
|
||||||
|
|
||||||
def __init__(self, enable=True) -> None:
|
|
||||||
self.enable = enable
|
|
||||||
self.default_device = torch.ones([]).device
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
if self.enable:
|
|
||||||
torch.set_default_device('meta')
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
if self.enable:
|
|
||||||
torch.set_default_device(self.default_device)
|
|
||||||
|
|
||||||
|
|
||||||
def _materialize_meta_module(module: nn.Module, ):
|
def _materialize_meta_module(module: nn.Module, ):
|
||||||
# Run default meta device initialization
|
# Run default meta device initialization
|
||||||
|
|
||||||
|
|
|
@ -78,3 +78,18 @@ def opt_infer(
|
||||||
|
|
||||||
if (i + 1) * batch_size >= num_samples:
|
if (i + 1) * batch_size >= num_samples:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
class init_on_meta:
|
||||||
|
|
||||||
|
def __init__(self, enable=True) -> None:
|
||||||
|
self.enable = enable
|
||||||
|
self.default_device = torch.ones([]).device
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self.enable:
|
||||||
|
torch.set_default_device('meta')
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
if self.enable:
|
||||||
|
torch.set_default_device(self.default_device)
|
||||||
|
|
Loading…
Reference in New Issue