diff --git a/timm/models/_builder.py b/timm/models/_builder.py
index 482d370a..cf80b455 100644
--- a/timm/models/_builder.py
+++ b/timm/models/_builder.py
@@ -1,10 +1,12 @@
+from contextlib import contextmanager, nullcontext
 import dataclasses
 import logging
 import os
 from copy import deepcopy
 from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Optional, Tuple, Union
 
+import torch
 from torch import nn as nn
 from torch.hub import load_state_dict_from_url
 
@@ -360,6 +362,27 @@ def resolve_pretrained_cfg(
     return pretrained_cfg
 
 
+@contextmanager
+def make_meta_init(*classes):
+    def create_new_init(cls):
+        old_init = cls.__init__
+        def new_init(self, *args, **kwargs):
+            kwargs.update(device="meta")
+            old_init(self, *args, **kwargs)
+        return new_init
+
+    original_dict = dict()
+    for cls in classes:
+        original_dict[cls] = cls.__init__
+        cls.__init__ = create_new_init(cls)
+
+    yield
+
+    # restore original __init__()
+    for cls, old_init in original_dict.items():
+        cls.__init__ = old_init
+
+
 def build_model_with_cfg(
         model_cls: Callable,
         variant: str,
@@ -419,11 +442,27 @@ def build_model_with_cfg(
         if 'feature_cls' in kwargs:
             feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
 
+    # use meta-device init to speed up loading pretrained weights.
+    # when num_classes is changed, we can't use meta device init since we need
+    # the original __init__() to initialize head from scratch.
+    num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"])
+    use_meta_init = (
+        pretrained
+        and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"])
+    )
+
     # Instantiate the model
-    if model_cfg is None:
-        model = model_cls(**kwargs)
-    else:
-        model = model_cls(cfg=model_cfg, **kwargs)
+    base_classes = [nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.LayerNorm]
+    with make_meta_init(*base_classes) if use_meta_init else nullcontext():
+        if model_cfg is None:
+            model = model_cls(**kwargs)
+        else:
+            model = model_cls(cfg=model_cfg, **kwargs)
+
+    # convert meta-device tensors to concrete tensors
+    device = kwargs.get("device", torch.get_default_device())
+    model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t))
+
     model.pretrained_cfg = pretrained_cfg
     model.default_cfg = model.pretrained_cfg  # alias for backwards compat