[Feature] Support using checkpoint in Swin Transformer to save memory. (#557)
* add checkpoint in swin backbone * add checkpoint in swin backbonepull/529/head^2
parent
f6076bfeca
commit
9d9dce69ad
|
@ -4,6 +4,7 @@ from typing import Sequence
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
|
@ -36,6 +37,9 @@ class SwinBlock(BaseModule):
|
|||
Defaults to empty dict.
|
||||
norm_cfg (dict, optional): The config of norm layers.
|
||||
Defaults to dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Defaults to False.
|
||||
auto_pad (bool, optional): Auto pad the feature map to be divisible by
|
||||
window_size, Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
|
@ -53,10 +57,12 @@ class SwinBlock(BaseModule):
|
|||
attn_cfgs=dict(),
|
||||
ffn_cfgs=dict(),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
auto_pad=False,
|
||||
init_cfg=None):
|
||||
|
||||
super(SwinBlock, self).__init__(init_cfg)
|
||||
self.with_cp = with_cp
|
||||
|
||||
_attn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
|
@ -84,14 +90,24 @@ class SwinBlock(BaseModule):
|
|||
self.ffn = FFN(**_ffn_cfgs)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x)
|
||||
x = x + identity
|
||||
|
||||
identity = x
|
||||
x = self.norm2(x)
|
||||
x = self.ffn(x, identity=identity)
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x)
|
||||
x = x + identity
|
||||
|
||||
identity = x
|
||||
x = self.norm2(x)
|
||||
x = self.ffn(x, identity=identity)
|
||||
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
@ -112,6 +128,9 @@ class SwinBlockSequence(BaseModule):
|
|||
each block. Defaults to 0.
|
||||
block_cfgs (Sequence[dict] | dict, optional): The extra config of each
|
||||
block. Defaults to empty dicts.
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Defaults to False.
|
||||
auto_pad (bool, optional): Auto pad the feature map to be divisible by
|
||||
window_size, Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
|
@ -127,6 +146,7 @@ class SwinBlockSequence(BaseModule):
|
|||
downsample_cfg=dict(),
|
||||
drop_paths=0.,
|
||||
block_cfgs=dict(),
|
||||
with_cp=False,
|
||||
auto_pad=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
@ -147,6 +167,7 @@ class SwinBlockSequence(BaseModule):
|
|||
'num_heads': num_heads,
|
||||
'shift': False if i % 2 == 0 else True,
|
||||
'drop_path': drop_paths[i],
|
||||
'with_cp': with_cp,
|
||||
'auto_pad': auto_pad,
|
||||
**block_cfgs[i]
|
||||
}
|
||||
|
@ -211,6 +232,9 @@ class SwinTransformer(BaseBackbone):
|
|||
Defaults to 0.1.
|
||||
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
||||
the patch embedding. Defaults to False.
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Defaults to False.
|
||||
auto_pad (bool): If True, auto pad feature map to fit window_size.
|
||||
Defaults to False.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer at end
|
||||
|
@ -266,6 +290,7 @@ class SwinTransformer(BaseBackbone):
|
|||
out_indices=(3, ),
|
||||
use_abs_pos_embed=False,
|
||||
auto_pad=False,
|
||||
with_cp=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
stage_cfgs=dict(),
|
||||
patch_cfg=dict(),
|
||||
|
@ -333,6 +358,7 @@ class SwinTransformer(BaseBackbone):
|
|||
'downsample': downsample,
|
||||
'input_resolution': input_resolution,
|
||||
'drop_paths': dpr[:depth],
|
||||
'with_cp': with_cp,
|
||||
'auto_pad': auto_pad,
|
||||
**stage_cfg
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
|||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import SwinTransformer
|
||||
from mmcls.models.backbones.swin_transformer import SwinBlock
|
||||
|
||||
|
||||
def test_assertion():
|
||||
|
@ -77,6 +78,19 @@ def test_forward():
|
|||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 1024, 12, 12)
|
||||
|
||||
# Test base arch with with checkpoint forward
|
||||
model = SwinTransformer(arch='B', with_cp=True)
|
||||
for m in model.modules():
|
||||
if isinstance(m, SwinBlock):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 1024, 7, 7)
|
||||
|
||||
|
||||
def test_structure():
|
||||
# Test small with use_abs_pos_embed = True
|
||||
|
|
Loading…
Reference in New Issue