Add pyupgrade pre-commit hook (#1937)

* add pyupgrade

* add options for pyupgrade

* minor refinement
pull/1968/head
Zaida Zhou 2022-05-18 11:47:14 +08:00 committed by GitHub
parent c561264d55
commit 45fa3e44a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
110 changed files with 339 additions and 360 deletions

View File

@ -42,7 +42,7 @@ def parse_args():
class SimpleModel(nn.Module): class SimpleModel(nn.Module):
def __init__(self): def __init__(self):
super(SimpleModel, self).__init__() super().__init__()
self.conv = nn.Conv2d(1, 1, 1) self.conv = nn.Conv2d(1, 1, 1)
def train_step(self, *args, **kwargs): def train_step(self, *args, **kwargs):
@ -159,13 +159,13 @@ def run(cfg, logger):
def plot_lr_curve(json_file, cfg): def plot_lr_curve(json_file, cfg):
data_dict = dict(LearningRate=[], Momentum=[]) data_dict = dict(LearningRate=[], Momentum=[])
assert os.path.isfile(json_file) assert os.path.isfile(json_file)
with open(json_file, 'r') as f: with open(json_file) as f:
for line in f: for line in f:
log = json.loads(line.strip()) log = json.loads(line.strip())
data_dict['LearningRate'].append(log['lr']) data_dict['LearningRate'].append(log['lr'])
data_dict['Momentum'].append(log['momentum']) data_dict['Momentum'].append(log['momentum'])
wind_w, wind_h = [int(size) for size in cfg.window_size.split('*')] wind_w, wind_h = (int(size) for size in cfg.window_size.split('*'))
# if legend is None, use {filename}_{key} as legend # if legend is None, use {filename}_{key} as legend
fig, axes = plt.subplots(2, 1, figsize=(wind_w, wind_h)) fig, axes = plt.subplots(2, 1, figsize=(wind_w, wind_h))
plt.subplots_adjust(hspace=0.5) plt.subplots_adjust(hspace=0.5)

View File

@ -43,7 +43,11 @@ repos:
hooks: hooks:
- id: docformatter - id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"] args: ["--in-place", "--wrap-descriptions", "79"]
- repo: https://github.com/asottile/pyupgrade
rev: v2.32.1
hooks:
- id: pyupgrade
args: ["--py36-plus"]
- repo: https://github.com/open-mmlab/pre-commit-hooks - repo: https://github.com/open-mmlab/pre-commit-hooks
rev: v0.2.0 # Use the ref you want to point at rev: v0.2.0 # Use the ref you want to point at
hooks: hooks:

View File

@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
sys.path.insert(0, os.path.abspath('../..')) sys.path.insert(0, os.path.abspath('../..'))
version_file = '../../mmcv/version.py' version_file = '../../mmcv/version.py'
with open(version_file, 'r') as f: with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec')) exec(compile(f.read(), version_file, 'exec'))
__version__ = locals()['__version__'] __version__ = locals()['__version__']

View File

@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
sys.path.insert(0, os.path.abspath('../..')) sys.path.insert(0, os.path.abspath('../..'))
version_file = '../../mmcv/version.py' version_file = '../../mmcv/version.py'
with open(version_file, 'r') as f: with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec')) exec(compile(f.read(), version_file, 'exec'))
__version__ = locals()['__version__'] __version__ = locals()['__version__']

View File

@ -14,7 +14,7 @@ from mmcv.utils import get_logger
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super(Model, self).__init__() super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5) self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2) self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5) self.conv2 = nn.Conv2d(6, 16, 5)

View File

@ -12,7 +12,7 @@ class AlexNet(nn.Module):
""" """
def __init__(self, num_classes=-1): def __init__(self, num_classes=-1):
super(AlexNet, self).__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.features = nn.Sequential( self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),

View File

@ -29,7 +29,7 @@ class Clamp(nn.Module):
""" """
def __init__(self, min=-1., max=1.): def __init__(self, min=-1., max=1.):
super(Clamp, self).__init__() super().__init__()
self.min = min self.min = min
self.max = max self.max = max

View File

@ -38,7 +38,7 @@ class ContextBlock(nn.Module):
ratio, ratio,
pooling_type='att', pooling_type='att',
fusion_types=('channel_add', )): fusion_types=('channel_add', )):
super(ContextBlock, self).__init__() super().__init__()
assert pooling_type in ['avg', 'att'] assert pooling_type in ['avg', 'att']
assert isinstance(fusion_types, (list, tuple)) assert isinstance(fusion_types, (list, tuple))
valid_fusion_types = ['channel_add', 'channel_mul'] valid_fusion_types = ['channel_add', 'channel_mul']

View File

@ -83,7 +83,7 @@ class ConvModule(nn.Module):
with_spectral_norm=False, with_spectral_norm=False,
padding_mode='zeros', padding_mode='zeros',
order=('conv', 'norm', 'act')): order=('conv', 'norm', 'act')):
super(ConvModule, self).__init__() super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict) assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict)
assert act_cfg is None or isinstance(act_cfg, dict) assert act_cfg is None or isinstance(act_cfg, dict)
@ -96,7 +96,7 @@ class ConvModule(nn.Module):
self.with_explicit_padding = padding_mode not in official_padding_mode self.with_explicit_padding = padding_mode not in official_padding_mode
self.order = order self.order = order
assert isinstance(self.order, tuple) and len(self.order) == 3 assert isinstance(self.order, tuple) and len(self.order) == 3
assert set(order) == set(['conv', 'norm', 'act']) assert set(order) == {'conv', 'norm', 'act'}
self.with_norm = norm_cfg is not None self.with_norm = norm_cfg is not None
self.with_activation = act_cfg is not None self.with_activation = act_cfg is not None

View File

@ -35,7 +35,7 @@ class ConvWS2d(nn.Conv2d):
groups=1, groups=1,
bias=True, bias=True,
eps=1e-5): eps=1e-5):
super(ConvWS2d, self).__init__( super().__init__(
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,

View File

@ -59,7 +59,7 @@ class DepthwiseSeparableConvModule(nn.Module):
pw_norm_cfg='default', pw_norm_cfg='default',
pw_act_cfg='default', pw_act_cfg='default',
**kwargs): **kwargs):
super(DepthwiseSeparableConvModule, self).__init__() super().__init__()
assert 'groups' not in kwargs, 'groups should not be specified' assert 'groups' not in kwargs, 'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not # if norm/activation config of depthwise/pointwise ConvModule is not

View File

@ -37,7 +37,7 @@ class DropPath(nn.Module):
""" """
def __init__(self, drop_prob=0.1): def __init__(self, drop_prob=0.1):
super(DropPath, self).__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x):

View File

@ -54,7 +54,7 @@ class GeneralizedAttention(nn.Module):
q_stride=1, q_stride=1,
attention_type='1111'): attention_type='1111'):
super(GeneralizedAttention, self).__init__() super().__init__()
# hard range means local range for non-local operation # hard range means local range for non-local operation
self.position_embedding_dim = ( self.position_embedding_dim = (

View File

@ -27,7 +27,7 @@ class HSigmoid(nn.Module):
""" """
def __init__(self, bias=3.0, divisor=6.0, min_value=0.0, max_value=1.0): def __init__(self, bias=3.0, divisor=6.0, min_value=0.0, max_value=1.0):
super(HSigmoid, self).__init__() super().__init__()
warnings.warn( warnings.warn(
'In MMCV v1.4.4, we modified the default value of args to align ' 'In MMCV v1.4.4, we modified the default value of args to align '
'with PyTorch official. Previous Implementation: ' 'with PyTorch official. Previous Implementation: '

View File

@ -22,7 +22,7 @@ class HSwish(nn.Module):
""" """
def __init__(self, inplace=False): def __init__(self, inplace=False):
super(HSwish, self).__init__() super().__init__()
self.act = nn.ReLU6(inplace) self.act = nn.ReLU6(inplace)
def forward(self, x): def forward(self, x):

View File

@ -40,7 +40,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
norm_cfg=None, norm_cfg=None,
mode='embedded_gaussian', mode='embedded_gaussian',
**kwargs): **kwargs):
super(_NonLocalNd, self).__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.reduction = reduction self.reduction = reduction
self.use_scale = use_scale self.use_scale = use_scale
@ -228,8 +228,7 @@ class NonLocal1d(_NonLocalNd):
sub_sample=False, sub_sample=False,
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
**kwargs): **kwargs):
super(NonLocal1d, self).__init__( super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample self.sub_sample = sub_sample
@ -262,8 +261,7 @@ class NonLocal2d(_NonLocalNd):
sub_sample=False, sub_sample=False,
conv_cfg=dict(type='Conv2d'), conv_cfg=dict(type='Conv2d'),
**kwargs): **kwargs):
super(NonLocal2d, self).__init__( super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample self.sub_sample = sub_sample
@ -293,8 +291,7 @@ class NonLocal3d(_NonLocalNd):
sub_sample=False, sub_sample=False,
conv_cfg=dict(type='Conv3d'), conv_cfg=dict(type='Conv3d'),
**kwargs): **kwargs):
super(NonLocal3d, self).__init__( super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample self.sub_sample = sub_sample
if sub_sample: if sub_sample:

View File

@ -14,7 +14,7 @@ class Scale(nn.Module):
""" """
def __init__(self, scale=1.0): def __init__(self, scale=1.0):
super(Scale, self).__init__() super().__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x): def forward(self, x):

View File

@ -19,7 +19,7 @@ class Swish(nn.Module):
""" """
def __init__(self): def __init__(self):
super(Swish, self).__init__() super().__init__()
def forward(self, x): def forward(self, x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)

View File

