mirror of https://github.com/open-mmlab/mmcv.git
Add pyupgrade pre-commit hook (#1937)
* add pyupgrade * add options for pyupgrade * minor refinementpull/1968/head
parent
c561264d55
commit
45fa3e44a2
|
@ -42,7 +42,7 @@ def parse_args():
|
||||||
class SimpleModel(nn.Module):
|
class SimpleModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SimpleModel, self).__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(1, 1, 1)
|
self.conv = nn.Conv2d(1, 1, 1)
|
||||||
|
|
||||||
def train_step(self, *args, **kwargs):
|
def train_step(self, *args, **kwargs):
|
||||||
|
@ -159,13 +159,13 @@ def run(cfg, logger):
|
||||||
def plot_lr_curve(json_file, cfg):
|
def plot_lr_curve(json_file, cfg):
|
||||||
data_dict = dict(LearningRate=[], Momentum=[])
|
data_dict = dict(LearningRate=[], Momentum=[])
|
||||||
assert os.path.isfile(json_file)
|
assert os.path.isfile(json_file)
|
||||||
with open(json_file, 'r') as f:
|
with open(json_file) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
log = json.loads(line.strip())
|
log = json.loads(line.strip())
|
||||||
data_dict['LearningRate'].append(log['lr'])
|
data_dict['LearningRate'].append(log['lr'])
|
||||||
data_dict['Momentum'].append(log['momentum'])
|
data_dict['Momentum'].append(log['momentum'])
|
||||||
|
|
||||||
wind_w, wind_h = [int(size) for size in cfg.window_size.split('*')]
|
wind_w, wind_h = (int(size) for size in cfg.window_size.split('*'))
|
||||||
# if legend is None, use {filename}_{key} as legend
|
# if legend is None, use {filename}_{key} as legend
|
||||||
fig, axes = plt.subplots(2, 1, figsize=(wind_w, wind_h))
|
fig, axes = plt.subplots(2, 1, figsize=(wind_w, wind_h))
|
||||||
plt.subplots_adjust(hspace=0.5)
|
plt.subplots_adjust(hspace=0.5)
|
||||||
|
|
|
@ -43,7 +43,11 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- id: docformatter
|
- id: docformatter
|
||||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v2.32.1
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
args: ["--py36-plus"]
|
||||||
- repo: https://github.com/open-mmlab/pre-commit-hooks
|
- repo: https://github.com/open-mmlab/pre-commit-hooks
|
||||||
rev: v0.2.0 # Use the ref you want to point at
|
rev: v0.2.0 # Use the ref you want to point at
|
||||||
hooks:
|
hooks:
|
||||||
|
|
|
@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
|
||||||
sys.path.insert(0, os.path.abspath('../..'))
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
|
|
||||||
version_file = '../../mmcv/version.py'
|
version_file = '../../mmcv/version.py'
|
||||||
with open(version_file, 'r') as f:
|
with open(version_file) as f:
|
||||||
exec(compile(f.read(), version_file, 'exec'))
|
exec(compile(f.read(), version_file, 'exec'))
|
||||||
__version__ = locals()['__version__']
|
__version__ = locals()['__version__']
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder
|
||||||
sys.path.insert(0, os.path.abspath('../..'))
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
|
|
||||||
version_file = '../../mmcv/version.py'
|
version_file = '../../mmcv/version.py'
|
||||||
with open(version_file, 'r') as f:
|
with open(version_file) as f:
|
||||||
exec(compile(f.read(), version_file, 'exec'))
|
exec(compile(f.read(), version_file, 'exec'))
|
||||||
__version__ = locals()['__version__']
|
__version__ = locals()['__version__']
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ from mmcv.utils import get_logger
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Model, self).__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv2d(3, 6, 5)
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
||||||
self.pool = nn.MaxPool2d(2, 2)
|
self.pool = nn.MaxPool2d(2, 2)
|
||||||
self.conv2 = nn.Conv2d(6, 16, 5)
|
self.conv2 = nn.Conv2d(6, 16, 5)
|
||||||
|
|
|
@ -12,7 +12,7 @@ class AlexNet(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_classes=-1):
|
def __init__(self, num_classes=-1):
|
||||||
super(AlexNet, self).__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.features = nn.Sequential(
|
self.features = nn.Sequential(
|
||||||
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
||||||
|
|
|
@ -29,7 +29,7 @@ class Clamp(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, min=-1., max=1.):
|
def __init__(self, min=-1., max=1.):
|
||||||
super(Clamp, self).__init__()
|
super().__init__()
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ class ContextBlock(nn.Module):
|
||||||
ratio,
|
ratio,
|
||||||
pooling_type='att',
|
pooling_type='att',
|
||||||
fusion_types=('channel_add', )):
|
fusion_types=('channel_add', )):
|
||||||
super(ContextBlock, self).__init__()
|
super().__init__()
|
||||||
assert pooling_type in ['avg', 'att']
|
assert pooling_type in ['avg', 'att']
|
||||||
assert isinstance(fusion_types, (list, tuple))
|
assert isinstance(fusion_types, (list, tuple))
|
||||||
valid_fusion_types = ['channel_add', 'channel_mul']
|
valid_fusion_types = ['channel_add', 'channel_mul']
|
||||||
|
|
|
@ -83,7 +83,7 @@ class ConvModule(nn.Module):
|
||||||
with_spectral_norm=False,
|
with_spectral_norm=False,
|
||||||
padding_mode='zeros',
|
padding_mode='zeros',
|
||||||
order=('conv', 'norm', 'act')):
|
order=('conv', 'norm', 'act')):
|
||||||
super(ConvModule, self).__init__()
|
super().__init__()
|
||||||
assert conv_cfg is None or isinstance(conv_cfg, dict)
|
assert conv_cfg is None or isinstance(conv_cfg, dict)
|
||||||
assert norm_cfg is None or isinstance(norm_cfg, dict)
|
assert norm_cfg is None or isinstance(norm_cfg, dict)
|
||||||
assert act_cfg is None or isinstance(act_cfg, dict)
|
assert act_cfg is None or isinstance(act_cfg, dict)
|
||||||
|
@ -96,7 +96,7 @@ class ConvModule(nn.Module):
|
||||||
self.with_explicit_padding = padding_mode not in official_padding_mode
|
self.with_explicit_padding = padding_mode not in official_padding_mode
|
||||||
self.order = order
|
self.order = order
|
||||||
assert isinstance(self.order, tuple) and len(self.order) == 3
|
assert isinstance(self.order, tuple) and len(self.order) == 3
|
||||||
assert set(order) == set(['conv', 'norm', 'act'])
|
assert set(order) == {'conv', 'norm', 'act'}
|
||||||
|
|
||||||
self.with_norm = norm_cfg is not None
|
self.with_norm = norm_cfg is not None
|
||||||
self.with_activation = act_cfg is not None
|
self.with_activation = act_cfg is not None
|
||||||
|
|
|
@ -35,7 +35,7 @@ class ConvWS2d(nn.Conv2d):
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
eps=1e-5):
|
eps=1e-5):
|
||||||
super(ConvWS2d, self).__init__(
|
super().__init__(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
|
|
|
@ -59,7 +59,7 @@ class DepthwiseSeparableConvModule(nn.Module):
|
||||||
pw_norm_cfg='default',
|
pw_norm_cfg='default',
|
||||||
pw_act_cfg='default',
|
pw_act_cfg='default',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(DepthwiseSeparableConvModule, self).__init__()
|
super().__init__()
|
||||||
assert 'groups' not in kwargs, 'groups should not be specified'
|
assert 'groups' not in kwargs, 'groups should not be specified'
|
||||||
|
|
||||||
# if norm/activation config of depthwise/pointwise ConvModule is not
|
# if norm/activation config of depthwise/pointwise ConvModule is not
|
||||||
|
|
|
@ -37,7 +37,7 @@ class DropPath(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, drop_prob=0.1):
|
def __init__(self, drop_prob=0.1):
|
||||||
super(DropPath, self).__init__()
|
super().__init__()
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -54,7 +54,7 @@ class GeneralizedAttention(nn.Module):
|
||||||
q_stride=1,
|
q_stride=1,
|
||||||
attention_type='1111'):
|
attention_type='1111'):
|
||||||
|
|
||||||
super(GeneralizedAttention, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
# hard range means local range for non-local operation
|
# hard range means local range for non-local operation
|
||||||
self.position_embedding_dim = (
|
self.position_embedding_dim = (
|
||||||
|
|
|
@ -27,7 +27,7 @@ class HSigmoid(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bias=3.0, divisor=6.0, min_value=0.0, max_value=1.0):
|
def __init__(self, bias=3.0, divisor=6.0, min_value=0.0, max_value=1.0):
|
||||||
super(HSigmoid, self).__init__()
|
super().__init__()
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'In MMCV v1.4.4, we modified the default value of args to align '
|
'In MMCV v1.4.4, we modified the default value of args to align '
|
||||||
'with PyTorch official. Previous Implementation: '
|
'with PyTorch official. Previous Implementation: '
|
||||||
|
|
|
@ -22,7 +22,7 @@ class HSwish(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, inplace=False):
|
def __init__(self, inplace=False):
|
||||||
super(HSwish, self).__init__()
|
super().__init__()
|
||||||
self.act = nn.ReLU6(inplace)
|
self.act = nn.ReLU6(inplace)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -40,7 +40,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
||||||
norm_cfg=None,
|
norm_cfg=None,
|
||||||
mode='embedded_gaussian',
|
mode='embedded_gaussian',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(_NonLocalNd, self).__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
self.use_scale = use_scale
|
self.use_scale = use_scale
|
||||||
|
@ -228,8 +228,7 @@ class NonLocal1d(_NonLocalNd):
|
||||||
sub_sample=False,
|
sub_sample=False,
|
||||||
conv_cfg=dict(type='Conv1d'),
|
conv_cfg=dict(type='Conv1d'),
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(NonLocal1d, self).__init__(
|
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||||
in_channels, conv_cfg=conv_cfg, **kwargs)
|
|
||||||
|
|
||||||
self.sub_sample = sub_sample
|
self.sub_sample = sub_sample
|
||||||
|
|
||||||
|
@ -262,8 +261,7 @@ class NonLocal2d(_NonLocalNd):
|
||||||
sub_sample=False,
|
sub_sample=False,
|
||||||
conv_cfg=dict(type='Conv2d'),
|
conv_cfg=dict(type='Conv2d'),
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(NonLocal2d, self).__init__(
|
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||||
in_channels, conv_cfg=conv_cfg, **kwargs)
|
|
||||||
|
|
||||||
self.sub_sample = sub_sample
|
self.sub_sample = sub_sample
|
||||||
|
|
||||||
|
@ -293,8 +291,7 @@ class NonLocal3d(_NonLocalNd):
|
||||||
sub_sample=False,
|
sub_sample=False,
|
||||||
conv_cfg=dict(type='Conv3d'),
|
conv_cfg=dict(type='Conv3d'),
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(NonLocal3d, self).__init__(
|
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
|
||||||
in_channels, conv_cfg=conv_cfg, **kwargs)
|
|
||||||
self.sub_sample = sub_sample
|
self.sub_sample = sub_sample
|
||||||
|
|
||||||
if sub_sample:
|
if sub_sample:
|
||||||
|
|
|
@ -14,7 +14,7 @@ class Scale(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, scale=1.0):
|
def __init__(self, scale=1.0):
|
||||||
super(Scale, self).__init__()
|
super().__init__()
|
||||||
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
|
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -19,7 +19,7 @@ class Swish(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Swish, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
|
@ -96,7 +96,7 @@ class AdaptivePadding(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
|
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
|
||||||
super(AdaptivePadding, self).__init__()
|
super().__init__()
|
||||||
assert padding in ('same', 'corner')
|
assert padding in ('same', 'corner')
|
||||||
|
|
||||||
kernel_size = to_2tuple(kernel_size)
|
kernel_size = to_2tuple(kernel_size)
|
||||||
|
@ -190,7 +190,7 @@ class PatchEmbed(BaseModule):
|
||||||
norm_cfg=None,
|
norm_cfg=None,
|
||||||
input_size=None,
|
input_size=None,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super(PatchEmbed, self).__init__(init_cfg=init_cfg)
|
super().__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
self.embed_dims = embed_dims
|
self.embed_dims = embed_dims
|
||||||
if stride is None:
|
if stride is None:
|
||||||
|
@ -435,7 +435,7 @@ class MultiheadAttention(BaseModule):
|
||||||
init_cfg=None,
|
init_cfg=None,
|
||||||
batch_first=False,
|
batch_first=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(MultiheadAttention, self).__init__(init_cfg)
|
super().__init__(init_cfg)
|
||||||
if 'dropout' in kwargs:
|
if 'dropout' in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'The arguments `dropout` in MultiheadAttention '
|
'The arguments `dropout` in MultiheadAttention '
|
||||||
|
@ -590,7 +590,7 @@ class FFN(BaseModule):
|
||||||
add_identity=True,
|
add_identity=True,
|
||||||
init_cfg=None,
|
init_cfg=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(FFN, self).__init__(init_cfg)
|
super().__init__(init_cfg)
|
||||||
assert num_fcs >= 2, 'num_fcs should be no less ' \
|
assert num_fcs >= 2, 'num_fcs should be no less ' \
|
||||||
f'than 2. got {num_fcs}.'
|
f'than 2. got {num_fcs}.'
|
||||||
self.embed_dims = embed_dims
|
self.embed_dims = embed_dims
|
||||||
|
@ -694,12 +694,12 @@ class BaseTransformerLayer(BaseModule):
|
||||||
f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
|
f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
|
||||||
ffn_cfgs[new_name] = kwargs[ori_name]
|
ffn_cfgs[new_name] = kwargs[ori_name]
|
||||||
|
|
||||||
super(BaseTransformerLayer, self).__init__(init_cfg)
|
super().__init__(init_cfg)
|
||||||
|
|
||||||
self.batch_first = batch_first
|
self.batch_first = batch_first
|
||||||
|
|
||||||
assert set(operation_order) & set(
|
assert set(operation_order) & {
|
||||||
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
|
'self_attn', 'norm', 'ffn', 'cross_attn'} == \
|
||||||
set(operation_order), f'The operation_order of' \
|
set(operation_order), f'The operation_order of' \
|
||||||
f' {self.__class__.__name__} should ' \
|
f' {self.__class__.__name__} should ' \
|
||||||
f'contains all four operation type ' \
|
f'contains all four operation type ' \
|
||||||
|
@ -880,7 +880,7 @@ class TransformerLayerSequence(BaseModule):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
|
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
|
||||||
super(TransformerLayerSequence, self).__init__(init_cfg)
|
super().__init__(init_cfg)
|
||||||
if isinstance(transformerlayers, dict):
|
if isinstance(transformerlayers, dict):
|
||||||
transformerlayers = [
|
transformerlayers = [
|
||||||
copy.deepcopy(transformerlayers) for _ in range(num_layers)
|
copy.deepcopy(transformerlayers) for _ in range(num_layers)
|
||||||
|
|
|
@ -26,7 +26,7 @@ class PixelShufflePack(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, scale_factor,
|
def __init__(self, in_channels, out_channels, scale_factor,
|
||||||
upsample_kernel):
|
upsample_kernel):
|
||||||
super(PixelShufflePack, self).__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
|
@ -30,7 +30,7 @@ class BasicBlock(nn.Module):
|
||||||
downsample=None,
|
downsample=None,
|
||||||
style='pytorch',
|
style='pytorch',
|
||||||
with_cp=False):
|
with_cp=False):
|
||||||
super(BasicBlock, self).__init__()
|
super().__init__()
|
||||||
assert style in ['pytorch', 'caffe']
|
assert style in ['pytorch', 'caffe']
|
||||||
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
|
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
|
||||||
self.bn1 = nn.BatchNorm2d(planes)
|
self.bn1 = nn.BatchNorm2d(planes)
|
||||||
|
@ -77,7 +77,7 @@ class Bottleneck(nn.Module):
|
||||||
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
|
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
|
||||||
it is "caffe", the stride-two layer is the first 1x1 conv layer.
|
it is "caffe", the stride-two layer is the first 1x1 conv layer.
|
||||||
"""
|
"""
|
||||||
super(Bottleneck, self).__init__()
|
super().__init__()
|
||||||
assert style in ['pytorch', 'caffe']
|
assert style in ['pytorch', 'caffe']
|
||||||
if style == 'pytorch':
|
if style == 'pytorch':
|
||||||
conv1_stride = 1
|
conv1_stride = 1
|
||||||
|
@ -218,7 +218,7 @@ class ResNet(nn.Module):
|
||||||
bn_eval=True,
|
bn_eval=True,
|
||||||
bn_frozen=False,
|
bn_frozen=False,
|
||||||
with_cp=False):
|
with_cp=False):
|
||||||
super(ResNet, self).__init__()
|
super().__init__()
|
||||||
if depth not in self.arch_settings:
|
if depth not in self.arch_settings:
|
||||||
raise KeyError(f'invalid depth {depth} for resnet')
|
raise KeyError(f'invalid depth {depth} for resnet')
|
||||||
assert num_stages >= 1 and num_stages <= 4
|
assert num_stages >= 1 and num_stages <= 4
|
||||||
|
@ -293,7 +293,7 @@ class ResNet(nn.Module):
|
||||||
return tuple(outs)
|
return tuple(outs)
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
super(ResNet, self).train(mode)
|
super().train(mode)
|
||||||
if self.bn_eval:
|
if self.bn_eval:
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.BatchNorm2d):
|
if isinstance(m, nn.BatchNorm2d):
|
||||||
|
|
|
@ -277,10 +277,10 @@ def print_model_with_flops(model,
|
||||||
return ', '.join([
|
return ', '.join([
|
||||||
params_to_string(
|
params_to_string(
|
||||||
accumulated_num_params, units='M', precision=precision),
|
accumulated_num_params, units='M', precision=precision),
|
||||||
'{:.3%} Params'.format(accumulated_num_params / total_params),
|
f'{accumulated_num_params / total_params:.3%} Params',
|
||||||
flops_to_string(
|
flops_to_string(
|
||||||
accumulated_flops_cost, units=units, precision=precision),
|
accumulated_flops_cost, units=units, precision=precision),
|
||||||
'{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
|
f'{accumulated_flops_cost / total_flops:.3%} FLOPs',
|
||||||
self.original_extra_repr()
|
self.original_extra_repr()
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
|
@ -129,7 +129,7 @@ def _get_bases_name(m):
|
||||||
return [b.__name__ for b in m.__class__.__bases__]
|
return [b.__name__ for b in m.__class__.__bases__]
|
||||||
|
|
||||||
|
|
||||||
class BaseInit(object):
|
class BaseInit:
|
||||||
|
|
||||||
def __init__(self, *, bias=0, bias_prob=None, layer=None):
|
def __init__(self, *, bias=0, bias_prob=None, layer=None):
|
||||||
self.wholemodule = False
|
self.wholemodule = False
|
||||||
|
@ -461,7 +461,7 @@ class Caffe2XavierInit(KaimingInit):
|
||||||
|
|
||||||
|
|
||||||
@INITIALIZERS.register_module(name='Pretrained')
|
@INITIALIZERS.register_module(name='Pretrained')
|
||||||
class PretrainedInit(object):
|
class PretrainedInit:
|
||||||
"""Initialize module by loading a pretrained model.
|
"""Initialize module by loading a pretrained model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -70,7 +70,7 @@ class VGG(nn.Module):
|
||||||
bn_frozen=False,
|
bn_frozen=False,
|
||||||
ceil_mode=False,
|
ceil_mode=False,
|
||||||
with_last_pool=True):
|
with_last_pool=True):
|
||||||
super(VGG, self).__init__()
|
super().__init__()
|
||||||
if depth not in self.arch_settings:
|
if depth not in self.arch_settings:
|
||||||
raise KeyError(f'invalid depth {depth} for vgg')
|
raise KeyError(f'invalid depth {depth} for vgg')
|
||||||
assert num_stages >= 1 and num_stages <= 5
|
assert num_stages >= 1 and num_stages <= 5
|
||||||
|
@ -157,7 +157,7 @@ class VGG(nn.Module):
|
||||||
return tuple(outs)
|
return tuple(outs)
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
super(VGG, self).train(mode)
|
super().train(mode)
|
||||||
if self.bn_eval:
|
if self.bn_eval:
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.BatchNorm2d):
|
if isinstance(m, nn.BatchNorm2d):
|
||||||
|
|
|
@ -33,7 +33,7 @@ class MLUDataParallel(MMDataParallel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, dim=0, **kwargs):
|
def __init__(self, *args, dim=0, **kwargs):
|
||||||
super(MLUDataParallel, self).__init__(*args, dim=dim, **kwargs)
|
super().__init__(*args, dim=dim, **kwargs)
|
||||||
self.device_ids = [0]
|
self.device_ids = [0]
|
||||||
self.src_device_obj = torch.device('mlu:0')
|
self.src_device_obj = torch.device('mlu:0')
|
||||||
|
|
||||||
|
|
|
@ -210,9 +210,9 @@ class PetrelBackend(BaseStorageBackend):
|
||||||
"""
|
"""
|
||||||
if not has_method(self._client, 'delete'):
|
if not has_method(self._client, 'delete'):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
('Current version of Petrel Python SDK has not supported '
|
'Current version of Petrel Python SDK has not supported '
|
||||||
'the `delete` method, please use a higher version or dev'
|
'the `delete` method, please use a higher version or dev'
|
||||||
' branch instead.'))
|
' branch instead.')
|
||||||
|
|
||||||
filepath = self._map_path(filepath)
|
filepath = self._map_path(filepath)
|
||||||
filepath = self._format_path(filepath)
|
filepath = self._format_path(filepath)
|
||||||
|
@ -230,9 +230,9 @@ class PetrelBackend(BaseStorageBackend):
|
||||||
if not (has_method(self._client, 'contains')
|
if not (has_method(self._client, 'contains')
|
||||||
and has_method(self._client, 'isdir')):
|
and has_method(self._client, 'isdir')):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
('Current version of Petrel Python SDK has not supported '
|
'Current version of Petrel Python SDK has not supported '
|
||||||
'the `contains` and `isdir` methods, please use a higher'
|
'the `contains` and `isdir` methods, please use a higher'
|
||||||
'version or dev branch instead.'))
|
'version or dev branch instead.')
|
||||||
|
|
||||||
filepath = self._map_path(filepath)
|
filepath = self._map_path(filepath)
|
||||||
filepath = self._format_path(filepath)
|
filepath = self._format_path(filepath)
|
||||||
|
@ -251,9 +251,9 @@ class PetrelBackend(BaseStorageBackend):
|
||||||
"""
|
"""
|
||||||
if not has_method(self._client, 'isdir'):
|
if not has_method(self._client, 'isdir'):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
('Current version of Petrel Python SDK has not supported '
|
'Current version of Petrel Python SDK has not supported '
|
||||||
'the `isdir` method, please use a higher version or dev'
|
'the `isdir` method, please use a higher version or dev'
|
||||||
' branch instead.'))
|
' branch instead.')
|
||||||
|
|
||||||
filepath = self._map_path(filepath)
|
filepath = self._map_path(filepath)
|
||||||
filepath = self._format_path(filepath)
|
filepath = self._format_path(filepath)
|
||||||
|
@ -271,9 +271,9 @@ class PetrelBackend(BaseStorageBackend):
|
||||||
"""
|
"""
|
||||||
if not has_method(self._client, 'contains'):
|
if not has_method(self._client, 'contains'):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
('Current version of Petrel Python SDK has not supported '
|
'Current version of Petrel Python SDK has not supported '
|
||||||
'the `contains` method, please use a higher version or '
|
'the `contains` method, please use a higher version or '
|
||||||
'dev branch instead.'))
|
'dev branch instead.')
|
||||||
|
|
||||||
filepath = self._map_path(filepath)
|
filepath = self._map_path(filepath)
|
||||||
filepath = self._format_path(filepath)
|
filepath = self._format_path(filepath)
|
||||||
|
@ -366,9 +366,9 @@ class PetrelBackend(BaseStorageBackend):
|
||||||
"""
|
"""
|
||||||
if not has_method(self._client, 'list'):
|
if not has_method(self._client, 'list'):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
('Current version of Petrel Python SDK has not supported '
|
'Current version of Petrel Python SDK has not supported '
|
||||||
'the `list` method, please use a higher version or dev'
|
'the `list` method, please use a higher version or dev'
|
||||||
' branch instead.'))
|
' branch instead.')
|
||||||
|
|
||||||
dir_path = self._map_path(dir_path)
|
dir_path = self._map_path(dir_path)
|
||||||
dir_path = self._format_path(dir_path)
|
dir_path = self._format_path(dir_path)
|
||||||
|
@ -549,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend):
|
||||||
Returns:
|
Returns:
|
||||||
str: Expected text reading from ``filepath``.
|
str: Expected text reading from ``filepath``.
|
||||||
"""
|
"""
|
||||||
with open(filepath, 'r', encoding=encoding) as f:
|
with open(filepath, encoding=encoding) as f:
|
||||||
value_buf = f.read()
|
value_buf = f.read()
|
||||||
return value_buf
|
return value_buf
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler):
|
||||||
return pickle.load(file, **kwargs)
|
return pickle.load(file, **kwargs)
|
||||||
|
|
||||||
def load_from_path(self, filepath, **kwargs):
|
def load_from_path(self, filepath, **kwargs):
|
||||||
return super(PickleHandler, self).load_from_path(
|
return super().load_from_path(filepath, mode='rb', **kwargs)
|
||||||
filepath, mode='rb', **kwargs)
|
|
||||||
|
|
||||||
def dump_to_str(self, obj, **kwargs):
|
def dump_to_str(self, obj, **kwargs):
|
||||||
kwargs.setdefault('protocol', 2)
|
kwargs.setdefault('protocol', 2)
|
||||||
|
@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler):
|
||||||
pickle.dump(obj, file, **kwargs)
|
pickle.dump(obj, file, **kwargs)
|
||||||
|
|
||||||
def dump_to_path(self, obj, filepath, **kwargs):
|
def dump_to_path(self, obj, filepath, **kwargs):
|
||||||
super(PickleHandler, self).dump_to_path(
|
super().dump_to_path(obj, filepath, mode='wb', **kwargs)
|
||||||
obj, filepath, mode='wb', **kwargs)
|
|
||||||
|
|
|
@ -157,7 +157,7 @@ def imresize_to_multiple(img,
|
||||||
size = _scale_size((w, h), scale_factor)
|
size = _scale_size((w, h), scale_factor)
|
||||||
|
|
||||||
divisor = to_2tuple(divisor)
|
divisor = to_2tuple(divisor)
|
||||||
size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
|
size = tuple(int(np.ceil(s / d)) * d for s, d in zip(size, divisor))
|
||||||
resized_img, w_scale, h_scale = imresize(
|
resized_img, w_scale, h_scale = imresize(
|
||||||
img,
|
img,
|
||||||
size,
|
size,
|
||||||
|
|
|
@ -59,7 +59,7 @@ def _parse_arg(value, desc):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"ONNX symbolic doesn't know to interpret ListConstruct node")
|
"ONNX symbolic doesn't know to interpret ListConstruct node")
|
||||||
|
|
||||||
raise RuntimeError('Unexpected node type: {}'.format(value.node().kind()))
|
raise RuntimeError(f'Unexpected node type: {value.node().kind()}')
|
||||||
|
|
||||||
|
|
||||||
def _maybe_get_const(value, desc):
|
def _maybe_get_const(value, desc):
|
||||||
|
|
|
@ -86,7 +86,7 @@ class BorderAlign(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pool_size):
|
def __init__(self, pool_size):
|
||||||
super(BorderAlign, self).__init__()
|
super().__init__()
|
||||||
self.pool_size = pool_size
|
self.pool_size = pool_size
|
||||||
|
|
||||||
def forward(self, input, boxes):
|
def forward(self, input, boxes):
|
||||||
|
|
|
@ -131,7 +131,7 @@ def box_iou_rotated(bboxes1,
|
||||||
if aligned:
|
if aligned:
|
||||||
ious = bboxes1.new_zeros(rows)
|
ious = bboxes1.new_zeros(rows)
|
||||||
else:
|
else:
|
||||||
ious = bboxes1.new_zeros((rows * cols))
|
ious = bboxes1.new_zeros(rows * cols)
|
||||||
if not clockwise:
|
if not clockwise:
|
||||||
flip_mat = bboxes1.new_ones(bboxes1.shape[-1])
|
flip_mat = bboxes1.new_ones(bboxes1.shape[-1])
|
||||||
flip_mat[-1] = -1
|
flip_mat[-1] = -1
|
||||||
|
|
|
@ -85,7 +85,7 @@ carafe_naive = CARAFENaiveFunction.apply
|
||||||
class CARAFENaive(Module):
|
class CARAFENaive(Module):
|
||||||
|
|
||||||
def __init__(self, kernel_size, group_size, scale_factor):
|
def __init__(self, kernel_size, group_size, scale_factor):
|
||||||
super(CARAFENaive, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert isinstance(kernel_size, int) and isinstance(
|
assert isinstance(kernel_size, int) and isinstance(
|
||||||
group_size, int) and isinstance(scale_factor, int)
|
group_size, int) and isinstance(scale_factor, int)
|
||||||
|
@ -195,7 +195,7 @@ class CARAFE(Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, kernel_size, group_size, scale_factor):
|
def __init__(self, kernel_size, group_size, scale_factor):
|
||||||
super(CARAFE, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert isinstance(kernel_size, int) and isinstance(
|
assert isinstance(kernel_size, int) and isinstance(
|
||||||
group_size, int) and isinstance(scale_factor, int)
|
group_size, int) and isinstance(scale_factor, int)
|
||||||
|
@ -238,7 +238,7 @@ class CARAFEPack(nn.Module):
|
||||||
encoder_kernel=3,
|
encoder_kernel=3,
|
||||||
encoder_dilation=1,
|
encoder_dilation=1,
|
||||||
compressed_channels=64):
|
compressed_channels=64):
|
||||||
super(CARAFEPack, self).__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
self.up_kernel = up_kernel
|
self.up_kernel = up_kernel
|
||||||
|
|
|
@ -125,7 +125,7 @@ class CornerPool(nn.Module):
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, mode):
|
def __init__(self, mode):
|
||||||
super(CornerPool, self).__init__()
|
super().__init__()
|
||||||
assert mode in self.pool_functions
|
assert mode in self.pool_functions
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.corner_pool = self.pool_functions[mode]
|
self.corner_pool = self.pool_functions[mode]
|
||||||
|
|
|
@ -236,7 +236,7 @@ class DeformConv2d(nn.Module):
|
||||||
deform_groups: int = 1,
|
deform_groups: int = 1,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
im2col_step: int = 32) -> None:
|
im2col_step: int = 32) -> None:
|
||||||
super(DeformConv2d, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert not bias, \
|
assert not bias, \
|
||||||
f'bias={bias} is not supported in DeformConv2d.'
|
f'bias={bias} is not supported in DeformConv2d.'
|
||||||
|
@ -356,7 +356,7 @@ class DeformConv2dPack(DeformConv2d):
|
||||||
_version = 2
|
_version = 2
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(DeformConv2dPack, self).__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.conv_offset = nn.Conv2d(
|
self.conv_offset = nn.Conv2d(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
|
self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
|
||||||
|
|
|
@ -96,7 +96,7 @@ class DeformRoIPool(nn.Module):
|
||||||
spatial_scale=1.0,
|
spatial_scale=1.0,
|
||||||
sampling_ratio=0,
|
sampling_ratio=0,
|
||||||
gamma=0.1):
|
gamma=0.1):
|
||||||
super(DeformRoIPool, self).__init__()
|
super().__init__()
|
||||||
self.output_size = _pair(output_size)
|
self.output_size = _pair(output_size)
|
||||||
self.spatial_scale = float(spatial_scale)
|
self.spatial_scale = float(spatial_scale)
|
||||||
self.sampling_ratio = int(sampling_ratio)
|
self.sampling_ratio = int(sampling_ratio)
|
||||||
|
@ -117,8 +117,7 @@ class DeformRoIPoolPack(DeformRoIPool):
|
||||||
spatial_scale=1.0,
|
spatial_scale=1.0,
|
||||||
sampling_ratio=0,
|
sampling_ratio=0,
|
||||||
gamma=0.1):
|
gamma=0.1):
|
||||||
super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale,
|
super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
|
||||||
sampling_ratio, gamma)
|
|
||||||
|
|
||||||
self.output_channels = output_channels
|
self.output_channels = output_channels
|
||||||
self.deform_fc_channels = deform_fc_channels
|
self.deform_fc_channels = deform_fc_channels
|
||||||
|
@ -158,8 +157,7 @@ class ModulatedDeformRoIPoolPack(DeformRoIPool):
|
||||||
spatial_scale=1.0,
|
spatial_scale=1.0,
|
||||||
sampling_ratio=0,
|
sampling_ratio=0,
|
||||||
gamma=0.1):
|
gamma=0.1):
|
||||||
super(ModulatedDeformRoIPoolPack,
|
super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
|
||||||
self).__init__(output_size, spatial_scale, sampling_ratio, gamma)
|
|
||||||
|
|
||||||
self.output_channels = output_channels
|
self.output_channels = output_channels
|
||||||
self.deform_fc_channels = deform_fc_channels
|
self.deform_fc_channels = deform_fc_channels
|
||||||
|
|
|
@ -89,7 +89,7 @@ sigmoid_focal_loss = SigmoidFocalLossFunction.apply
|
||||||
class SigmoidFocalLoss(nn.Module):
|
class SigmoidFocalLoss(nn.Module):
|
||||||
|
|
||||||
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
|
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
|
||||||
super(SigmoidFocalLoss, self).__init__()
|
super().__init__()
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.register_buffer('weight', weight)
|
self.register_buffer('weight', weight)
|
||||||
|
@ -195,7 +195,7 @@ softmax_focal_loss = SoftmaxFocalLossFunction.apply
|
||||||
class SoftmaxFocalLoss(nn.Module):
|
class SoftmaxFocalLoss(nn.Module):
|
||||||
|
|
||||||
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
|
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
|
||||||
super(SoftmaxFocalLoss, self).__init__()
|
super().__init__()
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.register_buffer('weight', weight)
|
self.register_buffer('weight', weight)
|
||||||
|
|
|
@ -212,7 +212,7 @@ class FusedBiasLeakyReLU(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
|
def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
|
||||||
super(FusedBiasLeakyReLU, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||||
self.negative_slope = negative_slope
|
self.negative_slope = negative_slope
|
||||||
|
|
|
@ -98,13 +98,12 @@ class MaskedConv2d(nn.Conv2d):
|
||||||
dilation=1,
|
dilation=1,
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True):
|
bias=True):
|
||||||
super(MaskedConv2d,
|
super().__init__(in_channels, out_channels, kernel_size, stride,
|
||||||
self).__init__(in_channels, out_channels, kernel_size, stride,
|
padding, dilation, groups, bias)
|
||||||
padding, dilation, groups, bias)
|
|
||||||
|
|
||||||
def forward(self, input, mask=None):
|
def forward(self, input, mask=None):
|
||||||
if mask is None: # fallback to the normal Conv2d
|
if mask is None: # fallback to the normal Conv2d
|
||||||
return super(MaskedConv2d, self).forward(input)
|
return super().forward(input)
|
||||||
else:
|
else:
|
||||||
return masked_conv2d(input, mask, self.weight, self.bias,
|
return masked_conv2d(input, mask, self.weight, self.bias,
|
||||||
self.padding)
|
self.padding)
|
||||||
|
|
|
@ -53,7 +53,7 @@ class BaseMergeCell(nn.Module):
|
||||||
input_conv_cfg=None,
|
input_conv_cfg=None,
|
||||||
input_norm_cfg=None,
|
input_norm_cfg=None,
|
||||||
upsample_mode='nearest'):
|
upsample_mode='nearest'):
|
||||||
super(BaseMergeCell, self).__init__()
|
super().__init__()
|
||||||
assert upsample_mode in ['nearest', 'bilinear']
|
assert upsample_mode in ['nearest', 'bilinear']
|
||||||
self.with_out_conv = with_out_conv
|
self.with_out_conv = with_out_conv
|
||||||
self.with_input1_conv = with_input1_conv
|
self.with_input1_conv = with_input1_conv
|
||||||
|
@ -121,7 +121,7 @@ class BaseMergeCell(nn.Module):
|
||||||
class SumCell(BaseMergeCell):
|
class SumCell(BaseMergeCell):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, **kwargs):
|
def __init__(self, in_channels, out_channels, **kwargs):
|
||||||
super(SumCell, self).__init__(in_channels, out_channels, **kwargs)
|
super().__init__(in_channels, out_channels, **kwargs)
|
||||||
|
|
||||||
def _binary_op(self, x1, x2):
|
def _binary_op(self, x1, x2):
|
||||||
return x1 + x2
|
return x1 + x2
|
||||||
|
@ -130,8 +130,7 @@ class SumCell(BaseMergeCell):
|
||||||
class ConcatCell(BaseMergeCell):
|
class ConcatCell(BaseMergeCell):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, **kwargs):
|
def __init__(self, in_channels, out_channels, **kwargs):
|
||||||
super(ConcatCell, self).__init__(in_channels * 2, out_channels,
|
super().__init__(in_channels * 2, out_channels, **kwargs)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def _binary_op(self, x1, x2):
|
def _binary_op(self, x1, x2):
|
||||||
ret = torch.cat([x1, x2], dim=1)
|
ret = torch.cat([x1, x2], dim=1)
|
||||||
|
|
|
@ -168,7 +168,7 @@ class ModulatedDeformConv2d(nn.Module):
|
||||||
groups=1,
|
groups=1,
|
||||||
deform_groups=1,
|
deform_groups=1,
|
||||||
bias=True):
|
bias=True):
|
||||||
super(ModulatedDeformConv2d, self).__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.kernel_size = _pair(kernel_size)
|
self.kernel_size = _pair(kernel_size)
|
||||||
|
@ -227,7 +227,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
|
||||||
_version = 2
|
_version = 2
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.conv_offset = nn.Conv2d(
|
self.conv_offset = nn.Conv2d(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
||||||
|
@ -239,7 +239,7 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
super(ModulatedDeformConv2dPack, self).init_weights()
|
super().init_weights()
|
||||||
if hasattr(self, 'conv_offset'):
|
if hasattr(self, 'conv_offset'):
|
||||||
self.conv_offset.weight.data.zero_()
|
self.conv_offset.weight.data.zero_()
|
||||||
self.conv_offset.bias.data.zero_()
|
self.conv_offset.bias.data.zero_()
|
||||||
|
|
|
@ -296,7 +296,7 @@ class SimpleRoIAlign(nn.Module):
|
||||||
If True, align the results more perfectly.
|
If True, align the results more perfectly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super(SimpleRoIAlign, self).__init__()
|
super().__init__()
|
||||||
self.output_size = _pair(output_size)
|
self.output_size = _pair(output_size)
|
||||||
self.spatial_scale = float(spatial_scale)
|
self.spatial_scale = float(spatial_scale)
|
||||||
# to be consistent with other RoI ops
|
# to be consistent with other RoI ops
|
||||||
|
|
|
@ -72,7 +72,7 @@ psa_mask = PSAMaskFunction.apply
|
||||||
class PSAMask(nn.Module):
|
class PSAMask(nn.Module):
|
||||||
|
|
||||||
def __init__(self, psa_type, mask_size=None):
|
def __init__(self, psa_type, mask_size=None):
|
||||||
super(PSAMask, self).__init__()
|
super().__init__()
|
||||||
assert psa_type in ['collect', 'distribute']
|
assert psa_type in ['collect', 'distribute']
|
||||||
if psa_type == 'collect':
|
if psa_type == 'collect':
|
||||||
psa_type_enum = 0
|
psa_type_enum = 0
|
||||||
|
|
|
@ -116,7 +116,7 @@ class RiRoIAlignRotated(nn.Module):
|
||||||
num_samples=0,
|
num_samples=0,
|
||||||
num_orientations=8,
|
num_orientations=8,
|
||||||
clockwise=False):
|
clockwise=False):
|
||||||
super(RiRoIAlignRotated, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.out_size = out_size
|
self.out_size = out_size
|
||||||
self.spatial_scale = float(spatial_scale)
|
self.spatial_scale = float(spatial_scale)
|
||||||
|
|
|
@ -181,7 +181,7 @@ class RoIAlign(nn.Module):
|
||||||
pool_mode='avg',
|
pool_mode='avg',
|
||||||
aligned=True,
|
aligned=True,
|
||||||
use_torchvision=False):
|
use_torchvision=False):
|
||||||
super(RoIAlign, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.output_size = _pair(output_size)
|
self.output_size = _pair(output_size)
|
||||||
self.spatial_scale = float(spatial_scale)
|
self.spatial_scale = float(spatial_scale)
|
||||||
|
|
|
@ -156,7 +156,7 @@ class RoIAlignRotated(nn.Module):
|
||||||
sampling_ratio=0,
|
sampling_ratio=0,
|
||||||
aligned=True,
|
aligned=True,
|
||||||
clockwise=False):
|
clockwise=False):
|
||||||
super(RoIAlignRotated, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.output_size = _pair(output_size)
|
self.output_size = _pair(output_size)
|
||||||
self.spatial_scale = float(spatial_scale)
|
self.spatial_scale = float(spatial_scale)
|
||||||
|
|
|
@ -71,7 +71,7 @@ roi_pool = RoIPoolFunction.apply
|
||||||
class RoIPool(nn.Module):
|
class RoIPool(nn.Module):
|
||||||
|
|
||||||
def __init__(self, output_size, spatial_scale=1.0):
|
def __init__(self, output_size, spatial_scale=1.0):
|
||||||
super(RoIPool, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.output_size = _pair(output_size)
|
self.output_size = _pair(output_size)
|
||||||
self.spatial_scale = float(spatial_scale)
|
self.spatial_scale = float(spatial_scale)
|
||||||
|
|
|
@ -64,7 +64,7 @@ class SparseConvolution(SparseModule):
|
||||||
inverse=False,
|
inverse=False,
|
||||||
indice_key=None,
|
indice_key=None,
|
||||||
fused_bn=False):
|
fused_bn=False):
|
||||||
super(SparseConvolution, self).__init__()
|
super().__init__()
|
||||||
assert groups == 1
|
assert groups == 1
|
||||||
if not isinstance(kernel_size, (list, tuple)):
|
if not isinstance(kernel_size, (list, tuple)):
|
||||||
kernel_size = [kernel_size] * ndim
|
kernel_size = [kernel_size] * ndim
|
||||||
|
@ -217,7 +217,7 @@ class SparseConv2d(SparseConvolution):
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
indice_key=None):
|
indice_key=None):
|
||||||
super(SparseConv2d, self).__init__(
|
super().__init__(
|
||||||
2,
|
2,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -243,7 +243,7 @@ class SparseConv3d(SparseConvolution):
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
indice_key=None):
|
indice_key=None):
|
||||||
super(SparseConv3d, self).__init__(
|
super().__init__(
|
||||||
3,
|
3,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -269,7 +269,7 @@ class SparseConv4d(SparseConvolution):
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
indice_key=None):
|
indice_key=None):
|
||||||
super(SparseConv4d, self).__init__(
|
super().__init__(
|
||||||
4,
|
4,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -295,7 +295,7 @@ class SparseConvTranspose2d(SparseConvolution):
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
indice_key=None):
|
indice_key=None):
|
||||||
super(SparseConvTranspose2d, self).__init__(
|
super().__init__(
|
||||||
2,
|
2,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -322,7 +322,7 @@ class SparseConvTranspose3d(SparseConvolution):
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
indice_key=None):
|
indice_key=None):
|
||||||
super(SparseConvTranspose3d, self).__init__(
|
super().__init__(
|
||||||
3,
|
3,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -345,7 +345,7 @@ class SparseInverseConv2d(SparseConvolution):
|
||||||
kernel_size,
|
kernel_size,
|
||||||
indice_key=None,
|
indice_key=None,
|
||||||
bias=True):
|
bias=True):
|
||||||
super(SparseInverseConv2d, self).__init__(
|
super().__init__(
|
||||||
2,
|
2,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -364,7 +364,7 @@ class SparseInverseConv3d(SparseConvolution):
|
||||||
kernel_size,
|
kernel_size,
|
||||||
indice_key=None,
|
indice_key=None,
|
||||||
bias=True):
|
bias=True):
|
||||||
super(SparseInverseConv3d, self).__init__(
|
super().__init__(
|
||||||
3,
|
3,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -387,7 +387,7 @@ class SubMConv2d(SparseConvolution):
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
indice_key=None):
|
indice_key=None):
|
||||||
super(SubMConv2d, self).__init__(
|
super().__init__(
|
||||||
2,
|
2,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -414,7 +414,7 @@ class SubMConv3d(SparseConvolution):
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
indice_key=None):
|
indice_key=None):
|
||||||
super(SubMConv3d, self).__init__(
|
super().__init__(
|
||||||
3,
|
3,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -441,7 +441,7 @@ class SubMConv4d(SparseConvolution):
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
indice_key=None):
|
indice_key=None):
|
||||||
super(SubMConv4d, self).__init__(
|
super().__init__(
|
||||||
4,
|
4,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
|
|
@ -86,7 +86,7 @@ class SparseSequential(SparseModule):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(SparseSequential, self).__init__()
|
super().__init__()
|
||||||
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
||||||
for key, module in args[0].items():
|
for key, module in args[0].items():
|
||||||
self.add_module(key, module)
|
self.add_module(key, module)
|
||||||
|
@ -103,7 +103,7 @@ class SparseSequential(SparseModule):
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
if not (-len(self) <= idx < len(self)):
|
if not (-len(self) <= idx < len(self)):
|
||||||
raise IndexError('index {} is out of range'.format(idx))
|
raise IndexError(f'index {idx} is out of range')
|
||||||
if idx < 0:
|
if idx < 0:
|
||||||
idx += len(self)
|
idx += len(self)
|
||||||
it = iter(self._modules.values())
|
it = iter(self._modules.values())
|
||||||
|
|
|
@ -29,7 +29,7 @@ class SparseMaxPool(SparseModule):
|
||||||
padding=0,
|
padding=0,
|
||||||
dilation=1,
|
dilation=1,
|
||||||
subm=False):
|
subm=False):
|
||||||
super(SparseMaxPool, self).__init__()
|
super().__init__()
|
||||||
if not isinstance(kernel_size, (list, tuple)):
|
if not isinstance(kernel_size, (list, tuple)):
|
||||||
kernel_size = [kernel_size] * ndim
|
kernel_size = [kernel_size] * ndim
|
||||||
if not isinstance(stride, (list, tuple)):
|
if not isinstance(stride, (list, tuple)):
|
||||||
|
@ -77,12 +77,10 @@ class SparseMaxPool(SparseModule):
|
||||||
class SparseMaxPool2d(SparseMaxPool):
|
class SparseMaxPool2d(SparseMaxPool):
|
||||||
|
|
||||||
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
|
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
|
||||||
super(SparseMaxPool2d, self).__init__(2, kernel_size, stride, padding,
|
super().__init__(2, kernel_size, stride, padding, dilation)
|
||||||
dilation)
|
|
||||||
|
|
||||||
|
|
||||||
class SparseMaxPool3d(SparseMaxPool):
|
class SparseMaxPool3d(SparseMaxPool):
|
||||||
|
|
||||||
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
|
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
|
||||||
super(SparseMaxPool3d, self).__init__(3, kernel_size, stride, padding,
|
super().__init__(3, kernel_size, stride, padding, dilation)
|
||||||
dilation)
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ def scatter_nd(indices, updates, shape):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class SparseConvTensor(object):
|
class SparseConvTensor:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
features,
|
features,
|
||||||
|
|
|
@ -198,7 +198,7 @@ class SyncBatchNorm(Module):
|
||||||
track_running_stats=True,
|
track_running_stats=True,
|
||||||
group=None,
|
group=None,
|
||||||
stats_mode='default'):
|
stats_mode='default'):
|
||||||
super(SyncBatchNorm, self).__init__()
|
super().__init__()
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
|
|
|
@ -32,7 +32,7 @@ class MMDataParallel(DataParallel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, dim=0, **kwargs):
|
def __init__(self, *args, dim=0, **kwargs):
|
||||||
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
|
super().__init__(*args, dim=dim, **kwargs)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
def forward(self, *inputs, **kwargs):
|
def forward(self, *inputs, **kwargs):
|
||||||
|
|
|
@ -18,7 +18,7 @@ class MMDistributedDataParallel(nn.Module):
|
||||||
dim=0,
|
dim=0,
|
||||||
broadcast_buffers=True,
|
broadcast_buffers=True,
|
||||||
bucket_cap_mb=25):
|
bucket_cap_mb=25):
|
||||||
super(MMDistributedDataParallel, self).__init__()
|
super().__init__()
|
||||||
self.module = module
|
self.module = module
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.broadcast_buffers = broadcast_buffers
|
self.broadcast_buffers = broadcast_buffers
|
||||||
|
|
|
@ -35,7 +35,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
||||||
# NOTE init_cfg can be defined in different levels, but init_cfg
|
# NOTE init_cfg can be defined in different levels, but init_cfg
|
||||||
# in low levels has a higher priority.
|
# in low levels has a higher priority.
|
||||||
|
|
||||||
super(BaseModule, self).__init__()
|
super().__init__()
|
||||||
# define default value of init_cfg instead of hard code
|
# define default value of init_cfg instead of hard code
|
||||||
# in init_weights() function
|
# in init_weights() function
|
||||||
self._is_init = False
|
self._is_init = False
|
||||||
|
|
|
@ -83,8 +83,8 @@ class CheckpointHook(Hook):
|
||||||
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
||||||
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
||||||
|
|
||||||
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
|
runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
|
||||||
f'{self.file_client.name}.'))
|
f'{self.file_client.name}.')
|
||||||
|
|
||||||
# disable the create_symlink option because some file backends do not
|
# disable the create_symlink option because some file backends do not
|
||||||
# allow to create a symlink
|
# allow to create a symlink
|
||||||
|
@ -93,9 +93,9 @@ class CheckpointHook(Hook):
|
||||||
'create_symlink'] and not self.file_client.allow_symlink:
|
'create_symlink'] and not self.file_client.allow_symlink:
|
||||||
self.args['create_symlink'] = False
|
self.args['create_symlink'] = False
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
('create_symlink is set as True by the user but is changed'
|
'create_symlink is set as True by the user but is changed'
|
||||||
'to be False because creating symbolic link is not '
|
'to be False because creating symbolic link is not '
|
||||||
f'allowed in {self.file_client.name}'))
|
f'allowed in {self.file_client.name}')
|
||||||
else:
|
else:
|
||||||
self.args['create_symlink'] = self.file_client.allow_symlink
|
self.args['create_symlink'] = self.file_client.allow_symlink
|
||||||
|
|
||||||
|
|
|
@ -214,8 +214,8 @@ class EvalHook(Hook):
|
||||||
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
||||||
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
||||||
runner.logger.info(
|
runner.logger.info(
|
||||||
(f'The best checkpoint will be saved to {self.out_dir} by '
|
f'The best checkpoint will be saved to {self.out_dir} by '
|
||||||
f'{self.file_client.name}'))
|
f'{self.file_client.name}')
|
||||||
|
|
||||||
if self.save_best is not None:
|
if self.save_best is not None:
|
||||||
if runner.meta is None:
|
if runner.meta is None:
|
||||||
|
@ -335,8 +335,8 @@ class EvalHook(Hook):
|
||||||
self.best_ckpt_path):
|
self.best_ckpt_path):
|
||||||
self.file_client.remove(self.best_ckpt_path)
|
self.file_client.remove(self.best_ckpt_path)
|
||||||
runner.logger.info(
|
runner.logger.info(
|
||||||
(f'The previous best checkpoint {self.best_ckpt_path} was '
|
f'The previous best checkpoint {self.best_ckpt_path} was '
|
||||||
'removed'))
|
'removed')
|
||||||
|
|
||||||
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
|
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
|
||||||
self.best_ckpt_path = self.file_client.join_path(
|
self.best_ckpt_path = self.file_client.join_path(
|
||||||
|
|
|
@ -34,8 +34,7 @@ class ClearMLLoggerHook(LoggerHook):
|
||||||
ignore_last=True,
|
ignore_last=True,
|
||||||
reset_flag=False,
|
reset_flag=False,
|
||||||
by_epoch=True):
|
by_epoch=True):
|
||||||
super(ClearMLLoggerHook, self).__init__(interval, ignore_last,
|
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||||
reset_flag, by_epoch)
|
|
||||||
self.import_clearml()
|
self.import_clearml()
|
||||||
self.init_kwargs = init_kwargs
|
self.init_kwargs = init_kwargs
|
||||||
|
|
||||||
|
@ -49,7 +48,7 @@ class ClearMLLoggerHook(LoggerHook):
|
||||||
|
|
||||||
@master_only
|
@master_only
|
||||||
def before_run(self, runner):
|
def before_run(self, runner):
|
||||||
super(ClearMLLoggerHook, self).before_run(runner)
|
super().before_run(runner)
|
||||||
task_kwargs = self.init_kwargs if self.init_kwargs else {}
|
task_kwargs = self.init_kwargs if self.init_kwargs else {}
|
||||||
self.task = self.clearml.Task.init(**task_kwargs)
|
self.task = self.clearml.Task.init(**task_kwargs)
|
||||||
self.task_logger = self.task.get_logger()
|
self.task_logger = self.task.get_logger()
|
||||||
|
|
|
@ -40,8 +40,7 @@ class MlflowLoggerHook(LoggerHook):
|
||||||
ignore_last=True,
|
ignore_last=True,
|
||||||
reset_flag=False,
|
reset_flag=False,
|
||||||
by_epoch=True):
|
by_epoch=True):
|
||||||
super(MlflowLoggerHook, self).__init__(interval, ignore_last,
|
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||||
reset_flag, by_epoch)
|
|
||||||
self.import_mlflow()
|
self.import_mlflow()
|
||||||
self.exp_name = exp_name
|
self.exp_name = exp_name
|
||||||
self.tags = tags
|
self.tags = tags
|
||||||
|
@ -59,7 +58,7 @@ class MlflowLoggerHook(LoggerHook):
|
||||||
|
|
||||||
@master_only
|
@master_only
|
||||||
def before_run(self, runner):
|
def before_run(self, runner):
|
||||||
super(MlflowLoggerHook, self).before_run(runner)
|
super().before_run(runner)
|
||||||
if self.exp_name is not None:
|
if self.exp_name is not None:
|
||||||
self.mlflow.set_experiment(self.exp_name)
|
self.mlflow.set_experiment(self.exp_name)
|
||||||
if self.tags is not None:
|
if self.tags is not None:
|
||||||
|
|
|
@ -49,8 +49,7 @@ class NeptuneLoggerHook(LoggerHook):
|
||||||
with_step=True,
|
with_step=True,
|
||||||
by_epoch=True):
|
by_epoch=True):
|
||||||
|
|
||||||
super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
|
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||||
reset_flag, by_epoch)
|
|
||||||
self.import_neptune()
|
self.import_neptune()
|
||||||
self.init_kwargs = init_kwargs
|
self.init_kwargs = init_kwargs
|
||||||
self.with_step = with_step
|
self.with_step = with_step
|
||||||
|
|
|
@ -40,8 +40,7 @@ class PaviLoggerHook(LoggerHook):
|
||||||
reset_flag=False,
|
reset_flag=False,
|
||||||
by_epoch=True,
|
by_epoch=True,
|
||||||
img_key='img_info'):
|
img_key='img_info'):
|
||||||
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
|
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||||
by_epoch)
|
|
||||||
self.init_kwargs = init_kwargs
|
self.init_kwargs = init_kwargs
|
||||||
self.add_graph = add_graph
|
self.add_graph = add_graph
|
||||||
self.add_last_ckpt = add_last_ckpt
|
self.add_last_ckpt = add_last_ckpt
|
||||||
|
@ -49,7 +48,7 @@ class PaviLoggerHook(LoggerHook):
|
||||||
|
|
||||||
@master_only
|
@master_only
|
||||||
def before_run(self, runner):
|
def before_run(self, runner):
|
||||||
super(PaviLoggerHook, self).before_run(runner)
|
super().before_run(runner)
|
||||||
try:
|
try:
|
||||||
from pavi import SummaryWriter
|
from pavi import SummaryWriter
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
|
@ -27,8 +27,7 @@ class SegmindLoggerHook(LoggerHook):
|
||||||
ignore_last=True,
|
ignore_last=True,
|
||||||
reset_flag=False,
|
reset_flag=False,
|
||||||
by_epoch=True):
|
by_epoch=True):
|
||||||
super(SegmindLoggerHook, self).__init__(interval, ignore_last,
|
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||||
reset_flag, by_epoch)
|
|
||||||
self.import_segmind()
|
self.import_segmind()
|
||||||
|
|
||||||
def import_segmind(self):
|
def import_segmind(self):
|
||||||
|
|
|
@ -28,13 +28,12 @@ class TensorboardLoggerHook(LoggerHook):
|
||||||
ignore_last=True,
|
ignore_last=True,
|
||||||
reset_flag=False,
|
reset_flag=False,
|
||||||
by_epoch=True):
|
by_epoch=True):
|
||||||
super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
|
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||||
reset_flag, by_epoch)
|
|
||||||
self.log_dir = log_dir
|
self.log_dir = log_dir
|
||||||
|
|
||||||
@master_only
|
@master_only
|
||||||
def before_run(self, runner):
|
def before_run(self, runner):
|
||||||
super(TensorboardLoggerHook, self).before_run(runner)
|
super().before_run(runner)
|
||||||
if (TORCH_VERSION == 'parrots'
|
if (TORCH_VERSION == 'parrots'
|
||||||
or digit_version(TORCH_VERSION) < digit_version('1.1')):
|
or digit_version(TORCH_VERSION) < digit_version('1.1')):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -62,8 +62,7 @@ class TextLoggerHook(LoggerHook):
|
||||||
out_suffix=('.log.json', '.log', '.py'),
|
out_suffix=('.log.json', '.log', '.py'),
|
||||||
keep_local=True,
|
keep_local=True,
|
||||||
file_client_args=None):
|
file_client_args=None):
|
||||||
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
|
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||||
by_epoch)
|
|
||||||
self.by_epoch = by_epoch
|
self.by_epoch = by_epoch
|
||||||
self.time_sec_tot = 0
|
self.time_sec_tot = 0
|
||||||
self.interval_exp_name = interval_exp_name
|
self.interval_exp_name = interval_exp_name
|
||||||
|
@ -87,7 +86,7 @@ class TextLoggerHook(LoggerHook):
|
||||||
self.out_dir)
|
self.out_dir)
|
||||||
|
|
||||||
def before_run(self, runner):
|
def before_run(self, runner):
|
||||||
super(TextLoggerHook, self).before_run(runner)
|
super().before_run(runner)
|
||||||
|
|
||||||
if self.out_dir is not None:
|
if self.out_dir is not None:
|
||||||
self.file_client = FileClient.infer_client(self.file_client_args,
|
self.file_client = FileClient.infer_client(self.file_client_args,
|
||||||
|
@ -97,8 +96,8 @@ class TextLoggerHook(LoggerHook):
|
||||||
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
||||||
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
||||||
runner.logger.info(
|
runner.logger.info(
|
||||||
(f'Text logs will be saved to {self.out_dir} by '
|
f'Text logs will be saved to {self.out_dir} by '
|
||||||
f'{self.file_client.name} after the training process.'))
|
f'{self.file_client.name} after the training process.')
|
||||||
|
|
||||||
self.start_iter = runner.iter
|
self.start_iter = runner.iter
|
||||||
self.json_log_path = osp.join(runner.work_dir,
|
self.json_log_path = osp.join(runner.work_dir,
|
||||||
|
@ -242,15 +241,15 @@ class TextLoggerHook(LoggerHook):
|
||||||
local_filepath = osp.join(runner.work_dir, filename)
|
local_filepath = osp.join(runner.work_dir, filename)
|
||||||
out_filepath = self.file_client.join_path(
|
out_filepath = self.file_client.join_path(
|
||||||
self.out_dir, filename)
|
self.out_dir, filename)
|
||||||
with open(local_filepath, 'r') as f:
|
with open(local_filepath) as f:
|
||||||
self.file_client.put_text(f.read(), out_filepath)
|
self.file_client.put_text(f.read(), out_filepath)
|
||||||
|
|
||||||
runner.logger.info(
|
runner.logger.info(
|
||||||
(f'The file {local_filepath} has been uploaded to '
|
f'The file {local_filepath} has been uploaded to '
|
||||||
f'{out_filepath}.'))
|
f'{out_filepath}.')
|
||||||
|
|
||||||
if not self.keep_local:
|
if not self.keep_local:
|
||||||
os.remove(local_filepath)
|
os.remove(local_filepath)
|
||||||
runner.logger.info(
|
runner.logger.info(
|
||||||
(f'{local_filepath} was removed due to the '
|
f'{local_filepath} was removed due to the '
|
||||||
'`self.keep_local=False`'))
|
'`self.keep_local=False`')
|
||||||
|
|
|
@ -57,8 +57,7 @@ class WandbLoggerHook(LoggerHook):
|
||||||
with_step=True,
|
with_step=True,
|
||||||
log_artifact=True,
|
log_artifact=True,
|
||||||
out_suffix=('.log.json', '.log', '.py')):
|
out_suffix=('.log.json', '.log', '.py')):
|
||||||
super(WandbLoggerHook, self).__init__(interval, ignore_last,
|
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||||
reset_flag, by_epoch)
|
|
||||||
self.import_wandb()
|
self.import_wandb()
|
||||||
self.init_kwargs = init_kwargs
|
self.init_kwargs = init_kwargs
|
||||||
self.commit = commit
|
self.commit = commit
|
||||||
|
@ -76,7 +75,7 @@ class WandbLoggerHook(LoggerHook):
|
||||||
|
|
||||||
@master_only
|
@master_only
|
||||||
def before_run(self, runner):
|
def before_run(self, runner):
|
||||||
super(WandbLoggerHook, self).before_run(runner)
|
super().before_run(runner)
|
||||||
if self.wandb is None:
|
if self.wandb is None:
|
||||||
self.import_wandb()
|
self.import_wandb()
|
||||||
if self.init_kwargs:
|
if self.init_kwargs:
|
||||||
|
|
|
@ -157,7 +157,7 @@ class LrUpdaterHook(Hook):
|
||||||
class FixedLrUpdaterHook(LrUpdaterHook):
|
class FixedLrUpdaterHook(LrUpdaterHook):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super(FixedLrUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def get_lr(self, runner, base_lr):
|
def get_lr(self, runner, base_lr):
|
||||||
return base_lr
|
return base_lr
|
||||||
|
@ -188,7 +188,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
|
||||||
self.step = step
|
self.step = step
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.min_lr = min_lr
|
self.min_lr = min_lr
|
||||||
super(StepLrUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def get_lr(self, runner, base_lr):
|
def get_lr(self, runner, base_lr):
|
||||||
progress = runner.epoch if self.by_epoch else runner.iter
|
progress = runner.epoch if self.by_epoch else runner.iter
|
||||||
|
@ -215,7 +215,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
|
||||||
|
|
||||||
def __init__(self, gamma, **kwargs):
|
def __init__(self, gamma, **kwargs):
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
super(ExpLrUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def get_lr(self, runner, base_lr):
|
def get_lr(self, runner, base_lr):
|
||||||
progress = runner.epoch if self.by_epoch else runner.iter
|
progress = runner.epoch if self.by_epoch else runner.iter
|
||||||
|
@ -228,7 +228,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
|
||||||
def __init__(self, power=1., min_lr=0., **kwargs):
|
def __init__(self, power=1., min_lr=0., **kwargs):
|
||||||
self.power = power
|
self.power = power
|
||||||
self.min_lr = min_lr
|
self.min_lr = min_lr
|
||||||
super(PolyLrUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def get_lr(self, runner, base_lr):
|
def get_lr(self, runner, base_lr):
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
|
@ -247,7 +247,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
|
||||||
def __init__(self, gamma, power=1., **kwargs):
|
def __init__(self, gamma, power=1., **kwargs):
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.power = power
|
self.power = power
|
||||||
super(InvLrUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def get_lr(self, runner, base_lr):
|
def get_lr(self, runner, base_lr):
|
||||||
progress = runner.epoch if self.by_epoch else runner.iter
|
progress = runner.epoch if self.by_epoch else runner.iter
|
||||||
|
@ -269,7 +269,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
|
||||||
assert (min_lr is None) ^ (min_lr_ratio is None)
|
assert (min_lr is None) ^ (min_lr_ratio is None)
|
||||||
self.min_lr = min_lr
|
self.min_lr = min_lr
|
||||||
self.min_lr_ratio = min_lr_ratio
|
self.min_lr_ratio = min_lr_ratio
|
||||||
super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def get_lr(self, runner, base_lr):
|
def get_lr(self, runner, base_lr):
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
|
@ -317,7 +317,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
|
||||||
self.start_percent = start_percent
|
self.start_percent = start_percent
|
||||||
self.min_lr = min_lr
|
self.min_lr = min_lr
|
||||||
self.min_lr_ratio = min_lr_ratio
|
self.min_lr_ratio = min_lr_ratio
|
||||||
super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def get_lr(self, runner, base_lr):
|
def get_lr(self, runner, base_lr):
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
|
@ -367,7 +367,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
|
||||||
self.restart_weights = restart_weights
|
self.restart_weights = restart_weights
|
||||||
assert (len(self.periods) == len(self.restart_weights)
|
assert (len(self.periods) == len(self.restart_weights)
|
||||||
), 'periods and restart_weights should have the same length.'
|
), 'periods and restart_weights should have the same length.'
|
||||||
super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.cumulative_periods = [
|
self.cumulative_periods = [
|
||||||
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
||||||
|
@ -484,10 +484,10 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
|
||||||
|
|
||||||
assert not by_epoch, \
|
assert not by_epoch, \
|
||||||
'currently only support "by_epoch" = False'
|
'currently only support "by_epoch" = False'
|
||||||
super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
|
super().__init__(by_epoch, **kwargs)
|
||||||
|
|
||||||
def before_run(self, runner):
|
def before_run(self, runner):
|
||||||
super(CyclicLrUpdaterHook, self).before_run(runner)
|
super().before_run(runner)
|
||||||
# initiate lr_phases
|
# initiate lr_phases
|
||||||
# total lr_phases are separated as up and down
|
# total lr_phases are separated as up and down
|
||||||
self.max_iter_per_phase = runner.max_iters // self.cyclic_times
|
self.max_iter_per_phase = runner.max_iters // self.cyclic_times
|
||||||
|
@ -598,7 +598,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
|
||||||
self.final_div_factor = final_div_factor
|
self.final_div_factor = final_div_factor
|
||||||
self.three_phase = three_phase
|
self.three_phase = three_phase
|
||||||
self.lr_phases = [] # init lr_phases
|
self.lr_phases = [] # init lr_phases
|
||||||
super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def before_run(self, runner):
|
def before_run(self, runner):
|
||||||
if hasattr(self, 'total_steps'):
|
if hasattr(self, 'total_steps'):
|
||||||
|
@ -668,7 +668,7 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
|
||||||
assert (min_lr is None) ^ (min_lr_ratio is None)
|
assert (min_lr is None) ^ (min_lr_ratio is None)
|
||||||
self.min_lr = min_lr
|
self.min_lr = min_lr
|
||||||
self.min_lr_ratio = min_lr_ratio
|
self.min_lr_ratio = min_lr_ratio
|
||||||
super(LinearAnnealingLrUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def get_lr(self, runner, base_lr):
|
def get_lr(self, runner, base_lr):
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
|
|
|
@ -176,7 +176,7 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook):
|
||||||
self.step = step
|
self.step = step
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.min_momentum = min_momentum
|
self.min_momentum = min_momentum
|
||||||
super(StepMomentumUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def get_momentum(self, runner, base_momentum):
|
def get_momentum(self, runner, base_momentum):
|
||||||
progress = runner.epoch if self.by_epoch else runner.iter
|
progress = runner.epoch if self.by_epoch else runner.iter
|
||||||
|
@ -214,7 +214,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
|
||||||
assert (min_momentum is None) ^ (min_momentum_ratio is None)
|
assert (min_momentum is None) ^ (min_momentum_ratio is None)
|
||||||
self.min_momentum = min_momentum
|
self.min_momentum = min_momentum
|
||||||
self.min_momentum_ratio = min_momentum_ratio
|
self.min_momentum_ratio = min_momentum_ratio
|
||||||
super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def get_momentum(self, runner, base_momentum):
|
def get_momentum(self, runner, base_momentum):
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
|
@ -247,7 +247,7 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
|
||||||
assert (min_momentum is None) ^ (min_momentum_ratio is None)
|
assert (min_momentum is None) ^ (min_momentum_ratio is None)
|
||||||
self.min_momentum = min_momentum
|
self.min_momentum = min_momentum
|
||||||
self.min_momentum_ratio = min_momentum_ratio
|
self.min_momentum_ratio = min_momentum_ratio
|
||||||
super(LinearAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def get_momentum(self, runner, base_momentum):
|
def get_momentum(self, runner, base_momentum):
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
|
@ -328,10 +328,10 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
|
||||||
# currently only support by_epoch=False
|
# currently only support by_epoch=False
|
||||||
assert not by_epoch, \
|
assert not by_epoch, \
|
||||||
'currently only support "by_epoch" = False'
|
'currently only support "by_epoch" = False'
|
||||||
super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
|
super().__init__(by_epoch, **kwargs)
|
||||||
|
|
||||||
def before_run(self, runner):
|
def before_run(self, runner):
|
||||||
super(CyclicMomentumUpdaterHook, self).before_run(runner)
|
super().before_run(runner)
|
||||||
# initiate momentum_phases
|
# initiate momentum_phases
|
||||||
# total momentum_phases are separated as up and down
|
# total momentum_phases are separated as up and down
|
||||||
max_iter_per_phase = runner.max_iters // self.cyclic_times
|
max_iter_per_phase = runner.max_iters // self.cyclic_times
|
||||||
|
@ -439,7 +439,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
|
||||||
self.anneal_func = annealing_linear
|
self.anneal_func = annealing_linear
|
||||||
self.three_phase = three_phase
|
self.three_phase = three_phase
|
||||||
self.momentum_phases = [] # init momentum_phases
|
self.momentum_phases = [] # init momentum_phases
|
||||||
super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def before_run(self, runner):
|
def before_run(self, runner):
|
||||||
if isinstance(runner.optimizer, dict):
|
if isinstance(runner.optimizer, dict):
|
||||||
|
|
|
@ -110,7 +110,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cumulative_iters=1, **kwargs):
|
def __init__(self, cumulative_iters=1, **kwargs):
|
||||||
super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
|
assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
|
||||||
f'cumulative_iters only accepts positive int, but got ' \
|
f'cumulative_iters only accepts positive int, but got ' \
|
||||||
|
@ -297,8 +297,7 @@ if (TORCH_VERSION != 'parrots'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(GradientCumulativeFp16OptimizerHook,
|
super().__init__(*args, **kwargs)
|
||||||
self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def after_train_iter(self, runner):
|
def after_train_iter(self, runner):
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
|
@ -490,8 +489,7 @@ else:
|
||||||
iters gradient cumulating."""
|
iters gradient cumulating."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(GradientCumulativeFp16OptimizerHook,
|
super().__init__(*args, **kwargs)
|
||||||
self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def after_train_iter(self, runner):
|
def after_train_iter(self, runner):
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
|
|
|
@ -263,7 +263,7 @@ class IterBasedRunner(BaseRunner):
|
||||||
if log_config is not None:
|
if log_config is not None:
|
||||||
for info in log_config['hooks']:
|
for info in log_config['hooks']:
|
||||||
info.setdefault('by_epoch', False)
|
info.setdefault('by_epoch', False)
|
||||||
super(IterBasedRunner, self).register_training_hooks(
|
super().register_training_hooks(
|
||||||
lr_config=lr_config,
|
lr_config=lr_config,
|
||||||
momentum_config=momentum_config,
|
momentum_config=momentum_config,
|
||||||
optimizer_config=optimizer_config,
|
optimizer_config=optimizer_config,
|
||||||
|
|
|
@ -54,7 +54,7 @@ def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
|
||||||
msg += reset_style
|
msg += reset_style
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
|
|
||||||
device = torch.device('cuda:{}'.format(device_id))
|
device = torch.device(f'cuda:{device_id}')
|
||||||
# create builder and network
|
# create builder and network
|
||||||
logger = trt.Logger(log_level)
|
logger = trt.Logger(log_level)
|
||||||
builder = trt.Builder(logger)
|
builder = trt.Builder(logger)
|
||||||
|
@ -209,7 +209,7 @@ class TRTWrapper(torch.nn.Module):
|
||||||
msg += reset_style
|
msg += reset_style
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
|
|
||||||
super(TRTWrapper, self).__init__()
|
super().__init__()
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
if isinstance(self.engine, str):
|
if isinstance(self.engine, str):
|
||||||
self.engine = load_trt_engine(engine)
|
self.engine = load_trt_engine(engine)
|
||||||
|
|
|
@ -39,7 +39,7 @@ class ConfigDict(Dict):
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
try:
|
try:
|
||||||
value = super(ConfigDict, self).__getattr__(name)
|
value = super().__getattr__(name)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
|
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
|
||||||
f"attribute '{name}'")
|
f"attribute '{name}'")
|
||||||
|
@ -96,7 +96,7 @@ class Config:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_py_syntax(filename):
|
def _validate_py_syntax(filename):
|
||||||
with open(filename, 'r', encoding='utf-8') as f:
|
with open(filename, encoding='utf-8') as f:
|
||||||
# Setting encoding explicitly to resolve coding issue on windows
|
# Setting encoding explicitly to resolve coding issue on windows
|
||||||
content = f.read()
|
content = f.read()
|
||||||
try:
|
try:
|
||||||
|
@ -116,7 +116,7 @@ class Config:
|
||||||
fileBasename=file_basename,
|
fileBasename=file_basename,
|
||||||
fileBasenameNoExtension=file_basename_no_extension,
|
fileBasenameNoExtension=file_basename_no_extension,
|
||||||
fileExtname=file_extname)
|
fileExtname=file_extname)
|
||||||
with open(filename, 'r', encoding='utf-8') as f:
|
with open(filename, encoding='utf-8') as f:
|
||||||
# Setting encoding explicitly to resolve coding issue on windows
|
# Setting encoding explicitly to resolve coding issue on windows
|
||||||
config_file = f.read()
|
config_file = f.read()
|
||||||
for key, value in support_templates.items():
|
for key, value in support_templates.items():
|
||||||
|
@ -130,7 +130,7 @@ class Config:
|
||||||
def _pre_substitute_base_vars(filename, temp_config_name):
|
def _pre_substitute_base_vars(filename, temp_config_name):
|
||||||
"""Substitute base variable placehoders to string, so that parsing
|
"""Substitute base variable placehoders to string, so that parsing
|
||||||
would work."""
|
would work."""
|
||||||
with open(filename, 'r', encoding='utf-8') as f:
|
with open(filename, encoding='utf-8') as f:
|
||||||
# Setting encoding explicitly to resolve coding issue on windows
|
# Setting encoding explicitly to resolve coding issue on windows
|
||||||
config_file = f.read()
|
config_file = f.read()
|
||||||
base_var_dict = {}
|
base_var_dict = {}
|
||||||
|
@ -183,7 +183,7 @@ class Config:
|
||||||
check_file_exist(filename)
|
check_file_exist(filename)
|
||||||
fileExtname = osp.splitext(filename)[1]
|
fileExtname = osp.splitext(filename)[1]
|
||||||
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
|
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
|
||||||
raise IOError('Only py/yml/yaml/json type are supported now!')
|
raise OSError('Only py/yml/yaml/json type are supported now!')
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_config_dir:
|
with tempfile.TemporaryDirectory() as temp_config_dir:
|
||||||
temp_config_file = tempfile.NamedTemporaryFile(
|
temp_config_file = tempfile.NamedTemporaryFile(
|
||||||
|
@ -236,7 +236,7 @@ class Config:
|
||||||
warnings.warn(warning_msg, DeprecationWarning)
|
warnings.warn(warning_msg, DeprecationWarning)
|
||||||
|
|
||||||
cfg_text = filename + '\n'
|
cfg_text = filename + '\n'
|
||||||
with open(filename, 'r', encoding='utf-8') as f:
|
with open(filename, encoding='utf-8') as f:
|
||||||
# Setting encoding explicitly to resolve coding issue on windows
|
# Setting encoding explicitly to resolve coding issue on windows
|
||||||
cfg_text += f.read()
|
cfg_text += f.read()
|
||||||
|
|
||||||
|
@ -356,7 +356,7 @@ class Config:
|
||||||
:obj:`Config`: Config obj.
|
:obj:`Config`: Config obj.
|
||||||
"""
|
"""
|
||||||
if file_format not in ['.py', '.json', '.yaml', '.yml']:
|
if file_format not in ['.py', '.json', '.yaml', '.yml']:
|
||||||
raise IOError('Only py/yml/yaml/json type are supported now!')
|
raise OSError('Only py/yml/yaml/json type are supported now!')
|
||||||
if file_format != '.py' and 'dict(' in cfg_str:
|
if file_format != '.py' and 'dict(' in cfg_str:
|
||||||
# check if users specify a wrong suffix for python
|
# check if users specify a wrong suffix for python
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -396,16 +396,16 @@ class Config:
|
||||||
if isinstance(filename, Path):
|
if isinstance(filename, Path):
|
||||||
filename = str(filename)
|
filename = str(filename)
|
||||||
|
|
||||||
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
|
super().__setattr__('_cfg_dict', ConfigDict(cfg_dict))
|
||||||
super(Config, self).__setattr__('_filename', filename)
|
super().__setattr__('_filename', filename)
|
||||||
if cfg_text:
|
if cfg_text:
|
||||||
text = cfg_text
|
text = cfg_text
|
||||||
elif filename:
|
elif filename:
|
||||||
with open(filename, 'r') as f:
|
with open(filename) as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
else:
|
else:
|
||||||
text = ''
|
text = ''
|
||||||
super(Config, self).__setattr__('_text', text)
|
super().__setattr__('_text', text)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def filename(self):
|
def filename(self):
|
||||||
|
@ -556,9 +556,9 @@ class Config:
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
_cfg_dict, _filename, _text = state
|
_cfg_dict, _filename, _text = state
|
||||||
super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
|
super().__setattr__('_cfg_dict', _cfg_dict)
|
||||||
super(Config, self).__setattr__('_filename', _filename)
|
super().__setattr__('_filename', _filename)
|
||||||
super(Config, self).__setattr__('_text', _text)
|
super().__setattr__('_text', _text)
|
||||||
|
|
||||||
def dump(self, file=None):
|
def dump(self, file=None):
|
||||||
"""Dumps config into a file or returns a string representation of the
|
"""Dumps config into a file or returns a string representation of the
|
||||||
|
@ -584,7 +584,7 @@ class Config:
|
||||||
will be dumped. Defaults to None.
|
will be dumped. Defaults to None.
|
||||||
"""
|
"""
|
||||||
import mmcv
|
import mmcv
|
||||||
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
|
cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
|
||||||
if file is None:
|
if file is None:
|
||||||
if self.filename is None or self.filename.endswith('.py'):
|
if self.filename is None or self.filename.endswith('.py'):
|
||||||
return self.pretty_text
|
return self.pretty_text
|
||||||
|
@ -638,8 +638,8 @@ class Config:
|
||||||
subkey = key_list[-1]
|
subkey = key_list[-1]
|
||||||
d[subkey] = v
|
d[subkey] = v
|
||||||
|
|
||||||
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
cfg_dict = super().__getattribute__('_cfg_dict')
|
||||||
super(Config, self).__setattr__(
|
super().__setattr__(
|
||||||
'_cfg_dict',
|
'_cfg_dict',
|
||||||
Config._merge_a_into_b(
|
Config._merge_a_into_b(
|
||||||
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
|
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
|
||||||
|
|
|
@ -6,7 +6,7 @@ class TimerError(Exception):
|
||||||
|
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
self.message = message
|
self.message = message
|
||||||
super(TimerError, self).__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
|
|
|
@ -40,10 +40,10 @@ def flowread(flow_or_path: Union[np.ndarray, str],
|
||||||
try:
|
try:
|
||||||
header = f.read(4).decode('utf-8')
|
header = f.read(4).decode('utf-8')
|
||||||
except Exception:
|
except Exception:
|
||||||
raise IOError(f'Invalid flow file: {flow_or_path}')
|
raise OSError(f'Invalid flow file: {flow_or_path}')
|
||||||
else:
|
else:
|
||||||
if header != 'PIEH':
|
if header != 'PIEH':
|
||||||
raise IOError(f'Invalid flow file: {flow_or_path}, '
|
raise OSError(f'Invalid flow file: {flow_or_path}, '
|
||||||
'header does not contain PIEH')
|
'header does not contain PIEH')
|
||||||
|
|
||||||
w = np.fromfile(f, np.int32, 1).squeeze()
|
w = np.fromfile(f, np.int32, 1).squeeze()
|
||||||
|
@ -53,7 +53,7 @@ def flowread(flow_or_path: Union[np.ndarray, str],
|
||||||
assert concat_axis in [0, 1]
|
assert concat_axis in [0, 1]
|
||||||
cat_flow = imread(flow_or_path, flag='unchanged')
|
cat_flow = imread(flow_or_path, flag='unchanged')
|
||||||
if cat_flow.ndim != 2:
|
if cat_flow.ndim != 2:
|
||||||
raise IOError(
|
raise OSError(
|
||||||
f'{flow_or_path} is not a valid quantized flow file, '
|
f'{flow_or_path} is not a valid quantized flow file, '
|
||||||
f'its dimension is {cat_flow.ndim}.')
|
f'its dimension is {cat_flow.ndim}.')
|
||||||
assert cat_flow.shape[concat_axis] % 2 == 0
|
assert cat_flow.shape[concat_axis] % 2 == 0
|
||||||
|
@ -86,7 +86,7 @@ def flowwrite(flow: np.ndarray,
|
||||||
"""
|
"""
|
||||||
if not quantize:
|
if not quantize:
|
||||||
with open(filename, 'wb') as f:
|
with open(filename, 'wb') as f:
|
||||||
f.write('PIEH'.encode('utf-8'))
|
f.write(b'PIEH')
|
||||||
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
|
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
|
||||||
flow = flow.astype(np.float32)
|
flow = flow.astype(np.float32)
|
||||||
flow.tofile(f)
|
flow.tofile(f)
|
||||||
|
@ -146,7 +146,7 @@ def dequantize_flow(dx: np.ndarray,
|
||||||
assert dx.shape == dy.shape
|
assert dx.shape == dy.shape
|
||||||
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
|
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
|
||||||
|
|
||||||
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
|
dx, dy = (dequantize(d, -max_val, max_val, 255) for d in [dx, dy])
|
||||||
|
|
||||||
if denorm:
|
if denorm:
|
||||||
dx *= dx.shape[1]
|
dx *= dx.shape[1]
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from __future__ import division
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
7
setup.py
7
setup.py
|
@ -39,7 +39,7 @@ def choose_requirement(primary, secondary):
|
||||||
|
|
||||||
def get_version():
|
def get_version():
|
||||||
version_file = 'mmcv/version.py'
|
version_file = 'mmcv/version.py'
|
||||||
with open(version_file, 'r', encoding='utf-8') as f:
|
with open(version_file, encoding='utf-8') as f:
|
||||||
exec(compile(f.read(), version_file, 'exec'))
|
exec(compile(f.read(), version_file, 'exec'))
|
||||||
return locals()['__version__']
|
return locals()['__version__']
|
||||||
|
|
||||||
|
@ -94,12 +94,11 @@ def parse_requirements(fname='requirements/runtime.txt', with_version=True):
|
||||||
yield info
|
yield info
|
||||||
|
|
||||||
def parse_require_file(fpath):
|
def parse_require_file(fpath):
|
||||||
with open(fpath, 'r') as f:
|
with open(fpath) as f:
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line and not line.startswith('#'):
|
if line and not line.startswith('#'):
|
||||||
for info in parse_line(line):
|
yield from parse_line(line)
|
||||||
yield info
|
|
||||||
|
|
||||||
def gen_packages_items():
|
def gen_packages_items():
|
||||||
if exists(require_fpath):
|
if exists(require_fpath):
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from __future__ import division
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
|
@ -23,7 +23,7 @@ class ExampleConv(nn.Module):
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
norm_cfg=None):
|
norm_cfg=None):
|
||||||
super(ExampleConv, self).__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
|
|
|
@ -202,21 +202,22 @@ class TestFileClient:
|
||||||
# test `list_dir_or_file`
|
# test `list_dir_or_file`
|
||||||
with build_temporary_directory() as tmp_dir:
|
with build_temporary_directory() as tmp_dir:
|
||||||
# 1. list directories and files
|
# 1. list directories and files
|
||||||
assert set(disk_backend.list_dir_or_file(tmp_dir)) == set(
|
assert set(disk_backend.list_dir_or_file(tmp_dir)) == {
|
||||||
['dir1', 'dir2', 'text1.txt', 'text2.txt'])
|
'dir1', 'dir2', 'text1.txt', 'text2.txt'
|
||||||
|
}
|
||||||
# 2. list directories and files recursively
|
# 2. list directories and files recursively
|
||||||
assert set(disk_backend.list_dir_or_file(
|
assert set(disk_backend.list_dir_or_file(
|
||||||
tmp_dir, recursive=True)) == set([
|
tmp_dir, recursive=True)) == {
|
||||||
'dir1',
|
'dir1',
|
||||||
osp.join('dir1', 'text3.txt'), 'dir2',
|
osp.join('dir1', 'text3.txt'), 'dir2',
|
||||||
osp.join('dir2', 'dir3'),
|
osp.join('dir2', 'dir3'),
|
||||||
osp.join('dir2', 'dir3', 'text4.txt'),
|
osp.join('dir2', 'dir3', 'text4.txt'),
|
||||||
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
||||||
])
|
}
|
||||||
# 3. only list directories
|
# 3. only list directories
|
||||||
assert set(
|
assert set(
|
||||||
disk_backend.list_dir_or_file(
|
disk_backend.list_dir_or_file(
|
||||||
tmp_dir, list_file=False)) == set(['dir1', 'dir2'])
|
tmp_dir, list_file=False)) == {'dir1', 'dir2'}
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
TypeError,
|
TypeError,
|
||||||
match='`suffix` should be None when `list_dir` is True'):
|
match='`suffix` should be None when `list_dir` is True'):
|
||||||
|
@ -227,30 +228,30 @@ class TestFileClient:
|
||||||
# 4. only list directories recursively
|
# 4. only list directories recursively
|
||||||
assert set(
|
assert set(
|
||||||
disk_backend.list_dir_or_file(
|
disk_backend.list_dir_or_file(
|
||||||
tmp_dir, list_file=False, recursive=True)) == set(
|
tmp_dir, list_file=False, recursive=True)) == {
|
||||||
['dir1', 'dir2',
|
'dir1', 'dir2',
|
||||||
osp.join('dir2', 'dir3')])
|
osp.join('dir2', 'dir3')
|
||||||
|
}
|
||||||
# 5. only list files
|
# 5. only list files
|
||||||
assert set(disk_backend.list_dir_or_file(
|
assert set(disk_backend.list_dir_or_file(
|
||||||
tmp_dir, list_dir=False)) == set(['text1.txt', 'text2.txt'])
|
tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
|
||||||
# 6. only list files recursively
|
# 6. only list files recursively
|
||||||
assert set(
|
assert set(
|
||||||
disk_backend.list_dir_or_file(
|
disk_backend.list_dir_or_file(
|
||||||
tmp_dir, list_dir=False, recursive=True)) == set([
|
tmp_dir, list_dir=False, recursive=True)) == {
|
||||||
osp.join('dir1', 'text3.txt'),
|
osp.join('dir1', 'text3.txt'),
|
||||||
osp.join('dir2', 'dir3', 'text4.txt'),
|
osp.join('dir2', 'dir3', 'text4.txt'),
|
||||||
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
||||||
])
|
}
|
||||||
# 7. only list files ending with suffix
|
# 7. only list files ending with suffix
|
||||||
assert set(
|
assert set(
|
||||||
disk_backend.list_dir_or_file(
|
disk_backend.list_dir_or_file(
|
||||||
tmp_dir, list_dir=False,
|
tmp_dir, list_dir=False,
|
||||||
suffix='.txt')) == set(['text1.txt', 'text2.txt'])
|
suffix='.txt')) == {'text1.txt', 'text2.txt'}
|
||||||
assert set(
|
assert set(
|
||||||
disk_backend.list_dir_or_file(
|
disk_backend.list_dir_or_file(
|
||||||
tmp_dir, list_dir=False,
|
tmp_dir, list_dir=False,
|
||||||
suffix=('.txt',
|
suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
|
||||||
'.jpg'))) == set(['text1.txt', 'text2.txt'])
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
TypeError,
|
TypeError,
|
||||||
match='`suffix` must be a string or tuple of strings'):
|
match='`suffix` must be a string or tuple of strings'):
|
||||||
|
@ -260,22 +261,22 @@ class TestFileClient:
|
||||||
assert set(
|
assert set(
|
||||||
disk_backend.list_dir_or_file(
|
disk_backend.list_dir_or_file(
|
||||||
tmp_dir, list_dir=False, suffix='.txt',
|
tmp_dir, list_dir=False, suffix='.txt',
|
||||||
recursive=True)) == set([
|
recursive=True)) == {
|
||||||
osp.join('dir1', 'text3.txt'),
|
osp.join('dir1', 'text3.txt'),
|
||||||
osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt',
|
osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt',
|
||||||
'text2.txt'
|
'text2.txt'
|
||||||
])
|
}
|
||||||
# 7. only list files ending with suffix
|
# 7. only list files ending with suffix
|
||||||
assert set(
|
assert set(
|
||||||
disk_backend.list_dir_or_file(
|
disk_backend.list_dir_or_file(
|
||||||
tmp_dir,
|
tmp_dir,
|
||||||
list_dir=False,
|
list_dir=False,
|
||||||
suffix=('.txt', '.jpg'),
|
suffix=('.txt', '.jpg'),
|
||||||
recursive=True)) == set([
|
recursive=True)) == {
|
||||||
osp.join('dir1', 'text3.txt'),
|
osp.join('dir1', 'text3.txt'),
|
||||||
osp.join('dir2', 'dir3', 'text4.txt'),
|
osp.join('dir2', 'dir3', 'text4.txt'),
|
||||||
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
||||||
])
|
}
|
||||||
|
|
||||||
@patch('ceph.S3Client', MockS3Client)
|
@patch('ceph.S3Client', MockS3Client)
|
||||||
def test_ceph_backend(self):
|
def test_ceph_backend(self):
|
||||||
|
@ -463,21 +464,21 @@ class TestFileClient:
|
||||||
|
|
||||||
with build_temporary_directory() as tmp_dir:
|
with build_temporary_directory() as tmp_dir:
|
||||||
# 1. list directories and files
|
# 1. list directories and files
|
||||||
assert set(petrel_backend.list_dir_or_file(tmp_dir)) == set(
|
assert set(petrel_backend.list_dir_or_file(tmp_dir)) == {
|
||||||
['dir1', 'dir2', 'text1.txt', 'text2.txt'])
|
'dir1', 'dir2', 'text1.txt', 'text2.txt'
|
||||||
|
}
|
||||||
# 2. list directories and files recursively
|
# 2. list directories and files recursively
|
||||||
assert set(
|
assert set(
|
||||||
petrel_backend.list_dir_or_file(
|
petrel_backend.list_dir_or_file(tmp_dir, recursive=True)) == {
|
||||||
tmp_dir, recursive=True)) == set([
|
'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join(
|
||||||
'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2',
|
('dir2', 'dir3')), '/'.join(
|
||||||
'/'.join(('dir2', 'dir3')), '/'.join(
|
|
||||||
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
||||||
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
||||||
])
|
}
|
||||||
# 3. only list directories
|
# 3. only list directories
|
||||||
assert set(
|
assert set(
|
||||||
petrel_backend.list_dir_or_file(
|
petrel_backend.list_dir_or_file(
|
||||||
tmp_dir, list_file=False)) == set(['dir1', 'dir2'])
|
tmp_dir, list_file=False)) == {'dir1', 'dir2'}
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
TypeError,
|
TypeError,
|
||||||
match=('`list_dir` should be False when `suffix` is not '
|
match=('`list_dir` should be False when `suffix` is not '
|
||||||
|
@ -489,31 +490,30 @@ class TestFileClient:
|
||||||
# 4. only list directories recursively
|
# 4. only list directories recursively
|
||||||
assert set(
|
assert set(
|
||||||
petrel_backend.list_dir_or_file(
|
petrel_backend.list_dir_or_file(
|
||||||
tmp_dir, list_file=False, recursive=True)) == set(
|
tmp_dir, list_file=False, recursive=True)) == {
|
||||||
['dir1', 'dir2', '/'.join(('dir2', 'dir3'))])
|
'dir1', 'dir2', '/'.join(('dir2', 'dir3'))
|
||||||
|
}
|
||||||
# 5. only list files
|
# 5. only list files
|
||||||
assert set(
|
assert set(
|
||||||
petrel_backend.list_dir_or_file(tmp_dir,
|
petrel_backend.list_dir_or_file(
|
||||||
list_dir=False)) == set(
|
tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
|
||||||
['text1.txt', 'text2.txt'])
|
|
||||||
# 6. only list files recursively
|
# 6. only list files recursively
|
||||||
assert set(
|
assert set(
|
||||||
petrel_backend.list_dir_or_file(
|
petrel_backend.list_dir_or_file(
|
||||||
tmp_dir, list_dir=False, recursive=True)) == set([
|
tmp_dir, list_dir=False, recursive=True)) == {
|
||||||
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
||||||
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
||||||
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
||||||
])
|
}
|
||||||
# 7. only list files ending with suffix
|
# 7. only list files ending with suffix
|
||||||
assert set(
|
assert set(
|
||||||
petrel_backend.list_dir_or_file(
|
petrel_backend.list_dir_or_file(
|
||||||
tmp_dir, list_dir=False,
|
tmp_dir, list_dir=False,
|
||||||
suffix='.txt')) == set(['text1.txt', 'text2.txt'])
|
suffix='.txt')) == {'text1.txt', 'text2.txt'}
|
||||||
assert set(
|
assert set(
|
||||||
petrel_backend.list_dir_or_file(
|
petrel_backend.list_dir_or_file(
|
||||||
tmp_dir, list_dir=False,
|
tmp_dir, list_dir=False,
|
||||||
suffix=('.txt',
|
suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
|
||||||
'.jpg'))) == set(['text1.txt', 'text2.txt'])
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
TypeError,
|
TypeError,
|
||||||
match='`suffix` must be a string or tuple of strings'):
|
match='`suffix` must be a string or tuple of strings'):
|
||||||
|
@ -523,22 +523,22 @@ class TestFileClient:
|
||||||
assert set(
|
assert set(
|
||||||
petrel_backend.list_dir_or_file(
|
petrel_backend.list_dir_or_file(
|
||||||
tmp_dir, list_dir=False, suffix='.txt',
|
tmp_dir, list_dir=False, suffix='.txt',
|
||||||
recursive=True)) == set([
|
recursive=True)) == {
|
||||||
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
||||||
('dir2', 'dir3', 'text4.txt')), 'text1.txt',
|
('dir2', 'dir3', 'text4.txt')), 'text1.txt',
|
||||||
'text2.txt'
|
'text2.txt'
|
||||||
])
|
}
|
||||||
# 7. only list files ending with suffix
|
# 7. only list files ending with suffix
|
||||||
assert set(
|
assert set(
|
||||||
petrel_backend.list_dir_or_file(
|
petrel_backend.list_dir_or_file(
|
||||||
tmp_dir,
|
tmp_dir,
|
||||||
list_dir=False,
|
list_dir=False,
|
||||||
suffix=('.txt', '.jpg'),
|
suffix=('.txt', '.jpg'),
|
||||||
recursive=True)) == set([
|
recursive=True)) == {
|
||||||
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
||||||
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
||||||
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
||||||
])
|
}
|
||||||
|
|
||||||
@patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
|
@patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
|
||||||
@patch('mc.pyvector', MagicMock)
|
@patch('mc.pyvector', MagicMock)
|
||||||
|
|
|
@ -128,7 +128,7 @@ def test_register_handler():
|
||||||
assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
|
assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
|
||||||
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2')
|
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2')
|
||||||
mmcv.dump(content, tmp_filename)
|
mmcv.dump(content, tmp_filename)
|
||||||
with open(tmp_filename, 'r') as f:
|
with open(tmp_filename) as f:
|
||||||
written = f.read()
|
written = f.read()
|
||||||
os.remove(tmp_filename)
|
os.remove(tmp_filename)
|
||||||
assert written == '\n' + content
|
assert written == '\n' + content
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
class TestBBox(object):
|
class TestBBox:
|
||||||
|
|
||||||
def _test_bbox_overlaps(self, device='cpu', dtype=torch.float):
|
def _test_bbox_overlaps(self, device='cpu', dtype=torch.float):
|
||||||
from mmcv.ops import bbox_overlaps
|
from mmcv.ops import bbox_overlaps
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class TestBilinearGridSample(object):
|
class TestBilinearGridSample:
|
||||||
|
|
||||||
def _test_bilinear_grid_sample(self,
|
def _test_bilinear_grid_sample(self,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class TestBoxIoURotated(object):
|
class TestBoxIoURotated:
|
||||||
|
|
||||||
def test_box_iou_rotated_cpu(self):
|
def test_box_iou_rotated_cpu(self):
|
||||||
from mmcv.ops import box_iou_rotated
|
from mmcv.ops import box_iou_rotated
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from torch.autograd import gradcheck
|
from torch.autograd import gradcheck
|
||||||
|
|
||||||
|
|
||||||
class TestCarafe(object):
|
class TestCarafe:
|
||||||
|
|
||||||
def test_carafe_naive_gradcheck(self):
|
def test_carafe_naive_gradcheck(self):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
|
|
|
@ -15,7 +15,7 @@ class Loss(nn.Module):
|
||||||
return torch.mean(input - target)
|
return torch.mean(input - target)
|
||||||
|
|
||||||
|
|
||||||
class TestCrissCrossAttention(object):
|
class TestCrissCrossAttention:
|
||||||
|
|
||||||
def test_cc_attention(self):
|
def test_cc_attention(self):
|
||||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
|
@ -35,7 +35,7 @@ gt_offset_bias_grad = [1.44, -0.72, 0., 0., -0.10, -0.08, -0.54, -0.54],
|
||||||
gt_deform_weight_grad = [[[[3.62, 0.], [0.40, 0.18]]]]
|
gt_deform_weight_grad = [[[[3.62, 0.], [0.40, 0.18]]]]
|
||||||
|
|
||||||
|
|
||||||
class TestDeformconv(object):
|
class TestDeformconv:
|
||||||
|
|
||||||
def _test_deformconv(self,
|
def _test_deformconv(self,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
|
|
|
@ -35,7 +35,7 @@ outputs = [([[[[1, 1.25], [1.5, 1.75]]]], [[[[3.0625, 0.4375],
|
||||||
0.00390625]]]])]
|
0.00390625]]]])]
|
||||||
|
|
||||||
|
|
||||||
class TestDeformRoIPool(object):
|
class TestDeformRoIPool:
|
||||||
|
|
||||||
def test_deform_roi_pool_gradcheck(self):
|
def test_deform_roi_pool_gradcheck(self):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
|
|
|
@ -37,7 +37,7 @@ sigmoid_outputs = [(0.13562961, [[-0.00657264, 0.11185755],
|
||||||
[-0.02462499, 0.08277918, 0.18050370]])]
|
[-0.02462499, 0.08277918, 0.18050370]])]
|
||||||
|
|
||||||
|
|
||||||
class Testfocalloss(object):
|
class Testfocalloss:
|
||||||
|
|
||||||
def _test_softmax(self, dtype=torch.float):
|
def _test_softmax(self, dtype=torch.float):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
|
|
|
@ -10,7 +10,7 @@ except ImportError:
|
||||||
_USING_PARROTS = False
|
_USING_PARROTS = False
|
||||||
|
|
||||||
|
|
||||||
class TestFusedBiasLeakyReLU(object):
|
class TestFusedBiasLeakyReLU:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_class(cls):
|
def setup_class(cls):
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class TestInfo(object):
|
class TestInfo:
|
||||||
|
|
||||||
def test_info(self):
|
def test_info(self):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class TestMaskedConv2d(object):
|
class TestMaskedConv2d:
|
||||||
|
|
||||||
def test_masked_conv2d(self):
|
def test_masked_conv2d(self):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
|
|
|
@ -37,7 +37,7 @@ dcn_offset_b_grad = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TestMdconv(object):
|
class TestMdconv:
|
||||||
|
|
||||||
def _test_mdconv(self, dtype=torch.float, device='cuda'):
|
def _test_mdconv(self, dtype=torch.float, device='cuda'):
|
||||||
if not torch.cuda.is_available() and device == 'cuda':
|
if not torch.cuda.is_available() and device == 'cuda':
|
||||||
|
|
|
@ -55,7 +55,7 @@ def test_forward_multi_scale_deformable_attn_pytorch():
|
||||||
N, M, D = 1, 2, 2
|
N, M, D = 1, 2, 2
|
||||||
Lq, L, P = 2, 2, 2
|
Lq, L, P = 2, 2, 2
|
||||||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
|
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
|
||||||
S = sum([(H * W).item() for H, W in shapes])
|
S = sum((H * W).item() for H, W in shapes)
|
||||||
|
|
||||||
torch.manual_seed(3)
|
torch.manual_seed(3)
|
||||||
value = torch.rand(N, S, M, D) * 0.01
|
value = torch.rand(N, S, M, D) * 0.01
|
||||||
|
@ -78,7 +78,7 @@ def test_forward_equal_with_pytorch_double():
|
||||||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
||||||
level_start_index = torch.cat((shapes.new_zeros(
|
level_start_index = torch.cat((shapes.new_zeros(
|
||||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||||
S = sum([(H * W).item() for H, W in shapes])
|
S = sum((H * W).item() for H, W in shapes)
|
||||||
|
|
||||||
torch.manual_seed(3)
|
torch.manual_seed(3)
|
||||||
value = torch.rand(N, S, M, D).cuda() * 0.01
|
value = torch.rand(N, S, M, D).cuda() * 0.01
|
||||||
|
@ -111,7 +111,7 @@ def test_forward_equal_with_pytorch_float():
|
||||||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
||||||
level_start_index = torch.cat((shapes.new_zeros(
|
level_start_index = torch.cat((shapes.new_zeros(
|
||||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||||
S = sum([(H * W).item() for H, W in shapes])
|
S = sum((H * W).item() for H, W in shapes)
|
||||||
|
|
||||||
torch.manual_seed(3)
|
torch.manual_seed(3)
|
||||||
value = torch.rand(N, S, M, D).cuda() * 0.01
|
value = torch.rand(N, S, M, D).cuda() * 0.01
|
||||||
|
@ -155,7 +155,7 @@ def test_gradient_numerical(channels,
|
||||||
shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).cuda()
|
shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).cuda()
|
||||||
level_start_index = torch.cat((shapes.new_zeros(
|
level_start_index = torch.cat((shapes.new_zeros(
|
||||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||||
S = sum([(H * W).item() for H, W in shapes])
|
S = sum((H * W).item() for H, W in shapes)
|
||||||
|
|
||||||
value = torch.rand(N, S, M, channels).cuda() * 0.01
|
value = torch.rand(N, S, M, channels).cuda() * 0.01
|
||||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
class Testnms(object):
|
class Testnms:
|
||||||
|
|
||||||
@pytest.mark.parametrize('device', [
|
@pytest.mark.parametrize('device', [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
|
@ -129,8 +129,7 @@ class Testnms(object):
|
||||||
scores = tensor_dets[:, 4]
|
scores = tensor_dets[:, 4]
|
||||||
nms_keep_inds = nms(boxes.contiguous(), scores.contiguous(),
|
nms_keep_inds = nms(boxes.contiguous(), scores.contiguous(),
|
||||||
iou_thr)[1]
|
iou_thr)[1]
|
||||||
assert set([g[0].item()
|
assert {g[0].item() for g in np_groups} == set(nms_keep_inds.tolist())
|
||||||
for g in np_groups]) == set(nms_keep_inds.tolist())
|
|
||||||
|
|
||||||
# non empty tensor input
|
# non empty tensor input
|
||||||
tensor_dets = torch.from_numpy(np_dets)
|
tensor_dets = torch.from_numpy(np_dets)
|
||||||
|
|
|
@ -33,7 +33,7 @@ def run_before_and_after_test():
|
||||||
class WrapFunction(nn.Module):
|
class WrapFunction(nn.Module):
|
||||||
|
|
||||||
def __init__(self, wrapped_function):
|
def __init__(self, wrapped_function):
|
||||||
super(WrapFunction, self).__init__()
|
super().__init__()
|
||||||
self.wrapped_function = wrapped_function
|
self.wrapped_function = wrapped_function
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
@ -662,7 +662,7 @@ def test_cummax_cummin(key, opset=11):
|
||||||
input_list = [
|
input_list = [
|
||||||
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
|
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
|
||||||
torch.rand((2, 3, 4, 1, 5)),
|
torch.rand((2, 3, 4, 1, 5)),
|
||||||
torch.rand((1)),
|
torch.rand(1),
|
||||||
torch.rand((2, 0, 1)), # tensor.numel() is 0
|
torch.rand((2, 0, 1)), # tensor.numel() is 0
|
||||||
torch.FloatTensor(), # empty tensor
|
torch.FloatTensor(), # empty tensor
|
||||||
]
|
]
|
||||||
|
|
|
@ -15,7 +15,7 @@ class Loss(nn.Module):
|
||||||
return torch.mean(input - target)
|
return torch.mean(input - target)
|
||||||
|
|
||||||
|
|
||||||
class TestPSAMask(object):
|
class TestPSAMask:
|
||||||
|
|
||||||
def test_psa_mask_collect(self):
|
def test_psa_mask_collect(self):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
|
|
|
@ -29,7 +29,7 @@ outputs = [([[[[1., 2.], [3., 4.]]]], [[[[1., 1.], [1., 1.]]]]),
|
||||||
1.]]]])]
|
1.]]]])]
|
||||||
|
|
||||||
|
|
||||||
class TestRoiPool(object):
|
class TestRoiPool:
|
||||||
|
|
||||||
def test_roipool_gradcheck(self):
|
def test_roipool_gradcheck(self):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
|
|
|
@ -14,7 +14,7 @@ else:
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
class TestSyncBN(object):
|
class TestSyncBN:
|
||||||
|
|
||||||
def dist_init(self):
|
def dist_init(self):
|
||||||
rank = int(os.environ['SLURM_PROCID'])
|
rank = int(os.environ['SLURM_PROCID'])
|
||||||
|
|
|
@ -30,7 +30,7 @@ if not is_tensorrt_plugin_loaded():
|
||||||
class WrapFunction(nn.Module):
|
class WrapFunction(nn.Module):
|
||||||
|
|
||||||
def __init__(self, wrapped_function):
|
def __init__(self, wrapped_function):
|
||||||
super(WrapFunction, self).__init__()
|
super().__init__()
|
||||||
self.wrapped_function = wrapped_function
|
self.wrapped_function = wrapped_function
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
@ -576,7 +576,7 @@ def test_cummin_cummax(func: Callable):
|
||||||
input_list = [
|
input_list = [
|
||||||
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
|
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
|
||||||
torch.rand((2, 3, 4, 1, 5)).cuda(),
|
torch.rand((2, 3, 4, 1, 5)).cuda(),
|
||||||
torch.rand((1)).cuda()
|
torch.rand(1).cuda()
|
||||||
]
|
]
|
||||||
|
|
||||||
input_names = ['input']
|
input_names = ['input']
|
||||||
|
@ -756,7 +756,7 @@ def test_corner_pool(mode):
|
||||||
class CornerPoolWrapper(CornerPool):
|
class CornerPoolWrapper(CornerPool):
|
||||||
|
|
||||||
def __init__(self, mode):
|
def __init__(self, mode):
|
||||||
super(CornerPoolWrapper, self).__init__(mode)
|
super().__init__(mode)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# no use `torch.cummax`, instead `corner_pool` is used
|
# no use `torch.cummax`, instead `corner_pool` is used
|
||||||
|
|
|
@ -10,7 +10,7 @@ except ImportError:
|
||||||
_USING_PARROTS = False
|
_USING_PARROTS = False
|
||||||
|
|
||||||
|
|
||||||
class TestUpFirDn2d(object):
|
class TestUpFirDn2d:
|
||||||
"""Unit test for UpFirDn2d.
|
"""Unit test for UpFirDn2d.
|
||||||
|
|
||||||
Here, we just test the basic case of upsample version. More gerneal tests
|
Here, we just test the basic case of upsample version. More gerneal tests
|
||||||
|
|
|
@ -96,8 +96,8 @@ def test_voxelization_nondeterministic():
|
||||||
coors_all = dynamic_voxelization.forward(points)
|
coors_all = dynamic_voxelization.forward(points)
|
||||||
coors_all = coors_all.cpu().detach().numpy().tolist()
|
coors_all = coors_all.cpu().detach().numpy().tolist()
|
||||||
|
|
||||||
coors_set = set([tuple(c) for c in coors])
|
coors_set = {tuple(c) for c in coors}
|
||||||
coors_all_set = set([tuple(c) for c in coors_all])
|
coors_all_set = {tuple(c) for c in coors_all}
|
||||||
|
|
||||||
assert len(coors_set) == len(coors)
|
assert len(coors_set) == len(coors)
|
||||||
assert len(coors_set - coors_all_set) == 0
|
assert len(coors_set - coors_all_set) == 0
|
||||||
|
@ -112,7 +112,7 @@ def test_voxelization_nondeterministic():
|
||||||
|
|
||||||
for c, ps, n in zip(coors, voxels, num_points_per_voxel):
|
for c, ps, n in zip(coors, voxels, num_points_per_voxel):
|
||||||
ideal_voxel_points_set = coors_points_dict[tuple(c)]
|
ideal_voxel_points_set = coors_points_dict[tuple(c)]
|
||||||
voxel_points_set = set([tuple(p) for p in ps[:n]])
|
voxel_points_set = {tuple(p) for p in ps[:n]}
|
||||||
assert len(voxel_points_set) == n
|
assert len(voxel_points_set) == n
|
||||||
if n < max_num_points:
|
if n < max_num_points:
|
||||||
assert voxel_points_set == ideal_voxel_points_set
|
assert voxel_points_set == ideal_voxel_points_set
|
||||||
|
@ -133,7 +133,7 @@ def test_voxelization_nondeterministic():
|
||||||
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
|
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
|
||||||
coors = coors.cpu().detach().numpy().tolist()
|
coors = coors.cpu().detach().numpy().tolist()
|
||||||
|
|
||||||
coors_set = set([tuple(c) for c in coors])
|
coors_set = {tuple(c) for c in coors}
|
||||||
coors_all_set = set([tuple(c) for c in coors_all])
|
coors_all_set = {tuple(c) for c in coors_all}
|
||||||
|
|
||||||
assert len(coors_set) == len(coors) == len(coors_all_set)
|
assert len(coors_set) == len(coors) == len(coors_all_set)
|
||||||
|
|
|
@ -63,7 +63,7 @@ def test_is_module_wrapper():
|
||||||
|
|
||||||
# test module wrapper registry
|
# test module wrapper registry
|
||||||
@MODULE_WRAPPERS.register_module()
|
@MODULE_WRAPPERS.register_module()
|
||||||
class ModuleWrapper(object):
|
class ModuleWrapper:
|
||||||
|
|
||||||
def __init__(self, module):
|
def __init__(self, module):
|
||||||
self.module = module
|
self.module = module
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue