From 232a459e36a69715723c1c6277eb367df84a460d Mon Sep 17 00:00:00 2001 From: brendanartley Date: Wed, 7 Aug 2024 09:22:51 -0700 Subject: [PATCH] Added gradient checkpointing to hgnet --- timm/models/hgnet.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index 9482e107..ea0a92d9 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d from ._builder import build_model_with_cfg from ._registry import register_model, generate_default_cfgs +from ._manipulate import checkpoint_seq __all__ = ['HighPerfGpuNet'] @@ -338,10 +339,14 @@ class HighPerfGpuStage(nn.Module): ) ) self.blocks = nn.Sequential(*blocks_list) + self.grad_checkpointing= False def forward(self, x): x = self.downsample(x) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x, flatten=False) + else: + x = self.blocks(x) return x