61 lines
2.2 KiB
Python
61 lines
2.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from mmcls.models.utils import GRN, LayerNorm2d
|
|
|
|
|
|
class TestGRN(TestCase):
|
|
|
|
def test_init(self):
|
|
module = GRN(in_channels=32, eps=1e-3)
|
|
self.assertEqual(module.in_channels, 32)
|
|
self.assertEqual(module.eps, 1e-3)
|
|
self.assertTrue(module.gamma.requires_grad)
|
|
self.assertTrue(module.beta.requires_grad)
|
|
self.assertEqual(module.gamma.shape, (32, ))
|
|
self.assertTrue(module.beta.shape, (32, ))
|
|
|
|
def test_forward(self):
|
|
module = GRN(in_channels=32, eps=1e-3)
|
|
input_ = torch.rand(1, 28, 28, 32)
|
|
gx = torch.norm(input_, p=2, dim=(1, 2), keepdim=True)
|
|
nx = gx / (gx.mean(dim=3, keepdim=True) + 1e-3)
|
|
expected_out = module.gamma * input_ * nx + module.beta + input_
|
|
|
|
torch.testing.assert_allclose(
|
|
module(input_, data_format='channel_last'), expected_out)
|
|
|
|
input_ = input_.permute([0, 3, 1, 2])
|
|
expected_out = expected_out.permute([0, 3, 1, 2])
|
|
torch.testing.assert_allclose(
|
|
module(input_, data_format='channel_first'), expected_out)
|
|
|
|
|
|
class TestLayerNorm2d(TestCase):
|
|
|
|
def test_init(self):
|
|
module = LayerNorm2d(num_channels=32, eps=1e-3)
|
|
self.assertEqual(module.num_channels, 32)
|
|
self.assertEqual(module.eps, 1e-3)
|
|
self.assertTrue(module.weight.requires_grad)
|
|
self.assertTrue(module.bias.requires_grad)
|
|
self.assertEqual(module.weight.shape, (32, ))
|
|
self.assertTrue(module.bias.shape, (32, ))
|
|
|
|
def test_forward(self):
|
|
module = LayerNorm2d(num_channels=32, eps=1e-3)
|
|
input_ = torch.rand(1, 28, 28, 32)
|
|
expected_out = F.layer_norm(input_, module.normalized_shape,
|
|
module.weight, module.bias, 1e-3)
|
|
|
|
torch.testing.assert_allclose(
|
|
module(input_, data_format='channel_last'), expected_out)
|
|
|
|
input_ = input_.permute([0, 3, 1, 2])
|
|
expected_out = expected_out.permute([0, 3, 1, 2])
|
|
torch.testing.assert_allclose(
|
|
module(input_, data_format='channel_first'), expected_out)
|