mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
* [CI] Add test mim CI. (#879) * [CI] Add test mim CI. (#879) * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * update * update ci * rebase * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * update * update readme and configs * update readme and configs * refactore eva02 * [CI] Add test mim CI. (#879) * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * update * update ci * rebase * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * update * update readme and configs * refactore eva02 * update readme and metafile * update readme and metafile * update readme and metafile * update * rename eva02 * rename eva02 * fix uts * rename configs --------- Co-authored-by: Ma Zerun <mzr1996@163.com> Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
22 lines
831 B
Python
22 lines
831 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmpretrain.models.utils import (ConditionalPositionEncoding,
|
|
RotaryEmbeddingFast)
|
|
|
|
|
|
def test_conditional_position_encoding_module():
|
|
CPE = ConditionalPositionEncoding(in_channels=32, embed_dims=32, stride=2)
|
|
outs = CPE(torch.randn(1, 3136, 32), (56, 56))
|
|
assert outs.shape == torch.Size([1, 784, 32])
|
|
|
|
|
|
def test_rotary_embedding_fast_module():
|
|
RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=24)
|
|
outs = RoPE(torch.randn(1, 2, 24 * 24, 64), (24, 24))
|
|
assert outs.shape == torch.Size([1, 2, 24 * 24, 64])
|
|
|
|
RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=(14, 20))
|
|
outs = RoPE(torch.randn(1, 2, 14 * 20, 64), (14, 20))
|
|
assert outs.shape == torch.Size([1, 2, 14 * 20, 64])
|