EasyCV/easycv/models/backbones/pytorch_image_models_wrappe...

186 lines
7.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import importlib
from distutils.version import LooseVersion
import timm
import torch
import torch.nn as nn
from timm.models.helpers import load_pretrained
from timm.models.hub import download_cached_file
from easycv.utils.logger import get_root_logger, print_log
from ..modelzoo import timm_models as model_urls
from ..registry import BACKBONES
from .shuffle_transformer import (shuffletrans_base_p4_w7_224,
shuffletrans_small_p4_w7_224,
shuffletrans_tiny_p4_w7_224)
from .swin_transformer_dynamic import (dynamic_swin_base_p4_w7_224,
dynamic_swin_small_p4_w7_224,
dynamic_swin_tiny_p4_w7_224)
from .vit_transfomer_dynamic import (dynamic_deit_small_p16,
dynamic_deit_tiny_p16,
dynamic_vit_base_p16,
dynamic_vit_huge_p14,
dynamic_vit_large_p16)
from .xcit_transformer import (xcit_large_24_p8, xcit_medium_24_p8,
xcit_medium_24_p16, xcit_small_12_p8,
xcit_small_12_p16)
_MODEL_MAP = {
# shuffle_transformer
'shuffletrans_tiny_p4_w7_224': shuffletrans_tiny_p4_w7_224,
'shuffletrans_base_p4_w7_224': shuffletrans_base_p4_w7_224,
'shuffletrans_small_p4_w7_224': shuffletrans_small_p4_w7_224,
# swin_transformer_dynamic
'dynamic_swin_tiny_p4_w7_224': dynamic_swin_tiny_p4_w7_224,
'dynamic_swin_small_p4_w7_224': dynamic_swin_small_p4_w7_224,
'dynamic_swin_base_p4_w7_224': dynamic_swin_base_p4_w7_224,
# vit_transfomer_dynamic
'dynamic_deit_small_p16': dynamic_deit_small_p16,
'dynamic_deit_tiny_p16': dynamic_deit_tiny_p16,
'dynamic_vit_base_p16': dynamic_vit_base_p16,
'dynamic_vit_large_p16': dynamic_vit_large_p16,
'dynamic_vit_huge_p14': dynamic_vit_huge_p14,
# xcit_transformer
'xcit_small_12_p16': xcit_small_12_p16,
'xcit_small_12_p8': xcit_small_12_p8,
'xcit_medium_24_p16': xcit_medium_24_p16,
'xcit_medium_24_p8': xcit_medium_24_p8,
'xcit_large_24_p8': xcit_large_24_p8
}
@BACKBONES.register_module
class PytorchImageModelWrapper(nn.Module):
"""Support Backbones From pytorch-image-models.
The PyTorch community has lots of awesome contributions for image models. PyTorch Image Models (timm) is
a collection of image models, aim to pull together a wide variety of SOTA models with ability to reproduce
ImageNet training results.
Model pages can be found at https://rwightman.github.io/pytorch-image-models/models/
References: https://github.com/rwightman/pytorch-image-models
"""
def __init__(self,
model_name='resnet50',
scriptable=None,
exportable=None,
no_jit=None,
**kwargs):
"""
Inits PytorchImageModelWrapper by timm.create_models
Args:
model_name (str): name of model to instantiate
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
"""
super(PytorchImageModelWrapper, self).__init__()
self.model_name = model_name
timm_model_names = timm.list_models(pretrained=False)
self.timm_model_names = timm_model_names
assert model_name in timm_model_names or model_name in _MODEL_MAP, \
f'{model_name} is not in model_list of timm/fair, please check the model_name!'
# Default to use backbone without head from timm
if 'num_classes' not in kwargs:
kwargs['num_classes'] = 0
# create model by timm
if model_name in timm_model_names:
self.model = timm.create_model(model_name, False, '', scriptable,
exportable, no_jit, **kwargs)
elif model_name in _MODEL_MAP:
self.model = _MODEL_MAP[model_name](**kwargs)
def init_weights(self, pretrained=None):
"""
Args:
if pretrained == True, load model from default path;
if pretrained == False or None, load from init weights.
if model_name in timm_model_names, load model from timm default path;
if model_name in _MODEL_MAP, load model from easycv default path
"""
logger = get_root_logger()
if pretrained:
if self.model_name in self.timm_model_names:
default_pretrained_model_path = model_urls[self.model_name]
print_log(
'load model from default path: {}'.format(
default_pretrained_model_path), logger)
if default_pretrained_model_path.endswith('.npz'):
pretrained_loc = download_cached_file(
default_pretrained_model_path,
check_hash=False,
progress=False)
return self.model.load_pretrained(pretrained_loc)
else:
backbone_module = importlib.import_module(
self.model.__module__)
return load_pretrained(
self.model,
default_cfg={'url': default_pretrained_model_path},
filter_fn=backbone_module.checkpoint_filter_fn
if hasattr(backbone_module, 'checkpoint_filter_fn')
else None)
elif self.model_name in _MODEL_MAP:
if self.model_name in model_urls.keys():
default_pretrained_model_path = model_urls[self.model_name]
print_log(
'load model from default path: {}'.format(
default_pretrained_model_path), logger)
try_max = 3
try_idx = 0
while try_idx < try_max:
try:
state_dict = torch.hub.load_state_dict_from_url(
url=default_pretrained_model_path,
map_location='cpu',
)
try_idx += try_max
except Exception:
try_idx += 1
state_dict = {}
if try_idx == try_max:
print_log(
f'load from url failed ! oh my DLC & OSS, you boys really good! {model_urls[self.model_name]}',
logger)
if 'model' in state_dict:
state_dict = state_dict['model']
self.model.load_state_dict(state_dict, strict=False)
else:
raise ValueError('{} not in evtorch modelzoo!'.format(
self.model_name))
else:
raise ValueError(
'Error: Fail to create {} with (pretrained={}...)'.format(
self.model_name, pretrained))
else:
print_log('load model from init weights')
self.model.init_weights()
def forward(self, x):
o = self.model(x)
if type(o) == tuple or type(o) == list:
features = []
for feature in o:
while feature.dim() < 4:
feature = feature.unsqueeze(-1)
features.append(feature)
else:
while o.dim() < 4:
o = o.unsqueeze(-1)
features = [o]
return tuple(features)