mirror of https://github.com/facebookresearch/deit
Add KD
parent
2a19e7a999
commit
a172d027f3
18
losses.py
18
losses.py
|
@ -31,18 +31,20 @@ class DistillationLoss(torch.nn.Module):
|
|||
in the first position and the distillation predictions as the second output
|
||||
labels: the labels for the base criterion
|
||||
"""
|
||||
outputs_kd = None
|
||||
if not isinstance(outputs, torch.Tensor):
|
||||
# assume that the model outputs a tuple of [outputs, outputs_kd]
|
||||
outputs, outputs_kd = outputs
|
||||
base_loss = self.base_criterion(outputs, labels)
|
||||
if self.distillation_type == 'none':
|
||||
return base_loss
|
||||
|
||||
outputs_kd = outputs # Use normal ViT version
|
||||
# outputs_kd = None
|
||||
# if not isinstance(outputs, torch.Tensor):
|
||||
# # assume that the model outputs a tuple of [outputs, outputs_kd]
|
||||
# outputs, outputs_kd = outputs
|
||||
|
||||
if outputs_kd is None:
|
||||
raise ValueError("When knowledge distillation is enabled, the model is "
|
||||
"expected to return a Tuple[Tensor, Tensor] with the output of the "
|
||||
"class_token and the dist_token")
|
||||
# if outputs_kd is None:
|
||||
# raise ValueError("When knowledge distillation is enabled, the model is "
|
||||
# "expected to return a Tuple[Tensor, Tensor] with the output of the "
|
||||
# "class_token and the dist_token")
|
||||
# don't backprop throught the teacher
|
||||
with torch.no_grad():
|
||||
teacher_outputs = self.teacher_model(inputs)
|
||||
|
|
63
main.py
63
main.py
|
@ -142,22 +142,22 @@ def get_args_parser():
|
|||
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
||||
|
||||
# Distillation parameters
|
||||
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
|
||||
help='Name of teacher model to train (default: "regnety_160"')
|
||||
parser.add_argument('--teacher-path', type=str, default='')
|
||||
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
|
||||
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
|
||||
# parser.add_argument('--distillation-alpha', default=0, type=float, help="")
|
||||
# parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
|
||||
# help='Name of teacher model to train (default: "regnety_160"')
|
||||
parser.add_argument('--teacher-model', default='deit_small_patch16_224', type=str, metavar='MODEL')
|
||||
parser.add_argument('--teacher-path', type=str, default=None)
|
||||
# parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
|
||||
parser.add_argument('--distillation-type', default='soft', choices=['none', 'soft', 'hard'], type=str, help="")
|
||||
# parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
|
||||
parser.add_argument('--distillation-alpha', default=0.0, type=float, help="")
|
||||
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
|
||||
|
||||
# * Finetuning params
|
||||
# parser.add_argument('--finetune', default='weights/deit_small_patch16_224-cd65a155.pth', help='finetune from checkpoint')
|
||||
# parser.add_argument('--finetune', default='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', help='finetune from checkpoint')
|
||||
parser.add_argument('--finetune', default=None, help='finetune from checkpoint')
|
||||
parser.add_argument('--attn-only', action='store_true')
|
||||
|
||||
# Dataset parameters
|
||||
parser.add_argument('--data-path', default='/dev/shm/imagenet', type=str,
|
||||
parser.add_argument('--data-path', default='/dataset/imagenet', type=str,
|
||||
help='dataset path')
|
||||
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
|
||||
type=str, help='Image Net dataset path')
|
||||
|
@ -193,19 +193,20 @@ def get_args_parser():
|
|||
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29500 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --epochs 50 --output_dir result_sub_1:4_50epoch
|
||||
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform24.yaml --epochs 50 --output_dir result_sub_2:4_50epoch
|
||||
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --eval
|
||||
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir twined_nas_124+13_150epoch
|
||||
parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL',
|
||||
help='Name of model to train')
|
||||
parser.add_argument('--nas-config', type=str, default='configs/deit_small_nxm_nas_124.yaml', help='configuration for supernet training')
|
||||
parser.add_argument('--nas-mode', action='store_true', default=True)
|
||||
# parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default='result_sub_2:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default='result_nas_124+13_150epoch/checkpoint.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default='result_nas_124_150epoch/best_checkpoint.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default='result_sub_14_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
|
||||
parser.add_argument('--nas-weights', default='result_sub_24_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default='result_nas_124+13_150epoch/best_checkpoint.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default=None, help='load pretrained supernet weight')
|
||||
parser.add_argument('--wandb', action='store_true', default=True)
|
||||
# parser.add_argument('--nas-weights', default='result_1:8_100epoch/best_checkpoint.pth', help='load pretrained supernet weight')
|
||||
parser.add_argument('--nas-weights', default=None, help='load pretrained supernet weight')
|
||||
parser.add_argument('--wandb', action='store_true')
|
||||
parser.add_argument('--output_dir', default='result',
|
||||
help='path where to save, empty for no saving')
|
||||
return parser
|
||||
|
@ -427,20 +428,30 @@ def main(args):
|
|||
|
||||
teacher_model = None
|
||||
if args.distillation_type != 'none':
|
||||
assert args.teacher_path, 'need to specify teacher-path when using distillation'
|
||||
# assert args.teacher_path, 'need to specify teacher-path when using distillation'
|
||||
print(f"Creating teacher model: {args.teacher_model}")
|
||||
teacher_model = create_model(
|
||||
args.teacher_model,
|
||||
pretrained=False,
|
||||
num_classes=args.nb_classes,
|
||||
global_pool='avg',
|
||||
)
|
||||
if args.teacher_path.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
# teacher_model = create_model( // regnety160
|
||||
# args.teacher_model,
|
||||
# pretrained=False,
|
||||
# num_classes=args.nb_classes,
|
||||
# global_pool='avg',
|
||||
# )
|
||||
teacher_model = create_model( # deit-small
|
||||
args.teacher_model,
|
||||
pretrained=True,
|
||||
num_classes=args.nb_classes,
|
||||
drop_rate=args.drop,
|
||||
drop_path_rate=args.drop_path,
|
||||
drop_block_rate=None,
|
||||
img_size=args.input_size
|
||||
)
|
||||
if args.teacher_path:
|
||||
if args.teacher_path.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.teacher_path, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.teacher_path, map_location='cpu')
|
||||
teacher_model.load_state_dict(checkpoint['model'])
|
||||
else:
|
||||
checkpoint = torch.load(args.teacher_path, map_location='cpu')
|
||||
teacher_model.load_state_dict(checkpoint['model'])
|
||||
teacher_model.to(device)
|
||||
teacher_model.eval()
|
||||
|
||||
|
|
24
models.py
24
models.py
|
@ -87,15 +87,15 @@ default_cfgs = {
|
|||
'vit_base_resnet50d_224': _cfg(),
|
||||
}
|
||||
|
||||
class LRMlpSuper(nn.Module):
|
||||
class MlpSuper(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.act = act_layer()
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.fc1 = SparseLinearSuper(in_features, hidden_features)
|
||||
self.fc2 = SparseLinearSuper(hidden_features, out_features)
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
|
@ -105,15 +105,15 @@ class LRMlpSuper(nn.Module):
|
|||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
class LRAttentionSuper(nn.Module):
|
||||
class AttentionSuper(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.proj = SparseLinearSuper(dim, dim)
|
||||
self.qkv = SparseLinearSuper(dim, dim * 3, bias = qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias = qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
|
@ -137,12 +137,12 @@ class Block(nn.Module):
|
|||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = LRAttentionSuper(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, )
|
||||
self.attn = AttentionSuper(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, )
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = LRMlpSuper(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.mlp = MlpSuper(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
|
@ -178,7 +178,7 @@ class PatchEmbed(nn.Module):
|
|||
|
||||
|
||||
|
||||
class SparseVisionTransformer(nn.Module):
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer
|
||||
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||
|
@ -323,7 +323,7 @@ def deit_base_patch16_224(pretrained=False, **kwargs):
|
|||
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model = SparseVisionTransformer(
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
|
@ -342,7 +342,7 @@ def deit_small_patch16_224(pretrained=False, **kwargs):
|
|||
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model = SparseVisionTransformer(
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
|
@ -360,7 +360,7 @@ def deit_tiny_patch16_224(pretrained=False, **kwargs):
|
|||
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model = SparseVisionTransformer(
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
|
|
Loading…
Reference in New Issue