mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix some warnings in unittest (#1522)
* [Fix] fix some warnings in unittest * [Impl] standardize some warnings * [Fix] fix warning type in test_deprecation * [Fix] fix warning type * [Fix] continue fixing * [Fix] fix some details * [Fix] fix docstring * [Fix] del useless statement * [Fix] keep compatibility for torch < 1.5.0pull/1605/head
parent
f367d621c6
commit
fb486b96fd
|
@ -131,7 +131,7 @@ class GeneralizedAttention(nn.Module):
|
|||
|
||||
max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
|
||||
local_constraint_map = np.ones(
|
||||
(max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
|
||||
(max_len, max_len, max_len_kv, max_len_kv), dtype=int)
|
||||
for iy in range(max_len):
|
||||
for ix in range(max_len):
|
||||
local_constraint_map[
|
||||
|
|
|
@ -436,10 +436,11 @@ class MultiheadAttention(BaseModule):
|
|||
**kwargs):
|
||||
super(MultiheadAttention, self).__init__(init_cfg)
|
||||
if 'dropout' in kwargs:
|
||||
warnings.warn('The arguments `dropout` in MultiheadAttention '
|
||||
'has been deprecated, now you can separately '
|
||||
'set `attn_drop`(float), proj_drop(float), '
|
||||
'and `dropout_layer`(dict) ')
|
||||
warnings.warn(
|
||||
'The arguments `dropout` in MultiheadAttention '
|
||||
'has been deprecated, now you can separately '
|
||||
'set `attn_drop`(float), proj_drop(float), '
|
||||
'and `dropout_layer`(dict) ', DeprecationWarning)
|
||||
attn_drop = kwargs['dropout']
|
||||
dropout_layer['drop_prob'] = kwargs.pop('dropout')
|
||||
|
||||
|
@ -689,7 +690,7 @@ class BaseTransformerLayer(BaseModule):
|
|||
f'The arguments `{ori_name}` in BaseTransformerLayer '
|
||||
f'has been deprecated, now you should set `{new_name}` '
|
||||
f'and other FFN related arguments '
|
||||
f'to a dict named `ffn_cfgs`. ')
|
||||
f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
|
||||
ffn_cfgs[new_name] = kwargs[ori_name]
|
||||
|
||||
super(BaseTransformerLayer, self).__init__(init_cfg)
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
# SOFTWARE.
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
@ -502,9 +503,8 @@ def batch_counter_hook(module, input, output):
|
|||
input = input[0]
|
||||
batch_size = len(input)
|
||||
else:
|
||||
pass
|
||||
print('Warning! No positional inputs found for a module, '
|
||||
'assuming batch size is 1.')
|
||||
warnings.warn('No positional inputs found for a module, '
|
||||
'assuming batch size is 1.')
|
||||
module.__batch_counter__ += batch_size
|
||||
|
||||
|
||||
|
@ -530,9 +530,9 @@ def remove_batch_counter_hook_function(module):
|
|||
def add_flops_counter_variable_or_reset(module):
|
||||
if is_supported_instance(module):
|
||||
if hasattr(module, '__flops__') or hasattr(module, '__params__'):
|
||||
print('Warning: variables __flops__ or __params__ are already '
|
||||
'defined for the module' + type(module).__name__ +
|
||||
' ptflops can affect your code!')
|
||||
warnings.warn('variables __flops__ or __params__ are already '
|
||||
'defined for the module' + type(module).__name__ +
|
||||
' ptflops can affect your code!')
|
||||
module.__flops__ = 0
|
||||
module.__params__ = get_model_parameters_number(module)
|
||||
|
||||
|
|
|
@ -64,7 +64,8 @@ class CephBackend(BaseStorageBackend):
|
|||
raise ImportError('Please install ceph to enable CephBackend.')
|
||||
|
||||
warnings.warn(
|
||||
'CephBackend will be deprecated, please use PetrelBackend instead')
|
||||
'CephBackend will be deprecated, please use PetrelBackend instead',
|
||||
DeprecationWarning)
|
||||
self._client = ceph.S3Client()
|
||||
assert isinstance(path_mapping, dict) or path_mapping is None
|
||||
self.path_mapping = path_mapping
|
||||
|
|
|
@ -12,7 +12,8 @@ class Conv2d_deprecated(Conv2d):
|
|||
super().__init__(*args, **kwargs)
|
||||
warnings.warn(
|
||||
'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in'
|
||||
' the future. Please import them from "mmcv.cnn" instead')
|
||||
' the future. Please import them from "mmcv.cnn" instead',
|
||||
DeprecationWarning)
|
||||
|
||||
|
||||
class ConvTranspose2d_deprecated(ConvTranspose2d):
|
||||
|
@ -22,7 +23,7 @@ class ConvTranspose2d_deprecated(ConvTranspose2d):
|
|||
warnings.warn(
|
||||
'Importing ConvTranspose2d wrapper from "mmcv.ops" will be '
|
||||
'deprecated in the future. Please import them from "mmcv.cnn" '
|
||||
'instead')
|
||||
'instead', DeprecationWarning)
|
||||
|
||||
|
||||
class MaxPool2d_deprecated(MaxPool2d):
|
||||
|
@ -31,7 +32,8 @@ class MaxPool2d_deprecated(MaxPool2d):
|
|||
super().__init__(*args, **kwargs)
|
||||
warnings.warn(
|
||||
'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in'
|
||||
' the future. Please import them from "mmcv.cnn" instead')
|
||||
' the future. Please import them from "mmcv.cnn" instead',
|
||||
DeprecationWarning)
|
||||
|
||||
|
||||
class Linear_deprecated(Linear):
|
||||
|
@ -40,4 +42,5 @@ class Linear_deprecated(Linear):
|
|||
super().__init__(*args, **kwargs)
|
||||
warnings.warn(
|
||||
'Importing Linear wrapper from "mmcv.ops" will be deprecated in'
|
||||
' the future. Please import them from "mmcv.cnn" instead')
|
||||
' the future. Please import them from "mmcv.cnn" instead',
|
||||
DeprecationWarning)
|
||||
|
|
|
@ -188,7 +188,7 @@ class FusedBiasLeakyReLUFunction(Function):
|
|||
|
||||
|
||||
class FusedBiasLeakyReLU(nn.Module):
|
||||
"""Fused bias leaky ReLU.
|
||||
r"""Fused bias leaky ReLU.
|
||||
|
||||
This function is introduced in the StyleGAN2:
|
||||
`Analyzing and Improving the Image Quality of StyleGAN
|
||||
|
@ -197,8 +197,8 @@ class FusedBiasLeakyReLU(nn.Module):
|
|||
The bias term comes from the convolution operation. In addition, to keep
|
||||
the variance of the feature map or gradients unchanged, they also adopt a
|
||||
scale similarly with Kaiming initialization. However, since the
|
||||
:math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
|
||||
final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
|
||||
:math:`1+{alpha}^2` is too small, we can just ignore it. Therefore, the
|
||||
final scale is just :math:`\sqrt{2}`. Of course, you may change it with
|
||||
your own scale.
|
||||
|
||||
TODO: Implement the CPU version.
|
||||
|
@ -224,7 +224,7 @@ class FusedBiasLeakyReLU(nn.Module):
|
|||
|
||||
|
||||
def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
|
||||
"""Fused bias leaky ReLU function.
|
||||
r"""Fused bias leaky ReLU function.
|
||||
|
||||
This function is introduced in the StyleGAN2:
|
||||
`Analyzing and Improving the Image Quality of StyleGAN
|
||||
|
@ -233,8 +233,8 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
|
|||
The bias term comes from the convolution operation. In addition, to keep
|
||||
the variance of the feature map or gradients unchanged, they also adopt a
|
||||
scale similarly with Kaiming initialization. However, since the
|
||||
:math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
|
||||
final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
|
||||
:math:`1+{alpha}^2` is too small, we can just ignore it. Therefore, the
|
||||
final scale is just :math:`\sqrt{2}`. Of course, you may change it with
|
||||
your own scale.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -61,7 +61,6 @@ class MaskedConv2dFunction(Function):
|
|||
kernel_w=kernel_w,
|
||||
pad_h=pad_h,
|
||||
pad_w=pad_w)
|
||||
|
||||
masked_output = torch.addmm(1, bias[:, None], 1,
|
||||
weight.view(out_channel, -1), data_col)
|
||||
ext_module.masked_col2im_forward(
|
||||
|
|
|
@ -389,7 +389,7 @@ def nms_match(dets, iou_threshold):
|
|||
if isinstance(dets, torch.Tensor):
|
||||
return [dets.new_tensor(m, dtype=torch.long) for m in matched]
|
||||
else:
|
||||
return [np.array(m, dtype=np.int) for m in matched]
|
||||
return [np.array(m, dtype=int) for m in matched]
|
||||
|
||||
|
||||
def nms_rotated(dets, scores, iou_threshold, labels=None):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
|
@ -119,8 +120,7 @@ class DynamicScatter(nn.Module):
|
|||
inds = torch.where(coors[:, 0] == i)
|
||||
voxel, voxel_coor = self.forward_single(
|
||||
points[inds], coors[inds][:, 1:])
|
||||
coor_pad = nn.functional.pad(
|
||||
voxel_coor, (1, 0), mode='constant', value=i)
|
||||
coor_pad = F.pad(voxel_coor, (1, 0), mode='constant', value=i)
|
||||
voxel_coors.append(coor_pad)
|
||||
voxels.append(voxel)
|
||||
features = torch.cat(voxels, dim=0)
|
||||
|
|
|
@ -61,8 +61,10 @@ class BaseRunner(metaclass=ABCMeta):
|
|||
if not callable(batch_processor):
|
||||
raise TypeError('batch_processor must be callable, '
|
||||
f'but got {type(batch_processor)}')
|
||||
warnings.warn('batch_processor is deprecated, please implement '
|
||||
'train_step() and val_step() in the model instead.')
|
||||
warnings.warn(
|
||||
'batch_processor is deprecated, please implement '
|
||||
'train_step() and val_step() in the model instead.',
|
||||
DeprecationWarning)
|
||||
# raise an error is `batch_processor` is not None and
|
||||
# `model.train_step()` exists.
|
||||
if is_module_wrapper(model):
|
||||
|
|
|
@ -358,7 +358,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
|
|||
|
||||
if backend == 'ceph':
|
||||
warnings.warn(
|
||||
'CephBackend will be deprecated, please use PetrelBackend instead')
|
||||
'CephBackend will be deprecated, please use PetrelBackend instead',
|
||||
DeprecationWarning)
|
||||
|
||||
# CephClient and PetrelBackend have the same prefix 's3://' and the latter
|
||||
# will be chosen as default. If PetrelBackend can not be instantiated
|
||||
|
@ -389,8 +390,9 @@ def load_from_torchvision(filename, map_location=None):
|
|||
"""
|
||||
model_urls = get_torchvision_models()
|
||||
if filename.startswith('modelzoo://'):
|
||||
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
||||
'use "torchvision://" instead')
|
||||
warnings.warn(
|
||||
'The URL scheme of "modelzoo://" is deprecated, please '
|
||||
'use "torchvision://" instead', DeprecationWarning)
|
||||
model_name = filename[11:]
|
||||
else:
|
||||
model_name = filename[14:]
|
||||
|
@ -422,8 +424,10 @@ def load_from_openmmlab(filename, map_location=None):
|
|||
|
||||
deprecated_urls = get_deprecated_model_names()
|
||||
if model_name in deprecated_urls:
|
||||
warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
|
||||
f'of {prefix_str}{deprecated_urls[model_name]}')
|
||||
warnings.warn(
|
||||
f'{prefix_str}{model_name} is deprecated in favor '
|
||||
f'of {prefix_str}{deprecated_urls[model_name]}',
|
||||
DeprecationWarning)
|
||||
model_name = deprecated_urls[model_name]
|
||||
model_url = model_urls[model_name]
|
||||
# check if is url
|
||||
|
|
|
@ -183,5 +183,6 @@ class Runner(EpochBasedRunner):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
'Runner was deprecated, please use EpochBasedRunner instead')
|
||||
'Runner was deprecated, please use EpochBasedRunner instead',
|
||||
DeprecationWarning)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
|
@ -227,7 +227,8 @@ def force_fp32(apply_to=None, out_fp16=False):
|
|||
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
|
||||
warnings.warning(
|
||||
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
|
||||
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
|
||||
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads',
|
||||
DeprecationWarning)
|
||||
_allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
|
||||
|
||||
|
||||
|
|
|
@ -231,5 +231,6 @@ class TRTWraper(TRTWrapper):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
warnings.warn('TRTWraper will be deprecated in'
|
||||
' future. Please use TRTWrapper instead')
|
||||
warnings.warn(
|
||||
'TRTWraper will be deprecated in'
|
||||
' future. Please use TRTWrapper instead', DeprecationWarning)
|
||||
|
|
|
@ -229,7 +229,7 @@ class Config:
|
|||
if 'reference' in deprecation_info:
|
||||
warning_msg += ' More information can be found at ' \
|
||||
f'{deprecation_info["reference"]}'
|
||||
warnings.warn(warning_msg)
|
||||
warnings.warn(warning_msg, DeprecationWarning)
|
||||
|
||||
cfg_text = filename + '\n'
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
|
|
|
@ -28,10 +28,11 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
|
|||
return False
|
||||
|
||||
def _legacy_zip_load(filename, model_dir, map_location):
|
||||
warnings.warn('Falling back to the old format < 1.6. This support will'
|
||||
' be deprecated in favor of default zipfile format '
|
||||
'introduced in 1.6. Please redo torch.save() to save it '
|
||||
'in the new zipfile format.')
|
||||
warnings.warn(
|
||||
'Falling back to the old format < 1.6. This support will'
|
||||
' be deprecated in favor of default zipfile format '
|
||||
'introduced in 1.6. Please redo torch.save() to save it '
|
||||
'in the new zipfile format.', DeprecationWarning)
|
||||
# Note: extractall() defaults to overwrite file if exists. No need to
|
||||
# clean up beforehand. We deliberately don't handle tarfile here
|
||||
# since our legacy serialization format was in tar.
|
||||
|
@ -84,8 +85,9 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
|
|||
"""
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_MODEL_ZOO'):
|
||||
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env '
|
||||
'TORCH_HOME instead')
|
||||
warnings.warn(
|
||||
'TORCH_MODEL_ZOO is deprecated, please use env '
|
||||
'TORCH_HOME instead', DeprecationWarning)
|
||||
|
||||
if model_dir is None:
|
||||
torch_home = _get_torch_home()
|
||||
|
|
|
@ -315,7 +315,7 @@ def deprecated_api_warning(name_dict, cls_name=None):
|
|||
warnings.warn(
|
||||
f'"{src_arg_name}" is deprecated in '
|
||||
f'`{func_name}`, please use "{dst_arg_name}" '
|
||||
'instead')
|
||||
'instead', DeprecationWarning)
|
||||
arg_names[arg_names.index(src_arg_name)] = dst_arg_name
|
||||
if kwargs:
|
||||
for src_arg_name, dst_arg_name in name_dict.items():
|
||||
|
@ -333,7 +333,7 @@ def deprecated_api_warning(name_dict, cls_name=None):
|
|||
warnings.warn(
|
||||
f'"{src_arg_name}" is deprecated in '
|
||||
f'`{func_name}`, please use "{dst_arg_name}" '
|
||||
'instead')
|
||||
'instead', DeprecationWarning)
|
||||
kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
|
||||
|
||||
# apply converted arguments to the decorated method
|
||||
|
|
|
@ -251,7 +251,8 @@ class Registry:
|
|||
warnings.warn(
|
||||
'The old API of register_module(module, force=False) '
|
||||
'is deprecated and will be removed, please use the new API '
|
||||
'register_module(name=None, force=False, module=None) instead.')
|
||||
'register_module(name=None, force=False, module=None) instead.',
|
||||
DeprecationWarning)
|
||||
if cls is None:
|
||||
return partial(self.deprecated_register_module, force=force)
|
||||
self._register_module(cls, force=force)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from torch.nn.functional import sigmoid
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmcv.cnn.bricks import Swish
|
||||
|
||||
|
@ -7,7 +7,7 @@ from mmcv.cnn.bricks import Swish
|
|||
def test_swish():
|
||||
act = Swish()
|
||||
input = torch.randn(1, 3, 64, 64)
|
||||
expected_output = input * sigmoid(input)
|
||||
expected_output = input * F.sigmoid(input)
|
||||
output = act(input)
|
||||
# test output shape
|
||||
assert output.shape == expected_output.shape
|
||||
|
|
|
@ -333,7 +333,7 @@ class TestPhotometric:
|
|||
|
||||
input_img = np.array(
|
||||
[[[0, 128, 255], [255, 128, 0]], [[0, 128, 255], [255, 128, 0]]],
|
||||
dtype=np.float)
|
||||
dtype=float)
|
||||
img = mmcv.lut_transform(input_img, lut_table)
|
||||
baseline = cv2.LUT(np.array(input_img, dtype=np.uint8), lut_table)
|
||||
assert np.allclose(img, baseline)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
@ -15,7 +14,8 @@ class TestBilinearGridSample(object):
|
|||
|
||||
input = torch.rand(1, 1, 20, 20, dtype=dtype)
|
||||
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
|
||||
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
|
||||
grid = F.affine_grid(
|
||||
grid, (1, 1, 15, 15), align_corners=align_corners).type_as(input)
|
||||
grid *= multiplier
|
||||
|
||||
out = bilinear_grid_sample(input, grid, align_corners=align_corners)
|
||||
|
|
|
@ -8,6 +8,7 @@ import onnxruntime as rt
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from packaging import version
|
||||
|
||||
onnx_file = 'tmp.onnx'
|
||||
|
@ -87,10 +88,11 @@ def test_grid_sample(mode, padding_mode, align_corners):
|
|||
|
||||
input = torch.rand(1, 1, 10, 10)
|
||||
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
|
||||
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
|
||||
grid = F.affine_grid(
|
||||
grid, (1, 1, 15, 15), align_corners=align_corners).type_as(input)
|
||||
|
||||
def func(input, grid):
|
||||
return nn.functional.grid_sample(
|
||||
return F.grid_sample(
|
||||
input,
|
||||
grid,
|
||||
mode=mode,
|
||||
|
@ -110,7 +112,8 @@ def test_bilinear_grid_sample(align_corners):
|
|||
|
||||
input = torch.rand(1, 1, 10, 10)
|
||||
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
|
||||
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
|
||||
grid = F.affine_grid(
|
||||
grid, (1, 1, 15, 15), align_corners=align_corners).type_as(input)
|
||||
|
||||
def func(input, grid):
|
||||
return bilinear_grid_sample(input, grid, align_corners=align_corners)
|
||||
|
@ -462,7 +465,7 @@ def test_interpolate():
|
|||
register_extra_symbolics(opset_version)
|
||||
|
||||
def func(feat, scale_factor=2):
|
||||
out = nn.functional.interpolate(feat, scale_factor=scale_factor)
|
||||
out = F.interpolate(feat, scale_factor=scale_factor)
|
||||
return out
|
||||
|
||||
net = WrapFunction(func)
|
||||
|
|
|
@ -7,6 +7,7 @@ import onnx
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from mmcv.tensorrt import (TRTWrapper, is_tensorrt_plugin_loaded, onnx2trt,
|
||||
|
@ -487,11 +488,10 @@ def test_grid_sample(mode, padding_mode, align_corners):
|
|||
|
||||
input = torch.rand(1, 1, 10, 10).cuda()
|
||||
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
|
||||
grid = nn.functional.affine_grid(grid,
|
||||
(1, 1, 15, 15)).type_as(input).cuda()
|
||||
grid = F.affine_grid(grid, (1, 1, 15, 15)).type_as(input).cuda()
|
||||
|
||||
def func(input, grid):
|
||||
return nn.functional.grid_sample(
|
||||
return F.grid_sample(
|
||||
input,
|
||||
grid,
|
||||
mode=mode,
|
||||
|
|
|
@ -39,7 +39,7 @@ def test_voxelization(device_type):
|
|||
device = torch.device(device_type)
|
||||
|
||||
# test hard_voxelization on cpu/gpu
|
||||
points = torch.tensor(points).contiguous().to(device)
|
||||
points = points.contiguous().to(device)
|
||||
coors, voxels, num_points_per_voxel = hard_voxelization.forward(points)
|
||||
coors = coors.cpu().detach().numpy()
|
||||
voxels = voxels.cpu().detach().numpy()
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_build_runner():
|
|||
@pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
|
||||
def test_epoch_based_runner(runner_class):
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
with pytest.warns(DeprecationWarning):
|
||||
# batch_processor is deprecated
|
||||
model = OldStyleModel()
|
||||
|
||||
|
|
|
@ -533,6 +533,6 @@ def test_deprecation():
|
|||
]
|
||||
|
||||
for cfg_file in deprecated_cfg_files:
|
||||
with pytest.warns(UserWarning):
|
||||
with pytest.warns(DeprecationWarning):
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
assert cfg.item1 == 'expected'
|
||||
|
|
|
@ -96,15 +96,15 @@ def test_registry():
|
|||
pass
|
||||
|
||||
# begin: test old APIs
|
||||
with pytest.warns(UserWarning):
|
||||
with pytest.warns(DeprecationWarning):
|
||||
CATS.register_module(SphynxCat)
|
||||
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
with pytest.warns(DeprecationWarning):
|
||||
CATS.register_module(SphynxCat, force=True)
|
||||
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
with pytest.warns(DeprecationWarning):
|
||||
|
||||
@CATS.register_module
|
||||
class NewCat:
|
||||
|
@ -112,11 +112,11 @@ def test_registry():
|
|||
|
||||
assert CATS.get('NewCat').__name__ == 'NewCat'
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
with pytest.warns(DeprecationWarning):
|
||||
CATS.deprecated_register_module(SphynxCat, force=True)
|
||||
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
with pytest.warns(DeprecationWarning):
|
||||
|
||||
@CATS.deprecated_register_module
|
||||
class CuteCat:
|
||||
|
@ -124,7 +124,7 @@ def test_registry():
|
|||
|
||||
assert CATS.get('CuteCat').__name__ == 'CuteCat'
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
with pytest.warns(DeprecationWarning):
|
||||
|
||||
@CATS.deprecated_register_module(force=True)
|
||||
class NewCat2:
|
||||
|
|
|
@ -10,7 +10,7 @@ def test_color():
|
|||
assert mmcv.color_val('green') == (0, 255, 0)
|
||||
assert mmcv.color_val((1, 2, 3)) == (1, 2, 3)
|
||||
assert mmcv.color_val(100) == (100, 100, 100)
|
||||
assert mmcv.color_val(np.zeros(3, dtype=np.int)) == (0, 0, 0)
|
||||
assert mmcv.color_val(np.zeros(3, dtype=int)) == (0, 0, 0)
|
||||
with pytest.raises(TypeError):
|
||||
mmcv.color_val([255, 255, 255])
|
||||
with pytest.raises(TypeError):
|
||||
|
|
Loading…
Reference in New Issue