mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Merge branch 'linfangjian/fix_pred_new' into 'refactor_dev'
[Fix] Fix pred See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!43
This commit is contained in:
commit
ec52a27299
@ -63,7 +63,7 @@ class IoUMetric(BaseMetric):
|
||||
reduce_zero_label = self.dataset_meta['reduce_zero_label']
|
||||
for data, pred in zip(data_batch, predictions):
|
||||
label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy()
|
||||
pred_label = pred['pred_sem_seg']['data'][0]
|
||||
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
|
||||
self.results.append(
|
||||
self.intersect_and_union(pred_label, label, num_classes,
|
||||
self.ignore_index, label_map,
|
||||
|
@ -256,8 +256,7 @@ class EncoderDecoder(BaseSegmentor):
|
||||
results_dict = dict()
|
||||
seg_logit = self.inference(batch_inputs, batch_img_metas, rescale)
|
||||
results_dict['seg_logits'] = seg_logit
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
seg_pred = seg_pred.cpu().numpy()
|
||||
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
|
||||
results_dict['pred_sem_seg'] = seg_pred
|
||||
results_list = self.postprocess_result(results_dict)
|
||||
return results_list
|
||||
|
@ -65,6 +65,7 @@ class TestIoUMetric(TestCase):
|
||||
results_dict['seg_logits'] = seg_logit
|
||||
seg_pred = np.random.randint(
|
||||
0, num_classes, (batch_size, h, w), dtype=np.uint8)
|
||||
seg_pred = torch.LongTensor(seg_pred)
|
||||
results_dict['pred_sem_seg'] = seg_pred
|
||||
|
||||
batch_datasampes = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user