mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix two fastvit issues
This commit is contained in:
parent
5242ba6edc
commit
16334e4bec
@ -1171,7 +1171,7 @@ class FastVit(nn.Module):
|
||||
self.add_module(layer_name, layer)
|
||||
else:
|
||||
# Classifier head
|
||||
final_features = int(embed_dims[-1] * cls_ratio)
|
||||
self.num_features = final_features = int(embed_dims[-1] * cls_ratio)
|
||||
self.final_conv = MobileOneBlock(
|
||||
in_chs=embed_dims[-1],
|
||||
out_chs=final_features,
|
||||
@ -1182,7 +1182,6 @@ class FastVit(nn.Module):
|
||||
use_se=True,
|
||||
num_conv_branches=1,
|
||||
)
|
||||
self.num_features = final_features
|
||||
self.head = ClassifierHead(
|
||||
final_features,
|
||||
num_classes,
|
||||
@ -1241,11 +1240,10 @@ class FastVit(nn.Module):
|
||||
if self.fork_feat:
|
||||
# output the features of four stages for dense prediction
|
||||
return outs
|
||||
# output only the features of last layer for image classification
|
||||
x = self.final_conv(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
|
||||
x = self.final_conv(x)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -1266,6 +1264,7 @@ def _cfg(url="", **kwargs):
|
||||
"interpolation": "bicubic",
|
||||
"mean": IMAGENET_DEFAULT_MEAN,
|
||||
"std": IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0.conv_kxk.0.conv',
|
||||
"classifier": "head.fc",
|
||||
**kwargs,
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user