mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add edgenext_base model def & weight link, update to improve ONNX export #1385
This commit is contained in:
parent
56596e4e84
commit
13565aad50
@ -50,6 +50,12 @@ default_cfgs = dict(
|
|||||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
|
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
|
||||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||||
),
|
),
|
||||||
|
# edgenext_base=_cfg(
|
||||||
|
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth"),
|
||||||
|
edgenext_base=_cfg( # USI weights
|
||||||
|
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth",
|
||||||
|
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||||
|
),
|
||||||
|
|
||||||
edgenext_small_rw=_cfg(
|
edgenext_small_rw=_cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
|
||||||
@ -154,7 +160,7 @@ class CrossCovarianceAttn(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1)
|
||||||
q, k, v = qkv.unbind(0)
|
q, k, v = qkv.unbind(0)
|
||||||
|
|
||||||
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
|
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
|
||||||
@ -217,7 +223,8 @@ class SplitTransposeBlock(nn.Module):
|
|||||||
shortcut = x
|
shortcut = x
|
||||||
|
|
||||||
# scales code re-written for torchscript as per my res2net fixes -rw
|
# scales code re-written for torchscript as per my res2net fixes -rw
|
||||||
spx = torch.split(x, self.width, 1)
|
# NOTE torch.split(x, self.width, 1) causing issues with ONNX export
|
||||||
|
spx = x.chunk(len(self.convs) + 1, dim=1)
|
||||||
spo = []
|
spo = []
|
||||||
sp = spx[0]
|
sp = spx[0]
|
||||||
for i, conv in enumerate(self.convs):
|
for i, conv in enumerate(self.convs):
|
||||||
@ -545,13 +552,19 @@ def edgenext_small(pretrained=False, **kwargs):
|
|||||||
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
|
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def edgenext_base(pretrained=False, **kwargs):
|
||||||
|
# 18.51M & 3840.93M @ 256 resolution
|
||||||
|
# 82.5% (normal) 83.7% (USI) Top-1 accuracy
|
||||||
|
# AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
|
||||||
|
# Jetson FPS=xx.xx versus xx.xx for MobileViT_S
|
||||||
|
# For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx
|
||||||
|
model_kwargs = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584], **kwargs)
|
||||||
|
return _create_edgenext('edgenext_base', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def edgenext_small_rw(pretrained=False, **kwargs):
|
def edgenext_small_rw(pretrained=False, **kwargs):
|
||||||
# 5.59M & 1260.59M @ 256 resolution
|
|
||||||
# 79.43% Top-1 accuracy
|
|
||||||
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
|
|
||||||
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
|
|
||||||
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
|
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
|
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
|
||||||
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)
|
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user