mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Merge branch 'linfangjian/fix_pred_new' into 'refactor_dev'
[Fix] Fix pred See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!43
This commit is contained in:
commit
ec52a27299
@ -63,7 +63,7 @@ class IoUMetric(BaseMetric):
|
|||||||
reduce_zero_label = self.dataset_meta['reduce_zero_label']
|
reduce_zero_label = self.dataset_meta['reduce_zero_label']
|
||||||
for data, pred in zip(data_batch, predictions):
|
for data, pred in zip(data_batch, predictions):
|
||||||
label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy()
|
label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy()
|
||||||
pred_label = pred['pred_sem_seg']['data'][0]
|
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
|
||||||
self.results.append(
|
self.results.append(
|
||||||
self.intersect_and_union(pred_label, label, num_classes,
|
self.intersect_and_union(pred_label, label, num_classes,
|
||||||
self.ignore_index, label_map,
|
self.ignore_index, label_map,
|
||||||
|
@ -256,8 +256,7 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
results_dict = dict()
|
results_dict = dict()
|
||||||
seg_logit = self.inference(batch_inputs, batch_img_metas, rescale)
|
seg_logit = self.inference(batch_inputs, batch_img_metas, rescale)
|
||||||
results_dict['seg_logits'] = seg_logit
|
results_dict['seg_logits'] = seg_logit
|
||||||
seg_pred = seg_logit.argmax(dim=1)
|
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
|
||||||
seg_pred = seg_pred.cpu().numpy()
|
|
||||||
results_dict['pred_sem_seg'] = seg_pred
|
results_dict['pred_sem_seg'] = seg_pred
|
||||||
results_list = self.postprocess_result(results_dict)
|
results_list = self.postprocess_result(results_dict)
|
||||||
return results_list
|
return results_list
|
||||||
|
@ -65,6 +65,7 @@ class TestIoUMetric(TestCase):
|
|||||||
results_dict['seg_logits'] = seg_logit
|
results_dict['seg_logits'] = seg_logit
|
||||||
seg_pred = np.random.randint(
|
seg_pred = np.random.randint(
|
||||||
0, num_classes, (batch_size, h, w), dtype=np.uint8)
|
0, num_classes, (batch_size, h, w), dtype=np.uint8)
|
||||||
|
seg_pred = torch.LongTensor(seg_pred)
|
||||||
results_dict['pred_sem_seg'] = seg_pred
|
results_dict['pred_sem_seg'] = seg_pred
|
||||||
|
|
||||||
batch_datasampes = [
|
batch_datasampes = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user