mirror of https://github.com/alibaba/EasyCV.git
186 lines
7.9 KiB
Python
186 lines
7.9 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import importlib
|
|
from distutils.version import LooseVersion
|
|
|
|
import timm
|
|
import torch
|
|
import torch.nn as nn
|
|
from timm.models.helpers import load_pretrained
|
|
from timm.models.hub import download_cached_file
|
|
|
|
from easycv.utils.logger import get_root_logger, print_log
|
|
from ..modelzoo import timm_models as model_urls
|
|
from ..registry import BACKBONES
|
|
from .shuffle_transformer import (shuffletrans_base_p4_w7_224,
|
|
shuffletrans_small_p4_w7_224,
|
|
shuffletrans_tiny_p4_w7_224)
|
|
from .swin_transformer_dynamic import (dynamic_swin_base_p4_w7_224,
|
|
dynamic_swin_small_p4_w7_224,
|
|
dynamic_swin_tiny_p4_w7_224)
|
|
from .vit_transfomer_dynamic import (dynamic_deit_small_p16,
|
|
dynamic_deit_tiny_p16,
|
|
dynamic_vit_base_p16,
|
|
dynamic_vit_huge_p14,
|
|
dynamic_vit_large_p16)
|
|
from .xcit_transformer import (xcit_large_24_p8, xcit_medium_24_p8,
|
|
xcit_medium_24_p16, xcit_small_12_p8,
|
|
xcit_small_12_p16)
|
|
|
|
_MODEL_MAP = {
|
|
# shuffle_transformer
|
|
'shuffletrans_tiny_p4_w7_224': shuffletrans_tiny_p4_w7_224,
|
|
'shuffletrans_base_p4_w7_224': shuffletrans_base_p4_w7_224,
|
|
'shuffletrans_small_p4_w7_224': shuffletrans_small_p4_w7_224,
|
|
|
|
# swin_transformer_dynamic
|
|
'dynamic_swin_tiny_p4_w7_224': dynamic_swin_tiny_p4_w7_224,
|
|
'dynamic_swin_small_p4_w7_224': dynamic_swin_small_p4_w7_224,
|
|
'dynamic_swin_base_p4_w7_224': dynamic_swin_base_p4_w7_224,
|
|
|
|
# vit_transfomer_dynamic
|
|
'dynamic_deit_small_p16': dynamic_deit_small_p16,
|
|
'dynamic_deit_tiny_p16': dynamic_deit_tiny_p16,
|
|
'dynamic_vit_base_p16': dynamic_vit_base_p16,
|
|
'dynamic_vit_large_p16': dynamic_vit_large_p16,
|
|
'dynamic_vit_huge_p14': dynamic_vit_huge_p14,
|
|
|
|
# xcit_transformer
|
|
'xcit_small_12_p16': xcit_small_12_p16,
|
|
'xcit_small_12_p8': xcit_small_12_p8,
|
|
'xcit_medium_24_p16': xcit_medium_24_p16,
|
|
'xcit_medium_24_p8': xcit_medium_24_p8,
|
|
'xcit_large_24_p8': xcit_large_24_p8
|
|
}
|
|
|
|
|
|
@BACKBONES.register_module
|
|
class PytorchImageModelWrapper(nn.Module):
|
|
"""Support Backbones From pytorch-image-models.
|
|
|
|
The PyTorch community has lots of awesome contributions for image models. PyTorch Image Models (timm) is
|
|
a collection of image models, aim to pull together a wide variety of SOTA models with ability to reproduce
|
|
ImageNet training results.
|
|
|
|
Model pages can be found at https://rwightman.github.io/pytorch-image-models/models/
|
|
|
|
References: https://github.com/rwightman/pytorch-image-models
|
|
"""
|
|
|
|
def __init__(self,
|
|
model_name='resnet50',
|
|
scriptable=None,
|
|
exportable=None,
|
|
no_jit=None,
|
|
**kwargs):
|
|
"""
|
|
Inits PytorchImageModelWrapper by timm.create_models
|
|
Args:
|
|
model_name (str): name of model to instantiate
|
|
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
|
|
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
|
|
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
|
|
"""
|
|
super(PytorchImageModelWrapper, self).__init__()
|
|
|
|
self.model_name = model_name
|
|
|
|
timm_model_names = timm.list_models(pretrained=False)
|
|
self.timm_model_names = timm_model_names
|
|
assert model_name in timm_model_names or model_name in _MODEL_MAP, \
|
|
f'{model_name} is not in model_list of timm/fair, please check the model_name!'
|
|
|
|
# Default to use backbone without head from timm
|
|
if 'num_classes' not in kwargs:
|
|
kwargs['num_classes'] = 0
|
|
|
|
# create model by timm
|
|
if model_name in timm_model_names:
|
|
self.model = timm.create_model(model_name, False, '', scriptable,
|
|
exportable, no_jit, **kwargs)
|
|
elif model_name in _MODEL_MAP:
|
|
self.model = _MODEL_MAP[model_name](**kwargs)
|
|
|
|
def init_weights(self, pretrained=None):
|
|
"""
|
|
Args:
|
|
if pretrained == True, load model from default path;
|
|
if pretrained == False or None, load from init weights.
|
|
|
|
if model_name in timm_model_names, load model from timm default path;
|
|
if model_name in _MODEL_MAP, load model from easycv default path
|
|
"""
|
|
logger = get_root_logger()
|
|
if pretrained:
|
|
if self.model_name in self.timm_model_names:
|
|
default_pretrained_model_path = model_urls[self.model_name]
|
|
print_log(
|
|
'load model from default path: {}'.format(
|
|
default_pretrained_model_path), logger)
|
|
if default_pretrained_model_path.endswith('.npz'):
|
|
pretrained_loc = download_cached_file(
|
|
default_pretrained_model_path,
|
|
check_hash=False,
|
|
progress=False)
|
|
return self.model.load_pretrained(pretrained_loc)
|
|
else:
|
|
backbone_module = importlib.import_module(
|
|
self.model.__module__)
|
|
return load_pretrained(
|
|
self.model,
|
|
default_cfg={'url': default_pretrained_model_path},
|
|
filter_fn=backbone_module.checkpoint_filter_fn
|
|
if hasattr(backbone_module, 'checkpoint_filter_fn')
|
|
else None)
|
|
elif self.model_name in _MODEL_MAP:
|
|
if self.model_name in model_urls.keys():
|
|
default_pretrained_model_path = model_urls[self.model_name]
|
|
print_log(
|
|
'load model from default path: {}'.format(
|
|
default_pretrained_model_path), logger)
|
|
try_max = 3
|
|
try_idx = 0
|
|
while try_idx < try_max:
|
|
try:
|
|
state_dict = torch.hub.load_state_dict_from_url(
|
|
url=default_pretrained_model_path,
|
|
map_location='cpu',
|
|
)
|
|
try_idx += try_max
|
|
except Exception:
|
|
try_idx += 1
|
|
state_dict = {}
|
|
if try_idx == try_max:
|
|
print_log(
|
|
f'load from url failed ! oh my DLC & OSS, you boys really good! {model_urls[self.model_name]}',
|
|
logger)
|
|
|
|
if 'model' in state_dict:
|
|
state_dict = state_dict['model']
|
|
self.model.load_state_dict(state_dict, strict=False)
|
|
else:
|
|
raise ValueError('{} not in evtorch modelzoo!'.format(
|
|
self.model_name))
|
|
else:
|
|
raise ValueError(
|
|
'Error: Fail to create {} with (pretrained={}...)'.format(
|
|
self.model_name, pretrained))
|
|
else:
|
|
print_log('load model from init weights')
|
|
self.model.init_weights()
|
|
|
|
def forward(self, x):
|
|
|
|
o = self.model(x)
|
|
if type(o) == tuple or type(o) == list:
|
|
features = []
|
|
for feature in o:
|
|
while feature.dim() < 4:
|
|
feature = feature.unsqueeze(-1)
|
|
features.append(feature)
|
|
else:
|
|
while o.dim() < 4:
|
|
o = o.unsqueeze(-1)
|
|
features = [o]
|
|
|
|
return tuple(features)
|