Merge pull request #2369 from brianhou0208/fix_reduction

Fix feature_info.reduction
openclip_weight_move
Ross Wightman 2024-12-18 16:51:53 -08:00 committed by GitHub
commit a02b1a8e79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 9 deletions

View File

@ -556,19 +556,19 @@ class DaVit(nn.Module):
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for stage_idx in range(num_stages):
out_chs = embed_dims[stage_idx]
for i in range(num_stages):
out_chs = embed_dims[i]
stage = DaVitStage(
in_chs,
out_chs,
depth=depths[stage_idx],
downsample=stage_idx > 0,
depth=depths[i],
downsample=i > 0,
attn_types=attn_types,
num_heads=num_heads[stage_idx],
num_heads=num_heads[i],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rates=dpr[stage_idx],
drop_path_rates=dpr[i],
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
ffn=ffn,
@ -579,7 +579,7 @@ class DaVit(nn.Module):
)
in_chs = out_chs
stages.append(stage)
self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]
self.feature_info += [dict(num_chs=out_chs, reduction=2**(i+2), module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)

View File

@ -407,7 +407,7 @@ class EfficientFormer(nn.Module):
)
prev_dim = embed_dims[i]
stages.append(stage)
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(1+i), module=f'stages.{i}')]
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(i+2), module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
# Classifier head

View File

@ -541,7 +541,7 @@ class MetaFormer(nn.Module):
**kwargs,
)]
prev_dim = dims[i]
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
self.feature_info += [dict(num_chs=dims[i], reduction=2**(i+2), module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)