diff --git a/mmcv/cnn/resnet.py b/mmcv/cnn/resnet.py index 6e1c62863..794206776 100644 --- a/mmcv/cnn/resnet.py +++ b/mmcv/cnn/resnet.py @@ -28,7 +28,8 @@ class BasicBlock(nn.Module): stride=1, dilation=1, downsample=None, - style='pytorch'): + style='pytorch', + with_cp=False): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.bn1 = nn.BatchNorm2d(planes) @@ -38,6 +39,7 @@ class BasicBlock(nn.Module): self.downsample = downsample self.stride = stride self.dilation = dilation + assert not with_cp def forward(self, x): residual = x