[Feature] Add with cp to mit and vit (#1431)
* add with cp to mit and vit * add test unit Co-authored-by: jiangyitong <jiangyitong1@sensetime.com>pull/1447/head
parent
17f8a96981
commit
be8f073c84
|
@ -4,6 +4,7 @@ import warnings
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||
|
@ -235,6 +236,8 @@ class TransformerEncoderLayer(BaseModule):
|
|||
Default:None.
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -248,7 +251,8 @@ class TransformerEncoderLayer(BaseModule):
|
|||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
sr_ratio=1):
|
||||
sr_ratio=1,
|
||||
with_cp=False):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
|
@ -275,9 +279,19 @@ class TransformerEncoderLayer(BaseModule):
|
|||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.with_cp = with_cp
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
||||
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
||||
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -319,6 +333,8 @@ class MixVisionTransformer(BaseModule):
|
|||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -339,7 +355,8 @@ class MixVisionTransformer(BaseModule):
|
|||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
init_cfg=None,
|
||||
with_cp=False):
|
||||
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
|
@ -358,8 +375,9 @@ class MixVisionTransformer(BaseModule):
|
|||
self.patch_sizes = patch_sizes
|
||||
self.strides = strides
|
||||
self.sr_ratios = sr_ratios
|
||||
self.with_cp = with_cp
|
||||
assert num_stages == len(num_layers) == len(num_heads) \
|
||||
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
||||
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
||||
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < self.num_stages
|
||||
|
@ -392,6 +410,7 @@ class MixVisionTransformer(BaseModule):
|
|||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
|
||||
])
|
||||
in_channels = embed_dims_i
|
||||
|
|
|
@ -4,6 +4,7 @@ import warnings
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||
|
@ -41,6 +42,8 @@ class TransformerEncoderLayer(BaseModule):
|
|||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: True.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -54,7 +57,8 @@ class TransformerEncoderLayer(BaseModule):
|
|||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True):
|
||||
batch_first=True,
|
||||
with_cp=False):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
|
@ -82,6 +86,8 @@ class TransformerEncoderLayer(BaseModule):
|
|||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.with_cp = with_cp
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
@ -91,8 +97,16 @@ class TransformerEncoderLayer(BaseModule):
|
|||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.attn(self.norm1(x), identity=x)
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.attn(self.norm1(x), identity=x)
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -251,6 +265,7 @@ class VisionTransformer(BaseModule):
|
|||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
batch_first=True))
|
||||
|
||||
self.final_norm = final_norm
|
||||
|
|
|
@ -3,7 +3,8 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmseg.models.backbones import MixVisionTransformer
|
||||
from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
|
||||
from mmseg.models.backbones.mit import (EfficientMultiheadAttention, MixFFN,
|
||||
TransformerEncoderLayer)
|
||||
|
||||
|
||||
def test_mit():
|
||||
|
@ -56,6 +57,14 @@ def test_mit():
|
|||
outs = MHA(temp, hw_shape, temp)
|
||||
assert out.shape == (1, token_len, 64)
|
||||
|
||||
# Test TransformerEncoderLayer with checkpoint forward
|
||||
block = TransformerEncoderLayer(
|
||||
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 56 * 56, 64)
|
||||
x_out = block(x, (56, 56))
|
||||
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
||||
|
||||
|
||||
def test_mit_init():
|
||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones.vit import VisionTransformer
|
||||
from mmseg.models.backbones.vit import (TransformerEncoderLayer,
|
||||
VisionTransformer)
|
||||
from .utils import check_norm_state
|
||||
|
||||
|
||||
|
@ -119,6 +120,14 @@ def test_vit_backbone():
|
|||
assert feat[0][0].shape == (1, 768, 14, 14)
|
||||
assert feat[0][1].shape == (1, 768)
|
||||
|
||||
# Test TransformerEncoderLayer with checkpoint forward
|
||||
block = TransformerEncoderLayer(
|
||||
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 56 * 56, 64)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
||||
|
||||
|
||||
def test_vit_init():
|
||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||
|
|
Loading…
Reference in New Issue