Merge pull request #1708 from pkluska/chore/mvitv2-coreml-exportable

chore: Modify the MobileVitV2Block to be coreml exportable
This commit is contained in:
Ross Wightman 2023-03-11 11:52:20 -10:00 committed by GitHub
commit 82cb47bcf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,7 +20,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module from ._features_fx import register_notrace_module
from ._registry import register_model from ._registry import register_model
@ -564,6 +564,7 @@ class MobileVitV2Block(nn.Module):
self.patch_size = to_2tuple(patch_size) self.patch_size = to_2tuple(patch_size)
self.patch_area = self.patch_size[0] * self.patch_size[1] self.patch_area = self.patch_size[0] * self.patch_size[1]
self.coreml_exportable = is_exportable()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape B, C, H, W = x.shape
@ -580,6 +581,9 @@ class MobileVitV2Block(nn.Module):
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N] # Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
C = x.shape[1] C = x.shape[1]
if self.coreml_exportable:
x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
else:
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4) x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
x = x.reshape(B, C, -1, num_patches) x = x.reshape(B, C, -1, num_patches)
@ -588,9 +592,15 @@ class MobileVitV2Block(nn.Module):
x = self.norm(x) x = self.norm(x)
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W] # Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
if self.coreml_exportable:
# adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w)
x = F.pixel_shuffle(x, upscale_factor=patch_h)
else:
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3) x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w) x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
x = self.conv_proj(x) x = self.conv_proj(x)
return x return x