# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import numpy as np
import pytest
import torch
from mmengine.structures import PixelData

from mmseg.structures import SegDataSample


def _equal(a, b):
    if isinstance(a, (torch.Tensor, np.ndarray)):
        return (a == b).all()
    else:
        return a == b


class TestSegDataSample(TestCase):

    def test_init(self):
        meta_info = dict(
            img_size=[256, 256],
            scale_factor=np.array([1.5, 1.5]),
            img_shape=torch.rand(4))

        seg_data_sample = SegDataSample(metainfo=meta_info)
        assert 'img_size' in seg_data_sample
        assert seg_data_sample.img_size == [256, 256]
        assert seg_data_sample.get('img_size') == [256, 256]

    def test_setter(self):
        seg_data_sample = SegDataSample()

        # test gt_sem_seg
        gt_sem_seg_data = dict(sem_seg=torch.rand(5, 4, 2))
        gt_sem_seg = PixelData(**gt_sem_seg_data)
        seg_data_sample.gt_sem_seg = gt_sem_seg
        assert 'gt_sem_seg' in seg_data_sample
        assert _equal(seg_data_sample.gt_sem_seg.sem_seg,
                      gt_sem_seg_data['sem_seg'])

        # test pred_sem_seg
        pred_sem_seg_data = dict(sem_seg=torch.rand(5, 4, 2))
        pred_sem_seg = PixelData(**pred_sem_seg_data)
        seg_data_sample.pred_sem_seg = pred_sem_seg
        assert 'pred_sem_seg' in seg_data_sample
        assert _equal(seg_data_sample.pred_sem_seg.sem_seg,
                      pred_sem_seg_data['sem_seg'])

        # test seg_logits
        seg_logits_data = dict(sem_seg=torch.rand(5, 4, 2))
        seg_logits = PixelData(**seg_logits_data)
        seg_data_sample.seg_logits = seg_logits
        assert 'seg_logits' in seg_data_sample
        assert _equal(seg_data_sample.seg_logits.sem_seg,
                      seg_logits_data['sem_seg'])

        # test type error
        with pytest.raises(AssertionError):
            seg_data_sample.gt_sem_seg = torch.rand(2, 4)

        with pytest.raises(AssertionError):
            seg_data_sample.pred_sem_seg = torch.rand(2, 4)

        with pytest.raises(AssertionError):
            seg_data_sample.seg_logits = torch.rand(2, 4)

    def test_deleter(self):
        seg_data_sample = SegDataSample()

        pred_sem_seg_data = dict(sem_seg=torch.rand(5, 4, 2))
        pred_sem_seg = PixelData(**pred_sem_seg_data)
        seg_data_sample.pred_sem_seg = pred_sem_seg
        assert 'pred_sem_seg' in seg_data_sample
        del seg_data_sample.pred_sem_seg
        assert 'pred_sem_seg' not in seg_data_sample