Add deit_base_distilled_patch16_384 (#55)

pull/83/head
Francisco Massa 2021-01-18 12:19:47 +01:00 committed by GitHub
parent 84f7cf316f
commit ab5715372d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 2 deletions

View File

@ -32,6 +32,7 @@ We provide baseline DeiT models pretrained on ImageNet 2012.
| DeiT-small distilled | 81.2 | 95.4 | 22M| [model](https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth) |
| DeiT-base distilled | 83.4 | 96.5 | 87M | [model](https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth) |
| DeiT-base 384 | 82.9 | 96.2 | 87M | [model](https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth) |
| DeiT-base distilled 384 (1000 epochs) | 85.2 | 97.2 | 88M | [model](https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth) |
The models are also available via torch hub.
@ -180,6 +181,22 @@ giving
</details>
<details>
<summary>
deit_base_distilled_patch16_384
</summary>
```
python main.py --eval --model deit_base_distilled_patch16_384 --input-size 384 --resume https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth
```
giving
```
* Acc@1 85.224 Acc@5 97.186 loss 0.636
```
</details>
## Training
To train DeiT-small and Deit-tiny on ImageNet on a single node with 4 gpus for 300 epochs run:

View File

@ -173,7 +173,7 @@ def main(args):
print(args)
if args.distillation_type != 'none' and args.finetune:
if args.distillation_type != 'none' and args.finetune and not args.eval:
raise NotImplementedError("Finetuning with distillation not yet supported")
device = torch.device(args.device)

View File

@ -12,7 +12,8 @@ from timm.models.layers import trunc_normal_
__all__ = [
'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
'deit_base_distilled_patch16_224', 'deit_base_patch16_384'
'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
'deit_base_distilled_patch16_384',
]
@ -161,3 +162,18 @@ def deit_base_patch16_384(pretrained=False, **kwargs):
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
model = DistilledVisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model