mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2253 from brendanartley/hgnet-grad-checkpointing
Add gradient checkpointing to hgnets
This commit is contained in:
commit
2d5c9bf60d
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d
|
from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
from ._manipulate import checkpoint_seq
|
||||||
|
|
||||||
__all__ = ['HighPerfGpuNet']
|
__all__ = ['HighPerfGpuNet']
|
||||||
|
|
||||||
@ -338,9 +339,13 @@ class HighPerfGpuStage(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.blocks = nn.Sequential(*blocks_list)
|
self.blocks = nn.Sequential(*blocks_list)
|
||||||
|
self.grad_checkpointing= False
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.downsample(x)
|
x = self.downsample(x)
|
||||||
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
x = checkpoint_seq(self.blocks, x, flatten=False)
|
||||||
|
else:
|
||||||
x = self.blocks(x)
|
x = self.blocks(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user