71 lines
2.7 KiB
Python
71 lines
2.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
from mmcv import ConfigDict
|
|
|
|
from mmfewshot.detection.models.utils import (AggregationLayer,
|
|
DepthWiseCorrelationAggregator,
|
|
DifferenceAggregator,
|
|
DotProductAggregator)
|
|
|
|
|
|
def test_depth_wise_aggregator():
|
|
# test forward w/o fc
|
|
self = DepthWiseCorrelationAggregator(in_channels=256)
|
|
query_feat = torch.randn(1, 256, 7, 7)
|
|
support_feat = torch.randn(1, 256, 3, 3)
|
|
out = self(query_feat, support_feat)
|
|
assert out.shape == torch.Size([1, 256, 5, 5])
|
|
# test forward w/ fc
|
|
self = DepthWiseCorrelationAggregator(
|
|
in_channels=256, out_channels=64, with_fc=True)
|
|
query_feat = torch.randn(2, 256, 7, 7)
|
|
support_feat = torch.randn(1, 256, 7, 7)
|
|
out = self(query_feat, support_feat)
|
|
assert out.shape == torch.Size([2, 64])
|
|
|
|
|
|
def test_diff_aggregator():
|
|
# test forward w/o fc
|
|
self = DifferenceAggregator(in_channels=256)
|
|
query_feat = torch.randn(2, 256, 7, 7)
|
|
support_feat = torch.randn(1, 256, 7, 7)
|
|
out = self(query_feat, support_feat)
|
|
assert out.shape == torch.Size([2, 256, 7, 7])
|
|
# test forward w/ fc
|
|
self = DifferenceAggregator(in_channels=256, out_channels=64, with_fc=True)
|
|
query_feat = torch.randn(2, 256, 1, 1)
|
|
support_feat = torch.randn(1, 256, 1, 1)
|
|
out = self(query_feat, support_feat)
|
|
assert out.shape == torch.Size([2, 64])
|
|
|
|
|
|
def test_dot_product_aggregator():
|
|
# test forward w/o fc
|
|
self = DotProductAggregator(in_channels=256)
|
|
query_feat = torch.randn(2, 256, 7, 7)
|
|
support_feat = torch.randn(1, 256, 7, 7)
|
|
out = self(query_feat, support_feat)
|
|
assert out.shape == torch.Size([2, 256, 7, 7])
|
|
# test forward w/ fc
|
|
self = DotProductAggregator(in_channels=256, out_channels=64, with_fc=True)
|
|
query_feat = torch.randn(2, 256, 1, 1)
|
|
support_feat = torch.randn(1, 256, 1, 1)
|
|
out = self(query_feat, support_feat)
|
|
assert out.shape == torch.Size([2, 64])
|
|
|
|
|
|
def test_aggregation_layer():
|
|
cfg = ConfigDict(aggregator_cfgs=[
|
|
dict(type='DepthWiseCorrelationAggregator', in_channels=256),
|
|
dict(type='DifferenceAggregator', in_channels=256),
|
|
dict(type='DotProductAggregator', in_channels=256),
|
|
])
|
|
self = AggregationLayer(**cfg)
|
|
query_feat = torch.randn(2, 256, 1, 1)
|
|
support_feat = torch.randn(1, 256, 1, 1)
|
|
out = self(query_feat, support_feat)
|
|
assert len(out) == 3
|
|
assert out[0].shape == torch.Size([2, 256, 1, 1])
|
|
assert out[1].shape == torch.Size([2, 256, 1, 1])
|
|
assert out[2].shape == torch.Size([2, 256, 1, 1])
|