mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Check the version of torchvision in __init__ of DCN (#2556)
parent
71ee2a61f2
commit
e0323b4e4b
|
@ -438,12 +438,9 @@ class DeformConv2dPack(DeformConv2d):
|
|||
|
||||
if IS_MLU_AVAILABLE:
|
||||
import torchvision
|
||||
from torchvision.ops import deform_conv2d as tv_deform_conv2d
|
||||
|
||||
from mmcv.utils import digit_version
|
||||
assert digit_version(torchvision.__version__) >= digit_version(
|
||||
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
|
||||
|
||||
from torchvision.ops import deform_conv2d as tv_deform_conv2d
|
||||
|
||||
@CONV_LAYERS.register_module('DCN', force=True)
|
||||
class DeformConv2dPack_MLU(DeformConv2d):
|
||||
|
@ -471,6 +468,8 @@ if IS_MLU_AVAILABLE:
|
|||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
assert digit_version(torchvision.__version__) >= digit_version(
|
||||
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.conv_offset = nn.Conv2d(
|
||||
|
@ -494,7 +493,6 @@ if IS_MLU_AVAILABLE:
|
|||
) == 0, 'batch size must be divisible by im2col_step'
|
||||
offset = self.conv_offset(x)
|
||||
x = x.type_as(offset)
|
||||
weight = self.weight
|
||||
weight = weight.type_as(x)
|
||||
weight = self.weight.type_as(x)
|
||||
return tv_deform_conv2d(x, offset, weight, None, self.stride,
|
||||
self.padding, self.dilation)
|
||||
|
|
|
@ -356,11 +356,9 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
|
|||
|
||||
if IS_MLU_AVAILABLE:
|
||||
import torchvision
|
||||
from torchvision.ops import deform_conv2d as tv_deform_conv2d
|
||||
|
||||
from mmcv.utils import digit_version
|
||||
assert digit_version(torchvision.__version__) >= digit_version(
|
||||
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
|
||||
from torchvision.ops import deform_conv2d as tv_deform_conv2d
|
||||
|
||||
@CONV_LAYERS.register_module('DCNv2', force=True)
|
||||
class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d):
|
||||
|
@ -383,6 +381,8 @@ if IS_MLU_AVAILABLE:
|
|||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
assert digit_version(torchvision.__version__) >= digit_version(
|
||||
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_offset = nn.Conv2d(
|
||||
self.in_channels,
|
||||
|
|
Loading…
Reference in New Issue