diff --git a/mmseg/models/backbones/mit.py b/mmseg/models/backbones/mit.py index c97213a4a..4417cf113 100644 --- a/mmseg/models/backbones/mit.py +++ b/mmseg/models/backbones/mit.py @@ -4,6 +4,7 @@ import warnings import torch import torch.nn as nn +import torch.utils.checkpoint as cp from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import MultiheadAttention @@ -235,6 +236,8 @@ class TransformerEncoderLayer(BaseModule): Default:None. sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head Attention of Segformer. Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. """ def __init__(self, @@ -248,7 +251,8 @@ class TransformerEncoderLayer(BaseModule): act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), batch_first=True, - sr_ratio=1): + sr_ratio=1, + with_cp=False): super(TransformerEncoderLayer, self).__init__() # The ret[0] of build_norm_layer is norm name. @@ -275,9 +279,19 @@ class TransformerEncoderLayer(BaseModule): dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), act_cfg=act_cfg) + self.with_cp = with_cp + def forward(self, x, hw_shape): - x = self.attn(self.norm1(x), hw_shape, identity=x) - x = self.ffn(self.norm2(x), hw_shape, identity=x) + + def _inner_forward(x): + x = self.attn(self.norm1(x), hw_shape, identity=x) + x = self.ffn(self.norm2(x), hw_shape, identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) return x @@ -319,6 +333,8 @@ class MixVisionTransformer(BaseModule): pretrained (str, optional): model pretrained path. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. """ def __init__(self, @@ -339,7 +355,8 @@ class MixVisionTransformer(BaseModule): act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN', eps=1e-6), pretrained=None, - init_cfg=None): + init_cfg=None, + with_cp=False): super(MixVisionTransformer, self).__init__(init_cfg=init_cfg) assert not (init_cfg and pretrained), \ @@ -358,8 +375,9 @@ class MixVisionTransformer(BaseModule): self.patch_sizes = patch_sizes self.strides = strides self.sr_ratios = sr_ratios + self.with_cp = with_cp assert num_stages == len(num_layers) == len(num_heads) \ - == len(patch_sizes) == len(strides) == len(sr_ratios) + == len(patch_sizes) == len(strides) == len(sr_ratios) self.out_indices = out_indices assert max(out_indices) < self.num_stages @@ -392,6 +410,7 @@ class MixVisionTransformer(BaseModule): qkv_bias=qkv_bias, act_cfg=act_cfg, norm_cfg=norm_cfg, + with_cp=with_cp, sr_ratio=sr_ratios[i]) for idx in range(num_layer) ]) in_channels = embed_dims_i diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 9c920baa6..fe6503992 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -4,6 +4,7 @@ import warnings import torch import torch.nn as nn +import torch.utils.checkpoint as cp from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, @@ -41,6 +42,8 @@ class TransformerEncoderLayer(BaseModule): batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default: True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. """ def __init__(self, @@ -54,7 +57,8 @@ class TransformerEncoderLayer(BaseModule): qkv_bias=True, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), - batch_first=True): + batch_first=True, + with_cp=False): super(TransformerEncoderLayer, self).__init__() self.norm1_name, norm1 = build_norm_layer( @@ -82,6 +86,8 @@ class TransformerEncoderLayer(BaseModule): dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), act_cfg=act_cfg) + self.with_cp = with_cp + @property def norm1(self): return getattr(self, self.norm1_name) @@ -91,8 +97,16 @@ class TransformerEncoderLayer(BaseModule): return getattr(self, self.norm2_name) def forward(self, x): - x = self.attn(self.norm1(x), identity=x) - x = self.ffn(self.norm2(x), identity=x) + + def _inner_forward(x): + x = self.attn(self.norm1(x), identity=x) + x = self.ffn(self.norm2(x), identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) return x @@ -251,6 +265,7 @@ class VisionTransformer(BaseModule): qkv_bias=qkv_bias, act_cfg=act_cfg, norm_cfg=norm_cfg, + with_cp=with_cp, batch_first=True)) self.final_norm = final_norm diff --git a/tests/test_models/test_backbones/test_mit.py b/tests/test_models/test_backbones/test_mit.py index 9eec1fa03..72f74fe20 100644 --- a/tests/test_models/test_backbones/test_mit.py +++ b/tests/test_models/test_backbones/test_mit.py @@ -3,7 +3,8 @@ import pytest import torch from mmseg.models.backbones import MixVisionTransformer -from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN +from mmseg.models.backbones.mit import (EfficientMultiheadAttention, MixFFN, + TransformerEncoderLayer) def test_mit(): @@ -56,6 +57,14 @@ def test_mit(): outs = MHA(temp, hw_shape, temp) assert out.shape == (1, token_len, 64) + # Test TransformerEncoderLayer with checkpoint forward + block = TransformerEncoderLayer( + embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True) + assert block.with_cp + x = torch.randn(1, 56 * 56, 64) + x_out = block(x, (56, 56)) + assert x_out.shape == torch.Size([1, 56 * 56, 64]) + def test_mit_init(): path = 'PATH_THAT_DO_NOT_EXIST' diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py index 4ce860c04..0d1ba7000 100644 --- a/tests/test_models/test_backbones/test_vit.py +++ b/tests/test_models/test_backbones/test_vit.py @@ -2,7 +2,8 @@ import pytest import torch -from mmseg.models.backbones.vit import VisionTransformer +from mmseg.models.backbones.vit import (TransformerEncoderLayer, + VisionTransformer) from .utils import check_norm_state @@ -119,6 +120,14 @@ def test_vit_backbone(): assert feat[0][0].shape == (1, 768, 14, 14) assert feat[0][1].shape == (1, 768) + # Test TransformerEncoderLayer with checkpoint forward + block = TransformerEncoderLayer( + embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True) + assert block.with_cp + x = torch.randn(1, 56 * 56, 64) + x_out = block(x) + assert x_out.shape == torch.Size([1, 56 * 56, 64]) + def test_vit_init(): path = 'PATH_THAT_DO_NOT_EXIST'