chore: Modify the MobileVitV2Block to be coreml exportable

based on is_exportable() set variable controlling behaviour of the block
CoreMLTools support im2col from 6.2 version, unfortunately col2im
is still not supported.

Tested with exporting to ONNX, Torchscript, CoreML, and TVM.
This commit is contained in:
Piotr Sebastian Kluska 2023-03-03 09:27:00 +01:00
parent 4b8cfa6c0a
commit 992bf7c3d4

View File

@ -20,7 +20,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._registry import register_model
@ -564,6 +564,7 @@ class MobileVitV2Block(nn.Module):
self.patch_size = to_2tuple(patch_size)
self.patch_area = self.patch_size[0] * self.patch_size[1]
self.coreml_exportable = is_exportable()
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
@ -580,7 +581,10 @@ class MobileVitV2Block(nn.Module):
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
C = x.shape[1]
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
if self.coreml_exportable:
x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
else:
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
x = x.reshape(B, C, -1, num_patches)
# Global representations
@ -588,8 +592,14 @@ class MobileVitV2Block(nn.Module):
x = self.norm(x)
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
if self.coreml_exportable:
# adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w)
x = F.pixel_shuffle(x, upscale_factor=patch_h)
else:
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
x = self.conv_proj(x)
return x