@ -96,7 +96,7 @@ class AdaptivePadding(nn.Module):
""" """
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
super(AdaptivePadding, self).__init__() super().__init__()
assert padding in ('same', 'corner') assert padding in ('same', 'corner')
kernel_size = to_2tuple(kernel_size) kernel_size = to_2tuple(kernel_size)
@ -190,7 +190,7 @@ class PatchEmbed(BaseModule):
norm_cfg=None, norm_cfg=None,
input_size=None, input_size=None,
init_cfg=None): init_cfg=None):
super(PatchEmbed, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims self.embed_dims = embed_dims
if stride is None: if stride is None:
@ -435,7 +435,7 @@ class MultiheadAttention(BaseModule):
init_cfg=None, init_cfg=None,
batch_first=False, batch_first=False,
**kwargs): **kwargs):
super(MultiheadAttention, self).__init__(init_cfg) super().__init__(init_cfg)
if 'dropout' in kwargs: if 'dropout' in kwargs:
warnings.warn( warnings.warn(
'The arguments `dropout` in MultiheadAttention ' 'The arguments `dropout` in MultiheadAttention '
@ -590,7 +590,7 @@ class FFN(BaseModule):
add_identity=True, add_identity=True,
init_cfg=None, init_cfg=None,
**kwargs): **kwargs):
super(FFN, self).__init__(init_cfg) super().__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \ assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.' f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims self.embed_dims = embed_dims
@ -694,12 +694,12 @@ class BaseTransformerLayer(BaseModule):
f'to a dict named `ffn_cfgs`. ', DeprecationWarning) f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
ffn_cfgs[new_name] = kwargs[ori_name] ffn_cfgs[new_name] = kwargs[ori_name]
super(BaseTransformerLayer, self).__init__(init_cfg) super().__init__(init_cfg)
self.batch_first = batch_first self.batch_first = batch_first
assert set(operation_order) & set( assert set(operation_order) & {
['self_attn', 'norm', 'ffn', 'cross_attn']) == \ 'self_attn', 'norm', 'ffn', 'cross_attn'} == \
set(operation_order), f'The operation_order of' \ set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \ f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \ f'contains all four operation type ' \
@ -880,7 +880,7 @@ class TransformerLayerSequence(BaseModule):
""" """
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super(TransformerLayerSequence, self).__init__(init_cfg) super().__init__(init_cfg)
if isinstance(transformerlayers, dict): if isinstance(transformerlayers, dict):
transformerlayers = [ transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers) copy.deepcopy(transformerlayers) for _ in range(num_layers)

View File

@ -26,7 +26,7 @@ class PixelShufflePack(nn.Module):
def __init__(self, in_channels, out_channels, scale_factor, def __init__(self, in_channels, out_channels, scale_factor,
upsample_kernel): upsample_kernel):
super(PixelShufflePack, self).__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.scale_factor = scale_factor self.scale_factor = scale_factor

View File

@ -30,7 +30,7 @@ class BasicBlock(nn.Module):
downsample=None, downsample=None,
style='pytorch', style='pytorch',
with_cp=False): with_cp=False):
super(BasicBlock, self).__init__() super().__init__()
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
@ -77,7 +77,7 @@ class Bottleneck(nn.Module):
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer. it is "caffe", the stride-two layer is the first 1x1 conv layer.
""" """
super(Bottleneck, self).__init__() super().__init__()
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
if style == 'pytorch': if style == 'pytorch':
conv1_stride = 1 conv1_stride = 1
@ -218,7 +218,7 @@ class ResNet(nn.Module):
bn_eval=True, bn_eval=True,
bn_frozen=False, bn_frozen=False,
with_cp=False): with_cp=False):
super(ResNet, self).__init__() super().__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet') raise KeyError(f'invalid depth {depth} for resnet')
assert num_stages >= 1 and num_stages <= 4 assert num_stages >= 1 and num_stages <= 4
@ -293,7 +293,7 @@ class ResNet(nn.Module):
return tuple(outs) return tuple(outs)
def train(self, mode=True): def train(self, mode=True):
super(ResNet, self).train(mode) super().train(mode)
if self.bn_eval: if self.bn_eval:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):

View File

@ -277,10 +277,10 @@ def print_model_with_flops(model,
return ', '.join([ return ', '.join([
params_to_string( params_to_string(
accumulated_num_params, units='M', precision=precision), accumulated_num_params, units='M', precision=precision),
'{:.3%} Params'.format(accumulated_num_params / total_params), f'{accumulated_num_params / total_params:.3%} Params',
flops_to_string( flops_to_string(
accumulated_flops_cost, units=units, precision=precision), accumulated_flops_cost, units=units, precision=precision),
'{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops), f'{accumulated_flops_cost / total_flops:.3%} FLOPs',
self.original_extra_repr() self.original_extra_repr()
]) ])

View File

@ -129,7 +129,7 @@ def _get_bases_name(m):
return [b.__name__ for b in m.__class__.__bases__] return [b.__name__ for b in m.__class__.__bases__]
class BaseInit(object): class BaseInit:
def __init__(self, *, bias=0, bias_prob=None, layer=None): def __init__(self, *, bias=0, bias_prob=None, layer=None):
self.wholemodule = False self.wholemodule = False
@ -461,7 +461,7 @@ class Caffe2XavierInit(KaimingInit):
@INITIALIZERS.register_module(name='Pretrained') @INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object): class PretrainedInit:
"""Initialize module by loading a pretrained model. """Initialize module by loading a pretrained model.
Args: Args:

View File

@ -70,7 +70,7 @@ class VGG(nn.Module):
bn_frozen=False, bn_frozen=False,
ceil_mode=False, ceil_mode=False,
with_last_pool=True): with_last_pool=True):
super(VGG, self).__init__() super().__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for vgg') raise KeyError(f'invalid depth {depth} for vgg')
assert num_stages >= 1 and num_stages <= 5 assert num_stages >= 1 and num_stages <= 5
@ -157,7 +157,7 @@ class VGG(nn.Module):
return tuple(outs) return tuple(outs)
def train(self, mode=True): def train(self, mode=True):
super(VGG, self).train(mode) super().train(mode)
if self.bn_eval: if self.bn_eval:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):

View File

@ -33,7 +33,7 @@ class MLUDataParallel(MMDataParallel):
""" """
def __init__(self, *args, dim=0, **kwargs): def __init__(self, *args, dim=0, **kwargs):
super(MLUDataParallel, self).__init__(*args, dim=dim, **kwargs) super().__init__(*args, dim=dim, **kwargs)
self.device_ids = [0] self.device_ids = [0]
self.src_device_obj = torch.device('mlu:0') self.src_device_obj = torch.device('mlu:0')

View File

@ -210,9 +210,9 @@ class PetrelBackend(BaseStorageBackend):
""" """
if not has_method(self._client, 'delete'): if not has_method(self._client, 'delete'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `delete` method, please use a higher version or dev' 'the `delete` method, please use a higher version or dev'
' branch instead.')) ' branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
@ -230,9 +230,9 @@ class PetrelBackend(BaseStorageBackend):
if not (has_method(self._client, 'contains') if not (has_method(self._client, 'contains')
and has_method(self._client, 'isdir')): and has_method(self._client, 'isdir')):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `contains` and `isdir` methods, please use a higher' 'the `contains` and `isdir` methods, please use a higher'
'version or dev branch instead.')) 'version or dev branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
@ -251,9 +251,9 @@ class PetrelBackend(BaseStorageBackend):
""" """
if not has_method(self._client, 'isdir'): if not has_method(self._client, 'isdir'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `isdir` method, please use a higher version or dev' 'the `isdir` method, please use a higher version or dev'
' branch instead.')) ' branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
@ -271,9 +271,9 @@ class PetrelBackend(BaseStorageBackend):
""" """
if not has_method(self._client, 'contains'): if not has_method(self._client, 'contains'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `contains` method, please use a higher version or ' 'the `contains` method, please use a higher version or '
'dev branch instead.')) 'dev branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
@ -366,9 +366,9 @@ class PetrelBackend(BaseStorageBackend):
""" """
if not has_method(self._client, 'list'): if not has_method(self._client, 'list'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `list` method, please use a higher version or dev' 'the `list` method, please use a higher version or dev'
' branch instead.')) ' branch instead.')
dir_path = self._map_path(dir_path) dir_path = self._map_path(dir_path)
dir_path = self._format_path(dir_path) dir_path = self._format_path(dir_path)
@ -549,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns: Returns:
str: Expected text reading from ``filepath``. str: Expected text reading from ``filepath``.
""" """
with open(filepath, 'r', encoding=encoding) as f: with open(filepath, encoding=encoding) as f:
value_buf = f.read() value_buf = f.read()
return value_buf return value_buf

View File

@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler):
return pickle.load(file, **kwargs) return pickle.load(file, **kwargs)
def load_from_path(self, filepath, **kwargs): def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path( return super().load_from_path(filepath, mode='rb', **kwargs)
filepath, mode='rb', **kwargs)
def dump_to_str(self, obj, **kwargs): def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('protocol', 2) kwargs.setdefault('protocol', 2)
@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler):
pickle.dump(obj, file, **kwargs) pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs): def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path( super().dump_to_path(obj, filepath, mode='wb', **kwargs)
obj, filepath, mode='wb', **kwargs)

View File

@ -157,7 +157,7 @@ def imresize_to_multiple(img,
size = _scale_size((w, h), scale_factor) size = _scale_size((w, h), scale_factor)
divisor = to_2tuple(divisor) divisor = to_2tuple(divisor)
size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)]) size = tuple(int(np.ceil(s / d)) * d for s, d in zip(size, divisor))
resized_img, w_scale, h_scale = imresize( resized_img, w_scale, h_scale = imresize(
img, img,
size, size,

View File

@ -59,7 +59,7 @@ def _parse_arg(value, desc):
raise RuntimeError( raise RuntimeError(
"ONNX symbolic doesn't know to interpret ListConstruct node") "ONNX symbolic doesn't know to interpret ListConstruct node")
raise RuntimeError('Unexpected node type: {}'.format(value.node().kind())) raise RuntimeError(f'Unexpected node type: {value.node().kind()}')
def _maybe_get_const(value, desc): def _maybe_get_const(value, desc):

View File

@ -86,7 +86,7 @@ class BorderAlign(nn.Module):
""" """
def __init__(self, pool_size): def __init__(self, pool_size):
super(BorderAlign, self).__init__() super().__init__()
self.pool_size = pool_size self.pool_size = pool_size
def forward(self, input, boxes): def forward(self, input, boxes):

View File

@ -131,7 +131,7 @@ def box_iou_rotated(bboxes1,
if aligned: if aligned:
ious = bboxes1.new_zeros(rows) ious = bboxes1.new_zeros(rows)
else: else:
ious = bboxes1.new_zeros((rows * cols)) ious = bboxes1.new_zeros(rows * cols)
if not clockwise: if not clockwise:
flip_mat = bboxes1.new_ones(bboxes1.shape[-1]) flip_mat = bboxes1.new_ones(bboxes1.shape[-1])
flip_mat[-1] = -1 flip_mat[-1] = -1

View File

@ -85,7 +85,7 @@ carafe_naive = CARAFENaiveFunction.apply
class CARAFENaive(Module): class CARAFENaive(Module):
def __init__(self, kernel_size, group_size, scale_factor): def __init__(self, kernel_size, group_size, scale_factor):
super(CARAFENaive, self).__init__() super().__init__()
assert isinstance(kernel_size, int) and isinstance( assert isinstance(kernel_size, int) and isinstance(
group_size, int) and isinstance(scale_factor, int) group_size, int) and isinstance(scale_factor, int)
@ -195,7 +195,7 @@ class CARAFE(Module):
""" """
def __init__(self, kernel_size, group_size, scale_factor): def __init__(self, kernel_size, group_size, scale_factor):
super(CARAFE, self).__init__() super().__init__()
assert isinstance(kernel_size, int) and isinstance( assert isinstance(kernel_size, int) and isinstance(
group_size, int) and isinstance(scale_factor, int) group_size, int) and isinstance(scale_factor, int)
@ -238,7 +238,7 @@ class CARAFEPack(nn.Module):
encoder_kernel=3, encoder_kernel=3,
encoder_dilation=1, encoder_dilation=1,
compressed_channels=64): compressed_channels=64):
super(CARAFEPack, self).__init__() super().__init__()
self.channels = channels self.channels = channels
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.up_kernel = up_kernel self.up_kernel = up_kernel

View File

@ -125,7 +125,7 @@ class CornerPool(nn.Module):
} }
def __init__(self, mode): def __init__(self, mode):
super(CornerPool, self).__init__() super().__init__()
assert mode in self.pool_functions assert mode in self.pool_functions
self.mode = mode self.mode = mode
self.corner_pool = self.pool_functions[mode] self.corner_pool = self.pool_functions[mode]

