[Fix] Fix some warnings in unittest (#1522)

* [Fix] fix some warnings in unittest

* [Impl] standardize some warnings

* [Fix] fix warning type in test_deprecation

* [Fix] fix warning type

* [Fix] continue fixing

* [Fix] fix some details

* [Fix] fix docstring

* [Fix] del useless statement

* [Fix] keep compatibility for torch < 1.5.0
pull/1605/head
Jiazhen Wang 2021-12-22 10:57:10 +08:00 committed by GitHub
parent f367d621c6
commit fb486b96fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 89 additions and 70 deletions

View File

@ -131,7 +131,7 @@ class GeneralizedAttention(nn.Module):
max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
local_constraint_map = np.ones(
(max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
(max_len, max_len, max_len_kv, max_len_kv), dtype=int)
for iy in range(max_len):
for ix in range(max_len):
local_constraint_map[

View File

@ -436,10 +436,11 @@ class MultiheadAttention(BaseModule):
**kwargs):
super(MultiheadAttention, self).__init__(init_cfg)
if 'dropout' in kwargs:
warnings.warn('The arguments `dropout` in MultiheadAttention '
'has been deprecated, now you can separately '
'set `attn_drop`(float), proj_drop(float), '
'and `dropout_layer`(dict) ')
warnings.warn(
'The arguments `dropout` in MultiheadAttention '
'has been deprecated, now you can separately '
'set `attn_drop`(float), proj_drop(float), '
'and `dropout_layer`(dict) ', DeprecationWarning)
attn_drop = kwargs['dropout']
dropout_layer['drop_prob'] = kwargs.pop('dropout')
@ -689,7 +690,7 @@ class BaseTransformerLayer(BaseModule):
f'The arguments `{ori_name}` in BaseTransformerLayer '
f'has been deprecated, now you should set `{new_name}` '
f'and other FFN related arguments '
f'to a dict named `ffn_cfgs`. ')
f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
ffn_cfgs[new_name] = kwargs[ori_name]
super(BaseTransformerLayer, self).__init__(init_cfg)

View File

@ -24,6 +24,7 @@
# SOFTWARE.
import sys
import warnings
from functools import partial
import numpy as np
@ -502,9 +503,8 @@ def batch_counter_hook(module, input, output):
input = input[0]
batch_size = len(input)
else:
pass
print('Warning! No positional inputs found for a module, '
'assuming batch size is 1.')
warnings.warn('No positional inputs found for a module, '
'assuming batch size is 1.')
module.__batch_counter__ += batch_size
@ -530,9 +530,9 @@ def remove_batch_counter_hook_function(module):
def add_flops_counter_variable_or_reset(module):
if is_supported_instance(module):
if hasattr(module, '__flops__') or hasattr(module, '__params__'):
print('Warning: variables __flops__ or __params__ are already '
'defined for the module' + type(module).__name__ +
' ptflops can affect your code!')
warnings.warn('variables __flops__ or __params__ are already '
'defined for the module' + type(module).__name__ +
' ptflops can affect your code!')
module.__flops__ = 0
module.__params__ = get_model_parameters_number(module)

View File

@ -64,7 +64,8 @@ class CephBackend(BaseStorageBackend):
raise ImportError('Please install ceph to enable CephBackend.')
warnings.warn(
'CephBackend will be deprecated, please use PetrelBackend instead')
'CephBackend will be deprecated, please use PetrelBackend instead',
DeprecationWarning)
self._client = ceph.S3Client()
assert isinstance(path_mapping, dict) or path_mapping is None
self.path_mapping = path_mapping

View File

@ -12,7 +12,8 @@ class Conv2d_deprecated(Conv2d):
super().__init__(*args, **kwargs)
warnings.warn(
'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in'
' the future. Please import them from "mmcv.cnn" instead')
' the future. Please import them from "mmcv.cnn" instead',
DeprecationWarning)
class ConvTranspose2d_deprecated(ConvTranspose2d):
@ -22,7 +23,7 @@ class ConvTranspose2d_deprecated(ConvTranspose2d):
warnings.warn(
'Importing ConvTranspose2d wrapper from "mmcv.ops" will be '
'deprecated in the future. Please import them from "mmcv.cnn" '
'instead')
'instead', DeprecationWarning)
class MaxPool2d_deprecated(MaxPool2d):
@ -31,7 +32,8 @@ class MaxPool2d_deprecated(MaxPool2d):
super().__init__(*args, **kwargs)
warnings.warn(
'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in'
' the future. Please import them from "mmcv.cnn" instead')
' the future. Please import them from "mmcv.cnn" instead',
DeprecationWarning)
class Linear_deprecated(Linear):
@ -40,4 +42,5 @@ class Linear_deprecated(Linear):
super().__init__(*args, **kwargs)
warnings.warn(
'Importing Linear wrapper from "mmcv.ops" will be deprecated in'
' the future. Please import them from "mmcv.cnn" instead')
' the future. Please import them from "mmcv.cnn" instead',
DeprecationWarning)

