mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix feature_info.reduction
This commit is contained in:
parent
ea231079f5
commit
ab0a70dfff
@ -556,19 +556,19 @@ class DaVit(nn.Module):
|
|||||||
|
|
||||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||||
stages = []
|
stages = []
|
||||||
for stage_idx in range(num_stages):
|
for i in range(num_stages):
|
||||||
out_chs = embed_dims[stage_idx]
|
out_chs = embed_dims[i]
|
||||||
stage = DaVitStage(
|
stage = DaVitStage(
|
||||||
in_chs,
|
in_chs,
|
||||||
out_chs,
|
out_chs,
|
||||||
depth=depths[stage_idx],
|
depth=depths[i],
|
||||||
downsample=stage_idx > 0,
|
downsample=i > 0,
|
||||||
attn_types=attn_types,
|
attn_types=attn_types,
|
||||||
num_heads=num_heads[stage_idx],
|
num_heads=num_heads[i],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop_path_rates=dpr[stage_idx],
|
drop_path_rates=dpr[i],
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
norm_layer_cl=norm_layer_cl,
|
norm_layer_cl=norm_layer_cl,
|
||||||
ffn=ffn,
|
ffn=ffn,
|
||||||
@ -579,7 +579,7 @@ class DaVit(nn.Module):
|
|||||||
)
|
)
|
||||||
in_chs = out_chs
|
in_chs = out_chs
|
||||||
stages.append(stage)
|
stages.append(stage)
|
||||||
self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]
|
self.feature_info += [dict(num_chs=out_chs, reduction=2**(i+2), module=f'stages.{i}')]
|
||||||
|
|
||||||
self.stages = nn.Sequential(*stages)
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
|
@ -407,7 +407,7 @@ class EfficientFormer(nn.Module):
|
|||||||
)
|
)
|
||||||
prev_dim = embed_dims[i]
|
prev_dim = embed_dims[i]
|
||||||
stages.append(stage)
|
stages.append(stage)
|
||||||
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(1+i), module=f'stages.{i}')]
|
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(i+2), module=f'stages.{i}')]
|
||||||
self.stages = nn.Sequential(*stages)
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
# Classifier head
|
# Classifier head
|
||||||
|
@ -541,7 +541,7 @@ class MetaFormer(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)]
|
)]
|
||||||
prev_dim = dims[i]
|
prev_dim = dims[i]
|
||||||
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
|
self.feature_info += [dict(num_chs=dims[i], reduction=2**(i+2), module=f'stages.{i}')]
|
||||||
|
|
||||||
self.stages = nn.Sequential(*stages)
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user