View File

@ -236,7 +236,7 @@ class DeformConv2d(nn.Module):
deform_groups: int = 1, deform_groups: int = 1,
bias: bool = False, bias: bool = False,
im2col_step: int = 32) -> None: im2col_step: int = 32) -> None:
super(DeformConv2d, self).__init__() super().__init__()
assert not bias, \ assert not bias, \
f'bias={bias} is not supported in DeformConv2d.' f'bias={bias} is not supported in DeformConv2d.'
@ -356,7 +356,7 @@ class DeformConv2dPack(DeformConv2d):
_version = 2 _version = 2
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DeformConv2dPack, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d( self.conv_offset = nn.Conv2d(
self.in_channels, self.in_channels,
self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1], self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],

View File

@ -96,7 +96,7 @@ class DeformRoIPool(nn.Module):
spatial_scale=1.0, spatial_scale=1.0,
sampling_ratio=0, sampling_ratio=0,
gamma=0.1): gamma=0.1):
super(DeformRoIPool, self).__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
self.sampling_ratio = int(sampling_ratio) self.sampling_ratio = int(sampling_ratio)
@ -117,8 +117,7 @@ class DeformRoIPoolPack(DeformRoIPool):
spatial_scale=1.0, spatial_scale=1.0,
sampling_ratio=0, sampling_ratio=0,
gamma=0.1): gamma=0.1):
super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale, super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
sampling_ratio, gamma)
self.output_channels = output_channels self.output_channels = output_channels
self.deform_fc_channels = deform_fc_channels self.deform_fc_channels = deform_fc_channels
@ -158,8 +157,7 @@ class ModulatedDeformRoIPoolPack(DeformRoIPool):
spatial_scale=1.0, spatial_scale=1.0,
sampling_ratio=0, sampling_ratio=0,
gamma=0.1): gamma=0.1):
super(ModulatedDeformRoIPoolPack, super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
self).__init__(output_size, spatial_scale, sampling_ratio, gamma)
self.output_channels = output_channels self.output_channels = output_channels
self.deform_fc_channels = deform_fc_channels self.deform_fc_channels = deform_fc_channels

View File

@ -89,7 +89,7 @@ sigmoid_focal_loss = SigmoidFocalLossFunction.apply
class SigmoidFocalLoss(nn.Module): class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'): def __init__(self, gamma, alpha, weight=None, reduction='mean'):
super(SigmoidFocalLoss, self).__init__() super().__init__()
self.gamma = gamma self.gamma = gamma
self.alpha = alpha self.alpha = alpha
self.register_buffer('weight', weight) self.register_buffer('weight', weight)
@ -195,7 +195,7 @@ softmax_focal_loss = SoftmaxFocalLossFunction.apply
class SoftmaxFocalLoss(nn.Module): class SoftmaxFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'): def __init__(self, gamma, alpha, weight=None, reduction='mean'):
super(SoftmaxFocalLoss, self).__init__() super().__init__()
self.gamma = gamma self.gamma = gamma
self.alpha = alpha self.alpha = alpha
self.register_buffer('weight', weight) self.register_buffer('weight', weight)

View File

@ -212,7 +212,7 @@ class FusedBiasLeakyReLU(nn.Module):
""" """
def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5): def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
super(FusedBiasLeakyReLU, self).__init__() super().__init__()
self.bias = nn.Parameter(torch.zeros(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels))
self.negative_slope = negative_slope self.negative_slope = negative_slope

View File

@ -98,13 +98,12 @@ class MaskedConv2d(nn.Conv2d):
dilation=1, dilation=1,
groups=1, groups=1,
bias=True): bias=True):
super(MaskedConv2d, super().__init__(in_channels, out_channels, kernel_size, stride,
self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
padding, dilation, groups, bias)
def forward(self, input, mask=None): def forward(self, input, mask=None):
if mask is None: # fallback to the normal Conv2d if mask is None: # fallback to the normal Conv2d
return super(MaskedConv2d, self).forward(input) return super().forward(input)
else: else:
return masked_conv2d(input, mask, self.weight, self.bias, return masked_conv2d(input, mask, self.weight, self.bias,
self.padding) self.padding)

View File

@ -53,7 +53,7 @@ class BaseMergeCell(nn.Module):
input_conv_cfg=None, input_conv_cfg=None,
input_norm_cfg=None, input_norm_cfg=None,
upsample_mode='nearest'): upsample_mode='nearest'):
super(BaseMergeCell, self).__init__() super().__init__()
assert upsample_mode in ['nearest', 'bilinear'] assert upsample_mode in ['nearest', 'bilinear']
self.with_out_conv = with_out_conv self.with_out_conv = with_out_conv
self.with_input1_conv = with_input1_conv self.with_input1_conv = with_input1_conv
@ -121,7 +121,7 @@ class BaseMergeCell(nn.Module):
class SumCell(BaseMergeCell): class SumCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, **kwargs):
super(SumCell, self).__init__(in_channels, out_channels, **kwargs) super().__init__(in_channels, out_channels, **kwargs)
def _binary_op(self, x1, x2): def _binary_op(self, x1, x2):
return x1 + x2 return x1 + x2
@ -130,8 +130,7 @@ class SumCell(BaseMergeCell):
class ConcatCell(BaseMergeCell): class ConcatCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, **kwargs):
super(ConcatCell, self).__init__(in_channels * 2, out_channels, super().__init__(in_channels * 2, out_channels, **kwargs)
**kwargs)
def _binary_op(self, x1, x2): def _binary_op(self, x1, x2):
ret = torch.cat([x1, x2], dim=1) ret = torch.cat([x1, x2], dim=1)

View File

@ -168,7 +168,7 @@ class ModulatedDeformConv2d(nn.Module):
groups=1, groups=1,
deform_groups=1, deform_groups=1,
bias=True): bias=True):
super(ModulatedDeformConv2d, self).__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.kernel_size = _pair(kernel_size) self.kernel_size = _pair(kernel_size)
@ -227,7 +227,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
_version = 2 _version = 2
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d( self.conv_offset = nn.Conv2d(
self.in_channels, self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1], self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
@ -239,7 +239,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self):
super(ModulatedDeformConv2dPack, self).init_weights() super().init_weights()
if hasattr(self, 'conv_offset'): if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_() self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_() self.conv_offset.bias.data.zero_()

View File

@ -296,7 +296,7 @@ class SimpleRoIAlign(nn.Module):
If True, align the results more perfectly. If True, align the results more perfectly.
""" """
super(SimpleRoIAlign, self).__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
# to be consistent with other RoI ops # to be consistent with other RoI ops

View File

@ -72,7 +72,7 @@ psa_mask = PSAMaskFunction.apply
class PSAMask(nn.Module): class PSAMask(nn.Module):
def __init__(self, psa_type, mask_size=None): def __init__(self, psa_type, mask_size=None):
super(PSAMask, self).__init__() super().__init__()
assert psa_type in ['collect', 'distribute'] assert psa_type in ['collect', 'distribute']
if psa_type == 'collect': if psa_type == 'collect':
psa_type_enum = 0 psa_type_enum = 0

View File

@ -116,7 +116,7 @@ class RiRoIAlignRotated(nn.Module):
num_samples=0, num_samples=0,
num_orientations=8, num_orientations=8,
clockwise=False): clockwise=False):
super(RiRoIAlignRotated, self).__init__() super().__init__()
self.out_size = out_size self.out_size = out_size
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)

View File

@ -181,7 +181,7 @@ class RoIAlign(nn.Module):
pool_mode='avg', pool_mode='avg',
aligned=True, aligned=True,
use_torchvision=False): use_torchvision=False):
super(RoIAlign, self).__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)

View File

@ -156,7 +156,7 @@ class RoIAlignRotated(nn.Module):
sampling_ratio=0, sampling_ratio=0,
aligned=True, aligned=True,
clockwise=False): clockwise=False):
super(RoIAlignRotated, self).__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)

View File

@ -71,7 +71,7 @@ roi_pool = RoIPoolFunction.apply
class RoIPool(nn.Module): class RoIPool(nn.Module):
def __init__(self, output_size, spatial_scale=1.0): def __init__(self, output_size, spatial_scale=1.0):
super(RoIPool, self).__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)

View File