View File

@ -188,7 +188,7 @@ class FusedBiasLeakyReLUFunction(Function):
class FusedBiasLeakyReLU(nn.Module):
"""Fused bias leaky ReLU.
r"""Fused bias leaky ReLU.
This function is introduced in the StyleGAN2:
`Analyzing and Improving the Image Quality of StyleGAN
@ -197,8 +197,8 @@ class FusedBiasLeakyReLU(nn.Module):
The bias term comes from the convolution operation. In addition, to keep
the variance of the feature map or gradients unchanged, they also adopt a
scale similarly with Kaiming initialization. However, since the
:math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
:math:`1+{alpha}^2` is too small, we can just ignore it. Therefore, the
final scale is just :math:`\sqrt{2}`. Of course, you may change it with
your own scale.
TODO: Implement the CPU version.
@ -224,7 +224,7 @@ class FusedBiasLeakyReLU(nn.Module):
def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
"""Fused bias leaky ReLU function.
r"""Fused bias leaky ReLU function.
This function is introduced in the StyleGAN2:
`Analyzing and Improving the Image Quality of StyleGAN
@ -233,8 +233,8 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
The bias term comes from the convolution operation. In addition, to keep
the variance of the feature map or gradients unchanged, they also adopt a
scale similarly with Kaiming initialization. However, since the
:math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
:math:`1+{alpha}^2` is too small, we can just ignore it. Therefore, the
final scale is just :math:`\sqrt{2}`. Of course, you may change it with
your own scale.
Args:

View File

