[Enhance] Rewrite channel split operation in ShufflenetV2 (#632)

* replace chunk op

* shufflenetv2 config
pull/671/head
Ezra-Yu 2022-01-25 14:45:28 +08:00 committed by GitHub
parent b5bd87d7fa
commit bd397f790f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 1 deletions

View File

@ -115,7 +115,14 @@ class InvertedResidual(BaseModule):
if self.stride > 1:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
else:
x1, x2 = x.chunk(2, dim=1)
# Channel Split operation. using these lines of code to replace
# ``chunk(x, 2, dim=1)`` can make it easier to deploy a
# shufflenetv2 model by using mmdeploy.
channels = x.shape[1]
c = channels // 2 + channels % 2
x1 = x[:, :c, :, :]
x2 = x[:, c:, :, :]
out = torch.cat((x1, self.branch2(x2)), dim=1)
out = channel_shuffle(out, 2)