pull/225/head
max410011 2023-06-16 17:09:58 +00:00
parent 2a19e7a999
commit a172d027f3
3 changed files with 59 additions and 46 deletions

View File

@ -31,18 +31,20 @@ class DistillationLoss(torch.nn.Module):
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss
outputs_kd = outputs # Use normal ViT version
# outputs_kd = None
# if not isinstance(outputs, torch.Tensor):
# # assume that the model outputs a tuple of [outputs, outputs_kd]
# outputs, outputs_kd = outputs
if outputs_kd is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
# if outputs_kd is None:
# raise ValueError("When knowledge distillation is enabled, the model is "
# "expected to return a Tuple[Tensor, Tensor] with the output of the "
# "class_token and the dist_token")
# don't backprop throught the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)

63
main.py
View File

@ -142,22 +142,22 @@ def get_args_parser():
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# Distillation parameters
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
help='Name of teacher model to train (default: "regnety_160"')
parser.add_argument('--teacher-path', type=str, default='')
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
# parser.add_argument('--distillation-alpha', default=0, type=float, help="")
# parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
# help='Name of teacher model to train (default: "regnety_160"')
parser.add_argument('--teacher-model', default='deit_small_patch16_224', type=str, metavar='MODEL')
parser.add_argument('--teacher-path', type=str, default=None)
# parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
parser.add_argument('--distillation-type', default='soft', choices=['none', 'soft', 'hard'], type=str, help="")
# parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
parser.add_argument('--distillation-alpha', default=0.0, type=float, help="")
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
# * Finetuning params
# parser.add_argument('--finetune', default='weights/deit_small_patch16_224-cd65a155.pth', help='finetune from checkpoint')
# parser.add_argument('--finetune', default='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', help='finetune from checkpoint')
parser.add_argument('--finetune', default=None, help='finetune from checkpoint')
parser.add_argument('--attn-only', action='store_true')
# Dataset parameters
parser.add_argument('--data-path', default='/dev/shm/imagenet', type=str,
parser.add_argument('--data-path', default='/dataset/imagenet', type=str,
help='dataset path')
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
type=str, help='Image Net dataset path')
@ -193,19 +193,20 @@ def get_args_parser():
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29500 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --epochs 50 --output_dir result_sub_1:4_50epoch
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform24.yaml --epochs 50 --output_dir result_sub_2:4_50epoch
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --eval
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir twined_nas_124+13_150epoch
parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--nas-config', type=str, default='configs/deit_small_nxm_nas_124.yaml', help='configuration for supernet training')
parser.add_argument('--nas-mode', action='store_true', default=True)
# parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_sub_2:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_nas_124+13_150epoch/checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_nas_124_150epoch/best_checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_sub_14_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
parser.add_argument('--nas-weights', default='result_sub_24_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_nas_124+13_150epoch/best_checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default=None, help='load pretrained supernet weight')
parser.add_argument('--wandb', action='store_true', default=True)
# parser.add_argument('--nas-weights', default='result_1:8_100epoch/best_checkpoint.pth', help='load pretrained supernet weight')
parser.add_argument('--nas-weights', default=None, help='load pretrained supernet weight')
parser.add_argument('--wandb', action='store_true')
parser.add_argument('--output_dir', default='result',
help='path where to save, empty for no saving')
return parser
@ -427,20 +428,30 @@ def main(args):
teacher_model = None
if args.distillation_type != 'none':
assert args.teacher_path, 'need to specify teacher-path when using distillation'
# assert args.teacher_path, 'need to specify teacher-path when using distillation'
print(f"Creating teacher model: {args.teacher_model}")
teacher_model = create_model(
args.teacher_model,
pretrained=False,
num_classes=args.nb_classes,
global_pool='avg',
)
if args.teacher_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
# teacher_model = create_model( // regnety160
# args.teacher_model,
# pretrained=False,
# num_classes=args.nb_classes,
# global_pool='avg',
# )
teacher_model = create_model( # deit-small
args.teacher_model,
pretrained=True,
num_classes=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=None,
img_size=args.input_size
)
if args.teacher_path:
if args.teacher_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.teacher_path, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.teacher_path, map_location='cpu')
teacher_model.load_state_dict(checkpoint['model'])
else:
checkpoint = torch.load(args.teacher_path, map_location='cpu')
teacher_model.load_state_dict(checkpoint['model'])
teacher_model.to(device)
teacher_model.eval()

View File

@ -87,15 +87,15 @@ default_cfgs = {
'vit_base_resnet50d_224': _cfg(),
}
class LRMlpSuper(nn.Module):
class MlpSuper(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.act = act_layer()
self.drop = nn.Dropout(drop)
self.fc1 = SparseLinearSuper(in_features, hidden_features)
self.fc2 = SparseLinearSuper(hidden_features, out_features)
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
def forward(self, x):
x = self.fc1(x)
@ -105,15 +105,15 @@ class LRMlpSuper(nn.Module):
x = self.drop(x)
return x
class LRAttentionSuper(nn.Module):
class AttentionSuper(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.proj = SparseLinearSuper(dim, dim)
self.qkv = SparseLinearSuper(dim, dim * 3, bias = qkv_bias)
self.proj = nn.Linear(dim, dim)
self.qkv = nn.Linear(dim, dim * 3, bias = qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
@ -137,12 +137,12 @@ class Block(nn.Module):
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = LRAttentionSuper(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, )
self.attn = AttentionSuper(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, )
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = LRMlpSuper(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp = MlpSuper(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
@ -178,7 +178,7 @@ class PatchEmbed(nn.Module):
class SparseVisionTransformer(nn.Module):
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
@ -323,7 +323,7 @@ def deit_base_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model = SparseVisionTransformer(
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
@ -342,7 +342,7 @@ def deit_small_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model = SparseVisionTransformer(
model = VisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
@ -360,7 +360,7 @@ def deit_tiny_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model = SparseVisionTransformer(
model = VisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()