mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Finishing adding stochastic depth support to BiT ResNetV2 models
This commit is contained in:
parent
0a1668f63e
commit
d55bcc0fee
@ -249,6 +249,7 @@ class Bottleneck(nn.Module):
|
||||
x = self.norm2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.norm3(x)
|
||||
x = self.drop_path(x)
|
||||
x = self.act3(x + shortcut)
|
||||
return x
|
||||
|
||||
@ -366,9 +367,10 @@ class ResNetV2(nn.Module):
|
||||
prev_chs = stem_chs
|
||||
curr_stride = 4
|
||||
dilation = 1
|
||||
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
|
||||
block_fn = PreActBottleneck if preact else Bottleneck
|
||||
self.stages = nn.Sequential()
|
||||
for stage_idx, (d, c) in enumerate(zip(layers, channels)):
|
||||
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
|
||||
out_chs = make_div(c * wf)
|
||||
stride = 1 if stage_idx == 0 else 2
|
||||
if curr_stride >= output_stride:
|
||||
@ -376,7 +378,7 @@ class ResNetV2(nn.Module):
|
||||
stride = 1
|
||||
stage = ResNetStage(
|
||||
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
|
||||
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn)
|
||||
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
|
||||
prev_chs = out_chs
|
||||
curr_stride *= stride
|
||||
feat_name = f'stages.{stage_idx}'
|
||||
|
Loading…
x
Reference in New Issue
Block a user