[Fix] Fix some warnings in unittest (#1522)

* [Fix] fix some warnings in unittest

* [Impl] standardize some warnings

* [Fix] fix warning type in test_deprecation

* [Fix] fix warning type

* [Fix] continue fixing

* [Fix] fix some details

* [Fix] fix docstring

* [Fix] del useless statement

* [Fix] keep compatibility for torch < 1.5.0
pull/1605/head
Jiazhen Wang 2021-12-22 10:57:10 +08:00 committed by GitHub
parent f367d621c6
commit fb486b96fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 89 additions and 70 deletions

View File

@ -131,7 +131,7 @@ class GeneralizedAttention(nn.Module):
max_len_kv = int((max_len - 1.0) / self.kv_stride + 1) max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
local_constraint_map = np.ones( local_constraint_map = np.ones(
(max_len, max_len, max_len_kv, max_len_kv), dtype=np.int) (max_len, max_len, max_len_kv, max_len_kv), dtype=int)
for iy in range(max_len): for iy in range(max_len):
for ix in range(max_len): for ix in range(max_len):
local_constraint_map[ local_constraint_map[

View File

@ -436,10 +436,11 @@ class MultiheadAttention(BaseModule):
**kwargs): **kwargs):
super(MultiheadAttention, self).__init__(init_cfg) super(MultiheadAttention, self).__init__(init_cfg)
if 'dropout' in kwargs: if 'dropout' in kwargs:
warnings.warn('The arguments `dropout` in MultiheadAttention ' warnings.warn(
'The arguments `dropout` in MultiheadAttention '
'has been deprecated, now you can separately ' 'has been deprecated, now you can separately '
'set `attn_drop`(float), proj_drop(float), ' 'set `attn_drop`(float), proj_drop(float), '
'and `dropout_layer`(dict) ') 'and `dropout_layer`(dict) ', DeprecationWarning)
attn_drop = kwargs['dropout'] attn_drop = kwargs['dropout']
dropout_layer['drop_prob'] = kwargs.pop('dropout') dropout_layer['drop_prob'] = kwargs.pop('dropout')
@ -689,7 +690,7 @@ class BaseTransformerLayer(BaseModule):
f'The arguments `{ori_name}` in BaseTransformerLayer ' f'The arguments `{ori_name}` in BaseTransformerLayer '
f'has been deprecated, now you should set `{new_name}` ' f'has been deprecated, now you should set `{new_name}` '
f'and other FFN related arguments ' f'and other FFN related arguments '
f'to a dict named `ffn_cfgs`. ') f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
ffn_cfgs[new_name] = kwargs[ori_name] ffn_cfgs[new_name] = kwargs[ori_name]
super(BaseTransformerLayer, self).__init__(init_cfg) super(BaseTransformerLayer, self).__init__(init_cfg)

View File

@ -24,6 +24,7 @@
# SOFTWARE. # SOFTWARE.
import sys import sys
import warnings
from functools import partial from functools import partial
import numpy as np import numpy as np
@ -502,8 +503,7 @@ def batch_counter_hook(module, input, output):
input = input[0] input = input[0]
batch_size = len(input) batch_size = len(input)
else: else:
pass warnings.warn('No positional inputs found for a module, '
print('Warning! No positional inputs found for a module, '
'assuming batch size is 1.') 'assuming batch size is 1.')
module.__batch_counter__ += batch_size module.__batch_counter__ += batch_size
@ -530,7 +530,7 @@ def remove_batch_counter_hook_function(module):
def add_flops_counter_variable_or_reset(module): def add_flops_counter_variable_or_reset(module):
if is_supported_instance(module): if is_supported_instance(module):
if hasattr(module, '__flops__') or hasattr(module, '__params__'): if hasattr(module, '__flops__') or hasattr(module, '__params__'):
print('Warning: variables __flops__ or __params__ are already ' warnings.warn('variables __flops__ or __params__ are already '
'defined for the module' + type(module).__name__ + 'defined for the module' + type(module).__name__ +
' ptflops can affect your code!') ' ptflops can affect your code!')
module.__flops__ = 0 module.__flops__ = 0

View File

@ -64,7 +64,8 @@ class CephBackend(BaseStorageBackend):
raise ImportError('Please install ceph to enable CephBackend.') raise ImportError('Please install ceph to enable CephBackend.')
warnings.warn( warnings.warn(
'CephBackend will be deprecated, please use PetrelBackend instead') 'CephBackend will be deprecated, please use PetrelBackend instead',
DeprecationWarning)
self._client = ceph.S3Client() self._client = ceph.S3Client()
assert isinstance(path_mapping, dict) or path_mapping is None assert isinstance(path_mapping, dict) or path_mapping is None
self.path_mapping = path_mapping self.path_mapping = path_mapping

View File

@ -12,7 +12,8 @@ class Conv2d_deprecated(Conv2d):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
warnings.warn( warnings.warn(
'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in' 'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in'
' the future. Please import them from "mmcv.cnn" instead') ' the future. Please import them from "mmcv.cnn" instead',
DeprecationWarning)
class ConvTranspose2d_deprecated(ConvTranspose2d): class ConvTranspose2d_deprecated(ConvTranspose2d):
@ -22,7 +23,7 @@ class ConvTranspose2d_deprecated(ConvTranspose2d):
warnings.warn( warnings.warn(
'Importing ConvTranspose2d wrapper from "mmcv.ops" will be ' 'Importing ConvTranspose2d wrapper from "mmcv.ops" will be '
'deprecated in the future. Please import them from "mmcv.cnn" ' 'deprecated in the future. Please import them from "mmcv.cnn" '
'instead') 'instead', DeprecationWarning)
class MaxPool2d_deprecated(MaxPool2d): class MaxPool2d_deprecated(MaxPool2d):
@ -31,7 +32,8 @@ class MaxPool2d_deprecated(MaxPool2d):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
warnings.warn( warnings.warn(
'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in' 'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in'
' the future. Please import them from "mmcv.cnn" instead') ' the future. Please import them from "mmcv.cnn" instead',
DeprecationWarning)
class Linear_deprecated(Linear): class Linear_deprecated(Linear):
@ -40,4 +42,5 @@ class Linear_deprecated(Linear):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
warnings.warn( warnings.warn(
'Importing Linear wrapper from "mmcv.ops" will be deprecated in' 'Importing Linear wrapper from "mmcv.ops" will be deprecated in'
' the future. Please import them from "mmcv.cnn" instead') ' the future. Please import them from "mmcv.cnn" instead',
DeprecationWarning)

View File

@ -188,7 +188,7 @@ class FusedBiasLeakyReLUFunction(Function):
class FusedBiasLeakyReLU(nn.Module): class FusedBiasLeakyReLU(nn.Module):
"""Fused bias leaky ReLU. r"""Fused bias leaky ReLU.
This function is introduced in the StyleGAN2: This function is introduced in the StyleGAN2:
`Analyzing and Improving the Image Quality of StyleGAN `Analyzing and Improving the Image Quality of StyleGAN
@ -197,8 +197,8 @@ class FusedBiasLeakyReLU(nn.Module):
The bias term comes from the convolution operation. In addition, to keep The bias term comes from the convolution operation. In addition, to keep
the variance of the feature map or gradients unchanged, they also adopt a the variance of the feature map or gradients unchanged, they also adopt a
scale similarly with Kaiming initialization. However, since the scale similarly with Kaiming initialization. However, since the
:math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the :math:`1+{alpha}^2` is too small, we can just ignore it. Therefore, the
final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501 final scale is just :math:`\sqrt{2}`. Of course, you may change it with
your own scale. your own scale.
TODO: Implement the CPU version. TODO: Implement the CPU version.
@ -224,7 +224,7 @@ class FusedBiasLeakyReLU(nn.Module):
def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5): def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
"""Fused bias leaky ReLU function. r"""Fused bias leaky ReLU function.
This function is introduced in the StyleGAN2: This function is introduced in the StyleGAN2:
`Analyzing and Improving the Image Quality of StyleGAN `Analyzing and Improving the Image Quality of StyleGAN
@ -233,8 +233,8 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
The bias term comes from the convolution operation. In addition, to keep The bias term comes from the convolution operation. In addition, to keep
the variance of the feature map or gradients unchanged, they also adopt a the variance of the feature map or gradients unchanged, they also adopt a
scale similarly with Kaiming initialization. However, since the scale similarly with Kaiming initialization. However, since the
:math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the :math:`1+{alpha}^2` is too small, we can just ignore it. Therefore, the
final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501 final scale is just :math:`\sqrt{2}`. Of course, you may change it with
your own scale. your own scale.
Args: Args:

View File

@ -61,7 +61,6 @@ class MaskedConv2dFunction(Function):
kernel_w=kernel_w, kernel_w=kernel_w,
pad_h=pad_h, pad_h=pad_h,
pad_w=pad_w) pad_w=pad_w)
masked_output = torch.addmm(1, bias[:, None], 1, masked_output = torch.addmm(1, bias[:, None], 1,
weight.view(out_channel, -1), data_col) weight.view(out_channel, -1), data_col)
ext_module.masked_col2im_forward( ext_module.masked_col2im_forward(

View File

@ -389,7 +389,7 @@ def nms_match(dets, iou_threshold):
if isinstance(dets, torch.Tensor): if isinstance(dets, torch.Tensor):
return [dets.new_tensor(m, dtype=torch.long) for m in matched] return [dets.new_tensor(m, dtype=torch.long) for m in matched]
else: else:
return [np.array(m, dtype=np.int) for m in matched] return [np.array(m, dtype=int) for m in matched]
def nms_rotated(dets, scores, iou_threshold, labels=None): def nms_rotated(dets, scores, iou_threshold, labels=None):

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
@ -119,8 +120,7 @@ class DynamicScatter(nn.Module):
inds = torch.where(coors[:, 0] == i) inds = torch.where(coors[:, 0] == i)
voxel, voxel_coor = self.forward_single( voxel, voxel_coor = self.forward_single(
points[inds], coors[inds][:, 1:]) points[inds], coors[inds][:, 1:])
coor_pad = nn.functional.pad( coor_pad = F.pad(voxel_coor, (1, 0), mode='constant', value=i)
voxel_coor, (1, 0), mode='constant', value=i)
voxel_coors.append(coor_pad) voxel_coors.append(coor_pad)
voxels.append(voxel) voxels.append(voxel)
features = torch.cat(voxels, dim=0) features = torch.cat(voxels, dim=0)

View File

@ -61,8 +61,10 @@ class BaseRunner(metaclass=ABCMeta):
if not callable(batch_processor): if not callable(batch_processor):
raise TypeError('batch_processor must be callable, ' raise TypeError('batch_processor must be callable, '
f'but got {type(batch_processor)}') f'but got {type(batch_processor)}')
warnings.warn('batch_processor is deprecated, please implement ' warnings.warn(
'train_step() and val_step() in the model instead.') 'batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.',
DeprecationWarning)
# raise an error is `batch_processor` is not None and # raise an error is `batch_processor` is not None and
# `model.train_step()` exists. # `model.train_step()` exists.
if is_module_wrapper(model): if is_module_wrapper(model):

View File

@ -358,7 +358,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
if backend == 'ceph': if backend == 'ceph':
warnings.warn( warnings.warn(
'CephBackend will be deprecated, please use PetrelBackend instead') 'CephBackend will be deprecated, please use PetrelBackend instead',
DeprecationWarning)
# CephClient and PetrelBackend have the same prefix 's3://' and the latter # CephClient and PetrelBackend have the same prefix 's3://' and the latter
# will be chosen as default. If PetrelBackend can not be instantiated # will be chosen as default. If PetrelBackend can not be instantiated
@ -389,8 +390,9 @@ def load_from_torchvision(filename, map_location=None):
""" """
model_urls = get_torchvision_models() model_urls = get_torchvision_models()
if filename.startswith('modelzoo://'): if filename.startswith('modelzoo://'):
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' warnings.warn(
'use "torchvision://" instead') 'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead', DeprecationWarning)
model_name = filename[11:] model_name = filename[11:]
else: else:
model_name = filename[14:] model_name = filename[14:]
@ -422,8 +424,10 @@ def load_from_openmmlab(filename, map_location=None):
deprecated_urls = get_deprecated_model_names() deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls: if model_name in deprecated_urls:
warnings.warn(f'{prefix_str}{model_name} is deprecated in favor ' warnings.warn(
f'of {prefix_str}{deprecated_urls[model_name]}') f'{prefix_str}{model_name} is deprecated in favor '
f'of {prefix_str}{deprecated_urls[model_name]}',
DeprecationWarning)
model_name = deprecated_urls[model_name] model_name = deprecated_urls[model_name]
model_url = model_urls[model_name] model_url = model_urls[model_name]
# check if is url # check if is url

