Merge branch 'linfangjian/fix_pred_new' into 'refactor_dev'

[Fix] Fix pred

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!43
This commit is contained in:
zhengmiao 2022-06-14 02:53:46 +00:00
commit ec52a27299
3 changed files with 3 additions and 3 deletions

View File

@ -63,7 +63,7 @@ class IoUMetric(BaseMetric):
reduce_zero_label = self.dataset_meta['reduce_zero_label'] reduce_zero_label = self.dataset_meta['reduce_zero_label']
for data, pred in zip(data_batch, predictions): for data, pred in zip(data_batch, predictions):
label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy() label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy()
pred_label = pred['pred_sem_seg']['data'][0] pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
self.results.append( self.results.append(
self.intersect_and_union(pred_label, label, num_classes, self.intersect_and_union(pred_label, label, num_classes,
self.ignore_index, label_map, self.ignore_index, label_map,

View File

@ -256,8 +256,7 @@ class EncoderDecoder(BaseSegmentor):
results_dict = dict() results_dict = dict()
seg_logit = self.inference(batch_inputs, batch_img_metas, rescale) seg_logit = self.inference(batch_inputs, batch_img_metas, rescale)
results_dict['seg_logits'] = seg_logit results_dict['seg_logits'] = seg_logit
seg_pred = seg_logit.argmax(dim=1) seg_pred = seg_logit.argmax(dim=1, keepdim=True)
seg_pred = seg_pred.cpu().numpy()
results_dict['pred_sem_seg'] = seg_pred results_dict['pred_sem_seg'] = seg_pred
results_list = self.postprocess_result(results_dict) results_list = self.postprocess_result(results_dict)
return results_list return results_list

View File

@ -65,6 +65,7 @@ class TestIoUMetric(TestCase):
results_dict['seg_logits'] = seg_logit results_dict['seg_logits'] = seg_logit
seg_pred = np.random.randint( seg_pred = np.random.randint(
0, num_classes, (batch_size, h, w), dtype=np.uint8) 0, num_classes, (batch_size, h, w), dtype=np.uint8)
seg_pred = torch.LongTensor(seg_pred)
results_dict['pred_sem_seg'] = seg_pred results_dict['pred_sem_seg'] = seg_pred
batch_datasampes = [ batch_datasampes = [