# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv import ConfigDict

from mmfewshot.detection.models.utils import (AggregationLayer,
                                              DepthWiseCorrelationAggregator,
                                              DifferenceAggregator,
                                              DotProductAggregator)


def test_depth_wise_aggregator():
    # test forward w/o fc
    self = DepthWiseCorrelationAggregator(in_channels=256)
    query_feat = torch.randn(1, 256, 7, 7)
    support_feat = torch.randn(1, 256, 3, 3)
    out = self(query_feat, support_feat)
    assert out.shape == torch.Size([1, 256, 5, 5])
    # test forward w/ fc
    self = DepthWiseCorrelationAggregator(
        in_channels=256, out_channels=64, with_fc=True)
    query_feat = torch.randn(2, 256, 7, 7)
    support_feat = torch.randn(1, 256, 7, 7)
    out = self(query_feat, support_feat)
    assert out.shape == torch.Size([2, 64])


def test_diff_aggregator():
    # test forward w/o fc
    self = DifferenceAggregator(in_channels=256)
    query_feat = torch.randn(2, 256, 7, 7)
    support_feat = torch.randn(1, 256, 7, 7)
    out = self(query_feat, support_feat)
    assert out.shape == torch.Size([2, 256, 7, 7])
    # test forward w/ fc
    self = DifferenceAggregator(in_channels=256, out_channels=64, with_fc=True)
    query_feat = torch.randn(2, 256, 1, 1)
    support_feat = torch.randn(1, 256, 1, 1)
    out = self(query_feat, support_feat)
    assert out.shape == torch.Size([2, 64])


def test_dot_product_aggregator():
    # test forward w/o fc
    self = DotProductAggregator(in_channels=256)
    query_feat = torch.randn(2, 256, 7, 7)
    support_feat = torch.randn(1, 256, 7, 7)
    out = self(query_feat, support_feat)
    assert out.shape == torch.Size([2, 256, 7, 7])
    # test forward w/ fc
    self = DotProductAggregator(in_channels=256, out_channels=64, with_fc=True)
    query_feat = torch.randn(2, 256, 1, 1)
    support_feat = torch.randn(1, 256, 1, 1)
    out = self(query_feat, support_feat)
    assert out.shape == torch.Size([2, 64])


def test_aggregation_layer():
    cfg = ConfigDict(aggregator_cfgs=[
        dict(type='DepthWiseCorrelationAggregator', in_channels=256),
        dict(type='DifferenceAggregator', in_channels=256),
        dict(type='DotProductAggregator', in_channels=256),
    ])
    self = AggregationLayer(**cfg)
    query_feat = torch.randn(2, 256, 1, 1)
    support_feat = torch.randn(1, 256, 1, 1)
    out = self(query_feat, support_feat)
    assert len(out) == 3
    assert out[0].shape == torch.Size([2, 256, 1, 1])
    assert out[1].shape == torch.Size([2, 256, 1, 1])
    assert out[2].shape == torch.Size([2, 256, 1, 1])