Logo
Explore Help
Register Sign In
mirrors/mmsegmentation
1
0
Fork 0
You've already forked mmsegmentation
mirror of https://github.com/open-mmlab/mmsegmentation.git synced 2025-06-03 22:03:48 +08:00
Code Issues Projects Releases Wiki Activity
mmsegmentation/tests/test_models/test_utils/test_drop.py

29 lines
483 B
Python
Raw Normal View History

Adjust vision transformer backbone architectures (#524) * Adjust vision transformer backbone architectures; * Add DropPath, trunc_normal_ for VisionTransformer implementation; * Add class token buring intermediate period and remove it during final period; * Fix some parameters loss bug; * * Store intermediate token features and impose no processes on them; * Remove class token and reshape entire token feature from NLC to NCHW; * Fix some doc error * Add a arg for VisionTransformer backbone to control if input class token into transformer; * Add stochastic depth decay rule for DropPath; * * Fix output bug when input_cls_token=False; * Add related unit test; * * Add arg: out_indices to control model output; * Add unit test for DropPath; * Apply suggestions from code review Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
2021-05-01 01:37:47 +08:00
import torch
from mmseg.models.utils import DropPath
def test_drop_path():
# zero drop
layer = DropPath()
# input NLC format feature
x = torch.randn((1, 16, 32))
layer(x)
# input NLHW format feature
x = torch.randn((1, 32, 4, 4))
layer(x)
# non-zero drop
layer = DropPath(0.1)
# input NLC format feature
x = torch.randn((1, 16, 32))
layer(x)
# input NLHW format feature
x = torch.randn((1, 32, 4, 4))
layer(x)
Reference in New Issue Copy Permalink
Powered by Gitea Version: 1.23.8 Page: 524ms Template: 8ms
English
Bahasa Indonesia Deutsch English Español Français Gaeilge Italiano Latviešu Magyar nyelv Nederlands Polski Português de Portugal Português do Brasil Suomi Svenska Türkçe Čeština Ελληνικά Български Русский Українська فارسی മലയാളം 日本語 简体中文 繁體中文(台灣) 繁體中文(香港) 한국어
Licenses API