@ -64,7 +64,7 @@ class SparseConvolution(SparseModule):
inverse=False, inverse=False,
indice_key=None, indice_key=None,
fused_bn=False): fused_bn=False):
super(SparseConvolution, self).__init__() super().__init__()
assert groups == 1 assert groups == 1
if not isinstance(kernel_size, (list, tuple)): if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * ndim kernel_size = [kernel_size] * ndim
@ -217,7 +217,7 @@ class SparseConv2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SparseConv2d, self).__init__( super().__init__(
2, 2,
in_channels, in_channels,
out_channels, out_channels,
@ -243,7 +243,7 @@ class SparseConv3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SparseConv3d, self).__init__( super().__init__(
3, 3,
in_channels, in_channels,
out_channels, out_channels,
@ -269,7 +269,7 @@ class SparseConv4d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SparseConv4d, self).__init__( super().__init__(
4, 4,
in_channels, in_channels,
out_channels, out_channels,
@ -295,7 +295,7 @@ class SparseConvTranspose2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SparseConvTranspose2d, self).__init__( super().__init__(
2, 2,
in_channels, in_channels,
out_channels, out_channels,
@ -322,7 +322,7 @@ class SparseConvTranspose3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SparseConvTranspose3d, self).__init__( super().__init__(
3, 3,
in_channels, in_channels,
out_channels, out_channels,
@ -345,7 +345,7 @@ class SparseInverseConv2d(SparseConvolution):
kernel_size, kernel_size,
indice_key=None, indice_key=None,
bias=True): bias=True):
super(SparseInverseConv2d, self).__init__( super().__init__(
2, 2,
in_channels, in_channels,
out_channels, out_channels,
@ -364,7 +364,7 @@ class SparseInverseConv3d(SparseConvolution):
kernel_size, kernel_size,
indice_key=None, indice_key=None,
bias=True): bias=True):
super(SparseInverseConv3d, self).__init__( super().__init__(
3, 3,
in_channels, in_channels,
out_channels, out_channels,
@ -387,7 +387,7 @@ class SubMConv2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SubMConv2d, self).__init__( super().__init__(
2, 2,
in_channels, in_channels,
out_channels, out_channels,
@ -414,7 +414,7 @@ class SubMConv3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SubMConv3d, self).__init__( super().__init__(
3, 3,
in_channels, in_channels,
out_channels, out_channels,
@ -441,7 +441,7 @@ class SubMConv4d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SubMConv4d, self).__init__( super().__init__(
4, 4,
in_channels, in_channels,
out_channels, out_channels,

View File

@ -86,7 +86,7 @@ class SparseSequential(SparseModule):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(SparseSequential, self).__init__() super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict): if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items(): for key, module in args[0].items():
self.add_module(key, module) self.add_module(key, module)
@ -103,7 +103,7 @@ class SparseSequential(SparseModule):
def __getitem__(self, idx): def __getitem__(self, idx):
if not (-len(self) <= idx < len(self)): if not (-len(self) <= idx < len(self)):
raise IndexError('index {} is out of range'.format(idx)) raise IndexError(f'index {idx} is out of range')
if idx < 0: if idx < 0:
idx += len(self) idx += len(self)
it = iter(self._modules.values()) it = iter(self._modules.values())

View File

@ -29,7 +29,7 @@ class SparseMaxPool(SparseModule):
padding=0, padding=0,
dilation=1, dilation=1,
subm=False): subm=False):
super(SparseMaxPool, self).__init__() super().__init__()
if not isinstance(kernel_size, (list, tuple)): if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * ndim kernel_size = [kernel_size] * ndim
if not isinstance(stride, (list, tuple)): if not isinstance(stride, (list, tuple)):
@ -77,12 +77,10 @@ class SparseMaxPool(SparseModule):
class SparseMaxPool2d(SparseMaxPool): class SparseMaxPool2d(SparseMaxPool):
def __init__(self, kernel_size, stride=1, padding=0, dilation=1): def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool2d, self).__init__(2, kernel_size, stride, padding, super().__init__(2, kernel_size, stride, padding, dilation)
dilation)
class SparseMaxPool3d(SparseMaxPool): class SparseMaxPool3d(SparseMaxPool):
def __init__(self, kernel_size, stride=1, padding=0, dilation=1): def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool3d, self).__init__(3, kernel_size, stride, padding, super().__init__(3, kernel_size, stride, padding, dilation)
dilation)

View File

@ -18,7 +18,7 @@ def scatter_nd(indices, updates, shape):
return ret return ret
class SparseConvTensor(object): class SparseConvTensor:
def __init__(self, def __init__(self,
features, features,

View File

@ -198,7 +198,7 @@ class SyncBatchNorm(Module):
track_running_stats=True, track_running_stats=True,
group=None, group=None,
stats_mode='default'): stats_mode='default'):
super(SyncBatchNorm, self).__init__() super().__init__()
self.num_features = num_features self.num_features = num_features
self.eps = eps self.eps = eps
self.momentum = momentum self.momentum = momentum

View File

@ -32,7 +32,7 @@ class MMDataParallel(DataParallel):
""" """
def __init__(self, *args, dim=0, **kwargs): def __init__(self, *args, dim=0, **kwargs):
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs) super().__init__(*args, dim=dim, **kwargs)
self.dim = dim self.dim = dim
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):

View File

@ -18,7 +18,7 @@ class MMDistributedDataParallel(nn.Module):
dim=0, dim=0,
broadcast_buffers=True, broadcast_buffers=True,
bucket_cap_mb=25): bucket_cap_mb=25):
super(MMDistributedDataParallel, self).__init__() super().__init__()
self.module = module self.module = module
self.dim = dim self.dim = dim
self.broadcast_buffers = broadcast_buffers self.broadcast_buffers = broadcast_buffers

View File

@ -35,7 +35,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# NOTE init_cfg can be defined in different levels, but init_cfg # NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority. # in low levels has a higher priority.
super(BaseModule, self).__init__() super().__init__()
# define default value of init_cfg instead of hard code # define default value of init_cfg instead of hard code
# in init_weights() function # in init_weights() function
self._is_init = False self._is_init = False

View File

@ -83,8 +83,8 @@ class CheckpointHook(Hook):
basename = osp.basename(runner.work_dir.rstrip(osp.sep)) basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename) self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by ' runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.')) f'{self.file_client.name}.')
# disable the create_symlink option because some file backends do not # disable the create_symlink option because some file backends do not
# allow to create a symlink # allow to create a symlink
@ -93,9 +93,9 @@ class CheckpointHook(Hook):
'create_symlink'] and not self.file_client.allow_symlink: 'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False self.args['create_symlink'] = False
warnings.warn( warnings.warn(
('create_symlink is set as True by the user but is changed' 'create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not ' 'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}')) f'allowed in {self.file_client.name}')
else: else:
self.args['create_symlink'] = self.file_client.allow_symlink self.args['create_symlink'] = self.file_client.allow_symlink

View File

@ -214,8 +214,8 @@ class EvalHook(Hook):
basename = osp.basename(runner.work_dir.rstrip(osp.sep)) basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename) self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info( runner.logger.info(
(f'The best checkpoint will be saved to {self.out_dir} by ' f'The best checkpoint will be saved to {self.out_dir} by '
f'{self.file_client.name}')) f'{self.file_client.name}')
if self.save_best is not None: if self.save_best is not None:
if runner.meta is None: if runner.meta is None:
@ -335,8 +335,8 @@ class EvalHook(Hook):
self.best_ckpt_path): self.best_ckpt_path):
self.file_client.remove(self.best_ckpt_path) self.file_client.remove(self.best_ckpt_path)
runner.logger.info( runner.logger.info(
(f'The previous best checkpoint {self.best_ckpt_path} was ' f'The previous best checkpoint {self.best_ckpt_path} was '
'removed')) 'removed')
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth' best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
self.best_ckpt_path = self.file_client.join_path( self.best_ckpt_path = self.file_client.join_path(

View File

@ -34,8 +34,7 @@ class ClearMLLoggerHook(LoggerHook):
ignore_last=True, ignore_last=True,
reset_flag=False, reset_flag=False,
by_epoch=True): by_epoch=True):
super(ClearMLLoggerHook, self).__init__(interval, ignore_last, super().__init__(interval, ignore_last, reset_flag, by_epoch)
reset_flag, by_epoch)
self.import_clearml() self.import_clearml()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
@ -49,7 +48,7 @@ class ClearMLLoggerHook(LoggerHook):
@master_only @master_only
def before_run(self, runner): def before_run(self, runner):
super(ClearMLLoggerHook, self).before_run(runner) super().before_run(runner)
task_kwargs = self.init_kwargs if self.init_kwargs else {} task_kwargs = self.init_kwargs if self.init_kwargs else {}
self.task = self.clearml.Task.init(**task_kwargs) self.task = self.clearml.Task.init(**task_kwargs)
self.task_logger = self.task.get_logger() self.task_logger = self.task.get_logger()

View File

@ -40,8 +40,7 @@ class MlflowLoggerHook(LoggerHook):
ignore_last=True, ignore_last=True,
reset_flag=False, reset_flag=False,
by_epoch=True): by_epoch=True):
super(MlflowLoggerHook, self).__init__(interval, ignore_last, super().__init__(interval, ignore_last, reset_flag, by_epoch)
reset_flag, by_epoch)
self.import_mlflow() self.import_mlflow()
self.exp_name = exp_name self.exp_name = exp_name
self.tags = tags self.tags = tags
@ -59,7 +58,7 @@ class MlflowLoggerHook(LoggerHook):
@master_only @master_only
def before_run(self, runner): def before_run(self, runner):
super(MlflowLoggerHook, self).before_run(runner) super().before_run(runner)
if self.exp_name is not None: if self.exp_name is not None:
self.mlflow.set_experiment(self.exp_name) self.mlflow.set_experiment(self.exp_name)
if self.tags is not None: if self.tags is not None:

View File

@ -49,8 +49,7 @@ class NeptuneLoggerHook(LoggerHook):
with_step=True, with_step=True,
by_epoch=True): by_epoch=True):
super(NeptuneLoggerHook, self).__init__(interval, ignore_last, super().__init__(interval, ignore_last, reset_flag, by_epoch)
reset_flag, by_epoch)
self.import_neptune() self.import_neptune()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
self.with_step = with_step self.with_step = with_step

View File

@ -40,8 +40,7 @@ class PaviLoggerHook(LoggerHook):
reset_flag=False, reset_flag=False,
by_epoch=True, by_epoch=True,
img_key='img_info'): img_key='img_info'):
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag, super().__init__(interval, ignore_last, reset_flag, by_epoch)
by_epoch)
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
self.add_graph = add_graph self.add_graph = add_graph
self.add_last_ckpt = add_last_ckpt self.add_last_ckpt = add_last_ckpt
@ -49,7 +48,7 @@ class PaviLoggerHook(LoggerHook):
@master_only @master_only
def before_run(self, runner): def before_run(self, runner):
super(PaviLoggerHook, self).before_run(runner) super().before_run(runner)
try: try:
from pavi import SummaryWriter from pavi import SummaryWriter
except ImportError: except ImportError:

