diff --git a/mmcls/models/classifiers/image.py b/mmcls/models/classifiers/image.py index 043a9eb25..283d86a95 100644 --- a/mmcls/models/classifiers/image.py +++ b/mmcls/models/classifiers/image.py @@ -94,4 +94,7 @@ class ImageClassifier(BaseClassifier): def simple_test(self, img, img_metas): """Test without augmentation.""" x = self.extract_feat(img) + x_dims = len(x.shape) + if x_dims == 1: + x.unsqueeze_(0) return self.head.simple_test(x)