Fix two fastvit issues

This commit is contained in:
Ross Wightman 2023-08-23 16:20:25 -07:00 committed by Ross Wightman
parent 5242ba6edc
commit 16334e4bec

View File

@ -1171,7 +1171,7 @@ class FastVit(nn.Module):
self.add_module(layer_name, layer)
else:
# Classifier head
final_features = int(embed_dims[-1] * cls_ratio)
self.num_features = final_features = int(embed_dims[-1] * cls_ratio)
self.final_conv = MobileOneBlock(
in_chs=embed_dims[-1],
out_chs=final_features,
@ -1182,7 +1182,6 @@ class FastVit(nn.Module):
use_se=True,
num_conv_branches=1,
)
self.num_features = final_features
self.head = ClassifierHead(
final_features,
num_classes,
@ -1241,11 +1240,10 @@ class FastVit(nn.Module):
if self.fork_feat:
# output the features of four stages for dense prediction
return outs
# output only the features of last layer for image classification
x = self.final_conv(x)
return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
x = self.final_conv(x)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -1266,6 +1264,7 @@ def _cfg(url="", **kwargs):
"interpolation": "bicubic",
"mean": IMAGENET_DEFAULT_MEAN,
"std": IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0.conv_kxk.0.conv',
"classifier": "head.fc",
**kwargs,
}