fix a bug when samples_per_gpu==1 (#311)
parent
42b99be9b2
commit
4ebee155e8
|
@ -94,4 +94,7 @@ class ImageClassifier(BaseClassifier):
|
||||||
def simple_test(self, img, img_metas):
|
def simple_test(self, img, img_metas):
|
||||||
"""Test without augmentation."""
|
"""Test without augmentation."""
|
||||||
x = self.extract_feat(img)
|
x = self.extract_feat(img)
|
||||||
|
x_dims = len(x.shape)
|
||||||
|
if x_dims == 1:
|
||||||
|
x.unsqueeze_(0)
|
||||||
return self.head.simple_test(x)
|
return self.head.simple_test(x)
|
||||||
|
|
Loading…
Reference in New Issue