mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
Reduce mmcls version dependency (#635)
This commit is contained in:
parent
0cac5154a6
commit
ae47e9d188
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
from mmcls.models.utils import channel_shuffle
|
|
||||||
|
|
||||||
from mmdeploy.core import FUNCTION_REWRITER
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
from mmdeploy.utils import Backend
|
from mmdeploy.utils import Backend
|
||||||
@ -29,6 +28,7 @@ def shufflenetv2_backbone__forward__ncnn(ctx, self, x):
|
|||||||
out (Tensor): A feature map output from InvertedResidual. The tensor
|
out (Tensor): A feature map output from InvertedResidual. The tensor
|
||||||
shape (N, Cout, H, W).
|
shape (N, Cout, H, W).
|
||||||
"""
|
"""
|
||||||
|
from mmcls.models.utils import channel_shuffle
|
||||||
if self.stride > 1:
|
if self.stride > 1:
|
||||||
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
|
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
from mmcls.models.utils import resize_pos_embed
|
|
||||||
|
|
||||||
from mmdeploy.core import FUNCTION_REWRITER
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
from mmdeploy.utils import Backend
|
from mmdeploy.utils import Backend
|
||||||
@ -25,6 +24,7 @@ def visiontransformer__forward__ncnn(ctx, self, x):
|
|||||||
out (Tensor): A feature map output from InvertedResidual. The tensor
|
out (Tensor): A feature map output from InvertedResidual. The tensor
|
||||||
shape (N, Cout, H, W).
|
shape (N, Cout, H, W).
|
||||||
"""
|
"""
|
||||||
|
from mmcls.models.utils import resize_pos_embed
|
||||||
B = x.shape[0]
|
B = x.shape[0]
|
||||||
x, patch_resolution = self.patch_embed(x)
|
x, patch_resolution = self.patch_embed(x)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user