From ccd17571cedbd48706fd55051a17715b6790cece Mon Sep 17 00:00:00 2001 From: Zeyuan Date: Mon, 4 Sep 2023 19:29:24 +0400 Subject: [PATCH] [Feature] Implement gradient checkpointing (#1319) --- mmengine/runner/__init__.py | 4 +- mmengine/runner/activation_checkpointing.py | 26 +++++++++ mmengine/runner/runner.py | 8 +++ .../test_activation_checkpointing.py | 55 +++++++++++++++++++ 4 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 mmengine/runner/activation_checkpointing.py create mode 100644 tests/test_runner/test_activation_checkpointing.py diff --git a/mmengine/runner/__init__.py b/mmengine/runner/__init__.py index 531212b9..b00f8e83 100644 --- a/mmengine/runner/__init__.py +++ b/mmengine/runner/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from ._flexible_runner import FlexibleRunner +from .activation_checkpointing import turn_on_activation_checkpointing from .amp import autocast from .base_loop import BaseLoop from .checkpoint import (CheckpointLoader, find_latest_checkpoint, @@ -19,5 +20,6 @@ __all__ = [ 'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict', 'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop', 'TestLoop', 'Runner', 'get_priority', 'Priority', 'find_latest_checkpoint', - 'autocast', 'LogProcessor', 'set_random_seed', 'FlexibleRunner' + 'autocast', 'LogProcessor', 'set_random_seed', 'FlexibleRunner', + 'turn_on_activation_checkpointing' ] diff --git a/mmengine/runner/activation_checkpointing.py b/mmengine/runner/activation_checkpointing.py new file mode 100644 index 00000000..3db67f05 --- /dev/null +++ b/mmengine/runner/activation_checkpointing.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import wraps +from operator import attrgetter +from typing import List, Union + +import torch +from torch.utils.checkpoint import checkpoint + + +def wrap_forward(forward): + + @wraps(forward) + def wrapper(*args): + return checkpoint(forward, *args) + + return wrapper + + +def turn_on_activation_checkpointing(model: torch.nn.Module, + modules: Union[List[str], str]): + + if isinstance(modules, str): + modules = [modules] + for module_name in modules: + module = attrgetter(module_name)(model) + module.forward = wrap_forward(module.forward) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 12830cf4..bd6757a8 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -41,6 +41,7 @@ from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, set_multi_processing) from mmengine.visualization import Visualizer +from .activation_checkpointing import turn_on_activation_checkpointing from .base_loop import BaseLoop from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, find_latest_checkpoint, save_checkpoint, @@ -1722,6 +1723,13 @@ class Runner: # initialize the model weights self._init_model_weights() + # try to enable activation_checkpointing feature + modules = self.cfg.get('activation_checkpointing', None) + if modules is not None: + self.logger.info(f'Enabling the "activation_checkpointing" feature' + f' for sub-modules: {modules}') + turn_on_activation_checkpointing(ori_model, modules) + # try to enable efficient_conv_bn_eval feature modules = self.cfg.get('efficient_conv_bn_eval', None) if modules is not None: diff --git a/tests/test_runner/test_activation_checkpointing.py b/tests/test_runner/test_activation_checkpointing.py new file mode 100644 index 00000000..d48c027c --- /dev/null +++ b/tests/test_runner/test_activation_checkpointing.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +import torch.nn.functional as F +from torch import nn + +from mmengine.runner.activation_checkpointing import \ + turn_on_activation_checkpointing +from mmengine.testing import assert_allclose + + +class Model(nn.Module): + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) + self.bn1 = nn.BatchNorm2d(16) + self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2d(32) + self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.bn3 = nn.BatchNorm2d(64) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(64, 10) + + def forward(self, x): + x = self.bn1(self.conv1(x)) + x = F.relu(x) + x = self.bn2(self.conv2(x)) + x = F.relu(x) + x = self.bn3(self.conv3(x)) + x = F.relu(x) + x = self.pool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + +class TestActivationCheckpointing(TestCase): + + def test_activation_checkpointing(self): + model = Model() + input = torch.randn(16, 3, 224, 224) + input.requires_grad = True + output = model(input) + output.sum().backward() + grad = input.grad.clone() + + turn_on_activation_checkpointing(model, ['conv1', 'conv2', 'conv3']) + output2 = model(input) + output2.sum().backward() + grad2 = input.grad.clone() + + assert_allclose(output, output2) + assert_allclose(grad, grad2, rtol=1e-3, atol=1e-3)