[Refactor] Add pyupgrade pre-commit hook (#2078)

* add pyupgrade hook

* run pyupgrade precommit hook
This commit is contained in:
谢昕辰 2022-09-19 14:06:29 +08:00 committed by GitHub
parent eef38883c8
commit 230246f557
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
101 changed files with 242 additions and 251 deletions

View File

@ -56,8 +56,7 @@ def main():
for model_name, yml_path in yml_list: for model_name, yml_path in yml_list:
# Default yaml loader unsafe. # Default yaml loader unsafe.
model_infos = yml.load( model_infos = yml.load(open(yml_path), Loader=yml.CLoader)['Models']
open(yml_path, 'r'), Loader=yml.CLoader)['Models']
for model_info in model_infos: for model_info in model_infos:
config_name = model_info['Name'] config_name = model_info['Name']
checkpoint_url = model_info['Weights'] checkpoint_url = model_info['Weights']

View File

@ -35,7 +35,7 @@ def process_checkpoint(in_file, out_file):
# The hash code calculation and rename command differ on different system # The hash code calculation and rename command differ on different system
# platform. # platform.
sha = calculate_file_sha256(out_file) sha = calculate_file_sha256(out_file)
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth'
os.rename(out_file, final_file) os.rename(out_file, final_file)
# Remove prefix and suffix # Remove prefix and suffix
@ -54,7 +54,7 @@ def get_final_iter(config):
def get_final_results(log_json_path, iter_num): def get_final_results(log_json_path, iter_num):
result_dict = dict() result_dict = dict()
last_iter = 0 last_iter = 0
with open(log_json_path, 'r') as f: with open(log_json_path) as f:
for line in f.readlines(): for line in f.readlines():
log_line = json.loads(line) log_line = json.loads(line)
if 'mode' not in log_line.keys(): if 'mode' not in log_line.keys():
@ -125,7 +125,7 @@ def main():
exp_dir = osp.join(work_dir, config_name) exp_dir = osp.join(work_dir, config_name)
# check whether the exps is finished # check whether the exps is finished
final_iter = get_final_iter(used_config) final_iter = get_final_iter(used_config)
final_model = 'iter_{}.pth'.format(final_iter) final_model = f'iter_{final_iter}.pth'
model_path = osp.join(exp_dir, final_model) model_path = osp.join(exp_dir, final_model)
# skip if the model is still training # skip if the model is still training

View File

@ -74,7 +74,7 @@ def main():
commands.append('\n') commands.append('\n')
commands.append('\n') commands.append('\n')
with open(args.txt_path, 'r') as f: with open(args.txt_path) as f:
model_cfgs = f.readlines() model_cfgs = f.readlines()
for i, cfg in enumerate(model_cfgs): for i, cfg in enumerate(model_cfgs):
create_train_bash_info(commands, cfg, script_name, '$PARTITION', create_train_bash_info(commands, cfg, script_name, '$PARTITION',

View File

@ -86,7 +86,7 @@ def main():
val_list = [] val_list = []
last_iter = 0 last_iter = 0
for log_name in log_list: for log_name in log_list:
with open(os.path.join(preceding_path, log_name), 'r') as f: with open(os.path.join(preceding_path, log_name)) as f:
# ignore the info line # ignore the info line
f.readline() f.readline()
all_lines = f.readlines() all_lines = f.readlines()

View File

@ -15,7 +15,7 @@ import sys
from lxml import etree from lxml import etree
from mmengine.fileio import dump from mmengine.fileio import dump
MMSEG_ROOT = osp.dirname(osp.dirname((osp.dirname(__file__)))) MMSEG_ROOT = osp.dirname(osp.dirname(osp.dirname(__file__)))
COLLECTIONS = [ COLLECTIONS = [
'ANN', 'APCNet', 'BiSeNetV1', 'BiSeNetV2', 'CCNet', 'CGNet', 'DANet', 'ANN', 'APCNet', 'BiSeNetV1', 'BiSeNetV2', 'CCNet', 'CGNet', 'DANet',
@ -42,7 +42,7 @@ def dump_yaml_and_check_difference(obj, filename, sort_keys=False):
str_dump = dump(obj, None, file_format='yaml', sort_keys=sort_keys) str_dump = dump(obj, None, file_format='yaml', sort_keys=sort_keys)
if osp.isfile(filename): if osp.isfile(filename):
file_exists = True file_exists = True
with open(filename, 'r', encoding='utf-8') as f: with open(filename, encoding='utf-8') as f:
str_orig = f.read() str_orig = f.read()
else: else:
file_exists = False file_exists = False
@ -97,7 +97,7 @@ def parse_md(md_file):
# should be set with head or neck of this config file. # should be set with head or neck of this config file.
is_backbone = None is_backbone = None
with open(md_file, 'r', encoding='UTF-8') as md: with open(md_file, encoding='UTF-8') as md:
lines = md.readlines() lines = md.readlines()
i = 0 i = 0
current_dataset = '' current_dataset = ''

View File

@ -52,6 +52,11 @@ repos:
language: python language: python
files: ^configs/.*\.md$ files: ^configs/.*\.md$
require_serial: true require_serial: true
- repo: https://github.com/asottile/pyupgrade
rev: v2.32.1
hooks:
- id: pyupgrade
args: ["--py36-plus"]
- repo: https://github.com/open-mmlab/pre-commit-hooks - repo: https://github.com/open-mmlab/pre-commit-hooks
rev: v0.2.0 # Use the rev to fix revision rev: v0.2.0 # Use the rev to fix revision
hooks: hooks:

View File

@ -28,7 +28,7 @@ version_file = '../../mmseg/version.py'
def get_version(): def get_version():
with open(version_file, 'r') as f: with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec')) exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__'] return locals()['__version__']

View File

@ -18,13 +18,15 @@ num_ckpts = 0
for f in files: for f in files:
url = osp.dirname(f.replace('../../', url_prefix)) url = osp.dirname(f.replace('../../', url_prefix))
with open(f, 'r') as content_file: with open(f) as content_file:
content = content_file.read() content = content_file.read()
title = content.split('\n')[0].replace('#', '').strip() title = content.split('\n')[0].replace('#', '').strip()
ckpts = set(x.lower().strip() ckpts = {
x.lower().strip()
for x in re.findall(r'https?://download.*\.pth', content) for x in re.findall(r'https?://download.*\.pth', content)
if 'mmsegmentation' in x) if 'mmsegmentation' in x
}
if len(ckpts) == 0: if len(ckpts) == 0:
continue continue
@ -34,7 +36,7 @@ for f in files:
assert len(_papertype) > 0 assert len(_papertype) > 0
papertype = _papertype[0] papertype = _papertype[0]
paper = set([(papertype, title)]) paper = {(papertype, title)}
titles.append(title) titles.append(title)
num_ckpts += len(ckpts) num_ckpts += len(ckpts)

View File

@ -28,7 +28,7 @@ version_file = '../../mmseg/version.py'
def get_version(): def get_version():
with open(version_file, 'r') as f: with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec')) exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__'] return locals()['__version__']

View File

@ -18,13 +18,15 @@ num_ckpts = 0
for f in files: for f in files:
url = osp.dirname(f.replace('../../', url_prefix)) url = osp.dirname(f.replace('../../', url_prefix))
with open(f, 'r') as content_file: with open(f) as content_file:
content = content_file.read() content = content_file.read()
title = content.split('\n')[0].replace('#', '').strip() title = content.split('\n')[0].replace('#', '').strip()
ckpts = set(x.lower().strip() ckpts = {
x.lower().strip()
for x in re.findall(r'https?://download.*\.pth', content) for x in re.findall(r'https?://download.*\.pth', content)
if 'mmsegmentation' in x) if 'mmsegmentation' in x
}
if len(ckpts) == 0: if len(ckpts) == 0:
continue continue
@ -34,7 +36,7 @@ for f in files:
assert len(_papertype) > 0 assert len(_papertype) > 0
papertype = _papertype[0] papertype = _papertype[0]
paper = set([(papertype, title)]) paper = {(papertype, title)}
titles.append(title) titles.append(title)
num_ckpts += len(ckpts) num_ckpts += len(ckpts)

View File

@ -204,5 +204,4 @@ class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
warnings.warn('DeprecationWarning: Layer_decay_rate will ' warnings.warn('DeprecationWarning: Layer_decay_rate will '
'be deleted, please use decay_rate instead.') 'be deleted, please use decay_rate instead.')
paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate') paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate')
super(LayerDecayOptimizerConstructor, super().__init__(optim_wrapper_cfg, paramwise_cfg)
self).__init__(optim_wrapper_cfg, paramwise_cfg)

View File

@ -212,7 +212,7 @@ class IoUMetric(BaseMetric):
metrics = [metrics] metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice', 'mFscore'] allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metrics).issubset(set(allowed_metrics)): if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics)) raise KeyError(f'metrics {metrics} is not supported')
all_acc = total_area_intersect.sum() / total_area_label.sum() all_acc = total_area_intersect.sum() / total_area_label.sum()
ret_metrics = OrderedDict({'aAcc': all_acc}) ret_metrics = OrderedDict({'aAcc': all_acc})

View File

@ -194,7 +194,7 @@ class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer):
init_values=None): init_values=None):
attn_cfg.update(dict(window_size=window_size, qk_scale=None)) attn_cfg.update(dict(window_size=window_size, qk_scale=None))
super(BEiTTransformerEncoderLayer, self).__init__( super().__init__(
embed_dims=embed_dims, embed_dims=embed_dims,
num_heads=num_heads, num_heads=num_heads,
feedforward_channels=feedforward_channels, feedforward_channels=feedforward_channels,
@ -214,9 +214,9 @@ class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer):
self.drop_path = build_dropout( self.drop_path = build_dropout(
dropout_layer) if dropout_layer else nn.Identity() dropout_layer) if dropout_layer else nn.Identity()
self.gamma_1 = nn.Parameter( self.gamma_1 = nn.Parameter(
init_values * torch.ones((embed_dims)), requires_grad=True) init_values * torch.ones(embed_dims), requires_grad=True)
self.gamma_2 = nn.Parameter( self.gamma_2 = nn.Parameter(
init_values * torch.ones((embed_dims)), requires_grad=True) init_values * torch.ones(embed_dims), requires_grad=True)
def build_attn(self, attn_cfg): def build_attn(self, attn_cfg):
self.attn = BEiTAttention(**attn_cfg) self.attn = BEiTAttention(**attn_cfg)
@ -287,7 +287,7 @@ class BEiT(BaseModule):
pretrained=None, pretrained=None,
init_values=0.1, init_values=0.1,
init_cfg=None): init_cfg=None):
super(BEiT, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
if isinstance(img_size, int): if isinstance(img_size, int):
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple): elif isinstance(img_size, tuple):
@ -505,7 +505,7 @@ class BEiT(BaseModule):
state_dict = self.resize_rel_pos_embed(checkpoint) state_dict = self.resize_rel_pos_embed(checkpoint)
self.load_state_dict(state_dict, False) self.load_state_dict(state_dict, False)
elif self.init_cfg is not None: elif self.init_cfg is not None:
super(BEiT, self).init_weights() super().init_weights()
else: else:
# We only implement the 'jax_impl' initialization implemented at # We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
@ -551,7 +551,7 @@ class BEiT(BaseModule):
return tuple(outs) return tuple(outs)
def train(self, mode=True): def train(self, mode=True):
super(BEiT, self).train(mode) super().train(mode)
if mode and self.norm_eval: if mode and self.norm_eval:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.LayerNorm): if isinstance(m, nn.LayerNorm):

