diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 482d370a..cf80b455 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -1,10 +1,12 @@ +from contextlib import contextmanager, nullcontext import dataclasses import logging import os from copy import deepcopy from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union +import torch from torch import nn as nn from torch.hub import load_state_dict_from_url @@ -360,6 +362,27 @@ def resolve_pretrained_cfg( return pretrained_cfg +@contextmanager +def make_meta_init(*classes): + def create_new_init(cls): + old_init = cls.__init__ + def new_init(self, *args, **kwargs): + kwargs.update(device="meta") + old_init(self, *args, **kwargs) + return new_init + + original_dict = dict() + for cls in classes: + original_dict[cls] = cls.__init__ + cls.__init__ = create_new_init(cls) + + yield + + # restore original __init__() + for cls, old_init in original_dict.items(): + cls.__init__ = old_init + + def build_model_with_cfg( model_cls: Callable, variant: str, @@ -419,11 +442,27 @@ def build_model_with_cfg( if 'feature_cls' in kwargs: feature_cfg['feature_cls'] = kwargs.pop('feature_cls') + # use meta-device init to speed up loading pretrained weights. + # when num_classes is changed, we can't use meta device init since we need + # the original __init__() to initialize head from scratch. + num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"]) + use_meta_init = ( + pretrained + and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"]) + ) + # Instantiate the model - if model_cfg is None: - model = model_cls(**kwargs) - else: - model = model_cls(cfg=model_cfg, **kwargs) + base_classes = [nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.LayerNorm] + with make_meta_init(*base_classes) if use_meta_init else nullcontext(): + if model_cfg is None: + model = model_cls(**kwargs) + else: + model = model_cls(cfg=model_cfg, **kwargs) + + # convert meta-device tensors to concrete tensors + device = kwargs.get("device", torch.get_default_device()) + model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t)) + model.pretrained_cfg = pretrained_cfg model.default_cfg = model.pretrained_cfg # alias for backwards compat