diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index a81bd9b0..422d4f2c 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -50,6 +50,12 @@ default_cfgs = dict( url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, ), + # edgenext_base=_cfg( + # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth"), + edgenext_base=_cfg( # USI weights + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth", + crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), edgenext_small_rw=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth', @@ -154,7 +160,7 @@ class CrossCovarianceAttn(nn.Module): def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1) q, k, v = qkv.unbind(0) # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map @@ -217,7 +223,8 @@ class SplitTransposeBlock(nn.Module): shortcut = x # scales code re-written for torchscript as per my res2net fixes -rw - spx = torch.split(x, self.width, 1) + # NOTE torch.split(x, self.width, 1) causing issues with ONNX export + spx = x.chunk(len(self.convs) + 1, dim=1) spo = [] sp = spx[0] for i, conv in enumerate(self.convs): @@ -545,13 +552,19 @@ def edgenext_small(pretrained=False, **kwargs): return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs) +@register_model +def edgenext_base(pretrained=False, **kwargs): + # 18.51M & 3840.93M @ 256 resolution + # 82.5% (normal) 83.7% (USI) Top-1 accuracy + # AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=xx.xx versus xx.xx for MobileViT_S + # For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx + model_kwargs = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584], **kwargs) + return _create_edgenext('edgenext_base', pretrained=pretrained, **model_kwargs) + + @register_model def edgenext_small_rw(pretrained=False, **kwargs): - # 5.59M & 1260.59M @ 256 resolution - # 79.43% Top-1 accuracy - # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler - # Jetson FPS=20.47 versus 18.86 for MobileViT_S - # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S model_kwargs = dict( depths=(3, 3, 9, 3), dims=(48, 96, 192, 384), downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)