Reduce mmcls version dependency (#635)

This commit is contained in:
q.yao 2022-06-27 09:41:36 +08:00 committed by GitHub
parent 0cac5154a6
commit ae47e9d188
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcls.models.utils import channel_shuffle
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend
@ -29,6 +28,7 @@ def shufflenetv2_backbone__forward__ncnn(ctx, self, x):
out (Tensor): A feature map output from InvertedResidual. The tensor
shape (N, Cout, H, W).
"""
from mmcls.models.utils import channel_shuffle
if self.stride > 1:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
else:

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcls.models.utils import resize_pos_embed
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend
@ -25,6 +24,7 @@ def visiontransformer__forward__ncnn(ctx, self, x):
out (Tensor): A feature map output from InvertedResidual. The tensor
shape (N, Cout, H, W).
"""
from mmcls.models.utils import resize_pos_embed
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)