mmfewshot/tests/test_detection_models/test_detection_utils.py

71 lines
2.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv import ConfigDict
from mmfewshot.detection.models.utils import (AggregationLayer,
DepthWiseCorrelationAggregator,
DifferenceAggregator,
DotProductAggregator)
def test_depth_wise_aggregator():
# test forward w/o fc
self = DepthWiseCorrelationAggregator(in_channels=256)
query_feat = torch.randn(1, 256, 7, 7)
support_feat = torch.randn(1, 256, 3, 3)
out = self(query_feat, support_feat)
assert out.shape == torch.Size([1, 256, 5, 5])
# test forward w/ fc
self = DepthWiseCorrelationAggregator(
in_channels=256, out_channels=64, with_fc=True)
query_feat = torch.randn(2, 256, 7, 7)
support_feat = torch.randn(1, 256, 7, 7)
out = self(query_feat, support_feat)
assert out.shape == torch.Size([2, 64])
def test_diff_aggregator():
# test forward w/o fc
self = DifferenceAggregator(in_channels=256)
query_feat = torch.randn(2, 256, 7, 7)
support_feat = torch.randn(1, 256, 7, 7)
out = self(query_feat, support_feat)
assert out.shape == torch.Size([2, 256, 7, 7])
# test forward w/ fc
self = DifferenceAggregator(in_channels=256, out_channels=64, with_fc=True)
query_feat = torch.randn(2, 256, 1, 1)
support_feat = torch.randn(1, 256, 1, 1)
out = self(query_feat, support_feat)
assert out.shape == torch.Size([2, 64])
def test_dot_product_aggregator():
# test forward w/o fc
self = DotProductAggregator(in_channels=256)
query_feat = torch.randn(2, 256, 7, 7)
support_feat = torch.randn(1, 256, 7, 7)
out = self(query_feat, support_feat)
assert out.shape == torch.Size([2, 256, 7, 7])
# test forward w/ fc
self = DotProductAggregator(in_channels=256, out_channels=64, with_fc=True)
query_feat = torch.randn(2, 256, 1, 1)
support_feat = torch.randn(1, 256, 1, 1)
out = self(query_feat, support_feat)
assert out.shape == torch.Size([2, 64])
def test_aggregation_layer():
cfg = ConfigDict(aggregator_cfgs=[
dict(type='DepthWiseCorrelationAggregator', in_channels=256),
dict(type='DifferenceAggregator', in_channels=256),
dict(type='DotProductAggregator', in_channels=256),
])
self = AggregationLayer(**cfg)
query_feat = torch.randn(2, 256, 1, 1)
support_feat = torch.randn(1, 256, 1, 1)
out = self(query_feat, support_feat)
assert len(out) == 3
assert out[0].shape == torch.Size([2, 256, 1, 1])
assert out[1].shape == torch.Size([2, 256, 1, 1])
assert out[2].shape == torch.Size([2, 256, 1, 1])