mmclassification/tests/test_models/test_utils/test_norm.py

61 lines
2.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
import torch.nn.functional as F
from mmcls.models.utils import GRN, LayerNorm2d
class TestGRN(TestCase):
def test_init(self):
module = GRN(in_channels=32, eps=1e-3)
self.assertEqual(module.in_channels, 32)
self.assertEqual(module.eps, 1e-3)
self.assertTrue(module.gamma.requires_grad)
self.assertTrue(module.beta.requires_grad)
self.assertEqual(module.gamma.shape, (32, ))
self.assertTrue(module.beta.shape, (32, ))
def test_forward(self):
module = GRN(in_channels=32, eps=1e-3)
input_ = torch.rand(1, 28, 28, 32)
gx = torch.norm(input_, p=2, dim=(1, 2), keepdim=True)
nx = gx / (gx.mean(dim=3, keepdim=True) + 1e-3)
expected_out = module.gamma * input_ * nx + module.beta + input_
torch.testing.assert_allclose(
module(input_, data_format='channel_last'), expected_out)
input_ = input_.permute([0, 3, 1, 2])
expected_out = expected_out.permute([0, 3, 1, 2])
torch.testing.assert_allclose(
module(input_, data_format='channel_first'), expected_out)
class TestLayerNorm2d(TestCase):
def test_init(self):
module = LayerNorm2d(num_channels=32, eps=1e-3)
self.assertEqual(module.num_channels, 32)
self.assertEqual(module.eps, 1e-3)
self.assertTrue(module.weight.requires_grad)
self.assertTrue(module.bias.requires_grad)
self.assertEqual(module.weight.shape, (32, ))
self.assertTrue(module.bias.shape, (32, ))
def test_forward(self):
module = LayerNorm2d(num_channels=32, eps=1e-3)
input_ = torch.rand(1, 28, 28, 32)
expected_out = F.layer_norm(input_, module.normalized_shape,
module.weight, module.bias, 1e-3)
torch.testing.assert_allclose(
module(input_, data_format='channel_last'), expected_out)
input_ = input_.permute([0, 3, 1, 2])
expected_out = expected_out.permute([0, 3, 1, 2])
torch.testing.assert_allclose(
module(input_, data_format='channel_first'), expected_out)