Add pyupgrade pre-commit hook (#1937)

* add pyupgrade

* add options for pyupgrade

* minor refinement
pull/1968/head
Zaida Zhou 2022-05-18 11:47:14 +08:00 committed by GitHub
parent c561264d55
commit 45fa3e44a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
110 changed files with 339 additions and 360 deletions

View File

@ -42,7 +42,7 @@ def parse_args():
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
def train_step(self, *args, **kwargs):
@ -159,13 +159,13 @@ def run(cfg, logger):
def plot_lr_curve(json_file, cfg):
data_dict = dict(LearningRate=[], Momentum=[])
assert os.path.isfile(json_file)
with open(json_file, 'r') as f:
with open(json_file) as f:
for line in f:
log = json.loads(line.strip())
data_dict['LearningRate'].append(log['lr'])
data_dict['Momentum'].append(log['momentum'])
wind_w, wind_h = [int(size) for size in cfg.window_size.split('*')]
wind_w, wind_h = (int(size) for size in cfg.window_size.split('*'))
# if legend is None, use {filename}_{key} as legend
fig, axes = plt.subplots(2, 1, figsize=(wind_w, wind_h))
plt.subplots_adjust(hspace=0.5)

View File

@ -43,7 +43,11 @@ repos:
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- repo: https://github.com/asottile/pyupgrade
rev: v2.32.1
hooks:
- id: pyupgrade
args: ["--py36-plus"]
- repo: https://github.com/open-mmlab/pre-commit-hooks
rev: v0.2.0 # Use the ref you want to point at
hooks:

View File

@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
sys.path.insert(0, os.path.abspath('../..'))
version_file = '../../mmcv/version.py'
with open(version_file, 'r') as f:
with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec'))
__version__ = locals()['__version__']

View File

@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
sys.path.insert(0, os.path.abspath('../..'))
version_file = '../../mmcv/version.py'
with open(version_file, 'r') as f:
with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec'))
__version__ = locals()['__version__']

View File

@ -14,7 +14,7 @@ from mmcv.utils import get_logger
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)

View File

@ -12,7 +12,7 @@ class AlexNet(nn.Module):
"""
def __init__(self, num_classes=-1):
super(AlexNet, self).__init__()
super().__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),

View File

@ -29,7 +29,7 @@ class Clamp(nn.Module):
"""
def __init__(self, min=-1., max=1.):
super(Clamp, self).__init__()
super().__init__()
self.min = min
self.max = max

View File

@ -38,7 +38,7 @@ class ContextBlock(nn.Module):
ratio,
pooling_type='att',
fusion_types=('channel_add', )):
super(ContextBlock, self).__init__()
super().__init__()
assert pooling_type in ['avg', 'att']
assert isinstance(fusion_types, (list, tuple))
valid_fusion_types = ['channel_add', 'channel_mul']

View File

@ -83,7 +83,7 @@ class ConvModule(nn.Module):
with_spectral_norm=False,
padding_mode='zeros',
order=('conv', 'norm', 'act')):
super(ConvModule, self).__init__()
super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
assert act_cfg is None or isinstance(act_cfg, dict)
@ -96,7 +96,7 @@ class ConvModule(nn.Module):
self.with_explicit_padding = padding_mode not in official_padding_mode
self.order = order
assert isinstance(self.order, tuple) and len(self.order) == 3
assert set(order) == set(['conv', 'norm', 'act'])
assert set(order) == {'conv', 'norm', 'act'}
self.with_norm = norm_cfg is not None
self.with_activation = act_cfg is not None

View File

