mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
support CNNs
This commit is contained in:
parent
d6701d8a81
commit
c11f4c3218
@ -7,21 +7,28 @@ from torch.nn.modules.transformer import _get_activation_fn
|
||||
|
||||
|
||||
def add_ml_decoder_head(model):
|
||||
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # resnet50
|
||||
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
|
||||
model.global_pool = nn.Identity()
|
||||
del model.fc
|
||||
num_classes = model.num_classes
|
||||
num_features = model.num_features
|
||||
model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
||||
elif hasattr(model, 'head'): # tresnet
|
||||
elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
|
||||
model.global_pool = nn.Identity()
|
||||
del model.classifier
|
||||
num_classes = model.num_classes
|
||||
num_features = model.num_features
|
||||
model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
||||
elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
|
||||
del model.head
|
||||
num_classes = model.num_classes
|
||||
num_features = model.num_features
|
||||
model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
||||
else:
|
||||
print("model is not suited for ml-decoder")
|
||||
print("Model code-writing is not aligned currently with ml-decoder")
|
||||
exit(-1)
|
||||
|
||||
if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
|
||||
model.drop_rate = 0
|
||||
return model
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user