is_scripting() guard on checkpoint_seq

This commit is contained in:
Ross Wightman 2023-02-04 14:21:49 -08:00
parent 95ec255f7f
commit 72fba669a8
2 changed files with 2 additions and 2 deletions

View File

@ -336,7 +336,7 @@ class EfficientFormerStage(nn.Module):
def forward(self, x):
x = self.downsample(x)
if self.grad_checkpointing:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)

View File

@ -499,7 +499,7 @@ class EfficientFormerV2Stage(nn.Module):
def forward(self, x):
x = self.downsample(x)
if self.grad_checkpointing:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)