@ -35,7 +35,7 @@ class ConvWS2d(nn.Conv2d):
groups=1,
bias=True,
eps=1e-5):
super(ConvWS2d, self).__init__(
super().__init__(
in_channels,
out_channels,
kernel_size,

View File

@ -59,7 +59,7 @@ class DepthwiseSeparableConvModule(nn.Module):
pw_norm_cfg='default',
pw_act_cfg='default',
**kwargs):
super(DepthwiseSeparableConvModule, self).__init__()
super().__init__()
assert 'groups' not in kwargs, 'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not

View File

@ -37,7 +37,7 @@ class DropPath(nn.Module):
"""
def __init__(self, drop_prob=0.1):
super(DropPath, self).__init__()
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):

View File

@ -54,7 +54,7 @@ class GeneralizedAttention(nn.Module):
q_stride=1,
attention_type='1111'):
super(GeneralizedAttention, self).__init__()
super().__init__()
# hard range means local range for non-local operation
self.position_embedding_dim = (

View File

@ -27,7 +27,7 @@ class HSigmoid(nn.Module):
"""
def __init__(self, bias=3.0, divisor=6.0, min_value=0.0, max_value=1.0):
super(HSigmoid, self).__init__()
super().__init__()
warnings.warn(
'In MMCV v1.4.4, we modified the default value of args to align '
'with PyTorch official. Previous Implementation: '

View File

@ -22,7 +22,7 @@ class HSwish(nn.Module):
"""
def __init__(self, inplace=False):
super(HSwish, self).__init__()
super().__init__()
self.act = nn.ReLU6(inplace)
def forward(self, x):

View File

@ -40,7 +40,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
norm_cfg=None,
mode='embedded_gaussian',
**kwargs):
super(_NonLocalNd, self).__init__()
super().__init__()
self.in_channels = in_channels
self.reduction = reduction
self.use_scale = use_scale
@ -228,8 +228,7 @@ class NonLocal1d(_NonLocalNd):
sub_sample=False,
conv_cfg=dict(type='Conv1d'),
**kwargs):
super(NonLocal1d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
@ -262,8 +261,7 @@ class NonLocal2d(_NonLocalNd):
sub_sample=False,
conv_cfg=dict(type='Conv2d'),
**kwargs):
super(NonLocal2d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
@ -293,8 +291,7 @@ class NonLocal3d(_NonLocalNd):
sub_sample=False,
conv_cfg=dict(type='Conv3d'),
**kwargs):
super(NonLocal3d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
if sub_sample:

View File

@ -14,7 +14,7 @@ class Scale(nn.Module):
"""
def __init__(self, scale=1.0):
super(Scale, self).__init__()
super().__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x):

View File

@ -19,7 +19,7 @@ class Swish(nn.Module):
"""
def __init__(self):
super(Swish, self).__init__()
super().__init__()
def forward(self, x):
return x * torch.sigmoid(x)

View File

@ -96,7 +96,7 @@ class AdaptivePadding(nn.Module):
"""
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
super(AdaptivePadding, self).__init__()
super().__init__()
assert padding in ('same', 'corner')
kernel_size = to_2tuple(kernel_size)
@ -190,7 +190,7 @@ class PatchEmbed(BaseModule):
norm_cfg=None,
input_size=None,
init_cfg=None):
super(PatchEmbed, self).__init__(init_cfg=init_cfg)
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
if stride is None:
@ -435,7 +435,7 @@ class MultiheadAttention(BaseModule):
init_cfg=None,
batch_first=False,
**kwargs):
super(MultiheadAttention, self).__init__(init_cfg)
super().__init__(init_cfg)
if 'dropout' in kwargs:
warnings.warn(
'The arguments `dropout` in MultiheadAttention '
@ -590,7 +590,7 @@ class FFN(BaseModule):
add_identity=True,
init_cfg=None,
**kwargs):
super(FFN, self).__init__(init_cfg)
super().__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims
@ -694,12 +694,12 @@ class BaseTransformerLayer(BaseModule):
f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
ffn_cfgs[new_name] = kwargs[ori_name]
super(BaseTransformerLayer, self).__init__(init_cfg)
super().__init__(init_cfg)
self.batch_first = batch_first
assert set(operation_order) & set(
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
assert set(operation_order) & {
'self_attn', 'norm', 'ffn', 'cross_attn'} == \
set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \
@ -880,7 +880,7 @@ class TransformerLayerSequence(BaseModule):
"""
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super(TransformerLayerSequence, self).__init__(init_cfg)
super().__init__(init_cfg)
if isinstance(transformerlayers, dict):
transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers)

View File

@ -26,7 +26,7 @@ class PixelShufflePack(nn.Module):
def __init__(self, in_channels, out_channels, scale_factor,
upsample_kernel):
super(PixelShufflePack, self).__init__()
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.scale_factor = scale_factor

View File

@ -30,7 +30,7 @@ class BasicBlock(nn.Module):
downsample=None,
style='pytorch',
with_cp=False):
super(BasicBlock, self).__init__()
super().__init__()
assert style in ['pytorch', 'caffe']
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes)
@ -77,7 +77,7 @@ class Bottleneck(nn.Module):
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__()
super().__init__()
assert style in ['pytorch', 'caffe']
if style == 'pytorch':
conv1_stride = 1
@ -218,7 +218,7 @@ class ResNet(nn.Module):
bn_eval=True,
bn_frozen=False,
with_cp=False):
super(ResNet, self).__init__()
super().__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
assert num_stages >= 1 and num_stages <= 4
@ -293,7 +293,7 @@ class ResNet(nn.Module):
return tuple(outs)
def train(self, mode=True):
super(ResNet, self).train(mode)
super().train(mode)
if self.bn_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):

View File

@ -277,10 +277,10 @@ def print_model_with_flops(model,
return ', '.join([
params_to_string(
accumulated_num_params, units='M', precision=precision),
'{:.3%} Params'.format(accumulated_num_params / total_params),
f'{accumulated_num_params / total_params:.3%} Params',
flops_to_string(
accumulated_flops_cost, units=units, precision=precision),
'{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
f'{accumulated_flops_cost / total_flops:.3%} FLOPs',
self.original_extra_repr()
])

View File

@ -129,7 +129,7 @@ def _get_bases_name(m):
return [b.__name__ for b in m.__class__.__bases__]
class BaseInit(object):
class BaseInit:
def __init__(self, *, bias=0, bias_prob=None, layer=None):
self.wholemodule = False
@ -461,7 +461,7 @@ class Caffe2XavierInit(KaimingInit):
@INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object):
class PretrainedInit:
"""Initialize module by loading a pretrained model.
Args:

View File

@ -70,7 +70,7 @@ class VGG(nn.Module):
bn_frozen=False,
ceil_mode=False,
with_last_pool=True):
super(VGG, self).__init__()
super().__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for vgg')
assert num_stages >= 1 and num_stages <= 5
@ -157,7 +157,7 @@ class VGG(nn.Module):
return tuple(outs)
def train(self, mode=True):
super(VGG, self).train(mode)
super().train(mode)
if self.bn_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):

View File

@ -33,7 +33,7 @@ class MLUDataParallel(MMDataParallel):
"""
def __init__(self, *args, dim=0, **kwargs):
super(MLUDataParallel, self).__init__(*args, dim=dim, **kwargs)
super().__init__(*args, dim=dim, **kwargs)
self.device_ids = [0]
self.src_device_obj = torch.device('mlu:0')

View File

@ -210,9 +210,9 @@ class PetrelBackend(BaseStorageBackend):
"""
if not has_method(self._client, 'delete'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `delete` method, please use a higher version or dev'
' branch instead.'))
'Current version of Petrel Python SDK has not supported '
'the `delete` method, please use a higher version or dev'
' branch instead.')
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
@ -230,9 +230,9 @@ class PetrelBackend(BaseStorageBackend):
if not (has_method(self._client, 'contains')
and has_method(self._client, 'isdir')):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `contains` and `isdir` methods, please use a higher'
'version or dev branch instead.'))
'Current version of Petrel Python SDK has not supported '
'the `contains` and `isdir` methods, please use a higher'
'version or dev branch instead.')
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
@ -251,9 +251,9 @@ class PetrelBackend(BaseStorageBackend):
"""
if not has_method(self._client, 'isdir'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `isdir` method, please use a higher version or dev'
' branch instead.'))
'Current version of Petrel Python SDK has not supported '
'the `isdir` method, please use a higher version or dev'
' branch instead.')
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
@ -271,9 +271,9 @@ class PetrelBackend(BaseStorageBackend):
"""
if not has_method(self._client, 'contains'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `contains` method, please use a higher version or '
'dev branch instead.'))
'Current version of Petrel Python SDK has not supported '
'the `contains` method, please use a higher version or '
'dev branch instead.')
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
@ -366,9 +366,9 @@ class PetrelBackend(BaseStorageBackend):
"""
if not has_method(self._client, 'list'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `list` method, please use a higher version or dev'
' branch instead.'))
'Current version of Petrel Python SDK has not supported '
'the `list` method, please use a higher version or dev'
' branch instead.')
dir_path = self._map_path(dir_path)
dir_path = self._format_path(dir_path)
@ -549,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns:
str: Expected text reading from ``filepath``.
"""
with open(filepath, 'r', encoding=encoding) as f:
with open(filepath, encoding=encoding) as f:
value_buf = f.read()
return value_buf

View File

@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler):
return pickle.load(file, **kwargs)
def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path(
filepath, mode='rb', **kwargs)
return super().load_from_path(filepath, mode='rb', **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('protocol', 2)
@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler):
pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path(
obj, filepath, mode='wb', **kwargs)
super().dump_to_path(obj, filepath, mode='wb', **kwargs)

View File

