Make drop_connect rate scaling match official impl. Fixes #14
parent
13c19e213d
commit
9d653b68a2
|
@ -318,7 +318,11 @@ class _BlockBuilder:
|
||||||
self.folded_bn = folded_bn
|
self.folded_bn = folded_bn
|
||||||
self.padding_same = padding_same
|
self.padding_same = padding_same
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
|
# updated during build
|
||||||
self.in_chs = None
|
self.in_chs = None
|
||||||
|
self.block_idx = 0
|
||||||
|
self.block_count = 0
|
||||||
|
|
||||||
def _round_channels(self, chs):
|
def _round_channels(self, chs):
|
||||||
return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||||
|
@ -334,35 +338,40 @@ class _BlockBuilder:
|
||||||
# block act fn overrides the model default
|
# block act fn overrides the model default
|
||||||
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
|
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
|
||||||
assert ba['act_fn'] is not None
|
assert ba['act_fn'] is not None
|
||||||
if self.verbose:
|
|
||||||
logging.info(' Args: {}'.format(str(ba)))
|
|
||||||
# could replace this if with lambdas or functools binding if variety increases
|
|
||||||
if bt == 'ir':
|
if bt == 'ir':
|
||||||
ba['drop_connect_rate'] = self.drop_connect_rate
|
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
||||||
ba['se_gate_fn'] = self.se_gate_fn
|
ba['se_gate_fn'] = self.se_gate_fn
|
||||||
ba['se_reduce_mid'] = self.se_reduce_mid
|
ba['se_reduce_mid'] = self.se_reduce_mid
|
||||||
|
if self.verbose:
|
||||||
|
logging.info(' InvertedResidual {}, Args: {}'.format(self.block_idx, str(ba)))
|
||||||
block = InvertedResidual(**ba)
|
block = InvertedResidual(**ba)
|
||||||
elif bt == 'ds' or bt == 'dsa':
|
elif bt == 'ds' or bt == 'dsa':
|
||||||
ba['drop_connect_rate'] = self.drop_connect_rate
|
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
||||||
|
if self.verbose:
|
||||||
|
logging.info(' DepthwiseSeparable {}, Args: {}'.format(self.block_idx, str(ba)))
|
||||||
block = DepthwiseSeparableConv(**ba)
|
block = DepthwiseSeparableConv(**ba)
|
||||||
elif bt == 'cn':
|
elif bt == 'cn':
|
||||||
|
if self.verbose:
|
||||||
|
logging.info(' ConvBnAct {}, Args: {}'.format(self.block_idx, str(ba)))
|
||||||
block = ConvBnAct(**ba)
|
block = ConvBnAct(**ba)
|
||||||
else:
|
else:
|
||||||
assert False, 'Uknkown block type (%s) while building model.' % bt
|
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||||
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
||||||
|
|
||||||
return block
|
return block
|
||||||
|
|
||||||
def _make_stack(self, stack_args):
|
def _make_stack(self, stack_args):
|
||||||
blocks = []
|
blocks = []
|
||||||
# each stack (stage) contains a list of block arguments
|
# each stack (stage) contains a list of block arguments
|
||||||
for block_idx, ba in enumerate(stack_args):
|
for i, ba in enumerate(stack_args):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info(' Block: {}'.format(block_idx))
|
logging.info(' Block: {}'.format(i))
|
||||||
if block_idx >= 1:
|
if i >= 1:
|
||||||
# only the first block in any stack/stage can have a stride > 1
|
# only the first block in any stack can have a stride > 1
|
||||||
ba['stride'] = 1
|
ba['stride'] = 1
|
||||||
block = self._make_block(ba)
|
block = self._make_block(ba)
|
||||||
blocks.append(block)
|
blocks.append(block)
|
||||||
|
self.block_idx += 1 # incr global idx (across all stacks)
|
||||||
return nn.Sequential(*blocks)
|
return nn.Sequential(*blocks)
|
||||||
|
|
||||||
def __call__(self, in_chs, block_args):
|
def __call__(self, in_chs, block_args):
|
||||||
|
@ -377,6 +386,8 @@ class _BlockBuilder:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info('Building model trunk with %d stages...' % len(block_args))
|
logging.info('Building model trunk with %d stages...' % len(block_args))
|
||||||
self.in_chs = in_chs
|
self.in_chs = in_chs
|
||||||
|
self.block_count = sum([len(x) for x in block_args])
|
||||||
|
self.block_idx = 0
|
||||||
blocks = []
|
blocks = []
|
||||||
# outer list of block_args defines the stacks ('stages' by some conventions)
|
# outer list of block_args defines the stacks ('stages' by some conventions)
|
||||||
for stack_idx, stack in enumerate(block_args):
|
for stack_idx, stack in enumerate(block_args):
|
||||||
|
@ -1404,6 +1415,7 @@ def efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
""" EfficientNet-B0 """
|
""" EfficientNet-B0 """
|
||||||
default_cfg = default_cfgs['efficientnet_b0']
|
default_cfg = default_cfgs['efficientnet_b0']
|
||||||
# NOTE for train, drop_rate should be 0.2
|
# NOTE for train, drop_rate should be 0.2
|
||||||
|
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
||||||
model = _gen_efficientnet(
|
model = _gen_efficientnet(
|
||||||
channel_multiplier=1.0, depth_multiplier=1.0,
|
channel_multiplier=1.0, depth_multiplier=1.0,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
@ -1418,6 +1430,7 @@ def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
""" EfficientNet-B1 """
|
""" EfficientNet-B1 """
|
||||||
default_cfg = default_cfgs['efficientnet_b1']
|
default_cfg = default_cfgs['efficientnet_b1']
|
||||||
# NOTE for train, drop_rate should be 0.2
|
# NOTE for train, drop_rate should be 0.2
|
||||||
|
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
||||||
model = _gen_efficientnet(
|
model = _gen_efficientnet(
|
||||||
channel_multiplier=1.0, depth_multiplier=1.1,
|
channel_multiplier=1.0, depth_multiplier=1.1,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
@ -1432,6 +1445,7 @@ def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
""" EfficientNet-B2 """
|
""" EfficientNet-B2 """
|
||||||
default_cfg = default_cfgs['efficientnet_b2']
|
default_cfg = default_cfgs['efficientnet_b2']
|
||||||
# NOTE for train, drop_rate should be 0.3
|
# NOTE for train, drop_rate should be 0.3
|
||||||
|
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
||||||
model = _gen_efficientnet(
|
model = _gen_efficientnet(
|
||||||
channel_multiplier=1.1, depth_multiplier=1.2,
|
channel_multiplier=1.1, depth_multiplier=1.2,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
@ -1446,6 +1460,7 @@ def efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
""" EfficientNet-B3 """
|
""" EfficientNet-B3 """
|
||||||
default_cfg = default_cfgs['efficientnet_b3']
|
default_cfg = default_cfgs['efficientnet_b3']
|
||||||
# NOTE for train, drop_rate should be 0.3
|
# NOTE for train, drop_rate should be 0.3
|
||||||
|
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
||||||
model = _gen_efficientnet(
|
model = _gen_efficientnet(
|
||||||
channel_multiplier=1.2, depth_multiplier=1.4,
|
channel_multiplier=1.2, depth_multiplier=1.4,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
@ -1460,6 +1475,7 @@ def efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
""" EfficientNet-B4 """
|
""" EfficientNet-B4 """
|
||||||
default_cfg = default_cfgs['efficientnet_b4']
|
default_cfg = default_cfgs['efficientnet_b4']
|
||||||
# NOTE for train, drop_rate should be 0.4
|
# NOTE for train, drop_rate should be 0.4
|
||||||
|
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
||||||
model = _gen_efficientnet(
|
model = _gen_efficientnet(
|
||||||
channel_multiplier=1.4, depth_multiplier=1.8,
|
channel_multiplier=1.4, depth_multiplier=1.8,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
@ -1473,6 +1489,7 @@ def efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
""" EfficientNet-B5 """
|
""" EfficientNet-B5 """
|
||||||
# NOTE for train, drop_rate should be 0.4
|
# NOTE for train, drop_rate should be 0.4
|
||||||
|
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
||||||
default_cfg = default_cfgs['efficientnet_b5']
|
default_cfg = default_cfgs['efficientnet_b5']
|
||||||
model = _gen_efficientnet(
|
model = _gen_efficientnet(
|
||||||
channel_multiplier=1.6, depth_multiplier=2.2,
|
channel_multiplier=1.6, depth_multiplier=2.2,
|
||||||
|
|
Loading…
Reference in New Issue