fix a bug when samples_per_gpu==1 (#311)
parent
42b99be9b2
commit
4ebee155e8
|
@ -94,4 +94,7 @@ class ImageClassifier(BaseClassifier):
|
|||
def simple_test(self, img, img_metas):
|
||||
"""Test without augmentation."""
|
||||
x = self.extract_feat(img)
|
||||
x_dims = len(x.shape)
|
||||
if x_dims == 1:
|
||||
x.unsqueeze_(0)
|
||||
return self.head.simple_test(x)
|
||||
|
|
Loading…
Reference in New Issue