mirror of
https://github.com/facebookresearch/deit.git
synced 2025-06-03 14:52:20 +08:00
Add Knowledge-Distillation (#42)
* Add knowledge distillation * Bugfix * Bugfix * Make names more readable and use single torch.cat call * Remove criterion.train() in engine The teacher should stay in eval mode * Change default argument for teacher-model * Return the average of classifiers during inference * Cleanup unused code * Add docstring for DistillationLoss * Remove warnings from newer PyTorch Also uses more stable variant, instead of using softmax + log, use directly log_softmax
This commit is contained in:
parent
30eb3186da
commit
8eae3269da
@ -12,16 +12,16 @@ import torch
|
||||
from timm.data import Mixup
|
||||
from timm.utils import accuracy, ModelEma
|
||||
|
||||
from losses import DistillationLoss
|
||||
import utils
|
||||
|
||||
|
||||
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
||||
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
|
||||
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
||||
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
||||
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None):
|
||||
# TODO fix this for finetuning
|
||||
model.train()
|
||||
criterion.train()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
header = 'Epoch: [{}]'.format(epoch)
|
||||
@ -36,7 +36,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs = model(samples)
|
||||
loss = criterion(outputs, targets)
|
||||
loss = criterion(samples, outputs, targets)
|
||||
|
||||
loss_value = loss.item()
|
||||
|
||||
|
59
losses.py
Normal file
59
losses.py
Normal file
@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class DistillationLoss(torch.nn.Module):
|
||||
"""
|
||||
This module wraps a standard criterion and adds an extra knowledge distillation loss by
|
||||
taking a teacher model prediction and using it as additional supervision.
|
||||
"""
|
||||
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
|
||||
distillation_type: str, alpha: float, tau: float):
|
||||
super().__init__()
|
||||
self.base_criterion = base_criterion
|
||||
self.teacher_model = teacher_model
|
||||
assert distillation_type in ['none', 'soft', 'hard']
|
||||
self.distillation_type = distillation_type
|
||||
self.alpha = alpha
|
||||
self.tau = tau
|
||||
|
||||
def forward(self, inputs, outputs, labels):
|
||||
"""
|
||||
Args:
|
||||
inputs: The original inputs that are feed to the teacher model
|
||||
outputs: the outputs of the model to be trained. It is expected to be
|
||||
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
|
||||
in the first position and the distillation predictions as the second output
|
||||
labels: the labels for the base criterion
|
||||
"""
|
||||
outputs_kd = None
|
||||
if not isinstance(outputs, torch.Tensor):
|
||||
# assume that the model outputs a tuple of [outputs, outputs_kd]
|
||||
outputs, outputs_kd = outputs
|
||||
base_loss = self.base_criterion(outputs, labels)
|
||||
if self.distillation_type == 'none':
|
||||
return base_loss
|
||||
|
||||
if outputs_kd is None:
|
||||
raise ValueError("When knowledge distillation is enabled, the model is "
|
||||
"expected to return a Tuple[Tensor, Tensor] with the output of the "
|
||||
"class_token and the dist_token")
|
||||
# don't backprop throught the teacher
|
||||
with torch.no_grad():
|
||||
teacher_outputs = self.teacher_model(inputs)
|
||||
|
||||
if self.distillation_type == 'soft':
|
||||
T = self.tau
|
||||
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
# with slight modifications
|
||||
distillation_loss = F.kl_div(
|
||||
F.log_softmax(outputs_kd / T, dim=1),
|
||||
F.log_softmax(teacher_outputs / T, dim=1),
|
||||
reduction='sum',
|
||||
log_target=True
|
||||
) * (T * T) / outputs_kd.numel()
|
||||
elif self.distillation_type == 'hard':
|
||||
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
|
||||
|
||||
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
|
||||
return loss
|
34
main.py
34
main.py
@ -19,6 +19,7 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
|
||||
|
||||
from datasets import build_dataset
|
||||
from engine import train_one_epoch, evaluate
|
||||
from losses import DistillationLoss
|
||||
from samplers import RASampler
|
||||
import models
|
||||
import utils
|
||||
@ -123,6 +124,14 @@ def get_args_parser():
|
||||
parser.add_argument('--mixup-mode', type=str, default='batch',
|
||||
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
||||
|
||||
# Distillation parameters
|
||||
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
|
||||
help='Name of teacher model to train (default: "regnety_160"')
|
||||
parser.add_argument('--teacher-path', type=str, default='')
|
||||
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
|
||||
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
|
||||
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
|
||||
|
||||
# Dataset parameters
|
||||
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
||||
help='dataset path')
|
||||
@ -269,6 +278,31 @@ def main(args):
|
||||
else:
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
teacher_model = None
|
||||
if args.distillation_type != 'none':
|
||||
assert args.teacher_path, 'need to specify teacher-path when using distillation'
|
||||
print(f"Creating teacher model: {args.teacher_model}")
|
||||
teacher_model = create_model(
|
||||
args.teacher_model,
|
||||
pretrained=False,
|
||||
num_classes=args.nb_classes,
|
||||
global_pool='avg',
|
||||
)
|
||||
if args.teacher_path.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.teacher_path, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.teacher_path, map_location='cpu')
|
||||
teacher_model.load_state_dict(checkpoint['model'])
|
||||
teacher_model.to(device)
|
||||
teacher_model.eval()
|
||||
|
||||
# wrap the criterion in our custom DistillationLoss, which
|
||||
# just dispatches to the original criterion if args.distillation_type is 'none'
|
||||
criterion = DistillationLoss(
|
||||
criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
|
||||
)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
if args.resume:
|
||||
if args.resume.startswith('https'):
|
||||
|
88
models.py
88
models.py
@ -6,6 +6,49 @@ from functools import partial
|
||||
|
||||
from timm.models.vision_transformer import VisionTransformer, _cfg
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
|
||||
class DistilledVisionTransformer(VisionTransformer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.dist_token, std=.02)
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
self.head_dist.apply(self._init_weights)
|
||||
|
||||
def forward_features(self, x):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to add the dist_token
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
return x[:, 0], x[:, 1]
|
||||
|
||||
def forward(self, x):
|
||||
x, x_dist = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
x_dist = self.head_dist(x_dist)
|
||||
if self.training:
|
||||
return x, x_dist
|
||||
else:
|
||||
# during inference, return the average of both classifier predictions
|
||||
return (x + x_dist) / 2
|
||||
|
||||
|
||||
@register_model
|
||||
@ -51,3 +94,48 @@ def deit_base_patch16_224(pretrained=False, **kwargs):
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
model = DistilledVisionTransformer(
|
||||
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="",
|
||||
map_location="cpu", check_hash=True
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
model = DistilledVisionTransformer(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="",
|
||||
map_location="cpu", check_hash=True
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
model = DistilledVisionTransformer(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="",
|
||||
map_location="cpu", check_hash=True
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user