View File

@ -29,7 +29,7 @@ class SpatialPath(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(SpatialPath, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert len(num_channels) == 4, 'Length of input channels \ assert len(num_channels) == 4, 'Length of input channels \
of Spatial Path must be 4!' of Spatial Path must be 4!'
@ -98,7 +98,7 @@ class AttentionRefinementModule(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(AttentionRefinementModule, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.conv_layer = ConvModule( self.conv_layer = ConvModule(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channel, out_channels=out_channel,
@ -152,7 +152,7 @@ class ContextPath(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(ContextPath, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert len(context_channels) == 3, 'Length of input channels \ assert len(context_channels) == 3, 'Length of input channels \
of Context Path must be 3!' of Context Path must be 3!'
@ -228,7 +228,7 @@ class FeatureFusionModule(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(FeatureFusionModule, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.conv1 = ConvModule( self.conv1 = ConvModule(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
@ -304,7 +304,7 @@ class BiSeNetV1(BaseModule):
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(BiSeNetV1, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert len(spatial_channels) == 4, 'Length of input channels \ assert len(spatial_channels) == 4, 'Length of input channels \
of Spatial Path must be 4!' of Spatial Path must be 4!'

View File

@ -37,7 +37,7 @@ class DetailBranch(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(DetailBranch, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
detail_branch = [] detail_branch = []
for i in range(len(detail_channels)): for i in range(len(detail_channels)):
if i == 0: if i == 0:
@ -126,7 +126,7 @@ class StemBlock(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(StemBlock, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.conv_first = ConvModule( self.conv_first = ConvModule(
in_channels=in_channels, in_channels=in_channels,
@ -207,7 +207,7 @@ class GELayer(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(GELayer, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
mid_channel = in_channels * exp_ratio mid_channel = in_channels * exp_ratio
self.conv1 = ConvModule( self.conv1 = ConvModule(
in_channels=in_channels, in_channels=in_channels,
@ -326,7 +326,7 @@ class CEBlock(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(CEBlock, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.gap = nn.Sequential( self.gap = nn.Sequential(
@ -385,7 +385,7 @@ class SemanticBranch(BaseModule):
in_channels=3, in_channels=3,
exp_ratio=6, exp_ratio=6,
init_cfg=None): init_cfg=None):
super(SemanticBranch, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels self.in_channels = in_channels
self.semantic_channels = semantic_channels self.semantic_channels = semantic_channels
self.semantic_stages = [] self.semantic_stages = []
@ -458,7 +458,7 @@ class BGALayer(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(BGALayer, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.out_channels = out_channels self.out_channels = out_channels
self.align_corners = align_corners self.align_corners = align_corners
self.detail_dwconv = nn.Sequential( self.detail_dwconv = nn.Sequential(
@ -594,7 +594,7 @@ class BiSeNetV2(BaseModule):
dict( dict(
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
] ]
super(BiSeNetV2, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels self.in_channels = in_channels
self.out_indices = out_indices self.out_indices = out_indices
self.detail_channels = detail_channels self.detail_channels = detail_channels

View File

@ -25,7 +25,7 @@ class GlobalContextExtractor(nn.Module):
""" """
def __init__(self, channel, reduction=16, with_cp=False): def __init__(self, channel, reduction=16, with_cp=False):
super(GlobalContextExtractor, self).__init__() super().__init__()
self.channel = channel self.channel = channel
self.reduction = reduction self.reduction = reduction
assert reduction >= 1 and channel >= reduction assert reduction >= 1 and channel >= reduction
@ -87,7 +87,7 @@ class ContextGuidedBlock(nn.Module):
norm_cfg=dict(type='BN', requires_grad=True), norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='PReLU'), act_cfg=dict(type='PReLU'),
with_cp=False): with_cp=False):
super(ContextGuidedBlock, self).__init__() super().__init__()
self.with_cp = with_cp self.with_cp = with_cp
self.downsample = downsample self.downsample = downsample
@ -172,7 +172,7 @@ class InputInjection(nn.Module):
"""Downsampling module for CGNet.""" """Downsampling module for CGNet."""
def __init__(self, num_downsampling): def __init__(self, num_downsampling):
super(InputInjection, self).__init__() super().__init__()
self.pool = nn.ModuleList() self.pool = nn.ModuleList()
for i in range(num_downsampling): for i in range(num_downsampling):
self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
@ -230,7 +230,7 @@ class CGNet(BaseModule):
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(CGNet, self).__init__(init_cfg) super().__init__(init_cfg)
assert not (init_cfg and pretrained), \ assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time' 'init_cfg and pretrained cannot be setting at the same time'
@ -364,7 +364,7 @@ class CGNet(BaseModule):
def train(self, mode=True): def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization """Convert the model into training mode will keeping the normalization
layer freezed.""" layer freezed."""
super(CGNet, self).train(mode) super().train(mode)
if mode and self.norm_eval: if mode and self.norm_eval:
for m in self.modules(): for m in self.modules():
# trick: eval have effect on BatchNorm only # trick: eval have effect on BatchNorm only

View File

@ -35,7 +35,7 @@ class DownsamplerBlock(BaseModule):
norm_cfg=dict(type='BN', eps=1e-3), norm_cfg=dict(type='BN', eps=1e-3),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(DownsamplerBlock, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.act_cfg = act_cfg self.act_cfg = act_cfg
@ -95,7 +95,7 @@ class NonBottleneck1d(BaseModule):
norm_cfg=dict(type='BN', eps=1e-3), norm_cfg=dict(type='BN', eps=1e-3),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(NonBottleneck1d, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
@ -168,7 +168,7 @@ class UpsamplerBlock(BaseModule):
norm_cfg=dict(type='BN', eps=1e-3), norm_cfg=dict(type='BN', eps=1e-3),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(UpsamplerBlock, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.act_cfg = act_cfg self.act_cfg = act_cfg
@ -242,7 +242,7 @@ class ERFNet(BaseModule):
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(ERFNet, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert len(enc_downsample_channels) \ assert len(enc_downsample_channels) \
== len(dec_upsample_channels)+1, 'Number of downsample\ == len(dec_upsample_channels)+1, 'Number of downsample\
block of encoder does not \ block of encoder does not \

View File

@ -36,7 +36,7 @@ class LearningToDownsample(nn.Module):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
dw_act_cfg=None): dw_act_cfg=None):
super(LearningToDownsample, self).__init__() super().__init__()
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.act_cfg = act_cfg self.act_cfg = act_cfg
@ -124,7 +124,7 @@ class GlobalFeatureExtractor(nn.Module):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
align_corners=False): align_corners=False):
super(GlobalFeatureExtractor, self).__init__() super().__init__()
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.act_cfg = act_cfg self.act_cfg = act_cfg
@ -220,7 +220,7 @@ class FeatureFusionModule(nn.Module):
dwconv_act_cfg=dict(type='ReLU'), dwconv_act_cfg=dict(type='ReLU'),
conv_act_cfg=None, conv_act_cfg=None,
align_corners=False): align_corners=False):
super(FeatureFusionModule, self).__init__() super().__init__()
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.dwconv_act_cfg = dwconv_act_cfg self.dwconv_act_cfg = dwconv_act_cfg
@ -340,7 +340,7 @@ class FastSCNN(BaseModule):
dw_act_cfg=None, dw_act_cfg=None,
init_cfg=None): init_cfg=None):
super(FastSCNN, self).__init__(init_cfg) super().__init__(init_cfg)
if init_cfg is None: if init_cfg is None:
self.init_cfg = [ self.init_cfg = [

View File

@ -30,7 +30,7 @@ class HRModule(BaseModule):
norm_cfg=dict(type='BN', requires_grad=True), norm_cfg=dict(type='BN', requires_grad=True),
block_init_cfg=None, block_init_cfg=None,
init_cfg=None): init_cfg=None):
super(HRModule, self).__init__(init_cfg) super().__init__(init_cfg)
self.block_init_cfg = block_init_cfg self.block_init_cfg = block_init_cfg
self._check_branches(num_branches, num_blocks, in_channels, self._check_branches(num_branches, num_blocks, in_channels,
num_channels) num_channels)
@ -308,7 +308,7 @@ class HRNet(BaseModule):
multiscale_output=True, multiscale_output=True,
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(HRNet, self).__init__(init_cfg) super().__init__(init_cfg)
self.pretrained = pretrained self.pretrained = pretrained
self.zero_init_residual = zero_init_residual self.zero_init_residual = zero_init_residual
@ -633,7 +633,7 @@ class HRNet(BaseModule):
def train(self, mode=True): def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization """Convert the model into training mode will keeping the normalization
layer freezed.""" layer freezed."""
super(HRNet, self).train(mode) super().train(mode)
self._freeze_stages() self._freeze_stages()
if mode and self.norm_eval: if mode and self.norm_eval:
for m in self.modules(): for m in self.modules():

View File

@ -64,7 +64,7 @@ class ICNet(BaseModule):
dict(type='Constant', val=1, layer='_BatchNorm'), dict(type='Constant', val=1, layer='_BatchNorm'),
dict(type='Normal', mean=0.01, layer='Linear') dict(type='Normal', mean=0.01, layer='Linear')
] ]
super(ICNet, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.align_corners = align_corners self.align_corners = align_corners
self.backbone = MODELS.build(backbone_cfg) self.backbone = MODELS.build(backbone_cfg)

View File

@ -100,7 +100,7 @@ class MAE(BEiT):
pretrained=None, pretrained=None,
init_values=0.1, init_values=0.1,
init_cfg=None): init_cfg=None):
super(MAE, self).__init__( super().__init__(
img_size=img_size, img_size=img_size,
patch_size=patch_size, patch_size=patch_size,
in_channels=in_channels, in_channels=in_channels,
@ -186,7 +186,7 @@ class MAE(BEiT):
state_dict = self.resize_abs_pos_embed(state_dict) state_dict = self.resize_abs_pos_embed(state_dict)
self.load_state_dict(state_dict, False) self.load_state_dict(state_dict, False)
elif self.init_cfg is not None: elif self.init_cfg is not None:
super(MAE, self).init_weights() super().init_weights()
else: else:
# We only implement the 'jax_impl' initialization implemented at # We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501

View File

@ -44,7 +44,7 @@ class MixFFN(BaseModule):
ffn_drop=0., ffn_drop=0.,
dropout_layer=None, dropout_layer=None,
init_cfg=None): init_cfg=None):
super(MixFFN, self).__init__(init_cfg) super().__init__(init_cfg)
self.embed_dims = embed_dims self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels self.feedforward_channels = feedforward_channels
@ -253,7 +253,7 @@ class TransformerEncoderLayer(BaseModule):
batch_first=True, batch_first=True,
sr_ratio=1, sr_ratio=1,
with_cp=False): with_cp=False):
super(TransformerEncoderLayer, self).__init__() super().__init__()
# The ret[0] of build_norm_layer is norm name. # The ret[0] of build_norm_layer is norm name.
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
@ -357,7 +357,7 @@ class MixVisionTransformer(BaseModule):
pretrained=None, pretrained=None,
init_cfg=None, init_cfg=None,
with_cp=False): with_cp=False):
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \ assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time' 'init_cfg and pretrained cannot be set at the same time'
@ -433,7 +433,7 @@ class MixVisionTransformer(BaseModule):
normal_init( normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
else: else:
super(MixVisionTransformer, self).init_weights() super().init_weights()
def forward(self, x): def forward(self, x):
outs = [] outs = []

View File

@ -63,7 +63,7 @@ class MobileNetV2(BaseModule):
with_cp=False, with_cp=False,
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(MobileNetV2, self).__init__(init_cfg) super().__init__(init_cfg)
self.pretrained = pretrained self.pretrained = pretrained
assert not (init_cfg and pretrained), \ assert not (init_cfg and pretrained), \
@ -189,7 +189,7 @@ class MobileNetV2(BaseModule):
param.requires_grad = False param.requires_grad = False
def train(self, mode=True): def train(self, mode=True):
super(MobileNetV2, self).train(mode) super().train(mode)
self._freeze_stages() self._freeze_stages()
if mode and self.norm_eval: if mode and self.norm_eval:
for m in self.modules(): for m in self.modules():

View File

@ -81,7 +81,7 @@ class MobileNetV3(BaseModule):
with_cp=False, with_cp=False,
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(MobileNetV3, self).__init__(init_cfg) super().__init__(init_cfg)
self.pretrained = pretrained self.pretrained = pretrained
assert not (init_cfg and pretrained), \ assert not (init_cfg and pretrained), \
@ -175,7 +175,7 @@ class MobileNetV3(BaseModule):
act_cfg=dict(type=act), act_cfg=dict(type=act),
with_cp=self.with_cp) with_cp=self.with_cp)
in_channels = out_channels in_channels = out_channels
layer_name = 'layer{}'.format(i + 1) layer_name = f'layer{i + 1}'
self.add_module(layer_name, layer) self.add_module(layer_name, layer)
layers.append(layer_name) layers.append(layer_name)
@ -192,7 +192,7 @@ class MobileNetV3(BaseModule):
conv_cfg=self.conv_cfg, conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
act_cfg=dict(type='HSwish')) act_cfg=dict(type='HSwish'))
layer_name = 'layer{}'.format(len(layer_setting) + 1) layer_name = f'layer{len(layer_setting) + 1}'
self.add_module(layer_name, layer) self.add_module(layer_name, layer)
layers.append(layer_name) layers.append(layer_name)
@ -259,7 +259,7 @@ class MobileNetV3(BaseModule):
param.requires_grad = False param.requires_grad = False
def train(self, mode=True): def train(self, mode=True):
super(MobileNetV3, self).train(mode) super().train(mode)
self._freeze_stages() self._freeze_stages()
if mode and self.norm_eval: if mode and self.norm_eval:
for m in self.modules(): for m in self.modules():

View File

@ -69,7 +69,7 @@ class SplitAttentionConv2d(nn.Module):
conv_cfg=None, conv_cfg=None,
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
dcn=None): dcn=None):
super(SplitAttentionConv2d, self).__init__() super().__init__()
inter_channels = max(in_channels * radix // reduction_factor, 32) inter_channels = max(in_channels * radix // reduction_factor, 32)
self.radix = radix self.radix = radix
self.groups = groups self.groups = groups
@ -174,7 +174,7 @@ class Bottleneck(_Bottleneck):
avg_down_stride=True, avg_down_stride=True,
**kwargs): **kwargs):
"""Bottleneck block for ResNeSt.""" """Bottleneck block for ResNeSt."""
super(Bottleneck, self).__init__(inplanes, planes, **kwargs) super().__init__(inplanes, planes, **kwargs)
if groups == 1: if groups == 1:
width = self.planes width = self.planes
@ -304,7 +304,7 @@ class ResNeSt(ResNetV1d):
self.radix = radix self.radix = radix
self.reduction_factor = reduction_factor self.reduction_factor = reduction_factor
self.avg_down_stride = avg_down_stride self.avg_down_stride = avg_down_stride
super(ResNeSt, self).__init__(**kwargs) super().__init__(**kwargs)
def make_res_layer(self, **kwargs): def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``.""" """Pack all blocks in a stage into a ``ResLayer``."""

View File

@ -29,7 +29,7 @@ class BasicBlock(BaseModule):
dcn=None, dcn=None,
plugins=None, plugins=None,
init_cfg=None): init_cfg=None):
super(BasicBlock, self).__init__(init_cfg) super().__init__(init_cfg)
assert dcn is None, 'Not implemented yet.' assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.'
@ -118,7 +118,7 @@ class Bottleneck(BaseModule):
dcn=None, dcn=None,
plugins=None, plugins=None,
init_cfg=None): init_cfg=None):
super(Bottleneck, self).__init__(init_cfg) super().__init__(init_cfg)
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
assert dcn is None or isinstance(dcn, dict) assert dcn is None or isinstance(dcn, dict)
assert plugins is None or isinstance(plugins, list) assert plugins is None or isinstance(plugins, list)
@ -418,7 +418,7 @@ class ResNet(BaseModule):
zero_init_residual=True, zero_init_residual=True,
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(ResNet, self).__init__(init_cfg) super().__init__(init_cfg)
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet') raise KeyError(f'invalid depth {depth} for resnet')
@ -676,7 +676,7 @@ class ResNet(BaseModule):
def train(self, mode=True): def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer """Convert the model into training mode while keep normalization layer
freezed.""" freezed."""
super(ResNet, self).train(mode) super().train(mode)
self._freeze_stages() self._freeze_stages()
if mode and self.norm_eval: if mode and self.norm_eval:
for m in self.modules(): for m in self.modules():
@ -696,8 +696,7 @@ class ResNetV1c(ResNet):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(ResNetV1c, self).__init__( super().__init__(deep_stem=True, avg_down=False, **kwargs)
deep_stem=True, avg_down=False, **kwargs)
@MODELS.register_module() @MODELS.register_module()
@ -710,5 +709,4 @@ class ResNetV1d(ResNet):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(ResNetV1d, self).__init__( super().__init__(deep_stem=True, avg_down=True, **kwargs)
deep_stem=True, avg_down=True, **kwargs)

View File

@ -23,7 +23,7 @@ class Bottleneck(_Bottleneck):
base_width=4, base_width=4,
base_channels=64, base_channels=64,
**kwargs): **kwargs):
super(Bottleneck, self).__init__(inplanes, planes, **kwargs) super().__init__(inplanes, planes, **kwargs)
if groups == 1: if groups == 1:
width = self.planes width = self.planes
@ -139,7 +139,7 @@ class ResNeXt(ResNet):
def __init__(self, groups=1, base_width=4, **kwargs): def __init__(self, groups=1, base_width=4, **kwargs):
self.groups = groups self.groups = groups
self.base_width = base_width self.base_width = base_width
super(ResNeXt, self).__init__(**kwargs) super().__init__(**kwargs)
def make_res_layer(self, **kwargs): def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``""" """Pack all blocks in a stage into a ``ResLayer``"""

View File

@ -35,7 +35,7 @@ class STDCModule(BaseModule):
num_convs=4, num_convs=4,
fusion_type='add', fusion_type='add',
init_cfg=None): init_cfg=None):
super(STDCModule, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert num_convs > 1 assert num_convs > 1
assert fusion_type in ['add', 'cat'] assert fusion_type in ['add', 'cat']
self.stride = stride self.stride = stride
@ -155,7 +155,7 @@ class FeatureFusionModule(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(FeatureFusionModule, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
channels = out_channels // scale_factor channels = out_channels // scale_factor
self.conv0 = ConvModule( self.conv0 = ConvModule(
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
@ -240,7 +240,7 @@ class STDCNet(BaseModule):
with_final_conv=False, with_final_conv=False,
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(STDCNet, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert stdc_type in self.arch_settings, \ assert stdc_type in self.arch_settings, \
f'invalid structure {stdc_type} for STDCNet.' f'invalid structure {stdc_type} for STDCNet.'
assert bottleneck_type in ['add', 'cat'],\ assert bottleneck_type in ['add', 'cat'],\
@ -370,7 +370,7 @@ class STDCContextPathNet(BaseModule):
align_corners=None, align_corners=None,
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
init_cfg=None): init_cfg=None):
super(STDCContextPathNet, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.backbone = MODELS.build(backbone_cfg) self.backbone = MODELS.build(backbone_cfg)
self.arms = ModuleList() self.arms = ModuleList()
self.convs = ModuleList() self.convs = ModuleList()

View File

@ -326,7 +326,7 @@ class SwinBlock(BaseModule):
with_cp=False, with_cp=False,
init_cfg=None): init_cfg=None):
super(SwinBlock, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.with_cp = with_cp self.with_cp = with_cp
@ -561,7 +561,7 @@ class SwinTransformer(BaseModule):
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
super(SwinTransformer, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
num_layers = len(depths) num_layers = len(depths)
self.out_indices = out_indices self.out_indices = out_indices
@ -636,7 +636,7 @@ class SwinTransformer(BaseModule):
def train(self, mode=True): def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed.""" """Convert the model into training mode while keep layers freezed."""
super(SwinTransformer, self).train(mode) super().train(mode)
self._freeze_stages() self._freeze_stages()
def _freeze_stages(self): def _freeze_stages(self):

View File

@ -37,7 +37,7 @@ class TIMMBackbone(BaseModule):
): ):
if timm is None: if timm is None:
raise RuntimeError('timm is not installed') raise RuntimeError('timm is not installed')
super(TIMMBackbone, self).__init__(init_cfg) super().__init__(init_cfg)
if 'norm_layer' in kwargs: if 'norm_layer' in kwargs:
kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer']) kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer'])
self.timm_model = timm.create_model( self.timm_model = timm.create_model(

View File

@ -62,7 +62,7 @@ class GlobalSubsampledAttention(EfficientMultiheadAttention):
norm_cfg=dict(type='LN'), norm_cfg=dict(type='LN'),
sr_ratio=1, sr_ratio=1,
init_cfg=None): init_cfg=None):
super(GlobalSubsampledAttention, self).__init__( super().__init__(
embed_dims, embed_dims,
num_heads, num_heads,
attn_drop=attn_drop, attn_drop=attn_drop,
@ -112,7 +112,7 @@ class GSAEncoderLayer(BaseModule):
norm_cfg=dict(type='LN'), norm_cfg=dict(type='LN'),
sr_ratio=1., sr_ratio=1.,
init_cfg=None): init_cfg=None):
super(GSAEncoderLayer, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
self.attn = GlobalSubsampledAttention( self.attn = GlobalSubsampledAttention(
@ -172,7 +172,7 @@ class LocallyGroupedSelfAttention(BaseModule):
proj_drop_rate=0., proj_drop_rate=0.,
window_size=1, window_size=1,
init_cfg=None): init_cfg=None):
super(LocallyGroupedSelfAttention, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \ assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \
f'divided by num_heads ' \ f'divided by num_heads ' \
@ -284,7 +284,7 @@ class LSAEncoderLayer(BaseModule):
window_size=1, window_size=1,
init_cfg=None): init_cfg=None):
super(LSAEncoderLayer, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads, self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads,
@ -325,7 +325,7 @@ class ConditionalPositionEncoding(BaseModule):
""" """
def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.proj = nn.Conv2d( self.proj = nn.Conv2d(
in_channels, in_channels,
embed_dims, embed_dims,
@ -401,7 +401,7 @@ class PCPVT(BaseModule):
norm_after_stage=False, norm_after_stage=False,
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(PCPVT, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \ assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time' 'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str): if isinstance(pretrained, str):
@ -471,7 +471,7 @@ class PCPVT(BaseModule):
def init_weights(self): def init_weights(self):
if self.init_cfg is not None: if self.init_cfg is not None:
super(PCPVT, self).init_weights() super().init_weights()
else: else:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
@ -563,11 +563,11 @@ class SVT(PCPVT):
norm_after_stage=True, norm_after_stage=True,
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(SVT, self).__init__(in_channels, embed_dims, patch_sizes, super().__init__(in_channels, embed_dims, patch_sizes, strides,
strides, num_heads, mlp_ratios, out_indices, num_heads, mlp_ratios, out_indices, qkv_bias,
qkv_bias, drop_rate, attn_drop_rate, drop_rate, attn_drop_rate, drop_path_rate, norm_cfg,
drop_path_rate, norm_cfg, depths, sr_ratios, depths, sr_ratios, norm_after_stage, pretrained,
norm_after_stage, pretrained, init_cfg) init_cfg)
# transformer encoder # transformer encoder
dpr = [ dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))

View File

@ -53,7 +53,7 @@ class BasicConvBlock(nn.Module):
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
dcn=None, dcn=None,
plugins=None): plugins=None):
super(BasicConvBlock, self).__init__() super().__init__()
assert dcn is None, 'Not implemented yet.' assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.'
@ -112,7 +112,7 @@ class DeconvModule(nn.Module):
*, *,
kernel_size=4, kernel_size=4,
scale_factor=2): scale_factor=2):
super(DeconvModule, self).__init__() super().__init__()
assert (kernel_size - scale_factor >= 0) and\ assert (kernel_size - scale_factor >= 0) and\
(kernel_size - scale_factor) % 2 == 0,\ (kernel_size - scale_factor) % 2 == 0,\
@ -191,7 +191,7 @@ class InterpConv(nn.Module):
padding=0, padding=0,
upsample_cfg=dict( upsample_cfg=dict(
scale_factor=2, mode='bilinear', align_corners=False)): scale_factor=2, mode='bilinear', align_corners=False)):
super(InterpConv, self).__init__() super().__init__()
self.with_cp = with_cp self.with_cp = with_cp
conv = ConvModule( conv = ConvModule(
@ -298,7 +298,7 @@ class UNet(BaseModule):
plugins=None, plugins=None,
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(UNet, self).__init__(init_cfg) super().__init__(init_cfg)
self.pretrained = pretrained self.pretrained = pretrained
assert not (init_cfg and pretrained), \ assert not (init_cfg and pretrained), \
@ -396,7 +396,7 @@ class UNet(BaseModule):
act_cfg=act_cfg, act_cfg=act_cfg,
dcn=None, dcn=None,
plugins=None)) plugins=None))
self.encoder.append((nn.Sequential(*enc_conv_block))) self.encoder.append(nn.Sequential(*enc_conv_block))
in_channels = base_channels * 2**i in_channels = base_channels * 2**i
def forward(self, x): def forward(self, x):
@ -415,7 +415,7 @@ class UNet(BaseModule):
def train(self, mode=True): def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer """Convert the model into training mode while keep normalization layer
freezed.""" freezed."""
super(UNet, self).train(mode) super().train(mode)
if mode and self.norm_eval: if mode and self.norm_eval:
for m in self.modules(): for m in self.modules():
# trick: eval have effect on BatchNorm only # trick: eval have effect on BatchNorm only

View File

@ -60,7 +60,7 @@ class TransformerEncoderLayer(BaseModule):
attn_cfg=dict(), attn_cfg=dict(),
ffn_cfg=dict(), ffn_cfg=dict(),
with_cp=False): with_cp=False):
super(TransformerEncoderLayer, self).__init__() super().__init__()
self.norm1_name, norm1 = build_norm_layer( self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1) norm_cfg, embed_dims, postfix=1)
@ -197,7 +197,7 @@ class VisionTransformer(BaseModule):
with_cp=False, with_cp=False,
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(VisionTransformer, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
if isinstance(img_size, int): if isinstance(img_size, int):
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
@ -315,7 +315,7 @@ class VisionTransformer(BaseModule):
load_state_dict(self, state_dict, strict=False, logger=None) load_state_dict(self, state_dict, strict=False, logger=None)
elif self.init_cfg is not None: elif self.init_cfg is not None:
super(VisionTransformer, self).init_weights() super().init_weights()
else: else:
# We only implement the 'jax_impl' initialization implemented at # We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
@ -431,7 +431,7 @@ class VisionTransformer(BaseModule):
return tuple(outs) return tuple(outs)
def train(self, mode=True): def train(self, mode=True):
super(VisionTransformer, self).train(mode) super().train(mode)
if mode and self.norm_eval: if mode and self.norm_eval:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.LayerNorm): if isinstance(m, nn.LayerNorm):

View File

@ -17,7 +17,7 @@ class PPMConcat(nn.ModuleList):
""" """
def __init__(self, pool_scales=(1, 3, 6, 8)): def __init__(self, pool_scales=(1, 3, 6, 8)):
super(PPMConcat, self).__init__( super().__init__(
[nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales]) [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
def forward(self, feats): def forward(self, feats):
@ -58,7 +58,7 @@ class SelfAttentionBlock(_SelfAttentionBlock):
query_downsample = nn.MaxPool2d(kernel_size=query_scale) query_downsample = nn.MaxPool2d(kernel_size=query_scale)
else: else:
query_downsample = None query_downsample = None
super(SelfAttentionBlock, self).__init__( super().__init__(
key_in_channels=low_in_channels, key_in_channels=low_in_channels,
query_in_channels=high_in_channels, query_in_channels=high_in_channels,
channels=channels, channels=channels,
@ -100,7 +100,7 @@ class AFNB(nn.Module):
def __init__(self, low_in_channels, high_in_channels, channels, def __init__(self, low_in_channels, high_in_channels, channels,
out_channels, query_scales, key_pool_scales, conv_cfg, out_channels, query_scales, key_pool_scales, conv_cfg,
norm_cfg, act_cfg): norm_cfg, act_cfg):
super(AFNB, self).__init__() super().__init__()
self.stages = nn.ModuleList() self.stages = nn.ModuleList()
for query_scale in query_scales: for query_scale in query_scales:
self.stages.append( self.stages.append(
@ -150,7 +150,7 @@ class APNB(nn.Module):
def __init__(self, in_channels, channels, out_channels, query_scales, def __init__(self, in_channels, channels, out_channels, query_scales,
key_pool_scales, conv_cfg, norm_cfg, act_cfg): key_pool_scales, conv_cfg, norm_cfg, act_cfg):
super(APNB, self).__init__() super().__init__()
self.stages = nn.ModuleList() self.stages = nn.ModuleList()
for query_scale in query_scales: for query_scale in query_scales:
self.stages.append( self.stages.append(
@ -201,8 +201,7 @@ class ANNHead(BaseDecodeHead):
query_scales=(1, ), query_scales=(1, ),
key_pool_scales=(1, 3, 6, 8), key_pool_scales=(1, 3, 6, 8),
**kwargs): **kwargs):
super(ANNHead, self).__init__( super().__init__(input_transform='multiple_select', **kwargs)
input_transform='multiple_select', **kwargs)
assert len(self.in_channels) == 2 assert len(self.in_channels) == 2
low_in_channels, high_in_channels = self.in_channels low_in_channels, high_in_channels = self.in_channels
self.project_channels = project_channels self.project_channels = project_channels

View File

@ -25,7 +25,7 @@ class ACM(nn.Module):
def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
norm_cfg, act_cfg): norm_cfg, act_cfg):
super(ACM, self).__init__() super().__init__()
self.pool_scale = pool_scale self.pool_scale = pool_scale
self.fusion = fusion self.fusion = fusion
self.in_channels = in_channels self.in_channels = in_channels
@ -123,7 +123,7 @@ class APCHead(BaseDecodeHead):
""" """
def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs): def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
super(APCHead, self).__init__(**kwargs) super().__init__(**kwargs)
assert isinstance(pool_scales, (list, tuple)) assert isinstance(pool_scales, (list, tuple))
self.pool_scales = pool_scales self.pool_scales = pool_scales
self.fusion = fusion self.fusion = fusion

View File

@ -22,7 +22,7 @@ class ASPPModule(nn.ModuleList):
def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
act_cfg): act_cfg):
super(ASPPModule, self).__init__() super().__init__()
self.dilations = dilations self.dilations = dilations
self.in_channels = in_channels self.in_channels = in_channels
self.channels = channels self.channels = channels
@ -63,7 +63,7 @@ class ASPPHead(BaseDecodeHead):
""" """
def __init__(self, dilations=(1, 6, 12, 18), **kwargs): def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
super(ASPPHead, self).__init__(**kwargs) super().__init__(**kwargs)
assert isinstance(dilations, (list, tuple)) assert isinstance(dilations, (list, tuple))
self.dilations = dilations self.dilations = dilations
self.image_pool = nn.Sequential( self.image_pool = nn.Sequential(

View File

@ -13,7 +13,7 @@ class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
:class:`CascadeEncoderDecoder.""" :class:`CascadeEncoderDecoder."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@abstractmethod @abstractmethod
def forward(self, inputs, prev_output): def forward(self, inputs, prev_output):

View File

@ -26,7 +26,7 @@ class CCHead(FCNHead):
if CrissCrossAttention is None: if CrissCrossAttention is None:
raise RuntimeError('Please install mmcv-full for ' raise RuntimeError('Please install mmcv-full for '
'CrissCrossAttention ops') 'CrissCrossAttention ops')
super(CCHead, self).__init__(num_convs=2, **kwargs) super().__init__(num_convs=2, **kwargs)
self.recurrence = recurrence self.recurrence = recurrence
self.cca = CrissCrossAttention(self.channels) self.cca = CrissCrossAttention(self.channels)

View File

@ -21,7 +21,7 @@ class PAM(_SelfAttentionBlock):
""" """
def __init__(self, in_channels, channels): def __init__(self, in_channels, channels):
super(PAM, self).__init__( super().__init__(
key_in_channels=in_channels, key_in_channels=in_channels,
query_in_channels=in_channels, query_in_channels=in_channels,
channels=channels, channels=channels,
@ -43,7 +43,7 @@ class PAM(_SelfAttentionBlock):
def forward(self, x): def forward(self, x):
"""Forward function.""" """Forward function."""
out = super(PAM, self).forward(x, x) out = super().forward(x, x)
out = self.gamma(out) + x out = self.gamma(out) + x
return out return out
@ -53,7 +53,7 @@ class CAM(nn.Module):
"""Channel Attention Module (CAM)""" """Channel Attention Module (CAM)"""
def __init__(self): def __init__(self):
super(CAM, self).__init__() super().__init__()
self.gamma = Scale(0) self.gamma = Scale(0)
def forward(self, x): def forward(self, x):
@ -86,7 +86,7 @@ class DAHead(BaseDecodeHead):
""" """
def __init__(self, pam_channels, **kwargs): def __init__(self, pam_channels, **kwargs):
super(DAHead, self).__init__(**kwargs) super().__init__(**kwargs)
self.pam_channels = pam_channels self.pam_channels = pam_channels
self.pam_in_conv = ConvModule( self.pam_in_conv = ConvModule(
self.in_channels, self.in_channels,
@ -173,15 +173,12 @@ class DAHead(BaseDecodeHead):
loss = dict() loss = dict()
loss.update( loss.update(
add_prefix( add_prefix(
super(DAHead, self).loss_by_feat(pam_cam_seg_logit, super().loss_by_feat(pam_cam_seg_logit, batch_data_samples),
batch_data_samples),
'pam_cam')) 'pam_cam'))
loss.update( loss.update(
add_prefix( add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples),
super(DAHead, self).loss_by_feat(pam_seg_logit, 'pam'))
batch_data_samples), 'pam'))
loss.update( loss.update(
add_prefix( add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples),
super(DAHead, self).loss_by_feat(cam_seg_logit, 'cam'))
batch_data_samples), 'cam'))
return loss return loss

View File

@ -97,7 +97,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
align_corners=False, align_corners=False,
init_cfg=dict( init_cfg=dict(
type='Normal', std=0.01, override=dict(name='conv_seg'))): type='Normal', std=0.01, override=dict(name='conv_seg'))):
super(BaseDecodeHead, self).__init__(init_cfg) super().__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform) self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels self.channels = channels
self.num_classes = num_classes self.num_classes = num_classes

View File

@ -24,7 +24,7 @@ class DCM(nn.Module):
def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
norm_cfg, act_cfg): norm_cfg, act_cfg):
super(DCM, self).__init__() super().__init__()
self.filter_size = filter_size self.filter_size = filter_size
self.fusion = fusion self.fusion = fusion
self.in_channels = in_channels self.in_channels = in_channels
@ -105,7 +105,7 @@ class DMHead(BaseDecodeHead):
""" """
def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs): def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
super(DMHead, self).__init__(**kwargs) super().__init__(**kwargs)
assert isinstance(filter_sizes, (list, tuple)) assert isinstance(filter_sizes, (list, tuple))
self.filter_sizes = filter_sizes self.filter_sizes = filter_sizes
self.fusion = fusion self.fusion = fusion

View File

@ -111,7 +111,7 @@ class DNLHead(FCNHead):
mode='embedded_gaussian', mode='embedded_gaussian',
temperature=0.05, temperature=0.05,
**kwargs): **kwargs):
super(DNLHead, self).__init__(num_convs=2, **kwargs) super().__init__(num_convs=2, **kwargs)
self.reduction = reduction self.reduction = reduction
self.use_scale = use_scale self.use_scale = use_scale
self.mode = mode self.mode = mode

View File

@ -30,7 +30,7 @@ class ReassembleBlocks(BaseModule):
readout_type='ignore', readout_type='ignore',
patch_size=16, patch_size=16,
init_cfg=None): init_cfg=None):
super(ReassembleBlocks, self).__init__(init_cfg) super().__init__(init_cfg)
assert readout_type in ['ignore', 'add', 'project'] assert readout_type in ['ignore', 'add', 'project']
self.readout_type = readout_type self.readout_type = readout_type
@ -116,7 +116,7 @@ class PreActResidualConvUnit(BaseModule):
stride=1, stride=1,
dilation=1, dilation=1,
init_cfg=None): init_cfg=None):
super(PreActResidualConvUnit, self).__init__(init_cfg) super().__init__(init_cfg)
self.conv1 = ConvModule( self.conv1 = ConvModule(
in_channels, in_channels,
@ -168,7 +168,7 @@ class FeatureFusionBlock(BaseModule):
expand=False, expand=False,
align_corners=True, align_corners=True,
init_cfg=None): init_cfg=None):
super(FeatureFusionBlock, self).__init__(init_cfg) super().__init__(init_cfg)
self.in_channels = in_channels self.in_channels = in_channels
self.expand = expand self.expand = expand
@ -242,7 +242,7 @@ class DPTHead(BaseDecodeHead):
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
**kwargs): **kwargs):
super(DPTHead, self).__init__(**kwargs) super().__init__(**kwargs)
self.in_channels = self.in_channels self.in_channels = self.in_channels
self.expand_channels = expand_channels self.expand_channels = expand_channels

View File

@ -30,7 +30,7 @@ class EMAModule(nn.Module):
""" """
def __init__(self, channels, num_bases, num_stages, momentum): def __init__(self, channels, num_bases, num_stages, momentum):
super(EMAModule, self).__init__() super().__init__()
assert num_stages >= 1, 'num_stages must be at least 1!' assert num_stages >= 1, 'num_stages must be at least 1!'
self.num_bases = num_bases self.num_bases = num_bases
self.num_stages = num_stages self.num_stages = num_stages
@ -99,7 +99,7 @@ class EMAHead(BaseDecodeHead):
concat_input=True, concat_input=True,
momentum=0.1, momentum=0.1,
**kwargs): **kwargs):
super(EMAHead, self).__init__(**kwargs) super().__init__(**kwargs)
self.ema_channels = ema_channels self.ema_channels = ema_channels
self.num_bases = num_bases self.num_bases = num_bases
self.num_stages = num_stages self.num_stages = num_stages

View File

@ -26,7 +26,7 @@ class EncModule(nn.Module):
""" """
def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
super(EncModule, self).__init__() super().__init__()
self.encoding_project = ConvModule( self.encoding_project = ConvModule(
in_channels, in_channels,
in_channels, in_channels,
@ -90,8 +90,7 @@ class EncHead(BaseDecodeHead):
use_sigmoid=True, use_sigmoid=True,
loss_weight=0.2), loss_weight=0.2),
**kwargs): **kwargs):
super(EncHead, self).__init__( super().__init__(input_transform='multiple_select', **kwargs)
input_transform='multiple_select', **kwargs)
self.use_se_loss = use_se_loss self.use_se_loss = use_se_loss
self.add_lateral = add_lateral self.add_lateral = add_lateral
self.num_codes = num_codes self.num_codes = num_codes
@ -188,8 +187,7 @@ class EncHead(BaseDecodeHead):
"""Compute segmentation and semantic encoding loss.""" """Compute segmentation and semantic encoding loss."""
seg_logit, se_seg_logit = seg_logit seg_logit, se_seg_logit = seg_logit
loss = dict() loss = dict()
loss.update( loss.update(super().loss_by_feat(seg_logit, batch_data_samples))
super(EncHead, self).loss_by_feat(seg_logit, batch_data_samples))
seg_label = self._stack_batch_gt(batch_data_samples) seg_label = self._stack_batch_gt(batch_data_samples)
se_loss = self.loss_se_decode( se_loss = self.loss_se_decode(

View File

@ -31,7 +31,7 @@ class FCNHead(BaseDecodeHead):
self.num_convs = num_convs self.num_convs = num_convs
self.concat_input = concat_input self.concat_input = concat_input
self.kernel_size = kernel_size self.kernel_size = kernel_size
super(FCNHead, self).__init__(**kwargs) super().__init__(**kwargs)
if num_convs == 0: if num_convs == 0:
assert self.in_channels == self.channels assert self.in_channels == self.channels

View File

@ -22,8 +22,7 @@ class FPNHead(BaseDecodeHead):
""" """
def __init__(self, feature_strides, **kwargs): def __init__(self, feature_strides, **kwargs):
super(FPNHead, self).__init__( super().__init__(input_transform='multiple_select', **kwargs)
input_transform='multiple_select', **kwargs)
assert len(feature_strides) == len(self.in_channels) assert len(feature_strides) == len(self.in_channels)
assert min(feature_strides) == feature_strides[0] assert min(feature_strides) == feature_strides[0]
self.feature_strides = feature_strides self.feature_strides = feature_strides

View File

@ -26,7 +26,7 @@ class GCHead(FCNHead):
pooling_type='att', pooling_type='att',
fusion_types=('channel_add', ), fusion_types=('channel_add', ),
**kwargs): **kwargs):
super(GCHead, self).__init__(num_convs=2, **kwargs) super().__init__(num_convs=2, **kwargs)
self.ratio = ratio self.ratio = ratio
self.pooling_type = pooling_type self.pooling_type = pooling_type
self.fusion_types = fusion_types self.fusion_types = fusion_types

View File

@ -22,7 +22,7 @@ class SelfAttentionBlock(_SelfAttentionBlock):
""" """
def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg): def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg):
super(SelfAttentionBlock, self).__init__( super().__init__(
key_in_channels=in_channels, key_in_channels=in_channels,
query_in_channels=in_channels, query_in_channels=in_channels,
channels=channels, channels=channels,
@ -51,7 +51,7 @@ class SelfAttentionBlock(_SelfAttentionBlock):
def forward(self, x): def forward(self, x):
"""Forward function.""" """Forward function."""
context = super(SelfAttentionBlock, self).forward(x, x) context = super().forward(x, x)
return self.output_project(context) return self.output_project(context)
@ -68,7 +68,7 @@ class ISAHead(BaseDecodeHead):
""" """
def __init__(self, isa_channels, down_factor=(8, 8), **kwargs): def __init__(self, isa_channels, down_factor=(8, 8), **kwargs):
super(ISAHead, self).__init__(**kwargs) super().__init__(**kwargs)
self.down_factor = down_factor self.down_factor = down_factor
self.in_conv = ConvModule( self.in_conv = ConvModule(

View File

@ -48,7 +48,7 @@ class KernelUpdator(nn.Module):
norm_cfg=dict(type='LN'), norm_cfg=dict(type='LN'),
act_cfg=dict(type='ReLU', inplace=True), act_cfg=dict(type='ReLU', inplace=True),
): ):
super(KernelUpdator, self).__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.feat_channels = feat_channels self.feat_channels = feat_channels
self.out_channels_raw = out_channels self.out_channels_raw = out_channels
@ -213,7 +213,7 @@ class KernelUpdateHead(nn.Module):
out_channels=256, out_channels=256,
act_cfg=dict(type='ReLU', inplace=True), act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'))): norm_cfg=dict(type='LN'))):
super(KernelUpdateHead, self).__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels

View File

@ -22,7 +22,7 @@ class LRASPPHead(BaseDecodeHead):
""" """
def __init__(self, branch_channels=(32, 64), **kwargs): def __init__(self, branch_channels=(32, 64), **kwargs):
super(LRASPPHead, self).__init__(**kwargs) super().__init__(**kwargs)
if self.input_transform != 'multiple_select': if self.input_transform != 'multiple_select':
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
f'must be \'multiple_select\'. But received ' f'must be \'multiple_select\'. But received '

View File

@ -26,7 +26,7 @@ class NLHead(FCNHead):
use_scale=True, use_scale=True,
mode='embedded_gaussian', mode='embedded_gaussian',
**kwargs): **kwargs):
super(NLHead, self).__init__(num_convs=2, **kwargs) super().__init__(num_convs=2, **kwargs)
self.reduction = reduction self.reduction = reduction
self.use_scale = use_scale self.use_scale = use_scale
self.mode = mode self.mode = mode

View File

@ -18,7 +18,7 @@ class SpatialGatherModule(nn.Module):
""" """
def __init__(self, scale): def __init__(self, scale):
super(SpatialGatherModule, self).__init__() super().__init__()
self.scale = scale self.scale = scale
def forward(self, feats, probs): def forward(self, feats, probs):
@ -46,7 +46,7 @@ class ObjectAttentionBlock(_SelfAttentionBlock):
query_downsample = nn.MaxPool2d(kernel_size=scale) query_downsample = nn.MaxPool2d(kernel_size=scale)
else: else:
query_downsample = None query_downsample = None
super(ObjectAttentionBlock, self).__init__( super().__init__(
key_in_channels=in_channels, key_in_channels=in_channels,
query_in_channels=in_channels, query_in_channels=in_channels,
channels=channels, channels=channels,
@ -73,8 +73,7 @@ class ObjectAttentionBlock(_SelfAttentionBlock):
def forward(self, query_feats, key_feats): def forward(self, query_feats, key_feats):
"""Forward function.""" """Forward function."""
context = super(ObjectAttentionBlock, context = super().forward(query_feats, key_feats)
self).forward(query_feats, key_feats)
output = self.bottleneck(torch.cat([context, query_feats], dim=1)) output = self.bottleneck(torch.cat([context, query_feats], dim=1))
if self.query_downsample is not None: if self.query_downsample is not None:
output = resize(query_feats) output = resize(query_feats)
@ -96,7 +95,7 @@ class OCRHead(BaseCascadeDecodeHead):
""" """
def __init__(self, ocr_channels, scale=1, **kwargs): def __init__(self, ocr_channels, scale=1, **kwargs):
super(OCRHead, self).__init__(**kwargs) super().__init__(**kwargs)
self.ocr_channels = ocr_channels self.ocr_channels = ocr_channels
self.scale = scale self.scale = scale
self.object_context_block = ObjectAttentionBlock( self.object_context_block = ObjectAttentionBlock(

View File

@ -74,7 +74,7 @@ class PointHead(BaseCascadeDecodeHead):
norm_cfg=None, norm_cfg=None,
act_cfg=dict(type='ReLU', inplace=False), act_cfg=dict(type='ReLU', inplace=False),
**kwargs): **kwargs):
super(PointHead, self).__init__( super().__init__(
input_transform='multiple_select', input_transform='multiple_select',
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,

View File

@ -43,7 +43,7 @@ class PSAHead(BaseDecodeHead):
**kwargs): **kwargs):
if PSAMask is None: if PSAMask is None:
raise RuntimeError('Please install mmcv-full for PSAMask ops') raise RuntimeError('Please install mmcv-full for PSAMask ops')
super(PSAHead, self).__init__(**kwargs) super().__init__(**kwargs)
assert psa_type in ['collect', 'distribute', 'bi-direction'] assert psa_type in ['collect', 'distribute', 'bi-direction']
self.psa_type = psa_type self.psa_type = psa_type
self.compact = compact self.compact = compact

View File

@ -24,7 +24,7 @@ class PPM(nn.ModuleList):
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
act_cfg, align_corners, **kwargs): act_cfg, align_corners, **kwargs):
super(PPM, self).__init__() super().__init__()
self.pool_scales = pool_scales self.pool_scales = pool_scales
self.align_corners = align_corners self.align_corners = align_corners
self.in_channels = in_channels self.in_channels = in_channels
@ -72,7 +72,7 @@ class PSPHead(BaseDecodeHead):
""" """
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(PSPHead, self).__init__(**kwargs) super().__init__(**kwargs)
assert isinstance(pool_scales, (list, tuple)) assert isinstance(pool_scales, (list, tuple))
self.pool_scales = pool_scales self.pool_scales = pool_scales
self.psp_modules = PPM( self.psp_modules = PPM(

View File

@ -61,8 +61,7 @@ class SegmenterMaskTransformerHead(BaseDecodeHead):
init_std=0.02, init_std=0.02,
**kwargs, **kwargs,
): ):
super(SegmenterMaskTransformerHead, self).__init__( super().__init__(in_channels=in_channels, **kwargs)
in_channels=in_channels, **kwargs)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
self.layers = ModuleList() self.layers = ModuleList()

View File

@ -13,7 +13,7 @@ class DepthwiseSeparableASPPModule(ASPPModule):
conv.""" conv."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) super().__init__(**kwargs)
for i, dilation in enumerate(self.dilations): for i, dilation in enumerate(self.dilations):
if dilation > 1: if dilation > 1:
self[i] = DepthwiseSeparableConvModule( self[i] = DepthwiseSeparableConvModule(
@ -41,7 +41,7 @@ class DepthwiseSeparableASPPHead(ASPPHead):
""" """
def __init__(self, c1_in_channels, c1_channels, **kwargs): def __init__(self, c1_in_channels, c1_channels, **kwargs):
super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) super().__init__(**kwargs)
assert c1_in_channels >= 0 assert c1_in_channels >= 0
self.aspp_modules = DepthwiseSeparableASPPModule( self.aspp_modules = DepthwiseSeparableASPPModule(
dilations=self.dilations, dilations=self.dilations,

View File

@ -32,7 +32,7 @@ class DepthwiseSeparableFCNHead(FCNHead):
""" """
def __init__(self, dw_act_cfg=None, **kwargs): def __init__(self, dw_act_cfg=None, **kwargs):
super(DepthwiseSeparableFCNHead, self).__init__(**kwargs) super().__init__(**kwargs)
self.convs[0] = DepthwiseSeparableConvModule( self.convs[0] = DepthwiseSeparableConvModule(
self.in_channels, self.in_channels,
self.channels, self.channels,

View File

@ -21,8 +21,7 @@ class SETRMLAHead(BaseDecodeHead):
""" """
def __init__(self, mla_channels=128, up_scale=4, **kwargs): def __init__(self, mla_channels=128, up_scale=4, **kwargs):
super(SETRMLAHead, self).__init__( super().__init__(input_transform='multiple_select', **kwargs)
input_transform='multiple_select', **kwargs)
self.mla_channels = mla_channels self.mla_channels = mla_channels
num_inputs = len(self.in_channels) num_inputs = len(self.in_channels)

View File

@ -41,7 +41,7 @@ class SETRUPHead(BaseDecodeHead):
assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'
super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs) super().__init__(init_cfg=init_cfg, **kwargs)
assert isinstance(self.in_channels, int) assert isinstance(self.in_channels, int)

View File

@ -21,7 +21,7 @@ class STDCHead(FCNHead):
""" """
def __init__(self, boundary_threshold=0.1, **kwargs): def __init__(self, boundary_threshold=0.1, **kwargs):
super(STDCHead, self).__init__(**kwargs) super().__init__(**kwargs)
self.boundary_threshold = boundary_threshold self.boundary_threshold = boundary_threshold
# Using register buffer to make laplacian kernel on the same # Using register buffer to make laplacian kernel on the same
# device of `seg_label`. # device of `seg_label`.
@ -93,6 +93,5 @@ class STDCHead(FCNHead):
seg_data_sample.gt_sem_seg = PixelData(data=label) seg_data_sample.gt_sem_seg = PixelData(data=label)
batch_sample_list.append(seg_data_sample) batch_sample_list.append(seg_data_sample)
loss = super(STDCHead, self).loss_by_feat(seg_logits, loss = super().loss_by_feat(seg_logits, batch_sample_list)
batch_sample_list)
return loss return loss

View File

@ -22,8 +22,7 @@ class UPerHead(BaseDecodeHead):
""" """
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(UPerHead, self).__init__( super().__init__(input_transform='multiple_select', **kwargs)
input_transform='multiple_select', **kwargs)
# PSP Module # PSP Module
self.psp_modules = PPM( self.psp_modules = PPM(
pool_scales, pool_scales,

View File

@ -223,7 +223,7 @@ class CrossEntropyLoss(nn.Module):
loss_weight=1.0, loss_weight=1.0,
loss_name='loss_ce', loss_name='loss_ce',
avg_non_ignore=False): avg_non_ignore=False):
super(CrossEntropyLoss, self).__init__() super().__init__()
assert (use_sigmoid is False) or (use_mask is False) assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid self.use_sigmoid = use_sigmoid
self.use_mask = use_mask self.use_mask = use_mask

View File

@ -80,7 +80,7 @@ class DiceLoss(nn.Module):
ignore_index=255, ignore_index=255,
loss_name='loss_dice', loss_name='loss_dice',
**kwards): **kwards):
super(DiceLoss, self).__init__() super().__init__()
self.smooth = smooth self.smooth = smooth
self.exponent = exponent self.exponent = exponent
self.reduction = reduction self.reduction = reduction

View File

@ -172,7 +172,7 @@ class FocalLoss(nn.Module):
loss item to be included into the backward graph, `loss_` must loss item to be included into the backward graph, `loss_` must
be the prefix of the name. Defaults to 'loss_focal'. be the prefix of the name. Defaults to 'loss_focal'.
""" """
super(FocalLoss, self).__init__() super().__init__()
assert use_sigmoid is True, \ assert use_sigmoid is True, \
'AssertionError: Only sigmoid focal loss supported now.' 'AssertionError: Only sigmoid focal loss supported now.'
assert reduction in ('none', 'mean', 'sum'), \ assert reduction in ('none', 'mean', 'sum'), \

View File

@ -257,7 +257,7 @@ class LovaszLoss(nn.Module):
class_weight=None, class_weight=None,
loss_weight=1.0, loss_weight=1.0,
loss_name='loss_lovasz'): loss_name='loss_lovasz'):
super(LovaszLoss, self).__init__() super().__init__()
assert loss_type in ('binary', 'multi_class'), "loss_type should be \ assert loss_type in ('binary', 'multi_class'), "loss_type should be \
'binary' or 'multi_class'." 'binary' or 'multi_class'."

View File

@ -88,7 +88,7 @@ class TverskyLoss(nn.Module):
alpha=0.3, alpha=0.3,
beta=0.7, beta=0.7,
loss_name='loss_tversky'): loss_name='loss_tversky'):
super(TverskyLoss, self).__init__() super().__init__()
self.smooth = smooth self.smooth = smooth
self.class_weight = get_class_weight(class_weight) self.class_weight = get_class_weight(class_weight)
self.loss_weight = loss_weight self.loss_weight = loss_weight

View File

@ -23,7 +23,7 @@ class Feature2Pyramid(nn.Module):
embed_dim, embed_dim,
rescales=[4, 2, 1, 0.5], rescales=[4, 2, 1, 0.5],
norm_cfg=dict(type='SyncBN', requires_grad=True)): norm_cfg=dict(type='SyncBN', requires_grad=True)):
super(Feature2Pyramid, self).__init__() super().__init__()
self.rescales = rescales self.rescales = rescales
self.upsample_4x = None self.upsample_4x = None
for k in self.rescales: for k in self.rescales:

View File

@ -80,7 +80,7 @@ class FPN(BaseModule):
upsample_cfg=dict(mode='nearest'), upsample_cfg=dict(mode='nearest'),
init_cfg=dict( init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')): type='Xavier', layer='Conv2d', distribution='uniform')):
super(FPN, self).__init__(init_cfg) super().__init__(init_cfg)
assert isinstance(in_channels, list) assert isinstance(in_channels, list)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels

View File

@ -42,7 +42,7 @@ class CascadeFeatureFusion(BaseModule):
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
align_corners=False, align_corners=False,
init_cfg=None): init_cfg=None):
super(CascadeFeatureFusion, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.align_corners = align_corners self.align_corners = align_corners
self.conv_low = ConvModule( self.conv_low = ConvModule(
low_channels, low_channels,
@ -108,7 +108,7 @@ class ICNeck(BaseModule):
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
align_corners=False, align_corners=False,
init_cfg=None): init_cfg=None):
super(ICNeck, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert len(in_channels) == 3, 'Length of input channels \ assert len(in_channels) == 3, 'Length of input channels \
must be 3!' must be 3!'

View File

@ -51,7 +51,7 @@ class JPU(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
init_cfg=None): init_cfg=None):
super(JPU, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert isinstance(in_channels, tuple) assert isinstance(in_channels, tuple)
assert isinstance(dilations, tuple) assert isinstance(dilations, tuple)
self.in_channels = in_channels self.in_channels = in_channels

View File

@ -12,7 +12,7 @@ class MLAModule(nn.Module):
out_channels=256, out_channels=256,
norm_cfg=None, norm_cfg=None,
act_cfg=None): act_cfg=None):
super(MLAModule, self).__init__() super().__init__()
self.channel_proj = nn.ModuleList() self.channel_proj = nn.ModuleList()
for i in range(len(in_channels)): for i in range(len(in_channels)):
self.channel_proj.append( self.channel_proj.append(
@ -83,7 +83,7 @@ class MLANeck(nn.Module):
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
norm_cfg=None, norm_cfg=None,
act_cfg=None): act_cfg=None):
super(MLANeck, self).__init__() super().__init__()
assert isinstance(in_channels, list) assert isinstance(in_channels, list)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels

View File

@ -29,7 +29,7 @@ class MultiLevelNeck(nn.Module):
scales=[0.5, 1, 2, 4], scales=[0.5, 1, 2, 4],
norm_cfg=None, norm_cfg=None,
act_cfg=None): act_cfg=None):
super(MultiLevelNeck, self).__init__() super().__init__()
assert isinstance(in_channels, list) assert isinstance(in_channels, list)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels

View File

@ -27,7 +27,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
def __init__(self, def __init__(self,
data_preprocessor: OptConfigType = None, data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None): init_cfg: OptMultiConfig = None):
super(BaseSegmentor, self).__init__( super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg) data_preprocessor=data_preprocessor, init_cfg=init_cfg)
@property @property

View File

@ -48,7 +48,7 @@ class CascadeEncoderDecoder(EncoderDecoder):
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
init_cfg: OptMultiConfig = None): init_cfg: OptMultiConfig = None):
self.num_stages = num_stages self.num_stages = num_stages
super(CascadeEncoderDecoder, self).__init__( super().__init__(
backbone=backbone, backbone=backbone,
decode_head=decode_head, decode_head=decode_head,
neck=neck, neck=neck,

View File

@ -78,7 +78,7 @@ class EncoderDecoder(BaseSegmentor):
data_preprocessor: OptConfigType = None, data_preprocessor: OptConfigType = None,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
init_cfg: OptMultiConfig = None): init_cfg: OptMultiConfig = None):
super(EncoderDecoder, self).__init__( super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg) data_preprocessor=data_preprocessor, init_cfg=init_cfg)
if pretrained is not None: if pretrained is not None:
assert backbone.get('pretrained') is None, \ assert backbone.get('pretrained') is None, \

View File

@ -42,7 +42,7 @@ class AdaptivePadding(nn.Module):
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
super(AdaptivePadding, self).__init__() super().__init__()
assert padding in ('same', 'corner') assert padding in ('same', 'corner')
@ -120,7 +120,7 @@ class PatchEmbed(BaseModule):
norm_cfg=None, norm_cfg=None,
input_size=None, input_size=None,
init_cfg=None): init_cfg=None):
super(PatchEmbed, self).__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims self.embed_dims = embed_dims
if stride is None: if stride is None:

View File

@ -16,7 +16,7 @@ class Encoding(nn.Module):
""" """
def __init__(self, channels, num_codes): def __init__(self, channels, num_codes):
super(Encoding, self).__init__() super().__init__()
# init codewords and smoothing factor # init codewords and smoothing factor
self.channels, self.num_codes = channels, num_codes self.channels, self.num_codes = channels, num_codes
std = 1. / ((num_codes * channels)**0.5) std = 1. / ((num_codes * channels)**0.5)

View File

@ -40,7 +40,7 @@ class InvertedResidual(nn.Module):
act_cfg=dict(type='ReLU6'), act_cfg=dict(type='ReLU6'),
with_cp=False, with_cp=False,
**kwargs): **kwargs):
super(InvertedResidual, self).__init__() super().__init__()
self.stride = stride self.stride = stride
assert stride in [1, 2], f'stride must in [1, 2]. ' \ assert stride in [1, 2], f'stride must in [1, 2]. ' \
f'But received {stride}.' f'But received {stride}.'
@ -138,7 +138,7 @@ class InvertedResidualV3(nn.Module):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
with_cp=False): with_cp=False):
super(InvertedResidualV3, self).__init__() super().__init__()
self.with_res_shortcut = (stride == 1 and in_channels == out_channels) self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
assert stride in [1, 2] assert stride in [1, 2]
self.with_cp = with_cp self.with_cp = with_cp

View File

@ -93,4 +93,4 @@ class ResLayer(Sequential):
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
**kwargs)) **kwargs))
super(ResLayer, self).__init__(*layers) super().__init__(*layers)

View File

@ -30,7 +30,7 @@ class SELayer(nn.Module):
conv_cfg=None, conv_cfg=None,
act_cfg=(dict(type='ReLU'), act_cfg=(dict(type='ReLU'),
dict(type='HSigmoid', bias=3.0, divisor=6.0))): dict(type='HSigmoid', bias=3.0, divisor=6.0))):
super(SELayer, self).__init__() super().__init__()
if isinstance(act_cfg, dict): if isinstance(act_cfg, dict):
act_cfg = (act_cfg, act_cfg) act_cfg = (act_cfg, act_cfg)
assert len(act_cfg) == 2 assert len(act_cfg) == 2

View File

@ -36,7 +36,7 @@ class SelfAttentionBlock(nn.Module):
key_downsample, key_query_num_convs, value_out_num_convs, key_downsample, key_query_num_convs, value_out_num_convs,
key_query_norm, value_out_norm, matmul_norm, with_out, key_query_norm, value_out_norm, matmul_norm, with_out,
conv_cfg, norm_cfg, act_cfg): conv_cfg, norm_cfg, act_cfg):
super(SelfAttentionBlock, self).__init__() super().__init__()
if share_key_query: if share_key_query:
assert key_in_channels == query_in_channels assert key_in_channels == query_in_channels
self.key_in_channels = key_in_channels self.key_in_channels = key_in_channels

View File

@ -57,7 +57,7 @@ class UpConvBlock(nn.Module):
upsample_cfg=dict(type='InterpConv'), upsample_cfg=dict(type='InterpConv'),
dcn=None, dcn=None,
plugins=None): plugins=None):
super(UpConvBlock, self).__init__() super().__init__()
assert dcn is None, 'Not implemented yet.' assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.'

View File

@ -34,7 +34,7 @@ class Upsample(nn.Module):
scale_factor=None, scale_factor=None,
mode='nearest', mode='nearest',
align_corners=None): align_corners=None):
super(Upsample, self).__init__() super().__init__()
self.size = size self.size = size
if isinstance(scale_factor, tuple): if isinstance(scale_factor, tuple):
self.scale_factor = tuple(float(factor) for factor in scale_factor) self.scale_factor = tuple(float(factor) for factor in scale_factor)

View File

@ -23,7 +23,7 @@ class OHEMPixelSampler(BasePixelSampler):
""" """
def __init__(self, context, thresh=None, min_kept=100000): def __init__(self, context, thresh=None, min_kept=100000):
super(OHEMPixelSampler, self).__init__() super().__init__()
self.context = context self.context = context
assert min_kept > 1 assert min_kept > 1
self.thresh = thresh self.thresh = thresh

View File

@ -15,4 +15,4 @@ def collect_env():
if __name__ == '__main__': if __name__ == '__main__':
for name, val in collect_env().items(): for name, val in collect_env().items():
print('{}: {}'.format(name, val)) print(f'{name}: {val}')

View File

@ -52,12 +52,12 @@ def stack_batch(inputs: List[torch.Tensor],
""" """
assert isinstance(inputs, list), \ assert isinstance(inputs, list), \
f'Expected input type to be list, but got {type(inputs)}' f'Expected input type to be list, but got {type(inputs)}'
assert len(set([tensor.ndim for tensor in inputs])) == 1, \ assert len({tensor.ndim for tensor in inputs}) == 1, \
f'Expected the dimensions of all inputs must be the same, ' \ f'Expected the dimensions of all inputs must be the same, ' \
f'but got {[tensor.ndim for tensor in inputs]}' f'but got {[tensor.ndim for tensor in inputs]}'
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \ assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
f'but got {inputs[0].ndim}' f'but got {inputs[0].ndim}'
assert len(set([tensor.shape[0] for tensor in inputs])) == 1, \ assert len({tensor.shape[0] for tensor in inputs}) == 1, \
f'Expected the channels of all inputs must be the same, ' \ f'Expected the channels of all inputs must be the same, ' \
f'but got {[tensor.shape[0] for tensor in inputs]}' f'but got {[tensor.shape[0] for tensor in inputs]}'

View File

@ -18,7 +18,7 @@ version_file = 'mmseg/version.py'
def get_version(): def get_version():
with open(version_file, 'r') as f: with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec')) exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__'] return locals()['__version__']
@ -74,12 +74,11 @@ def parse_requirements(fname='requirements.txt', with_version=True):
yield info yield info
def parse_require_file(fpath): def parse_require_file(fpath):
with open(fpath, 'r') as f: with open(fpath) as f:
for line in f.readlines(): for line in f.readlines():
line = line.strip() line = line.strip()
if line and not line.startswith('#'): if line and not line.startswith('#'):
for info in parse_line(line): yield from parse_line(line)
yield info
def gen_packages_items(): def gen_packages_items():
if exists(require_fpath): if exists(require_fpath):

View File

@ -31,7 +31,7 @@ def test_config_build_segmentor():
"""Test that all segmentation models defined in the configs can be """Test that all segmentation models defined in the configs can be
initialized.""" initialized."""
config_dpath = _get_config_directory() config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath)) print(f'Found config_dpath = {config_dpath!r}')
config_fpaths = [] config_fpaths = []
# one config each sub folder # one config each sub folder
@ -42,20 +42,20 @@ def test_config_build_segmentor():
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1] config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
config_names = [relpath(p, config_dpath) for p in config_fpaths] config_names = [relpath(p, config_dpath) for p in config_fpaths]
print('Using {} config files'.format(len(config_names))) print(f'Using {len(config_names)} config files')
for config_fname in config_names: for config_fname in config_names:
config_fpath = join(config_dpath, config_fname) config_fpath = join(config_dpath, config_fname)
config_mod = Config.fromfile(config_fpath) config_mod = Config.fromfile(config_fpath)
config_mod.model config_mod.model
print('Building segmentor, config_fpath = {!r}'.format(config_fpath)) print(f'Building segmentor, config_fpath = {config_fpath!r}')
# Remove pretrained keys to allow for testing in an offline environment # Remove pretrained keys to allow for testing in an offline environment
if 'pretrained' in config_mod.model: if 'pretrained' in config_mod.model:
config_mod.model['pretrained'] = None config_mod.model['pretrained'] = None
print('building {}'.format(config_fname)) print(f'building {config_fname}')
segmentor = build_segmentor(config_mod.model) segmentor = build_segmentor(config_mod.model)
assert segmentor is not None assert segmentor is not None
@ -72,19 +72,18 @@ def test_config_data_pipeline():
register_all_modules() register_all_modules()
config_dpath = _get_config_directory() config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath)) print(f'Found config_dpath = {config_dpath!r}')
import glob import glob
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py'))) config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1] config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
config_names = [relpath(p, config_dpath) for p in config_fpaths] config_names = [relpath(p, config_dpath) for p in config_fpaths]
print('Using {} config files'.format(len(config_names))) print(f'Using {len(config_names)} config files')
for config_fname in config_names: for config_fname in config_names:
config_fpath = join(config_dpath, config_fname) config_fpath = join(config_dpath, config_fname)
print( print(f'Building data pipeline, config_fpath = {config_fpath!r}')
'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
config_mod = Config.fromfile(config_fpath) config_mod = Config.fromfile(config_fpath)
# remove loading pipeline # remove loading pipeline
@ -112,7 +111,7 @@ def test_config_data_pipeline():
gt_seg_map=seg) gt_seg_map=seg)
results['seg_fields'] = ['gt_seg_map'] results['seg_fields'] = ['gt_seg_map']
print('Test training data pipeline: \n{!r}'.format(train_pipeline)) print(f'Test training data pipeline: \n{train_pipeline!r}')
output_results = train_pipeline(results) output_results = train_pipeline(results)
assert output_results is not None assert output_results is not None
@ -123,7 +122,7 @@ def test_config_data_pipeline():
img_shape=img.shape, img_shape=img.shape,
ori_shape=img.shape, ori_shape=img.shape,
) )
print('Test testing data pipeline: \n{!r}'.format(test_pipeline)) print(f'Test testing data pipeline: \n{test_pipeline!r}')
output_results = test_pipeline(results) output_results = test_pipeline(results)
assert output_results is not None assert output_results is not None

View File

@ -11,7 +11,7 @@ register_all_modules()
@DATASETS.register_module() @DATASETS.register_module()
class ToyDataset(object): class ToyDataset:
def __init__(self, cnt=0): def __init__(self, cnt=0):
self.cnt = cnt self.cnt = cnt

View File

@ -10,7 +10,7 @@ from mmcv.transforms import LoadImageFromFile
from mmseg.datasets.transforms import LoadAnnotations, LoadImageFromNDArray from mmseg.datasets.transforms import LoadAnnotations, LoadImageFromNDArray
class TestLoading(object): class TestLoading:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):

View File

@ -331,7 +331,7 @@ def test_resnet_backbone():
for param in layer.parameters(): for param in layer.parameters():
assert param.requires_grad is False assert param.requires_grad is False
for i in range(1, frozen_stages + 1): for i in range(1, frozen_stages + 1):
layer = getattr(model, 'layer{}'.format(i)) layer = getattr(model, f'layer{i}')
for mod in layer.modules(): for mod in layer.modules():
if isinstance(mod, _BatchNorm): if isinstance(mod, _BatchNorm):
assert mod.training is False assert mod.training is False
@ -347,7 +347,7 @@ def test_resnet_backbone():
for param in model.stem.parameters(): for param in model.stem.parameters():
assert param.requires_grad is False assert param.requires_grad is False
for i in range(1, frozen_stages + 1): for i in range(1, frozen_stages + 1):
layer = getattr(model, 'layer{}'.format(i)) layer = getattr(model, f'layer{i}')
for mod in layer.modules(): for mod in layer.modules():
if isinstance(mod, _BatchNorm): if isinstance(mod, _BatchNorm):
assert mod.training is False assert mod.training is False

View File

@ -101,7 +101,7 @@ def load_json_logs(json_logs):
log_dicts = [dict() for _ in json_logs] log_dicts = [dict() for _ in json_logs]
prev_step = 0 prev_step = 0
for json_log, log_dict in zip(json_logs, log_dicts): for json_log, log_dict in zip(json_logs, log_dicts):
with open(json_log, 'r') as log_file: with open(json_log) as log_file:
for line in log_file: for line in log_file:
log = json.loads(line.strip()) log = json.loads(line.strip())
# the final step in json file is 0. # the final step in json file is 0.

View File

@ -47,7 +47,7 @@ def main():
print('Generating training dataset...') print('Generating training dataset...')
assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \ assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \
'len(os.listdir(tmp_dir)) != {}'.format(CHASE_DB1_LEN) f'len(os.listdir(tmp_dir)) != {CHASE_DB1_LEN}'
for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(tmp_dir, img_name)) img = mmcv.imread(osp.join(tmp_dir, img_name))

View File

@ -63,7 +63,7 @@ def main():
zip_file.extractall(tmp_dir) zip_file.extractall(tmp_dir)
assert len(os.listdir(tmp_dir)) == HRF_LEN, \ assert len(os.listdir(tmp_dir)) == HRF_LEN, \
'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN) f'len(os.listdir(tmp_dir)) != {HRF_LEN}'
for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(tmp_dir, filename)) img = mmcv.imread(osp.join(tmp_dir, filename))
@ -85,7 +85,7 @@ def main():
zip_file.extractall(tmp_dir) zip_file.extractall(tmp_dir)
assert len(os.listdir(tmp_dir)) == HRF_LEN, \ assert len(os.listdir(tmp_dir)) == HRF_LEN, \
'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN) f'len(os.listdir(tmp_dir)) != {HRF_LEN}'
for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(tmp_dir, filename)) img = mmcv.imread(osp.join(tmp_dir, filename))

View File

@ -188,17 +188,17 @@ def main():
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test')) mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test'))
assert os.path.exists(os.path.join(dataset_path, 'train')), \ assert os.path.exists(os.path.join(dataset_path, 'train')), \
'train is not in {}'.format(dataset_path) f'train is not in {dataset_path}'
assert os.path.exists(os.path.join(dataset_path, 'val')), \ assert os.path.exists(os.path.join(dataset_path, 'val')), \
'val is not in {}'.format(dataset_path) f'val is not in {dataset_path}'
assert os.path.exists(os.path.join(dataset_path, 'test')), \ assert os.path.exists(os.path.join(dataset_path, 'test')), \
'test is not in {}'.format(dataset_path) f'test is not in {dataset_path}'
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
for dataset_mode in ['train', 'val', 'test']: for dataset_mode in ['train', 'val', 'test']:
# for dataset_mode in [ 'test']: # for dataset_mode in [ 'test']:
print('Extracting {}ing.zip...'.format(dataset_mode)) print(f'Extracting {dataset_mode}ing.zip...')
img_zipp_list = glob.glob( img_zipp_list = glob.glob(
os.path.join(dataset_path, dataset_mode, 'images', '*.zip')) os.path.join(dataset_path, dataset_mode, 'images', '*.zip'))
print('Find the data', img_zipp_list) print('Find the data', img_zipp_list)

View File

@ -38,11 +38,11 @@ def main():
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
assert 'Train.zip' in os.listdir(dataset_path), \ assert 'Train.zip' in os.listdir(dataset_path), \
'Train.zip is not in {}'.format(dataset_path) f'Train.zip is not in {dataset_path}'
assert 'Val.zip' in os.listdir(dataset_path), \ assert 'Val.zip' in os.listdir(dataset_path), \
'Val.zip is not in {}'.format(dataset_path) f'Val.zip is not in {dataset_path}'
assert 'Test.zip' in os.listdir(dataset_path), \ assert 'Test.zip' in os.listdir(dataset_path), \
'Test.zip is not in {}'.format(dataset_path) f'Test.zip is not in {dataset_path}'
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
for dataset in ['Train', 'Val', 'Test']: for dataset in ['Train', 'Val', 'Test']:

View File

@ -68,7 +68,7 @@ def main():
now_dir = osp.join(tmp_dir, 'files') now_dir = osp.join(tmp_dir, 'files')
assert len(os.listdir(now_dir)) == STARE_LEN, \ assert len(os.listdir(now_dir)) == STARE_LEN, \
'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) f'len(os.listdir(now_dir)) != {STARE_LEN}'
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(now_dir, filename)) img = mmcv.imread(osp.join(now_dir, filename))
@ -103,7 +103,7 @@ def main():
now_dir = osp.join(tmp_dir, 'files') now_dir = osp.join(tmp_dir, 'files')
assert len(os.listdir(now_dir)) == STARE_LEN, \ assert len(os.listdir(now_dir)) == STARE_LEN, \
'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) f'len(os.listdir(now_dir)) != {STARE_LEN}'
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(now_dir, filename)) img = mmcv.imread(osp.join(now_dir, filename))
@ -142,7 +142,7 @@ def main():
now_dir = osp.join(tmp_dir, 'files') now_dir = osp.join(tmp_dir, 'files')
assert len(os.listdir(now_dir)) == STARE_LEN, \ assert len(os.listdir(now_dir)) == STARE_LEN, \
'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) f'len(os.listdir(now_dir)) != {STARE_LEN}'
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(now_dir, filename)) img = mmcv.imread(osp.join(now_dir, filename))

View File

@ -126,7 +126,7 @@ def pytorch2libtorch(model,
print(traced_model.graph) print(traced_model.graph)
traced_model.save(output_file) traced_model.save(output_file)
print('Successfully exported TorchScript model: {}'.format(output_file)) print(f'Successfully exported TorchScript model: {output_file}')
def parse_args(): def parse_args():

Some files were not shown because too many files have changed in this diff Show More