mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
chore: Modify the MobileVitV2Block to be coreml exportable
based on is_exportable() set variable controlling behaviour of the block CoreMLTools support im2col from 6.2 version, unfortunately col2im is still not supported. Tested with exporting to ONNX, Torchscript, CoreML, and TVM.
This commit is contained in:
parent
4b8cfa6c0a
commit
992bf7c3d4
@ -20,7 +20,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath
|
||||
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._registry import register_model
|
||||
@ -564,6 +564,7 @@ class MobileVitV2Block(nn.Module):
|
||||
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.patch_area = self.patch_size[0] * self.patch_size[1]
|
||||
self.coreml_exportable = is_exportable()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, C, H, W = x.shape
|
||||
@ -580,7 +581,10 @@ class MobileVitV2Block(nn.Module):
|
||||
|
||||
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
|
||||
C = x.shape[1]
|
||||
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
|
||||
if self.coreml_exportable:
|
||||
x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
|
||||
else:
|
||||
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
|
||||
x = x.reshape(B, C, -1, num_patches)
|
||||
|
||||
# Global representations
|
||||
@ -588,8 +592,14 @@ class MobileVitV2Block(nn.Module):
|
||||
x = self.norm(x)
|
||||
|
||||
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
|
||||
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
|
||||
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||
if self.coreml_exportable:
|
||||
# adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
|
||||
x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w)
|
||||
x = F.pixel_shuffle(x, upscale_factor=patch_h)
|
||||
else:
|
||||
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
|
||||
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||
|
||||
|
||||
x = self.conv_proj(x)
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user