fix feature_info.reduction

This commit is contained in:
Ryan 2024-12-18 21:12:40 +08:00
parent ea231079f5
commit ab0a70dfff
3 changed files with 9 additions and 9 deletions

View File

@ -556,19 +556,19 @@ class DaVit(nn.Module):
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for stage_idx in range(num_stages):
out_chs = embed_dims[stage_idx]
for i in range(num_stages):
out_chs = embed_dims[i]
stage = DaVitStage(
in_chs,
out_chs,
depth=depths[stage_idx],
downsample=stage_idx > 0,
depth=depths[i],
downsample=i > 0,
attn_types=attn_types,
num_heads=num_heads[stage_idx],
num_heads=num_heads[i],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rates=dpr[stage_idx],
drop_path_rates=dpr[i],
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
ffn=ffn,
@ -579,7 +579,7 @@ class DaVit(nn.Module):
)
in_chs = out_chs
stages.append(stage)
self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]
self.feature_info += [dict(num_chs=out_chs, reduction=2**(i+2), module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)

View File

@ -407,7 +407,7 @@ class EfficientFormer(nn.Module):
)
prev_dim = embed_dims[i]
stages.append(stage)
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(1+i), module=f'stages.{i}')]
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(i+2), module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
# Classifier head

View File

@ -541,7 +541,7 @@ class MetaFormer(nn.Module):
**kwargs,
)]
prev_dim = dims[i]
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
self.feature_info += [dict(num_chs=dims[i], reduction=2**(i+2), module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)