View File

@ -183,5 +183,6 @@ class Runner(EpochBasedRunner):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn( warnings.warn(
'Runner was deprecated, please use EpochBasedRunner instead') 'Runner was deprecated, please use EpochBasedRunner instead',
DeprecationWarning)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@ -227,7 +227,8 @@ def force_fp32(apply_to=None, out_fp16=False):
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
warnings.warning( warnings.warning(
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be ' '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads') 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads',
DeprecationWarning)
_allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb) _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)

View File

@ -231,5 +231,6 @@ class TRTWraper(TRTWrapper):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
warnings.warn('TRTWraper will be deprecated in' warnings.warn(
' future. Please use TRTWrapper instead') 'TRTWraper will be deprecated in'
' future. Please use TRTWrapper instead', DeprecationWarning)

View File

@ -229,7 +229,7 @@ class Config:
if 'reference' in deprecation_info: if 'reference' in deprecation_info:
warning_msg += ' More information can be found at ' \ warning_msg += ' More information can be found at ' \
f'{deprecation_info["reference"]}' f'{deprecation_info["reference"]}'
warnings.warn(warning_msg) warnings.warn(warning_msg, DeprecationWarning)
cfg_text = filename + '\n' cfg_text = filename + '\n'
with open(filename, 'r', encoding='utf-8') as f: with open(filename, 'r', encoding='utf-8') as f:

View File

@ -28,10 +28,11 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
return False return False
def _legacy_zip_load(filename, model_dir, map_location): def _legacy_zip_load(filename, model_dir, map_location):
warnings.warn('Falling back to the old format < 1.6. This support will' warnings.warn(
'Falling back to the old format < 1.6. This support will'
' be deprecated in favor of default zipfile format ' ' be deprecated in favor of default zipfile format '
'introduced in 1.6. Please redo torch.save() to save it ' 'introduced in 1.6. Please redo torch.save() to save it '
'in the new zipfile format.') 'in the new zipfile format.', DeprecationWarning)
# Note: extractall() defaults to overwrite file if exists. No need to # Note: extractall() defaults to overwrite file if exists. No need to
# clean up beforehand. We deliberately don't handle tarfile here # clean up beforehand. We deliberately don't handle tarfile here
# since our legacy serialization format was in tar. # since our legacy serialization format was in tar.
@ -84,8 +85,9 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
""" """
# Issue warning to move data if old env is set # Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'): if os.getenv('TORCH_MODEL_ZOO'):
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env ' warnings.warn(
'TORCH_HOME instead') 'TORCH_MODEL_ZOO is deprecated, please use env '
'TORCH_HOME instead', DeprecationWarning)
if model_dir is None: if model_dir is None:
torch_home = _get_torch_home() torch_home = _get_torch_home()

View File

@ -315,7 +315,7 @@ def deprecated_api_warning(name_dict, cls_name=None):
warnings.warn( warnings.warn(
f'"{src_arg_name}" is deprecated in ' f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" ' f'`{func_name}`, please use "{dst_arg_name}" '
'instead') 'instead', DeprecationWarning)
arg_names[arg_names.index(src_arg_name)] = dst_arg_name arg_names[arg_names.index(src_arg_name)] = dst_arg_name
if kwargs: if kwargs:
for src_arg_name, dst_arg_name in name_dict.items(): for src_arg_name, dst_arg_name in name_dict.items():
@ -333,7 +333,7 @@ def deprecated_api_warning(name_dict, cls_name=None):
warnings.warn( warnings.warn(
f'"{src_arg_name}" is deprecated in ' f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" ' f'`{func_name}`, please use "{dst_arg_name}" '
'instead') 'instead', DeprecationWarning)
kwargs[dst_arg_name] = kwargs.pop(src_arg_name) kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
# apply converted arguments to the decorated method # apply converted arguments to the decorated method

View File

