mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1708 from pkluska/chore/mvitv2-coreml-exportable
chore: Modify the MobileVitV2Block to be coreml exportable
This commit is contained in:
commit
82cb47bcf3
@ -20,7 +20,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath
|
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features_fx import register_notrace_module
|
from ._features_fx import register_notrace_module
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
@ -564,6 +564,7 @@ class MobileVitV2Block(nn.Module):
|
|||||||
|
|
||||||
self.patch_size = to_2tuple(patch_size)
|
self.patch_size = to_2tuple(patch_size)
|
||||||
self.patch_area = self.patch_size[0] * self.patch_size[1]
|
self.patch_area = self.patch_size[0] * self.patch_size[1]
|
||||||
|
self.coreml_exportable = is_exportable()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
@ -580,7 +581,10 @@ class MobileVitV2Block(nn.Module):
|
|||||||
|
|
||||||
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
|
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
|
||||||
C = x.shape[1]
|
C = x.shape[1]
|
||||||
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
|
if self.coreml_exportable:
|
||||||
|
x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
|
||||||
|
else:
|
||||||
|
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
|
||||||
x = x.reshape(B, C, -1, num_patches)
|
x = x.reshape(B, C, -1, num_patches)
|
||||||
|
|
||||||
# Global representations
|
# Global representations
|
||||||
@ -588,8 +592,14 @@ class MobileVitV2Block(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
|
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
|
||||||
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
|
if self.coreml_exportable:
|
||||||
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
|
# adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
|
||||||
|
x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w)
|
||||||
|
x = F.pixel_shuffle(x, upscale_factor=patch_h)
|
||||||
|
else:
|
||||||
|
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
|
||||||
|
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||||
|
|
||||||
|
|
||||||
x = self.conv_proj(x)
|
x = self.conv_proj(x)
|
||||||
return x
|
return x
|
||||||
|
Loading…
x
Reference in New Issue
Block a user