fix a bug when samples_per_gpu==1 (#311)

pull/338/head
Mingqiang Ning 2021-06-30 07:57:21 -05:00 committed by GitHub
parent 42b99be9b2
commit 4ebee155e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 0 deletions

View File

@ -94,4 +94,7 @@ class ImageClassifier(BaseClassifier):
def simple_test(self, img, img_metas):
"""Test without augmentation."""
x = self.extract_feat(img)
x_dims = len(x.shape)
if x_dims == 1:
x.unsqueeze_(0)
return self.head.simple_test(x)