mirror of https://github.com/facebookresearch/deit
Add deit_base_distilled_patch16_384 (#55)
parent
84f7cf316f
commit
ab5715372d
17
README.md
17
README.md
|
@ -32,6 +32,7 @@ We provide baseline DeiT models pretrained on ImageNet 2012.
|
|||
| DeiT-small distilled | 81.2 | 95.4 | 22M| [model](https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth) |
|
||||
| DeiT-base distilled | 83.4 | 96.5 | 87M | [model](https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth) |
|
||||
| DeiT-base 384 | 82.9 | 96.2 | 87M | [model](https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth) |
|
||||
| DeiT-base distilled 384 (1000 epochs) | 85.2 | 97.2 | 88M | [model](https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth) |
|
||||
|
||||
|
||||
The models are also available via torch hub.
|
||||
|
@ -180,6 +181,22 @@ giving
|
|||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>
|
||||
deit_base_distilled_patch16_384
|
||||
</summary>
|
||||
|
||||
```
|
||||
python main.py --eval --model deit_base_distilled_patch16_384 --input-size 384 --resume https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth
|
||||
```
|
||||
giving
|
||||
```
|
||||
* Acc@1 85.224 Acc@5 97.186 loss 0.636
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Training
|
||||
To train DeiT-small and Deit-tiny on ImageNet on a single node with 4 gpus for 300 epochs run:
|
||||
|
||||
|
|
2
main.py
2
main.py
|
@ -173,7 +173,7 @@ def main(args):
|
|||
|
||||
print(args)
|
||||
|
||||
if args.distillation_type != 'none' and args.finetune:
|
||||
if args.distillation_type != 'none' and args.finetune and not args.eval:
|
||||
raise NotImplementedError("Finetuning with distillation not yet supported")
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
|
18
models.py
18
models.py
|
@ -12,7 +12,8 @@ from timm.models.layers import trunc_normal_
|
|||
__all__ = [
|
||||
'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
|
||||
'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
|
||||
'deit_base_distilled_patch16_224', 'deit_base_patch16_384'
|
||||
'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
|
||||
'deit_base_distilled_patch16_384',
|
||||
]
|
||||
|
||||
|
||||
|
@ -161,3 +162,18 @@ def deit_base_patch16_384(pretrained=False, **kwargs):
|
|||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
||||
model = DistilledVisionTransformer(
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
|
||||
map_location="cpu", check_hash=True
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue