# Copyright (c) OpenMMLab. All rights reserved. try: import timm except ImportError: timm = None from ..builder import BACKBONES from .base_backbone import BaseBackbone @BACKBONES.register_module() class TIMMBackbone(BaseBackbone): """Wrapper to use backbones from timm library. More details can be found in `timm `_ . Args: model_name (str): Name of timm model to instantiate. pretrained (bool): Load pretrained weights if True. checkpoint_path (str): Path of checkpoint to load after model is initialized. in_channels (int): Number of input image channels. Default: 3. init_cfg (dict, optional): Initialization config dict **kwargs: Other timm & model specific arguments. """ def __init__( self, model_name, pretrained=False, checkpoint_path='', in_channels=3, init_cfg=None, **kwargs, ): if timm is None: raise RuntimeError('timm is not installed') super(TIMMBackbone, self).__init__(init_cfg) self.timm_model = timm.create_model( model_name=model_name, pretrained=pretrained, in_chans=in_channels, checkpoint_path=checkpoint_path, **kwargs, ) # Make unused parameters None self.timm_model.global_pool = None self.timm_model.fc = None self.timm_model.classifier = None # Hack to use pretrained weights from timm if pretrained or checkpoint_path: self._is_init = True def forward(self, x): features = self.timm_model.forward_features(x) return features