EasyCV/easycv/models/backbones/genet.py

1683 lines
56 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import uuid
import numpy as np
import torch
import torch.nn.functional as F
from mmcv.cnn import constant_init, kaiming_init
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
from ..modelzoo import genet as model_urls
from ..registry import BACKBONES
GENET_LARGE = 'ConvKX(uuid9d1dca0f098143aaa1a947acf1100787|3,32,3,2)\
BN(uuid7d10ba10dc524ffb8863ae97c4a21797|32)RELU(uuidccd810d3d10a48158ccfa48ca975915c|32)\
SuperResKXKX(uuid5ba1db21fce64b16a34ad577c258fd6c|32,128,3,2,1.0,1)\
SuperResKXKX(uuida09fc4e4946444bf9b912f8c666c4b12|128,192,3,2,1.0,2)\
SuperResK1KXK1(uuidfa45c5f5cc96435dbd54801f31c83ca8|192,640,3,2,0.25,6)\
SuperResK1DWK1(uuid99bf6442b33643579dc680045da7549d|640,640,3,2,3.0,5)\
SuperResK1DWK1(uuid615cbfd4ed284cbc8589d84cbe9b0e92|640,640,3,1,3.0,4)\
ConvKX(uuid002fa25f74f14cdeb89a5aacd6ce64ff|640,2560,1,1)\
BN(uuidc5d6c88c326343efa2a8700907f87732|2560)RELU(uuidd2b39caab4cb4ac2b6905b18858c0037|2560)AdaptiveAvgPool(2560,1)'
GENET_NORMAL = 'ConvKX(uuid70de938099844017bd745349f7a1d35a|3,32,3,2)\
BN(uuid10f8a99f83294067bfdf5fc5a5c9bffd|32)\
RELU(uuideffe03bd73254e7c8027364ba71d25cd|32)\
SuperResKXKX(uuidb023bea8c7b34c22a1650e07dfc8e2c1|32,128,3,2,1.0,1)\
SuperResKXKX(uuidf829740023044b879eefaf7fc7d1ad8e|128,192,3,2,1.0,2)\
SuperResK1KXK1(uuid33bfe77cb8864357a840ca3341ea629a|192,640,3,2,0.25,6)\
SuperResK1DWK1(uuide2c948d819fb4869980e30d67a773244|640,640,3,2,3.0,4)\
SuperResK1DWK1(uuid53c308e481c24154b7a81fcbaf99edbf|640,640,3,1,3.0,1)\
ConvKX(uuidbc6953bfd8de45fc8534787a66b96430|640,2560,1,1)\
BN(uuida8acaaae74ed47a4a7514b41c643eb23|2560)RELU(uuida5d71c4fd5d24a7b848472f0383df467|2560)AdaptiveAvgPool(2560,1)'
GENET_SMALL = 'ConvKX(uuid46ff2328b77f40ff88aed69a5318d771|3,13,3,2)\
BN(uuid43b72f65311c42d9a1af485c594a6ab4|13)RELU(uuid282901aaa7f84b028e3c5bd7d37ae056|13)\
SuperResKXKX(uuiddb56d6f9a60b4455966e13b06a8ff723|13,48,3,2,1.0,1)\
SuperResKXKX(uuidd964406e6fdf4e9abac225afaeb1fe0b|48,48,3,2,1.0,3)\
SuperResK1KXK1(uuid39819ad4f4da405583de614af437b568|48,384,3,2,0.25,7)\
SuperResK1DWK1(uuid420593fe7b1e46f690b76bac3786d4b7|384,560,3,2,3.0,2)\
SuperResK1DWK1(uuid96236b3c50774f1ab2d3049d6aca6d85|560,256,3,1,3.0,1)\
ConvKX(uuid89ed263767a14f21b7426cccb120ad1d|256,1920,1,1)\
BN(uuidd6ad568b290544be9f4b47dc3fa271c9|1920)RELU(uuid823ced7441394fb9b3a96a5f7c40da2b|1920)AdaptiveAvgPool(1920,1)'
plainnet_struct_dict = {
'normal': GENET_NORMAL,
'large': GENET_LARGE,
'small': GENET_SMALL
}
# ------------ Fuse BN ------
def _fuse_convkx_and_bn_(convkx, bn):
the_weight_scale = bn.weight / torch.sqrt(bn.running_var + bn.eps)
convkx.weight[:] = convkx.weight * the_weight_scale.view((-1, 1, 1, 1))
the_bias_shift = (bn.weight * bn.running_mean) / \
torch.sqrt(bn.running_var + bn.eps)
bn.weight[:] = 1
bn.bias[:] = bn.bias - the_bias_shift
bn.running_var[:] = 1.0 - bn.eps
bn.running_mean[:] = 0.0
convkx.bias = nn.Parameter(bn.bias)
def remove_bn_in_superblock(super_block):
new_shortcut_list = []
for the_seq_list in super_block.shortcut_list:
assert isinstance(the_seq_list, nn.Sequential)
new_seq_list = []
last_block = None
for block in the_seq_list:
if isinstance(block, nn.BatchNorm2d):
_fuse_convkx_and_bn_(last_block, block)
else:
new_seq_list.append(block)
last_block = block
new_shortcut_list.append(nn.Sequential(*new_seq_list))
super_block.shortcut_list = nn.ModuleList(new_shortcut_list)
new_conv_list = []
for the_seq_list in super_block.conv_list:
assert isinstance(the_seq_list, nn.Sequential)
new_seq_list = []
last_block = None
for block in the_seq_list:
if isinstance(block, nn.BatchNorm2d):
_fuse_convkx_and_bn_(last_block, block)
else:
new_seq_list.append(block)
last_block = block
new_conv_list.append(nn.Sequential(*new_seq_list))
super_block.conv_list = nn.ModuleList(new_conv_list)
def fuse_bn(model):
the_block_list = model.block_list
last_block = the_block_list[0]
new_block_list = [last_block]
for the_block in the_block_list[1:]:
if isinstance(the_block, BN):
_fuse_convkx_and_bn_(last_block.netblock, the_block.netblock)
else:
new_block_list.append(the_block)
last_block = the_block
pass
the_block_list = new_block_list
for the_block in the_block_list:
if hasattr(the_block, 'shortcut_list'):
remove_bn_in_superblock(the_block)
else:
continue
model.block_list = new_block_list
model.module_list = nn.ModuleList(new_block_list)
return model
# ------------ end of fuse bn --------
def _create_netblock_list_from_str_(s, no_create=False):
block_list = []
while len(s) > 0:
is_found_block_class = False
for the_block_class_name in _all_netblocks_dict_.keys():
if s.startswith(the_block_class_name):
is_found_block_class = True
the_block_class = _all_netblocks_dict_[the_block_class_name]
the_block, remaining_s = the_block_class.create_from_str(
s, no_create=no_create)
if the_block is not None:
block_list.append(the_block)
s = remaining_s
if len(s) > 0 and s[0] == ';':
return block_list, s[1:]
break
pass # end if
pass # end for
assert is_found_block_class
pass # end while
return block_list, ''
def _get_right_parentheses_index_(s):
# assert s[0] == '('
left_paren_count = 0
for index, x in enumerate(s):
if x == '(':
left_paren_count += 1
elif x == ')':
left_paren_count -= 1
if left_paren_count == 0:
return index
else:
pass
return None
'''
-------------------- GENet Blocks --------------------
'''
class PlainNetBasicBlockClass(nn.Module):
def __init__(self,
in_channels=0,
out_channels=0,
stride=1,
no_create=False,
block_name=None,
**kwargs):
super(PlainNetBasicBlockClass, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.no_create = no_create
self.block_name = block_name
def forward(self, x):
return x
@staticmethod
def create_from_str(s, no_create=False):
assert PlainNetBasicBlockClass.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('PlainNetBasicBlockClass('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
param_str_split = param_str.split(',')
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
stride = int(param_str_split[2])
return PlainNetBasicBlockClass(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
block_name=tmp_block_name,
no_create=no_create), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('PlainNetBasicBlockClass(') and s[-1] == ')':
return True
else:
return False
class AdaptiveAvgPool(PlainNetBasicBlockClass):
def __init__(self,
out_channels,
output_size,
no_create=False,
block_name=None,
**kwargs):
super(AdaptiveAvgPool, self).__init__(**kwargs)
self.in_channels = out_channels
self.out_channels = out_channels * output_size**2
self.output_size = output_size
self.block_name = block_name
if not no_create:
self.netblock = nn.AdaptiveAvgPool2d(
output_size=(self.output_size, self.output_size))
def forward(self, x):
return self.netblock(x)
@staticmethod
def create_from_str(s, no_create=False):
assert AdaptiveAvgPool.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('AdaptiveAvgPool('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
param_str_split = param_str.split(',')
out_channels = int(param_str_split[0])
output_size = int(param_str_split[1])
return AdaptiveAvgPool(
out_channels=out_channels,
output_size=output_size,
block_name=tmp_block_name,
no_create=no_create), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('AdaptiveAvgPool(') and s[-1] == ')':
return True
else:
return False
class BN(PlainNetBasicBlockClass):
def __init__(self,
out_channels=None,
copy_from=None,
no_create=False,
block_name=None,
**kwargs):
super(BN, self).__init__(**kwargs)
self.block_name = block_name
if copy_from is not None:
assert isinstance(copy_from, nn.BatchNorm2d)
self.in_channels = copy_from.weight.shape[0]
self.out_channels = copy_from.weight.shape[0]
assert out_channels is None or out_channels == self.out_channels
self.netblock = copy_from
else:
self.in_channels = out_channels
self.out_channels = out_channels
if no_create:
return
else:
self.netblock = nn.BatchNorm2d(num_features=self.out_channels)
def forward(self, x):
return self.netblock(x)
@staticmethod
def create_from_str(s, no_create=False):
assert BN.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('BN('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
out_channels = int(param_str)
return BN(
out_channels=out_channels,
block_name=tmp_block_name,
no_create=no_create), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('BN(') and s[-1] == ')':
return True
else:
return False
class ConvDW(PlainNetBasicBlockClass):
def __init__(self,
out_channels=None,
kernel_size=None,
stride=None,
copy_from=None,
no_create=False,
block_name=None,
**kwargs):
super(ConvDW, self).__init__(**kwargs)
self.block_name = block_name
self.use_weight_mean_zero_constrain = False
if copy_from is not None:
assert isinstance(copy_from, nn.Conv2d)
self.in_channels = copy_from.in_channels
self.out_channels = copy_from.out_channels
self.kernel_size = copy_from.kernel_size[0]
self.stride = copy_from.stride[0]
assert self.in_channels == self.out_channels
assert out_channels is None or out_channels == self.out_channels
assert kernel_size is None or kernel_size == self.kernel_size
assert stride is None or stride == self.stride
self.netblock = copy_from
else:
self.in_channels = out_channels
self.out_channels = out_channels
self.stride = stride
self.kernel_size = kernel_size
self.padding = (self.kernel_size - 1) // 2
if no_create or self.in_channels == 0 or self.out_channels == 0 or self.kernel_size == 0 \
or self.stride == 0:
return
else:
self.netblock = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=False,
groups=self.in_channels)
def forward(self, x):
output = self.netblock(x)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert ConvDW.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('ConvDW('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
split_str = param_str.split(',')
out_channels = int(split_str[0])
kernel_size = int(split_str[1])
stride = int(split_str[2])
return ConvDW(
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
no_create=no_create,
block_name=tmp_block_name), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('ConvDW(') and s[-1] == ')':
return True
else:
return False
class ConvKX(PlainNetBasicBlockClass):
def __init__(self,
in_channels=None,
out_channels=None,
kernel_size=None,
stride=None,
copy_from=None,
no_create=False,
block_name=None,
**kwargs):
super(ConvKX, self).__init__(**kwargs)
self.block_name = block_name
self.use_weight_mean_zero_constrain = False
if copy_from is not None:
assert isinstance(copy_from, nn.Conv2d)
self.in_channels = copy_from.in_channels
self.out_channels = copy_from.out_channels
self.kernel_size = copy_from.kernel_size[0]
self.stride = copy_from.stride[0]
assert in_channels is None or in_channels == self.in_channels
assert out_channels is None or out_channels == self.out_channels
assert kernel_size is None or kernel_size == self.kernel_size
assert stride is None or stride == self.stride
self.netblock = copy_from
else:
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.kernel_size = kernel_size
self.padding = (self.kernel_size - 1) // 2
if no_create or self.in_channels == 0 or self.out_channels == 0 or self.kernel_size == 0 \
or self.stride == 0:
return
else:
self.netblock = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=False)
def forward(self, x):
output = self.netblock(x)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert ConvKX.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('ConvKX('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
split_str = param_str.split(',')
in_channels = int(split_str[0])
out_channels = int(split_str[1])
kernel_size = int(split_str[2])
stride = int(split_str[3])
return ConvKX(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
no_create=no_create,
block_name=tmp_block_name), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('ConvKX(') and s[-1] == ')':
return True
else:
return False
class Flatten(PlainNetBasicBlockClass):
def __init__(self,
out_channels,
no_create=False,
block_name=None,
**kwargs):
super(Flatten, self).__init__(**kwargs)
self.block_name = block_name
self.in_channels = out_channels
self.out_channels = out_channels
def forward(self, x):
return torch.flatten(x, 1)
@staticmethod
def create_from_str(s, no_create=False):
assert Flatten.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('Flatten('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
out_channels = int(param_str)
return Flatten(
out_channels=out_channels,
no_create=no_create,
block_name=tmp_block_name), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('Flatten(') and s[-1] == ')':
return True
else:
return False
class Linear(PlainNetBasicBlockClass):
def __init__(self,
in_channels=None,
out_channels=None,
bias=None,
copy_from=None,
no_create=False,
block_name=None,
**kwargs):
super(Linear, self).__init__(**kwargs)
self.block_name = block_name
if copy_from is not None:
assert isinstance(copy_from, nn.Linear)
self.in_channels = copy_from.in_channels
self.out_channels = copy_from.out_channels
self.bias = copy_from.bias
assert in_channels is None or in_channels == self.in_channels
assert out_channels is None or out_channels == self.out_channels
assert bias is None or bias == self.bias
self.netblock = copy_from
else:
self.in_channels = in_channels
self.out_channels = out_channels
self.bias = bias
if not no_create:
self.netblock = nn.Linear(
self.in_channels, self.out_channels, bias=self.bias)
def forward(self, x):
return self.netblock(x)
@staticmethod
def create_from_str(s, no_create=False):
assert Linear.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('Linear('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
split_str = param_str.split(',')
in_channels = int(split_str[0])
out_channels = int(split_str[1])
bias = int(split_str[2])
return Linear(
in_channels=in_channels,
out_channels=out_channels,
bias=bias == 1,
block_name=tmp_block_name,
no_create=no_create), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('Linear(') and s[-1] == ')':
return True
else:
return False
class MaxPool(PlainNetBasicBlockClass):
def __init__(self,
out_channels,
kernel_size,
stride,
no_create=False,
block_name=None,
**kwargs):
super(MaxPool, self).__init__(**kwargs)
self.block_name = block_name
self.in_channels = out_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = (kernel_size - 1) // 2
if not no_create:
self.netblock = nn.MaxPool2d(
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding)
def forward(self, x):
return self.netblock(x)
@staticmethod
def create_from_str(s, no_create=False):
assert MaxPool.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('MaxPool('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
param_str_split = param_str.split(',')
out_channels = int(param_str_split[0])
kernel_size = int(param_str_split[1])
stride = int(param_str_split[2])
return MaxPool(
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
no_create=no_create,
block_name=tmp_block_name), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('MaxPool(') and s[-1] == ')':
return True
else:
return False
class MultiSumBlock(PlainNetBasicBlockClass):
def __init__(self,
inner_block_list,
no_create=False,
block_name=None,
**kwargs):
super(MultiSumBlock, self).__init__(**kwargs)
self.block_name = block_name
self.inner_block_list = inner_block_list
if not no_create:
self.inner_module_list = nn.ModuleList(inner_block_list)
self.in_channels = np.max([x.in_channels for x in inner_block_list])
self.out_channels = np.max([x.out_channels for x in inner_block_list])
res = 1024
res = self.inner_block_list[0].get_output_resolution(res)
self.stride = 1024 // res
def forward(self, x):
output = self.inner_block_list[0](x)
for inner_block in self.inner_block_list[1:]:
output2 = inner_block(x)
output = output + output2
return output
@staticmethod
def create_from_str(s, no_create=False):
assert MultiSumBlock.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('MultiSumBlock('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
the_s = param_str
the_inner_block_list = []
while len(the_s) > 0:
tmp_block_list, remaining_s = _create_netblock_list_from_str_(
the_s, no_create=no_create)
the_s = remaining_s
if tmp_block_list is None:
pass
elif len(tmp_block_list) == 1:
the_inner_block_list.append(tmp_block_list[0])
else:
the_inner_block_list.append(
Sequential(
inner_block_list=tmp_block_list, no_create=no_create))
pass # end while
if len(the_inner_block_list) == 0:
return None, s[idx + 1:]
return MultiSumBlock(
inner_block_list=the_inner_block_list,
block_name=tmp_block_name,
no_create=no_create), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('MultiSumBlock(') and s[-1] == ')':
return True
else:
return False
class RELU(PlainNetBasicBlockClass):
def __init__(self,
out_channels,
no_create=False,
block_name=None,
**kwargs):
super(RELU, self).__init__(**kwargs)
self.block_name = block_name
self.in_channels = out_channels
self.out_channels = out_channels
def forward(self, x):
return F.relu(x)
@staticmethod
def create_from_str(s, no_create=False):
assert RELU.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('RELU('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
out_channels = int(param_str)
return RELU(
out_channels=out_channels,
no_create=no_create,
block_name=tmp_block_name), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('RELU(') and s[-1] == ')':
return True
else:
return False
class ResBlock(PlainNetBasicBlockClass):
'''
ResBlock(in_channles, inner_blocks_str). If in_channels is missing, use inner_block_list[0].in_channels as in_channels
'''
def __init__(self,
inner_block_list,
in_channels=None,
stride=None,
no_create=False,
block_name=None,
**kwargs):
super(ResBlock, self).__init__(**kwargs)
self.block_name = block_name
self.inner_block_list = inner_block_list
self.stride = stride
if not no_create:
self.inner_module_list = nn.ModuleList(inner_block_list)
if in_channels is None:
self.in_channels = inner_block_list[0].in_channels
else:
self.in_channels = in_channels
self.out_channels = max(self.in_channels,
inner_block_list[-1].out_channels)
if self.stride is None:
tmp_input_res = 1024
tmp_output_res = self.get_output_resolution(tmp_input_res)
self.stride = tmp_input_res // tmp_output_res
def forward(self, x):
if self.stride > 1:
downsampled_x = F.avg_pool2d(
x,
kernel_size=self.stride + 1,
stride=self.stride,
padding=self.stride // 2)
else:
downsampled_x = x
if len(self.inner_block_list) == 0:
return downsampled_x
output = x
for inner_block in self.inner_block_list:
output = inner_block(output)
output = output + downsampled_x
return output
@staticmethod
def create_from_str(s, no_create=False):
assert ResBlock.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
the_stride = None
param_str = s[len('ResBlock('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
first_comma_index = param_str.find(',')
if first_comma_index < 0 or not param_str[0:first_comma_index].isdigit(
): # cannot parse in_channels, missing, use default
in_channels = None
the_inner_block_list, remaining_s = _create_netblock_list_from_str_(
param_str, no_create=no_create)
else:
in_channels = int(param_str[0:first_comma_index])
param_str = param_str[first_comma_index + 1:]
second_comma_index = param_str.find(',')
if second_comma_index < 0 or not param_str[
0:second_comma_index].isdigit():
the_inner_block_list, remaining_s = _create_netblock_list_from_str_(
param_str, no_create=no_create)
else:
the_stride = int(param_str[0:second_comma_index])
param_str = param_str[second_comma_index + 1:]
the_inner_block_list, remaining_s = _create_netblock_list_from_str_(
param_str, no_create=no_create)
pass
pass
assert len(remaining_s) == 0
if the_inner_block_list is None or len(the_inner_block_list) == 0:
return None, s[idx + 1:]
return ResBlock(
inner_block_list=the_inner_block_list,
in_channels=in_channels,
stride=the_stride,
no_create=no_create,
block_name=tmp_block_name), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('ResBlock(') and s[-1] == ')':
return True
else:
return False
class Sequential(PlainNetBasicBlockClass):
def __init__(self,
inner_block_list,
no_create=False,
block_name=None,
**kwargs):
super(Sequential, self).__init__(**kwargs)
self.block_name = block_name
self.inner_block_list = inner_block_list
if not no_create:
self.inner_module_list = nn.ModuleList(inner_block_list)
self.in_channels = inner_block_list[0].in_channels
self.out_channels = inner_block_list[-1].out_channels
res = 1024
for block in self.inner_block_list:
res = block.get_output_resolution(res)
self.stride = 1024 // res
def forward(self, x):
output = x
for inner_block in self.inner_block_list:
output = inner_block(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert Sequential.is_instance_from_str(s)
the_right_paraen_idx = _get_right_parentheses_index_(s)
param_str = s[len('Sequential(') + 1:the_right_paraen_idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
the_inner_block_list, remaining_s = _create_netblock_list_from_str_(
param_str, no_create=no_create)
assert len(remaining_s) == 0
if the_inner_block_list is None or len(the_inner_block_list) == 0:
return None, ''
return Sequential(
inner_block_list=the_inner_block_list,
no_create=no_create,
block_name=tmp_block_name), ''
@staticmethod
def is_instance_from_str(s):
if s.startswith('Sequential('):
return True
else:
return False
'''
Super Blocks
'''
class SuperResKXKX(PlainNetBasicBlockClass):
def __init__(self,
in_channels=0,
out_channels=0,
kernel_size=3,
stride=1,
expansion=1.0,
sublayers=1,
no_create=False,
block_name=None,
**kwargs):
super(SuperResKXKX, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.expansion = expansion
self.stride = stride
self.sublayers = sublayers
self.no_create = no_create
self.block_name = block_name
self.shortcut_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for layerID in range(self.sublayers):
if layerID == 0:
current_in_channels = self.in_channels
current_out_channels = self.out_channels
current_stride = self.stride
current_kernel_size = self.kernel_size
else:
current_in_channels = self.out_channels
current_out_channels = self.out_channels
current_stride = 1
current_kernel_size = self.kernel_size
current_expansion_channel = int(
round(current_out_channels * self.expansion))
the_conv_block = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_expansion_channel,
kernel_size=current_kernel_size,
stride=current_stride,
padding=(current_kernel_size - 1) // 2,
bias=False),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_out_channels,
kernel_size=current_kernel_size,
stride=1,
padding=(current_kernel_size - 1) // 2,
bias=False),
nn.BatchNorm2d(current_out_channels),
)
self.conv_list.append(the_conv_block)
if current_stride == 1 and current_in_channels == current_out_channels:
shortcut = nn.Sequential()
else:
shortcut = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=current_stride,
padding=0,
bias=False), nn.BatchNorm2d(current_out_channels))
self.shortcut_list.append(shortcut)
pass # end for
def forward(self, x):
output = x
for block, shortcut in zip(self.conv_list, self.shortcut_list):
conv_output = block(output)
output = conv_output + shortcut(output)
output = F.relu(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert SuperResKXKX.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('SuperResKXKX('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
param_str_split = param_str.split(',')
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
kernel_size = int(param_str_split[2])
stride = int(param_str_split[3])
expansion = float(param_str_split[4])
sublayers = int(param_str_split[5])
return SuperResKXKX(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
expansion=expansion,
sublayers=sublayers,
block_name=tmp_block_name,
no_create=no_create), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('SuperResKXKX(') and s[-1] == ')':
return True
else:
return False
class SuperResK1KX(PlainNetBasicBlockClass):
def __init__(self,
in_channels=0,
out_channels=0,
kernel_size=3,
stride=1,
expansion=1.0,
sublayers=1,
no_create=False,
block_name=None,
**kwargs):
super(SuperResK1KX, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.expansion = expansion
self.stride = stride
self.sublayers = sublayers
self.no_create = no_create
self.block_name = block_name
self.shortcut_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for layerID in range(self.sublayers):
if layerID == 0:
current_in_channels = self.in_channels
current_out_channels = self.out_channels
current_stride = self.stride
current_kernel_size = self.kernel_size
else:
current_in_channels = self.out_channels
current_out_channels = self.out_channels
current_stride = 1
current_kernel_size = self.kernel_size
current_expansion_channel = int(
round(current_out_channels * self.expansion))
the_conv_block = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_expansion_channel,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_out_channels,
kernel_size=current_kernel_size,
stride=current_stride,
padding=(current_kernel_size - 1) // 2,
bias=False),
nn.BatchNorm2d(current_out_channels),
)
self.conv_list.append(the_conv_block)
if current_stride == 1 and current_in_channels == current_out_channels:
shortcut = nn.Sequential()
else:
shortcut = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=current_stride,
padding=0,
bias=False), nn.BatchNorm2d(current_out_channels))
self.shortcut_list.append(shortcut)
pass # end for
def forward(self, x):
output = x
for block, shortcut in zip(self.conv_list, self.shortcut_list):
conv_output = block(output)
output = conv_output + shortcut(output)
output = F.relu(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert SuperResK1KX.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('SuperResK1KX('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
param_str_split = param_str.split(',')
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
kernel_size = int(param_str_split[2])
stride = int(param_str_split[3])
expansion = float(param_str_split[4])
sublayers = int(param_str_split[5])
return SuperResK1KX(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
expansion=expansion,
sublayers=sublayers,
block_name=tmp_block_name,
no_create=no_create), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('SuperResK1KX(') and s[-1] == ')':
return True
else:
return False
class SuperResK1KXK1(PlainNetBasicBlockClass):
def __init__(self,
in_channels=0,
out_channels=0,
kernel_size=3,
stride=1,
expansion=1.0,
sublayers=1,
no_create=False,
block_name=None,
**kwargs):
super(SuperResK1KXK1, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.expansion = expansion
self.stride = stride
self.sublayers = sublayers
self.no_create = no_create
self.block_name = block_name
self.shortcut_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for layerID in range(self.sublayers):
if layerID == 0:
current_in_channels = self.in_channels
current_out_channels = self.out_channels
current_stride = self.stride
current_kernel_size = self.kernel_size
else:
current_in_channels = self.out_channels
current_out_channels = self.out_channels
current_stride = 1
current_kernel_size = self.kernel_size
current_expansion_channel = int(
round(current_out_channels * self.expansion))
the_conv_block = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_expansion_channel,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_expansion_channel,
kernel_size=current_kernel_size,
stride=current_stride,
padding=(current_kernel_size - 1) // 2,
bias=False),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(current_out_channels),
)
self.conv_list.append(the_conv_block)
if current_stride == 1 and current_in_channels == current_out_channels:
shortcut = nn.Sequential()
else:
shortcut = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=current_stride,
padding=0,
bias=False), nn.BatchNorm2d(current_out_channels))
self.shortcut_list.append(shortcut)
pass # end for
def forward(self, x):
output = x
for block, shortcut in zip(self.conv_list, self.shortcut_list):
conv_output = block(output)
output = conv_output + shortcut(output)
output = F.relu(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert SuperResK1KXK1.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('SuperResK1KXK1('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
param_str_split = param_str.split(',')
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
kernel_size = int(param_str_split[2])
stride = int(param_str_split[3])
expansion = float(param_str_split[4])
sublayers = int(param_str_split[5])
return SuperResK1KXK1(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
expansion=expansion,
sublayers=sublayers,
block_name=tmp_block_name,
no_create=no_create), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('SuperResK1KXK1(') and s[-1] == ')':
return True
else:
return False
class SuperResK1DWK1(PlainNetBasicBlockClass):
def __init__(self,
in_channels=0,
out_channels=0,
kernel_size=3,
stride=1,
expansion=1.0,
sublayers=1,
no_create=False,
block_name=None,
**kwargs):
super(SuperResK1DWK1, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.expansion = expansion
self.stride = stride
self.sublayers = sublayers
self.no_create = no_create
self.block_name = block_name
self.shortcut_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for layerID in range(self.sublayers):
if layerID == 0:
current_in_channels = self.in_channels
current_out_channels = self.out_channels
current_stride = self.stride
current_kernel_size = self.kernel_size
else:
current_in_channels = self.out_channels
current_out_channels = self.out_channels
current_stride = 1
current_kernel_size = self.kernel_size
current_expansion_channel = int(
round(current_out_channels * self.expansion))
the_conv_block = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_expansion_channel,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_expansion_channel,
kernel_size=current_kernel_size,
stride=current_stride,
padding=(current_kernel_size - 1) // 2,
bias=False,
groups=current_expansion_channel),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_expansion_channel,
current_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(current_out_channels),
)
self.conv_list.append(the_conv_block)
if current_stride == 1 and current_in_channels == current_out_channels:
shortcut = nn.Sequential()
else:
shortcut = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=current_stride,
padding=0,
bias=False), nn.BatchNorm2d(current_out_channels))
self.shortcut_list.append(shortcut)
pass # end for
def forward(self, x):
output = x
for block, shortcut in zip(self.conv_list, self.shortcut_list):
conv_output = block(output)
output = conv_output + shortcut(output)
output = F.relu(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert SuperResK1DWK1.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('SuperResK1DWK1('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
param_str_split = param_str.split(',')
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
kernel_size = int(param_str_split[2])
stride = int(param_str_split[3])
expansion = float(param_str_split[4])
sublayers = int(param_str_split[5])
return SuperResK1DWK1(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
expansion=expansion,
sublayers=sublayers,
block_name=tmp_block_name,
no_create=no_create), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('SuperResK1DWK1(') and s[-1] == ')':
return True
else:
return False
class SuperResK1DW(PlainNetBasicBlockClass):
def __init__(self,
in_channels=0,
out_channels=0,
kernel_size=3,
stride=1,
expansion=1.0,
sublayers=1,
no_create=False,
block_name=None,
**kwargs):
super(SuperResK1DW, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.expansion = expansion
assert abs(expansion - 1) < 1e-6
self.stride = stride
self.sublayers = sublayers
self.no_create = no_create
self.block_name = block_name
self.shortcut_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for layerID in range(self.sublayers):
if layerID == 0:
current_in_channels = self.in_channels
current_out_channels = self.out_channels
current_stride = self.stride
current_kernel_size = self.kernel_size
else:
current_in_channels = self.out_channels
current_out_channels = self.out_channels
current_stride = 1
current_kernel_size = self.kernel_size
current_expansion_channel = int(
round(current_out_channels * self.expansion))
the_conv_block = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(current_expansion_channel),
nn.ReLU(),
nn.Conv2d(
current_out_channels,
current_out_channels,
kernel_size=current_kernel_size,
stride=current_stride,
padding=(current_kernel_size - 1) // 2,
bias=False,
groups=current_out_channels),
nn.BatchNorm2d(current_out_channels),
)
self.conv_list.append(the_conv_block)
if current_stride == 1 and current_in_channels == current_out_channels:
shortcut = nn.Sequential()
else:
shortcut = nn.Sequential(
nn.Conv2d(
current_in_channels,
current_out_channels,
kernel_size=1,
stride=current_stride,
padding=0,
bias=False), nn.BatchNorm2d(current_out_channels))
self.shortcut_list.append(shortcut)
pass # end for
def forward(self, x):
output = x
for block, shortcut in zip(self.conv_list, self.shortcut_list):
conv_output = block(output)
output = conv_output + shortcut(output)
output = F.relu(output)
return output
@staticmethod
def create_from_str(s, no_create=False):
assert SuperResK1DW.is_instance_from_str(s)
idx = _get_right_parentheses_index_(s)
assert idx is not None
param_str = s[len('SuperResK1DW('):idx]
# find block_name
tmp_idx = param_str.find('|')
if tmp_idx < 0:
tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex)
else:
tmp_block_name = param_str[0:tmp_idx]
param_str = param_str[tmp_idx + 1:]
param_str_split = param_str.split(',')
in_channels = int(param_str_split[0])
out_channels = int(param_str_split[1])
kernel_size = int(param_str_split[2])
stride = int(param_str_split[3])
expansion = float(param_str_split[4])
sublayers = int(param_str_split[5])
return SuperResK1DW(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
expansion=expansion,
sublayers=sublayers,
block_name=tmp_block_name,
no_create=no_create), s[idx + 1:]
@staticmethod
def is_instance_from_str(s):
if s.startswith('SuperResK1DW(') and s[-1] == ')':
return True
else:
return False
_all_netblocks_dict_ = {
'AdaptiveAvgPool': AdaptiveAvgPool,
'BN': BN,
'ConvDW': ConvDW,
'ConvKX': ConvKX,
'Flatten': Flatten,
'Linear': Linear,
'MaxPool': MaxPool,
'MultiSumBlock': MultiSumBlock,
'PlainNetBasicBlockClass': PlainNetBasicBlockClass,
'RELU': RELU,
'ResBlock': ResBlock,
'Sequential': Sequential,
'SuperResKXKX': SuperResKXKX,
'SuperResK1KXK1': SuperResK1KXK1,
'SuperResK1DWK1': SuperResK1DWK1,
'SuperResK1KX': SuperResK1KX,
'SuperResK1DW': SuperResK1DW,
}
@BACKBONES.register_module
class PlainNet(nn.Module):
def __init__(self,
plainnet_struct_idx=None,
num_classes=0,
no_create=False,
**kwargs):
super(PlainNet, self).__init__(**kwargs)
self.num_classes = num_classes
self.plainnet_struct = plainnet_struct_dict[plainnet_struct_idx]
the_s = self.plainnet_struct # type: str
block_list, remaining_s = _create_netblock_list_from_str_(
the_s, no_create=no_create)
assert len(remaining_s) == 0
if isinstance(block_list[-1], AdaptiveAvgPool):
self.adptive_avg_pool = block_list[-1]
block_list.pop(-1)
else:
self.adptive_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.block_list = block_list
if not no_create:
self.module_list = nn.ModuleList(block_list) # register
self.last_channels = self.adptive_avg_pool.out_channels
if num_classes > 0:
self.fc_linear = nn.Linear(
self.last_channels, self.num_classes, bias=True)
else:
self.fc_linear = None
self.plainnet_struct = str(self) + str(self.adptive_avg_pool)
self.zero_init_residual = False
self.default_pretrained_model_path = model_urls[self.__class__.__name__
+ plainnet_struct_idx]
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', nonlinearity='relu')
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
def forward(self, x):
output = x
for the_block in self.block_list:
output = the_block(output)
if self.fc_linear is not None:
bs = output.size(0)
output = self.adptive_avg_pool(output)
output = output.view(bs, -1)
output = self.fc_linear(output)
return [output]