@ -251,7 +251,8 @@ class Registry:
warnings.warn( warnings.warn(
'The old API of register_module(module, force=False) ' 'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API ' 'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.') 'register_module(name=None, force=False, module=None) instead.',
DeprecationWarning)
if cls is None: if cls is None:
return partial(self.deprecated_register_module, force=force) return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force) self._register_module(cls, force=force)

View File

@ -1,5 +1,5 @@
import torch import torch
from torch.nn.functional import sigmoid import torch.nn.functional as F
from mmcv.cnn.bricks import Swish from mmcv.cnn.bricks import Swish
@ -7,7 +7,7 @@ from mmcv.cnn.bricks import Swish
def test_swish(): def test_swish():
act = Swish() act = Swish()
input = torch.randn(1, 3, 64, 64) input = torch.randn(1, 3, 64, 64)
expected_output = input * sigmoid(input) expected_output = input * F.sigmoid(input)
output = act(input) output = act(input)
# test output shape # test output shape
assert output.shape == expected_output.shape assert output.shape == expected_output.shape

View File

@ -333,7 +333,7 @@ class TestPhotometric:
input_img = np.array( input_img = np.array(
[[[0, 128, 255], [255, 128, 0]], [[0, 128, 255], [255, 128, 0]]], [[[0, 128, 255], [255, 128, 0]], [[0, 128, 255], [255, 128, 0]]],
dtype=np.float) dtype=float)
img = mmcv.lut_transform(input_img, lut_table) img = mmcv.lut_transform(input_img, lut_table)
baseline = cv2.LUT(np.array(input_img, dtype=np.uint8), lut_table) baseline = cv2.LUT(np.array(input_img, dtype=np.uint8), lut_table)
assert np.allclose(img, baseline) assert np.allclose(img, baseline)

View File

@ -1,6 +1,5 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -15,7 +14,8 @@ class TestBilinearGridSample(object):
input = torch.rand(1, 1, 20, 20, dtype=dtype) input = torch.rand(1, 1, 20, 20, dtype=dtype)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]]) grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input) grid = F.affine_grid(
grid, (1, 1, 15, 15), align_corners=align_corners).type_as(input)
grid *= multiplier grid *= multiplier
out = bilinear_grid_sample(input, grid, align_corners=align_corners) out = bilinear_grid_sample(input, grid, align_corners=align_corners)

View File

@ -8,6 +8,7 @@ import onnxruntime as rt
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from packaging import version from packaging import version
onnx_file = 'tmp.onnx' onnx_file = 'tmp.onnx'
@ -87,10 +88,11 @@ def test_grid_sample(mode, padding_mode, align_corners):
input = torch.rand(1, 1, 10, 10) input = torch.rand(1, 1, 10, 10)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]]) grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input) grid = F.affine_grid(
grid, (1, 1, 15, 15), align_corners=align_corners).type_as(input)
def func(input, grid): def func(input, grid):
return nn.functional.grid_sample( return F.grid_sample(
input, input,
grid, grid,
mode=mode, mode=mode,
@ -110,7 +112,8 @@ def test_bilinear_grid_sample(align_corners):
input = torch.rand(1, 1, 10, 10) input = torch.rand(1, 1, 10, 10)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]]) grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input) grid = F.affine_grid(
grid, (1, 1, 15, 15), align_corners=align_corners).type_as(input)
def func(input, grid): def func(input, grid):
return bilinear_grid_sample(input, grid, align_corners=align_corners) return bilinear_grid_sample(input, grid, align_corners=align_corners)
@ -462,7 +465,7 @@ def test_interpolate():
register_extra_symbolics(opset_version) register_extra_symbolics(opset_version)
def func(feat, scale_factor=2): def func(feat, scale_factor=2):
out = nn.functional.interpolate(feat, scale_factor=scale_factor) out = F.interpolate(feat, scale_factor=scale_factor)
return out return out
net = WrapFunction(func) net = WrapFunction(func)

View File

