[Enhance] Migrate to MMCV DepthwiseSeparableConv (#158)

* Add D16-MG124 models

* Use MMCV DepthSepConv

* add OHEM

* add warmup

* fixed test

* fixed test

* change to bs 16

* revert config

* add models

* seperate
This commit is contained in:
Jerry Jiarui XU 2020-09-25 19:56:10 +08:00 committed by GitHub
parent efc5c20cd0
commit f3f443ff71
14 changed files with 55 additions and 556 deletions

View File

@ -1,125 +0,0 @@
import argparse
import glob
import json
import os
import os.path as osp
import mmcv
# build schedule look-up table to automatically find the final model
SCHEDULES_LUT = {
'20ki': 20000,
'40ki': 40000,
'60ki': 60000,
'80ki': 80000,
'160ki': 160000
}
RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc']
def get_final_iter(config):
iter_num = SCHEDULES_LUT[config.split('_')[-2]]
return iter_num
def get_final_results(log_json_path, iter_num):
result_dict = dict()
with open(log_json_path, 'r') as f:
for line in f.readlines():
log_line = json.loads(line)
if 'mode' not in log_line.keys():
continue
if log_line['mode'] == 'train' and log_line['iter'] == iter_num:
result_dict['memory'] = log_line['memory']
if log_line['iter'] == iter_num:
result_dict.update({
key: log_line[key]
for key in RESULTS_LUT if key in log_line
})
return result_dict
def parse_args():
parser = argparse.ArgumentParser(description='Gather benchmarked models')
parser.add_argument(
'root',
type=str,
help='root path of benchmarked models to be gathered')
parser.add_argument(
'config',
type=str,
help='root path of benchmarked configs to be gathered')
args = parser.parse_args()
return args
def main():
args = parse_args()
models_root = args.root
config_name = args.config
# find all models in the root directory to be gathered
raw_configs = list(mmcv.scandir(config_name, '.py', recursive=True))
# filter configs that is not trained in the experiments dir
used_configs = []
for raw_config in raw_configs:
work_dir = osp.splitext(osp.basename(raw_config))[0]
if osp.exists(osp.join(models_root, work_dir)):
used_configs.append(work_dir)
print(f'Find {len(used_configs)} models to be gathered')
# find final_ckpt and log file for trained each config
# and parse the best performance
model_infos = []
for used_config in used_configs:
exp_dir = osp.join(models_root, used_config)
# check whether the exps is finished
final_iter = get_final_iter(used_config)
final_model = 'iter_{}.pth'.format(final_iter)
model_path = osp.join(exp_dir, final_model)
# skip if the model is still training
if not osp.exists(model_path):
print(f'{used_config} not finished yet')
continue
# get logs
log_json_path = glob.glob(osp.join(exp_dir, '*.log.json'))[0]
log_txt_path = glob.glob(osp.join(exp_dir, '*.log'))[0]
model_performance = get_final_results(log_json_path, final_iter)
if model_performance is None:
print(f'{used_config} does not have performance')
continue
model_time = osp.split(log_txt_path)[-1].split('.')[0]
model_infos.append(
dict(
config=used_config,
results=model_performance,
iters=final_iter,
model_time=model_time,
log_json_path=osp.split(log_json_path)[-1]))
# publish model for each checkpoint
for model in model_infos:
model_name = osp.split(model['config'])[-1].split('.')[0]
model_name += '_' + model['model_time']
for checkpoints in mmcv.scandir(
osp.join(models_root, model['config']), suffix='.pth'):
if checkpoints.endswith(f"iter_{model['iters']}.pth"
) or checkpoints.endswith('latest.pth'):
continue
print('removing {}'.format(
osp.join(models_root, model['config'], checkpoints)))
os.remove(osp.join(models_root, model['config'], checkpoints))
if __name__ == '__main__':
main()

View File

@ -1,152 +0,0 @@
import argparse
import csv
import glob
import json
import os.path as osp
from collections import OrderedDict
import mmcv
# build schedule look-up table to automatically find the final model
RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc']
def get_final_iter(config):
iter_num = config.split('_')[-2]
assert iter_num.endswith('ki')
return int(iter_num[:-2]) * 1000
def get_final_results(log_json_path, iter_num):
result_dict = dict()
with open(log_json_path, 'r') as f:
for line in f.readlines():
log_line = json.loads(line)
if 'mode' not in log_line.keys():
continue
if log_line['mode'] == 'train' and log_line[
'iter'] == iter_num - 50:
result_dict['memory'] = log_line['memory']
if log_line['iter'] == iter_num:
result_dict.update({
key: log_line[key] * 100
for key in RESULTS_LUT if key in log_line
})
return result_dict
def get_total_time(log_json_path, iter_num):
def convert(seconds):
hour = seconds // 3600
seconds %= 3600
minutes = seconds // 60
seconds %= 60
return f'{hour:d}:{minutes:2d}:{seconds:2d}'
time_dict = dict()
with open(log_json_path, 'r') as f:
last_iter = 0
total_sec = 0
for line in f.readlines():
log_line = json.loads(line)
if 'mode' not in log_line.keys():
continue
if log_line['mode'] == 'train':
cur_iter = log_line['iter']
total_sec += (cur_iter - last_iter) * log_line['time']
last_iter = cur_iter
time_dict['time'] = convert(int(total_sec))
return time_dict
def parse_args():
parser = argparse.ArgumentParser(description='Gather benchmarked models')
parser.add_argument(
'root',
type=str,
help='root path of benchmarked models to be gathered')
parser.add_argument(
'config',
type=str,
help='root path of benchmarked configs to be gathered')
parser.add_argument(
'out', type=str, help='output path of gathered models to be stored')
args = parser.parse_args()
return args
def main():
args = parse_args()
models_root = args.root
models_out = args.out
config_name = args.config
mmcv.mkdir_or_exist(models_out)
# find all models in the root directory to be gathered
raw_configs = list(mmcv.scandir(config_name, '.py', recursive=True))
# filter configs that is not trained in the experiments dir
exp_dirs = []
for raw_config in raw_configs:
work_dir = osp.splitext(osp.basename(raw_config))[0]
if osp.exists(osp.join(models_root, work_dir)):
exp_dirs.append(work_dir)
print(f'Find {len(exp_dirs)} models to be gathered')
# find final_ckpt and log file for trained each config
# and parse the best performance
model_infos = []
for work_dir in exp_dirs:
exp_dir = osp.join(models_root, work_dir)
# check whether the exps is finished
final_iter = get_final_iter(work_dir)
final_model = 'iter_{}.pth'.format(final_iter)
model_path = osp.join(exp_dir, final_model)
# skip if the model is still training
if not osp.exists(model_path):
print(f'{model_path} not finished yet')
continue
# get logs
log_json_path = glob.glob(osp.join(exp_dir, '*.log.json'))[0]
model_performance = get_final_results(log_json_path, final_iter)
if model_performance is None:
continue
head = work_dir.split('_')[0]
backbone = work_dir.split('_')[1]
crop_size = work_dir.split('_')[-3]
dataset = work_dir.split('_')[-1]
model_info = OrderedDict(
head=head,
backbone=backbone,
crop_size=crop_size,
dataset=dataset,
iters=f'{final_iter//1000}ki')
model_info.update(model_performance)
model_time = get_total_time(log_json_path, final_iter)
model_info.update(model_time)
model_info['config'] = work_dir
model_infos.append(model_info)
with open(
osp.join(models_out, 'models_table.csv'), 'w',
newline='') as csvfile:
writer = csv.writer(
csvfile, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
writer.writerow(model_infos[0].keys())
for model_info in model_infos:
writer.writerow(model_info.values())
if __name__ == '__main__':
main()

View File

@ -1,58 +0,0 @@
import argparse
import os
import os.path as osp
import mmcv
from pytablewriter import Align, MarkdownTableWriter
def parse_args():
parser = argparse.ArgumentParser(description='Gather benchmarked models')
parser.add_argument('table_cache', type=str, help='table_cache input')
parser.add_argument('out', type=str, help='output path md')
args = parser.parse_args()
return args
def main():
args = parse_args()
table_cache = mmcv.load(args.table_cache)
output_dir = args.out
writer = MarkdownTableWriter()
writer.headers = [
'Method', 'Backbone', 'Crop Size', 'Lr schd', 'Mem (GB)',
'Inf time (fps)', 'mIoU', 'mIoU(ms+flip)', 'download'
]
writer.margin = 1
writer.align_list = [Align.CENTER] * len(writer.headers)
dataset_maps = {
'cityscapes': 'Cityscapes',
'ade20k': 'ADE20K',
'voc12aug': 'Pascal VOC 2012 + Aug'
}
for directory in table_cache:
for dataset in table_cache[directory]:
table = table_cache[directory][dataset][0]
writer.table_name = dataset_maps[dataset]
writer.value_matrix = table
for i in range(len(table)):
if table[i][-4] != '-':
table[i][-4] = f'{table[i][-4]:.2f}'
mmcv.mkdir_or_exist(osp.join(output_dir, directory))
writer.dump(
osp.join(output_dir, directory, f'README_{dataset}.md'))
with open(osp.join(output_dir, directory, 'README.md'), 'w') as dst_f:
for dataset in dataset_maps:
dataset_md_file = osp.join(output_dir, directory,
f'README_{dataset}.md')
with open(dataset_md_file) as src_f:
for line in src_f:
dst_f.write(line)
dst_f.write('\n')
os.remove(dataset_md_file)
if __name__ == '__main__':
main()

44
.dev/upload_modelzoo.py Normal file
View File

@ -0,0 +1,44 @@
import argparse
import os
import os.path as osp
import oss2
ACCESS_KEY_ID = os.getenv('OSS_ACCESS_KEY_ID', None)
ACCESS_KEY_SECRET = os.getenv('OSS_ACCESS_KEY_SECRET', None)
BUCKET_NAME = 'openmmlab'
ENDPOINT = 'https://oss-accelerate.aliyuncs.com'
def parse_args():
parser = argparse.ArgumentParser(description='Upload models to OSS')
parser.add_argument('model_zoo', type=str, help='model_zoo input')
parser.add_argument(
'--dst-folder',
type=str,
default='mmsegmentation/v0.5',
help='destination folder')
args = parser.parse_args()
return args
def main():
args = parse_args()
model_zoo = args.model_zoo
dst_folder = args.dst_folder
bucket = oss2.Bucket(
oss2.Auth(ACCESS_KEY_ID, ACCESS_KEY_SECRET), ENDPOINT, BUCKET_NAME)
for root, dirs, files in os.walk(model_zoo):
for file in files:
file_path = osp.relpath(osp.join(root, file), model_zoo)
print(f'Uploading {file_path}')
oss2.resumable_upload(bucket, osp.join(dst_folder, file_path),
osp.join(model_zoo, file_path))
bucket.put_object_acl(
osp.join(dst_folder, file_path), oss2.OBJECT_ACL_PUBLIC_READ)
if __name__ == '__main__':
main()

View File

@ -1,4 +1,3 @@
from .dist_utils import allreduce_grads
from .misc import add_prefix from .misc import add_prefix
__all__ = ['add_prefix', 'allreduce_grads'] __all__ = ['add_prefix']

View File

@ -1,49 +0,0 @@
from collections import OrderedDict
import torch.distributed as dist
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
dist.all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
"""Allreduce gradients.
Args:
params (list[torch.Parameters]): List of parameters of a model
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
grads = [
param.grad.data for param in params
if param.requires_grad and param.grad is not None
]
world_size = dist.get_world_size()
if coalesce:
_allreduce_coalesced(grads, world_size, bucket_size_mb)
else:
for tensor in grads:
dist.all_reduce(tensor.div_(world_size))

View File

@ -1,10 +1,11 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
kaiming_init)
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.models.decode_heads.psp_head import PPM from mmseg.models.decode_heads.psp_head import PPM
from mmseg.ops import DepthwiseSeparableConvModule, resize from mmseg.ops import resize
from mmseg.utils import InvertedResidual from mmseg.utils import InvertedResidual
from ..builder import BACKBONES from ..builder import BACKBONES

View File

@ -1,8 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmseg.ops import DepthwiseSeparableConvModule, resize from mmseg.ops import resize
from ..builder import HEADS from ..builder import HEADS
from .aspp_head import ASPPHead, ASPPModule from .aspp_head import ASPPHead, ASPPModule

View File

@ -1,4 +1,5 @@
from mmseg.ops import DepthwiseSeparableConvModule from mmcv.cnn import DepthwiseSeparableConvModule
from ..builder import HEADS from ..builder import HEADS
from .fcn_head import FCNHead from .fcn_head import FCNHead

View File

@ -1,5 +1,4 @@
from .encoding import Encoding from .encoding import Encoding
from .separable_conv_module import DepthwiseSeparableConvModule
from .wrappers import Upsample, resize from .wrappers import Upsample, resize
__all__ = ['Upsample', 'resize', 'DepthwiseSeparableConvModule', 'Encoding'] __all__ = ['Upsample', 'resize', 'Encoding']

View File

@ -1,88 +0,0 @@
import torch.nn as nn
from mmcv.cnn import ConvModule
class DepthwiseSeparableConvModule(nn.Module):
"""Depthwise separable convolution module.
See https://arxiv.org/pdf/1704.04861.pdf for details.
This module can replace a ConvModule with the conv block replaced by two
conv block: depthwise conv block and pointwise conv block. The depthwise
conv block contains depthwise-conv/norm/activation layers. The pointwise
conv block contains pointwise-conv/norm/activation layers. It should be
noted that there will be norm/activation layer in the depthwise conv block
if `norm_cfg` and `act_cfg` are specified.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d. Default: 1.
padding (int or tuple[int]): Same as nn.Conv2d. Default: 0.
dilation (int or tuple[int]): Same as nn.Conv2d. Default: 1.
norm_cfg (dict): Default norm config for both depthwise ConvModule and
pointwise ConvModule. Default: None.
act_cfg (dict): Default activation config for both depthwise ConvModule
and pointwise ConvModule. Default: dict(type='ReLU').
dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
'default', it will be the same as `norm_cfg`. Default: 'default'.
dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
'default', it will be the same as `act_cfg`. Default: 'default'.
pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
'default', it will be the same as `norm_cfg`. Default: 'default'.
pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
'default', it will be the same as `act_cfg`. Default: 'default'.
kwargs (optional): Other shared arguments for depthwise and pointwise
ConvModule. See ConvModule for ref.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
dw_norm_cfg='default',
dw_act_cfg='default',
pw_norm_cfg='default',
pw_act_cfg='default',
**kwargs):
super(DepthwiseSeparableConvModule, self).__init__()
assert 'groups' not in kwargs, 'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not
# specified, use default config.
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
# depthwise convolution
self.depthwise_conv = ConvModule(
in_channels,
in_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
norm_cfg=dw_norm_cfg,
act_cfg=dw_act_cfg,
**kwargs)
self.pointwise_conv = ConvModule(
in_channels,
out_channels,
1,
norm_cfg=pw_norm_cfg,
act_cfg=pw_act_cfg,
**kwargs)
def forward(self, x):
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x

View File

@ -8,6 +8,6 @@ line_length = 79
multi_line_output = 0 multi_line_output = 0
known_standard_library = setuptools known_standard_library = setuptools
known_first_party = mmseg known_first_party = mmseg
known_third_party = PIL,cityscapesscripts,detail,matplotlib,mmcv,numpy,onnxruntime,pytablewriter,pytest,scipy,torch known_third_party = PIL,cityscapesscripts,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,torch
no_lines_before = STDLIB,LOCALFOLDER no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY

View File

@ -2,7 +2,7 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmcv.utils import ConfigDict from mmcv.utils import ConfigDict
from mmcv.utils.parrots_wrapper import SyncBatchNorm from mmcv.utils.parrots_wrapper import SyncBatchNorm
@ -557,7 +557,6 @@ def test_sep_fcn_head():
output = head(x) output = head(x)
assert output.shape == (2, head.num_classes, 32, 32) assert output.shape == (2, head.num_classes, 32, 32)
assert not head.concat_input assert not head.concat_input
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
assert isinstance(head.convs[0], DepthwiseSeparableConvModule) assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
assert isinstance(head.convs[1], DepthwiseSeparableConvModule) assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
assert head.conv_seg.kernel_size == (1, 1) assert head.conv_seg.kernel_size == (1, 1)
@ -573,7 +572,6 @@ def test_sep_fcn_head():
output = head(x) output = head(x)
assert output.shape == (3, head.num_classes, 32, 32) assert output.shape == (3, head.num_classes, 32, 32)
assert head.concat_input assert head.concat_input
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
assert isinstance(head.convs[0], DepthwiseSeparableConvModule) assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
assert isinstance(head.convs[1], DepthwiseSeparableConvModule) assert isinstance(head.convs[1], DepthwiseSeparableConvModule)

View File

@ -1,71 +0,0 @@
import pytest
import torch
import torch.nn as nn
from mmseg.ops import DepthwiseSeparableConvModule
def test_depthwise_separable_conv():
with pytest.raises(AssertionError):
# conv_cfg must be a dict or None
DepthwiseSeparableConvModule(4, 8, 2, groups=2)
# test default config
conv = DepthwiseSeparableConvModule(3, 8, 2)
assert conv.depthwise_conv.conv.groups == 3
assert conv.pointwise_conv.conv.kernel_size == (1, 1)
assert not conv.depthwise_conv.with_norm
assert not conv.pointwise_conv.with_norm
assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU'
assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU'
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# test
conv = DepthwiseSeparableConvModule(3, 8, 2, dw_norm_cfg=dict(type='BN'))
assert conv.depthwise_conv.norm_name == 'bn'
assert not conv.pointwise_conv.with_norm
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
conv = DepthwiseSeparableConvModule(3, 8, 2, pw_norm_cfg=dict(type='BN'))
assert not conv.depthwise_conv.with_norm
assert conv.pointwise_conv.norm_name == 'bn'
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# add test for ['norm', 'conv', 'act']
conv = DepthwiseSeparableConvModule(3, 8, 2, order=('norm', 'conv', 'act'))
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, with_spectral_norm=True)
assert hasattr(conv.depthwise_conv.conv, 'weight_orig')
assert hasattr(conv.pointwise_conv.conv, 'weight_orig')
output = conv(x)
assert output.shape == (1, 8, 256, 256)
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, padding_mode='reflect')
assert isinstance(conv.depthwise_conv.padding_layer, nn.ReflectionPad2d)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, dw_act_cfg=dict(type='LeakyReLU'))
assert conv.depthwise_conv.activate.__class__.__name__ == 'LeakyReLU'
assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU'
output = conv(x)
assert output.shape == (1, 8, 256, 256)
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, pw_act_cfg=dict(type='LeakyReLU'))
assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU'
assert conv.pointwise_conv.activate.__class__.__name__ == 'LeakyReLU'
output = conv(x)
assert output.shape == (1, 8, 256, 256)