mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
108 lines
3.5 KiB
Python
Executable File
108 lines
3.5 KiB
Python
Executable File
import numpy as np
|
|
import torch
|
|
|
|
|
|
def set_seed(seed):
|
|
np.random.seed(seed)
|
|
torch.random.manual_seed(seed)
|
|
|
|
|
|
def get_wikitext2(nsamples, seed, seqlen, model):
|
|
from datasets import load_dataset
|
|
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
|
|
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
|
|
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
|
|
trainenc = tokenizer(' '.join(traindata['text']), return_tensors='pt')
|
|
testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt')
|
|
|
|
import random
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
return trainloader, testenc
|
|
|
|
|
|
def get_ptb(nsamples, seed, seqlen, model):
|
|
from datasets import load_dataset
|
|
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
|
|
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
|
|
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
|
|
trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt')
|
|
testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt')
|
|
|
|
import random
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
return trainloader, testenc
|
|
|
|
|
|
def get_c4(nsamples, seed, seqlen, model):
|
|
from datasets import load_dataset
|
|
traindata = load_dataset(
|
|
'allenai/c4',
|
|
'allenai--c4',
|
|
data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
|
|
split='train')
|
|
valdata = load_dataset(
|
|
'allenai/c4',
|
|
'allenai--c4',
|
|
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
|
|
split='validation')
|
|
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
|
|
|
|
import random
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
while True:
|
|
i = random.randint(0, len(traindata) - 1)
|
|
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
|
|
if trainenc.input_ids.shape[1] >= seqlen:
|
|
break
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
|
|
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
|
|
valenc = valenc.input_ids[:, :(256 * seqlen)]
|
|
|
|
class TokenizerWrapper:
|
|
|
|
def __init__(self, input_ids):
|
|
self.input_ids = input_ids
|
|
|
|
valenc = TokenizerWrapper(valenc)
|
|
|
|
return trainloader, valenc
|
|
|
|
|
|
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
|
|
if 'wikitext2' in name:
|
|
return get_wikitext2(nsamples, seed, seqlen, model)
|
|
if 'ptb' in name:
|
|
return get_ptb(nsamples, seed, seqlen, model)
|
|
if 'c4' in name:
|
|
return get_c4(nsamples, seed, seqlen, model)
|