diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index 8e8f4428..6466f127 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -20,7 +20,7 @@ import torch import torch.nn.functional as F from torch import nn -from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath +from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module from ._registry import register_model @@ -564,6 +564,7 @@ class MobileVitV2Block(nn.Module): self.patch_size = to_2tuple(patch_size) self.patch_area = self.patch_size[0] * self.patch_size[1] + self.coreml_exportable = is_exportable() def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape @@ -580,7 +581,10 @@ class MobileVitV2Block(nn.Module): # Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N] C = x.shape[1] - x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4) + if self.coreml_exportable: + x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w)) + else: + x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4) x = x.reshape(B, C, -1, num_patches) # Global representations @@ -588,8 +592,14 @@ class MobileVitV2Block(nn.Module): x = self.norm(x) # Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W] - x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3) - x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w) + if self.coreml_exportable: + # adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624 + x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w) + x = F.pixel_shuffle(x, upscale_factor=patch_h) + else: + x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3) + x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w) + x = self.conv_proj(x) return x