mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
29 lines
483 B
Python
29 lines
483 B
Python
|
import torch
|
||
|
|
||
|
from mmseg.models.utils import DropPath
|
||
|
|
||
|
|
||
|
def test_drop_path():
|
||
|
|
||
|
# zero drop
|
||
|
layer = DropPath()
|
||
|
|
||
|
# input NLC format feature
|
||
|
x = torch.randn((1, 16, 32))
|
||
|
layer(x)
|
||
|
|
||
|
# input NLHW format feature
|
||
|
x = torch.randn((1, 32, 4, 4))
|
||
|
layer(x)
|
||
|
|
||
|
# non-zero drop
|
||
|
layer = DropPath(0.1)
|
||
|
|
||
|
# input NLC format feature
|
||
|
x = torch.randn((1, 16, 32))
|
||
|
layer(x)
|
||
|
|
||
|
# input NLHW format feature
|
||
|
x = torch.randn((1, 32, 4, 4))
|
||
|
layer(x)
|