[Feature] Add with cp to mit and vit (#1431)
* add with cp to mit and vit * add test unit Co-authored-by: jiangyitong <jiangyitong1@sensetime.com>pull/1447/head
parent
17f8a96981
commit
be8f073c84
|
@ -4,6 +4,7 @@ import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint as cp
|
||||||
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
||||||
from mmcv.cnn.bricks.drop import build_dropout
|
from mmcv.cnn.bricks.drop import build_dropout
|
||||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||||
|
@ -235,6 +236,8 @@ class TransformerEncoderLayer(BaseModule):
|
||||||
Default:None.
|
Default:None.
|
||||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||||
Attention of Segformer. Default: 1.
|
Attention of Segformer. Default: 1.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||||
|
some memory while slowing down the training speed. Default: False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -248,7 +251,8 @@ class TransformerEncoderLayer(BaseModule):
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_cfg=dict(type='LN'),
|
norm_cfg=dict(type='LN'),
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
sr_ratio=1):
|
sr_ratio=1,
|
||||||
|
with_cp=False):
|
||||||
super(TransformerEncoderLayer, self).__init__()
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
|
|
||||||
# The ret[0] of build_norm_layer is norm name.
|
# The ret[0] of build_norm_layer is norm name.
|
||||||
|
@ -275,11 +279,21 @@ class TransformerEncoderLayer(BaseModule):
|
||||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||||
act_cfg=act_cfg)
|
act_cfg=act_cfg)
|
||||||
|
|
||||||
|
self.with_cp = with_cp
|
||||||
|
|
||||||
def forward(self, x, hw_shape):
|
def forward(self, x, hw_shape):
|
||||||
|
|
||||||
|
def _inner_forward(x):
|
||||||
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
||||||
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
if self.with_cp and x.requires_grad:
|
||||||
|
x = cp.checkpoint(_inner_forward, x)
|
||||||
|
else:
|
||||||
|
x = _inner_forward(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class MixVisionTransformer(BaseModule):
|
class MixVisionTransformer(BaseModule):
|
||||||
|
@ -319,6 +333,8 @@ class MixVisionTransformer(BaseModule):
|
||||||
pretrained (str, optional): model pretrained path. Default: None.
|
pretrained (str, optional): model pretrained path. Default: None.
|
||||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
Default: None.
|
Default: None.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||||
|
some memory while slowing down the training speed. Default: False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -339,7 +355,8 @@ class MixVisionTransformer(BaseModule):
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_cfg=dict(type='LN', eps=1e-6),
|
norm_cfg=dict(type='LN', eps=1e-6),
|
||||||
pretrained=None,
|
pretrained=None,
|
||||||
init_cfg=None):
|
init_cfg=None,
|
||||||
|
with_cp=False):
|
||||||
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)
|
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
assert not (init_cfg and pretrained), \
|
assert not (init_cfg and pretrained), \
|
||||||
|
@ -358,6 +375,7 @@ class MixVisionTransformer(BaseModule):
|
||||||
self.patch_sizes = patch_sizes
|
self.patch_sizes = patch_sizes
|
||||||
self.strides = strides
|
self.strides = strides
|
||||||
self.sr_ratios = sr_ratios
|
self.sr_ratios = sr_ratios
|
||||||
|
self.with_cp = with_cp
|
||||||
assert num_stages == len(num_layers) == len(num_heads) \
|
assert num_stages == len(num_layers) == len(num_heads) \
|
||||||
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
||||||
|
|
||||||
|
@ -392,6 +410,7 @@ class MixVisionTransformer(BaseModule):
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
act_cfg=act_cfg,
|
act_cfg=act_cfg,
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
|
with_cp=with_cp,
|
||||||
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
|
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
|
||||||
])
|
])
|
||||||
in_channels = embed_dims_i
|
in_channels = embed_dims_i
|
||||||
|
|
|
@ -4,6 +4,7 @@ import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint as cp
|
||||||
from mmcv.cnn import build_norm_layer
|
from mmcv.cnn import build_norm_layer
|
||||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||||
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||||
|
@ -41,6 +42,8 @@ class TransformerEncoderLayer(BaseModule):
|
||||||
batch_first (bool): Key, Query and Value are shape of
|
batch_first (bool): Key, Query and Value are shape of
|
||||||
(batch, n, embed_dim)
|
(batch, n, embed_dim)
|
||||||
or (n, batch, embed_dim). Default: True.
|
or (n, batch, embed_dim). Default: True.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||||
|
some memory while slowing down the training speed. Default: False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -54,7 +57,8 @@ class TransformerEncoderLayer(BaseModule):
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_cfg=dict(type='LN'),
|
norm_cfg=dict(type='LN'),
|
||||||
batch_first=True):
|
batch_first=True,
|
||||||
|
with_cp=False):
|
||||||
super(TransformerEncoderLayer, self).__init__()
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
|
|
||||||
self.norm1_name, norm1 = build_norm_layer(
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
|
@ -82,6 +86,8 @@ class TransformerEncoderLayer(BaseModule):
|
||||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||||
act_cfg=act_cfg)
|
act_cfg=act_cfg)
|
||||||
|
|
||||||
|
self.with_cp = with_cp
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def norm1(self):
|
def norm1(self):
|
||||||
return getattr(self, self.norm1_name)
|
return getattr(self, self.norm1_name)
|
||||||
|
@ -91,10 +97,18 @@ class TransformerEncoderLayer(BaseModule):
|
||||||
return getattr(self, self.norm2_name)
|
return getattr(self, self.norm2_name)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
|
def _inner_forward(x):
|
||||||
x = self.attn(self.norm1(x), identity=x)
|
x = self.attn(self.norm1(x), identity=x)
|
||||||
x = self.ffn(self.norm2(x), identity=x)
|
x = self.ffn(self.norm2(x), identity=x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
if self.with_cp and x.requires_grad:
|
||||||
|
x = cp.checkpoint(_inner_forward, x)
|
||||||
|
else:
|
||||||
|
x = _inner_forward(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class VisionTransformer(BaseModule):
|
class VisionTransformer(BaseModule):
|
||||||
|
@ -251,6 +265,7 @@ class VisionTransformer(BaseModule):
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
act_cfg=act_cfg,
|
act_cfg=act_cfg,
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
|
with_cp=with_cp,
|
||||||
batch_first=True))
|
batch_first=True))
|
||||||
|
|
||||||
self.final_norm = final_norm
|
self.final_norm = final_norm
|
||||||
|
|
|
@ -3,7 +3,8 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmseg.models.backbones import MixVisionTransformer
|
from mmseg.models.backbones import MixVisionTransformer
|
||||||
from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
|
from mmseg.models.backbones.mit import (EfficientMultiheadAttention, MixFFN,
|
||||||
|
TransformerEncoderLayer)
|
||||||
|
|
||||||
|
|
||||||
def test_mit():
|
def test_mit():
|
||||||
|
@ -56,6 +57,14 @@ def test_mit():
|
||||||
outs = MHA(temp, hw_shape, temp)
|
outs = MHA(temp, hw_shape, temp)
|
||||||
assert out.shape == (1, token_len, 64)
|
assert out.shape == (1, token_len, 64)
|
||||||
|
|
||||||
|
# Test TransformerEncoderLayer with checkpoint forward
|
||||||
|
block = TransformerEncoderLayer(
|
||||||
|
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
|
||||||
|
assert block.with_cp
|
||||||
|
x = torch.randn(1, 56 * 56, 64)
|
||||||
|
x_out = block(x, (56, 56))
|
||||||
|
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
||||||
|
|
||||||
|
|
||||||
def test_mit_init():
|
def test_mit_init():
|
||||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||||
|
|
|
@ -2,7 +2,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmseg.models.backbones.vit import VisionTransformer
|
from mmseg.models.backbones.vit import (TransformerEncoderLayer,
|
||||||
|
VisionTransformer)
|
||||||
from .utils import check_norm_state
|
from .utils import check_norm_state
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,6 +120,14 @@ def test_vit_backbone():
|
||||||
assert feat[0][0].shape == (1, 768, 14, 14)
|
assert feat[0][0].shape == (1, 768, 14, 14)
|
||||||
assert feat[0][1].shape == (1, 768)
|
assert feat[0][1].shape == (1, 768)
|
||||||
|
|
||||||
|
# Test TransformerEncoderLayer with checkpoint forward
|
||||||
|
block = TransformerEncoderLayer(
|
||||||
|
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
|
||||||
|
assert block.with_cp
|
||||||
|
x = torch.randn(1, 56 * 56, 64)
|
||||||
|
x_out = block(x)
|
||||||
|
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
||||||
|
|
||||||
|
|
||||||
def test_vit_init():
|
def test_vit_init():
|
||||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||||
|
|
Loading…
Reference in New Issue