mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* add twins backbone * add position_encoding * refactor twins * Supplemental unit tests * update docstring and readme * update docstring and readme * update docstring and readme * update docstring * update docstring * update docstring * update docstring * remove note * update doc and docstring * update docstring * update docstring * use abstract pdf link and rename yamlfile * Update model link Co-authored-by: mzr1996 <mzr1996@163.com>
11 lines
352 B
Python
11 lines
352 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmcls.models.utils import ConditionalPositionEncoding
|
|
|
|
|
|
def test_conditional_position_encoding_module():
|
|
CPE = ConditionalPositionEncoding(in_channels=32, embed_dims=32, stride=2)
|
|
outs = CPE(torch.randn(1, 3136, 32), (56, 56))
|
|
assert outs.shape == torch.Size([1, 784, 32])
|