Fix score field in DBPostprocessor

pull/1178/head
gaotongxiao 2022-05-31 14:11:50 +08:00
parent d34fad1451
commit b585dbcdd7
2 changed files with 9 additions and 3 deletions

View File

@ -115,8 +115,10 @@ class DBPostprocessor(BaseTextDetPostProcessor):
if len(poly) > 0:
data_sample.pred_instances.polygons.append(poly)
data_sample.pred_instances.scores.append(
torch.FloatTensor([score]))
data_sample.pred_instances.scores.append(score)
data_sample.pred_instances.scores = torch.FloatTensor(
data_sample.pred_instances.scores)
return data_sample

View File

@ -33,10 +33,14 @@ class TestDBPostProcessor(unittest.TestCase):
results = postprocessor.get_text_instances(pred_result, data_sample)
self.assertIn('polygons', results.pred_instances)
self.assertIn('scores', results.pred_instances)
self.assertTrue(
isinstance(results.pred_instances['scores'], torch.FloatTensor))
postprocessor = DBPostprocessor(
min_text_score=1, text_repr_type=text_repr_type)
pred_result = dict(prob_map=torch.rand(4, 5) * 0.8)
results = postprocessor.get_text_instances(pred_result, data_sample)
self.assertEqual(results.pred_instances.polygons, [])
self.assertEqual(results.pred_instances.scores, [])
self.assertTrue(
isinstance(results.pred_instances['scores'], torch.FloatTensor))
self.assertEqual(len(results.pred_instances.scores), 0)