Merge pull request #2253 from brendanartley/hgnet-grad-checkpointing

Add gradient checkpointing to hgnets
weights_only
Ross Wightman 2024-08-07 12:45:14 -07:00 committed by GitHub
commit 2d5c9bf60d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 1 deletions

View File

@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d
from ._builder import build_model_with_cfg
from ._registry import register_model, generate_default_cfgs
from ._manipulate import checkpoint_seq
__all__ = ['HighPerfGpuNet']
@ -338,10 +339,14 @@ class HighPerfGpuStage(nn.Module):
)
)
self.blocks = nn.Sequential(*blocks_list)
self.grad_checkpointing= False
def forward(self, x):
x = self.downsample(x)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x, flatten=False)
else:
x = self.blocks(x)
return x