@ -61,7 +61,6 @@ class MaskedConv2dFunction(Function):
kernel_w=kernel_w,
pad_h=pad_h,
pad_w=pad_w)
masked_output = torch.addmm(1, bias[:, None], 1,
weight.view(out_channel, -1), data_col)
ext_module.masked_col2im_forward(

View File

@ -389,7 +389,7 @@ def nms_match(dets, iou_threshold):
if isinstance(dets, torch.Tensor):
return [dets.new_tensor(m, dtype=torch.long) for m in matched]
else:
return [np.array(m, dtype=np.int) for m in matched]
return [np.array(m, dtype=int) for m in matched]
def nms_rotated(dets, scores, iou_threshold, labels=None):

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function
@ -119,8 +120,7 @@ class DynamicScatter(nn.Module):
inds = torch.where(coors[:, 0] == i)
voxel, voxel_coor = self.forward_single(
points[inds], coors[inds][:, 1:])
coor_pad = nn.functional.pad(
voxel_coor, (1, 0), mode='constant', value=i)
coor_pad = F.pad(voxel_coor, (1, 0), mode='constant', value=i)
voxel_coors.append(coor_pad)
voxels.append(voxel)
features = torch.cat(voxels, dim=0)

View File

@ -61,8 +61,10 @@ class BaseRunner(metaclass=ABCMeta):
if not callable(batch_processor):
raise TypeError('batch_processor must be callable, '
f'but got {type(batch_processor)}')
warnings.warn('batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.')
warnings.warn(
'batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.',
DeprecationWarning)
# raise an error is `batch_processor` is not None and
# `model.train_step()` exists.
if is_module_wrapper(model):

View File

@ -358,7 +358,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
if backend == 'ceph':
warnings.warn(
'CephBackend will be deprecated, please use PetrelBackend instead')
'CephBackend will be deprecated, please use PetrelBackend instead',
DeprecationWarning)
# CephClient and PetrelBackend have the same prefix 's3://' and the latter
# will be chosen as default. If PetrelBackend can not be instantiated
@ -389,8 +390,9 @@ def load_from_torchvision(filename, map_location=None):
"""
model_urls = get_torchvision_models()
if filename.startswith('modelzoo://'):
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead')
warnings.warn(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead', DeprecationWarning)
model_name = filename[11:]
else:
model_name = filename[14:]
@ -422,8 +424,10 @@ def load_from_openmmlab(filename, map_location=None):
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
f'of {prefix_str}{deprecated_urls[model_name]}')
warnings.warn(
f'{prefix_str}{model_name} is deprecated in favor '
f'of {prefix_str}{deprecated_urls[model_name]}',
DeprecationWarning)
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url

View File

@ -183,5 +183,6 @@ class Runner(EpochBasedRunner):
def __init__(self, *args, **kwargs):
warnings.warn(
'Runner was deprecated, please use EpochBasedRunner instead')
'Runner was deprecated, please use EpochBasedRunner instead',
DeprecationWarning)
super().__init__(*args, **kwargs)

View File

@ -227,7 +227,8 @@ def force_fp32(apply_to=None, out_fp16=False):
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
warnings.warning(
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads',
DeprecationWarning)
_allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)

View File

@ -231,5 +231,6 @@ class TRTWraper(TRTWrapper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn('TRTWraper will be deprecated in'
' future. Please use TRTWrapper instead')
warnings.warn(
'TRTWraper will be deprecated in'
' future. Please use TRTWrapper instead', DeprecationWarning)

View File

@ -229,7 +229,7 @@ class Config:
if 'reference' in deprecation_info:
warning_msg += ' More information can be found at ' \
f'{deprecation_info["reference"]}'
warnings.warn(warning_msg)
warnings.warn(warning_msg, DeprecationWarning)
cfg_text = filename + '\n'
with open(filename, 'r', encoding='utf-8') as f:

View File

@ -28,10 +28,11 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
return False
def _legacy_zip_load(filename, model_dir, map_location):
warnings.warn('Falling back to the old format < 1.6. This support will'
' be deprecated in favor of default zipfile format '
'introduced in 1.6. Please redo torch.save() to save it '
'in the new zipfile format.')
warnings.warn(
'Falling back to the old format < 1.6. This support will'
' be deprecated in favor of default zipfile format '
'introduced in 1.6. Please redo torch.save() to save it '
'in the new zipfile format.', DeprecationWarning)
# Note: extractall() defaults to overwrite file if exists. No need to
# clean up beforehand. We deliberately don't handle tarfile here
# since our legacy serialization format was in tar.
@ -84,8 +85,9 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env '
'TORCH_HOME instead')
warnings.warn(
'TORCH_MODEL_ZOO is deprecated, please use env '
'TORCH_HOME instead', DeprecationWarning)
if model_dir is None:
torch_home = _get_torch_home()

View File

@ -315,7 +315,7 @@ def deprecated_api_warning(name_dict, cls_name=None):
warnings.warn(
f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" '
'instead')
'instead', DeprecationWarning)
arg_names[arg_names.index(src_arg_name)] = dst_arg_name
if kwargs:
for src_arg_name, dst_arg_name in name_dict.items():
@ -333,7 +333,7 @@ def deprecated_api_warning(name_dict, cls_name=None):
warnings.warn(
f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" '
'instead')
'instead', DeprecationWarning)
kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
# apply converted arguments to the decorated method

View File

@ -251,7 +251,8 @@ class Registry:
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.')
'register_module(name=None, force=False, module=None) instead.',
DeprecationWarning)
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)

View File

@ -1,5 +1,5 @@
import torch
from torch.nn.functional import sigmoid
import torch.nn.functional as F
from mmcv.cnn.bricks import Swish
@ -7,7 +7,7 @@ from mmcv.cnn.bricks import Swish
def test_swish():
act = Swish()
input = torch.randn(1, 3, 64, 64)
expected_output = input * sigmoid(input)
expected_output = input * F.sigmoid(input)
output = act(input)
# test output shape
assert output.shape == expected_output.shape

View File

@ -333,7 +333,7 @@ class TestPhotometric:
input_img = np.array(
[[[0, 128, 255], [255, 128, 0]], [[0, 128, 255], [255, 128, 0]]],
dtype=np.float)
dtype=float)
img = mmcv.lut_transform(input_img, lut_table)
baseline = cv2.LUT(np.array(input_img, dtype=np.uint8), lut_table)
assert np.allclose(img, baseline)

View File

@ -1,6 +1,5 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -15,7 +14,8 @@ class TestBilinearGridSample(object):
input = torch.rand(1, 1, 20, 20, dtype=dtype)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
grid = F.affine_grid(
grid, (1, 1, 15, 15), align_corners=align_corners).type_as(input)
grid *= multiplier
out = bilinear_grid_sample(input, grid, align_corners=align_corners)

View File

@ -8,6 +8,7 @@ import onnxruntime as rt
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version
onnx_file = 'tmp.onnx'
@ -87,10 +88,11 @@ def test_grid_sample(mode, padding_mode, align_corners):
input = torch.rand(1, 1, 10, 10)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
grid = F.affine_grid(
grid, (1, 1, 15, 15), align_corners=align_corners).type_as(input)
def func(input, grid):
return nn.functional.grid_sample(
return F.grid_sample(
input,
grid,
mode=mode,
@ -110,7 +112,8 @@ def test_bilinear_grid_sample(align_corners):
input = torch.rand(1, 1, 10, 10)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
grid = F.affine_grid(
grid, (1, 1, 15, 15), align_corners=align_corners).type_as(input)
def func(input, grid):
return bilinear_grid_sample(input, grid, align_corners=align_corners)
@ -462,7 +465,7 @@ def test_interpolate():
register_extra_symbolics(opset_version)
def func(feat, scale_factor=2):
out = nn.functional.interpolate(feat, scale_factor=scale_factor)
out = F.interpolate(feat, scale_factor=scale_factor)
return out
net = WrapFunction(func)

View File

@ -7,6 +7,7 @@ import onnx
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from mmcv.tensorrt import (TRTWrapper, is_tensorrt_plugin_loaded, onnx2trt,
@ -487,11 +488,10 @@ def test_grid_sample(mode, padding_mode, align_corners):
input = torch.rand(1, 1, 10, 10).cuda()
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid,
(1, 1, 15, 15)).type_as(input).cuda()
grid = F.affine_grid(grid, (1, 1, 15, 15)).type_as(input).cuda()
def func(input, grid):
return nn.functional.grid_sample(
return F.grid_sample(
input,
grid,
mode=mode,

View File

@ -39,7 +39,7 @@ def test_voxelization(device_type):
device = torch.device(device_type)
# test hard_voxelization on cpu/gpu
points = torch.tensor(points).contiguous().to(device)
points = points.contiguous().to(device)
coors, voxels, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy()
voxels = voxels.cpu().detach().numpy()

View File

@ -57,7 +57,7 @@ def test_build_runner():
@pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_epoch_based_runner(runner_class):
with pytest.warns(UserWarning):
with pytest.warns(DeprecationWarning):
# batch_processor is deprecated
model = OldStyleModel()

View File

@ -533,6 +533,6 @@ def test_deprecation():
]
for cfg_file in deprecated_cfg_files:
with pytest.warns(UserWarning):
with pytest.warns(DeprecationWarning):
cfg = Config.fromfile(cfg_file)
assert cfg.item1 == 'expected'

View File

@ -96,15 +96,15 @@ def test_registry():
pass
# begin: test old APIs
with pytest.warns(UserWarning):
with pytest.warns(DeprecationWarning):
CATS.register_module(SphynxCat)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(UserWarning):
with pytest.warns(DeprecationWarning):
CATS.register_module(SphynxCat, force=True)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(UserWarning):
with pytest.warns(DeprecationWarning):
@CATS.register_module
class NewCat:
@ -112,11 +112,11 @@ def test_registry():
assert CATS.get('NewCat').__name__ == 'NewCat'
with pytest.warns(UserWarning):
with pytest.warns(DeprecationWarning):
CATS.deprecated_register_module(SphynxCat, force=True)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(UserWarning):
with pytest.warns(DeprecationWarning):
@CATS.deprecated_register_module
class CuteCat:
@ -124,7 +124,7 @@ def test_registry():
assert CATS.get('CuteCat').__name__ == 'CuteCat'
with pytest.warns(UserWarning):
with pytest.warns(DeprecationWarning):
@CATS.deprecated_register_module(force=True)
class NewCat2:

View File

@ -10,7 +10,7 @@ def test_color():
assert mmcv.color_val('green') == (0, 255, 0)
assert mmcv.color_val((1, 2, 3)) == (1, 2, 3)
assert mmcv.color_val(100) == (100, 100, 100)
assert mmcv.color_val(np.zeros(3, dtype=np.int)) == (0, 0, 0)
assert mmcv.color_val(np.zeros(3, dtype=int)) == (0, 0, 0)
with pytest.raises(TypeError):
mmcv.color_val([255, 255, 255])
with pytest.raises(TypeError):