diff --git a/README.md b/README.md index fcb86f8..a884759 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ We provide baseline DeiT models pretrained on ImageNet 2012. | DeiT-small distilled | 81.2 | 95.4 | 22M| [model](https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth) | | DeiT-base distilled | 83.4 | 96.5 | 87M | [model](https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth) | | DeiT-base 384 | 82.9 | 96.2 | 87M | [model](https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth) | +| DeiT-base distilled 384 (1000 epochs) | 85.2 | 97.2 | 88M | [model](https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth) | The models are also available via torch hub. @@ -180,6 +181,22 @@ giving +
+ + +deit_base_distilled_patch16_384 + + +``` +python main.py --eval --model deit_base_distilled_patch16_384 --input-size 384 --resume https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth +``` +giving +``` +* Acc@1 85.224 Acc@5 97.186 loss 0.636 +``` + +
+ ## Training To train DeiT-small and Deit-tiny on ImageNet on a single node with 4 gpus for 300 epochs run: diff --git a/main.py b/main.py index 5ff017b..06679f6 100644 --- a/main.py +++ b/main.py @@ -173,7 +173,7 @@ def main(args): print(args) - if args.distillation_type != 'none' and args.finetune: + if args.distillation_type != 'none' and args.finetune and not args.eval: raise NotImplementedError("Finetuning with distillation not yet supported") device = torch.device(args.device) diff --git a/models.py b/models.py index 424b18c..5b22ef3 100644 --- a/models.py +++ b/models.py @@ -12,7 +12,8 @@ from timm.models.layers import trunc_normal_ __all__ = [ 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224', - 'deit_base_distilled_patch16_224', 'deit_base_patch16_384' + 'deit_base_distilled_patch16_224', 'deit_base_patch16_384', + 'deit_base_distilled_patch16_384', ] @@ -161,3 +162,18 @@ def deit_base_patch16_384(pretrained=False, **kwargs): ) model.load_state_dict(checkpoint["model"]) return model + + +@register_model +def deit_base_distilled_patch16_384(pretrained=False, **kwargs): + model = DistilledVisionTransformer( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model