[Feature] Support Activation Checkpointing for ConvNeXt. (#1153)
* Support Activation Checkpointing for ConvNeXt * Add docstring. Co-authored-by: mzr1996 <mzr1996@163.com>pull/1196/merge
parent
11cd88f39a
commit
cf5879988d
|
@ -6,6 +6,7 @@ from typing import Sequence
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from mmengine.registry import MODELS
|
||||
|
@ -78,8 +79,11 @@ class ConvNeXtBlock(BaseModule):
|
|||
mlp_ratio=4.,
|
||||
linear_pw_conv=True,
|
||||
drop_path_rate=0.,
|
||||
layer_scale_init_value=1e-6):
|
||||
layer_scale_init_value=1e-6,
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.depthwise_conv = nn.Conv2d(
|
||||
in_channels, in_channels, groups=in_channels, **dw_conv_cfg)
|
||||
|
||||
|
@ -105,24 +109,32 @@ class ConvNeXtBlock(BaseModule):
|
|||
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if self.linear_pw_conv:
|
||||
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
||||
def _inner_forward(x):
|
||||
shortcut = x
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.norm(x)
|
||||
|
||||
x = self.pointwise_conv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pointwise_conv2(x)
|
||||
if self.linear_pw_conv:
|
||||
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
||||
|
||||
if self.linear_pw_conv:
|
||||
x = x.permute(0, 3, 1, 2) # permute back
|
||||
x = self.pointwise_conv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
if self.gamma is not None:
|
||||
x = x.mul(self.gamma.view(1, -1, 1, 1))
|
||||
if self.linear_pw_conv:
|
||||
x = x.permute(0, 3, 1, 2) # permute back
|
||||
|
||||
x = shortcut + self.drop_path(x)
|
||||
if self.gamma is not None:
|
||||
x = x.mul(self.gamma.view(1, -1, 1, 1))
|
||||
|
||||
x = shortcut + self.drop_path(x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -166,6 +178,8 @@ class ConvNeXt(BaseBackbone):
|
|||
gap_before_final_norm (bool): Whether to globally average the feature
|
||||
map before the final norm layer. In the official repo, it's only
|
||||
used in classification task. Defaults to True.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
init_cfg (dict, optional): Initialization config dict
|
||||
""" # noqa: E501
|
||||
arch_settings = {
|
||||
|
@ -203,6 +217,7 @@ class ConvNeXt(BaseBackbone):
|
|||
out_indices=-1,
|
||||
frozen_stages=0,
|
||||
gap_before_final_norm=True,
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
|
@ -285,8 +300,8 @@ class ConvNeXt(BaseBackbone):
|
|||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
linear_pw_conv=linear_pw_conv,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
for j in range(depth)
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
with_cp=with_cp) for j in range(depth)
|
||||
])
|
||||
block_idx += depth
|
||||
|
||||
|
|
|
@ -84,3 +84,13 @@ def test_convnext():
|
|||
for i in range(2, 4):
|
||||
assert model.downsample_layers[i].training
|
||||
assert model.stages[i].training
|
||||
|
||||
# Test Activation Checkpointing
|
||||
model = ConvNeXt(arch='tiny', out_indices=-1, with_cp=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 1
|
||||
assert feat[0].shape == torch.Size([1, 768])
|
||||
|
|
Loading…
Reference in New Issue