[Feature] Support using checkpoint in Swin Transformer to save memory. (#557)

* add checkpoint in swin backbone

* add checkpoint in swin backbone
pull/529/head^2
takuoko 2021-12-01 12:49:00 +09:00 committed by GitHub
parent f6076bfeca
commit 9d9dce69ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 7 deletions

View File

@ -4,6 +4,7 @@ from typing import Sequence
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import trunc_normal_
@ -36,6 +37,9 @@ class SwinBlock(BaseModule):
Defaults to empty dict.
norm_cfg (dict, optional): The config of norm layers.
Defaults to dict(type='LN').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Defaults to False.
auto_pad (bool, optional): Auto pad the feature map to be divisible by
window_size, Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
@ -53,10 +57,12 @@ class SwinBlock(BaseModule):
attn_cfgs=dict(),
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
with_cp=False,
auto_pad=False,
init_cfg=None):
super(SwinBlock, self).__init__(init_cfg)
self.with_cp = with_cp
_attn_cfgs = {
'embed_dims': embed_dims,
@ -84,14 +90,24 @@ class SwinBlock(BaseModule):
self.ffn = FFN(**_ffn_cfgs)
def forward(self, x):
identity = x
x = self.norm1(x)
x = self.attn(x)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
def _inner_forward(x):
identity = x
x = self.norm1(x)
x = self.attn(x)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
@ -112,6 +128,9 @@ class SwinBlockSequence(BaseModule):
each block. Defaults to 0.
block_cfgs (Sequence[dict] | dict, optional): The extra config of each
block. Defaults to empty dicts.
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Defaults to False.
auto_pad (bool, optional): Auto pad the feature map to be divisible by
window_size, Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
@ -127,6 +146,7 @@ class SwinBlockSequence(BaseModule):
downsample_cfg=dict(),
drop_paths=0.,
block_cfgs=dict(),
with_cp=False,
auto_pad=False,
init_cfg=None):
super().__init__(init_cfg)
@ -147,6 +167,7 @@ class SwinBlockSequence(BaseModule):
'num_heads': num_heads,
'shift': False if i % 2 == 0 else True,
'drop_path': drop_paths[i],
'with_cp': with_cp,
'auto_pad': auto_pad,
**block_cfgs[i]
}
@ -211,6 +232,9 @@ class SwinTransformer(BaseBackbone):
Defaults to 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Defaults to False.
auto_pad (bool): If True, auto pad feature map to fit window_size.
Defaults to False.
norm_cfg (dict, optional): Config dict for normalization layer at end
@ -266,6 +290,7 @@ class SwinTransformer(BaseBackbone):
out_indices=(3, ),
use_abs_pos_embed=False,
auto_pad=False,
with_cp=False,
norm_cfg=dict(type='LN'),
stage_cfgs=dict(),
patch_cfg=dict(),
@ -333,6 +358,7 @@ class SwinTransformer(BaseBackbone):
'downsample': downsample,
'input_resolution': input_resolution,
'drop_paths': dpr[:depth],
'with_cp': with_cp,
'auto_pad': auto_pad,
**stage_cfg
}

View File

@ -9,6 +9,7 @@ import torch
from mmcv.runner import load_checkpoint, save_checkpoint
from mmcls.models.backbones import SwinTransformer
from mmcls.models.backbones.swin_transformer import SwinBlock
def test_assertion():
@ -77,6 +78,19 @@ def test_forward():
assert len(output) == 1
assert output[0].shape == (1, 1024, 12, 12)
# Test base arch with with checkpoint forward
model = SwinTransformer(arch='B', with_cp=True)
for m in model.modules():
if isinstance(m, SwinBlock):
assert m.with_cp
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert len(output) == 1
assert output[0].shape == (1, 1024, 7, 7)
def test_structure():
# Test small with use_abs_pos_embed = True