[Fix] Check the version of torchvision in __init__ of DCN (#2556)

pull/2572/head
mengpenghui 2023-01-17 19:19:35 +08:00 committed by GitHub
parent 71ee2a61f2
commit e0323b4e4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 9 deletions

View File

@ -438,12 +438,9 @@ class DeformConv2dPack(DeformConv2d):
if IS_MLU_AVAILABLE:
import torchvision
from torchvision.ops import deform_conv2d as tv_deform_conv2d
from mmcv.utils import digit_version
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
from torchvision.ops import deform_conv2d as tv_deform_conv2d
@CONV_LAYERS.register_module('DCN', force=True)
class DeformConv2dPack_MLU(DeformConv2d):
@ -471,6 +468,8 @@ if IS_MLU_AVAILABLE:
"""
def __init__(self, *args, **kwargs):
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
@ -494,7 +493,6 @@ if IS_MLU_AVAILABLE:
) == 0, 'batch size must be divisible by im2col_step'
offset = self.conv_offset(x)
x = x.type_as(offset)
weight = self.weight
weight = weight.type_as(x)
weight = self.weight.type_as(x)
return tv_deform_conv2d(x, offset, weight, None, self.stride,
self.padding, self.dilation)

View File

@ -356,11 +356,9 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
if IS_MLU_AVAILABLE:
import torchvision
from torchvision.ops import deform_conv2d as tv_deform_conv2d
from mmcv.utils import digit_version
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
from torchvision.ops import deform_conv2d as tv_deform_conv2d
@CONV_LAYERS.register_module('DCNv2', force=True)
class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d):
@ -383,6 +381,8 @@ if IS_MLU_AVAILABLE:
"""
def __init__(self, *args, **kwargs):
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,