@ -157,7 +157,7 @@ def imresize_to_multiple(img,
size = _scale_size((w, h), scale_factor)
divisor = to_2tuple(divisor)
size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
size = tuple(int(np.ceil(s / d)) * d for s, d in zip(size, divisor))
resized_img, w_scale, h_scale = imresize(
img,
size,

View File

@ -59,7 +59,7 @@ def _parse_arg(value, desc):
raise RuntimeError(
"ONNX symbolic doesn't know to interpret ListConstruct node")
raise RuntimeError('Unexpected node type: {}'.format(value.node().kind()))
raise RuntimeError(f'Unexpected node type: {value.node().kind()}')
def _maybe_get_const(value, desc):

View File

@ -86,7 +86,7 @@ class BorderAlign(nn.Module):
"""
def __init__(self, pool_size):
super(BorderAlign, self).__init__()
super().__init__()
self.pool_size = pool_size
def forward(self, input, boxes):

View File

@ -131,7 +131,7 @@ def box_iou_rotated(bboxes1,
if aligned:
ious = bboxes1.new_zeros(rows)
else:
ious = bboxes1.new_zeros((rows * cols))
ious = bboxes1.new_zeros(rows * cols)
if not clockwise:
flip_mat = bboxes1.new_ones(bboxes1.shape[-1])
flip_mat[-1] = -1

View File

@ -85,7 +85,7 @@ carafe_naive = CARAFENaiveFunction.apply
class CARAFENaive(Module):
def __init__(self, kernel_size, group_size, scale_factor):
super(CARAFENaive, self).__init__()
super().__init__()
assert isinstance(kernel_size, int) and isinstance(
group_size, int) and isinstance(scale_factor, int)
@ -195,7 +195,7 @@ class CARAFE(Module):
"""
def __init__(self, kernel_size, group_size, scale_factor):
super(CARAFE, self).__init__()
super().__init__()
assert isinstance(kernel_size, int) and isinstance(
group_size, int) and isinstance(scale_factor, int)
@ -238,7 +238,7 @@ class CARAFEPack(nn.Module):
encoder_kernel=3,
encoder_dilation=1,
compressed_channels=64):
super(CARAFEPack, self).__init__()
super().__init__()
self.channels = channels
self.scale_factor = scale_factor
self.up_kernel = up_kernel

View File

@ -125,7 +125,7 @@ class CornerPool(nn.Module):
}
def __init__(self, mode):
super(CornerPool, self).__init__()
super().__init__()
assert mode in self.pool_functions
self.mode = mode
self.corner_pool = self.pool_functions[mode]

View File

@ -236,7 +236,7 @@ class DeformConv2d(nn.Module):
deform_groups: int = 1,
bias: bool = False,
im2col_step: int = 32) -> None:
super(DeformConv2d, self).__init__()
super().__init__()
assert not bias, \
f'bias={bias} is not supported in DeformConv2d.'
@ -356,7 +356,7 @@ class DeformConv2dPack(DeformConv2d):
_version = 2
def __init__(self, *args, **kwargs):
super(DeformConv2dPack, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],

View File

@ -96,7 +96,7 @@ class DeformRoIPool(nn.Module):
spatial_scale=1.0,
sampling_ratio=0,
gamma=0.1):
super(DeformRoIPool, self).__init__()
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
self.sampling_ratio = int(sampling_ratio)
@ -117,8 +117,7 @@ class DeformRoIPoolPack(DeformRoIPool):
spatial_scale=1.0,
sampling_ratio=0,
gamma=0.1):
super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale,
sampling_ratio, gamma)
super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
self.output_channels = output_channels
self.deform_fc_channels = deform_fc_channels
@ -158,8 +157,7 @@ class ModulatedDeformRoIPoolPack(DeformRoIPool):
spatial_scale=1.0,
sampling_ratio=0,
gamma=0.1):
super(ModulatedDeformRoIPoolPack,
self).__init__(output_size, spatial_scale, sampling_ratio, gamma)
super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
self.output_channels = output_channels
self.deform_fc_channels = deform_fc_channels

View File

@ -89,7 +89,7 @@ sigmoid_focal_loss = SigmoidFocalLossFunction.apply
class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
super(SigmoidFocalLoss, self).__init__()
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.register_buffer('weight', weight)
@ -195,7 +195,7 @@ softmax_focal_loss = SoftmaxFocalLossFunction.apply
class SoftmaxFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
super(SoftmaxFocalLoss, self).__init__()
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.register_buffer('weight', weight)

View File

@ -212,7 +212,7 @@ class FusedBiasLeakyReLU(nn.Module):
"""
def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
super(FusedBiasLeakyReLU, self).__init__()
super().__init__()
self.bias = nn.Parameter(torch.zeros(num_channels))
self.negative_slope = negative_slope

View File

@ -98,13 +98,12 @@ class MaskedConv2d(nn.Conv2d):
dilation=1,
groups=1,
bias=True):
super(MaskedConv2d,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, input, mask=None):
if mask is None: # fallback to the normal Conv2d
return super(MaskedConv2d, self).forward(input)
return super().forward(input)
else:
return masked_conv2d(input, mask, self.weight, self.bias,
self.padding)

View File

@ -53,7 +53,7 @@ class BaseMergeCell(nn.Module):
input_conv_cfg=None,
input_norm_cfg=None,
upsample_mode='nearest'):
super(BaseMergeCell, self).__init__()
super().__init__()
assert upsample_mode in ['nearest', 'bilinear']
self.with_out_conv = with_out_conv
self.with_input1_conv = with_input1_conv
@ -121,7 +121,7 @@ class BaseMergeCell(nn.Module):
class SumCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs):
super(SumCell, self).__init__(in_channels, out_channels, **kwargs)
super().__init__(in_channels, out_channels, **kwargs)
def _binary_op(self, x1, x2):
return x1 + x2
@ -130,8 +130,7 @@ class SumCell(BaseMergeCell):
class ConcatCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs):
super(ConcatCell, self).__init__(in_channels * 2, out_channels,
**kwargs)
super().__init__(in_channels * 2, out_channels, **kwargs)
def _binary_op(self, x1, x2):
ret = torch.cat([x1, x2], dim=1)

View File

@ -168,7 +168,7 @@ class ModulatedDeformConv2d(nn.Module):
groups=1,
deform_groups=1,
bias=True):
super(ModulatedDeformConv2d, self).__init__()
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
@ -227,7 +227,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
_version = 2
def __init__(self, *args, **kwargs):
super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
@ -239,7 +239,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
self.init_weights()
def init_weights(self):
super(ModulatedDeformConv2dPack, self).init_weights()
super().init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()

View File

@ -296,7 +296,7 @@ class SimpleRoIAlign(nn.Module):
If True, align the results more perfectly.
"""
super(SimpleRoIAlign, self).__init__()
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
# to be consistent with other RoI ops

View File

@ -72,7 +72,7 @@ psa_mask = PSAMaskFunction.apply
class PSAMask(nn.Module):
def __init__(self, psa_type, mask_size=None):
super(PSAMask, self).__init__()
super().__init__()
assert psa_type in ['collect', 'distribute']
if psa_type == 'collect':
psa_type_enum = 0

View File

