mmpretrain/tests/test_models/test_utils/test_position_encoding.py
zzc98 034919d032
[Feature] add eva02 backbone (#1450)
* [CI] Add test mim CI. (#879)

* [CI] Add test mim CI. (#879)

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update ci

* rebase

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update readme and configs

* update readme and configs

* refactore eva02

* [CI] Add test mim CI. (#879)

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update ci

* rebase

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update readme and configs

* refactore eva02

* update readme and metafile

* update readme and metafile

* update readme and metafile

* update

* rename eva02

* rename eva02

* fix uts

* rename configs

---------

Co-authored-by: Ma Zerun <mzr1996@163.com>
Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2023-05-06 19:28:31 +08:00

22 lines
831 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmpretrain.models.utils import (ConditionalPositionEncoding,
RotaryEmbeddingFast)
def test_conditional_position_encoding_module():
CPE = ConditionalPositionEncoding(in_channels=32, embed_dims=32, stride=2)
outs = CPE(torch.randn(1, 3136, 32), (56, 56))
assert outs.shape == torch.Size([1, 784, 32])
def test_rotary_embedding_fast_module():
RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=24)
outs = RoPE(torch.randn(1, 2, 24 * 24, 64), (24, 24))
assert outs.shape == torch.Size([1, 2, 24 * 24, 64])
RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=(14, 20))
outs = RoPE(torch.randn(1, 2, 14 * 20, 64), (14, 20))
assert outs.shape == torch.Size([1, 2, 14 * 20, 64])