Fix MobileNetV3 crash with global_pool='', output consistent with other models but not equivalent due to efficient head.
parent
fc8b8afb6f
commit
470220b1f4
|
@ -99,10 +99,11 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||
assert outputs.shape[-1] == model.num_features
|
||||
|
||||
# test model forward without pooling and classifier
|
||||
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
|
||||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 4
|
||||
if not isinstance(model, timm.models.MobileNetV3):
|
||||
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
|
||||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 4
|
||||
# FIXME mobilenetv3 forward_features vs removed pooling differ
|
||||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
||||
|
||||
# check classifier and first convolution names match those in default_cfg
|
||||
|
|
|
@ -101,7 +101,7 @@ class MobileNetV3(nn.Module):
|
|||
head_chs = builder.in_chs
|
||||
|
||||
# Head + Pooling
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if global_pool else nn.Identity()
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
num_pooled_chs = head_chs * self.global_pool.feat_mult()
|
||||
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
|
@ -122,7 +122,7 @@ class MobileNetV3(nn.Module):
|
|||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
# cannot meaningfully change pooling of efficient head after creation
|
||||
assert global_pool == self.global_pool.pool_type
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -136,7 +136,9 @@ class MobileNetV3(nn.Module):
|
|||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x).flatten(1)
|
||||
x = self.forward_features(x)
|
||||
if not self.global_pool.is_identity():
|
||||
x = x.flatten(1)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
|
|
Loading…
Reference in New Issue