mirror of https://github.com/open-mmlab/mmocr.git
Fix score field in DBPostprocessor
parent
d34fad1451
commit
b585dbcdd7
|
@ -115,8 +115,10 @@ class DBPostprocessor(BaseTextDetPostProcessor):
|
|||
|
||||
if len(poly) > 0:
|
||||
data_sample.pred_instances.polygons.append(poly)
|
||||
data_sample.pred_instances.scores.append(
|
||||
torch.FloatTensor([score]))
|
||||
data_sample.pred_instances.scores.append(score)
|
||||
|
||||
data_sample.pred_instances.scores = torch.FloatTensor(
|
||||
data_sample.pred_instances.scores)
|
||||
|
||||
return data_sample
|
||||
|
||||
|
|
|
@ -33,10 +33,14 @@ class TestDBPostProcessor(unittest.TestCase):
|
|||
results = postprocessor.get_text_instances(pred_result, data_sample)
|
||||
self.assertIn('polygons', results.pred_instances)
|
||||
self.assertIn('scores', results.pred_instances)
|
||||
self.assertTrue(
|
||||
isinstance(results.pred_instances['scores'], torch.FloatTensor))
|
||||
|
||||
postprocessor = DBPostprocessor(
|
||||
min_text_score=1, text_repr_type=text_repr_type)
|
||||
pred_result = dict(prob_map=torch.rand(4, 5) * 0.8)
|
||||
results = postprocessor.get_text_instances(pred_result, data_sample)
|
||||
self.assertEqual(results.pred_instances.polygons, [])
|
||||
self.assertEqual(results.pred_instances.scores, [])
|
||||
self.assertTrue(
|
||||
isinstance(results.pred_instances['scores'], torch.FloatTensor))
|
||||
self.assertEqual(len(results.pred_instances.scores), 0)
|
||||
|
|
Loading…
Reference in New Issue