[Feature] Add with cp to mit and vit (#1431)

* add with cp to mit and vit

* add test unit

Co-authored-by: jiangyitong <jiangyitong1@sensetime.com>
pull/1447/head
jiangyitong 2022-04-01 21:01:45 +08:00 committed by GitHub
parent 17f8a96981
commit be8f073c84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 10 deletions

View File

@ -4,6 +4,7 @@ import warnings
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import MultiheadAttention
@ -235,6 +236,8 @@ class TransformerEncoderLayer(BaseModule):
Default:None.
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
Attention of Segformer. Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""
def __init__(self,
@ -248,7 +251,8 @@ class TransformerEncoderLayer(BaseModule):
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True,
sr_ratio=1):
sr_ratio=1,
with_cp=False):
super(TransformerEncoderLayer, self).__init__()
# The ret[0] of build_norm_layer is norm name.
@ -275,9 +279,19 @@ class TransformerEncoderLayer(BaseModule):
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
self.with_cp = with_cp
def forward(self, x, hw_shape):
x = self.attn(self.norm1(x), hw_shape, identity=x)
x = self.ffn(self.norm2(x), hw_shape, identity=x)
def _inner_forward(x):
x = self.attn(self.norm1(x), hw_shape, identity=x)
x = self.ffn(self.norm2(x), hw_shape, identity=x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
@ -319,6 +333,8 @@ class MixVisionTransformer(BaseModule):
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""
def __init__(self,
@ -339,7 +355,8 @@ class MixVisionTransformer(BaseModule):
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN', eps=1e-6),
pretrained=None,
init_cfg=None):
init_cfg=None,
with_cp=False):
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \
@ -358,8 +375,9 @@ class MixVisionTransformer(BaseModule):
self.patch_sizes = patch_sizes
self.strides = strides
self.sr_ratios = sr_ratios
self.with_cp = with_cp
assert num_stages == len(num_layers) == len(num_heads) \
== len(patch_sizes) == len(strides) == len(sr_ratios)
== len(patch_sizes) == len(strides) == len(sr_ratios)
self.out_indices = out_indices
assert max(out_indices) < self.num_stages
@ -392,6 +410,7 @@ class MixVisionTransformer(BaseModule):
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
])
in_channels = embed_dims_i

View File

@ -4,6 +4,7 @@ import warnings
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
@ -41,6 +42,8 @@ class TransformerEncoderLayer(BaseModule):
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""
def __init__(self,
@ -54,7 +57,8 @@ class TransformerEncoderLayer(BaseModule):
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True):
batch_first=True,
with_cp=False):
super(TransformerEncoderLayer, self).__init__()
self.norm1_name, norm1 = build_norm_layer(
@ -82,6 +86,8 @@ class TransformerEncoderLayer(BaseModule):
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
self.with_cp = with_cp
@property
def norm1(self):
return getattr(self, self.norm1_name)
@ -91,8 +97,16 @@ class TransformerEncoderLayer(BaseModule):
return getattr(self, self.norm2_name)
def forward(self, x):
x = self.attn(self.norm1(x), identity=x)
x = self.ffn(self.norm2(x), identity=x)
def _inner_forward(x):
x = self.attn(self.norm1(x), identity=x)
x = self.ffn(self.norm2(x), identity=x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
@ -251,6 +265,7 @@ class VisionTransformer(BaseModule):
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
batch_first=True))
self.final_norm = final_norm

View File

@ -3,7 +3,8 @@ import pytest
import torch
from mmseg.models.backbones import MixVisionTransformer
from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
from mmseg.models.backbones.mit import (EfficientMultiheadAttention, MixFFN,
TransformerEncoderLayer)
def test_mit():
@ -56,6 +57,14 @@ def test_mit():
outs = MHA(temp, hw_shape, temp)
assert out.shape == (1, token_len, 64)
# Test TransformerEncoderLayer with checkpoint forward
block = TransformerEncoderLayer(
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
assert block.with_cp
x = torch.randn(1, 56 * 56, 64)
x_out = block(x, (56, 56))
assert x_out.shape == torch.Size([1, 56 * 56, 64])
def test_mit_init():
path = 'PATH_THAT_DO_NOT_EXIST'

View File

@ -2,7 +2,8 @@
import pytest
import torch
from mmseg.models.backbones.vit import VisionTransformer
from mmseg.models.backbones.vit import (TransformerEncoderLayer,
VisionTransformer)
from .utils import check_norm_state
@ -119,6 +120,14 @@ def test_vit_backbone():
assert feat[0][0].shape == (1, 768, 14, 14)
assert feat[0][1].shape == (1, 768)
# Test TransformerEncoderLayer with checkpoint forward
block = TransformerEncoderLayer(
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
assert block.with_cp
x = torch.randn(1, 56 * 56, 64)
x_out = block(x)
assert x_out.shape == torch.Size([1, 56 * 56, 64])
def test_vit_init():
path = 'PATH_THAT_DO_NOT_EXIST'