diff --git a/configs/deit_small_nxm.yaml b/configs/deit_small_nxm.yaml new file mode 100644 index 0000000..2ebd159 --- /dev/null +++ b/configs/deit_small_nxm.yaml @@ -0,0 +1,3 @@ +sparsity: + mode: nxm + level: [[4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2], [4, 2]] \ No newline at end of file diff --git a/configs/deit_small_uniform.yaml b/configs/deit_small_uniform.yaml new file mode 100644 index 0000000..95c7023 --- /dev/null +++ b/configs/deit_small_uniform.yaml @@ -0,0 +1,3 @@ +sparsity: + pruner: + level: [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5] \ No newline at end of file diff --git a/configs/testing.py b/configs/testing.py new file mode 100644 index 0000000..d6a5a53 --- /dev/null +++ b/configs/testing.py @@ -0,0 +1,8 @@ +import yaml +from yaml.loader import SafeLoader + +# Open the file and load the file +with open('deit_small_nxm.yaml') as f: + data = yaml.load(f, Loader=SafeLoader) + print(data) + diff --git a/engine.py b/engine.py index ed10cea..a6261bc 100644 --- a/engine.py +++ b/engine.py @@ -33,10 +33,10 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) - + if args.bce_loss: targets = targets.gt(0.0).type(targets.dtype) - + with torch.cuda.amp.autocast(): outputs = model(samples) loss = criterion(samples, outputs, targets) diff --git a/main.py b/main.py index bc8c418..d787f92 100644 --- a/main.py +++ b/main.py @@ -105,15 +105,15 @@ def get_args_parser(): parser.add_argument('--repeated-aug', action='store_true') parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') parser.set_defaults(repeated_aug=True) - + parser.add_argument('--train-mode', action='store_true') parser.add_argument('--no-train-mode', action='store_false', dest='train_mode') parser.set_defaults(train_mode=True) - + parser.add_argument('--ThreeAugment', action='store_true') #3augment - + parser.add_argument('--src', action='store_true') #simple random crop - + # * Random Erase params parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob (default: 0.25)') @@ -148,8 +148,8 @@ def get_args_parser(): # * Finetuning params parser.add_argument('--finetune', default='', help='finetune from checkpoint') - parser.add_argument('--attn-only', action='store_true') - + parser.add_argument('--attn-only', action='store_true') + # Dataset parameters parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path') @@ -266,7 +266,9 @@ def main(args): img_size=args.input_size ) - + + + if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( @@ -302,7 +304,7 @@ def main(args): checkpoint_model['pos_embed'] = new_pos_embed model.load_state_dict(checkpoint_model, strict=False) - + if args.attn_only: for name_p,p in model.named_parameters(): if '.attn.' in name_p: @@ -324,7 +326,7 @@ def main(args): p.requires_grad = False except: print('no patch embed') - + model.to(device) model_ema = None @@ -359,10 +361,10 @@ def main(args): criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() - + if args.bce_loss: criterion = torch.nn.BCEWithLogitsLoss() - + teacher_model = None if args.distillation_type != 'none': assert args.teacher_path, 'need to specify teacher-path when using distillation' @@ -438,11 +440,11 @@ def main(args): 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) - + test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") - + if max_accuracy < test_stats["acc1"]: max_accuracy = test_stats["acc1"] if args.output_dir: @@ -457,17 +459,17 @@ def main(args): 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) - + print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} - - - - + + + + if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") diff --git a/sparsify_training.py b/sparsify_training.py new file mode 100644 index 0000000..01e8009 --- /dev/null +++ b/sparsify_training.py @@ -0,0 +1,513 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +import argparse +import datetime +import numpy as np +import time +import torch +import torch.backends.cudnn as cudnn +import json +import yaml +from yaml.loader import SafeLoader + +from pathlib import Path + +from timm.data import Mixup +from timm.models import create_model +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.scheduler import create_scheduler +from timm.optim import create_optimizer +from timm.utils import NativeScaler, get_state_dict, ModelEma + +from datasets import build_dataset +from engine import train_one_epoch, evaluate +from losses import DistillationLoss +from samplers import RASampler +from augment import new_data_aug_generator + +import models +import models_v2 + +import utils + +from sparsity_factory.pruners import weight_pruner_loader, prune_weights_reparam, check_valid_pruner + +def get_args_parser(): + parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) + parser.add_argument('--batch-size', default=64, type=int) + parser.add_argument('--epochs', default=300, type=int) + parser.add_argument('--bce-loss', action='store_true') + parser.add_argument('--unscale-lr', action='store_true') + + # Model parameters + parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', + help='Name of model to train') + parser.add_argument('--input-size', default=224, type=int, help='images input size') + + parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', + help='Dropout rate (default: 0.)') + parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', + help='Drop path rate (default: 0.1)') + + parser.add_argument('--model-ema', action='store_true') + parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') + parser.set_defaults(model_ema=False) + parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') + parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') + + # Optimizer parameters + parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "adamw"') + parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') + parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') + parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') + parser.add_argument('--weight-decay', type=float, default=0.05, + help='weight decay (default: 0.05)') + # Learning rate schedule parameters + parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "cosine"') + parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', + help='learning rate (default: 5e-4)') + parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') + parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', + help='learning rate noise limit percent (default: 0.67)') + parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', + help='learning rate noise std-dev (default: 1.0)') + parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', + help='warmup learning rate (default: 1e-6)') + parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') + + parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', + help='epoch interval to decay LR') + parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', + help='epochs to warmup LR, if scheduler supports') + parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') + parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', + help='patience epochs for Plateau LR scheduler (default: 10') + parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') + + # Augmentation parameters + parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT', + help='Color jitter factor (default: 0.3)') + parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', + help='Use AutoAugment policy. "v0" or "original". " + \ + "(default: rand-m9-mstd0.5-inc1)'), + parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') + parser.add_argument('--train-interpolation', type=str, default='bicubic', + help='Training interpolation (random, bilinear, bicubic default: "bicubic")') + + parser.add_argument('--repeated-aug', action='store_true') + parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') + parser.set_defaults(repeated_aug=True) + + parser.add_argument('--train-mode', action='store_true') + parser.add_argument('--no-train-mode', action='store_false', dest='train_mode') + parser.set_defaults(train_mode=True) + + parser.add_argument('--ThreeAugment', action='store_true') #3augment + + parser.add_argument('--src', action='store_true') #simple random crop + + # * Random Erase params + parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', + help='Random erase prob (default: 0.25)') + parser.add_argument('--remode', type=str, default='pixel', + help='Random erase mode (default: "pixel")') + parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') + parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') + + # * Mixup params + parser.add_argument('--mixup', type=float, default=0.8, + help='mixup alpha, mixup enabled if > 0. (default: 0.8)') + parser.add_argument('--cutmix', type=float, default=1.0, + help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') + parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') + parser.add_argument('--mixup-prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') + parser.add_argument('--mixup-switch-prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') + parser.add_argument('--mixup-mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + + # Distillation parameters + parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', + help='Name of teacher model to train (default: "regnety_160"') + parser.add_argument('--teacher-path', type=str, default='') + parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") + parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") + parser.add_argument('--distillation-tau', default=1.0, type=float, help="") + + # * Finetuning params + parser.add_argument('--finetune', default='', help='finetune from checkpoint') + parser.add_argument('--attn-only', action='store_true') + + # Dataset parameters + parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, + help='dataset path') + parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], + type=str, help='Image Net dataset path') + parser.add_argument('--inat-category', default='name', + choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], + type=str, help='semantic granularity') + + parser.add_argument('--output_dir', default='', + help='path where to save, empty for no saving') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', help='Perform evaluation only') + parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation") + parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') + parser.add_argument('--num_workers', default=10, type=int) + parser.add_argument('--pin-mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', + help='') + parser.set_defaults(pin_mem=True) + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + + # sparsity parameters + parser.add_argument('--pruner', type=str, help='pruning criterion') + parser.add_argument('--sparsity', type=float, default=1.0, help = 'the sparisty level (ratio of unpruned weight)') + parser.add_argument('--custom-config', type=str, help='customized configuration of sparsity level for each linear layer') + return parser + + +def main(args): + utils.init_distributed_mode(args) + + print(args) + + if args.distillation_type != 'none' and args.finetune and not args.eval: + raise NotImplementedError("Finetuning with distillation not yet supported") + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + # random.seed(seed) + + cudnn.benchmark = True + + dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) + dataset_val, _ = build_dataset(is_train=False, args=args) + + if True: # args.distributed: + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + if args.repeated_aug: + sampler_train = RASampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + else: + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' + 'This will slightly alter validation results as extra duplicate entries are added to achieve ' + 'equal num of samples per-process.') + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + if args.ThreeAugment: + data_loader_train.dataset.transform = new_data_aug_generator(args) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + batch_size=int(1.5 * args.batch_size), + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False + ) + + mixup_fn = None + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + mixup_fn = Mixup( + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, + label_smoothing=args.smoothing, num_classes=args.nb_classes) + + print(f"Creating model: {args.model}") + model = create_model( + args.model, + pretrained=True, + num_classes=args.nb_classes, + drop_rate=args.drop, + drop_path_rate=args.drop_path, + drop_block_rate=None, + img_size=args.input_size + ) + + + + if args.pruner == 'custom': + if args.custom_config: + with open(args.custom_config) as f: + config = yaml.load(f, Loader=SafeLoader) + else: + raise ValueError("Please provide the configuration file when using the custom mode") + + mode = config['sparsity']['mode'] + sparsity_config = config['sparsity']['level'] + + pruner = weight_pruner_loader(args.pruner) + pruner(model, mode, sparsity_config) + elif check_valid_pruner(args.pruner): + pruner = weight_pruner_loader(args.pruner) + prune_weights_reparam(model) + pruner(model, args.sparsity) + else: + raise ValueError(f"Pruner '{args.pruner}' is not supported") + + + + if args.finetune: + if args.finetune.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.finetune, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.finetune, map_location='cpu') + + checkpoint_model = checkpoint['model'] + state_dict = model.state_dict() + for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: + if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: + print(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + # interpolate position embedding + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + model.load_state_dict(checkpoint_model, strict=False) + + if args.attn_only: + for name_p,p in model.named_parameters(): + if '.attn.' in name_p: + p.requires_grad = True + else: + p.requires_grad = False + try: + model.head.weight.requires_grad = True + model.head.bias.requires_grad = True + except: + model.fc.weight.requires_grad = True + model.fc.bias.requires_grad = True + try: + model.pos_embed.requires_grad = True + except: + print('no position encoding') + try: + for p in model.patch_embed.parameters(): + p.requires_grad = False + except: + print('no patch embed') + + model.to(device) + + model_ema = None + if args.model_ema: + # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper + model_ema = ModelEma( + model, + decay=args.model_ema_decay, + device='cpu' if args.model_ema_force_cpu else '', + resume='') + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('number of params:', n_parameters) + if not args.unscale_lr: + linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 + args.lr = linear_scaled_lr + optimizer = create_optimizer(args, model_without_ddp) + loss_scaler = NativeScaler() + + lr_scheduler, _ = create_scheduler(args, optimizer) + + criterion = LabelSmoothingCrossEntropy() + + if mixup_active: + # smoothing is handled with mixup label transform + criterion = SoftTargetCrossEntropy() + elif args.smoothing: + criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) + else: + criterion = torch.nn.CrossEntropyLoss() + + if args.bce_loss: + criterion = torch.nn.BCEWithLogitsLoss() + + teacher_model = None + if args.distillation_type != 'none': + assert args.teacher_path, 'need to specify teacher-path when using distillation' + print(f"Creating teacher model: {args.teacher_model}") + teacher_model = create_model( + args.teacher_model, + pretrained=False, + num_classes=args.nb_classes, + global_pool='avg', + ) + if args.teacher_path.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.teacher_path, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.teacher_path, map_location='cpu') + teacher_model.load_state_dict(checkpoint['model']) + teacher_model.to(device) + teacher_model.eval() + + # wrap the criterion in our custom DistillationLoss, which + # just dispatches to the original criterion if args.distillation_type is 'none' + criterion = DistillationLoss( + criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau + ) + + output_dir = Path(args.output_dir) + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + if args.model_ema: + utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + lr_scheduler.step(args.start_epoch) + if args.eval: + test_stats = evaluate(data_loader_val, model, device) + print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") + return + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + max_accuracy = 0.0 + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + args.clip_grad, model_ema, mixup_fn, + set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning + args = args, + ) + + lr_scheduler.step(epoch) + if args.output_dir: + checkpoint_paths = [output_dir / 'checkpoint.pth'] + for checkpoint_path in checkpoint_paths: + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + }, checkpoint_path) + + + test_stats = evaluate(data_loader_val, model, device) + print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") + + if max_accuracy < test_stats["acc1"]: + max_accuracy = test_stats["acc1"] + if args.output_dir: + checkpoint_paths = [output_dir / 'best_checkpoint.pth'] + for checkpoint_path in checkpoint_paths: + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + }, checkpoint_path) + + print(f'Max accuracy: {max_accuracy:.2f}%') + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + + + + + if args.output_dir and utils.is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) + args = parser.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/sparsity_factory/__init__.py b/sparsity_factory/__init__.py new file mode 100644 index 0000000..bb01e0f --- /dev/null +++ b/sparsity_factory/__init__.py @@ -0,0 +1,5 @@ +from .utils import get_weights, get_modules, get_copied_modules, get_sparsities, get_nnzs, get_model_sparsity +from .pruners import weight_pruner_loader, prune_weights_semistructured, check_valid_pruner +#from .dataloaders import dataset_loader +#from .modelloaders import model_and_opt_loader +#from .train import trainer_loader,initialize_weight \ No newline at end of file diff --git a/sparsity_factory/dataloaders.py b/sparsity_factory/dataloaders.py new file mode 100644 index 0000000..3ccacfb --- /dev/null +++ b/sparsity_factory/dataloaders.py @@ -0,0 +1,13 @@ +import torch +from tools.datasets import * + +data_route = './data/' +cifar10_strings = ['vgg16','resnet18','densenet','effnet'] + +def dataset_loader(model,batch_size=100,num_workers=5): + if model in cifar10_strings: + print("Loading CIFAR-10 with batch size "+str(batch_size)) + train_loader,test_loader = get_cifar10_loaders(data_route,batch_size,num_workers) + else: + raise ValueError('Model not implemented :P') + return train_loader,test_loader diff --git a/sparsity_factory/datasets/__init__.py b/sparsity_factory/datasets/__init__.py new file mode 100644 index 0000000..19a9e4f --- /dev/null +++ b/sparsity_factory/datasets/__init__.py @@ -0,0 +1 @@ +from .cifar10 import get_cifar10_loaders diff --git a/sparsity_factory/datasets/cifar10.py b/sparsity_factory/datasets/cifar10.py new file mode 100644 index 0000000..f28ce05 --- /dev/null +++ b/sparsity_factory/datasets/cifar10.py @@ -0,0 +1,18 @@ +import torch +from torch.utils.data import DataLoader +import torchvision.datasets as dts +import torchvision.transforms as T + +cifar_nm = T.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261)) + +def get_cifar10_loaders(data_route,batch_size,num_workers): + tfm_train = T.Compose([T.RandomCrop(32, padding=4),T.RandomHorizontalFlip(),T.ToTensor(),cifar_nm]) + tfm_test = T.Compose([T.ToTensor(),cifar_nm]) + + train_set = dts.CIFAR10(data_route,train=True,download=True,transform=tfm_train) + test_set = dts.CIFAR10(data_route,train=False,download=False,transform=tfm_test) + + train_loader = DataLoader(train_set,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=num_workers) + test_loader = DataLoader(test_set,batch_size=batch_size,shuffle=False,drop_last=False,num_workers=num_workers) + + return train_loader,test_loader \ No newline at end of file diff --git a/sparsity_factory/modelloaders.py b/sparsity_factory/modelloaders.py new file mode 100644 index 0000000..b74fbda --- /dev/null +++ b/sparsity_factory/modelloaders.py @@ -0,0 +1,70 @@ +import torch.optim as optim +import torch.optim.lr_scheduler as sched +import torchvision.models as tmodels +from functools import partial +from tools.models import * +from tools.pruners import prune_weights_reparam + +def model_and_opt_loader(model_string,DEVICE): + if DEVICE == None: + raise ValueError('No cuda device!') + if model_string == 'vgg16': + model = VGG16().to(DEVICE) + amount = 0.20 + batch_size = 100 + opt_pre = { + "optimizer": partial(optim.AdamW,lr=0.0003), + "steps": 50000, + "scheduler": None + } + opt_post = { + "optimizer": partial(optim.AdamW,lr=0.0003), + "steps": 40000, + "scheduler": None + } + elif model_string == 'resnet18': + model = ResNet18().to(DEVICE) + amount = 0.20 + batch_size = 100 + opt_pre = { + "optimizer": partial(optim.AdamW,lr=0.0003), + "steps": 50000, + "scheduler": None + } + opt_post = { + "optimizer": partial(optim.AdamW,lr=0.0003), + "steps": 40000, + "scheduler": None + } + elif model_string == 'densenet': + model = DenseNet121().to(DEVICE) + amount = 0.20 + batch_size = 100 + opt_pre = { + "optimizer": partial(optim.AdamW,lr=0.0003), + "steps": 80000, + "scheduler": None + } + opt_post = { + "optimizer": partial(optim.AdamW,lr=0.0003), + "steps": 60000, + "scheduler": None + } + elif model_string == 'effnet': + model = EfficientNetB0().to(DEVICE) + amount = 0.20 + batch_size = 100 + opt_pre = { + "optimizer": partial(optim.AdamW,lr=0.0003), + "steps": 50000, + "scheduler": None + } + opt_post = { + "optimizer": partial(optim.AdamW,lr=0.0003), + "steps": 40000, + "scheduler": None + } + else: + raise ValueError('Unknown model') + prune_weights_reparam(model) + return model,amount,batch_size,opt_pre,opt_post \ No newline at end of file diff --git a/sparsity_factory/models/__init__.py b/sparsity_factory/models/__init__.py new file mode 100644 index 0000000..af42b55 --- /dev/null +++ b/sparsity_factory/models/__init__.py @@ -0,0 +1,4 @@ +from .vgg import VGG16 #, VGG11, VGG13, VGG19 +from .resnet import ResNet18 #, ResNet34, ResNet50, ResNet101, ResNet152 +from .densenet import DenseNet121 +from .efficientnet import EfficientNetB0 diff --git a/sparsity_factory/models/densenet.py b/sparsity_factory/models/densenet.py new file mode 100644 index 0000000..4fc1efa --- /dev/null +++ b/sparsity_factory/models/densenet.py @@ -0,0 +1,82 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import BatchNorm2d, Conv2d, Linear + +class Bottleneck(nn.Module): + def __init__(self, in_planes, growth_rate): + super(Bottleneck, self).__init__() + self.bn1 = BatchNorm2d(in_planes) + self.conv1 = Conv2d(in_planes,4*growth_rate, kernel_size=1, bias=False) + self.bn2 = BatchNorm2d(4*growth_rate) + self.conv2 = Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) + def forward(self, x): + out = self.conv1(F.relu(self.bn1(x))) + out = self.conv2(F.relu(self.bn2(out))) + out = torch.cat([out,x], 1) + return out + +class Transition(nn.Module): + def __init__(self, in_planes, out_planes): + super(Transition, self).__init__() + self.bn = BatchNorm2d(in_planes) + self.conv = Conv2d(in_planes, out_planes, kernel_size=1, bias=False) + + def forward(self, x): + out = self.conv(F.relu(self.bn(x))) + out = F.avg_pool2d(out, 2) + return out + +class DenseNet(nn.Module): + def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): + super(DenseNet, self).__init__() + self.growth_rate = growth_rate + + num_planes = 2*growth_rate + self.conv1 = Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) + + self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) + num_planes += nblocks[0]*growth_rate + out_planes = int(math.floor(num_planes*reduction)) + self.trans1 = Transition(num_planes, out_planes) + num_planes = out_planes + + self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) + num_planes += nblocks[1]*growth_rate + out_planes = int(math.floor(num_planes*reduction)) + self.trans2 = Transition(num_planes, out_planes) + num_planes = out_planes + + self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) + num_planes += nblocks[2]*growth_rate + out_planes = int(math.floor(num_planes*reduction)) + self.trans3 = Transition(num_planes, out_planes) + num_planes = out_planes + + self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) + num_planes += nblocks[3]*growth_rate + + self.bn = BatchNorm2d(num_planes) + self.linear = Linear(num_planes, num_classes) + + def _make_dense_layers(self, block, in_planes, nblock): + layers = [] + for i in range(nblock): + layers.append(block(in_planes, self.growth_rate)) + in_planes += self.growth_rate + return nn.Sequential(*layers) + + def forward(self, x): + out = self.conv1(x) + out = self.trans1(self.dense1(out)) + out = self.trans2(self.dense2(out)) + out = self.trans3(self.dense3(out)) + out = self.dense4(out) + out = F.avg_pool2d(F.relu(self.bn(out)), 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + +def DenseNet121(): + return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) \ No newline at end of file diff --git a/sparsity_factory/models/efficientnet.py b/sparsity_factory/models/efficientnet.py new file mode 100644 index 0000000..6958502 --- /dev/null +++ b/sparsity_factory/models/efficientnet.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def swish(x): + return x * x.sigmoid() + + +def drop_connect(x, drop_ratio): + keep_ratio = 1.0 - drop_ratio + mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) + mask.bernoulli_(keep_ratio) + x.div_(keep_ratio) + x.mul_(mask) + return x + + +class SE(nn.Module): + '''Squeeze-and-Excitation block with Swish.''' + + def __init__(self, in_channels, se_channels): + super(SE, self).__init__() + self.se1 = nn.Conv2d(in_channels, se_channels, + kernel_size=1, bias=True) + self.se2 = nn.Conv2d(se_channels, in_channels, + kernel_size=1, bias=True) + + def forward(self, x): + out = F.adaptive_avg_pool2d(x, (1, 1)) + out = swish(self.se1(out)) + out = self.se2(out).sigmoid() + out = x * out + return out + + +class Block(nn.Module): + '''expansion + depthwise + pointwise + squeeze-excitation''' + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + expand_ratio=1, + se_ratio=0., + drop_rate=0.): + super(Block, self).__init__() + self.stride = stride + self.drop_rate = drop_rate + self.expand_ratio = expand_ratio + + # Expansion + channels = expand_ratio * in_channels + self.conv1 = nn.Conv2d(in_channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn1 = nn.BatchNorm2d(channels) + + # Depthwise conv + self.conv2 = nn.Conv2d(channels, + channels, + kernel_size=kernel_size, + stride=stride, + padding=(1 if kernel_size == 3 else 2), + groups=channels, + bias=False) + self.bn2 = nn.BatchNorm2d(channels) + + # SE layers + se_channels = int(in_channels * se_ratio) + self.se = SE(channels, se_channels) + + # Output + self.conv3 = nn.Conv2d(channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn3 = nn.BatchNorm2d(out_channels) + + # Skip connection if in and out shapes are the same (MV-V2 style) + self.has_skip = (stride == 1) and (in_channels == out_channels) + + def forward(self, x): + out = x if self.expand_ratio == 1 else swish(self.bn1(self.conv1(x))) + out = swish(self.bn2(self.conv2(out))) + out = self.se(out) + out = self.bn3(self.conv3(out)) + if self.has_skip: + if self.training and self.drop_rate > 0: + out = drop_connect(out, self.drop_rate) + out = out + x + return out + + +class EfficientNet(nn.Module): + def __init__(self, cfg, num_classes=10): + super(EfficientNet, self).__init__() + self.cfg = cfg + self.conv1 = nn.Conv2d(3, + 32, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.layers = self._make_layers(in_channels=32) + self.linear = nn.Linear(cfg['out_channels'][-1], num_classes) + + def _make_layers(self, in_channels): + layers = [] + cfg = [self.cfg[k] for k in ['expansion', 'out_channels', 'num_blocks', 'kernel_size', + 'stride']] + b = 0 + blocks = sum(self.cfg['num_blocks']) + for expansion, out_channels, num_blocks, kernel_size, stride in zip(*cfg): + strides = [stride] + [1] * (num_blocks - 1) + for stride in strides: + drop_rate = self.cfg['drop_connect_rate'] * b / blocks + layers.append( + Block(in_channels, + out_channels, + kernel_size, + stride, + expansion, + se_ratio=0.25, + drop_rate=drop_rate)) + in_channels = out_channels + return nn.Sequential(*layers) + + def forward(self, x): + out = swish(self.bn1(self.conv1(x))) + out = self.layers(out) + out = F.adaptive_avg_pool2d(out, 1) + out = out.view(out.size(0), -1) + dropout_rate = self.cfg['dropout_rate'] + if self.training and dropout_rate > 0: + out = F.dropout(out, p=dropout_rate) + out = self.linear(out) + return out + + +def EfficientNetB0(): + cfg = { + 'num_blocks': [1, 2, 2, 3, 3, 4, 1], + 'expansion': [1, 6, 6, 6, 6, 6, 6], + 'out_channels': [16, 24, 40, 80, 112, 192, 320], + 'kernel_size': [3, 3, 5, 3, 5, 5, 3], + 'stride': [1, 2, 2, 2, 1, 2, 1], + 'dropout_rate': 0.2, + 'drop_connect_rate': 0.2, + } + return EfficientNet(cfg) + + +def test(): + net = EfficientNetB0() + x = torch.randn(2, 3, 32, 32) + y = net(x) + print(y.shape) + + +if __name__ == '__main__': + test() \ No newline at end of file diff --git a/sparsity_factory/models/resnet.py b/sparsity_factory/models/resnet.py new file mode 100644 index 0000000..96f24b7 --- /dev/null +++ b/sparsity_factory/models/resnet.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv2d, Linear,BatchNorm2d + +class BasicBlock(nn.Module): + expansion = 1 + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = BatchNorm2d(planes) + self.conv2 = Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + Conv2d(in_planes, self.expansion*planes,kernel_size=1,stride=stride,bias=False), + BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes) + self.conv2 = Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = BatchNorm2d(planes) + self.conv3 = Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + +def ResNet34(): + return ResNet(BasicBlock, [3, 4, 6, 3]) + +def ResNet50(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + +def ResNet101(): + return ResNet(Bottleneck, [3, 4, 23, 3]) + +def ResNet152(): + return ResNet(Bottleneck, [3, 8, 36, 3]) \ No newline at end of file diff --git a/sparsity_factory/models/vgg.py b/sparsity_factory/models/vgg.py new file mode 100644 index 0000000..ecbcda4 --- /dev/null +++ b/sparsity_factory/models/vgg.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv2d, Linear, BatchNorm2d, MaxPool2d, AvgPool2d + +cfg = { + 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + +class VGG(nn.Module): + def __init__(self, vgg_name, use_bn=True): + super(VGG, self).__init__() + self.features = self._make_layers(cfg[vgg_name], use_bn) + self.classifier = Linear(512, 10) + + def forward(self, x): + out = self.features(x) + out = out.view(out.size(0), -1) + out = self.classifier(out) + return out + + def _make_layers(self, cfg, use_bn): + layers = [] + in_channels = 3 + for x in cfg: + if x == 'M': + layers += [MaxPool2d(kernel_size=2, stride=2)] + else: + layers += [Conv2d(in_channels, x, kernel_size=3, padding=1)] + if use_bn: + layers += [BatchNorm2d(x)] + layers += [nn.ReLU(inplace=True)] + in_channels = x + layers += [AvgPool2d(kernel_size=1, stride=1)] + return nn.Sequential(*layers) + + +def VGG11(use_bn=True): + return VGG('VGG11', use_bn) +def VGG13(use_bn=True): + return VGG('VGG13', use_bn) +def VGG16(use_bn=True): + return VGG('VGG16', use_bn) +def VGG19(use_bn=True): + return VGG('VGG19', use_bn) \ No newline at end of file diff --git a/sparsity_factory/pruners.py b/sparsity_factory/pruners.py new file mode 100644 index 0000000..04eff3f --- /dev/null +++ b/sparsity_factory/pruners.py @@ -0,0 +1,279 @@ +import torch +from torch.nn.utils import prune +from .utils import get_weights, get_modules, get_modules_with_name +import numpy as np + +ALL_PRUNERS = ['lamp', 'glob', 'unif', 'unifplus', 'erk', 'custom'] +def check_valid_pruner(name): + return check_valid_pruner in ALL_PRUNERS + +def weight_pruner_loader(pruner_string): + """ + Gives you the pruning methods: LAMP, Glob, Unif, Unif+, ERK, nxm + """ + if pruner_string == 'lamp': + return prune_weights_lamp + elif pruner_string == 'glob': + return prune_weights_global + elif pruner_string == 'unif': + return prune_weights_uniform + elif pruner_string == 'unifplus': + return prune_weights_unifplus + elif pruner_string == 'erk': + return prune_weights_erk + elif pruner_string == 'custom': + return prune_weights_from_config + else: + raise ValueError('Unknown pruner') + + +""" +prune_weights_reparam: Allocate identity mask to every weight tensors. +prune_weights_l1predefined: Perform layerwise pruning w.r.t. given amounts. +""" + +def prune_weights_semistructured(module, configs=None): + """ + Remove the weight by the defined N:M configs. + The configs will be a 2D lists with the following format: + configs = [[N1:M1], [N2:M2] ...] + Please make sure that the len(configs) == number_of_pruning_layers + """ + def compute_mask(t, N, M): + out_channel, in_channel = t.shape + percentile = M / N + t_reshaped = t.reshape(out_channel, -1, N) + #print(t_reshaped.shape) + mask = torch.ones_like(t) + mask_reshaped = mask.reshape(out_channel, -1, N) + + nparams_topprune = int(N * percentile) + if nparams_topprune != 0: + topk = torch.topk(torch.abs(t_reshaped), k=nparams_topprune, largest=False, dim = -1) + #print(topk.indices) + mask_reshaped = mask_reshaped.scatter(dim = -1, index = topk.indices, value = 0) + + return mask_reshaped.reshape(out_channel, in_channel) + + if configs == None: + raise ValueError("Currently nxm pruning only support from manual config. \ + Please provide config of the sparsity level for earch pruning target") + + mlist = get_modules_with_name(module) + for idx, (name, m) in enumerate(mlist): + weight_tensor = m.weight + config = configs[idx] + N, M = config[0], config[1] + print(f"module: {name}, N:M = ({N}, {M})") + mask = compute_mask(weight_tensor, N, M) + prune.custom_from_mask(m, name = 'weight', mask = mask) + +def prune_weights_reparam(model): + module_list = get_modules(model) + for m in module_list: + prune.identity(m,name="weight") + +def prune_weights_l1predefined(model,amounts): + mlist = get_modules_with_name(model) + for idx,(name, m) in enumerate(mlist): + print(f"module: {name}, amounts of removed weight: {float(amounts[idx])}") + prune.l1_unstructured(m,name="weight",amount=float(amounts[idx])) + +""" +Methods: All weights +""" + +def prune_weights_global(model,amount): + parameters_to_prune = _extract_weight_tuples(model) + prune.global_unstructured(parameters_to_prune,pruning_method = prune.L1Unstructured,amount=amount) + +def prune_weights_lamp(model,amount): + assert amount <= 1 + amounts = _compute_lamp_amounts(model,amount) + prune_weights_l1predefined(model,amounts) + +def prune_weights_uniform(model,amount): + module_list = get_modules(model) + assert amount <= 1 # Can be updated later to handle > 1. + for m in module_list: + print("module:", m, " remove amount:", amount) + prune.l1_unstructured(m,name="weight",amount=amount) + +def prune_weights_unifplus(model,amount): + assert amount <= 1 + amounts = _compute_unifplus_amounts(model,amount) + prune_weights_l1predefined(model,amounts) + +def prune_weights_erk(model,amount): + assert amount <= 1 + amounts = _compute_erk_amounts(model,amount) + prune_weights_l1predefined(model,amounts) + + +def prune_weights_from_config(model, mode, configs): + if mode == 'nxm': + prune_weights_semistructured(model, configs) + elif mode == 'unstructured': + prune_weights_l1predefined(model, amounts=configs) + +""" +These are not intended to be exported. +""" + +def _extract_weight_tuples(model): + """ + Gives you well-packed weight tensors for global pruning. + """ + mlist = get_modules(model) + return tuple([(m,'weight') for m in mlist]) + +def _compute_unifplus_amounts(model,amount): + """ + Compute # of weights to prune in each layer. + """ + amounts = [] + wlist = get_weights(model) + unmaskeds = _count_unmasked_weights(model) + totals = _count_total_weights(model) + + last_layer_minimum = np.round(totals[-1]*0.2) # Minimum number of last-layer weights to keep + total_to_prune = np.round(unmaskeds.sum()*amount) + + if wlist[0].dim() == 4: + amounts.append(0) # Leave the first layer unpruned. + frac_to_prune = (total_to_prune*1.0)/(unmaskeds[1:].sum()) + if frac_to_prune > 1.0: + raise ValueError("Cannot be pruned further by the Unif+ scheme! (first layer exception)") + last_layer_to_surv_planned = np.round((1.0-frac_to_prune)*unmaskeds[-1]) + if last_layer_to_surv_planned < last_layer_minimum: + last_layer_to_prune = unmaskeds[-1] - last_layer_minimum + frac_to_prune_middle = ((total_to_prune-last_layer_to_prune)*1.0)/(unmaskeds[1:-1].sum()) + if frac_to_prune_middle > 1.0: + raise ValueError("Cannot be pruned further by the Unif+ scheme! (first+last layer exception)") + amounts.extend([frac_to_prune_middle]*(unmaskeds.size(0)-2)) + amounts.append((last_layer_to_prune*1.0)/unmaskeds[-1]) + else: + amounts.extend([frac_to_prune]*(unmaskeds.size(0)-1)) + else: + frac_to_prune = (total_to_prune*1.0)/(unmaskeds.sum()) + last_layer_to_surv_planned = np.round((1.0-frac_to_prune)*unmaskeds[-1]) + if last_layer_to_surv_planned < last_layer_minimum: + last_layer_to_prune = unmaskeds[-1] - last_layer_minimum + frac_to_prune_middle = ((total_to_prune-last_layer_to_prune)*1.0)/(unmaskeds[:-1].sum()) + if frac_to_prune_middle > 1.0: + raise ValueError("Cannot be pruned further by the Unif+ scheme! (last layer exception)") + amounts.extend([frac_to_prune_middle]*(unmaskeds.size(0)-1)) + amounts.append((last_layer_to_prune*1.0)/unmaskeds[-1]) + else: + amounts.extend([frac_to_prune]*(unmaskeds.size(0))) + return amounts + +def _compute_erk_amounts(model,amount): + unmaskeds = _count_unmasked_weights(model) + erks = _compute_erks(model) + + return _amounts_from_eps(unmaskeds,erks,amount) + +def _amounts_from_eps(unmaskeds,ers,amount): + num_layers = ers.size(0) + layers_to_keep_dense = torch.zeros(num_layers) + total_to_survive = (1.0-amount)*unmaskeds.sum() # Total to keep. + + # Determine some layers to keep dense. + is_eps_invalid = True + while is_eps_invalid: + unmasked_among_prunables = (unmaskeds*(1-layers_to_keep_dense)).sum() + to_survive_among_prunables = total_to_survive - (layers_to_keep_dense*unmaskeds).sum() + + ers_of_prunables = ers*(1.0-layers_to_keep_dense) + survs_of_prunables = torch.round(to_survive_among_prunables*ers_of_prunables/ers_of_prunables.sum()) + + layer_to_make_dense = -1 + max_ratio = 1.0 + for idx in range(num_layers): + if layers_to_keep_dense[idx] == 0: + if survs_of_prunables[idx]/unmaskeds[idx] > max_ratio: + layer_to_make_dense = idx + max_ratio = survs_of_prunables[idx]/unmaskeds[idx] + + if layer_to_make_dense == -1: + is_eps_invalid = False + else: + layers_to_keep_dense[layer_to_make_dense] = 1 + + amounts = torch.zeros(num_layers) + + for idx in range(num_layers): + if layers_to_keep_dense[idx] == 1: + amounts[idx] = 0.0 + else: + amounts[idx] = 1.0 - (survs_of_prunables[idx]/unmaskeds[idx]) + return amounts + +def _compute_lamp_amounts(model,amount): + """ + Compute normalization schemes. + """ + unmaskeds = _count_unmasked_weights(model) + num_surv = int(np.round(unmaskeds.sum()*(1.0-amount))) + + flattened_scores = [_normalize_scores(w**2).view(-1) for w in get_weights(model)] + concat_scores = torch.cat(flattened_scores,dim=0) + topks,_ = torch.topk(concat_scores,num_surv) + threshold = topks[-1] + + # We don't care much about tiebreakers, for now. + final_survs = [torch.ge(score,threshold*torch.ones(score.size()).to(score.device)).sum() for score in flattened_scores] + amounts = [] + for idx,final_surv in enumerate(final_survs): + amounts.append(1.0 - (final_surv/unmaskeds[idx])) + + return amounts + +def _compute_erks(model): + wlist = get_weights(model) + erks = torch.zeros(len(wlist)) + for idx,w in enumerate(wlist): + if w.dim() == 4: + erks[idx] = w.size(0)+w.size(1)+w.size(2)+w.size(3) + else: + erks[idx] = w.size(0)+w.size(1) + return erks + +def _count_unmasked_weights(model): + """ + Return a 1-dimensional tensor of #unmasked weights. + """ + mlist = get_modules(model) + unmaskeds = [] + for m in mlist: + unmaskeds.append(m.weight_mask.sum()) + return torch.FloatTensor(unmaskeds) + +def _count_total_weights(model): + """ + Return a 1-dimensional tensor of #total weights. + """ + wlist = get_weights(model) + numels = [] + for w in wlist: + numels.append(w.numel()) + return torch.FloatTensor(numels) + +def _normalize_scores(scores): + """ + Normalizing scheme for LAMP. + """ + # sort scores in an ascending order + sorted_scores,sorted_idx = scores.view(-1).sort(descending=False) + # compute cumulative sum + scores_cumsum_temp = sorted_scores.cumsum(dim=0) + scores_cumsum = torch.zeros(scores_cumsum_temp.shape,device=scores.device) + scores_cumsum[1:] = scores_cumsum_temp[:len(scores_cumsum_temp)-1] + # normalize by cumulative sum + sorted_scores /= (scores.sum() - scores_cumsum) + # tidy up and output + new_scores = torch.zeros(scores_cumsum.shape,device=scores.device) + new_scores[sorted_idx] = sorted_scores + + return new_scores.view(scores.shape) diff --git a/sparsity_factory/train.py b/sparsity_factory/train.py new file mode 100644 index 0000000..6e967e9 --- /dev/null +++ b/sparsity_factory/train.py @@ -0,0 +1,102 @@ +import torch +import numpy as np +import torch.nn.functional as F + +def trainer_loader(): + return train + +def initialize_weight(model,loader): + batch = next(iter(loader)) + device = next(model.parameters()).device + with torch.no_grad(): + model(batch[0].to(device)) + +def train(model,optpack,train_loader,test_loader,print_steps=-1,log_results=False,log_path='log.txt'): + model.train() + opt = optpack["optimizer"](model.parameters()) + if optpack["scheduler"] is not None: + sched = optpack["scheduler"](opt) + else: + sched = None + num_steps = optpack["steps"] + device = next(model.parameters()).device + + results_log = [] + training_step = 0 + + if sched is not None: + while True: + for i,(x,y) in enumerate(train_loader): + training_step += 1 + x = x.to(device) + y = y.to(device) + + opt.zero_grad() + yhat = model(x) + loss = F.cross_entropy(yhat,y) + loss.backward() + opt.step() + sched.step() + + if print_steps != -1 and training_step%print_steps == 0: + train_acc,train_loss = test(model,train_loader) + test_acc,test_loss = test(model,test_loader) + print(f'Steps: {training_step}/{num_steps} \t Train acc: {train_acc:.2f} \t Test acc: {test_acc:.2f}', end='\r') + if log_results: + results_log.append([test_acc,test_loss,train_acc,train_loss]) + np.savetxt(log_path,results_log) + if training_step >= num_steps: + break + if training_step >= num_steps: + break + else: + while True: + for i,(x,y) in enumerate(train_loader): + training_step += 1 + x = x.to(device) + y = y.to(device) + + opt.zero_grad() + yhat = model(x) + loss = F.cross_entropy(yhat,y) + loss.backward() + opt.step() + + if print_steps != -1 and training_step%print_steps == 0: + train_acc,train_loss = test(model,train_loader) + test_acc,test_loss = test(model,test_loader) + print(f'Steps: {training_step}/{num_steps} \t Train acc: {train_acc:.2f} \t Test acc: {test_acc:.2f}', end='\r') + if log_results: + results_log.append([test_acc,test_loss,train_acc,train_loss]) + np.savetxt(log_path,results_log) + if training_step >= num_steps: + break + if training_step >= num_steps: + break + train_acc,train_loss = test(model,train_loader) + test_acc,test_loss = test(model,test_loader) + print(f'Train acc: {train_acc:.2f}\t Test acc: {test_acc:.2f}') + return [test_acc,test_loss,train_acc,train_loss] + +def test(model,loader): + model.eval() + device = next(model.parameters()).device + + correct = 0 + loss = 0 + total = 0 + for i,(x,y) in enumerate(loader): + x = x.to(device) + y = y.to(device) + with torch.no_grad(): + yhat = model(x) + _,pred = yhat.max(1) + correct += pred.eq(y).sum().item() + loss += F.cross_entropy(yhat,y)*len(x) + total += len(x) + acc = correct/total * 100.0 + loss = loss/total + + model.train() + + return acc,loss \ No newline at end of file diff --git a/sparsity_factory/utils.py b/sparsity_factory/utils.py new file mode 100644 index 0000000..6386069 --- /dev/null +++ b/sparsity_factory/utils.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +from copy import deepcopy + +# Preliminaries. Not to be exported. + +exclude_module_name = ['head'] + +def _is_prunable_module(m, name = None): + if name is not None and name in exclude_module_name: + return False + return isinstance(m, nn.Linear) + +def _get_sparsity(tsr): + total = tsr.numel() + nnz = tsr.nonzero().size(0) + return nnz/total + +def _get_nnz(tsr): + return tsr.nonzero().size(0) + +# Modules + +def get_weights(model): + weights = [] + for name, m in model.named_modules(): + if _is_prunable_module(m, name): + weights.append(m.weight) + return weights + +def get_convweights(model): + weights = [] + for m in model.modules(): + if isinstance(m,nn.Conv2d): + weights.append(m.weight) + return weights + +def get_modules(model): + modules = [] + for name, m in model.named_modules(): + if _is_prunable_module(m, name): + modules.append(m) + return modules + +def get_modules_with_name(model): + modules = [] + for name, m in model.named_modules(): + if _is_prunable_module(m, name): + modules.append((name, m)) + return modules + + +def get_convmodules(model): + modules = [] + for m in model.modules(): + if isinstance(m,nn.Conv2d): + modules.append(m) + return modules + +def get_copied_modules(model): + modules = [] + for m in model.modules(): + if _is_prunable_module(m): + modules.append(deepcopy(m).cpu()) + return modules + +def get_model_sparsity(model): + prunables = 0 + nnzs = 0 + for m in model.modules(): + if _is_prunable_module(m): + prunables += m.weight.data.numel() + nnzs += m.weight.data.nonzero().size(0) + return nnzs/prunables + +def get_sparsities(model): + return [_get_sparsity(m.weight.data) for m in model.modules() if _is_prunable_module(m)] + +def get_nnzs(model): + return [_get_nnz(m.weight.data) for m in model.modules() if _is_prunable_module(m)]