diff --git a/mmocr/models/textdet/postprocessors/db_postprocessor.py b/mmocr/models/textdet/postprocessors/db_postprocessor.py index dd4e123b..33b9560d 100644 --- a/mmocr/models/textdet/postprocessors/db_postprocessor.py +++ b/mmocr/models/textdet/postprocessors/db_postprocessor.py @@ -115,8 +115,10 @@ class DBPostprocessor(BaseTextDetPostProcessor): if len(poly) > 0: data_sample.pred_instances.polygons.append(poly) - data_sample.pred_instances.scores.append( - torch.FloatTensor([score])) + data_sample.pred_instances.scores.append(score) + + data_sample.pred_instances.scores = torch.FloatTensor( + data_sample.pred_instances.scores) return data_sample diff --git a/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py b/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py index 48896f93..169814fb 100644 --- a/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py +++ b/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py @@ -33,10 +33,14 @@ class TestDBPostProcessor(unittest.TestCase): results = postprocessor.get_text_instances(pred_result, data_sample) self.assertIn('polygons', results.pred_instances) self.assertIn('scores', results.pred_instances) + self.assertTrue( + isinstance(results.pred_instances['scores'], torch.FloatTensor)) postprocessor = DBPostprocessor( min_text_score=1, text_repr_type=text_repr_type) pred_result = dict(prob_map=torch.rand(4, 5) * 0.8) results = postprocessor.get_text_instances(pred_result, data_sample) self.assertEqual(results.pred_instances.polygons, []) - self.assertEqual(results.pred_instances.scores, []) + self.assertTrue( + isinstance(results.pred_instances['scores'], torch.FloatTensor)) + self.assertEqual(len(results.pred_instances.scores), 0)