@ -7,6 +7,7 @@ import onnx
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
try: try:
from mmcv.tensorrt import (TRTWrapper, is_tensorrt_plugin_loaded, onnx2trt, from mmcv.tensorrt import (TRTWrapper, is_tensorrt_plugin_loaded, onnx2trt,
@ -487,11 +488,10 @@ def test_grid_sample(mode, padding_mode, align_corners):
input = torch.rand(1, 1, 10, 10).cuda() input = torch.rand(1, 1, 10, 10).cuda()
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]]) grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, grid = F.affine_grid(grid, (1, 1, 15, 15)).type_as(input).cuda()
(1, 1, 15, 15)).type_as(input).cuda()
def func(input, grid): def func(input, grid):
return nn.functional.grid_sample( return F.grid_sample(
input, input,
grid, grid,
mode=mode, mode=mode,

View File

@ -39,7 +39,7 @@ def test_voxelization(device_type):
device = torch.device(device_type) device = torch.device(device_type)
# test hard_voxelization on cpu/gpu # test hard_voxelization on cpu/gpu
points = torch.tensor(points).contiguous().to(device) points = points.contiguous().to(device)
coors, voxels, num_points_per_voxel = hard_voxelization.forward(points) coors, voxels, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy() coors = coors.cpu().detach().numpy()
voxels = voxels.cpu().detach().numpy() voxels = voxels.cpu().detach().numpy()

View File

@ -57,7 +57,7 @@ def test_build_runner():
@pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values()) @pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_epoch_based_runner(runner_class): def test_epoch_based_runner(runner_class):
with pytest.warns(UserWarning): with pytest.warns(DeprecationWarning):
# batch_processor is deprecated # batch_processor is deprecated
model = OldStyleModel() model = OldStyleModel()

View File

@ -533,6 +533,6 @@ def test_deprecation():
] ]
for cfg_file in deprecated_cfg_files: for cfg_file in deprecated_cfg_files:
with pytest.warns(UserWarning): with pytest.warns(DeprecationWarning):
cfg = Config.fromfile(cfg_file) cfg = Config.fromfile(cfg_file)
assert cfg.item1 == 'expected' assert cfg.item1 == 'expected'

View File

@ -96,15 +96,15 @@ def test_registry():
pass pass
# begin: test old APIs # begin: test old APIs
with pytest.warns(UserWarning): with pytest.warns(DeprecationWarning):
CATS.register_module(SphynxCat) CATS.register_module(SphynxCat)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat' assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(UserWarning): with pytest.warns(DeprecationWarning):
CATS.register_module(SphynxCat, force=True) CATS.register_module(SphynxCat, force=True)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat' assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(UserWarning): with pytest.warns(DeprecationWarning):
@CATS.register_module @CATS.register_module
class NewCat: class NewCat:
@ -112,11 +112,11 @@ def test_registry():
assert CATS.get('NewCat').__name__ == 'NewCat' assert CATS.get('NewCat').__name__ == 'NewCat'
with pytest.warns(UserWarning): with pytest.warns(DeprecationWarning):
CATS.deprecated_register_module(SphynxCat, force=True) CATS.deprecated_register_module(SphynxCat, force=True)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat' assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(UserWarning): with pytest.warns(DeprecationWarning):
@CATS.deprecated_register_module @CATS.deprecated_register_module
class CuteCat: class CuteCat:
@ -124,7 +124,7 @@ def test_registry():
assert CATS.get('CuteCat').__name__ == 'CuteCat' assert CATS.get('CuteCat').__name__ == 'CuteCat'
with pytest.warns(UserWarning): with pytest.warns(DeprecationWarning):
@CATS.deprecated_register_module(force=True) @CATS.deprecated_register_module(force=True)
class NewCat2: class NewCat2:

View File

@ -10,7 +10,7 @@ def test_color():
assert mmcv.color_val('green') == (0, 255, 0) assert mmcv.color_val('green') == (0, 255, 0)
assert mmcv.color_val((1, 2, 3)) == (1, 2, 3) assert mmcv.color_val((1, 2, 3)) == (1, 2, 3)
assert mmcv.color_val(100) == (100, 100, 100) assert mmcv.color_val(100) == (100, 100, 100)
assert mmcv.color_val(np.zeros(3, dtype=np.int)) == (0, 0, 0) assert mmcv.color_val(np.zeros(3, dtype=int)) == (0, 0, 0)
with pytest.raises(TypeError): with pytest.raises(TypeError):
mmcv.color_val([255, 255, 255]) mmcv.color_val([255, 255, 255])
with pytest.raises(TypeError): with pytest.raises(TypeError):