From a172d027f3272fcc6105dd453bb525962a3ddc31 Mon Sep 17 00:00:00 2001 From: max410011 <410011max@gmail.com> Date: Fri, 16 Jun 2023 17:09:58 +0000 Subject: [PATCH] Add KD --- losses.py | 18 +++++++++------- main.py | 63 ++++++++++++++++++++++++++++++++----------------------- models.py | 24 ++++++++++----------- 3 files changed, 59 insertions(+), 46 deletions(-) diff --git a/losses.py b/losses.py index b386769..6a0393d 100644 --- a/losses.py +++ b/losses.py @@ -31,18 +31,20 @@ class DistillationLoss(torch.nn.Module): in the first position and the distillation predictions as the second output labels: the labels for the base criterion """ - outputs_kd = None - if not isinstance(outputs, torch.Tensor): - # assume that the model outputs a tuple of [outputs, outputs_kd] - outputs, outputs_kd = outputs base_loss = self.base_criterion(outputs, labels) if self.distillation_type == 'none': return base_loss + + outputs_kd = outputs # Use normal ViT version + # outputs_kd = None + # if not isinstance(outputs, torch.Tensor): + # # assume that the model outputs a tuple of [outputs, outputs_kd] + # outputs, outputs_kd = outputs - if outputs_kd is None: - raise ValueError("When knowledge distillation is enabled, the model is " - "expected to return a Tuple[Tensor, Tensor] with the output of the " - "class_token and the dist_token") + # if outputs_kd is None: + # raise ValueError("When knowledge distillation is enabled, the model is " + # "expected to return a Tuple[Tensor, Tensor] with the output of the " + # "class_token and the dist_token") # don't backprop throught the teacher with torch.no_grad(): teacher_outputs = self.teacher_model(inputs) diff --git a/main.py b/main.py index 682965c..fe4ec32 100644 --- a/main.py +++ b/main.py @@ -142,22 +142,22 @@ def get_args_parser(): help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') # Distillation parameters - parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', - help='Name of teacher model to train (default: "regnety_160"') - parser.add_argument('--teacher-path', type=str, default='') - parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") - parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") - # parser.add_argument('--distillation-alpha', default=0, type=float, help="") + # parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', + # help='Name of teacher model to train (default: "regnety_160"') + parser.add_argument('--teacher-model', default='deit_small_patch16_224', type=str, metavar='MODEL') + parser.add_argument('--teacher-path', type=str, default=None) + # parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") + parser.add_argument('--distillation-type', default='soft', choices=['none', 'soft', 'hard'], type=str, help="") + # parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") + parser.add_argument('--distillation-alpha', default=0.0, type=float, help="") parser.add_argument('--distillation-tau', default=1.0, type=float, help="") # * Finetuning params - # parser.add_argument('--finetune', default='weights/deit_small_patch16_224-cd65a155.pth', help='finetune from checkpoint') - # parser.add_argument('--finetune', default='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', help='finetune from checkpoint') parser.add_argument('--finetune', default=None, help='finetune from checkpoint') parser.add_argument('--attn-only', action='store_true') # Dataset parameters - parser.add_argument('--data-path', default='/dev/shm/imagenet', type=str, + parser.add_argument('--data-path', default='/dataset/imagenet', type=str, help='dataset path') parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], type=str, help='Image Net dataset path') @@ -193,19 +193,20 @@ def get_args_parser(): # python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29500 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --epochs 50 --output_dir result_sub_1:4_50epoch # python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform24.yaml --epochs 50 --output_dir result_sub_2:4_50epoch # python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --eval + # python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir twined_nas_124+13_150epoch parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--nas-config', type=str, default='configs/deit_small_nxm_nas_124.yaml', help='configuration for supernet training') parser.add_argument('--nas-mode', action='store_true', default=True) # parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight') # parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight') + # parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight') + # parser.add_argument('--nas-weights', default='result_sub_2:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight') # parser.add_argument('--nas-weights', default='result_nas_124+13_150epoch/checkpoint.pth', help='load pretrained supernet weight') - # parser.add_argument('--nas-weights', default='result_nas_124_150epoch/best_checkpoint.pth', help='load pretrained supernet weight') - # parser.add_argument('--nas-weights', default='result_sub_14_50epoch/best_checkpoint.pth', help='load pretrained supernet weight') - parser.add_argument('--nas-weights', default='result_sub_24_50epoch/best_checkpoint.pth', help='load pretrained supernet weight') # parser.add_argument('--nas-weights', default='result_nas_124+13_150epoch/best_checkpoint.pth', help='load pretrained supernet weight') - # parser.add_argument('--nas-weights', default=None, help='load pretrained supernet weight') - parser.add_argument('--wandb', action='store_true', default=True) + # parser.add_argument('--nas-weights', default='result_1:8_100epoch/best_checkpoint.pth', help='load pretrained supernet weight') + parser.add_argument('--nas-weights', default=None, help='load pretrained supernet weight') + parser.add_argument('--wandb', action='store_true') parser.add_argument('--output_dir', default='result', help='path where to save, empty for no saving') return parser @@ -427,20 +428,30 @@ def main(args): teacher_model = None if args.distillation_type != 'none': - assert args.teacher_path, 'need to specify teacher-path when using distillation' + # assert args.teacher_path, 'need to specify teacher-path when using distillation' print(f"Creating teacher model: {args.teacher_model}") - teacher_model = create_model( - args.teacher_model, - pretrained=False, - num_classes=args.nb_classes, - global_pool='avg', - ) - if args.teacher_path.startswith('https'): - checkpoint = torch.hub.load_state_dict_from_url( + # teacher_model = create_model( // regnety160 + # args.teacher_model, + # pretrained=False, + # num_classes=args.nb_classes, + # global_pool='avg', + # ) + teacher_model = create_model( # deit-small + args.teacher_model, + pretrained=True, + num_classes=args.nb_classes, + drop_rate=args.drop, + drop_path_rate=args.drop_path, + drop_block_rate=None, + img_size=args.input_size + ) + if args.teacher_path: + if args.teacher_path.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( args.teacher_path, map_location='cpu', check_hash=True) - else: - checkpoint = torch.load(args.teacher_path, map_location='cpu') - teacher_model.load_state_dict(checkpoint['model']) + else: + checkpoint = torch.load(args.teacher_path, map_location='cpu') + teacher_model.load_state_dict(checkpoint['model']) teacher_model.to(device) teacher_model.eval() diff --git a/models.py b/models.py index 3ecd17e..e40fbe9 100644 --- a/models.py +++ b/models.py @@ -87,15 +87,15 @@ default_cfgs = { 'vit_base_resnet50d_224': _cfg(), } -class LRMlpSuper(nn.Module): +class MlpSuper(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.act = act_layer() self.drop = nn.Dropout(drop) - self.fc1 = SparseLinearSuper(in_features, hidden_features) - self.fc2 = SparseLinearSuper(hidden_features, out_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) def forward(self, x): x = self.fc1(x) @@ -105,15 +105,15 @@ class LRMlpSuper(nn.Module): x = self.drop(x) return x -class LRAttentionSuper(nn.Module): +class AttentionSuper(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 - self.proj = SparseLinearSuper(dim, dim) - self.qkv = SparseLinearSuper(dim, dim * 3, bias = qkv_bias) + self.proj = nn.Linear(dim, dim) + self.qkv = nn.Linear(dim, dim * 3, bias = qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj_drop = nn.Dropout(proj_drop) @@ -137,12 +137,12 @@ class Block(nn.Module): drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) - self.attn = LRAttentionSuper(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, ) + self.attn = AttentionSuper(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = LRMlpSuper(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = MlpSuper(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) @@ -178,7 +178,7 @@ class PatchEmbed(nn.Module): -class SparseVisionTransformer(nn.Module): +class VisionTransformer(nn.Module): """ Vision Transformer A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` @@ -323,7 +323,7 @@ def deit_base_patch16_224(pretrained=False, **kwargs): """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ - model = SparseVisionTransformer( + model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() @@ -342,7 +342,7 @@ def deit_small_patch16_224(pretrained=False, **kwargs): """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ - model = SparseVisionTransformer( + model = VisionTransformer( patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg() @@ -360,7 +360,7 @@ def deit_tiny_patch16_224(pretrained=False, **kwargs): """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ - model = SparseVisionTransformer( + model = VisionTransformer( patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = _cfg()