mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
support CNNs
This commit is contained in:
parent
d6701d8a81
commit
c11f4c3218
@ -7,21 +7,28 @@ from torch.nn.modules.transformer import _get_activation_fn
|
|||||||
|
|
||||||
|
|
||||||
def add_ml_decoder_head(model):
|
def add_ml_decoder_head(model):
|
||||||
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # resnet50
|
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
|
||||||
model.global_pool = nn.Identity()
|
model.global_pool = nn.Identity()
|
||||||
del model.fc
|
del model.fc
|
||||||
num_classes = model.num_classes
|
num_classes = model.num_classes
|
||||||
num_features = model.num_features
|
num_features = model.num_features
|
||||||
model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
||||||
elif hasattr(model, 'head'): # tresnet
|
elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
|
||||||
|
model.global_pool = nn.Identity()
|
||||||
|
del model.classifier
|
||||||
|
num_classes = model.num_classes
|
||||||
|
num_features = model.num_features
|
||||||
|
model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
||||||
|
elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
|
||||||
del model.head
|
del model.head
|
||||||
num_classes = model.num_classes
|
num_classes = model.num_classes
|
||||||
num_features = model.num_features
|
num_features = model.num_features
|
||||||
model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
||||||
else:
|
else:
|
||||||
print("model is not suited for ml-decoder")
|
print("Model code-writing is not aligned currently with ml-decoder")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
|
||||||
|
model.drop_rate = 0
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user