mirror of
https://github.com/facebookresearch/deit.git
synced 2025-06-03 14:52:20 +08:00
Add the basic implementation of the sparsity factory
This commit is contained in:
parent
ee8893c806
commit
ce08e408dd
3
configs/deit_small_nxm.yaml
Normal file
3
configs/deit_small_nxm.yaml
Normal file
@ -0,0 +1,3 @@
|
||||
sparsity:
|
||||
mode: nxm
|
||||
level: [[4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2]]
|
3
configs/deit_small_uniform.yaml
Normal file
3
configs/deit_small_uniform.yaml
Normal file
@ -0,0 +1,3 @@
|
||||
sparsity:
|
||||
pruner:
|
||||
level: [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
|
8
configs/testing.py
Normal file
8
configs/testing.py
Normal file
@ -0,0 +1,8 @@
|
||||
import yaml
|
||||
from yaml.loader import SafeLoader
|
||||
|
||||
# Open the file and load the file
|
||||
with open('deit_small_nxm.yaml') as f:
|
||||
data = yaml.load(f, Loader=SafeLoader)
|
||||
print(data)
|
||||
|
@ -33,10 +33,10 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
|
||||
|
||||
if mixup_fn is not None:
|
||||
samples, targets = mixup_fn(samples, targets)
|
||||
|
||||
|
||||
if args.bce_loss:
|
||||
targets = targets.gt(0.0).type(targets.dtype)
|
||||
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs = model(samples)
|
||||
loss = criterion(samples, outputs, targets)
|
||||
|
38
main.py
38
main.py
@ -105,15 +105,15 @@ def get_args_parser():
|
||||
parser.add_argument('--repeated-aug', action='store_true')
|
||||
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
|
||||
parser.set_defaults(repeated_aug=True)
|
||||
|
||||
|
||||
parser.add_argument('--train-mode', action='store_true')
|
||||
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
|
||||
parser.set_defaults(train_mode=True)
|
||||
|
||||
|
||||
parser.add_argument('--ThreeAugment', action='store_true') #3augment
|
||||
|
||||
|
||||
parser.add_argument('--src', action='store_true') #simple random crop
|
||||
|
||||
|
||||
# * Random Erase params
|
||||
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
||||
help='Random erase prob (default: 0.25)')
|
||||
@ -148,8 +148,8 @@ def get_args_parser():
|
||||
|
||||
# * Finetuning params
|
||||
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
||||
parser.add_argument('--attn-only', action='store_true')
|
||||
|
||||
parser.add_argument('--attn-only', action='store_true')
|
||||
|
||||
# Dataset parameters
|
||||
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
||||
help='dataset path')
|
||||
@ -266,7 +266,9 @@ def main(args):
|
||||
img_size=args.input_size
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if args.finetune:
|
||||
if args.finetune.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
@ -302,7 +304,7 @@ def main(args):
|
||||
checkpoint_model['pos_embed'] = new_pos_embed
|
||||
|
||||
model.load_state_dict(checkpoint_model, strict=False)
|
||||
|
||||
|
||||
if args.attn_only:
|
||||
for name_p,p in model.named_parameters():
|
||||
if '.attn.' in name_p:
|
||||
@ -324,7 +326,7 @@ def main(args):
|
||||
p.requires_grad = False
|
||||
except:
|
||||
print('no patch embed')
|
||||
|
||||
|
||||
model.to(device)
|
||||
|
||||
model_ema = None
|
||||
@ -359,10 +361,10 @@ def main(args):
|
||||
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
||||
else:
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
|
||||
if args.bce_loss:
|
||||
criterion = torch.nn.BCEWithLogitsLoss()
|
||||
|
||||
|
||||
teacher_model = None
|
||||
if args.distillation_type != 'none':
|
||||
assert args.teacher_path, 'need to specify teacher-path when using distillation'
|
||||
@ -438,11 +440,11 @@ def main(args):
|
||||
'scaler': loss_scaler.state_dict(),
|
||||
'args': args,
|
||||
}, checkpoint_path)
|
||||
|
||||
|
||||
|
||||
test_stats = evaluate(data_loader_val, model, device)
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
|
||||
|
||||
if max_accuracy < test_stats["acc1"]:
|
||||
max_accuracy = test_stats["acc1"]
|
||||
if args.output_dir:
|
||||
@ -457,17 +459,17 @@ def main(args):
|
||||
'scaler': loss_scaler.state_dict(),
|
||||
'args': args,
|
||||
}, checkpoint_path)
|
||||
|
||||
|
||||
print(f'Max accuracy: {max_accuracy:.2f}%')
|
||||
|
||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
**{f'test_{k}': v for k, v in test_stats.items()},
|
||||
'epoch': epoch,
|
||||
'n_parameters': n_parameters}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if args.output_dir and utils.is_main_process():
|
||||
with (output_dir / "log.txt").open("a") as f:
|
||||
f.write(json.dumps(log_stats) + "\n")
|
||||
|
513
sparsify_training.py
Normal file
513
sparsify_training.py
Normal file
@ -0,0 +1,513 @@
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
import argparse
|
||||
import datetime
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import json
|
||||
import yaml
|
||||
from yaml.loader import SafeLoader
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from timm.data import Mixup
|
||||
from timm.models import create_model
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||
from timm.scheduler import create_scheduler
|
||||
from timm.optim import create_optimizer
|
||||
from timm.utils import NativeScaler, get_state_dict, ModelEma
|
||||
|
||||
from datasets import build_dataset
|
||||
from engine import train_one_epoch, evaluate
|
||||
from losses import DistillationLoss
|
||||
from samplers import RASampler
|
||||
from augment import new_data_aug_generator
|
||||
|
||||
import models
|
||||
import models_v2
|
||||
|
||||
import utils
|
||||
|
||||
from sparsity_factory.pruners import weight_pruner_loader, prune_weights_reparam, check_valid_pruner
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
|
||||
parser.add_argument('--batch-size', default=64, type=int)
|
||||
parser.add_argument('--epochs', default=300, type=int)
|
||||
parser.add_argument('--bce-loss', action='store_true')
|
||||
parser.add_argument('--unscale-lr', action='store_true')
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
|
||||
help='Name of model to train')
|
||||
parser.add_argument('--input-size', default=224, type=int, help='images input size')
|
||||
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
|
||||
help='Drop path rate (default: 0.1)')
|
||||
|
||||
parser.add_argument('--model-ema', action='store_true')
|
||||
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
|
||||
parser.set_defaults(model_ema=False)
|
||||
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
|
||||
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
|
||||
|
||||
# Optimizer parameters
|
||||
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "adamw"')
|
||||
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
||||
help='Optimizer Epsilon (default: 1e-8)')
|
||||
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
||||
help='Optimizer Betas (default: None, use opt default)')
|
||||
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
||||
help='Clip gradient norm (default: None, no clipping)')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||
help='SGD momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.05,
|
||||
help='weight decay (default: 0.05)')
|
||||
# Learning rate schedule parameters
|
||||
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
||||
help='LR scheduler (default: "cosine"')
|
||||
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
|
||||
help='learning rate (default: 5e-4)')
|
||||
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
||||
help='learning rate noise on/off epoch percentages')
|
||||
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
||||
help='learning rate noise limit percent (default: 0.67)')
|
||||
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
||||
help='learning rate noise std-dev (default: 1.0)')
|
||||
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
|
||||
help='warmup learning rate (default: 1e-6)')
|
||||
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
||||
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
||||
|
||||
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
||||
help='epoch interval to decay LR')
|
||||
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
|
||||
help='epochs to warmup LR, if scheduler supports')
|
||||
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
||||
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
||||
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
||||
help='patience epochs for Plateau LR scheduler (default: 10')
|
||||
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
||||
help='LR decay rate (default: 0.1)')
|
||||
|
||||
# Augmentation parameters
|
||||
parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
|
||||
help='Color jitter factor (default: 0.3)')
|
||||
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
||||
help='Use AutoAugment policy. "v0" or "original". " + \
|
||||
"(default: rand-m9-mstd0.5-inc1)'),
|
||||
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
|
||||
parser.add_argument('--train-interpolation', type=str, default='bicubic',
|
||||
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
|
||||
|
||||
parser.add_argument('--repeated-aug', action='store_true')
|
||||
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
|
||||
parser.set_defaults(repeated_aug=True)
|
||||
|
||||
parser.add_argument('--train-mode', action='store_true')
|
||||
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
|
||||
parser.set_defaults(train_mode=True)
|
||||
|
||||
parser.add_argument('--ThreeAugment', action='store_true') #3augment
|
||||
|
||||
parser.add_argument('--src', action='store_true') #simple random crop
|
||||
|
||||
# * Random Erase params
|
||||
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
||||
help='Random erase prob (default: 0.25)')
|
||||
parser.add_argument('--remode', type=str, default='pixel',
|
||||
help='Random erase mode (default: "pixel")')
|
||||
parser.add_argument('--recount', type=int, default=1,
|
||||
help='Random erase count (default: 1)')
|
||||
parser.add_argument('--resplit', action='store_true', default=False,
|
||||
help='Do not random erase first (clean) augmentation split')
|
||||
|
||||
# * Mixup params
|
||||
parser.add_argument('--mixup', type=float, default=0.8,
|
||||
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
|
||||
parser.add_argument('--cutmix', type=float, default=1.0,
|
||||
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
|
||||
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
|
||||
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
||||
parser.add_argument('--mixup-prob', type=float, default=1.0,
|
||||
help='Probability of performing mixup or cutmix when either/both is enabled')
|
||||
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
|
||||
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
||||
parser.add_argument('--mixup-mode', type=str, default='batch',
|
||||
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
||||
|
||||
# Distillation parameters
|
||||
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
|
||||
help='Name of teacher model to train (default: "regnety_160"')
|
||||
parser.add_argument('--teacher-path', type=str, default='')
|
||||
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
|
||||
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
|
||||
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
|
||||
|
||||
# * Finetuning params
|
||||
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
||||
parser.add_argument('--attn-only', action='store_true')
|
||||
|
||||
# Dataset parameters
|
||||
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
||||
help='dataset path')
|
||||
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
|
||||
type=str, help='Image Net dataset path')
|
||||
parser.add_argument('--inat-category', default='name',
|
||||
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
|
||||
type=str, help='semantic granularity')
|
||||
|
||||
parser.add_argument('--output_dir', default='',
|
||||
help='path where to save, empty for no saving')
|
||||
parser.add_argument('--device', default='cuda',
|
||||
help='device to use for training / testing')
|
||||
parser.add_argument('--seed', default=0, type=int)
|
||||
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
||||
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
||||
help='start epoch')
|
||||
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
||||
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
|
||||
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
|
||||
parser.add_argument('--num_workers', default=10, type=int)
|
||||
parser.add_argument('--pin-mem', action='store_true',
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
|
||||
help='')
|
||||
parser.set_defaults(pin_mem=True)
|
||||
|
||||
# distributed training parameters
|
||||
parser.add_argument('--world_size', default=1, type=int,
|
||||
help='number of distributed processes')
|
||||
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
||||
|
||||
# sparsity parameters
|
||||
parser.add_argument('--pruner', type=str, help='pruning criterion')
|
||||
parser.add_argument('--sparsity', type=float, default=1.0, help = 'the sparisty level (ratio of unpruned weight)')
|
||||
parser.add_argument('--custom-config', type=str, help='customized configuration of sparsity level for each linear layer')
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
utils.init_distributed_mode(args)
|
||||
|
||||
print(args)
|
||||
|
||||
if args.distillation_type != 'none' and args.finetune and not args.eval:
|
||||
raise NotImplementedError("Finetuning with distillation not yet supported")
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed + utils.get_rank()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
# random.seed(seed)
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
|
||||
dataset_val, _ = build_dataset(is_train=False, args=args)
|
||||
|
||||
if True: # args.distributed:
|
||||
num_tasks = utils.get_world_size()
|
||||
global_rank = utils.get_rank()
|
||||
if args.repeated_aug:
|
||||
sampler_train = RASampler(
|
||||
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||
)
|
||||
else:
|
||||
sampler_train = torch.utils.data.DistributedSampler(
|
||||
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||
)
|
||||
if args.dist_eval:
|
||||
if len(dataset_val) % num_tasks != 0:
|
||||
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
||||
'equal num of samples per-process.')
|
||||
sampler_val = torch.utils.data.DistributedSampler(
|
||||
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
|
||||
else:
|
||||
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||||
else:
|
||||
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
||||
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||||
|
||||
data_loader_train = torch.utils.data.DataLoader(
|
||||
dataset_train, sampler=sampler_train,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=True,
|
||||
)
|
||||
if args.ThreeAugment:
|
||||
data_loader_train.dataset.transform = new_data_aug_generator(args)
|
||||
|
||||
data_loader_val = torch.utils.data.DataLoader(
|
||||
dataset_val, sampler=sampler_val,
|
||||
batch_size=int(1.5 * args.batch_size),
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
mixup_fn = None
|
||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
||||
if mixup_active:
|
||||
mixup_fn = Mixup(
|
||||
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
||||
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
||||
label_smoothing=args.smoothing, num_classes=args.nb_classes)
|
||||
|
||||
print(f"Creating model: {args.model}")
|
||||
model = create_model(
|
||||
args.model,
|
||||
pretrained=True,
|
||||
num_classes=args.nb_classes,
|
||||
drop_rate=args.drop,
|
||||
drop_path_rate=args.drop_path,
|
||||
drop_block_rate=None,
|
||||
img_size=args.input_size
|
||||
)
|
||||
|
||||
|
||||
|
||||
if args.pruner == 'custom':
|
||||
if args.custom_config:
|
||||
with open(args.custom_config) as f:
|
||||
config = yaml.load(f, Loader=SafeLoader)
|
||||
else:
|
||||
raise ValueError("Please provide the configuration file when using the custom mode")
|
||||
|
||||
mode = config['sparsity']['mode']
|
||||
sparsity_config = config['sparsity']['level']
|
||||
|
||||
pruner = weight_pruner_loader(args.pruner)
|
||||
pruner(model, mode, sparsity_config)
|
||||
elif check_valid_pruner(args.pruner):
|
||||
pruner = weight_pruner_loader(args.pruner)
|
||||
prune_weights_reparam(model)
|
||||
pruner(model, args.sparsity)
|
||||
else:
|
||||
raise ValueError(f"Pruner '{args.pruner}' is not supported")
|
||||
|
||||
|
||||
|
||||
if args.finetune:
|
||||
if args.finetune.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.finetune, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.finetune, map_location='cpu')
|
||||
|
||||
checkpoint_model = checkpoint['model']
|
||||
state_dict = model.state_dict()
|
||||
for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
|
||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
||||
print(f"Removing key {k} from pretrained checkpoint")
|
||||
del checkpoint_model[k]
|
||||
|
||||
# interpolate position embedding
|
||||
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches ** 0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
checkpoint_model['pos_embed'] = new_pos_embed
|
||||
|
||||
model.load_state_dict(checkpoint_model, strict=False)
|
||||
|
||||
if args.attn_only:
|
||||
for name_p,p in model.named_parameters():
|
||||
if '.attn.' in name_p:
|
||||
p.requires_grad = True
|
||||
else:
|
||||
p.requires_grad = False
|
||||
try:
|
||||
model.head.weight.requires_grad = True
|
||||
model.head.bias.requires_grad = True
|
||||
except:
|
||||
model.fc.weight.requires_grad = True
|
||||
model.fc.bias.requires_grad = True
|
||||
try:
|
||||
model.pos_embed.requires_grad = True
|
||||
except:
|
||||
print('no position encoding')
|
||||
try:
|
||||
for p in model.patch_embed.parameters():
|
||||
p.requires_grad = False
|
||||
except:
|
||||
print('no patch embed')
|
||||
|
||||
model.to(device)
|
||||
|
||||
model_ema = None
|
||||
if args.model_ema:
|
||||
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
||||
model_ema = ModelEma(
|
||||
model,
|
||||
decay=args.model_ema_decay,
|
||||
device='cpu' if args.model_ema_force_cpu else '',
|
||||
resume='')
|
||||
|
||||
model_without_ddp = model
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
model_without_ddp = model.module
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print('number of params:', n_parameters)
|
||||
if not args.unscale_lr:
|
||||
linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
|
||||
args.lr = linear_scaled_lr
|
||||
optimizer = create_optimizer(args, model_without_ddp)
|
||||
loss_scaler = NativeScaler()
|
||||
|
||||
lr_scheduler, _ = create_scheduler(args, optimizer)
|
||||
|
||||
criterion = LabelSmoothingCrossEntropy()
|
||||
|
||||
if mixup_active:
|
||||
# smoothing is handled with mixup label transform
|
||||
criterion = SoftTargetCrossEntropy()
|
||||
elif args.smoothing:
|
||||
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
||||
else:
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
if args.bce_loss:
|
||||
criterion = torch.nn.BCEWithLogitsLoss()
|
||||
|
||||
teacher_model = None
|
||||
if args.distillation_type != 'none':
|
||||
assert args.teacher_path, 'need to specify teacher-path when using distillation'
|
||||
print(f"Creating teacher model: {args.teacher_model}")
|
||||
teacher_model = create_model(
|
||||
args.teacher_model,
|
||||
pretrained=False,
|
||||
num_classes=args.nb_classes,
|
||||
global_pool='avg',
|
||||
)
|
||||
if args.teacher_path.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.teacher_path, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.teacher_path, map_location='cpu')
|
||||
teacher_model.load_state_dict(checkpoint['model'])
|
||||
teacher_model.to(device)
|
||||
teacher_model.eval()
|
||||
|
||||
# wrap the criterion in our custom DistillationLoss, which
|
||||
# just dispatches to the original criterion if args.distillation_type is 'none'
|
||||
criterion = DistillationLoss(
|
||||
criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
|
||||
)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
if args.resume:
|
||||
if args.resume.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.resume, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
model_without_ddp.load_state_dict(checkpoint['model'])
|
||||
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
||||
args.start_epoch = checkpoint['epoch'] + 1
|
||||
if args.model_ema:
|
||||
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
|
||||
if 'scaler' in checkpoint:
|
||||
loss_scaler.load_state_dict(checkpoint['scaler'])
|
||||
lr_scheduler.step(args.start_epoch)
|
||||
if args.eval:
|
||||
test_stats = evaluate(data_loader_val, model, device)
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
return
|
||||
|
||||
print(f"Start training for {args.epochs} epochs")
|
||||
start_time = time.time()
|
||||
max_accuracy = 0.0
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
data_loader_train.sampler.set_epoch(epoch)
|
||||
|
||||
train_stats = train_one_epoch(
|
||||
model, criterion, data_loader_train,
|
||||
optimizer, device, epoch, loss_scaler,
|
||||
args.clip_grad, model_ema, mixup_fn,
|
||||
set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning
|
||||
args = args,
|
||||
)
|
||||
|
||||
lr_scheduler.step(epoch)
|
||||
if args.output_dir:
|
||||
checkpoint_paths = [output_dir / 'checkpoint.pth']
|
||||
for checkpoint_path in checkpoint_paths:
|
||||
utils.save_on_master({
|
||||
'model': model_without_ddp.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'lr_scheduler': lr_scheduler.state_dict(),
|
||||
'epoch': epoch,
|
||||
'scaler': loss_scaler.state_dict(),
|
||||
'args': args,
|
||||
}, checkpoint_path)
|
||||
|
||||
|
||||
test_stats = evaluate(data_loader_val, model, device)
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
|
||||
if max_accuracy < test_stats["acc1"]:
|
||||
max_accuracy = test_stats["acc1"]
|
||||
if args.output_dir:
|
||||
checkpoint_paths = [output_dir / 'best_checkpoint.pth']
|
||||
for checkpoint_path in checkpoint_paths:
|
||||
utils.save_on_master({
|
||||
'model': model_without_ddp.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'lr_scheduler': lr_scheduler.state_dict(),
|
||||
'epoch': epoch,
|
||||
'scaler': loss_scaler.state_dict(),
|
||||
'args': args,
|
||||
}, checkpoint_path)
|
||||
|
||||
print(f'Max accuracy: {max_accuracy:.2f}%')
|
||||
|
||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
**{f'test_{k}': v for k, v in test_stats.items()},
|
||||
'epoch': epoch,
|
||||
'n_parameters': n_parameters}
|
||||
|
||||
|
||||
|
||||
|
||||
if args.output_dir and utils.is_main_process():
|
||||
with (output_dir / "log.txt").open("a") as f:
|
||||
f.write(json.dumps(log_stats) + "\n")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Training time {}'.format(total_time_str))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
if args.output_dir:
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
main(args)
|
5
sparsity_factory/__init__.py
Normal file
5
sparsity_factory/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .utils import get_weights, get_modules, get_copied_modules, get_sparsities, get_nnzs, get_model_sparsity
|
||||
from .pruners import weight_pruner_loader, prune_weights_semistructured, check_valid_pruner
|
||||
#from .dataloaders import dataset_loader
|
||||
#from .modelloaders import model_and_opt_loader
|
||||
#from .train import trainer_loader,initialize_weight
|
13
sparsity_factory/dataloaders.py
Normal file
13
sparsity_factory/dataloaders.py
Normal file
@ -0,0 +1,13 @@
|
||||
import torch
|
||||
from tools.datasets import *
|
||||
|
||||
data_route = './data/'
|
||||
cifar10_strings = ['vgg16','resnet18','densenet','effnet']
|
||||
|
||||
def dataset_loader(model,batch_size=100,num_workers=5):
|
||||
if model in cifar10_strings:
|
||||
print("Loading CIFAR-10 with batch size "+str(batch_size))
|
||||
train_loader,test_loader = get_cifar10_loaders(data_route,batch_size,num_workers)
|
||||
else:
|
||||
raise ValueError('Model not implemented :P')
|
||||
return train_loader,test_loader
|
1
sparsity_factory/datasets/__init__.py
Normal file
1
sparsity_factory/datasets/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .cifar10 import get_cifar10_loaders
|
18
sparsity_factory/datasets/cifar10.py
Normal file
18
sparsity_factory/datasets/cifar10.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.datasets as dts
|
||||
import torchvision.transforms as T
|
||||
|
||||
cifar_nm = T.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))
|
||||
|
||||
def get_cifar10_loaders(data_route,batch_size,num_workers):
|
||||
tfm_train = T.Compose([T.RandomCrop(32, padding=4),T.RandomHorizontalFlip(),T.ToTensor(),cifar_nm])
|
||||
tfm_test = T.Compose([T.ToTensor(),cifar_nm])
|
||||
|
||||
train_set = dts.CIFAR10(data_route,train=True,download=True,transform=tfm_train)
|
||||
test_set = dts.CIFAR10(data_route,train=False,download=False,transform=tfm_test)
|
||||
|
||||
train_loader = DataLoader(train_set,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=num_workers)
|
||||
test_loader = DataLoader(test_set,batch_size=batch_size,shuffle=False,drop_last=False,num_workers=num_workers)
|
||||
|
||||
return train_loader,test_loader
|
70
sparsity_factory/modelloaders.py
Normal file
70
sparsity_factory/modelloaders.py
Normal file
@ -0,0 +1,70 @@
|
||||
import torch.optim as optim
|
||||
import torch.optim.lr_scheduler as sched
|
||||
import torchvision.models as tmodels
|
||||
from functools import partial
|
||||
from tools.models import *
|
||||
from tools.pruners import prune_weights_reparam
|
||||
|
||||
def model_and_opt_loader(model_string,DEVICE):
|
||||
if DEVICE == None:
|
||||
raise ValueError('No cuda device!')
|
||||
if model_string == 'vgg16':
|
||||
model = VGG16().to(DEVICE)
|
||||
amount = 0.20
|
||||
batch_size = 100
|
||||
opt_pre = {
|
||||
"optimizer": partial(optim.AdamW,lr=0.0003),
|
||||
"steps": 50000,
|
||||
"scheduler": None
|
||||
}
|
||||
opt_post = {
|
||||
"optimizer": partial(optim.AdamW,lr=0.0003),
|
||||
"steps": 40000,
|
||||
"scheduler": None
|
||||
}
|
||||
elif model_string == 'resnet18':
|
||||
model = ResNet18().to(DEVICE)
|
||||
amount = 0.20
|
||||
batch_size = 100
|
||||
opt_pre = {
|
||||
"optimizer": partial(optim.AdamW,lr=0.0003),
|
||||
"steps": 50000,
|
||||
"scheduler": None
|
||||
}
|
||||
opt_post = {
|
||||
"optimizer": partial(optim.AdamW,lr=0.0003),
|
||||
"steps": 40000,
|
||||
"scheduler": None
|
||||
}
|
||||
elif model_string == 'densenet':
|
||||
model = DenseNet121().to(DEVICE)
|
||||
amount = 0.20
|
||||
batch_size = 100
|
||||
opt_pre = {
|
||||
"optimizer": partial(optim.AdamW,lr=0.0003),
|
||||
"steps": 80000,
|
||||
"scheduler": None
|
||||
}
|
||||
opt_post = {
|
||||
"optimizer": partial(optim.AdamW,lr=0.0003),
|
||||
"steps": 60000,
|
||||
"scheduler": None
|
||||
}
|
||||
elif model_string == 'effnet':
|
||||
model = EfficientNetB0().to(DEVICE)
|
||||
amount = 0.20
|
||||
batch_size = 100
|
||||
opt_pre = {
|
||||
"optimizer": partial(optim.AdamW,lr=0.0003),
|
||||
"steps": 50000,
|
||||
"scheduler": None
|
||||
}
|
||||
opt_post = {
|
||||
"optimizer": partial(optim.AdamW,lr=0.0003),
|
||||
"steps": 40000,
|
||||
"scheduler": None
|
||||
}
|
||||
else:
|
||||
raise ValueError('Unknown model')
|
||||
prune_weights_reparam(model)
|
||||
return model,amount,batch_size,opt_pre,opt_post
|
4
sparsity_factory/models/__init__.py
Normal file
4
sparsity_factory/models/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .vgg import VGG16 #, VGG11, VGG13, VGG19
|
||||
from .resnet import ResNet18 #, ResNet34, ResNet50, ResNet101, ResNet152
|
||||
from .densenet import DenseNet121
|
||||
from .efficientnet import EfficientNetB0
|
82
sparsity_factory/models/densenet.py
Normal file
82
sparsity_factory/models/densenet.py
Normal file
@ -0,0 +1,82 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import BatchNorm2d, Conv2d, Linear
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, in_planes, growth_rate):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.bn1 = BatchNorm2d(in_planes)
|
||||
self.conv1 = Conv2d(in_planes,4*growth_rate, kernel_size=1, bias=False)
|
||||
self.bn2 = BatchNorm2d(4*growth_rate)
|
||||
self.conv2 = Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
|
||||
def forward(self, x):
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out = torch.cat([out,x], 1)
|
||||
return out
|
||||
|
||||
class Transition(nn.Module):
|
||||
def __init__(self, in_planes, out_planes):
|
||||
super(Transition, self).__init__()
|
||||
self.bn = BatchNorm2d(in_planes)
|
||||
self.conv = Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(F.relu(self.bn(x)))
|
||||
out = F.avg_pool2d(out, 2)
|
||||
return out
|
||||
|
||||
class DenseNet(nn.Module):
|
||||
def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):
|
||||
super(DenseNet, self).__init__()
|
||||
self.growth_rate = growth_rate
|
||||
|
||||
num_planes = 2*growth_rate
|
||||
self.conv1 = Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
|
||||
num_planes += nblocks[0]*growth_rate
|
||||
out_planes = int(math.floor(num_planes*reduction))
|
||||
self.trans1 = Transition(num_planes, out_planes)
|
||||
num_planes = out_planes
|
||||
|
||||
self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
|
||||
num_planes += nblocks[1]*growth_rate
|
||||
out_planes = int(math.floor(num_planes*reduction))
|
||||
self.trans2 = Transition(num_planes, out_planes)
|
||||
num_planes = out_planes
|
||||
|
||||
self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
|
||||
num_planes += nblocks[2]*growth_rate
|
||||
out_planes = int(math.floor(num_planes*reduction))
|
||||
self.trans3 = Transition(num_planes, out_planes)
|
||||
num_planes = out_planes
|
||||
|
||||
self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
|
||||
num_planes += nblocks[3]*growth_rate
|
||||
|
||||
self.bn = BatchNorm2d(num_planes)
|
||||
self.linear = Linear(num_planes, num_classes)
|
||||
|
||||
def _make_dense_layers(self, block, in_planes, nblock):
|
||||
layers = []
|
||||
for i in range(nblock):
|
||||
layers.append(block(in_planes, self.growth_rate))
|
||||
in_planes += self.growth_rate
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.trans1(self.dense1(out))
|
||||
out = self.trans2(self.dense2(out))
|
||||
out = self.trans3(self.dense3(out))
|
||||
out = self.dense4(out)
|
||||
out = F.avg_pool2d(F.relu(self.bn(out)), 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
def DenseNet121():
|
||||
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12)
|
169
sparsity_factory/models/efficientnet.py
Normal file
169
sparsity_factory/models/efficientnet.py
Normal file
@ -0,0 +1,169 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * x.sigmoid()
|
||||
|
||||
|
||||
def drop_connect(x, drop_ratio):
|
||||
keep_ratio = 1.0 - drop_ratio
|
||||
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
|
||||
mask.bernoulli_(keep_ratio)
|
||||
x.div_(keep_ratio)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
class SE(nn.Module):
|
||||
'''Squeeze-and-Excitation block with Swish.'''
|
||||
|
||||
def __init__(self, in_channels, se_channels):
|
||||
super(SE, self).__init__()
|
||||
self.se1 = nn.Conv2d(in_channels, se_channels,
|
||||
kernel_size=1, bias=True)
|
||||
self.se2 = nn.Conv2d(se_channels, in_channels,
|
||||
kernel_size=1, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.adaptive_avg_pool2d(x, (1, 1))
|
||||
out = swish(self.se1(out))
|
||||
out = self.se2(out).sigmoid()
|
||||
out = x * out
|
||||
return out
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
'''expansion + depthwise + pointwise + squeeze-excitation'''
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
expand_ratio=1,
|
||||
se_ratio=0.,
|
||||
drop_rate=0.):
|
||||
super(Block, self).__init__()
|
||||
self.stride = stride
|
||||
self.drop_rate = drop_rate
|
||||
self.expand_ratio = expand_ratio
|
||||
|
||||
# Expansion
|
||||
channels = expand_ratio * in_channels
|
||||
self.conv1 = nn.Conv2d(in_channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(channels)
|
||||
|
||||
# Depthwise conv
|
||||
self.conv2 = nn.Conv2d(channels,
|
||||
channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(1 if kernel_size == 3 else 2),
|
||||
groups=channels,
|
||||
bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(channels)
|
||||
|
||||
# SE layers
|
||||
se_channels = int(in_channels * se_ratio)
|
||||
self.se = SE(channels, se_channels)
|
||||
|
||||
# Output
|
||||
self.conv3 = nn.Conv2d(channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
# Skip connection if in and out shapes are the same (MV-V2 style)
|
||||
self.has_skip = (stride == 1) and (in_channels == out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = x if self.expand_ratio == 1 else swish(self.bn1(self.conv1(x)))
|
||||
out = swish(self.bn2(self.conv2(out)))
|
||||
out = self.se(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
if self.has_skip:
|
||||
if self.training and self.drop_rate > 0:
|
||||
out = drop_connect(out, self.drop_rate)
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
|
||||
class EfficientNet(nn.Module):
|
||||
def __init__(self, cfg, num_classes=10):
|
||||
super(EfficientNet, self).__init__()
|
||||
self.cfg = cfg
|
||||
self.conv1 = nn.Conv2d(3,
|
||||
32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
self.layers = self._make_layers(in_channels=32)
|
||||
self.linear = nn.Linear(cfg['out_channels'][-1], num_classes)
|
||||
|
||||
def _make_layers(self, in_channels):
|
||||
layers = []
|
||||
cfg = [self.cfg[k] for k in ['expansion', 'out_channels', 'num_blocks', 'kernel_size',
|
||||
'stride']]
|
||||
b = 0
|
||||
blocks = sum(self.cfg['num_blocks'])
|
||||
for expansion, out_channels, num_blocks, kernel_size, stride in zip(*cfg):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
for stride in strides:
|
||||
drop_rate = self.cfg['drop_connect_rate'] * b / blocks
|
||||
layers.append(
|
||||
Block(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
expansion,
|
||||
se_ratio=0.25,
|
||||
drop_rate=drop_rate))
|
||||
in_channels = out_channels
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = swish(self.bn1(self.conv1(x)))
|
||||
out = self.layers(out)
|
||||
out = F.adaptive_avg_pool2d(out, 1)
|
||||
out = out.view(out.size(0), -1)
|
||||
dropout_rate = self.cfg['dropout_rate']
|
||||
if self.training and dropout_rate > 0:
|
||||
out = F.dropout(out, p=dropout_rate)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def EfficientNetB0():
|
||||
cfg = {
|
||||
'num_blocks': [1, 2, 2, 3, 3, 4, 1],
|
||||
'expansion': [1, 6, 6, 6, 6, 6, 6],
|
||||
'out_channels': [16, 24, 40, 80, 112, 192, 320],
|
||||
'kernel_size': [3, 3, 5, 3, 5, 5, 3],
|
||||
'stride': [1, 2, 2, 2, 1, 2, 1],
|
||||
'dropout_rate': 0.2,
|
||||
'drop_connect_rate': 0.2,
|
||||
}
|
||||
return EfficientNet(cfg)
|
||||
|
||||
|
||||
def test():
|
||||
net = EfficientNetB0()
|
||||
x = torch.randn(2, 3, 32, 32)
|
||||
y = net(x)
|
||||
print(y.shape)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
105
sparsity_factory/models/resnet.py
Normal file
105
sparsity_factory/models/resnet.py
Normal file
@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Conv2d, Linear,BatchNorm2d
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = BatchNorm2d(planes)
|
||||
self.conv2 = Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
Conv2d(in_planes, self.expansion*planes,kernel_size=1,stride=stride,bias=False),
|
||||
BatchNorm2d(self.expansion*planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = BatchNorm2d(planes)
|
||||
self.conv2 = Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = BatchNorm2d(planes)
|
||||
self.conv3 = Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
||||
self.bn3 = BatchNorm2d(self.expansion*planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
BatchNorm2d(self.expansion*planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear = Linear(512*block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def ResNet18():
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2])
|
||||
|
||||
def ResNet34():
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||
|
||||
def ResNet50():
|
||||
return ResNet(Bottleneck, [3, 4, 6, 3])
|
||||
|
||||
def ResNet101():
|
||||
return ResNet(Bottleneck, [3, 4, 23, 3])
|
||||
|
||||
def ResNet152():
|
||||
return ResNet(Bottleneck, [3, 8, 36, 3])
|
48
sparsity_factory/models/vgg.py
Normal file
48
sparsity_factory/models/vgg.py
Normal file
@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Conv2d, Linear, BatchNorm2d, MaxPool2d, AvgPool2d
|
||||
|
||||
cfg = {
|
||||
'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
||||
}
|
||||
|
||||
class VGG(nn.Module):
|
||||
def __init__(self, vgg_name, use_bn=True):
|
||||
super(VGG, self).__init__()
|
||||
self.features = self._make_layers(cfg[vgg_name], use_bn)
|
||||
self.classifier = Linear(512, 10)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.features(x)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.classifier(out)
|
||||
return out
|
||||
|
||||
def _make_layers(self, cfg, use_bn):
|
||||
layers = []
|
||||
in_channels = 3
|
||||
for x in cfg:
|
||||
if x == 'M':
|
||||
layers += [MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
layers += [Conv2d(in_channels, x, kernel_size=3, padding=1)]
|
||||
if use_bn:
|
||||
layers += [BatchNorm2d(x)]
|
||||
layers += [nn.ReLU(inplace=True)]
|
||||
in_channels = x
|
||||
layers += [AvgPool2d(kernel_size=1, stride=1)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def VGG11(use_bn=True):
|
||||
return VGG('VGG11', use_bn)
|
||||
def VGG13(use_bn=True):
|
||||
return VGG('VGG13', use_bn)
|
||||
def VGG16(use_bn=True):
|
||||
return VGG('VGG16', use_bn)
|
||||
def VGG19(use_bn=True):
|
||||
return VGG('VGG19', use_bn)
|
279
sparsity_factory/pruners.py
Normal file
279
sparsity_factory/pruners.py
Normal file
@ -0,0 +1,279 @@
|
||||
import torch
|
||||
from torch.nn.utils import prune
|
||||
from .utils import get_weights, get_modules, get_modules_with_name
|
||||
import numpy as np
|
||||
|
||||
ALL_PRUNERS = ['lamp', 'glob', 'unif', 'unifplus', 'erk', 'custom']
|
||||
def check_valid_pruner(name):
|
||||
return check_valid_pruner in ALL_PRUNERS
|
||||
|
||||
def weight_pruner_loader(pruner_string):
|
||||
"""
|
||||
Gives you the pruning methods: LAMP, Glob, Unif, Unif+, ERK, nxm
|
||||
"""
|
||||
if pruner_string == 'lamp':
|
||||
return prune_weights_lamp
|
||||
elif pruner_string == 'glob':
|
||||
return prune_weights_global
|
||||
elif pruner_string == 'unif':
|
||||
return prune_weights_uniform
|
||||
elif pruner_string == 'unifplus':
|
||||
return prune_weights_unifplus
|
||||
elif pruner_string == 'erk':
|
||||
return prune_weights_erk
|
||||
elif pruner_string == 'custom':
|
||||
return prune_weights_from_config
|
||||
else:
|
||||
raise ValueError('Unknown pruner')
|
||||
|
||||
|
||||
"""
|
||||
prune_weights_reparam: Allocate identity mask to every weight tensors.
|
||||
prune_weights_l1predefined: Perform layerwise pruning w.r.t. given amounts.
|
||||
"""
|
||||
|
||||
def prune_weights_semistructured(module, configs=None):
|
||||
"""
|
||||
Remove the weight by the defined N:M configs.
|
||||
The configs will be a 2D lists with the following format:
|
||||
configs = [[N1:M1], [N2:M2] ...]
|
||||
Please make sure that the len(configs) == number_of_pruning_layers
|
||||
"""
|
||||
def compute_mask(t, N, M):
|
||||
out_channel, in_channel = t.shape
|
||||
percentile = M / N
|
||||
t_reshaped = t.reshape(out_channel, -1, N)
|
||||
#print(t_reshaped.shape)
|
||||
mask = torch.ones_like(t)
|
||||
mask_reshaped = mask.reshape(out_channel, -1, N)
|
||||
|
||||
nparams_topprune = int(N * percentile)
|
||||
if nparams_topprune != 0:
|
||||
topk = torch.topk(torch.abs(t_reshaped), k=nparams_topprune, largest=False, dim = -1)
|
||||
#print(topk.indices)
|
||||
mask_reshaped = mask_reshaped.scatter(dim = -1, index = topk.indices, value = 0)
|
||||
|
||||
return mask_reshaped.reshape(out_channel, in_channel)
|
||||
|
||||
if configs == None:
|
||||
raise ValueError("Currently nxm pruning only support from manual config. \
|
||||
Please provide config of the sparsity level for earch pruning target")
|
||||
|
||||
mlist = get_modules_with_name(module)
|
||||
for idx, (name, m) in enumerate(mlist):
|
||||
weight_tensor = m.weight
|
||||
config = configs[idx]
|
||||
N, M = config[0], config[1]
|
||||
print(f"module: {name}, N:M = ({N}, {M})")
|
||||
mask = compute_mask(weight_tensor, N, M)
|
||||
prune.custom_from_mask(m, name = 'weight', mask = mask)
|
||||
|
||||
def prune_weights_reparam(model):
|
||||
module_list = get_modules(model)
|
||||
for m in module_list:
|
||||
prune.identity(m,name="weight")
|
||||
|
||||
def prune_weights_l1predefined(model,amounts):
|
||||
mlist = get_modules_with_name(model)
|
||||
for idx,(name, m) in enumerate(mlist):
|
||||
print(f"module: {name}, amounts of removed weight: {float(amounts[idx])}")
|
||||
prune.l1_unstructured(m,name="weight",amount=float(amounts[idx]))
|
||||
|
||||
"""
|
||||
Methods: All weights
|
||||
"""
|
||||
|
||||
def prune_weights_global(model,amount):
|
||||
parameters_to_prune = _extract_weight_tuples(model)
|
||||
prune.global_unstructured(parameters_to_prune,pruning_method = prune.L1Unstructured,amount=amount)
|
||||
|
||||
def prune_weights_lamp(model,amount):
|
||||
assert amount <= 1
|
||||
amounts = _compute_lamp_amounts(model,amount)
|
||||
prune_weights_l1predefined(model,amounts)
|
||||
|
||||
def prune_weights_uniform(model,amount):
|
||||
module_list = get_modules(model)
|
||||
assert amount <= 1 # Can be updated later to handle > 1.
|
||||
for m in module_list:
|
||||
print("module:", m, " remove amount:", amount)
|
||||
prune.l1_unstructured(m,name="weight",amount=amount)
|
||||
|
||||
def prune_weights_unifplus(model,amount):
|
||||
assert amount <= 1
|
||||
amounts = _compute_unifplus_amounts(model,amount)
|
||||
prune_weights_l1predefined(model,amounts)
|
||||
|
||||
def prune_weights_erk(model,amount):
|
||||
assert amount <= 1
|
||||
amounts = _compute_erk_amounts(model,amount)
|
||||
prune_weights_l1predefined(model,amounts)
|
||||
|
||||
|
||||
def prune_weights_from_config(model, mode, configs):
|
||||
if mode == 'nxm':
|
||||
prune_weights_semistructured(model, configs)
|
||||
elif mode == 'unstructured':
|
||||
prune_weights_l1predefined(model, amounts=configs)
|
||||
|
||||
"""
|
||||
These are not intended to be exported.
|
||||
"""
|
||||
|
||||
def _extract_weight_tuples(model):
|
||||
"""
|
||||
Gives you well-packed weight tensors for global pruning.
|
||||
"""
|
||||
mlist = get_modules(model)
|
||||
return tuple([(m,'weight') for m in mlist])
|
||||
|
||||
def _compute_unifplus_amounts(model,amount):
|
||||
"""
|
||||
Compute # of weights to prune in each layer.
|
||||
"""
|
||||
amounts = []
|
||||
wlist = get_weights(model)
|
||||
unmaskeds = _count_unmasked_weights(model)
|
||||
totals = _count_total_weights(model)
|
||||
|
||||
last_layer_minimum = np.round(totals[-1]*0.2) # Minimum number of last-layer weights to keep
|
||||
total_to_prune = np.round(unmaskeds.sum()*amount)
|
||||
|
||||
if wlist[0].dim() == 4:
|
||||
amounts.append(0) # Leave the first layer unpruned.
|
||||
frac_to_prune = (total_to_prune*1.0)/(unmaskeds[1:].sum())
|
||||
if frac_to_prune > 1.0:
|
||||
raise ValueError("Cannot be pruned further by the Unif+ scheme! (first layer exception)")
|
||||
last_layer_to_surv_planned = np.round((1.0-frac_to_prune)*unmaskeds[-1])
|
||||
if last_layer_to_surv_planned < last_layer_minimum:
|
||||
last_layer_to_prune = unmaskeds[-1] - last_layer_minimum
|
||||
frac_to_prune_middle = ((total_to_prune-last_layer_to_prune)*1.0)/(unmaskeds[1:-1].sum())
|
||||
if frac_to_prune_middle > 1.0:
|
||||
raise ValueError("Cannot be pruned further by the Unif+ scheme! (first+last layer exception)")
|
||||
amounts.extend([frac_to_prune_middle]*(unmaskeds.size(0)-2))
|
||||
amounts.append((last_layer_to_prune*1.0)/unmaskeds[-1])
|
||||
else:
|
||||
amounts.extend([frac_to_prune]*(unmaskeds.size(0)-1))
|
||||
else:
|
||||
frac_to_prune = (total_to_prune*1.0)/(unmaskeds.sum())
|
||||
last_layer_to_surv_planned = np.round((1.0-frac_to_prune)*unmaskeds[-1])
|
||||
if last_layer_to_surv_planned < last_layer_minimum:
|
||||
last_layer_to_prune = unmaskeds[-1] - last_layer_minimum
|
||||
frac_to_prune_middle = ((total_to_prune-last_layer_to_prune)*1.0)/(unmaskeds[:-1].sum())
|
||||
if frac_to_prune_middle > 1.0:
|
||||
raise ValueError("Cannot be pruned further by the Unif+ scheme! (last layer exception)")
|
||||
amounts.extend([frac_to_prune_middle]*(unmaskeds.size(0)-1))
|
||||
amounts.append((last_layer_to_prune*1.0)/unmaskeds[-1])
|
||||
else:
|
||||
amounts.extend([frac_to_prune]*(unmaskeds.size(0)))
|
||||
return amounts
|
||||
|
||||
def _compute_erk_amounts(model,amount):
|
||||
unmaskeds = _count_unmasked_weights(model)
|
||||
erks = _compute_erks(model)
|
||||
|
||||
return _amounts_from_eps(unmaskeds,erks,amount)
|
||||
|
||||
def _amounts_from_eps(unmaskeds,ers,amount):
|
||||
num_layers = ers.size(0)
|
||||
layers_to_keep_dense = torch.zeros(num_layers)
|
||||
total_to_survive = (1.0-amount)*unmaskeds.sum() # Total to keep.
|
||||
|
||||
# Determine some layers to keep dense.
|
||||
is_eps_invalid = True
|
||||
while is_eps_invalid:
|
||||
unmasked_among_prunables = (unmaskeds*(1-layers_to_keep_dense)).sum()
|
||||
to_survive_among_prunables = total_to_survive - (layers_to_keep_dense*unmaskeds).sum()
|
||||
|
||||
ers_of_prunables = ers*(1.0-layers_to_keep_dense)
|
||||
survs_of_prunables = torch.round(to_survive_among_prunables*ers_of_prunables/ers_of_prunables.sum())
|
||||
|
||||
layer_to_make_dense = -1
|
||||
max_ratio = 1.0
|
||||
for idx in range(num_layers):
|
||||
if layers_to_keep_dense[idx] == 0:
|
||||
if survs_of_prunables[idx]/unmaskeds[idx] > max_ratio:
|
||||
layer_to_make_dense = idx
|
||||
max_ratio = survs_of_prunables[idx]/unmaskeds[idx]
|
||||
|
||||
if layer_to_make_dense == -1:
|
||||
is_eps_invalid = False
|
||||
else:
|
||||
layers_to_keep_dense[layer_to_make_dense] = 1
|
||||
|
||||
amounts = torch.zeros(num_layers)
|
||||
|
||||
for idx in range(num_layers):
|
||||
if layers_to_keep_dense[idx] == 1:
|
||||
amounts[idx] = 0.0
|
||||
else:
|
||||
amounts[idx] = 1.0 - (survs_of_prunables[idx]/unmaskeds[idx])
|
||||
return amounts
|
||||
|
||||
def _compute_lamp_amounts(model,amount):
|
||||
"""
|
||||
Compute normalization schemes.
|
||||
"""
|
||||
unmaskeds = _count_unmasked_weights(model)
|
||||
num_surv = int(np.round(unmaskeds.sum()*(1.0-amount)))
|
||||
|
||||
flattened_scores = [_normalize_scores(w**2).view(-1) for w in get_weights(model)]
|
||||
concat_scores = torch.cat(flattened_scores,dim=0)
|
||||
topks,_ = torch.topk(concat_scores,num_surv)
|
||||
threshold = topks[-1]
|
||||
|
||||
# We don't care much about tiebreakers, for now.
|
||||
final_survs = [torch.ge(score,threshold*torch.ones(score.size()).to(score.device)).sum() for score in flattened_scores]
|
||||
amounts = []
|
||||
for idx,final_surv in enumerate(final_survs):
|
||||
amounts.append(1.0 - (final_surv/unmaskeds[idx]))
|
||||
|
||||
return amounts
|
||||
|
||||
def _compute_erks(model):
|
||||
wlist = get_weights(model)
|
||||
erks = torch.zeros(len(wlist))
|
||||
for idx,w in enumerate(wlist):
|
||||
if w.dim() == 4:
|
||||
erks[idx] = w.size(0)+w.size(1)+w.size(2)+w.size(3)
|
||||
else:
|
||||
erks[idx] = w.size(0)+w.size(1)
|
||||
return erks
|
||||
|
||||
def _count_unmasked_weights(model):
|
||||
"""
|
||||
Return a 1-dimensional tensor of #unmasked weights.
|
||||
"""
|
||||
mlist = get_modules(model)
|
||||
unmaskeds = []
|
||||
for m in mlist:
|
||||
unmaskeds.append(m.weight_mask.sum())
|
||||
return torch.FloatTensor(unmaskeds)
|
||||
|
||||
def _count_total_weights(model):
|
||||
"""
|
||||
Return a 1-dimensional tensor of #total weights.
|
||||
"""
|
||||
wlist = get_weights(model)
|
||||
numels = []
|
||||
for w in wlist:
|
||||
numels.append(w.numel())
|
||||
return torch.FloatTensor(numels)
|
||||
|
||||
def _normalize_scores(scores):
|
||||
"""
|
||||
Normalizing scheme for LAMP.
|
||||
"""
|
||||
# sort scores in an ascending order
|
||||
sorted_scores,sorted_idx = scores.view(-1).sort(descending=False)
|
||||
# compute cumulative sum
|
||||
scores_cumsum_temp = sorted_scores.cumsum(dim=0)
|
||||
scores_cumsum = torch.zeros(scores_cumsum_temp.shape,device=scores.device)
|
||||
scores_cumsum[1:] = scores_cumsum_temp[:len(scores_cumsum_temp)-1]
|
||||
# normalize by cumulative sum
|
||||
sorted_scores /= (scores.sum() - scores_cumsum)
|
||||
# tidy up and output
|
||||
new_scores = torch.zeros(scores_cumsum.shape,device=scores.device)
|
||||
new_scores[sorted_idx] = sorted_scores
|
||||
|
||||
return new_scores.view(scores.shape)
|
102
sparsity_factory/train.py
Normal file
102
sparsity_factory/train.py
Normal file
@ -0,0 +1,102 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
def trainer_loader():
|
||||
return train
|
||||
|
||||
def initialize_weight(model,loader):
|
||||
batch = next(iter(loader))
|
||||
device = next(model.parameters()).device
|
||||
with torch.no_grad():
|
||||
model(batch[0].to(device))
|
||||
|
||||
def train(model,optpack,train_loader,test_loader,print_steps=-1,log_results=False,log_path='log.txt'):
|
||||
model.train()
|
||||
opt = optpack["optimizer"](model.parameters())
|
||||
if optpack["scheduler"] is not None:
|
||||
sched = optpack["scheduler"](opt)
|
||||
else:
|
||||
sched = None
|
||||
num_steps = optpack["steps"]
|
||||
device = next(model.parameters()).device
|
||||
|
||||
results_log = []
|
||||
training_step = 0
|
||||
|
||||
if sched is not None:
|
||||
while True:
|
||||
for i,(x,y) in enumerate(train_loader):
|
||||
training_step += 1
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
|
||||
opt.zero_grad()
|
||||
yhat = model(x)
|
||||
loss = F.cross_entropy(yhat,y)
|
||||
loss.backward()
|
||||
opt.step()
|
||||
sched.step()
|
||||
|
||||
if print_steps != -1 and training_step%print_steps == 0:
|
||||
train_acc,train_loss = test(model,train_loader)
|
||||
test_acc,test_loss = test(model,test_loader)
|
||||
print(f'Steps: {training_step}/{num_steps} \t Train acc: {train_acc:.2f} \t Test acc: {test_acc:.2f}', end='\r')
|
||||
if log_results:
|
||||
results_log.append([test_acc,test_loss,train_acc,train_loss])
|
||||
np.savetxt(log_path,results_log)
|
||||
if training_step >= num_steps:
|
||||
break
|
||||
if training_step >= num_steps:
|
||||
break
|
||||
else:
|
||||
while True:
|
||||
for i,(x,y) in enumerate(train_loader):
|
||||
training_step += 1
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
|
||||
opt.zero_grad()
|
||||
yhat = model(x)
|
||||
loss = F.cross_entropy(yhat,y)
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
if print_steps != -1 and training_step%print_steps == 0:
|
||||
train_acc,train_loss = test(model,train_loader)
|
||||
test_acc,test_loss = test(model,test_loader)
|
||||
print(f'Steps: {training_step}/{num_steps} \t Train acc: {train_acc:.2f} \t Test acc: {test_acc:.2f}', end='\r')
|
||||
if log_results:
|
||||
results_log.append([test_acc,test_loss,train_acc,train_loss])
|
||||
np.savetxt(log_path,results_log)
|
||||
if training_step >= num_steps:
|
||||
break
|
||||
if training_step >= num_steps:
|
||||
break
|
||||
train_acc,train_loss = test(model,train_loader)
|
||||
test_acc,test_loss = test(model,test_loader)
|
||||
print(f'Train acc: {train_acc:.2f}\t Test acc: {test_acc:.2f}')
|
||||
return [test_acc,test_loss,train_acc,train_loss]
|
||||
|
||||
def test(model,loader):
|
||||
model.eval()
|
||||
device = next(model.parameters()).device
|
||||
|
||||
correct = 0
|
||||
loss = 0
|
||||
total = 0
|
||||
for i,(x,y) in enumerate(loader):
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
with torch.no_grad():
|
||||
yhat = model(x)
|
||||
_,pred = yhat.max(1)
|
||||
correct += pred.eq(y).sum().item()
|
||||
loss += F.cross_entropy(yhat,y)*len(x)
|
||||
total += len(x)
|
||||
acc = correct/total * 100.0
|
||||
loss = loss/total
|
||||
|
||||
model.train()
|
||||
|
||||
return acc,loss
|
80
sparsity_factory/utils.py
Normal file
80
sparsity_factory/utils.py
Normal file
@ -0,0 +1,80 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
# Preliminaries. Not to be exported.
|
||||
|
||||
exclude_module_name = ['head']
|
||||
|
||||
def _is_prunable_module(m, name = None):
|
||||
if name is not None and name in exclude_module_name:
|
||||
return False
|
||||
return isinstance(m, nn.Linear)
|
||||
|
||||
def _get_sparsity(tsr):
|
||||
total = tsr.numel()
|
||||
nnz = tsr.nonzero().size(0)
|
||||
return nnz/total
|
||||
|
||||
def _get_nnz(tsr):
|
||||
return tsr.nonzero().size(0)
|
||||
|
||||
# Modules
|
||||
|
||||
def get_weights(model):
|
||||
weights = []
|
||||
for name, m in model.named_modules():
|
||||
if _is_prunable_module(m, name):
|
||||
weights.append(m.weight)
|
||||
return weights
|
||||
|
||||
def get_convweights(model):
|
||||
weights = []
|
||||
for m in model.modules():
|
||||
if isinstance(m,nn.Conv2d):
|
||||
weights.append(m.weight)
|
||||
return weights
|
||||
|
||||
def get_modules(model):
|
||||
modules = []
|
||||
for name, m in model.named_modules():
|
||||
if _is_prunable_module(m, name):
|
||||
modules.append(m)
|
||||
return modules
|
||||
|
||||
def get_modules_with_name(model):
|
||||
modules = []
|
||||
for name, m in model.named_modules():
|
||||
if _is_prunable_module(m, name):
|
||||
modules.append((name, m))
|
||||
return modules
|
||||
|
||||
|
||||
def get_convmodules(model):
|
||||
modules = []
|
||||
for m in model.modules():
|
||||
if isinstance(m,nn.Conv2d):
|
||||
modules.append(m)
|
||||
return modules
|
||||
|
||||
def get_copied_modules(model):
|
||||
modules = []
|
||||
for m in model.modules():
|
||||
if _is_prunable_module(m):
|
||||
modules.append(deepcopy(m).cpu())
|
||||
return modules
|
||||
|
||||
def get_model_sparsity(model):
|
||||
prunables = 0
|
||||
nnzs = 0
|
||||
for m in model.modules():
|
||||
if _is_prunable_module(m):
|
||||
prunables += m.weight.data.numel()
|
||||
nnzs += m.weight.data.nonzero().size(0)
|
||||
return nnzs/prunables
|
||||
|
||||
def get_sparsities(model):
|
||||
return [_get_sparsity(m.weight.data) for m in model.modules() if _is_prunable_module(m)]
|
||||
|
||||
def get_nnzs(model):
|
||||
return [_get_nnz(m.weight.data) for m in model.modules() if _is_prunable_module(m)]
|
Loading…
x
Reference in New Issue
Block a user