@ -116,7 +116,7 @@ class RiRoIAlignRotated(nn.Module):
num_samples=0,
num_orientations=8,
clockwise=False):
super(RiRoIAlignRotated, self).__init__()
super().__init__()
self.out_size = out_size
self.spatial_scale = float(spatial_scale)

View File

@ -181,7 +181,7 @@ class RoIAlign(nn.Module):
pool_mode='avg',
aligned=True,
use_torchvision=False):
super(RoIAlign, self).__init__()
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)

View File

@ -156,7 +156,7 @@ class RoIAlignRotated(nn.Module):
sampling_ratio=0,
aligned=True,
clockwise=False):
super(RoIAlignRotated, self).__init__()
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)

View File

@ -71,7 +71,7 @@ roi_pool = RoIPoolFunction.apply
class RoIPool(nn.Module):
def __init__(self, output_size, spatial_scale=1.0):
super(RoIPool, self).__init__()
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)

View File

@ -64,7 +64,7 @@ class SparseConvolution(SparseModule):
inverse=False,
indice_key=None,
fused_bn=False):
super(SparseConvolution, self).__init__()
super().__init__()
assert groups == 1
if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * ndim
@ -217,7 +217,7 @@ class SparseConv2d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SparseConv2d, self).__init__(
super().__init__(
2,
in_channels,
out_channels,
@ -243,7 +243,7 @@ class SparseConv3d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SparseConv3d, self).__init__(
super().__init__(
3,
in_channels,
out_channels,
@ -269,7 +269,7 @@ class SparseConv4d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SparseConv4d, self).__init__(
super().__init__(
4,
in_channels,
out_channels,
@ -295,7 +295,7 @@ class SparseConvTranspose2d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SparseConvTranspose2d, self).__init__(
super().__init__(
2,
in_channels,
out_channels,
@ -322,7 +322,7 @@ class SparseConvTranspose3d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SparseConvTranspose3d, self).__init__(
super().__init__(
3,
in_channels,
out_channels,
@ -345,7 +345,7 @@ class SparseInverseConv2d(SparseConvolution):
kernel_size,
indice_key=None,
bias=True):
super(SparseInverseConv2d, self).__init__(
super().__init__(
2,
in_channels,
out_channels,
@ -364,7 +364,7 @@ class SparseInverseConv3d(SparseConvolution):
kernel_size,
indice_key=None,
bias=True):
super(SparseInverseConv3d, self).__init__(
super().__init__(
3,
in_channels,
out_channels,
@ -387,7 +387,7 @@ class SubMConv2d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SubMConv2d, self).__init__(
super().__init__(
2,
in_channels,
out_channels,
@ -414,7 +414,7 @@ class SubMConv3d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SubMConv3d, self).__init__(
super().__init__(
3,
in_channels,
out_channels,
@ -441,7 +441,7 @@ class SubMConv4d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SubMConv4d, self).__init__(
super().__init__(
4,
in_channels,
out_channels,

View File

@ -86,7 +86,7 @@ class SparseSequential(SparseModule):
"""
def __init__(self, *args, **kwargs):
super(SparseSequential, self).__init__()
super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
@ -103,7 +103,7 @@ class SparseSequential(SparseModule):
def __getitem__(self, idx):
if not (-len(self) <= idx < len(self)):
raise IndexError('index {} is out of range'.format(idx))
raise IndexError(f'index {idx} is out of range')
if idx < 0:
idx += len(self)
it = iter(self._modules.values())

View File

@ -29,7 +29,7 @@ class SparseMaxPool(SparseModule):
padding=0,
dilation=1,
subm=False):
super(SparseMaxPool, self).__init__()
super().__init__()
if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * ndim
if not isinstance(stride, (list, tuple)):
@ -77,12 +77,10 @@ class SparseMaxPool(SparseModule):
class SparseMaxPool2d(SparseMaxPool):
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool2d, self).__init__(2, kernel_size, stride, padding,
dilation)
super().__init__(2, kernel_size, stride, padding, dilation)
class SparseMaxPool3d(SparseMaxPool):
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool3d, self).__init__(3, kernel_size, stride, padding,
dilation)
super().__init__(3, kernel_size, stride, padding, dilation)

View File

@ -18,7 +18,7 @@ def scatter_nd(indices, updates, shape):
return ret
class SparseConvTensor(object):
class SparseConvTensor:
def __init__(self,
features,

View File

@ -198,7 +198,7 @@ class SyncBatchNorm(Module):
track_running_stats=True,
group=None,
stats_mode='default'):
super(SyncBatchNorm, self).__init__()
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum

View File

@ -32,7 +32,7 @@ class MMDataParallel(DataParallel):
"""
def __init__(self, *args, dim=0, **kwargs):
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
super().__init__(*args, dim=dim, **kwargs)
self.dim = dim
def forward(self, *inputs, **kwargs):

View File

@ -18,7 +18,7 @@ class MMDistributedDataParallel(nn.Module):
dim=0,
broadcast_buffers=True,
bucket_cap_mb=25):
super(MMDistributedDataParallel, self).__init__()
super().__init__()
self.module = module
self.dim = dim
self.broadcast_buffers = broadcast_buffers

View File

@ -35,7 +35,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.
super(BaseModule, self).__init__()
super().__init__()
# define default value of init_cfg instead of hard code
# in init_weights() function
self._is_init = False

View File

@ -83,8 +83,8 @@ class CheckpointHook(Hook):
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.'))
runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.')
# disable the create_symlink option because some file backends do not
# allow to create a symlink
@ -93,9 +93,9 @@ class CheckpointHook(Hook):
'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False
warnings.warn(
('create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}'))
'create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}')
else:
self.args['create_symlink'] = self.file_client.allow_symlink

View File

@ -214,8 +214,8 @@ class EvalHook(Hook):
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(
(f'The best checkpoint will be saved to {self.out_dir} by '
f'{self.file_client.name}'))
f'The best checkpoint will be saved to {self.out_dir} by '
f'{self.file_client.name}')
if self.save_best is not None:
if runner.meta is None:
@ -335,8 +335,8 @@ class EvalHook(Hook):
self.best_ckpt_path):
self.file_client.remove(self.best_ckpt_path)
runner.logger.info(
(f'The previous best checkpoint {self.best_ckpt_path} was '
'removed'))
f'The previous best checkpoint {self.best_ckpt_path} was '
'removed')
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
self.best_ckpt_path = self.file_client.join_path(

View File

@ -34,8 +34,7 @@ class ClearMLLoggerHook(LoggerHook):
ignore_last=True,
reset_flag=False,
by_epoch=True):
super(ClearMLLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_clearml()
self.init_kwargs = init_kwargs
@ -49,7 +48,7 @@ class ClearMLLoggerHook(LoggerHook):
@master_only
def before_run(self, runner):
super(ClearMLLoggerHook, self).before_run(runner)
super().before_run(runner)
task_kwargs = self.init_kwargs if self.init_kwargs else {}
self.task = self.clearml.Task.init(**task_kwargs)
self.task_logger = self.task.get_logger()

View File

@ -40,8 +40,7 @@ class MlflowLoggerHook(LoggerHook):
ignore_last=True,
reset_flag=False,
by_epoch=True):
super(MlflowLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_mlflow()
self.exp_name = exp_name
self.tags = tags
@ -59,7 +58,7 @@ class MlflowLoggerHook(LoggerHook):
@master_only
def before_run(self, runner):
super(MlflowLoggerHook, self).before_run(runner)
super().before_run(runner)
if self.exp_name is not None:
self.mlflow.set_experiment(self.exp_name)
if self.tags is not None:

View File

@ -49,8 +49,7 @@ class NeptuneLoggerHook(LoggerHook):
with_step=True,
by_epoch=True):
super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_neptune()
self.init_kwargs = init_kwargs
self.with_step = with_step

View File

@ -40,8 +40,7 @@ class PaviLoggerHook(LoggerHook):
reset_flag=False,
by_epoch=True,
img_key='img_info'):
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.init_kwargs = init_kwargs
self.add_graph = add_graph
self.add_last_ckpt = add_last_ckpt
@ -49,7 +48,7 @@ class PaviLoggerHook(LoggerHook):
@master_only
def before_run(self, runner):
super(PaviLoggerHook, self).before_run(runner)
super().before_run(runner)
try:
from pavi import SummaryWriter
except ImportError:

View File

@ -27,8 +27,7 @@ class SegmindLoggerHook(LoggerHook):
ignore_last=True,
reset_flag=False,
by_epoch=True):
super(SegmindLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_segmind()
def import_segmind(self):

View File

@ -28,13 +28,12 @@ class TensorboardLoggerHook(LoggerHook):
ignore_last=True,
reset_flag=False,
by_epoch=True):
super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.log_dir = log_dir
@master_only
def before_run(self, runner):
super(TensorboardLoggerHook, self).before_run(runner)
super().before_run(runner)
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.1')):
try:

View File

@ -62,8 +62,7 @@ class TextLoggerHook(LoggerHook):
out_suffix=('.log.json', '.log', '.py'),
keep_local=True,
file_client_args=None):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.by_epoch = by_epoch
self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name
@ -87,7 +86,7 @@ class TextLoggerHook(LoggerHook):
self.out_dir)
def before_run(self, runner):
super(TextLoggerHook, self).before_run(runner)
super().before_run(runner)
if self.out_dir is not None:
self.file_client = FileClient.infer_client(self.file_client_args,
@ -97,8 +96,8 @@ class TextLoggerHook(LoggerHook):
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(
(f'Text logs will be saved to {self.out_dir} by '
f'{self.file_client.name} after the training process.'))
f'Text logs will be saved to {self.out_dir} by '
f'{self.file_client.name} after the training process.')
self.start_iter = runner.iter
self.json_log_path = osp.join(runner.work_dir,
@ -242,15 +241,15 @@ class TextLoggerHook(LoggerHook):
local_filepath = osp.join(runner.work_dir, filename)
out_filepath = self.file_client.join_path(
self.out_dir, filename)
with open(local_filepath, 'r') as f:
with open(local_filepath) as f:
self.file_client.put_text(f.read(), out_filepath)
runner.logger.info(
(f'The file {local_filepath} has been uploaded to '
f'{out_filepath}.'))
f'The file {local_filepath} has been uploaded to '
f'{out_filepath}.')
if not self.keep_local:
os.remove(local_filepath)
runner.logger.info(
(f'{local_filepath} was removed due to the '
'`self.keep_local=False`'))
f'{local_filepath} was removed due to the '
'`self.keep_local=False`')

View File

@ -57,8 +57,7 @@ class WandbLoggerHook(LoggerHook):
with_step=True,
log_artifact=True,
out_suffix=('.log.json', '.log', '.py')):
super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_wandb()
self.init_kwargs = init_kwargs
self.commit = commit
@ -76,7 +75,7 @@ class WandbLoggerHook(LoggerHook):
@master_only
def before_run(self, runner):
super(WandbLoggerHook, self).before_run(runner)
super().before_run(runner)
if self.wandb is None:
self.import_wandb()
if self.init_kwargs:

View File

@ -157,7 +157,7 @@ class LrUpdaterHook(Hook):
class FixedLrUpdaterHook(LrUpdaterHook):
def __init__(self, **kwargs):
super(FixedLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
return base_lr
@ -188,7 +188,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
self.step = step
self.gamma = gamma
self.min_lr = min_lr
super(StepLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
@ -215,7 +215,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, **kwargs):
self.gamma = gamma
super(ExpLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
@ -228,7 +228,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
def __init__(self, power=1., min_lr=0., **kwargs):
self.power = power
self.min_lr = min_lr
super(PolyLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
@ -247,7 +247,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, power=1., **kwargs):
self.gamma = gamma
self.power = power
super(InvLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
@ -269,7 +269,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
@ -317,7 +317,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
self.start_percent = start_percent
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
@ -367,7 +367,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
self.restart_weights = restart_weights
assert (len(self.periods) == len(self.restart_weights)
), 'periods and restart_weights should have the same length.'
super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
self.cumulative_periods = [
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
@ -484,10 +484,10 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
assert not by_epoch, \
'currently only support "by_epoch" = False'
super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
super().__init__(by_epoch, **kwargs)
def before_run(self, runner):
super(CyclicLrUpdaterHook, self).before_run(runner)
super().before_run(runner)
# initiate lr_phases
# total lr_phases are separated as up and down
self.max_iter_per_phase = runner.max_iters // self.cyclic_times
@ -598,7 +598,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
self.final_div_factor = final_div_factor
self.three_phase = three_phase
self.lr_phases = [] # init lr_phases
super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def before_run(self, runner):
if hasattr(self, 'total_steps'):
@ -668,7 +668,7 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
super(LinearAnnealingLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:

View File

@ -176,7 +176,7 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook):
self.step = step
self.gamma = gamma
self.min_momentum = min_momentum
super(StepMomentumUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
progress = runner.epoch if self.by_epoch else runner.iter
@ -214,7 +214,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio
super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
if self.by_epoch:
@ -247,7 +247,7 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio
super(LinearAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
if self.by_epoch:
@ -328,10 +328,10 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
# currently only support by_epoch=False
assert not by_epoch, \
'currently only support "by_epoch" = False'
super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
super().__init__(by_epoch, **kwargs)
def before_run(self, runner):
super(CyclicMomentumUpdaterHook, self).before_run(runner)
super().before_run(runner)
# initiate momentum_phases
# total momentum_phases are separated as up and down
max_iter_per_phase = runner.max_iters // self.cyclic_times
@ -439,7 +439,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
self.anneal_func = annealing_linear
self.three_phase = three_phase
self.momentum_phases = [] # init momentum_phases
super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def before_run(self, runner):
if isinstance(runner.optimizer, dict):

View File

@ -110,7 +110,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
"""
def __init__(self, cumulative_iters=1, **kwargs):
super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
super().__init__(**kwargs)
assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
f'cumulative_iters only accepts positive int, but got ' \
@ -297,8 +297,7 @@ if (TORCH_VERSION != 'parrots'
"""
def __init__(self, *args, **kwargs):
super(GradientCumulativeFp16OptimizerHook,
self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
def after_train_iter(self, runner):
if not self.initialized:
@ -490,8 +489,7 @@ else:
iters gradient cumulating."""
def __init__(self, *args, **kwargs):
super(GradientCumulativeFp16OptimizerHook,
self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
def after_train_iter(self, runner):
if not self.initialized:

View File

@ -263,7 +263,7 @@ class IterBasedRunner(BaseRunner):
if log_config is not None:
for info in log_config['hooks']:
info.setdefault('by_epoch', False)
super(IterBasedRunner, self).register_training_hooks(
super().register_training_hooks(
lr_config=lr_config,
momentum_config=momentum_config,
optimizer_config=optimizer_config,

View File

@ -54,7 +54,7 @@ def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
msg += reset_style
warnings.warn(msg)
device = torch.device('cuda:{}'.format(device_id))
device = torch.device(f'cuda:{device_id}')
# create builder and network
logger = trt.Logger(log_level)
builder = trt.Builder(logger)
@ -209,7 +209,7 @@ class TRTWrapper(torch.nn.Module):
msg += reset_style
warnings.warn(msg)
super(TRTWrapper, self).__init__()
super().__init__()
self.engine = engine
if isinstance(self.engine, str):
self.engine = load_trt_engine(engine)

View File

@ -39,7 +39,7 @@ class ConfigDict(Dict):
def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
value = super().__getattr__(name)
except KeyError:
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
f"attribute '{name}'")
@ -96,7 +96,7 @@ class Config:
@staticmethod
def _validate_py_syntax(filename):
with open(filename, 'r', encoding='utf-8') as f:
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
content = f.read()
try:
@ -116,7 +116,7 @@ class Config:
fileBasename=file_basename,
fileBasenameNoExtension=file_basename_no_extension,
fileExtname=file_extname)
with open(filename, 'r', encoding='utf-8') as f:
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
for key, value in support_templates.items():
@ -130,7 +130,7 @@ class Config:
def _pre_substitute_base_vars(filename, temp_config_name):
"""Substitute base variable placehoders to string, so that parsing
would work."""
with open(filename, 'r', encoding='utf-8') as f:
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
base_var_dict = {}
@ -183,7 +183,7 @@ class Config:
check_file_exist(filename)
fileExtname = osp.splitext(filename)[1]
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
raise OSError('Only py/yml/yaml/json type are supported now!')
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
@ -236,7 +236,7 @@ class Config:
warnings.warn(warning_msg, DeprecationWarning)
cfg_text = filename + '\n'
with open(filename, 'r', encoding='utf-8') as f:
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
cfg_text += f.read()
@ -356,7 +356,7 @@ class Config:
:obj:`Config`: Config obj.
"""
if file_format not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
raise OSError('Only py/yml/yaml/json type are supported now!')
if file_format != '.py' and 'dict(' in cfg_str:
# check if users specify a wrong suffix for python
warnings.warn(
@ -396,16 +396,16 @@ class Config:
if isinstance(filename, Path):
filename = str(filename)
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
super().__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super().__setattr__('_filename', filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, 'r') as f:
with open(filename) as f:
text = f.read()
else:
text = ''
super(Config, self).__setattr__('_text', text)
super().__setattr__('_text', text)
@property
def filename(self):
@ -556,9 +556,9 @@ class Config:
def __setstate__(self, state):
_cfg_dict, _filename, _text = state
super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
super(Config, self).__setattr__('_filename', _filename)
super(Config, self).__setattr__('_text', _text)
super().__setattr__('_cfg_dict', _cfg_dict)
super().__setattr__('_filename', _filename)
super().__setattr__('_text', _text)
def dump(self, file=None):
"""Dumps config into a file or returns a string representation of the
@ -584,7 +584,7 @@ class Config:
will be dumped. Defaults to None.
"""
import mmcv
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
if file is None:
if self.filename is None or self.filename.endswith('.py'):
return self.pretty_text
@ -638,8 +638,8 @@ class Config:
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
super(Config, self).__setattr__(
cfg_dict = super().__getattribute__('_cfg_dict')
super().__setattr__(
'_cfg_dict',
Config._merge_a_into_b(
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))

View File

@ -6,7 +6,7 @@ class TimerError(Exception):
def __init__(self, message):
self.message = message
super(TimerError, self).__init__(message)
super().__init__(message)
class Timer:

View File

@ -40,10 +40,10 @@ def flowread(flow_or_path: Union[np.ndarray, str],
try:
header = f.read(4).decode('utf-8')
except Exception:
raise IOError(f'Invalid flow file: {flow_or_path}')
raise OSError(f'Invalid flow file: {flow_or_path}')
else:
if header != 'PIEH':
raise IOError(f'Invalid flow file: {flow_or_path}, '
raise OSError(f'Invalid flow file: {flow_or_path}, '
'header does not contain PIEH')
w = np.fromfile(f, np.int32, 1).squeeze()
@ -53,7 +53,7 @@ def flowread(flow_or_path: Union[np.ndarray, str],
assert concat_axis in [0, 1]
cat_flow = imread(flow_or_path, flag='unchanged')
if cat_flow.ndim != 2:
raise IOError(
raise OSError(
f'{flow_or_path} is not a valid quantized flow file, '
f'its dimension is {cat_flow.ndim}.')
assert cat_flow.shape[concat_axis] % 2 == 0
@ -86,7 +86,7 @@ def flowwrite(flow: np.ndarray,
"""
if not quantize:
with open(filename, 'wb') as f:
f.write('PIEH'.encode('utf-8'))
f.write(b'PIEH')
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
@ -146,7 +146,7 @@ def dequantize_flow(dx: np.ndarray,
assert dx.shape == dy.shape
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
dx, dy = (dequantize(d, -max_val, max_val, 255) for d in [dx, dy])
if denorm:
dx *= dx.shape[1]

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division
from typing import Optional, Union
import numpy as np

View File

@ -39,7 +39,7 @@ def choose_requirement(primary, secondary):
def get_version():
version_file = 'mmcv/version.py'
with open(version_file, 'r', encoding='utf-8') as f:
with open(version_file, encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
@ -94,12 +94,11 @@ def parse_requirements(fname='requirements/runtime.txt', with_version=True):
yield info
def parse_require_file(fpath):
with open(fpath, 'r') as f:
with open(fpath) as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith('#'):
for info in parse_line(line):
yield info
yield from parse_line(line)
def gen_packages_items():
if exists(require_fpath):

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division
import numpy as np
import pytest

View File

@ -23,7 +23,7 @@ class ExampleConv(nn.Module):
groups=1,
bias=True,
norm_cfg=None):
super(ExampleConv, self).__init__()
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size

View File

@ -202,21 +202,22 @@ class TestFileClient:
# test `list_dir_or_file`
with build_temporary_directory() as tmp_dir:
# 1. list directories and files
assert set(disk_backend.list_dir_or_file(tmp_dir)) == set(
['dir1', 'dir2', 'text1.txt', 'text2.txt'])
assert set(disk_backend.list_dir_or_file(tmp_dir)) == {
'dir1', 'dir2', 'text1.txt', 'text2.txt'
}
# 2. list directories and files recursively
assert set(disk_backend.list_dir_or_file(
tmp_dir, recursive=True)) == set([
tmp_dir, recursive=True)) == {
'dir1',
osp.join('dir1', 'text3.txt'), 'dir2',
osp.join('dir2', 'dir3'),
osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
])
}
# 3. only list directories
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_file=False)) == set(['dir1', 'dir2'])
tmp_dir, list_file=False)) == {'dir1', 'dir2'}
with pytest.raises(
TypeError,
match='`suffix` should be None when `list_dir` is True'):
@ -227,30 +228,30 @@ class TestFileClient:
# 4. only list directories recursively
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_file=False, recursive=True)) == set(
['dir1', 'dir2',
osp.join('dir2', 'dir3')])
tmp_dir, list_file=False, recursive=True)) == {
'dir1', 'dir2',
osp.join('dir2', 'dir3')
}
# 5. only list files
assert set(disk_backend.list_dir_or_file(
tmp_dir, list_dir=False)) == set(['text1.txt', 'text2.txt'])
tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
# 6. only list files recursively
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False, recursive=True)) == set([
tmp_dir, list_dir=False, recursive=True)) == {
osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
])
}
# 7. only list files ending with suffix
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix='.txt')) == set(['text1.txt', 'text2.txt'])
suffix='.txt')) == {'text1.txt', 'text2.txt'}
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix=('.txt',
'.jpg'))) == set(['text1.txt', 'text2.txt'])
suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
with pytest.raises(
TypeError,
match='`suffix` must be a string or tuple of strings'):
@ -260,22 +261,22 @@ class TestFileClient:
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False, suffix='.txt',
recursive=True)) == set([
recursive=True)) == {
osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt',
'text2.txt'
])
}
# 7. only list files ending with suffix
assert set(
disk_backend.list_dir_or_file(
tmp_dir,
list_dir=False,
suffix=('.txt', '.jpg'),
recursive=True)) == set([
recursive=True)) == {
osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
])
}
@patch('ceph.S3Client', MockS3Client)
def test_ceph_backend(self):
@ -463,21 +464,21 @@ class TestFileClient:
with build_temporary_directory() as tmp_dir:
# 1. list directories and files
assert set(petrel_backend.list_dir_or_file(tmp_dir)) == set(
['dir1', 'dir2', 'text1.txt', 'text2.txt'])
assert set(petrel_backend.list_dir_or_file(tmp_dir)) == {
'dir1', 'dir2', 'text1.txt', 'text2.txt'
}
# 2. list directories and files recursively
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, recursive=True)) == set([
'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2',
'/'.join(('dir2', 'dir3')), '/'.join(
petrel_backend.list_dir_or_file(tmp_dir, recursive=True)) == {
'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join(
('dir2', 'dir3')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
])
}
# 3. only list directories
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_file=False)) == set(['dir1', 'dir2'])
tmp_dir, list_file=False)) == {'dir1', 'dir2'}
with pytest.raises(
TypeError,
match=('`list_dir` should be False when `suffix` is not '
@ -489,31 +490,30 @@ class TestFileClient:
# 4. only list directories recursively
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_file=False, recursive=True)) == set(
['dir1', 'dir2', '/'.join(('dir2', 'dir3'))])
tmp_dir, list_file=False, recursive=True)) == {
'dir1', 'dir2', '/'.join(('dir2', 'dir3'))
}
# 5. only list files
assert set(
petrel_backend.list_dir_or_file(tmp_dir,
list_dir=False)) == set(
['text1.txt', 'text2.txt'])
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
# 6. only list files recursively
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False, recursive=True)) == set([
tmp_dir, list_dir=False, recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
])
}
# 7. only list files ending with suffix
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix='.txt')) == set(['text1.txt', 'text2.txt'])
suffix='.txt')) == {'text1.txt', 'text2.txt'}
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix=('.txt',
'.jpg'))) == set(['text1.txt', 'text2.txt'])
suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
with pytest.raises(
TypeError,
match='`suffix` must be a string or tuple of strings'):
@ -523,22 +523,22 @@ class TestFileClient:
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False, suffix='.txt',
recursive=True)) == set([
recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), 'text1.txt',
'text2.txt'
])
}
# 7. only list files ending with suffix
assert set(
petrel_backend.list_dir_or_file(
tmp_dir,
list_dir=False,
suffix=('.txt', '.jpg'),
recursive=True)) == set([
recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
])
}
@patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
@patch('mc.pyvector', MagicMock)

View File

@ -128,7 +128,7 @@ def test_register_handler():
assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2')
mmcv.dump(content, tmp_filename)
with open(tmp_filename, 'r') as f:
with open(tmp_filename) as f:
written = f.read()
os.remove(tmp_filename)
assert written == '\n' + content

View File

@ -6,7 +6,7 @@ import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
class TestBBox(object):
class TestBBox:
def _test_bbox_overlaps(self, device='cpu', dtype=torch.float):
from mmcv.ops import bbox_overlaps

View File

@ -4,7 +4,7 @@ import torch
import torch.nn.functional as F
class TestBilinearGridSample(object):
class TestBilinearGridSample:
def _test_bilinear_grid_sample(self,
dtype=torch.float,

View File

@ -4,7 +4,7 @@ import pytest
import torch
class TestBoxIoURotated(object):
class TestBoxIoURotated:
def test_box_iou_rotated_cpu(self):
from mmcv.ops import box_iou_rotated

View File

@ -3,7 +3,7 @@ import torch
from torch.autograd import gradcheck
class TestCarafe(object):
class TestCarafe:
def test_carafe_naive_gradcheck(self):
if not torch.cuda.is_available():

View File

@ -15,7 +15,7 @@ class Loss(nn.Module):
return torch.mean(input - target)
class TestCrissCrossAttention(object):
class TestCrissCrossAttention:
def test_cc_attention(self):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

View File

@ -35,7 +35,7 @@ gt_offset_bias_grad = [1.44, -0.72, 0., 0., -0.10, -0.08, -0.54, -0.54],
gt_deform_weight_grad = [[[[3.62, 0.], [0.40, 0.18]]]]
class TestDeformconv(object):
class TestDeformconv:
def _test_deformconv(self,
dtype=torch.float,

View File

@ -35,7 +35,7 @@ outputs = [([[[[1, 1.25], [1.5, 1.75]]]], [[[[3.0625, 0.4375],
0.00390625]]]])]
class TestDeformRoIPool(object):
class TestDeformRoIPool:
def test_deform_roi_pool_gradcheck(self):
if not torch.cuda.is_available():

View File

@ -37,7 +37,7 @@ sigmoid_outputs = [(0.13562961, [[-0.00657264, 0.11185755],
[-0.02462499, 0.08277918, 0.18050370]])]
class Testfocalloss(object):
class Testfocalloss:
def _test_softmax(self, dtype=torch.float):
if not torch.cuda.is_available():

View File

@ -10,7 +10,7 @@ except ImportError:
_USING_PARROTS = False
class TestFusedBiasLeakyReLU(object):
class TestFusedBiasLeakyReLU:
@classmethod
def setup_class(cls):

View File

@ -2,7 +2,7 @@
import torch
class TestInfo(object):
class TestInfo:
def test_info(self):
if not torch.cuda.is_available():

View File

@ -2,7 +2,7 @@
import torch
class TestMaskedConv2d(object):
class TestMaskedConv2d:
def test_masked_conv2d(self):
if not torch.cuda.is_available():

View File

@ -37,7 +37,7 @@ dcn_offset_b_grad = [
]
class TestMdconv(object):
class TestMdconv:
def _test_mdconv(self, dtype=torch.float, device='cuda'):
if not torch.cuda.is_available() and device == 'cuda':

View File

@ -55,7 +55,7 @@ def test_forward_multi_scale_deformable_attn_pytorch():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
S = sum([(H * W).item() for H, W in shapes])
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
@ -78,7 +78,7 @@ def test_forward_equal_with_pytorch_double():
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01
@ -111,7 +111,7 @@ def test_forward_equal_with_pytorch_float():
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01
@ -155,7 +155,7 @@ def test_gradient_numerical(channels,
shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
S = sum((H * W).item() for H, W in shapes)
value = torch.rand(N, S, M, channels).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()

View File

@ -6,7 +6,7 @@ import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
class Testnms(object):
class Testnms:
@pytest.mark.parametrize('device', [
pytest.param(
@ -129,8 +129,7 @@ class Testnms(object):
scores = tensor_dets[:, 4]
nms_keep_inds = nms(boxes.contiguous(), scores.contiguous(),
iou_thr)[1]
assert set([g[0].item()
for g in np_groups]) == set(nms_keep_inds.tolist())
assert {g[0].item() for g in np_groups} == set(nms_keep_inds.tolist())
# non empty tensor input
tensor_dets = torch.from_numpy(np_dets)

View File

@ -33,7 +33,7 @@ def run_before_and_after_test():
class WrapFunction(nn.Module):
def __init__(self, wrapped_function):
super(WrapFunction, self).__init__()
super().__init__()
self.wrapped_function = wrapped_function
def forward(self, *args, **kwargs):
@ -662,7 +662,7 @@ def test_cummax_cummin(key, opset=11):
input_list = [
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
torch.rand((2, 3, 4, 1, 5)),
torch.rand((1)),
torch.rand(1),
torch.rand((2, 0, 1)), # tensor.numel() is 0
torch.FloatTensor(), # empty tensor
]

View File

@ -15,7 +15,7 @@ class Loss(nn.Module):
return torch.mean(input - target)
class TestPSAMask(object):
class TestPSAMask:
def test_psa_mask_collect(self):
if not torch.cuda.is_available():

View File

@ -29,7 +29,7 @@ outputs = [([[[[1., 2.], [3., 4.]]]], [[[[1., 1.], [1., 1.]]]]),
1.]]]])]
class TestRoiPool(object):
class TestRoiPool:
def test_roipool_gradcheck(self):
if not torch.cuda.is_available():

View File

@ -14,7 +14,7 @@ else:
import re
class TestSyncBN(object):
class TestSyncBN:
def dist_init(self):
rank = int(os.environ['SLURM_PROCID'])

View File

@ -30,7 +30,7 @@ if not is_tensorrt_plugin_loaded():
class WrapFunction(nn.Module):
def __init__(self, wrapped_function):
super(WrapFunction, self).__init__()
super().__init__()
self.wrapped_function = wrapped_function
def forward(self, *args, **kwargs):
@ -576,7 +576,7 @@ def test_cummin_cummax(func: Callable):
input_list = [
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
torch.rand((2, 3, 4, 1, 5)).cuda(),
torch.rand((1)).cuda()
torch.rand(1).cuda()
]
input_names = ['input']
@ -756,7 +756,7 @@ def test_corner_pool(mode):
class CornerPoolWrapper(CornerPool):
def __init__(self, mode):
super(CornerPoolWrapper, self).__init__(mode)
super().__init__(mode)
def forward(self, x):
# no use `torch.cummax`, instead `corner_pool` is used

View File

@ -10,7 +10,7 @@ except ImportError:
_USING_PARROTS = False
class TestUpFirDn2d(object):
class TestUpFirDn2d:
"""Unit test for UpFirDn2d.
Here, we just test the basic case of upsample version. More gerneal tests

View File

@ -96,8 +96,8 @@ def test_voxelization_nondeterministic():
coors_all = dynamic_voxelization.forward(points)
coors_all = coors_all.cpu().detach().numpy().tolist()
coors_set = set([tuple(c) for c in coors])
coors_all_set = set([tuple(c) for c in coors_all])
coors_set = {tuple(c) for c in coors}
coors_all_set = {tuple(c) for c in coors_all}
assert len(coors_set) == len(coors)
assert len(coors_set - coors_all_set) == 0
@ -112,7 +112,7 @@ def test_voxelization_nondeterministic():
for c, ps, n in zip(coors, voxels, num_points_per_voxel):
ideal_voxel_points_set = coors_points_dict[tuple(c)]
voxel_points_set = set([tuple(p) for p in ps[:n]])
voxel_points_set = {tuple(p) for p in ps[:n]}
assert len(voxel_points_set) == n
if n < max_num_points:
assert voxel_points_set == ideal_voxel_points_set
@ -133,7 +133,7 @@ def test_voxelization_nondeterministic():
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy().tolist()
coors_set = set([tuple(c) for c in coors])
coors_all_set = set([tuple(c) for c in coors_all])
coors_set = {tuple(c) for c in coors}
coors_all_set = {tuple(c) for c in coors_all}
assert len(coors_set) == len(coors) == len(coors_all_set)

View File

@ -63,7 +63,7 @@ def test_is_module_wrapper():
# test module wrapper registry
@MODULE_WRAPPERS.register_module()
class ModuleWrapper(object):
class ModuleWrapper:
def __init__(self, module):
self.module = module

Some files were not shown because too many files have changed in this diff Show More