# Copyright (c) Alibaba, Inc. and its affiliates. from distutils.version import LooseVersion import timm import torch import torch.nn as nn from easycv.utils import get_root_logger, load_checkpoint from ..modelzoo import timm_models as model_urls from ..registry import BACKBONES from .shuffle_transformer import (shuffletrans_base_p4_w7_224, shuffletrans_small_p4_w7_224, shuffletrans_tiny_p4_w7_224) from .swin_transformer_dynamic import (dynamic_swin_base_p4_w7_224, dynamic_swin_small_p4_w7_224, dynamic_swin_tiny_p4_w7_224) from .vit_transfomer_dynamic import (dynamic_deit_small_p16, dynamic_deit_tiny_p16, dynamic_vit_base_p16, dynamic_vit_huge_p14, dynamic_vit_large_p16) from .xcit_transformer import (xcit_large_24_p8, xcit_medium_24_p8, xcit_medium_24_p16, xcit_small_12_p8, xcit_small_12_p16) _MODEL_MAP = { # shuffle_transformer 'shuffletrans_tiny_p4_w7_224': shuffletrans_tiny_p4_w7_224, 'shuffletrans_base_p4_w7_224': shuffletrans_base_p4_w7_224, 'shuffletrans_small_p4_w7_224': shuffletrans_small_p4_w7_224, # swin_transformer_dynamic 'dynamic_swin_tiny_p4_w7_224': dynamic_swin_tiny_p4_w7_224, 'dynamic_swin_small_p4_w7_224': dynamic_swin_small_p4_w7_224, 'dynamic_swin_base_p4_w7_224': dynamic_swin_base_p4_w7_224, # vit_transfomer_dynamic 'dynamic_deit_small_p16': dynamic_deit_small_p16, 'dynamic_deit_tiny_p16': dynamic_deit_tiny_p16, 'dynamic_vit_base_p16': dynamic_vit_base_p16, 'dynamic_vit_large_p16': dynamic_vit_large_p16, 'dynamic_vit_huge_p14': dynamic_vit_huge_p14, # xcit_transformer 'xcit_small_12_p16': xcit_small_12_p16, 'xcit_small_12_p8': xcit_small_12_p8, 'xcit_medium_24_p16': xcit_medium_24_p16, 'xcit_medium_24_p8': xcit_medium_24_p8, 'xcit_large_24_p8': xcit_large_24_p8 } @BACKBONES.register_module class PytorchImageModelWrapper(nn.Module): """Support Backbones From pytorch-image-models. The PyTorch community has lots of awesome contributions for image models. PyTorch Image Models (timm) is a collection of image models, aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results. Model pages can be found at https://rwightman.github.io/pytorch-image-models/models/ References: https://github.com/rwightman/pytorch-image-models """ def __init__(self, model_name='resnet50', pretrained=False, checkpoint_path=None, scriptable=None, exportable=None, no_jit=None, **kwargs): """ Inits PytorchImageModelWrapper by timm.create_models Args: model_name (str): name of model to instantiate pretrained (bool): load pretrained ImageNet-1k weights if true checkpoint_path (str): path of checkpoint to load after model is initialized scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) """ super(PytorchImageModelWrapper, self).__init__() timm_model_names = timm.list_models(pretrained=False) assert model_name in timm_model_names or model_name in _MODEL_MAP, \ f'{model_name} is not in model_list of timm/fair, please check the model_name!' # Default to use backbone without head from timm if 'num_classes' not in kwargs: kwargs['num_classes'] = 0 # create model by timm if model_name in timm_model_names: try: if pretrained and (model_name in model_urls): self.model = timm.create_model(model_name, False, '', scriptable, exportable, no_jit, **kwargs) self.init_weights(model_urls[model_name]) print('Info: Load model from %s' % model_urls[model_name]) if checkpoint_path is not None: self.init_weights(checkpoint_path) else: # load from timm if pretrained and model_name.startswith('swin_') and ( LooseVersion( torch.__version__) <= LooseVersion('1.6.0')): print( 'Warning: Pretrained SwinTransformer from timm may be zipfile extract' ' error while torch<=1.6.0') self.model = timm.create_model(model_name, pretrained, checkpoint_path, scriptable, exportable, no_jit, **kwargs) # need fix: delete this except after pytorch 1.7 update in all production # (dlc, dsw, studio, ev_predict_py3) except Exception: print( f'Error: Fail to create {model_name} with (pretrained={pretrained}, checkpoint_path={checkpoint_path} ...)' ) print( f'Try to create {model_name} with pretrained=False, checkpoint_path=None and default params' ) self.model = timm.create_model(model_name, False, '', None, None, None, **kwargs) # facebook model wrapper if model_name in _MODEL_MAP: self.model = _MODEL_MAP[model_name](**kwargs) if pretrained: if model_name in model_urls.keys(): try_max = 3 try_idx = 0 while try_idx < try_max: try: state_dict = torch.hub.load_state_dict_from_url( url=model_urls[model_name], map_location='cpu', ) try_idx += try_max except Exception: try_idx += 1 state_dict = {} if try_idx == try_max: print( 'load from url failed ! oh my DLC & OSS, you boys really good! ', model_urls[model_name]) # for some model strict = False still failed when model doesn't exactly match try: self.model.load_state_dict(state_dict, strict=False) except Exception: print('load for model_name not all right') else: print('%s not in evtorch modelzoo!' % model_name) def init_weights(self, pretrained=None): # pretrained is the path of pretrained model offered by easycv if pretrained is not None: logger = get_root_logger() load_checkpoint( self.model, pretrained, map_location=torch.device('cpu'), strict=False, logger=logger) else: # init by timm pass def forward(self, x): o = self.model(x) if type(o) == tuple or type(o) == list: features = [] for feature in o: while feature.dim() < 4: feature = feature.unsqueeze(-1) features.append(feature) else: while o.dim() < 4: o = o.unsqueeze(-1) features = [o] return tuple(features)