takuoko 1047daa28e
[Feature] Support HorNet Backbone. (#1013)
* add hornet

* add hornet

* add hornet

* add hornet

* add hornet

* add hornet

* add hornet

* fix test for torch before 1.7.0

* del timm

* fix readme

* fix readme

* Update mmcls/models/backbones/hornet.py

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>

* fix docs

* fix docs

* s -> scale

* fix dims and dpr impl

* fix layer scale

* refactor gnconv

* add dw_cfg

* add convert tools

* update code

* update docs

* update readme

* update URLs

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2022-09-27 10:37:49 +08:00

49 lines
1.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmcls.models.utils import LayerScale
class TestLayerScale(TestCase):
def test_init(self):
with self.assertRaisesRegex(AssertionError, "'data_format' could"):
cfg = dict(
dim=10,
inplace=False,
data_format='BNC',
)
LayerScale(**cfg)
cfg = dict(dim=10)
ls = LayerScale(**cfg)
assert torch.equal(ls.weight,
torch.ones(10, requires_grad=True) * 1e-5)
def forward(self):
# Test channels_last
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 49, 256))
out = ls_channels_last(x)
self.assertEqual(tuple(out.size()), (4, 49, 256))
assert torch.equal(x * 1e-5, out)
# Test channels_first
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
assert torch.equal(x * 1e-5, out)
# Test inplace True
cfg = dict(dim=256, inplace=True, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
self.assertIs(x, out)