mirror of
https://github.com/facebookresearch/deit.git
synced 2025-06-03 14:52:20 +08:00
397 lines
15 KiB
Python
397 lines
15 KiB
Python
import random
|
|
import utils
|
|
import numpy as np
|
|
import time
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
import argparse
|
|
import json
|
|
import os
|
|
import copy
|
|
import random
|
|
|
|
from engine import evaluate
|
|
from datasets import build_dataset
|
|
from pathlib import Path
|
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
from timm.models import create_model
|
|
from torchvision import datasets, transforms
|
|
from collections import defaultdict
|
|
import yaml
|
|
from yaml.loader import SafeLoader
|
|
import model_sparse
|
|
from sparsity_factory import get_model_sparsity, weight_pruner_loader
|
|
|
|
class RandomCandGenerator():
|
|
def __init__(self, sparsity_config):
|
|
self.sparsity_config = sparsity_config
|
|
self.num_candidates_per_block = len(sparsity_config[0]) # might have bug if each block has different number of choices
|
|
self.config_length = len(sparsity_config) # e.g., the len of DeiT-S is 48 (12 blocks, each has qkv, fc1, fc2, and linear projection)
|
|
self.m = defaultdict(list) # m: the magic dictionary with {index: cand_config}
|
|
#random.seed(seed)
|
|
v = [] # v: a temp vector for function rec()
|
|
self.rec(v, self.m)
|
|
|
|
def calc(self, v): # generate the unique index for each candidate
|
|
res = 0
|
|
for i in range(self.num_candidates_per_block):
|
|
res += i * v[i]
|
|
return res
|
|
|
|
def rec(self, v, m, idx=0, cur=0): # recursively enumerate all possible candidates and attach unique indexes for them
|
|
if idx == (self.num_candidates_per_block-1) :
|
|
v.append(self.config_length - cur)
|
|
m[self.calc(v)].append(copy.copy(v))
|
|
v.pop()
|
|
return
|
|
|
|
i = self.config_length - cur
|
|
while i >= 0:
|
|
v.append(i)
|
|
self.rec(v, m, idx+1, cur+i)
|
|
v.pop()
|
|
i -= 1
|
|
|
|
def random(self): # generate a random index and return its corresponding candidate
|
|
row = random.choice(random.choice(self.m))
|
|
ratios = []
|
|
for num, ratio in zip(row, [i for i in range(self.num_candidates_per_block)]):
|
|
ratios += [ratio] * num
|
|
random.shuffle(ratios)
|
|
res = []
|
|
for idx, ratio in enumerate(ratios):
|
|
res.append(tuple(self.sparsity_config[idx][ratio])) # Fixme:
|
|
return res # return a cand_config
|
|
|
|
|
|
|
|
class EvolutionSearcher():
|
|
def __init__(self, args, model, model_without_ddp, sparsity_config, val_loader, output_dir, config):
|
|
self.model = model
|
|
self.model_without_ddp = model_without_ddp
|
|
self.max_epochs = args.max_epochs
|
|
self.select_num = args.select_num
|
|
self.population_num = args.population_num
|
|
self.m_prob = args.m_prob
|
|
self.crossover_num = args.crossover_num
|
|
self.mutation_num = args.mutation_num
|
|
self.parameters_limits = args.param_limits
|
|
self.min_parameters_limits = args.min_param_limits
|
|
self.val_loader = val_loader
|
|
self.output_dir = output_dir
|
|
self.s_prob =args.s_prob
|
|
self.memory = []
|
|
self.vis_dict = {}
|
|
self.keep_top_k = {self.select_num: [], 50: []}
|
|
self.epoch = 0
|
|
self.candidates = []
|
|
self.top_accuracies = []
|
|
self.cand_params = []
|
|
self.sparsity_config = config['sparsity']['choices']
|
|
|
|
self.rcg = RandomCandGenerator(self.sparsity_config)
|
|
|
|
def save_checkpoint(self):
|
|
|
|
info = {}
|
|
info['top_accuracies'] = self.top_accuracies
|
|
info['memory'] = self.memory
|
|
info['candidates'] = self.candidates
|
|
info['vis_dict'] = self.vis_dict
|
|
info['keep_top_k'] = self.keep_top_k
|
|
info['epoch'] = self.epoch
|
|
checkpoint_path = os.path.join(self.output_dir, "checkpoint-{}.pth.tar".format(self.epoch))
|
|
torch.save(info, checkpoint_path)
|
|
print('save checkpoint to', checkpoint_path)
|
|
|
|
def is_legal(self, cand):
|
|
assert isinstance(cand, tuple)
|
|
if cand not in self.vis_dict:
|
|
self.vis_dict[cand] = {}
|
|
info = self.vis_dict[cand]
|
|
if 'visited' in info:
|
|
return False
|
|
|
|
|
|
self.model_without_ddp.set_sample_config(cand)
|
|
print(cand)
|
|
n_parameters = self.model_without_ddp.num_params() / 1e6
|
|
info['params'] = n_parameters # sparsity level
|
|
print(n_parameters)
|
|
|
|
if info['params'] > self.parameters_limits:
|
|
print('parameters limit exceed')
|
|
return False
|
|
|
|
if info['params'] < self.min_parameters_limits:
|
|
print('under minimum parameters limit')
|
|
return False
|
|
|
|
print("rank:", utils.get_rank(), cand, info['params'])
|
|
eval_stats = evaluate(self.val_loader, self.model, 'cuda')
|
|
|
|
info['acc'] = eval_stats['acc1']
|
|
|
|
info['visited'] = True
|
|
|
|
return True
|
|
|
|
def update_top_k(self, candidates, *, k, key, reverse=True):
|
|
assert k in self.keep_top_k
|
|
print('select ......')
|
|
t = self.keep_top_k[k]
|
|
t += candidates
|
|
t.sort(key=key, reverse=reverse)
|
|
self.keep_top_k[k] = t[:k]
|
|
|
|
def stack_random_cand(self, random_func, *, batchsize=10):
|
|
while True:
|
|
cands = [random_func() for _ in range(batchsize)]
|
|
for cand in cands:
|
|
print(cands)
|
|
if cand not in self.vis_dict:
|
|
self.vis_dict[cand] = {}
|
|
info = self.vis_dict[cand]
|
|
for cand in cands:
|
|
yield cand
|
|
|
|
def get_random_cand(self):
|
|
|
|
cand_tuple = self.rcg.random()
|
|
|
|
return tuple(cand_tuple)
|
|
|
|
def get_random(self, num):
|
|
print('random select ........')
|
|
cand_iter = self.stack_random_cand(self.get_random_cand)
|
|
while len(self.candidates) < num:
|
|
cand = next(cand_iter)
|
|
if not self.is_legal(cand):
|
|
continue
|
|
self.candidates.append(cand)
|
|
print('random {}/{}'.format(len(self.candidates), num))
|
|
print('random_num = {}'.format(len(self.candidates)))
|
|
|
|
def get_mutation(self, k, mutation_num, m_prob, s_prob):
|
|
assert k in self.keep_top_k
|
|
print('mutation ......')
|
|
res = []
|
|
iter = 0
|
|
max_iters = mutation_num * 10
|
|
|
|
def random_func():
|
|
cand = list(random.choice(self.keep_top_k[k]))
|
|
|
|
# sparsity ratio
|
|
for idx in range(len(self.sparsity_config)):
|
|
random_s = random.random()
|
|
if random_s < m_prob:
|
|
cand[idx] = tuple(random.choice(self.sparsity_config[idx]))
|
|
|
|
return tuple(cand)
|
|
|
|
|
|
cand_iter = self.stack_random_cand(random_func)
|
|
while len(res) < mutation_num and max_iters > 0:
|
|
max_iters -= 1
|
|
cand = next(cand_iter)
|
|
if not self.is_legal(cand):
|
|
continue
|
|
res.append(cand)
|
|
print('mutation {}/{}'.format(len(res), mutation_num))
|
|
|
|
print('mutation_num = {}'.format(len(res)))
|
|
return res
|
|
|
|
def get_crossover(self, k, crossover_num):
|
|
assert k in self.keep_top_k
|
|
print('crossover ......')
|
|
res = []
|
|
iter = 0
|
|
max_iters = 10 * crossover_num
|
|
|
|
def random_func():
|
|
|
|
p1 = random.choice(self.keep_top_k[k])
|
|
p2 = random.choice(self.keep_top_k[k])
|
|
max_iters_tmp = 50
|
|
while len(p1) != len(p2) and max_iters_tmp > 0:
|
|
max_iters_tmp -= 1
|
|
p1 = random.choice(self.keep_top_k[k])
|
|
p2 = random.choice(self.keep_top_k[k])
|
|
return tuple(random.choice([i, j]) for i, j in zip(p1, p2))
|
|
|
|
cand_iter = self.stack_random_cand(random_func)
|
|
while len(res) < crossover_num and max_iters > 0:
|
|
max_iters -= 1
|
|
cand = next(cand_iter)
|
|
if not self.is_legal(cand):
|
|
continue
|
|
res.append(cand)
|
|
print('crossover {}/{}'.format(len(res), crossover_num))
|
|
|
|
print('crossover_num = {}'.format(len(res)))
|
|
return res
|
|
|
|
def search(self):
|
|
print(
|
|
'population_num = {} select_num = {} mutation_num = {} crossover_num = {} random_num = {} max_epochs = {}'.format(
|
|
self.population_num, self.select_num, self.mutation_num, self.crossover_num,
|
|
self.population_num - self.mutation_num - self.crossover_num, self.max_epochs))
|
|
|
|
# self.load_checkpoint()
|
|
|
|
self.get_random(self.population_num)
|
|
|
|
while self.epoch < self.max_epochs:
|
|
print('epoch = {}'.format(self.epoch))
|
|
|
|
self.memory.append([])
|
|
for cand in self.candidates:
|
|
self.memory[-1].append(cand)
|
|
|
|
self.update_top_k(
|
|
self.candidates, k=self.select_num, key=lambda x: self.vis_dict[x]['acc'])
|
|
self.update_top_k(
|
|
self.candidates, k=50, key=lambda x: self.vis_dict[x]['acc'])
|
|
|
|
print('epoch = {} : top {} result'.format(
|
|
self.epoch, len(self.keep_top_k[50])))
|
|
tmp_accuracy = []
|
|
for i, cand in enumerate(self.keep_top_k[50]):
|
|
print('No.{} {} Top-1 val acc = {}, params = {}'.format(
|
|
i + 1, cand, self.vis_dict[cand]['acc'], self.vis_dict[cand]['params']))
|
|
tmp_accuracy.append(self.vis_dict[cand]['acc'])
|
|
self.top_accuracies.append(tmp_accuracy)
|
|
|
|
mutation = self.get_mutation(
|
|
self.select_num, self.mutation_num, self.m_prob, self.s_prob)
|
|
crossover = self.get_crossover(self.select_num, self.crossover_num)
|
|
|
|
self.candidates = mutation + crossover
|
|
|
|
self.get_random(self.population_num)
|
|
|
|
self.epoch += 1
|
|
|
|
self.save_checkpoint()
|
|
|
|
def get_args_parser():
|
|
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
|
|
parser.add_argument('--batch-size', default=128, type=int)
|
|
# data-params
|
|
|
|
# evolution search parameters
|
|
parser.add_argument('--max-epochs', type=int, default=20)
|
|
parser.add_argument('--select-num', type=int, default=10)
|
|
parser.add_argument('--population-num', type=int, default=50)
|
|
parser.add_argument('--m_prob', type=float, default=0.2)
|
|
parser.add_argument('--s_prob', type=float, default=0.4)
|
|
parser.add_argument('--crossover-num', type=int, default=25)
|
|
parser.add_argument('--epochs', type=int, default=30)
|
|
parser.add_argument('--mutation-num', type=int, default=25)
|
|
parser.add_argument('--param-limits', type=float, default=5.6)
|
|
parser.add_argument('--min-param-limits', type=float, default=5)
|
|
parser.add_argument('--output_dir', default='',
|
|
help='path where to save, empty for no saving')
|
|
parser.add_argument('--num_workers', default=16, type=int)
|
|
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
|
|
parser.add_argument('--input-size', default=224, type=int, help='images input size')
|
|
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
|
help='dataset path')
|
|
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
|
|
type=str, help='Image Net dataset path')
|
|
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
|
|
parser.add_argument('--pin-mem', action='store_true',
|
|
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
|
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
|
|
help='')
|
|
parser.set_defaults(pin_mem=True)
|
|
parser.add_argument('--model', default='', type=str, metavar='MODEL',
|
|
help='Name of model to train')
|
|
parser.add_argument('--world_size', default=1, type=int,
|
|
help='number of distributed processes')
|
|
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
|
|
|
# Sparsity correlated arguments
|
|
parser.add_argument('--sparsity-config', default='', type=str, help='path to the sparsity yaml file')
|
|
|
|
return parser
|
|
|
|
def main(args):
|
|
utils.init_distributed_mode(args)
|
|
|
|
print(args)
|
|
|
|
cudnn.benchmark = True
|
|
|
|
'''
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
|
])
|
|
|
|
dataset_val = datasets.ImageFolder(
|
|
os.path.join(args.data_path, 'val'),
|
|
transform=transform)
|
|
'''
|
|
|
|
dataset_val, _ = build_dataset(is_train=False, args=args)
|
|
if True: # args.distributed:
|
|
num_tasks = utils.get_world_size()
|
|
global_rank = utils.get_rank()
|
|
if args.dist_eval:
|
|
if len(dataset_val) % num_tasks != 0:
|
|
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
|
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
|
'equal num of samples per-process.')
|
|
sampler_val = torch.utils.data.DistributedSampler(
|
|
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
|
|
else:
|
|
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
|
|
|
data_loader_val = torch.utils.data.DataLoader(
|
|
dataset_val, batch_size=int(2 * args.batch_size),
|
|
sampler=sampler_val, num_workers=args.num_workers,
|
|
pin_memory=args.pin_mem, drop_last=False
|
|
)
|
|
|
|
with open(args.sparsity_config) as f:
|
|
sparsity_config = yaml.load(f, Loader=SafeLoader)
|
|
|
|
|
|
print(f"Creating model: {args.model}")
|
|
model = create_model(
|
|
args.model,
|
|
pretrained=True,
|
|
num_classes=1000,
|
|
drop_rate=0,
|
|
drop_path_rate=0,
|
|
img_size=args.input_size
|
|
)
|
|
|
|
model.cuda()
|
|
|
|
model_without_ddp = model
|
|
if args.distributed:
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
|
model_without_ddp = model.module
|
|
|
|
t = time.time()
|
|
searcher = EvolutionSearcher(args, model, model_without_ddp, sparsity_config, val_loader=data_loader_val, output_dir = args.output_dir, config=sparsity_config)
|
|
|
|
searcher.search()
|
|
|
|
print('total searching time = {:.2f} hours'.format(
|
|
(time.time() - t) / 3600))
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser('AutoFormer evolution search', parents=[get_args_parser()])
|
|
args = parser.parse_args()
|
|
if args.output_dir:
|
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
|
main(args)
|
|
|
|
|
|
# CUDA_VISIBLE_DEVICES=3 python evolution_svd.py --data-path /dev/shm/imagenet/ --output_dir BASE_EA_13_16.5 --config sparsity_config/Vit_imnet_config_base.json --model deit_dist_base_p16_224_imnet_0416_wo_fc/checkpoint.pth --param-limits 16.5 --min-param-limits 13
|
|
#python -m torch.distributed.launch --nproc_per_node=2 evolution_search.py --data-path /home/yysung/imagenet --output_dir deit_small_nxm_ea_124 --sparsity-config configs/deit_small_nxm_ea124.yml --model Sparse_deit_small_patch16_224 --param-limits 13.2 --min-param-limits 8 |