mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Finishing adding stochastic depth support to BiT ResNetV2 models
This commit is contained in:
parent
0a1668f63e
commit
d55bcc0fee
@ -249,6 +249,7 @@ class Bottleneck(nn.Module):
|
|||||||
x = self.norm2(x)
|
x = self.norm2(x)
|
||||||
x = self.conv3(x)
|
x = self.conv3(x)
|
||||||
x = self.norm3(x)
|
x = self.norm3(x)
|
||||||
|
x = self.drop_path(x)
|
||||||
x = self.act3(x + shortcut)
|
x = self.act3(x + shortcut)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -366,9 +367,10 @@ class ResNetV2(nn.Module):
|
|||||||
prev_chs = stem_chs
|
prev_chs = stem_chs
|
||||||
curr_stride = 4
|
curr_stride = 4
|
||||||
dilation = 1
|
dilation = 1
|
||||||
|
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
|
||||||
block_fn = PreActBottleneck if preact else Bottleneck
|
block_fn = PreActBottleneck if preact else Bottleneck
|
||||||
self.stages = nn.Sequential()
|
self.stages = nn.Sequential()
|
||||||
for stage_idx, (d, c) in enumerate(zip(layers, channels)):
|
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
|
||||||
out_chs = make_div(c * wf)
|
out_chs = make_div(c * wf)
|
||||||
stride = 1 if stage_idx == 0 else 2
|
stride = 1 if stage_idx == 0 else 2
|
||||||
if curr_stride >= output_stride:
|
if curr_stride >= output_stride:
|
||||||
@ -376,7 +378,7 @@ class ResNetV2(nn.Module):
|
|||||||
stride = 1
|
stride = 1
|
||||||
stage = ResNetStage(
|
stage = ResNetStage(
|
||||||
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
|
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
|
||||||
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn)
|
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
|
||||||
prev_chs = out_chs
|
prev_chs = out_chs
|
||||||
curr_stride *= stride
|
curr_stride *= stride
|
||||||
feat_name = f'stages.{stage_idx}'
|
feat_name = f'stages.{stage_idx}'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user