mirror of https://github.com/open-mmlab/mmcv.git
Add pyupgrade pre-commit hook (#1937)
* add pyupgrade * add options for pyupgrade * minor refinementpull/1968/head
parent
c561264d55
commit
45fa3e44a2
|
@ -42,7 +42,7 @@ def parse_args():
|
|||
class SimpleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(SimpleModel, self).__init__()
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def train_step(self, *args, **kwargs):
|
||||
|
@ -159,13 +159,13 @@ def run(cfg, logger):
|
|||
def plot_lr_curve(json_file, cfg):
|
||||
data_dict = dict(LearningRate=[], Momentum=[])
|
||||
assert os.path.isfile(json_file)
|
||||
with open(json_file, 'r') as f:
|
||||
with open(json_file) as f:
|
||||
for line in f:
|
||||
log = json.loads(line.strip())
|
||||
data_dict['LearningRate'].append(log['lr'])
|
||||
data_dict['Momentum'].append(log['momentum'])
|
||||
|
||||
wind_w, wind_h = [int(size) for size in cfg.window_size.split('*')]
|
||||
wind_w, wind_h = (int(size) for size in cfg.window_size.split('*'))
|
||||
# if legend is None, use {filename}_{key} as legend
|
||||
fig, axes = plt.subplots(2, 1, figsize=(wind_w, wind_h))
|
||||
plt.subplots_adjust(hspace=0.5)
|
||||
|
|
|
@ -43,7 +43,11 @@ repos:
|
|||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.32.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: ["--py36-plus"]
|
||||
- repo: https://github.com/open-mmlab/pre-commit-hooks
|
||||
rev: v0.2.0 # Use the ref you want to point at
|
||||
hooks:
|
||||
|
|
|
@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
|
|||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
version_file = '../../mmcv/version.py'
|
||||
with open(version_file, 'r') as f:
|
||||
with open(version_file) as f:
|
||||
exec(compile(f.read(), version_file, 'exec'))
|
||||
__version__ = locals()['__version__']
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
|
|||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
version_file = '../../mmcv/version.py'
|
||||
with open(version_file, 'r') as f:
|
||||
with open(version_file) as f:
|
||||
exec(compile(f.read(), version_file, 'exec'))
|
||||
__version__ = locals()['__version__']
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from mmcv.utils import get_logger
|
|||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 6, 5)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
self.conv2 = nn.Conv2d(6, 16, 5)
|
||||
|
|
|
@ -12,7 +12,7 @@ class AlexNet(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, num_classes=-1):
|
||||
super(AlexNet, self).__init__()
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
||||
|
|
|
@ -29,7 +29,7 @@ class Clamp(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, min=-1., max=1.):
|
||||
super(Clamp, self).__init__()
|
||||
super().__init__()
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ class ContextBlock(nn.Module):
|
|||
ratio,
|
||||
pooling_type='att',
|
||||
fusion_types=('channel_add', )):
|
||||
super(ContextBlock, self).__init__()
|
||||
super().__init__()
|
||||
assert pooling_type in ['avg', 'att']
|
||||
assert isinstance(fusion_types, (list, tuple))
|
||||
valid_fusion_types = ['channel_add', 'channel_mul']
|
||||
|
|
|
@ -83,7 +83,7 @@ class ConvModule(nn.Module):
|
|||
with_spectral_norm=False,
|
||||
padding_mode='zeros',
|
||||
order=('conv', 'norm', 'act')):
|
||||
super(ConvModule, self).__init__()
|
||||
super().__init__()
|
||||
assert conv_cfg is None or isinstance(conv_cfg, dict)
|
||||
assert norm_cfg is None or isinstance(norm_cfg, dict)
|
||||
assert act_cfg is None or isinstance(act_cfg, dict)
|
||||
|
@ -96,7 +96,7 @@ class ConvModule(nn.Module):
|
|||
self.with_explicit_padding = padding_mode not in official_padding_mode
|
||||
self.order = order
|
||||
assert isinstance(self.order, tuple) and len(self.order) == 3
|
||||
assert set(order) == set(['conv', 'norm', 'act'])
|
||||
assert set(order) == {'conv', 'norm', 'act'}
|
||||
|
||||
self.with_norm = norm_cfg is not None
|
||||
self.with_activation = act_cfg is not None
|
||||
|
|
|
@ -35,7 +35,7 @@ class ConvWS2d(nn.Conv2d):
|
|||
groups=1,
|
||||
bias=True,
|
||||
eps=1e-5):
|
||||
super(ConvWS2d, self).__init__(
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
|
|
|
@ -59,7 +59,7 @@ class DepthwiseSeparableConvModule(nn.Module):
|
|||
pw_norm_cfg='default',
|
||||
pw_act_cfg='default',
|
||||
**kwargs):
|
||||
super(DepthwiseSeparableConvModule, self).__init__()
|
||||
super().__init__()
|
||||
assert 'groups' not in kwargs, 'groups should not be specified'
|
||||
|
||||
# if norm/activation config of depthwise/pointwise ConvModule is not
|
||||
|
|
|
@ -37,7 +37,7 @@ class DropPath(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, drop_prob=0.1):
|
||||
super(DropPath, self).__init__()
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -54,7 +54,7 @@ class GeneralizedAttention(nn.Module):
|
|||
q_stride=1,
|
||||
attention_type='1111'):
|
||||
|
||||
super(GeneralizedAttention, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
# hard range means local range for non-local operation
|
||||
self.position_embedding_dim = (
|
||||
|
|
|
@ -27,7 +27,7 @@ class HSigmoid(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, bias=3.0, divisor=6.0, min_value=0.0, max_value=1.0):
|
||||
super(HSigmoid, self).__init__()
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
'In MMCV v1.4.4, we modified the default value of args to align '
|
||||
'with PyTorch official. Previous Implementation: '
|
||||
|
|
|
@ -22,7 +22,7 @@ class HSwish(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super(HSwish, self).__init__()
|
||||
super().__init__()
|
||||
self.act = nn.ReLU6(inplace)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -40,7 +40,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
|||
norm_cfg=None,
|
||||
mode='embedded_gaussian',
|
||||
**kwargs):
|
||||
super(_NonLocalNd, self).__init__()
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.reduction = reduction
|
||||
self.use_scale = use_scale
|
||||
|
@ -228,8 +228,7 @@ class NonLocal1d(_NonLocalNd):
|
|||
sub_sample=False,
|
||||
conv_cfg=dict(type='Conv1d'),
|
||||
**kwargs):
|
||||
super(NonLocal1d, self).__init__(
|
||||
in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
|
||||
self.sub_sample = sub_sample
|
||||
|
||||
|
@ -262,8 +261,7 @@ class NonLocal2d(_NonLocalNd):
|
|||
sub_sample=False,
|
||||
conv_cfg=dict(type='Conv2d'),
|
||||
**kwargs):
|
||||
super(NonLocal2d, self).__init__(
|
||||
in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
|
||||
self.sub_sample = sub_sample
|
||||
|
||||
|
@ -293,8 +291,7 @@ class NonLocal3d(_NonLocalNd):
|
|||
sub_sample=False,
|
||||
conv_cfg=dict(type='Conv3d'),
|
||||
**kwargs):
|
||||
super(NonLocal3d, self).__init__(
|
||||
in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||
self.sub_sample = sub_sample
|
||||
|
||||
if sub_sample:
|
||||
|
|
|
@ -14,7 +14,7 @@ class Scale(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, scale=1.0):
|
||||
super(Scale, self).__init__()
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -19,7 +19,7 @@ class Swish(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Swish, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
|
|
@ -96,7 +96,7 @@ class AdaptivePadding(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
|
||||
super(AdaptivePadding, self).__init__()
|
||||
super().__init__()
|
||||
assert padding in ('same', 'corner')
|
||||
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
|
@ -190,7 +190,7 @@ class PatchEmbed(BaseModule):
|
|||
norm_cfg=None,
|
||||
input_size=None,
|
||||
init_cfg=None):
|
||||
super(PatchEmbed, self).__init__(init_cfg=init_cfg)
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
if stride is None:
|
||||
|
@ -435,7 +435,7 @@ class MultiheadAttention(BaseModule):
|
|||
init_cfg=None,
|
||||
batch_first=False,
|
||||
**kwargs):
|
||||
super(MultiheadAttention, self).__init__(init_cfg)
|
||||
super().__init__(init_cfg)
|
||||
if 'dropout' in kwargs:
|
||||
warnings.warn(
|
||||
'The arguments `dropout` in MultiheadAttention '
|
||||
|
@ -590,7 +590,7 @@ class FFN(BaseModule):
|
|||
add_identity=True,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super(FFN, self).__init__(init_cfg)
|
||||
super().__init__(init_cfg)
|
||||
assert num_fcs >= 2, 'num_fcs should be no less ' \
|
||||
f'than 2. got {num_fcs}.'
|
||||
self.embed_dims = embed_dims
|
||||
|
@ -694,12 +694,12 @@ class BaseTransformerLayer(BaseModule):
|
|||
f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
|
||||
ffn_cfgs[new_name] = kwargs[ori_name]
|
||||
|
||||
super(BaseTransformerLayer, self).__init__(init_cfg)
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.batch_first = batch_first
|
||||
|
||||
assert set(operation_order) & set(
|
||||
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
|
||||
assert set(operation_order) & {
|
||||
'self_attn', 'norm', 'ffn', 'cross_attn'} == \
|
||||
set(operation_order), f'The operation_order of' \
|
||||
f' {self.__class__.__name__} should ' \
|
||||
f'contains all four operation type ' \
|
||||
|
@ -880,7 +880,7 @@ class TransformerLayerSequence(BaseModule):
|
|||
"""
|
||||
|
||||
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
|
||||
super(TransformerLayerSequence, self).__init__(init_cfg)
|
||||
super().__init__(init_cfg)
|
||||
if isinstance(transformerlayers, dict):
|
||||
transformerlayers = [
|
||||
copy.deepcopy(transformerlayers) for _ in range(num_layers)
|
||||
|
|
|
@ -26,7 +26,7 @@ class PixelShufflePack(nn.Module):
|
|||
|
||||
def __init__(self, in_channels, out_channels, scale_factor,
|
||||
upsample_kernel):
|
||||
super(PixelShufflePack, self).__init__()
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.scale_factor = scale_factor
|
||||
|
|
|
@ -30,7 +30,7 @@ class BasicBlock(nn.Module):
|
|||
downsample=None,
|
||||
style='pytorch',
|
||||
with_cp=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
super().__init__()
|
||||
assert style in ['pytorch', 'caffe']
|
||||
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
|
@ -77,7 +77,7 @@ class Bottleneck(nn.Module):
|
|||
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
|
||||
it is "caffe", the stride-two layer is the first 1x1 conv layer.
|
||||
"""
|
||||
super(Bottleneck, self).__init__()
|
||||
super().__init__()
|
||||
assert style in ['pytorch', 'caffe']
|
||||
if style == 'pytorch':
|
||||
conv1_stride = 1
|
||||
|
@ -218,7 +218,7 @@ class ResNet(nn.Module):
|
|||
bn_eval=True,
|
||||
bn_frozen=False,
|
||||
with_cp=False):
|
||||
super(ResNet, self).__init__()
|
||||
super().__init__()
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for resnet')
|
||||
assert num_stages >= 1 and num_stages <= 4
|
||||
|
@ -293,7 +293,7 @@ class ResNet(nn.Module):
|
|||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super(ResNet, self).train(mode)
|
||||
super().train(mode)
|
||||
if self.bn_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
|
|
|
@ -277,10 +277,10 @@ def print_model_with_flops(model,
|
|||
return ', '.join([
|
||||
params_to_string(
|
||||
accumulated_num_params, units='M', precision=precision),
|
||||
'{:.3%} Params'.format(accumulated_num_params / total_params),
|
||||
f'{accumulated_num_params / total_params:.3%} Params',
|
||||
flops_to_string(
|
||||
accumulated_flops_cost, units=units, precision=precision),
|
||||
'{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
|
||||
f'{accumulated_flops_cost / total_flops:.3%} FLOPs',
|
||||
self.original_extra_repr()
|
||||
])
|
||||
|
||||
|
|
|
@ -129,7 +129,7 @@ def _get_bases_name(m):
|
|||
return [b.__name__ for b in m.__class__.__bases__]
|
||||
|
||||
|
||||
class BaseInit(object):
|
||||
class BaseInit:
|
||||
|
||||
def __init__(self, *, bias=0, bias_prob=None, layer=None):
|
||||
self.wholemodule = False
|
||||
|
@ -461,7 +461,7 @@ class Caffe2XavierInit(KaimingInit):
|
|||
|
||||
|
||||
@INITIALIZERS.register_module(name='Pretrained')
|
||||
class PretrainedInit(object):
|
||||
class PretrainedInit:
|
||||
"""Initialize module by loading a pretrained model.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -70,7 +70,7 @@ class VGG(nn.Module):
|
|||
bn_frozen=False,
|
||||
ceil_mode=False,
|
||||
with_last_pool=True):
|
||||
super(VGG, self).__init__()
|
||||
super().__init__()
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for vgg')
|
||||
assert num_stages >= 1 and num_stages <= 5
|
||||
|
@ -157,7 +157,7 @@ class VGG(nn.Module):
|
|||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super(VGG, self).train(mode)
|
||||
super().train(mode)
|
||||
if self.bn_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
|
|
|
@ -33,7 +33,7 @@ class MLUDataParallel(MMDataParallel):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, dim=0, **kwargs):
|
||||
super(MLUDataParallel, self).__init__(*args, dim=dim, **kwargs)
|
||||
super().__init__(*args, dim=dim, **kwargs)
|
||||
self.device_ids = [0]
|
||||
self.src_device_obj = torch.device('mlu:0')
|
||||
|
||||
|
|
|
@ -210,9 +210,9 @@ class PetrelBackend(BaseStorageBackend):
|
|||
"""
|
||||
if not has_method(self._client, 'delete'):
|
||||
raise NotImplementedError(
|
||||
('Current version of Petrel Python SDK has not supported '
|
||||
'the `delete` method, please use a higher version or dev'
|
||||
' branch instead.'))
|
||||
'Current version of Petrel Python SDK has not supported '
|
||||
'the `delete` method, please use a higher version or dev'
|
||||
' branch instead.')
|
||||
|
||||
filepath = self._map_path(filepath)
|
||||
filepath = self._format_path(filepath)
|
||||
|
@ -230,9 +230,9 @@ class PetrelBackend(BaseStorageBackend):
|
|||
if not (has_method(self._client, 'contains')
|
||||
and has_method(self._client, 'isdir')):
|
||||
raise NotImplementedError(
|
||||
('Current version of Petrel Python SDK has not supported '
|
||||
'the `contains` and `isdir` methods, please use a higher'
|
||||
'version or dev branch instead.'))
|
||||
'Current version of Petrel Python SDK has not supported '
|
||||
'the `contains` and `isdir` methods, please use a higher'
|
||||
'version or dev branch instead.')
|
||||
|
||||
filepath = self._map_path(filepath)
|
||||
filepath = self._format_path(filepath)
|
||||
|
@ -251,9 +251,9 @@ class PetrelBackend(BaseStorageBackend):
|
|||
"""
|
||||
if not has_method(self._client, 'isdir'):
|
||||
raise NotImplementedError(
|
||||
('Current version of Petrel Python SDK has not supported '
|
||||
'the `isdir` method, please use a higher version or dev'
|
||||
' branch instead.'))
|
||||
'Current version of Petrel Python SDK has not supported '
|
||||
'the `isdir` method, please use a higher version or dev'
|
||||
' branch instead.')
|
||||
|
||||
filepath = self._map_path(filepath)
|
||||
filepath = self._format_path(filepath)
|
||||
|
@ -271,9 +271,9 @@ class PetrelBackend(BaseStorageBackend):
|
|||
"""
|
||||
if not has_method(self._client, 'contains'):
|
||||
raise NotImplementedError(
|
||||
('Current version of Petrel Python SDK has not supported '
|
||||
'the `contains` method, please use a higher version or '
|
||||
'dev branch instead.'))
|
||||
'Current version of Petrel Python SDK has not supported '
|
||||
'the `contains` method, please use a higher version or '
|
||||
'dev branch instead.')
|
||||
|
||||
filepath = self._map_path(filepath)
|
||||
filepath = self._format_path(filepath)
|
||||
|
@ -366,9 +366,9 @@ class PetrelBackend(BaseStorageBackend):
|
|||
"""
|
||||
if not has_method(self._client, 'list'):
|
||||
raise NotImplementedError(
|
||||
('Current version of Petrel Python SDK has not supported '
|
||||
'the `list` method, please use a higher version or dev'
|
||||
' branch instead.'))
|
||||
'Current version of Petrel Python SDK has not supported '
|
||||
'the `list` method, please use a higher version or dev'
|
||||
' branch instead.')
|
||||
|
||||
dir_path = self._map_path(dir_path)
|
||||
dir_path = self._format_path(dir_path)
|
||||
|
@ -549,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend):
|
|||
Returns:
|
||||
str: Expected text reading from ``filepath``.
|
||||
"""
|
||||
with open(filepath, 'r', encoding=encoding) as f:
|
||||
with open(filepath, encoding=encoding) as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
|
|
|
@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler):
|
|||
return pickle.load(file, **kwargs)
|
||||
|
||||
def load_from_path(self, filepath, **kwargs):
|
||||
return super(PickleHandler, self).load_from_path(
|
||||
filepath, mode='rb', **kwargs)
|
||||
return super().load_from_path(filepath, mode='rb', **kwargs)
|
||||
|
||||
def dump_to_str(self, obj, **kwargs):
|
||||
kwargs.setdefault('protocol', 2)
|
||||
|
@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler):
|
|||
pickle.dump(obj, file, **kwargs)
|
||||
|
||||
def dump_to_path(self, obj, filepath, **kwargs):
|
||||
super(PickleHandler, self).dump_to_path(
|
||||
obj, filepath, mode='wb', **kwargs)
|
||||
super().dump_to_path(obj, filepath, mode='wb', **kwargs)
|
||||
|
|
|
@ -157,7 +157,7 @@ def imresize_to_multiple(img,
|
|||
size = _scale_size((w, h), scale_factor)
|
||||
|
||||
divisor = to_2tuple(divisor)
|
||||
size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
|
||||
size = tuple(int(np.ceil(s / d)) * d for s, d in zip(size, divisor))
|
||||
resized_img, w_scale, h_scale = imresize(
|
||||
img,
|
||||
size,
|
||||
|
|
|
@ -59,7 +59,7 @@ def _parse_arg(value, desc):
|
|||
raise RuntimeError(
|
||||
"ONNX symbolic doesn't know to interpret ListConstruct node")
|
||||
|
||||
raise RuntimeError('Unexpected node type: {}'.format(value.node().kind()))
|
||||
raise RuntimeError(f'Unexpected node type: {value.node().kind()}')
|
||||
|
||||
|
||||
def _maybe_get_const(value, desc):
|
||||
|
|
|
@ -86,7 +86,7 @@ class BorderAlign(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, pool_size):
|
||||
super(BorderAlign, self).__init__()
|
||||
super().__init__()
|
||||
self.pool_size = pool_size
|
||||
|
||||
def forward(self, input, boxes):
|
||||
|
|
|
@ -131,7 +131,7 @@ def box_iou_rotated(bboxes1,
|
|||
if aligned:
|
||||
ious = bboxes1.new_zeros(rows)
|
||||
else:
|
||||
ious = bboxes1.new_zeros((rows * cols))
|
||||
ious = bboxes1.new_zeros(rows * cols)
|
||||
if not clockwise:
|
||||
flip_mat = bboxes1.new_ones(bboxes1.shape[-1])
|
||||
flip_mat[-1] = -1
|
||||
|
|
|
@ -85,7 +85,7 @@ carafe_naive = CARAFENaiveFunction.apply
|
|||
class CARAFENaive(Module):
|
||||
|
||||
def __init__(self, kernel_size, group_size, scale_factor):
|
||||
super(CARAFENaive, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(kernel_size, int) and isinstance(
|
||||
group_size, int) and isinstance(scale_factor, int)
|
||||
|
@ -195,7 +195,7 @@ class CARAFE(Module):
|
|||
"""
|
||||
|
||||
def __init__(self, kernel_size, group_size, scale_factor):
|
||||
super(CARAFE, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(kernel_size, int) and isinstance(
|
||||
group_size, int) and isinstance(scale_factor, int)
|
||||
|
@ -238,7 +238,7 @@ class CARAFEPack(nn.Module):
|
|||
encoder_kernel=3,
|
||||
encoder_dilation=1,
|
||||
compressed_channels=64):
|
||||
super(CARAFEPack, self).__init__()
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.scale_factor = scale_factor
|
||||
self.up_kernel = up_kernel
|
||||
|
|
|
@ -125,7 +125,7 @@ class CornerPool(nn.Module):
|
|||
}
|
||||
|
||||
def __init__(self, mode):
|
||||
super(CornerPool, self).__init__()
|
||||
super().__init__()
|
||||
assert mode in self.pool_functions
|
||||
self.mode = mode
|
||||
self.corner_pool = self.pool_functions[mode]
|
||||
|
|
|
@ -236,7 +236,7 @@ class DeformConv2d(nn.Module):
|
|||
deform_groups: int = 1,
|
||||
bias: bool = False,
|
||||
im2col_step: int = 32) -> None:
|
||||
super(DeformConv2d, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
assert not bias, \
|
||||
f'bias={bias} is not supported in DeformConv2d.'
|
||||
|
@ -356,7 +356,7 @@ class DeformConv2dPack(DeformConv2d):
|
|||
_version = 2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DeformConv2dPack, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_offset = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
|
||||
|
|
|
@ -96,7 +96,7 @@ class DeformRoIPool(nn.Module):
|
|||
spatial_scale=1.0,
|
||||
sampling_ratio=0,
|
||||
gamma=0.1):
|
||||
super(DeformRoIPool, self).__init__()
|
||||
super().__init__()
|
||||
self.output_size = _pair(output_size)
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
self.sampling_ratio = int(sampling_ratio)
|
||||
|
@ -117,8 +117,7 @@ class DeformRoIPoolPack(DeformRoIPool):
|
|||
spatial_scale=1.0,
|
||||
sampling_ratio=0,
|
||||
gamma=0.1):
|
||||
super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale,
|
||||
sampling_ratio, gamma)
|
||||
super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
|
||||
|
||||
self.output_channels = output_channels
|
||||
self.deform_fc_channels = deform_fc_channels
|
||||
|
@ -158,8 +157,7 @@ class ModulatedDeformRoIPoolPack(DeformRoIPool):
|
|||
spatial_scale=1.0,
|
||||
sampling_ratio=0,
|
||||
gamma=0.1):
|
||||
super(ModulatedDeformRoIPoolPack,
|
||||
self).__init__(output_size, spatial_scale, sampling_ratio, gamma)
|
||||
super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
|
||||
|
||||
self.output_channels = output_channels
|
||||
self.deform_fc_channels = deform_fc_channels
|
||||
|
|
|
@ -89,7 +89,7 @@ sigmoid_focal_loss = SigmoidFocalLossFunction.apply
|
|||
class SigmoidFocalLoss(nn.Module):
|
||||
|
||||
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
|
||||
super(SigmoidFocalLoss, self).__init__()
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.register_buffer('weight', weight)
|
||||
|
@ -195,7 +195,7 @@ softmax_focal_loss = SoftmaxFocalLossFunction.apply
|
|||
class SoftmaxFocalLoss(nn.Module):
|
||||
|
||||
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
|
||||
super(SoftmaxFocalLoss, self).__init__()
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.register_buffer('weight', weight)
|
||||
|
|
|
@ -212,7 +212,7 @@ class FusedBiasLeakyReLU(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
|
||||
super(FusedBiasLeakyReLU, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.negative_slope = negative_slope
|
||||
|
|
|
@ -98,13 +98,12 @@ class MaskedConv2d(nn.Conv2d):
|
|||
dilation=1,
|
||||
groups=1,
|
||||
bias=True):
|
||||
super(MaskedConv2d,
|
||||
self).__init__(in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, groups, bias)
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, groups, bias)
|
||||
|
||||
def forward(self, input, mask=None):
|
||||
if mask is None: # fallback to the normal Conv2d
|
||||
return super(MaskedConv2d, self).forward(input)
|
||||
return super().forward(input)
|
||||
else:
|
||||
return masked_conv2d(input, mask, self.weight, self.bias,
|
||||
self.padding)
|
||||
|
|
|
@ -53,7 +53,7 @@ class BaseMergeCell(nn.Module):
|
|||
input_conv_cfg=None,
|
||||
input_norm_cfg=None,
|
||||
upsample_mode='nearest'):
|
||||
super(BaseMergeCell, self).__init__()
|
||||
super().__init__()
|
||||
assert upsample_mode in ['nearest', 'bilinear']
|
||||
self.with_out_conv = with_out_conv
|
||||
self.with_input1_conv = with_input1_conv
|
||||
|
@ -121,7 +121,7 @@ class BaseMergeCell(nn.Module):
|
|||
class SumCell(BaseMergeCell):
|
||||
|
||||
def __init__(self, in_channels, out_channels, **kwargs):
|
||||
super(SumCell, self).__init__(in_channels, out_channels, **kwargs)
|
||||
super().__init__(in_channels, out_channels, **kwargs)
|
||||
|
||||
def _binary_op(self, x1, x2):
|
||||
return x1 + x2
|
||||
|
@ -130,8 +130,7 @@ class SumCell(BaseMergeCell):
|
|||
class ConcatCell(BaseMergeCell):
|
||||
|
||||
def __init__(self, in_channels, out_channels, **kwargs):
|
||||
super(ConcatCell, self).__init__(in_channels * 2, out_channels,
|
||||
**kwargs)
|
||||
super().__init__(in_channels * 2, out_channels, **kwargs)
|
||||
|
||||
def _binary_op(self, x1, x2):
|
||||
ret = torch.cat([x1, x2], dim=1)
|
||||
|
|
|
@ -168,7 +168,7 @@ class ModulatedDeformConv2d(nn.Module):
|
|||
groups=1,
|
||||
deform_groups=1,
|
||||
bias=True):
|
||||
super(ModulatedDeformConv2d, self).__init__()
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
|
@ -227,7 +227,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
|
|||
_version = 2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_offset = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
||||
|
@ -239,7 +239,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
|
|||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
super(ModulatedDeformConv2dPack, self).init_weights()
|
||||
super().init_weights()
|
||||
if hasattr(self, 'conv_offset'):
|
||||
self.conv_offset.weight.data.zero_()
|
||||
self.conv_offset.bias.data.zero_()
|
||||
|
|
|
@ -296,7 +296,7 @@ class SimpleRoIAlign(nn.Module):
|
|||
If True, align the results more perfectly.
|
||||
"""
|
||||
|
||||
super(SimpleRoIAlign, self).__init__()
|
||||
super().__init__()
|
||||
self.output_size = _pair(output_size)
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
# to be consistent with other RoI ops
|
||||
|
|
|
@ -72,7 +72,7 @@ psa_mask = PSAMaskFunction.apply
|
|||
class PSAMask(nn.Module):
|
||||
|
||||
def __init__(self, psa_type, mask_size=None):
|
||||
super(PSAMask, self).__init__()
|
||||
super().__init__()
|
||||
assert psa_type in ['collect', 'distribute']
|
||||
if psa_type == 'collect':
|
||||
psa_type_enum = 0
|
||||
|
|
|
@ -116,7 +116,7 @@ class RiRoIAlignRotated(nn.Module):
|
|||
num_samples=0,
|
||||
num_orientations=8,
|
||||
clockwise=False):
|
||||
super(RiRoIAlignRotated, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.out_size = out_size
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
|
|
|
@ -181,7 +181,7 @@ class RoIAlign(nn.Module):
|
|||
pool_mode='avg',
|
||||
aligned=True,
|
||||
use_torchvision=False):
|
||||
super(RoIAlign, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.output_size = _pair(output_size)
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
|
|
|
@ -156,7 +156,7 @@ class RoIAlignRotated(nn.Module):
|
|||
sampling_ratio=0,
|
||||
aligned=True,
|
||||
clockwise=False):
|
||||
super(RoIAlignRotated, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.output_size = _pair(output_size)
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
|
|
|
@ -71,7 +71,7 @@ roi_pool = RoIPoolFunction.apply
|
|||
class RoIPool(nn.Module):
|
||||
|
||||
def __init__(self, output_size, spatial_scale=1.0):
|
||||
super(RoIPool, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.output_size = _pair(output_size)
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
|
|
|
@ -64,7 +64,7 @@ class SparseConvolution(SparseModule):
|
|||
inverse=False,
|
||||
indice_key=None,
|
||||
fused_bn=False):
|
||||
super(SparseConvolution, self).__init__()
|
||||
super().__init__()
|
||||
assert groups == 1
|
||||
if not isinstance(kernel_size, (list, tuple)):
|
||||
kernel_size = [kernel_size] * ndim
|
||||
|
@ -217,7 +217,7 @@ class SparseConv2d(SparseConvolution):
|
|||
groups=1,
|
||||
bias=True,
|
||||
indice_key=None):
|
||||
super(SparseConv2d, self).__init__(
|
||||
super().__init__(
|
||||
2,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -243,7 +243,7 @@ class SparseConv3d(SparseConvolution):
|
|||
groups=1,
|
||||
bias=True,
|
||||
indice_key=None):
|
||||
super(SparseConv3d, self).__init__(
|
||||
super().__init__(
|
||||
3,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -269,7 +269,7 @@ class SparseConv4d(SparseConvolution):
|
|||
groups=1,
|
||||
bias=True,
|
||||
indice_key=None):
|
||||
super(SparseConv4d, self).__init__(
|
||||
super().__init__(
|
||||
4,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -295,7 +295,7 @@ class SparseConvTranspose2d(SparseConvolution):
|
|||
groups=1,
|
||||
bias=True,
|
||||
indice_key=None):
|
||||
super(SparseConvTranspose2d, self).__init__(
|
||||
super().__init__(
|
||||
2,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -322,7 +322,7 @@ class SparseConvTranspose3d(SparseConvolution):
|
|||
groups=1,
|
||||
bias=True,
|
||||
indice_key=None):
|
||||
super(SparseConvTranspose3d, self).__init__(
|
||||
super().__init__(
|
||||
3,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -345,7 +345,7 @@ class SparseInverseConv2d(SparseConvolution):
|
|||
kernel_size,
|
||||
indice_key=None,
|
||||
bias=True):
|
||||
super(SparseInverseConv2d, self).__init__(
|
||||
super().__init__(
|
||||
2,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -364,7 +364,7 @@ class SparseInverseConv3d(SparseConvolution):
|
|||
kernel_size,
|
||||
indice_key=None,
|
||||
bias=True):
|
||||
super(SparseInverseConv3d, self).__init__(
|
||||
super().__init__(
|
||||
3,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -387,7 +387,7 @@ class SubMConv2d(SparseConvolution):
|
|||
groups=1,
|
||||
bias=True,
|
||||
indice_key=None):
|
||||
super(SubMConv2d, self).__init__(
|
||||
super().__init__(
|
||||
2,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -414,7 +414,7 @@ class SubMConv3d(SparseConvolution):
|
|||
groups=1,
|
||||
bias=True,
|
||||
indice_key=None):
|
||||
super(SubMConv3d, self).__init__(
|
||||
super().__init__(
|
||||
3,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -441,7 +441,7 @@ class SubMConv4d(SparseConvolution):
|
|||
groups=1,
|
||||
bias=True,
|
||||
indice_key=None):
|
||||
super(SubMConv4d, self).__init__(
|
||||
super().__init__(
|
||||
4,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
|
|
@ -86,7 +86,7 @@ class SparseSequential(SparseModule):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SparseSequential, self).__init__()
|
||||
super().__init__()
|
||||
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
||||
for key, module in args[0].items():
|
||||
self.add_module(key, module)
|
||||
|
@ -103,7 +103,7 @@ class SparseSequential(SparseModule):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
if not (-len(self) <= idx < len(self)):
|
||||
raise IndexError('index {} is out of range'.format(idx))
|
||||
raise IndexError(f'index {idx} is out of range')
|
||||
if idx < 0:
|
||||
idx += len(self)
|
||||
it = iter(self._modules.values())
|
||||
|
|
|
@ -29,7 +29,7 @@ class SparseMaxPool(SparseModule):
|
|||
padding=0,
|
||||
dilation=1,
|
||||
subm=False):
|
||||
super(SparseMaxPool, self).__init__()
|
||||
super().__init__()
|
||||
if not isinstance(kernel_size, (list, tuple)):
|
||||
kernel_size = [kernel_size] * ndim
|
||||
if not isinstance(stride, (list, tuple)):
|
||||
|
@ -77,12 +77,10 @@ class SparseMaxPool(SparseModule):
|
|||
class SparseMaxPool2d(SparseMaxPool):
|
||||
|
||||
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
|
||||
super(SparseMaxPool2d, self).__init__(2, kernel_size, stride, padding,
|
||||
dilation)
|
||||
super().__init__(2, kernel_size, stride, padding, dilation)
|
||||
|
||||
|
||||
class SparseMaxPool3d(SparseMaxPool):
|
||||
|
||||
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
|
||||
super(SparseMaxPool3d, self).__init__(3, kernel_size, stride, padding,
|
||||
dilation)
|
||||
super().__init__(3, kernel_size, stride, padding, dilation)
|
||||
|
|
|
@ -18,7 +18,7 @@ def scatter_nd(indices, updates, shape):
|
|||
return ret
|
||||
|
||||
|
||||
class SparseConvTensor(object):
|
||||
class SparseConvTensor:
|
||||
|
||||
def __init__(self,
|
||||
features,
|
||||
|
|
|
@ -198,7 +198,7 @@ class SyncBatchNorm(Module):
|
|||
track_running_stats=True,
|
||||
group=None,
|
||||
stats_mode='default'):
|
||||
super(SyncBatchNorm, self).__init__()
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
|
|
|
@ -32,7 +32,7 @@ class MMDataParallel(DataParallel):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, dim=0, **kwargs):
|
||||
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
|
||||
super().__init__(*args, dim=dim, **kwargs)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
|
|
|
@ -18,7 +18,7 @@ class MMDistributedDataParallel(nn.Module):
|
|||
dim=0,
|
||||
broadcast_buffers=True,
|
||||
bucket_cap_mb=25):
|
||||
super(MMDistributedDataParallel, self).__init__()
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.dim = dim
|
||||
self.broadcast_buffers = broadcast_buffers
|
||||
|
|
|
@ -35,7 +35,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
|||
# NOTE init_cfg can be defined in different levels, but init_cfg
|
||||
# in low levels has a higher priority.
|
||||
|
||||
super(BaseModule, self).__init__()
|
||||
super().__init__()
|
||||
# define default value of init_cfg instead of hard code
|
||||
# in init_weights() function
|
||||
self._is_init = False
|
||||
|
|
|
@ -83,8 +83,8 @@ class CheckpointHook(Hook):
|
|||
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
||||
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
||||
|
||||
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
|
||||
f'{self.file_client.name}.'))
|
||||
runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
|
||||
f'{self.file_client.name}.')
|
||||
|
||||
# disable the create_symlink option because some file backends do not
|
||||
# allow to create a symlink
|
||||
|
@ -93,9 +93,9 @@ class CheckpointHook(Hook):
|
|||
'create_symlink'] and not self.file_client.allow_symlink:
|
||||
self.args['create_symlink'] = False
|
||||
warnings.warn(
|
||||
('create_symlink is set as True by the user but is changed'
|
||||
'to be False because creating symbolic link is not '
|
||||
f'allowed in {self.file_client.name}'))
|
||||
'create_symlink is set as True by the user but is changed'
|
||||
'to be False because creating symbolic link is not '
|
||||
f'allowed in {self.file_client.name}')
|
||||
else:
|
||||
self.args['create_symlink'] = self.file_client.allow_symlink
|
||||
|
||||
|
|
|
@ -214,8 +214,8 @@ class EvalHook(Hook):
|
|||
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
||||
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
||||
runner.logger.info(
|
||||
(f'The best checkpoint will be saved to {self.out_dir} by '
|
||||
f'{self.file_client.name}'))
|
||||
f'The best checkpoint will be saved to {self.out_dir} by '
|
||||
f'{self.file_client.name}')
|
||||
|
||||
if self.save_best is not None:
|
||||
if runner.meta is None:
|
||||
|
@ -335,8 +335,8 @@ class EvalHook(Hook):
|
|||
self.best_ckpt_path):
|
||||
self.file_client.remove(self.best_ckpt_path)
|
||||
runner.logger.info(
|
||||
(f'The previous best checkpoint {self.best_ckpt_path} was '
|
||||
'removed'))
|
||||
f'The previous best checkpoint {self.best_ckpt_path} was '
|
||||
'removed')
|
||||
|
||||
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
|
||||
self.best_ckpt_path = self.file_client.join_path(
|
||||
|
|
|
@ -34,8 +34,7 @@ class ClearMLLoggerHook(LoggerHook):
|
|||
ignore_last=True,
|
||||
reset_flag=False,
|
||||
by_epoch=True):
|
||||
super(ClearMLLoggerHook, self).__init__(interval, ignore_last,
|
||||
reset_flag, by_epoch)
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.import_clearml()
|
||||
self.init_kwargs = init_kwargs
|
||||
|
||||
|
@ -49,7 +48,7 @@ class ClearMLLoggerHook(LoggerHook):
|
|||
|
||||
@master_only
|
||||
def before_run(self, runner):
|
||||
super(ClearMLLoggerHook, self).before_run(runner)
|
||||
super().before_run(runner)
|
||||
task_kwargs = self.init_kwargs if self.init_kwargs else {}
|
||||
self.task = self.clearml.Task.init(**task_kwargs)
|
||||
self.task_logger = self.task.get_logger()
|
||||
|
|
|
@ -40,8 +40,7 @@ class MlflowLoggerHook(LoggerHook):
|
|||
ignore_last=True,
|
||||
reset_flag=False,
|
||||
by_epoch=True):
|
||||
super(MlflowLoggerHook, self).__init__(interval, ignore_last,
|
||||
reset_flag, by_epoch)
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.import_mlflow()
|
||||
self.exp_name = exp_name
|
||||
self.tags = tags
|
||||
|
@ -59,7 +58,7 @@ class MlflowLoggerHook(LoggerHook):
|
|||
|
||||
@master_only
|
||||
def before_run(self, runner):
|
||||
super(MlflowLoggerHook, self).before_run(runner)
|
||||
super().before_run(runner)
|
||||
if self.exp_name is not None:
|
||||
self.mlflow.set_experiment(self.exp_name)
|
||||
if self.tags is not None:
|
||||
|
|
|
@ -49,8 +49,7 @@ class NeptuneLoggerHook(LoggerHook):
|
|||
with_step=True,
|
||||
by_epoch=True):
|
||||
|
||||
super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
|
||||
reset_flag, by_epoch)
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.import_neptune()
|
||||
self.init_kwargs = init_kwargs
|
||||
self.with_step = with_step
|
||||
|
|
|
@ -40,8 +40,7 @@ class PaviLoggerHook(LoggerHook):
|
|||
reset_flag=False,
|
||||
by_epoch=True,
|
||||
img_key='img_info'):
|
||||
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
|
||||
by_epoch)
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.init_kwargs = init_kwargs
|
||||
self.add_graph = add_graph
|
||||
self.add_last_ckpt = add_last_ckpt
|
||||
|
@ -49,7 +48,7 @@ class PaviLoggerHook(LoggerHook):
|
|||
|
||||
@master_only
|
||||
def before_run(self, runner):
|
||||
super(PaviLoggerHook, self).before_run(runner)
|
||||
super().before_run(runner)
|
||||
try:
|
||||
from pavi import SummaryWriter
|
||||
except ImportError:
|
||||
|
|
|
@ -27,8 +27,7 @@ class SegmindLoggerHook(LoggerHook):
|
|||
ignore_last=True,
|
||||
reset_flag=False,
|
||||
by_epoch=True):
|
||||
super(SegmindLoggerHook, self).__init__(interval, ignore_last,
|
||||
reset_flag, by_epoch)
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.import_segmind()
|
||||
|
||||
def import_segmind(self):
|
||||
|
|
|
@ -28,13 +28,12 @@ class TensorboardLoggerHook(LoggerHook):
|
|||
ignore_last=True,
|
||||
reset_flag=False,
|
||||
by_epoch=True):
|
||||
super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
|
||||
reset_flag, by_epoch)
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.log_dir = log_dir
|
||||
|
||||
@master_only
|
||||
def before_run(self, runner):
|
||||
super(TensorboardLoggerHook, self).before_run(runner)
|
||||
super().before_run(runner)
|
||||
if (TORCH_VERSION == 'parrots'
|
||||
or digit_version(TORCH_VERSION) < digit_version('1.1')):
|
||||
try:
|
||||
|
|
|
@ -62,8 +62,7 @@ class TextLoggerHook(LoggerHook):
|
|||
out_suffix=('.log.json', '.log', '.py'),
|
||||
keep_local=True,
|
||||
file_client_args=None):
|
||||
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
|
||||
by_epoch)
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.by_epoch = by_epoch
|
||||
self.time_sec_tot = 0
|
||||
self.interval_exp_name = interval_exp_name
|
||||
|
@ -87,7 +86,7 @@ class TextLoggerHook(LoggerHook):
|
|||
self.out_dir)
|
||||
|
||||
def before_run(self, runner):
|
||||
super(TextLoggerHook, self).before_run(runner)
|
||||
super().before_run(runner)
|
||||
|
||||
if self.out_dir is not None:
|
||||
self.file_client = FileClient.infer_client(self.file_client_args,
|
||||
|
@ -97,8 +96,8 @@ class TextLoggerHook(LoggerHook):
|
|||
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
||||
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
||||
runner.logger.info(
|
||||
(f'Text logs will be saved to {self.out_dir} by '
|
||||
f'{self.file_client.name} after the training process.'))
|
||||
f'Text logs will be saved to {self.out_dir} by '
|
||||
f'{self.file_client.name} after the training process.')
|
||||
|
||||
self.start_iter = runner.iter
|
||||
self.json_log_path = osp.join(runner.work_dir,
|
||||
|
@ -242,15 +241,15 @@ class TextLoggerHook(LoggerHook):
|
|||
local_filepath = osp.join(runner.work_dir, filename)
|
||||
out_filepath = self.file_client.join_path(
|
||||
self.out_dir, filename)
|
||||
with open(local_filepath, 'r') as f:
|
||||
with open(local_filepath) as f:
|
||||
self.file_client.put_text(f.read(), out_filepath)
|
||||
|
||||
runner.logger.info(
|
||||
(f'The file {local_filepath} has been uploaded to '
|
||||
f'{out_filepath}.'))
|
||||
f'The file {local_filepath} has been uploaded to '
|
||||
f'{out_filepath}.')
|
||||
|
||||
if not self.keep_local:
|
||||
os.remove(local_filepath)
|
||||
runner.logger.info(
|
||||
(f'{local_filepath} was removed due to the '
|
||||
'`self.keep_local=False`'))
|
||||
f'{local_filepath} was removed due to the '
|
||||
'`self.keep_local=False`')
|
||||
|
|
|
@ -57,8 +57,7 @@ class WandbLoggerHook(LoggerHook):
|
|||
with_step=True,
|
||||
log_artifact=True,
|
||||
out_suffix=('.log.json', '.log', '.py')):
|
||||
super(WandbLoggerHook, self).__init__(interval, ignore_last,
|
||||
reset_flag, by_epoch)
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.import_wandb()
|
||||
self.init_kwargs = init_kwargs
|
||||
self.commit = commit
|
||||
|
@ -76,7 +75,7 @@ class WandbLoggerHook(LoggerHook):
|
|||
|
||||
@master_only
|
||||
def before_run(self, runner):
|
||||
super(WandbLoggerHook, self).before_run(runner)
|
||||
super().before_run(runner)
|
||||
if self.wandb is None:
|
||||
self.import_wandb()
|
||||
if self.init_kwargs:
|
||||
|
|
|
@ -157,7 +157,7 @@ class LrUpdaterHook(Hook):
|
|||
class FixedLrUpdaterHook(LrUpdaterHook):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(FixedLrUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
return base_lr
|
||||
|
@ -188,7 +188,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
|
|||
self.step = step
|
||||
self.gamma = gamma
|
||||
self.min_lr = min_lr
|
||||
super(StepLrUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
progress = runner.epoch if self.by_epoch else runner.iter
|
||||
|
@ -215,7 +215,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
|
|||
|
||||
def __init__(self, gamma, **kwargs):
|
||||
self.gamma = gamma
|
||||
super(ExpLrUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
progress = runner.epoch if self.by_epoch else runner.iter
|
||||
|
@ -228,7 +228,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
|
|||
def __init__(self, power=1., min_lr=0., **kwargs):
|
||||
self.power = power
|
||||
self.min_lr = min_lr
|
||||
super(PolyLrUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
if self.by_epoch:
|
||||
|
@ -247,7 +247,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
|
|||
def __init__(self, gamma, power=1., **kwargs):
|
||||
self.gamma = gamma
|
||||
self.power = power
|
||||
super(InvLrUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
progress = runner.epoch if self.by_epoch else runner.iter
|
||||
|
@ -269,7 +269,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
|
|||
assert (min_lr is None) ^ (min_lr_ratio is None)
|
||||
self.min_lr = min_lr
|
||||
self.min_lr_ratio = min_lr_ratio
|
||||
super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
if self.by_epoch:
|
||||
|
@ -317,7 +317,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
|
|||
self.start_percent = start_percent
|
||||
self.min_lr = min_lr
|
||||
self.min_lr_ratio = min_lr_ratio
|
||||
super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
if self.by_epoch:
|
||||
|
@ -367,7 +367,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
|
|||
self.restart_weights = restart_weights
|
||||
assert (len(self.periods) == len(self.restart_weights)
|
||||
), 'periods and restart_weights should have the same length.'
|
||||
super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cumulative_periods = [
|
||||
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
||||
|
@ -484,10 +484,10 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
|
|||
|
||||
assert not by_epoch, \
|
||||
'currently only support "by_epoch" = False'
|
||||
super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
|
||||
super().__init__(by_epoch, **kwargs)
|
||||
|
||||
def before_run(self, runner):
|
||||
super(CyclicLrUpdaterHook, self).before_run(runner)
|
||||
super().before_run(runner)
|
||||
# initiate lr_phases
|
||||
# total lr_phases are separated as up and down
|
||||
self.max_iter_per_phase = runner.max_iters // self.cyclic_times
|
||||
|
@ -598,7 +598,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
|
|||
self.final_div_factor = final_div_factor
|
||||
self.three_phase = three_phase
|
||||
self.lr_phases = [] # init lr_phases
|
||||
super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def before_run(self, runner):
|
||||
if hasattr(self, 'total_steps'):
|
||||
|
@ -668,7 +668,7 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
|
|||
assert (min_lr is None) ^ (min_lr_ratio is None)
|
||||
self.min_lr = min_lr
|
||||
self.min_lr_ratio = min_lr_ratio
|
||||
super(LinearAnnealingLrUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
if self.by_epoch:
|
||||
|
|
|
@ -176,7 +176,7 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook):
|
|||
self.step = step
|
||||
self.gamma = gamma
|
||||
self.min_momentum = min_momentum
|
||||
super(StepMomentumUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_momentum(self, runner, base_momentum):
|
||||
progress = runner.epoch if self.by_epoch else runner.iter
|
||||
|
@ -214,7 +214,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
|
|||
assert (min_momentum is None) ^ (min_momentum_ratio is None)
|
||||
self.min_momentum = min_momentum
|
||||
self.min_momentum_ratio = min_momentum_ratio
|
||||
super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_momentum(self, runner, base_momentum):
|
||||
if self.by_epoch:
|
||||
|
@ -247,7 +247,7 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
|
|||
assert (min_momentum is None) ^ (min_momentum_ratio is None)
|
||||
self.min_momentum = min_momentum
|
||||
self.min_momentum_ratio = min_momentum_ratio
|
||||
super(LinearAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_momentum(self, runner, base_momentum):
|
||||
if self.by_epoch:
|
||||
|
@ -328,10 +328,10 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
|
|||
# currently only support by_epoch=False
|
||||
assert not by_epoch, \
|
||||
'currently only support "by_epoch" = False'
|
||||
super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
|
||||
super().__init__(by_epoch, **kwargs)
|
||||
|
||||
def before_run(self, runner):
|
||||
super(CyclicMomentumUpdaterHook, self).before_run(runner)
|
||||
super().before_run(runner)
|
||||
# initiate momentum_phases
|
||||
# total momentum_phases are separated as up and down
|
||||
max_iter_per_phase = runner.max_iters // self.cyclic_times
|
||||
|
@ -439,7 +439,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
|
|||
self.anneal_func = annealing_linear
|
||||
self.three_phase = three_phase
|
||||
self.momentum_phases = [] # init momentum_phases
|
||||
super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def before_run(self, runner):
|
||||
if isinstance(runner.optimizer, dict):
|
||||
|
|
|
@ -110,7 +110,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
|
|||
"""
|
||||
|
||||
def __init__(self, cumulative_iters=1, **kwargs):
|
||||
super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
|
||||
f'cumulative_iters only accepts positive int, but got ' \
|
||||
|
@ -297,8 +297,7 @@ if (TORCH_VERSION != 'parrots'
|
|||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GradientCumulativeFp16OptimizerHook,
|
||||
self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
if not self.initialized:
|
||||
|
@ -490,8 +489,7 @@ else:
|
|||
iters gradient cumulating."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GradientCumulativeFp16OptimizerHook,
|
||||
self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
if not self.initialized:
|
||||
|
|
|
@ -263,7 +263,7 @@ class IterBasedRunner(BaseRunner):
|
|||
if log_config is not None:
|
||||
for info in log_config['hooks']:
|
||||
info.setdefault('by_epoch', False)
|
||||
super(IterBasedRunner, self).register_training_hooks(
|
||||
super().register_training_hooks(
|
||||
lr_config=lr_config,
|
||||
momentum_config=momentum_config,
|
||||
optimizer_config=optimizer_config,
|
||||
|
|
|
@ -54,7 +54,7 @@ def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
|
|||
msg += reset_style
|
||||
warnings.warn(msg)
|
||||
|
||||
device = torch.device('cuda:{}'.format(device_id))
|
||||
device = torch.device(f'cuda:{device_id}')
|
||||
# create builder and network
|
||||
logger = trt.Logger(log_level)
|
||||
builder = trt.Builder(logger)
|
||||
|
@ -209,7 +209,7 @@ class TRTWrapper(torch.nn.Module):
|
|||
msg += reset_style
|
||||
warnings.warn(msg)
|
||||
|
||||
super(TRTWrapper, self).__init__()
|
||||
super().__init__()
|
||||
self.engine = engine
|
||||
if isinstance(self.engine, str):
|
||||
self.engine = load_trt_engine(engine)
|
||||
|
|
|
@ -39,7 +39,7 @@ class ConfigDict(Dict):
|
|||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
value = super(ConfigDict, self).__getattr__(name)
|
||||
value = super().__getattr__(name)
|
||||
except KeyError:
|
||||
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
|
||||
f"attribute '{name}'")
|
||||
|
@ -96,7 +96,7 @@ class Config:
|
|||
|
||||
@staticmethod
|
||||
def _validate_py_syntax(filename):
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
content = f.read()
|
||||
try:
|
||||
|
@ -116,7 +116,7 @@ class Config:
|
|||
fileBasename=file_basename,
|
||||
fileBasenameNoExtension=file_basename_no_extension,
|
||||
fileExtname=file_extname)
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
config_file = f.read()
|
||||
for key, value in support_templates.items():
|
||||
|
@ -130,7 +130,7 @@ class Config:
|
|||
def _pre_substitute_base_vars(filename, temp_config_name):
|
||||
"""Substitute base variable placehoders to string, so that parsing
|
||||
would work."""
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
config_file = f.read()
|
||||
base_var_dict = {}
|
||||
|
@ -183,7 +183,7 @@ class Config:
|
|||
check_file_exist(filename)
|
||||
fileExtname = osp.splitext(filename)[1]
|
||||
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
|
||||
raise IOError('Only py/yml/yaml/json type are supported now!')
|
||||
raise OSError('Only py/yml/yaml/json type are supported now!')
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_config_dir:
|
||||
temp_config_file = tempfile.NamedTemporaryFile(
|
||||
|
@ -236,7 +236,7 @@ class Config:
|
|||
warnings.warn(warning_msg, DeprecationWarning)
|
||||
|
||||
cfg_text = filename + '\n'
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
cfg_text += f.read()
|
||||
|
||||
|
@ -356,7 +356,7 @@ class Config:
|
|||
:obj:`Config`: Config obj.
|
||||
"""
|
||||
if file_format not in ['.py', '.json', '.yaml', '.yml']:
|
||||
raise IOError('Only py/yml/yaml/json type are supported now!')
|
||||
raise OSError('Only py/yml/yaml/json type are supported now!')
|
||||
if file_format != '.py' and 'dict(' in cfg_str:
|
||||
# check if users specify a wrong suffix for python
|
||||
warnings.warn(
|
||||
|
@ -396,16 +396,16 @@ class Config:
|
|||
if isinstance(filename, Path):
|
||||
filename = str(filename)
|
||||
|
||||
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
|
||||
super(Config, self).__setattr__('_filename', filename)
|
||||
super().__setattr__('_cfg_dict', ConfigDict(cfg_dict))
|
||||
super().__setattr__('_filename', filename)
|
||||
if cfg_text:
|
||||
text = cfg_text
|
||||
elif filename:
|
||||
with open(filename, 'r') as f:
|
||||
with open(filename) as f:
|
||||
text = f.read()
|
||||
else:
|
||||
text = ''
|
||||
super(Config, self).__setattr__('_text', text)
|
||||
super().__setattr__('_text', text)
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
|
@ -556,9 +556,9 @@ class Config:
|
|||
|
||||
def __setstate__(self, state):
|
||||
_cfg_dict, _filename, _text = state
|
||||
super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
|
||||
super(Config, self).__setattr__('_filename', _filename)
|
||||
super(Config, self).__setattr__('_text', _text)
|
||||
super().__setattr__('_cfg_dict', _cfg_dict)
|
||||
super().__setattr__('_filename', _filename)
|
||||
super().__setattr__('_text', _text)
|
||||
|
||||
def dump(self, file=None):
|
||||
"""Dumps config into a file or returns a string representation of the
|
||||
|
@ -584,7 +584,7 @@ class Config:
|
|||
will be dumped. Defaults to None.
|
||||
"""
|
||||
import mmcv
|
||||
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
|
||||
cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
|
||||
if file is None:
|
||||
if self.filename is None or self.filename.endswith('.py'):
|
||||
return self.pretty_text
|
||||
|
@ -638,8 +638,8 @@ class Config:
|
|||
subkey = key_list[-1]
|
||||
d[subkey] = v
|
||||
|
||||
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
super(Config, self).__setattr__(
|
||||
cfg_dict = super().__getattribute__('_cfg_dict')
|
||||
super().__setattr__(
|
||||
'_cfg_dict',
|
||||
Config._merge_a_into_b(
|
||||
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
|
||||
|
|
|
@ -6,7 +6,7 @@ class TimerError(Exception):
|
|||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super(TimerError, self).__init__(message)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class Timer:
|
||||
|
|
|
@ -40,10 +40,10 @@ def flowread(flow_or_path: Union[np.ndarray, str],
|
|||
try:
|
||||
header = f.read(4).decode('utf-8')
|
||||
except Exception:
|
||||
raise IOError(f'Invalid flow file: {flow_or_path}')
|
||||
raise OSError(f'Invalid flow file: {flow_or_path}')
|
||||
else:
|
||||
if header != 'PIEH':
|
||||
raise IOError(f'Invalid flow file: {flow_or_path}, '
|
||||
raise OSError(f'Invalid flow file: {flow_or_path}, '
|
||||
'header does not contain PIEH')
|
||||
|
||||
w = np.fromfile(f, np.int32, 1).squeeze()
|
||||
|
@ -53,7 +53,7 @@ def flowread(flow_or_path: Union[np.ndarray, str],
|
|||
assert concat_axis in [0, 1]
|
||||
cat_flow = imread(flow_or_path, flag='unchanged')
|
||||
if cat_flow.ndim != 2:
|
||||
raise IOError(
|
||||
raise OSError(
|
||||
f'{flow_or_path} is not a valid quantized flow file, '
|
||||
f'its dimension is {cat_flow.ndim}.')
|
||||
assert cat_flow.shape[concat_axis] % 2 == 0
|
||||
|
@ -86,7 +86,7 @@ def flowwrite(flow: np.ndarray,
|
|||
"""
|
||||
if not quantize:
|
||||
with open(filename, 'wb') as f:
|
||||
f.write('PIEH'.encode('utf-8'))
|
||||
f.write(b'PIEH')
|
||||
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
|
||||
flow = flow.astype(np.float32)
|
||||
flow.tofile(f)
|
||||
|
@ -146,7 +146,7 @@ def dequantize_flow(dx: np.ndarray,
|
|||
assert dx.shape == dy.shape
|
||||
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
|
||||
|
||||
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
|
||||
dx, dy = (dequantize(d, -max_val, max_val, 255) for d in [dx, dy])
|
||||
|
||||
if denorm:
|
||||
dx *= dx.shape[1]
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from __future__ import division
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
|
7
setup.py
7
setup.py
|
@ -39,7 +39,7 @@ def choose_requirement(primary, secondary):
|
|||
|
||||
def get_version():
|
||||
version_file = 'mmcv/version.py'
|
||||
with open(version_file, 'r', encoding='utf-8') as f:
|
||||
with open(version_file, encoding='utf-8') as f:
|
||||
exec(compile(f.read(), version_file, 'exec'))
|
||||
return locals()['__version__']
|
||||
|
||||
|
@ -94,12 +94,11 @@ def parse_requirements(fname='requirements/runtime.txt', with_version=True):
|
|||
yield info
|
||||
|
||||
def parse_require_file(fpath):
|
||||
with open(fpath, 'r') as f:
|
||||
with open(fpath) as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
for info in parse_line(line):
|
||||
yield info
|
||||
yield from parse_line(line)
|
||||
|
||||
def gen_packages_items():
|
||||
if exists(require_fpath):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from __future__ import division
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
|
|
@ -23,7 +23,7 @@ class ExampleConv(nn.Module):
|
|||
groups=1,
|
||||
bias=True,
|
||||
norm_cfg=None):
|
||||
super(ExampleConv, self).__init__()
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
|
|
|
@ -202,21 +202,22 @@ class TestFileClient:
|
|||
# test `list_dir_or_file`
|
||||
with build_temporary_directory() as tmp_dir:
|
||||
# 1. list directories and files
|
||||
assert set(disk_backend.list_dir_or_file(tmp_dir)) == set(
|
||||
['dir1', 'dir2', 'text1.txt', 'text2.txt'])
|
||||
assert set(disk_backend.list_dir_or_file(tmp_dir)) == {
|
||||
'dir1', 'dir2', 'text1.txt', 'text2.txt'
|
||||
}
|
||||
# 2. list directories and files recursively
|
||||
assert set(disk_backend.list_dir_or_file(
|
||||
tmp_dir, recursive=True)) == set([
|
||||
tmp_dir, recursive=True)) == {
|
||||
'dir1',
|
||||
osp.join('dir1', 'text3.txt'), 'dir2',
|
||||
osp.join('dir2', 'dir3'),
|
||||
osp.join('dir2', 'dir3', 'text4.txt'),
|
||||
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
}
|
||||
# 3. only list directories
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_file=False)) == set(['dir1', 'dir2'])
|
||||
tmp_dir, list_file=False)) == {'dir1', 'dir2'}
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match='`suffix` should be None when `list_dir` is True'):
|
||||
|
@ -227,30 +228,30 @@ class TestFileClient:
|
|||
# 4. only list directories recursively
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_file=False, recursive=True)) == set(
|
||||
['dir1', 'dir2',
|
||||
osp.join('dir2', 'dir3')])
|
||||
tmp_dir, list_file=False, recursive=True)) == {
|
||||
'dir1', 'dir2',
|
||||
osp.join('dir2', 'dir3')
|
||||
}
|
||||
# 5. only list files
|
||||
assert set(disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False)) == set(['text1.txt', 'text2.txt'])
|
||||
tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
|
||||
# 6. only list files recursively
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False, recursive=True)) == set([
|
||||
tmp_dir, list_dir=False, recursive=True)) == {
|
||||
osp.join('dir1', 'text3.txt'),
|
||||
osp.join('dir2', 'dir3', 'text4.txt'),
|
||||
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
}
|
||||
# 7. only list files ending with suffix
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False,
|
||||
suffix='.txt')) == set(['text1.txt', 'text2.txt'])
|
||||
suffix='.txt')) == {'text1.txt', 'text2.txt'}
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False,
|
||||
suffix=('.txt',
|
||||
'.jpg'))) == set(['text1.txt', 'text2.txt'])
|
||||
suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match='`suffix` must be a string or tuple of strings'):
|
||||
|
@ -260,22 +261,22 @@ class TestFileClient:
|
|||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False, suffix='.txt',
|
||||
recursive=True)) == set([
|
||||
recursive=True)) == {
|
||||
osp.join('dir1', 'text3.txt'),
|
||||
osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt',
|
||||
'text2.txt'
|
||||
])
|
||||
}
|
||||
# 7. only list files ending with suffix
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir,
|
||||
list_dir=False,
|
||||
suffix=('.txt', '.jpg'),
|
||||
recursive=True)) == set([
|
||||
recursive=True)) == {
|
||||
osp.join('dir1', 'text3.txt'),
|
||||
osp.join('dir2', 'dir3', 'text4.txt'),
|
||||
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
}
|
||||
|
||||
@patch('ceph.S3Client', MockS3Client)
|
||||
def test_ceph_backend(self):
|
||||
|
@ -463,21 +464,21 @@ class TestFileClient:
|
|||
|
||||
with build_temporary_directory() as tmp_dir:
|
||||
# 1. list directories and files
|
||||
assert set(petrel_backend.list_dir_or_file(tmp_dir)) == set(
|
||||
['dir1', 'dir2', 'text1.txt', 'text2.txt'])
|
||||
assert set(petrel_backend.list_dir_or_file(tmp_dir)) == {
|
||||
'dir1', 'dir2', 'text1.txt', 'text2.txt'
|
||||
}
|
||||
# 2. list directories and files recursively
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, recursive=True)) == set([
|
||||
'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2',
|
||||
'/'.join(('dir2', 'dir3')), '/'.join(
|
||||
petrel_backend.list_dir_or_file(tmp_dir, recursive=True)) == {
|
||||
'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join(
|
||||
('dir2', 'dir3')), '/'.join(
|
||||
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
||||
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
}
|
||||
# 3. only list directories
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_file=False)) == set(['dir1', 'dir2'])
|
||||
tmp_dir, list_file=False)) == {'dir1', 'dir2'}
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=('`list_dir` should be False when `suffix` is not '
|
||||
|
@ -489,31 +490,30 @@ class TestFileClient:
|
|||
# 4. only list directories recursively
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_file=False, recursive=True)) == set(
|
||||
['dir1', 'dir2', '/'.join(('dir2', 'dir3'))])
|
||||
tmp_dir, list_file=False, recursive=True)) == {
|
||||
'dir1', 'dir2', '/'.join(('dir2', 'dir3'))
|
||||
}
|
||||
# 5. only list files
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(tmp_dir,
|
||||
list_dir=False)) == set(
|
||||
['text1.txt', 'text2.txt'])
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
|
||||
# 6. only list files recursively
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False, recursive=True)) == set([
|
||||
tmp_dir, list_dir=False, recursive=True)) == {
|
||||
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
||||
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
||||
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
}
|
||||
# 7. only list files ending with suffix
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False,
|
||||
suffix='.txt')) == set(['text1.txt', 'text2.txt'])
|
||||
suffix='.txt')) == {'text1.txt', 'text2.txt'}
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False,
|
||||
suffix=('.txt',
|
||||
'.jpg'))) == set(['text1.txt', 'text2.txt'])
|
||||
suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match='`suffix` must be a string or tuple of strings'):
|
||||
|
@ -523,22 +523,22 @@ class TestFileClient:
|
|||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False, suffix='.txt',
|
||||
recursive=True)) == set([
|
||||
recursive=True)) == {
|
||||
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
||||
('dir2', 'dir3', 'text4.txt')), 'text1.txt',
|
||||
'text2.txt'
|
||||
])
|
||||
}
|
||||
# 7. only list files ending with suffix
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir,
|
||||
list_dir=False,
|
||||
suffix=('.txt', '.jpg'),
|
||||
recursive=True)) == set([
|
||||
recursive=True)) == {
|
||||
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
||||
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
||||
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
}
|
||||
|
||||
@patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
|
||||
@patch('mc.pyvector', MagicMock)
|
||||
|
|
|
@ -128,7 +128,7 @@ def test_register_handler():
|
|||
assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
|
||||
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2')
|
||||
mmcv.dump(content, tmp_filename)
|
||||
with open(tmp_filename, 'r') as f:
|
||||
with open(tmp_filename) as f:
|
||||
written = f.read()
|
||||
os.remove(tmp_filename)
|
||||
assert written == '\n' + content
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
|
||||
|
||||
class TestBBox(object):
|
||||
class TestBBox:
|
||||
|
||||
def _test_bbox_overlaps(self, device='cpu', dtype=torch.float):
|
||||
from mmcv.ops import bbox_overlaps
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class TestBilinearGridSample(object):
|
||||
class TestBilinearGridSample:
|
||||
|
||||
def _test_bilinear_grid_sample(self,
|
||||
dtype=torch.float,
|
||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
|
||||
class TestBoxIoURotated(object):
|
||||
class TestBoxIoURotated:
|
||||
|
||||
def test_box_iou_rotated_cpu(self):
|
||||
from mmcv.ops import box_iou_rotated
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from torch.autograd import gradcheck
|
||||
|
||||
|
||||
class TestCarafe(object):
|
||||
class TestCarafe:
|
||||
|
||||
def test_carafe_naive_gradcheck(self):
|
||||
if not torch.cuda.is_available():
|
||||
|
|
|
@ -15,7 +15,7 @@ class Loss(nn.Module):
|
|||
return torch.mean(input - target)
|
||||
|
||||
|
||||
class TestCrissCrossAttention(object):
|
||||
class TestCrissCrossAttention:
|
||||
|
||||
def test_cc_attention(self):
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
|
|
@ -35,7 +35,7 @@ gt_offset_bias_grad = [1.44, -0.72, 0., 0., -0.10, -0.08, -0.54, -0.54],
|
|||
gt_deform_weight_grad = [[[[3.62, 0.], [0.40, 0.18]]]]
|
||||
|
||||
|
||||
class TestDeformconv(object):
|
||||
class TestDeformconv:
|
||||
|
||||
def _test_deformconv(self,
|
||||
dtype=torch.float,
|
||||
|
|
|
@ -35,7 +35,7 @@ outputs = [([[[[1, 1.25], [1.5, 1.75]]]], [[[[3.0625, 0.4375],
|
|||
0.00390625]]]])]
|
||||
|
||||
|
||||
class TestDeformRoIPool(object):
|
||||
class TestDeformRoIPool:
|
||||
|
||||
def test_deform_roi_pool_gradcheck(self):
|
||||
if not torch.cuda.is_available():
|
||||
|
|
|
@ -37,7 +37,7 @@ sigmoid_outputs = [(0.13562961, [[-0.00657264, 0.11185755],
|
|||
[-0.02462499, 0.08277918, 0.18050370]])]
|
||||
|
||||
|
||||
class Testfocalloss(object):
|
||||
class Testfocalloss:
|
||||
|
||||
def _test_softmax(self, dtype=torch.float):
|
||||
if not torch.cuda.is_available():
|
||||
|
|
|
@ -10,7 +10,7 @@ except ImportError:
|
|||
_USING_PARROTS = False
|
||||
|
||||
|
||||
class TestFusedBiasLeakyReLU(object):
|
||||
class TestFusedBiasLeakyReLU:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import torch
|
||||
|
||||
|
||||
class TestInfo(object):
|
||||
class TestInfo:
|
||||
|
||||
def test_info(self):
|
||||
if not torch.cuda.is_available():
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import torch
|
||||
|
||||
|
||||
class TestMaskedConv2d(object):
|
||||
class TestMaskedConv2d:
|
||||
|
||||
def test_masked_conv2d(self):
|
||||
if not torch.cuda.is_available():
|
||||
|
|
|
@ -37,7 +37,7 @@ dcn_offset_b_grad = [
|
|||
]
|
||||
|
||||
|
||||
class TestMdconv(object):
|
||||
class TestMdconv:
|
||||
|
||||
def _test_mdconv(self, dtype=torch.float, device='cuda'):
|
||||
if not torch.cuda.is_available() and device == 'cuda':
|
||||
|
|
|
@ -55,7 +55,7 @@ def test_forward_multi_scale_deformable_attn_pytorch():
|
|||
N, M, D = 1, 2, 2
|
||||
Lq, L, P = 2, 2, 2
|
||||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
|
||||
S = sum([(H * W).item() for H, W in shapes])
|
||||
S = sum((H * W).item() for H, W in shapes)
|
||||
|
||||
torch.manual_seed(3)
|
||||
value = torch.rand(N, S, M, D) * 0.01
|
||||
|
@ -78,7 +78,7 @@ def test_forward_equal_with_pytorch_double():
|
|||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
||||
level_start_index = torch.cat((shapes.new_zeros(
|
||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||
S = sum([(H * W).item() for H, W in shapes])
|
||||
S = sum((H * W).item() for H, W in shapes)
|
||||
|
||||
torch.manual_seed(3)
|
||||
value = torch.rand(N, S, M, D).cuda() * 0.01
|
||||
|
@ -111,7 +111,7 @@ def test_forward_equal_with_pytorch_float():
|
|||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
||||
level_start_index = torch.cat((shapes.new_zeros(
|
||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||
S = sum([(H * W).item() for H, W in shapes])
|
||||
S = sum((H * W).item() for H, W in shapes)
|
||||
|
||||
torch.manual_seed(3)
|
||||
value = torch.rand(N, S, M, D).cuda() * 0.01
|
||||
|
@ -155,7 +155,7 @@ def test_gradient_numerical(channels,
|
|||
shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).cuda()
|
||||
level_start_index = torch.cat((shapes.new_zeros(
|
||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||
S = sum([(H * W).item() for H, W in shapes])
|
||||
S = sum((H * W).item() for H, W in shapes)
|
||||
|
||||
value = torch.rand(N, S, M, channels).cuda() * 0.01
|
||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
|
||||
|
||||
class Testnms(object):
|
||||
class Testnms:
|
||||
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
|
@ -129,8 +129,7 @@ class Testnms(object):
|
|||
scores = tensor_dets[:, 4]
|
||||
nms_keep_inds = nms(boxes.contiguous(), scores.contiguous(),
|
||||
iou_thr)[1]
|
||||
assert set([g[0].item()
|
||||
for g in np_groups]) == set(nms_keep_inds.tolist())
|
||||
assert {g[0].item() for g in np_groups} == set(nms_keep_inds.tolist())
|
||||
|
||||
# non empty tensor input
|
||||
tensor_dets = torch.from_numpy(np_dets)
|
||||
|
|
|
@ -33,7 +33,7 @@ def run_before_and_after_test():
|
|||
class WrapFunction(nn.Module):
|
||||
|
||||
def __init__(self, wrapped_function):
|
||||
super(WrapFunction, self).__init__()
|
||||
super().__init__()
|
||||
self.wrapped_function = wrapped_function
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
@ -662,7 +662,7 @@ def test_cummax_cummin(key, opset=11):
|
|||
input_list = [
|
||||
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
|
||||
torch.rand((2, 3, 4, 1, 5)),
|
||||
torch.rand((1)),
|
||||
torch.rand(1),
|
||||
torch.rand((2, 0, 1)), # tensor.numel() is 0
|
||||
torch.FloatTensor(), # empty tensor
|
||||
]
|
||||
|
|
|
@ -15,7 +15,7 @@ class Loss(nn.Module):
|
|||
return torch.mean(input - target)
|
||||
|
||||
|
||||
class TestPSAMask(object):
|
||||
class TestPSAMask:
|
||||
|
||||
def test_psa_mask_collect(self):
|
||||
if not torch.cuda.is_available():
|
||||
|
|
|
@ -29,7 +29,7 @@ outputs = [([[[[1., 2.], [3., 4.]]]], [[[[1., 1.], [1., 1.]]]]),
|
|||
1.]]]])]
|
||||
|
||||
|
||||
class TestRoiPool(object):
|
||||
class TestRoiPool:
|
||||
|
||||
def test_roipool_gradcheck(self):
|
||||
if not torch.cuda.is_available():
|
||||
|
|
|
@ -14,7 +14,7 @@ else:
|
|||
import re
|
||||
|
||||
|
||||
class TestSyncBN(object):
|
||||
class TestSyncBN:
|
||||
|
||||
def dist_init(self):
|
||||
rank = int(os.environ['SLURM_PROCID'])
|
||||
|
|
|
@ -30,7 +30,7 @@ if not is_tensorrt_plugin_loaded():
|
|||
class WrapFunction(nn.Module):
|
||||
|
||||
def __init__(self, wrapped_function):
|
||||
super(WrapFunction, self).__init__()
|
||||
super().__init__()
|
||||
self.wrapped_function = wrapped_function
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
@ -576,7 +576,7 @@ def test_cummin_cummax(func: Callable):
|
|||
input_list = [
|
||||
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
|
||||
torch.rand((2, 3, 4, 1, 5)).cuda(),
|
||||
torch.rand((1)).cuda()
|
||||
torch.rand(1).cuda()
|
||||
]
|
||||
|
||||
input_names = ['input']
|
||||
|
@ -756,7 +756,7 @@ def test_corner_pool(mode):
|
|||
class CornerPoolWrapper(CornerPool):
|
||||
|
||||
def __init__(self, mode):
|
||||
super(CornerPoolWrapper, self).__init__(mode)
|
||||
super().__init__(mode)
|
||||
|
||||
def forward(self, x):
|
||||
# no use `torch.cummax`, instead `corner_pool` is used
|
||||
|
|
|
@ -10,7 +10,7 @@ except ImportError:
|
|||
_USING_PARROTS = False
|
||||
|
||||
|
||||
class TestUpFirDn2d(object):
|
||||
class TestUpFirDn2d:
|
||||
"""Unit test for UpFirDn2d.
|
||||
|
||||
Here, we just test the basic case of upsample version. More gerneal tests
|
||||
|
|
|
@ -96,8 +96,8 @@ def test_voxelization_nondeterministic():
|
|||
coors_all = dynamic_voxelization.forward(points)
|
||||
coors_all = coors_all.cpu().detach().numpy().tolist()
|
||||
|
||||
coors_set = set([tuple(c) for c in coors])
|
||||
coors_all_set = set([tuple(c) for c in coors_all])
|
||||
coors_set = {tuple(c) for c in coors}
|
||||
coors_all_set = {tuple(c) for c in coors_all}
|
||||
|
||||
assert len(coors_set) == len(coors)
|
||||
assert len(coors_set - coors_all_set) == 0
|
||||
|
@ -112,7 +112,7 @@ def test_voxelization_nondeterministic():
|
|||
|
||||
for c, ps, n in zip(coors, voxels, num_points_per_voxel):
|
||||
ideal_voxel_points_set = coors_points_dict[tuple(c)]
|
||||
voxel_points_set = set([tuple(p) for p in ps[:n]])
|
||||
voxel_points_set = {tuple(p) for p in ps[:n]}
|
||||
assert len(voxel_points_set) == n
|
||||
if n < max_num_points:
|
||||
assert voxel_points_set == ideal_voxel_points_set
|
||||
|
@ -133,7 +133,7 @@ def test_voxelization_nondeterministic():
|
|||
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
|
||||
coors = coors.cpu().detach().numpy().tolist()
|
||||
|
||||
coors_set = set([tuple(c) for c in coors])
|
||||
coors_all_set = set([tuple(c) for c in coors_all])
|
||||
coors_set = {tuple(c) for c in coors}
|
||||
coors_all_set = {tuple(c) for c in coors_all}
|
||||
|
||||
assert len(coors_set) == len(coors) == len(coors_all_set)
|
||||
|
|
|
@ -63,7 +63,7 @@ def test_is_module_wrapper():
|
|||
|
||||
# test module wrapper registry
|
||||
@MODULE_WRAPPERS.register_module()
|
||||
class ModuleWrapper(object):
|
||||
class ModuleWrapper:
|
||||
|
||||
def __init__(self, module):
|
||||
self.module = module
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue