[Feature] Support Activation Checkpointing for ConvNeXt. (#1153)

* Support Activation Checkpointing for ConvNeXt

* Add docstring.

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1196/merge
Hakjin Lee 2022-11-14 16:04:28 +09:00 committed by GitHub
parent 11cd88f39a
commit cf5879988d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 16 deletions

View File

@ -6,6 +6,7 @@ from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.registry import MODELS
@ -78,8 +79,11 @@ class ConvNeXtBlock(BaseModule):
mlp_ratio=4.,
linear_pw_conv=True,
drop_path_rate=0.,
layer_scale_init_value=1e-6):
layer_scale_init_value=1e-6,
with_cp=False):
super().__init__()
self.with_cp = with_cp
self.depthwise_conv = nn.Conv2d(
in_channels, in_channels, groups=in_channels, **dw_conv_cfg)
@ -105,24 +109,32 @@ class ConvNeXtBlock(BaseModule):
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.depthwise_conv(x)
x = self.norm(x)
if self.linear_pw_conv:
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
def _inner_forward(x):
shortcut = x
x = self.depthwise_conv(x)
x = self.norm(x)
x = self.pointwise_conv1(x)
x = self.act(x)
x = self.pointwise_conv2(x)
if self.linear_pw_conv:
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
if self.linear_pw_conv:
x = x.permute(0, 3, 1, 2) # permute back
x = self.pointwise_conv1(x)
x = self.act(x)
x = self.pointwise_conv2(x)
if self.gamma is not None:
x = x.mul(self.gamma.view(1, -1, 1, 1))
if self.linear_pw_conv:
x = x.permute(0, 3, 1, 2) # permute back
x = shortcut + self.drop_path(x)
if self.gamma is not None:
x = x.mul(self.gamma.view(1, -1, 1, 1))
x = shortcut + self.drop_path(x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
@ -166,6 +178,8 @@ class ConvNeXt(BaseBackbone):
gap_before_final_norm (bool): Whether to globally average the feature
map before the final norm layer. In the official repo, it's only
used in classification task. Defaults to True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): Initialization config dict
""" # noqa: E501
arch_settings = {
@ -203,6 +217,7 @@ class ConvNeXt(BaseBackbone):
out_indices=-1,
frozen_stages=0,
gap_before_final_norm=True,
with_cp=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
@ -285,8 +300,8 @@ class ConvNeXt(BaseBackbone):
norm_cfg=norm_cfg,
act_cfg=act_cfg,
linear_pw_conv=linear_pw_conv,
layer_scale_init_value=layer_scale_init_value)
for j in range(depth)
layer_scale_init_value=layer_scale_init_value,
with_cp=with_cp) for j in range(depth)
])
block_idx += depth

View File

@ -84,3 +84,13 @@ def test_convnext():
for i in range(2, 4):
assert model.downsample_layers[i].training
assert model.stages[i].training
# Test Activation Checkpointing
model = ConvNeXt(arch='tiny', out_indices=-1, with_cp=True)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 768])