mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add edgenext_base model def & weight link, update to improve ONNX export #1385
This commit is contained in:
parent
56596e4e84
commit
13565aad50
@ -50,6 +50,12 @@ default_cfgs = dict(
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
|
||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
# edgenext_base=_cfg(
|
||||
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth"),
|
||||
edgenext_base=_cfg( # USI weights
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth",
|
||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
|
||||
edgenext_small_rw=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
|
||||
@ -154,7 +160,7 @@ class CrossCovarianceAttn(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
|
||||
@ -217,7 +223,8 @@ class SplitTransposeBlock(nn.Module):
|
||||
shortcut = x
|
||||
|
||||
# scales code re-written for torchscript as per my res2net fixes -rw
|
||||
spx = torch.split(x, self.width, 1)
|
||||
# NOTE torch.split(x, self.width, 1) causing issues with ONNX export
|
||||
spx = x.chunk(len(self.convs) + 1, dim=1)
|
||||
spo = []
|
||||
sp = spx[0]
|
||||
for i, conv in enumerate(self.convs):
|
||||
@ -545,13 +552,19 @@ def edgenext_small(pretrained=False, **kwargs):
|
||||
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def edgenext_base(pretrained=False, **kwargs):
|
||||
# 18.51M & 3840.93M @ 256 resolution
|
||||
# 82.5% (normal) 83.7% (USI) Top-1 accuracy
|
||||
# AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
|
||||
# Jetson FPS=xx.xx versus xx.xx for MobileViT_S
|
||||
# For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx
|
||||
model_kwargs = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584], **kwargs)
|
||||
return _create_edgenext('edgenext_base', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def edgenext_small_rw(pretrained=False, **kwargs):
|
||||
# 5.59M & 1260.59M @ 256 resolution
|
||||
# 79.43% Top-1 accuracy
|
||||
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
|
||||
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
|
||||
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
|
||||
model_kwargs = dict(
|
||||
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
|
||||
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user