View File

@ -27,8 +27,7 @@ class SegmindLoggerHook(LoggerHook):
ignore_last=True, ignore_last=True,
reset_flag=False, reset_flag=False,
by_epoch=True): by_epoch=True):
super(SegmindLoggerHook, self).__init__(interval, ignore_last, super().__init__(interval, ignore_last, reset_flag, by_epoch)
reset_flag, by_epoch)
self.import_segmind() self.import_segmind()
def import_segmind(self): def import_segmind(self):

View File

@ -28,13 +28,12 @@ class TensorboardLoggerHook(LoggerHook):
ignore_last=True, ignore_last=True,
reset_flag=False, reset_flag=False,
by_epoch=True): by_epoch=True):
super(TensorboardLoggerHook, self).__init__(interval, ignore_last, super().__init__(interval, ignore_last, reset_flag, by_epoch)
reset_flag, by_epoch)
self.log_dir = log_dir self.log_dir = log_dir
@master_only @master_only
def before_run(self, runner): def before_run(self, runner):
super(TensorboardLoggerHook, self).before_run(runner) super().before_run(runner)
if (TORCH_VERSION == 'parrots' if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.1')): or digit_version(TORCH_VERSION) < digit_version('1.1')):
try: try:

View File

@ -62,8 +62,7 @@ class TextLoggerHook(LoggerHook):
out_suffix=('.log.json', '.log', '.py'), out_suffix=('.log.json', '.log', '.py'),
keep_local=True, keep_local=True,
file_client_args=None): file_client_args=None):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag, super().__init__(interval, ignore_last, reset_flag, by_epoch)
by_epoch)
self.by_epoch = by_epoch self.by_epoch = by_epoch
self.time_sec_tot = 0 self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name self.interval_exp_name = interval_exp_name
@ -87,7 +86,7 @@ class TextLoggerHook(LoggerHook):
self.out_dir) self.out_dir)
def before_run(self, runner): def before_run(self, runner):
super(TextLoggerHook, self).before_run(runner) super().before_run(runner)
if self.out_dir is not None: if self.out_dir is not None:
self.file_client = FileClient.infer_client(self.file_client_args, self.file_client = FileClient.infer_client(self.file_client_args,
@ -97,8 +96,8 @@ class TextLoggerHook(LoggerHook):
basename = osp.basename(runner.work_dir.rstrip(osp.sep)) basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename) self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info( runner.logger.info(
(f'Text logs will be saved to {self.out_dir} by ' f'Text logs will be saved to {self.out_dir} by '
f'{self.file_client.name} after the training process.')) f'{self.file_client.name} after the training process.')
self.start_iter = runner.iter self.start_iter = runner.iter
self.json_log_path = osp.join(runner.work_dir, self.json_log_path = osp.join(runner.work_dir,
@ -242,15 +241,15 @@ class TextLoggerHook(LoggerHook):
local_filepath = osp.join(runner.work_dir, filename) local_filepath = osp.join(runner.work_dir, filename)
out_filepath = self.file_client.join_path( out_filepath = self.file_client.join_path(
self.out_dir, filename) self.out_dir, filename)
with open(local_filepath, 'r') as f: with open(local_filepath) as f:
self.file_client.put_text(f.read(), out_filepath) self.file_client.put_text(f.read(), out_filepath)
runner.logger.info( runner.logger.info(
(f'The file {local_filepath} has been uploaded to ' f'The file {local_filepath} has been uploaded to '
f'{out_filepath}.')) f'{out_filepath}.')
if not self.keep_local: if not self.keep_local:
os.remove(local_filepath) os.remove(local_filepath)
runner.logger.info( runner.logger.info(
(f'{local_filepath} was removed due to the ' f'{local_filepath} was removed due to the '
'`self.keep_local=False`')) '`self.keep_local=False`')

View File

@ -57,8 +57,7 @@ class WandbLoggerHook(LoggerHook):
with_step=True, with_step=True,
log_artifact=True, log_artifact=True,
out_suffix=('.log.json', '.log', '.py')): out_suffix=('.log.json', '.log', '.py')):
super(WandbLoggerHook, self).__init__(interval, ignore_last, super().__init__(interval, ignore_last, reset_flag, by_epoch)
reset_flag, by_epoch)
self.import_wandb() self.import_wandb()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
self.commit = commit self.commit = commit
@ -76,7 +75,7 @@ class WandbLoggerHook(LoggerHook):
@master_only @master_only
def before_run(self, runner): def before_run(self, runner):
super(WandbLoggerHook, self).before_run(runner) super().before_run(runner)
if self.wandb is None: if self.wandb is None:
self.import_wandb() self.import_wandb()
if self.init_kwargs: if self.init_kwargs:

View File

@ -157,7 +157,7 @@ class LrUpdaterHook(Hook):
class FixedLrUpdaterHook(LrUpdaterHook): class FixedLrUpdaterHook(LrUpdaterHook):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(FixedLrUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
return base_lr return base_lr
@ -188,7 +188,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
self.step = step self.step = step
self.gamma = gamma self.gamma = gamma
self.min_lr = min_lr self.min_lr = min_lr
super(StepLrUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter progress = runner.epoch if self.by_epoch else runner.iter
@ -215,7 +215,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, **kwargs): def __init__(self, gamma, **kwargs):
self.gamma = gamma self.gamma = gamma
super(ExpLrUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter progress = runner.epoch if self.by_epoch else runner.iter
@ -228,7 +228,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
def __init__(self, power=1., min_lr=0., **kwargs): def __init__(self, power=1., min_lr=0., **kwargs):
self.power = power self.power = power
self.min_lr = min_lr self.min_lr = min_lr
super(PolyLrUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
if self.by_epoch: if self.by_epoch:
@ -247,7 +247,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, power=1., **kwargs): def __init__(self, gamma, power=1., **kwargs):
self.gamma = gamma self.gamma = gamma
self.power = power self.power = power
super(InvLrUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter progress = runner.epoch if self.by_epoch else runner.iter
@ -269,7 +269,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
assert (min_lr is None) ^ (min_lr_ratio is None) assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio self.min_lr_ratio = min_lr_ratio
super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
if self.by_epoch: if self.by_epoch:
@ -317,7 +317,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
self.start_percent = start_percent self.start_percent = start_percent
self.min_lr = min_lr self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio self.min_lr_ratio = min_lr_ratio
super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
if self.by_epoch: if self.by_epoch:
@ -367,7 +367,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
self.restart_weights = restart_weights self.restart_weights = restart_weights
assert (len(self.periods) == len(self.restart_weights) assert (len(self.periods) == len(self.restart_weights)
), 'periods and restart_weights should have the same length.' ), 'periods and restart_weights should have the same length.'
super(CosineRestartLrUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
self.cumulative_periods = [ self.cumulative_periods = [
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
@ -484,10 +484,10 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
assert not by_epoch, \ assert not by_epoch, \
'currently only support "by_epoch" = False' 'currently only support "by_epoch" = False'
super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs) super().__init__(by_epoch, **kwargs)
def before_run(self, runner): def before_run(self, runner):
super(CyclicLrUpdaterHook, self).before_run(runner) super().before_run(runner)
# initiate lr_phases # initiate lr_phases
# total lr_phases are separated as up and down # total lr_phases are separated as up and down
self.max_iter_per_phase = runner.max_iters // self.cyclic_times self.max_iter_per_phase = runner.max_iters // self.cyclic_times
@ -598,7 +598,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
self.final_div_factor = final_div_factor self.final_div_factor = final_div_factor
self.three_phase = three_phase self.three_phase = three_phase
self.lr_phases = [] # init lr_phases self.lr_phases = [] # init lr_phases
super(OneCycleLrUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def before_run(self, runner): def before_run(self, runner):
if hasattr(self, 'total_steps'): if hasattr(self, 'total_steps'):
@ -668,7 +668,7 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
assert (min_lr is None) ^ (min_lr_ratio is None) assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio self.min_lr_ratio = min_lr_ratio
super(LinearAnnealingLrUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
if self.by_epoch: if self.by_epoch:

View File

@ -176,7 +176,7 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook):
self.step = step self.step = step
self.gamma = gamma self.gamma = gamma
self.min_momentum = min_momentum self.min_momentum = min_momentum
super(StepMomentumUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum): def get_momentum(self, runner, base_momentum):
progress = runner.epoch if self.by_epoch else runner.iter progress = runner.epoch if self.by_epoch else runner.iter
@ -214,7 +214,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
assert (min_momentum is None) ^ (min_momentum_ratio is None) assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio self.min_momentum_ratio = min_momentum_ratio
super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum): def get_momentum(self, runner, base_momentum):
if self.by_epoch: if self.by_epoch:
@ -247,7 +247,7 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
assert (min_momentum is None) ^ (min_momentum_ratio is None) assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio self.min_momentum_ratio = min_momentum_ratio
super(LinearAnnealingMomentumUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum): def get_momentum(self, runner, base_momentum):
if self.by_epoch: if self.by_epoch:
@ -328,10 +328,10 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
# currently only support by_epoch=False # currently only support by_epoch=False
assert not by_epoch, \ assert not by_epoch, \
'currently only support "by_epoch" = False' 'currently only support "by_epoch" = False'
super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs) super().__init__(by_epoch, **kwargs)
def before_run(self, runner): def before_run(self, runner):
super(CyclicMomentumUpdaterHook, self).before_run(runner) super().before_run(runner)
# initiate momentum_phases # initiate momentum_phases
# total momentum_phases are separated as up and down # total momentum_phases are separated as up and down
max_iter_per_phase = runner.max_iters // self.cyclic_times max_iter_per_phase = runner.max_iters // self.cyclic_times
@ -439,7 +439,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
self.anneal_func = annealing_linear self.anneal_func = annealing_linear
self.three_phase = three_phase self.three_phase = three_phase
self.momentum_phases = [] # init momentum_phases self.momentum_phases = [] # init momentum_phases
super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs) super().__init__(**kwargs)
def before_run(self, runner): def before_run(self, runner):
if isinstance(runner.optimizer, dict): if isinstance(runner.optimizer, dict):

View File

@ -110,7 +110,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
""" """
def __init__(self, cumulative_iters=1, **kwargs): def __init__(self, cumulative_iters=1, **kwargs):
super(GradientCumulativeOptimizerHook, self).__init__(**kwargs) super().__init__(**kwargs)
assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \ assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
f'cumulative_iters only accepts positive int, but got ' \ f'cumulative_iters only accepts positive int, but got ' \
@ -297,8 +297,7 @@ if (TORCH_VERSION != 'parrots'
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(GradientCumulativeFp16OptimizerHook, super().__init__(*args, **kwargs)
self).__init__(*args, **kwargs)
def after_train_iter(self, runner): def after_train_iter(self, runner):
if not self.initialized: if not self.initialized:
@ -490,8 +489,7 @@ else:
iters gradient cumulating.""" iters gradient cumulating."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(GradientCumulativeFp16OptimizerHook, super().__init__(*args, **kwargs)
self).__init__(*args, **kwargs)
def after_train_iter(self, runner): def after_train_iter(self, runner):
if not self.initialized: if not self.initialized:

View File

@ -263,7 +263,7 @@ class IterBasedRunner(BaseRunner):
if log_config is not None: if log_config is not None:
for info in log_config['hooks']: for info in log_config['hooks']:
info.setdefault('by_epoch', False) info.setdefault('by_epoch', False)
super(IterBasedRunner, self).register_training_hooks( super().register_training_hooks(
lr_config=lr_config, lr_config=lr_config,
momentum_config=momentum_config, momentum_config=momentum_config,
optimizer_config=optimizer_config, optimizer_config=optimizer_config,

View File

@ -54,7 +54,7 @@ def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
msg += reset_style msg += reset_style
warnings.warn(msg) warnings.warn(msg)
device = torch.device('cuda:{}'.format(device_id)) device = torch.device(f'cuda:{device_id}')
# create builder and network # create builder and network
logger = trt.Logger(log_level) logger = trt.Logger(log_level)
builder = trt.Builder(logger) builder = trt.Builder(logger)
@ -209,7 +209,7 @@ class TRTWrapper(torch.nn.Module):
msg += reset_style msg += reset_style
warnings.warn(msg) warnings.warn(msg)
super(TRTWrapper, self).__init__() super().__init__()
self.engine = engine self.engine = engine
if isinstance(self.engine, str): if isinstance(self.engine, str):
self.engine = load_trt_engine(engine) self.engine = load_trt_engine(engine)

View File

@ -39,7 +39,7 @@ class ConfigDict(Dict):
def __getattr__(self, name): def __getattr__(self, name):
try: try:
value = super(ConfigDict, self).__getattr__(name) value = super().__getattr__(name)
except KeyError: except KeyError:
ex = AttributeError(f"'{self.__class__.__name__}' object has no " ex = AttributeError(f"'{self.__class__.__name__}' object has no "
f"attribute '{name}'") f"attribute '{name}'")
@ -96,7 +96,7 @@ class Config:
@staticmethod @staticmethod
def _validate_py_syntax(filename): def _validate_py_syntax(filename):
with open(filename, 'r', encoding='utf-8') as f: with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows # Setting encoding explicitly to resolve coding issue on windows
content = f.read() content = f.read()
try: try:
@ -116,7 +116,7 @@ class Config:
fileBasename=file_basename, fileBasename=file_basename,
fileBasenameNoExtension=file_basename_no_extension, fileBasenameNoExtension=file_basename_no_extension,
fileExtname=file_extname) fileExtname=file_extname)
with open(filename, 'r', encoding='utf-8') as f: with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows # Setting encoding explicitly to resolve coding issue on windows
config_file = f.read() config_file = f.read()
for key, value in support_templates.items(): for key, value in support_templates.items():
@ -130,7 +130,7 @@ class Config:
def _pre_substitute_base_vars(filename, temp_config_name): def _pre_substitute_base_vars(filename, temp_config_name):
"""Substitute base variable placehoders to string, so that parsing """Substitute base variable placehoders to string, so that parsing
would work.""" would work."""
with open(filename, 'r', encoding='utf-8') as f: with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows # Setting encoding explicitly to resolve coding issue on windows
config_file = f.read() config_file = f.read()
base_var_dict = {} base_var_dict = {}
@ -183,7 +183,7 @@ class Config:
check_file_exist(filename) check_file_exist(filename)
fileExtname = osp.splitext(filename)[1] fileExtname = osp.splitext(filename)[1]
if fileExtname not in ['.py', '.json', '.yaml', '.yml']: if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!') raise OSError('Only py/yml/yaml/json type are supported now!')
with tempfile.TemporaryDirectory() as temp_config_dir: with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile( temp_config_file = tempfile.NamedTemporaryFile(
@ -236,7 +236,7 @@ class Config:
warnings.warn(warning_msg, DeprecationWarning) warnings.warn(warning_msg, DeprecationWarning)
cfg_text = filename + '\n' cfg_text = filename + '\n'
with open(filename, 'r', encoding='utf-8') as f: with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows # Setting encoding explicitly to resolve coding issue on windows
cfg_text += f.read() cfg_text += f.read()
@ -356,7 +356,7 @@ class Config:
:obj:`Config`: Config obj. :obj:`Config`: Config obj.
""" """
if file_format not in ['.py', '.json', '.yaml', '.yml']: if file_format not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!') raise OSError('Only py/yml/yaml/json type are supported now!')
if file_format != '.py' and 'dict(' in cfg_str: if file_format != '.py' and 'dict(' in cfg_str:
# check if users specify a wrong suffix for python # check if users specify a wrong suffix for python
warnings.warn( warnings.warn(
@ -396,16 +396,16 @@ class Config:
if isinstance(filename, Path): if isinstance(filename, Path):
filename = str(filename) filename = str(filename)
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) super().__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename) super().__setattr__('_filename', filename)
if cfg_text: if cfg_text:
text = cfg_text text = cfg_text
elif filename: elif filename:
with open(filename, 'r') as f: with open(filename) as f:
text = f.read() text = f.read()
else: else:
text = '' text = ''
super(Config, self).__setattr__('_text', text) super().__setattr__('_text', text)
@property @property
def filename(self): def filename(self):
@ -556,9 +556,9 @@ class Config:
def __setstate__(self, state): def __setstate__(self, state):
_cfg_dict, _filename, _text = state _cfg_dict, _filename, _text = state
super(Config, self).__setattr__('_cfg_dict', _cfg_dict) super().__setattr__('_cfg_dict', _cfg_dict)
super(Config, self).__setattr__('_filename', _filename) super().__setattr__('_filename', _filename)
super(Config, self).__setattr__('_text', _text) super().__setattr__('_text', _text)
def dump(self, file=None): def dump(self, file=None):
"""Dumps config into a file or returns a string representation of the """Dumps config into a file or returns a string representation of the
@ -584,7 +584,7 @@ class Config:
will be dumped. Defaults to None. will be dumped. Defaults to None.
""" """
import mmcv import mmcv
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
if file is None: if file is None:
if self.filename is None or self.filename.endswith('.py'): if self.filename is None or self.filename.endswith('.py'):
return self.pretty_text return self.pretty_text
@ -638,8 +638,8 @@ class Config:
subkey = key_list[-1] subkey = key_list[-1]
d[subkey] = v d[subkey] = v
cfg_dict = super(Config, self).__getattribute__('_cfg_dict') cfg_dict = super().__getattribute__('_cfg_dict')
super(Config, self).__setattr__( super().__setattr__(
'_cfg_dict', '_cfg_dict',
Config._merge_a_into_b( Config._merge_a_into_b(
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))

View File

@ -6,7 +6,7 @@ class TimerError(Exception):
def __init__(self, message): def __init__(self, message):
self.message = message self.message = message
super(TimerError, self).__init__(message) super().__init__(message)
class Timer: class Timer:

View File

@ -40,10 +40,10 @@ def flowread(flow_or_path: Union[np.ndarray, str],
try: try:
header = f.read(4).decode('utf-8') header = f.read(4).decode('utf-8')
except Exception: except Exception:
raise IOError(f'Invalid flow file: {flow_or_path}') raise OSError(f'Invalid flow file: {flow_or_path}')
else: else:
if header != 'PIEH': if header != 'PIEH':
raise IOError(f'Invalid flow file: {flow_or_path}, ' raise OSError(f'Invalid flow file: {flow_or_path}, '
'header does not contain PIEH') 'header does not contain PIEH')
w = np.fromfile(f, np.int32, 1).squeeze() w = np.fromfile(f, np.int32, 1).squeeze()
@ -53,7 +53,7 @@ def flowread(flow_or_path: Union[np.ndarray, str],
assert concat_axis in [0, 1] assert concat_axis in [0, 1]
cat_flow = imread(flow_or_path, flag='unchanged') cat_flow = imread(flow_or_path, flag='unchanged')
if cat_flow.ndim != 2: if cat_flow.ndim != 2:
raise IOError( raise OSError(
f'{flow_or_path} is not a valid quantized flow file, ' f'{flow_or_path} is not a valid quantized flow file, '
f'its dimension is {cat_flow.ndim}.') f'its dimension is {cat_flow.ndim}.')
assert cat_flow.shape[concat_axis] % 2 == 0 assert cat_flow.shape[concat_axis] % 2 == 0
@ -86,7 +86,7 @@ def flowwrite(flow: np.ndarray,
""" """
if not quantize: if not quantize:
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
f.write('PIEH'.encode('utf-8')) f.write(b'PIEH')
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32) flow = flow.astype(np.float32)
flow.tofile(f) flow.tofile(f)
@ -146,7 +146,7 @@ def dequantize_flow(dx: np.ndarray,
assert dx.shape == dy.shape assert dx.shape == dy.shape
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] dx, dy = (dequantize(d, -max_val, max_val, 255) for d in [dx, dy])
if denorm: if denorm:
dx *= dx.shape[1] dx *= dx.shape[1]

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np

View File

@ -39,7 +39,7 @@ def choose_requirement(primary, secondary):
def get_version(): def get_version():
version_file = 'mmcv/version.py' version_file = 'mmcv/version.py'
with open(version_file, 'r', encoding='utf-8') as f: with open(version_file, encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec')) exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__'] return locals()['__version__']
@ -94,12 +94,11 @@ def parse_requirements(fname='requirements/runtime.txt', with_version=True):
yield info yield info
def parse_require_file(fpath): def parse_require_file(fpath):
with open(fpath, 'r') as f: with open(fpath) as f:
for line in f.readlines(): for line in f.readlines():
line = line.strip() line = line.strip()
if line and not line.startswith('#'): if line and not line.startswith('#'):
for info in parse_line(line): yield from parse_line(line)
yield info
def gen_packages_items(): def gen_packages_items():
if exists(require_fpath): if exists(require_fpath):

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division
import numpy as np import numpy as np
import pytest import pytest

View File

@ -23,7 +23,7 @@ class ExampleConv(nn.Module):
groups=1, groups=1,
bias=True, bias=True,
norm_cfg=None): norm_cfg=None):
super(ExampleConv, self).__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size

View File

@ -202,21 +202,22 @@ class TestFileClient:
# test `list_dir_or_file` # test `list_dir_or_file`
with build_temporary_directory() as tmp_dir: with build_temporary_directory() as tmp_dir:
# 1. list directories and files # 1. list directories and files
assert set(disk_backend.list_dir_or_file(tmp_dir)) == set( assert set(disk_backend.list_dir_or_file(tmp_dir)) == {
['dir1', 'dir2', 'text1.txt', 'text2.txt']) 'dir1', 'dir2', 'text1.txt', 'text2.txt'
}
# 2. list directories and files recursively # 2. list directories and files recursively
assert set(disk_backend.list_dir_or_file( assert set(disk_backend.list_dir_or_file(
tmp_dir, recursive=True)) == set([ tmp_dir, recursive=True)) == {
'dir1', 'dir1',
osp.join('dir1', 'text3.txt'), 'dir2', osp.join('dir1', 'text3.txt'), 'dir2',
osp.join('dir2', 'dir3'), osp.join('dir2', 'dir3'),
osp.join('dir2', 'dir3', 'text4.txt'), osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
]) }
# 3. only list directories # 3. only list directories
assert set( assert set(
disk_backend.list_dir_or_file( disk_backend.list_dir_or_file(
tmp_dir, list_file=False)) == set(['dir1', 'dir2']) tmp_dir, list_file=False)) == {'dir1', 'dir2'}
with pytest.raises( with pytest.raises(
TypeError, TypeError,
match='`suffix` should be None when `list_dir` is True'): match='`suffix` should be None when `list_dir` is True'):
@ -227,30 +228,30 @@ class TestFileClient:
# 4. only list directories recursively # 4. only list directories recursively
assert set( assert set(
disk_backend.list_dir_or_file( disk_backend.list_dir_or_file(
tmp_dir, list_file=False, recursive=True)) == set( tmp_dir, list_file=False, recursive=True)) == {
['dir1', 'dir2', 'dir1', 'dir2',
osp.join('dir2', 'dir3')]) osp.join('dir2', 'dir3')
}
# 5. only list files # 5. only list files
assert set(disk_backend.list_dir_or_file( assert set(disk_backend.list_dir_or_file(
tmp_dir, list_dir=False)) == set(['text1.txt', 'text2.txt']) tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
# 6. only list files recursively # 6. only list files recursively
assert set( assert set(
disk_backend.list_dir_or_file( disk_backend.list_dir_or_file(
tmp_dir, list_dir=False, recursive=True)) == set([ tmp_dir, list_dir=False, recursive=True)) == {
osp.join('dir1', 'text3.txt'), osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'), osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
]) }
# 7. only list files ending with suffix # 7. only list files ending with suffix
assert set( assert set(
disk_backend.list_dir_or_file( disk_backend.list_dir_or_file(
tmp_dir, list_dir=False, tmp_dir, list_dir=False,
suffix='.txt')) == set(['text1.txt', 'text2.txt']) suffix='.txt')) == {'text1.txt', 'text2.txt'}
assert set( assert set(
disk_backend.list_dir_or_file( disk_backend.list_dir_or_file(
tmp_dir, list_dir=False, tmp_dir, list_dir=False,
suffix=('.txt', suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
'.jpg'))) == set(['text1.txt', 'text2.txt'])
with pytest.raises( with pytest.raises(
TypeError, TypeError,
match='`suffix` must be a string or tuple of strings'): match='`suffix` must be a string or tuple of strings'):
@ -260,22 +261,22 @@ class TestFileClient:
assert set( assert set(
disk_backend.list_dir_or_file( disk_backend.list_dir_or_file(
tmp_dir, list_dir=False, suffix='.txt', tmp_dir, list_dir=False, suffix='.txt',
recursive=True)) == set([ recursive=True)) == {
osp.join('dir1', 'text3.txt'), osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt',
'text2.txt' 'text2.txt'
]) }
# 7. only list files ending with suffix # 7. only list files ending with suffix
assert set( assert set(
disk_backend.list_dir_or_file( disk_backend.list_dir_or_file(
tmp_dir, tmp_dir,
list_dir=False, list_dir=False,
suffix=('.txt', '.jpg'), suffix=('.txt', '.jpg'),
recursive=True)) == set([ recursive=True)) == {
osp.join('dir1', 'text3.txt'), osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'), osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
]) }
@patch('ceph.S3Client', MockS3Client) @patch('ceph.S3Client', MockS3Client)
def test_ceph_backend(self): def test_ceph_backend(self):
@ -463,21 +464,21 @@ class TestFileClient:
with build_temporary_directory() as tmp_dir: with build_temporary_directory() as tmp_dir:
# 1. list directories and files # 1. list directories and files
assert set(petrel_backend.list_dir_or_file(tmp_dir)) == set( assert set(petrel_backend.list_dir_or_file(tmp_dir)) == {
['dir1', 'dir2', 'text1.txt', 'text2.txt']) 'dir1', 'dir2', 'text1.txt', 'text2.txt'
}
# 2. list directories and files recursively # 2. list directories and files recursively
assert set( assert set(
petrel_backend.list_dir_or_file( petrel_backend.list_dir_or_file(tmp_dir, recursive=True)) == {
tmp_dir, recursive=True)) == set([ 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join(
'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', ('dir2', 'dir3')), '/'.join(
'/'.join(('dir2', 'dir3')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
]) }
# 3. only list directories # 3. only list directories
assert set( assert set(
petrel_backend.list_dir_or_file( petrel_backend.list_dir_or_file(
tmp_dir, list_file=False)) == set(['dir1', 'dir2']) tmp_dir, list_file=False)) == {'dir1', 'dir2'}
with pytest.raises( with pytest.raises(
TypeError, TypeError,
match=('`list_dir` should be False when `suffix` is not ' match=('`list_dir` should be False when `suffix` is not '
@ -489,31 +490,30 @@ class TestFileClient:
# 4. only list directories recursively # 4. only list directories recursively
assert set( assert set(
petrel_backend.list_dir_or_file( petrel_backend.list_dir_or_file(
tmp_dir, list_file=False, recursive=True)) == set( tmp_dir, list_file=False, recursive=True)) == {
['dir1', 'dir2', '/'.join(('dir2', 'dir3'))]) 'dir1', 'dir2', '/'.join(('dir2', 'dir3'))
}
# 5. only list files # 5. only list files
assert set( assert set(
petrel_backend.list_dir_or_file(tmp_dir, petrel_backend.list_dir_or_file(
list_dir=False)) == set( tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
['text1.txt', 'text2.txt'])
# 6. only list files recursively # 6. only list files recursively
assert set( assert set(
petrel_backend.list_dir_or_file( petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False, recursive=True)) == set([ tmp_dir, list_dir=False, recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join( '/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
]) }
# 7. only list files ending with suffix # 7. only list files ending with suffix
assert set( assert set(
petrel_backend.list_dir_or_file( petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False, tmp_dir, list_dir=False,
suffix='.txt')) == set(['text1.txt', 'text2.txt']) suffix='.txt')) == {'text1.txt', 'text2.txt'}
assert set( assert set(
petrel_backend.list_dir_or_file( petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False, tmp_dir, list_dir=False,
suffix=('.txt', suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
'.jpg'))) == set(['text1.txt', 'text2.txt'])
with pytest.raises( with pytest.raises(
TypeError, TypeError,
match='`suffix` must be a string or tuple of strings'): match='`suffix` must be a string or tuple of strings'):
@ -523,22 +523,22 @@ class TestFileClient:
assert set( assert set(
petrel_backend.list_dir_or_file( petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False, suffix='.txt', tmp_dir, list_dir=False, suffix='.txt',
recursive=True)) == set([ recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join( '/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), 'text1.txt', ('dir2', 'dir3', 'text4.txt')), 'text1.txt',
'text2.txt' 'text2.txt'
]) }
# 7. only list files ending with suffix # 7. only list files ending with suffix
assert set( assert set(
petrel_backend.list_dir_or_file( petrel_backend.list_dir_or_file(
tmp_dir, tmp_dir,
list_dir=False, list_dir=False,
suffix=('.txt', '.jpg'), suffix=('.txt', '.jpg'),
recursive=True)) == set([ recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join( '/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
]) }
@patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
@patch('mc.pyvector', MagicMock) @patch('mc.pyvector', MagicMock)

View File

@ -128,7 +128,7 @@ def test_register_handler():
assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg' assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2') tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2')
mmcv.dump(content, tmp_filename) mmcv.dump(content, tmp_filename)
with open(tmp_filename, 'r') as f: with open(tmp_filename) as f:
written = f.read() written = f.read()
os.remove(tmp_filename) os.remove(tmp_filename)
assert written == '\n' + content assert written == '\n' + content

View File

@ -6,7 +6,7 @@ import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
class TestBBox(object): class TestBBox:
def _test_bbox_overlaps(self, device='cpu', dtype=torch.float): def _test_bbox_overlaps(self, device='cpu', dtype=torch.float):
from mmcv.ops import bbox_overlaps from mmcv.ops import bbox_overlaps

View File

@ -4,7 +4,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
class TestBilinearGridSample(object): class TestBilinearGridSample:
def _test_bilinear_grid_sample(self, def _test_bilinear_grid_sample(self,
dtype=torch.float, dtype=torch.float,

View File

@ -4,7 +4,7 @@ import pytest
import torch import torch
class TestBoxIoURotated(object): class TestBoxIoURotated:
def test_box_iou_rotated_cpu(self): def test_box_iou_rotated_cpu(self):
from mmcv.ops import box_iou_rotated from mmcv.ops import box_iou_rotated

View File

@ -3,7 +3,7 @@ import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
class TestCarafe(object): class TestCarafe:
def test_carafe_naive_gradcheck(self): def test_carafe_naive_gradcheck(self):
if not torch.cuda.is_available(): if not torch.cuda.is_available():

View File

@ -15,7 +15,7 @@ class Loss(nn.Module):
return torch.mean(input - target) return torch.mean(input - target)
class TestCrissCrossAttention(object): class TestCrissCrossAttention:
def test_cc_attention(self): def test_cc_attention(self):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

View File

@ -35,7 +35,7 @@ gt_offset_bias_grad = [1.44, -0.72, 0., 0., -0.10, -0.08, -0.54, -0.54],
gt_deform_weight_grad = [[[[3.62, 0.], [0.40, 0.18]]]] gt_deform_weight_grad = [[[[3.62, 0.], [0.40, 0.18]]]]
class TestDeformconv(object): class TestDeformconv:
def _test_deformconv(self, def _test_deformconv(self,
dtype=torch.float, dtype=torch.float,

View File

@ -35,7 +35,7 @@ outputs = [([[[[1, 1.25], [1.5, 1.75]]]], [[[[3.0625, 0.4375],
0.00390625]]]])] 0.00390625]]]])]
class TestDeformRoIPool(object): class TestDeformRoIPool:
def test_deform_roi_pool_gradcheck(self): def test_deform_roi_pool_gradcheck(self):
if not torch.cuda.is_available(): if not torch.cuda.is_available():

View File

@ -37,7 +37,7 @@ sigmoid_outputs = [(0.13562961, [[-0.00657264, 0.11185755],
[-0.02462499, 0.08277918, 0.18050370]])] [-0.02462499, 0.08277918, 0.18050370]])]
class Testfocalloss(object): class Testfocalloss:
def _test_softmax(self, dtype=torch.float): def _test_softmax(self, dtype=torch.float):
if not torch.cuda.is_available(): if not torch.cuda.is_available():

View File

@ -10,7 +10,7 @@ except ImportError:
_USING_PARROTS = False _USING_PARROTS = False
class TestFusedBiasLeakyReLU(object): class TestFusedBiasLeakyReLU:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):

View File

@ -2,7 +2,7 @@
import torch import torch
class TestInfo(object): class TestInfo:
def test_info(self): def test_info(self):
if not torch.cuda.is_available(): if not torch.cuda.is_available():

View File

@ -2,7 +2,7 @@
import torch import torch
class TestMaskedConv2d(object): class TestMaskedConv2d:
def test_masked_conv2d(self): def test_masked_conv2d(self):
if not torch.cuda.is_available(): if not torch.cuda.is_available():

View File

@ -37,7 +37,7 @@ dcn_offset_b_grad = [
] ]
class TestMdconv(object): class TestMdconv:
def _test_mdconv(self, dtype=torch.float, device='cuda'): def _test_mdconv(self, dtype=torch.float, device='cuda'):
if not torch.cuda.is_available() and device == 'cuda': if not torch.cuda.is_available() and device == 'cuda':

View File

@ -55,7 +55,7 @@ def test_forward_multi_scale_deformable_attn_pytorch():
N, M, D = 1, 2, 2 N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2 Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long) shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
S = sum([(H * W).item() for H, W in shapes]) S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3) torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01 value = torch.rand(N, S, M, D) * 0.01
@ -78,7 +78,7 @@ def test_forward_equal_with_pytorch_double():
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros( level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1])) (1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes]) S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3) torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01 value = torch.rand(N, S, M, D).cuda() * 0.01
@ -111,7 +111,7 @@ def test_forward_equal_with_pytorch_float():
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros( level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1])) (1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes]) S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3) torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01 value = torch.rand(N, S, M, D).cuda() * 0.01
@ -155,7 +155,7 @@ def test_gradient_numerical(channels,
shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).cuda() shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros( level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1])) (1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes]) S = sum((H * W).item() for H, W in shapes)
value = torch.rand(N, S, M, channels).cuda() * 0.01 value = torch.rand(N, S, M, channels).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()

View File

@ -6,7 +6,7 @@ import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
class Testnms(object): class Testnms:
@pytest.mark.parametrize('device', [ @pytest.mark.parametrize('device', [
pytest.param( pytest.param(
@ -129,8 +129,7 @@ class Testnms(object):
scores = tensor_dets[:, 4] scores = tensor_dets[:, 4]
nms_keep_inds = nms(boxes.contiguous(), scores.contiguous(), nms_keep_inds = nms(boxes.contiguous(), scores.contiguous(),
iou_thr)[1] iou_thr)[1]
assert set([g[0].item() assert {g[0].item() for g in np_groups} == set(nms_keep_inds.tolist())
for g in np_groups]) == set(nms_keep_inds.tolist())
# non empty tensor input # non empty tensor input
tensor_dets = torch.from_numpy(np_dets) tensor_dets = torch.from_numpy(np_dets)

View File

@ -33,7 +33,7 @@ def run_before_and_after_test():
class WrapFunction(nn.Module): class WrapFunction(nn.Module):
def __init__(self, wrapped_function): def __init__(self, wrapped_function):
super(WrapFunction, self).__init__() super().__init__()
self.wrapped_function = wrapped_function self.wrapped_function = wrapped_function
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@ -662,7 +662,7 @@ def test_cummax_cummin(key, opset=11):
input_list = [ input_list = [
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ... # arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
torch.rand((2, 3, 4, 1, 5)), torch.rand((2, 3, 4, 1, 5)),
torch.rand((1)), torch.rand(1),
torch.rand((2, 0, 1)), # tensor.numel() is 0 torch.rand((2, 0, 1)), # tensor.numel() is 0
torch.FloatTensor(), # empty tensor torch.FloatTensor(), # empty tensor
] ]

View File

@ -15,7 +15,7 @@ class Loss(nn.Module):
return torch.mean(input - target) return torch.mean(input - target)
class TestPSAMask(object): class TestPSAMask:
def test_psa_mask_collect(self): def test_psa_mask_collect(self):
if not torch.cuda.is_available(): if not torch.cuda.is_available():

View File

@ -29,7 +29,7 @@ outputs = [([[[[1., 2.], [3., 4.]]]], [[[[1., 1.], [1., 1.]]]]),
1.]]]])] 1.]]]])]
class TestRoiPool(object): class TestRoiPool:
def test_roipool_gradcheck(self): def test_roipool_gradcheck(self):
if not torch.cuda.is_available(): if not torch.cuda.is_available():

View File

@ -14,7 +14,7 @@ else:
import re import re
class TestSyncBN(object): class TestSyncBN:
def dist_init(self): def dist_init(self):
rank = int(os.environ['SLURM_PROCID']) rank = int(os.environ['SLURM_PROCID'])

View File

@ -30,7 +30,7 @@ if not is_tensorrt_plugin_loaded():
class WrapFunction(nn.Module): class WrapFunction(nn.Module):
def __init__(self, wrapped_function): def __init__(self, wrapped_function):
super(WrapFunction, self).__init__() super().__init__()
self.wrapped_function = wrapped_function self.wrapped_function = wrapped_function
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@ -576,7 +576,7 @@ def test_cummin_cummax(func: Callable):
input_list = [ input_list = [
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ... # arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
torch.rand((2, 3, 4, 1, 5)).cuda(), torch.rand((2, 3, 4, 1, 5)).cuda(),
torch.rand((1)).cuda() torch.rand(1).cuda()
] ]
input_names = ['input'] input_names = ['input']
@ -756,7 +756,7 @@ def test_corner_pool(mode):
class CornerPoolWrapper(CornerPool): class CornerPoolWrapper(CornerPool):
def __init__(self, mode): def __init__(self, mode):
super(CornerPoolWrapper, self).__init__(mode) super().__init__(mode)
def forward(self, x): def forward(self, x):
# no use `torch.cummax`, instead `corner_pool` is used # no use `torch.cummax`, instead `corner_pool` is used

View File

@ -10,7 +10,7 @@ except ImportError:
_USING_PARROTS = False _USING_PARROTS = False
class TestUpFirDn2d(object): class TestUpFirDn2d:
"""Unit test for UpFirDn2d. """Unit test for UpFirDn2d.
Here, we just test the basic case of upsample version. More gerneal tests Here, we just test the basic case of upsample version. More gerneal tests

View File

@ -96,8 +96,8 @@ def test_voxelization_nondeterministic():
coors_all = dynamic_voxelization.forward(points) coors_all = dynamic_voxelization.forward(points)
coors_all = coors_all.cpu().detach().numpy().tolist() coors_all = coors_all.cpu().detach().numpy().tolist()
coors_set = set([tuple(c) for c in coors]) coors_set = {tuple(c) for c in coors}
coors_all_set = set([tuple(c) for c in coors_all]) coors_all_set = {tuple(c) for c in coors_all}
assert len(coors_set) == len(coors) assert len(coors_set) == len(coors)
assert len(coors_set - coors_all_set) == 0 assert len(coors_set - coors_all_set) == 0
@ -112,7 +112,7 @@ def test_voxelization_nondeterministic():
for c, ps, n in zip(coors, voxels, num_points_per_voxel): for c, ps, n in zip(coors, voxels, num_points_per_voxel):
ideal_voxel_points_set = coors_points_dict[tuple(c)] ideal_voxel_points_set = coors_points_dict[tuple(c)]
voxel_points_set = set([tuple(p) for p in ps[:n]]) voxel_points_set = {tuple(p) for p in ps[:n]}
assert len(voxel_points_set) == n assert len(voxel_points_set) == n
if n < max_num_points: if n < max_num_points:
assert voxel_points_set == ideal_voxel_points_set assert voxel_points_set == ideal_voxel_points_set
@ -133,7 +133,7 @@ def test_voxelization_nondeterministic():
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points) voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy().tolist() coors = coors.cpu().detach().numpy().tolist()
coors_set = set([tuple(c) for c in coors]) coors_set = {tuple(c) for c in coors}
coors_all_set = set([tuple(c) for c in coors_all]) coors_all_set = {tuple(c) for c in coors_all}
assert len(coors_set) == len(coors) == len(coors_all_set) assert len(coors_set) == len(coors) == len(coors_all_set)

View File

@ -63,7 +63,7 @@ def test_is_module_wrapper():
# test module wrapper registry # test module wrapper registry
@MODULE_WRAPPERS.register_module() @MODULE_WRAPPERS.register_module()
class ModuleWrapper(object): class ModuleWrapper:
def __init__(self, module): def __init__(self, module):
self.module = module self.module = module

Some files were not shown because too many files have changed in this diff Show More