pull/2350/merge
Thien Tran 2025-04-23 22:31:03 +03:00 committed by GitHub
commit c0a54c4caa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 45 additions and 5 deletions
timm/models

View File

@ -1,10 +1,12 @@
from contextlib import contextmanager, nullcontext
import dataclasses
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
from torch import nn as nn
from torch.hub import load_state_dict_from_url
@ -370,6 +372,27 @@ def resolve_pretrained_cfg(
return pretrained_cfg
@contextmanager
def make_meta_init(*classes):
def create_new_init(cls):
old_init = cls.__init__
def new_init(self, *args, **kwargs):
kwargs.update(device="meta")
old_init(self, *args, **kwargs)
return new_init
original_dict = dict()
for cls in classes:
original_dict[cls] = cls.__init__
cls.__init__ = create_new_init(cls)
yield
# restore original __init__()
for cls, old_init in original_dict.items():
cls.__init__ = old_init
def build_model_with_cfg(
model_cls: Callable,
variant: str,
@ -429,11 +452,28 @@ def build_model_with_cfg(
if 'feature_cls' in kwargs:
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
# use meta-device init to speed up loading pretrained weights.
# when num_classes is changed, we can't use meta device init since we need
# the original __init__() to initialize head from scratch.
num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"])
use_meta_init = (
pretrained
and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"])
)
# Instantiate the model
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
base_classes = [nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.LayerNorm]
with make_meta_init(*base_classes) if use_meta_init else nullcontext():
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
# convert meta-device tensors to concrete tensors
default_device = torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
device = kwargs.get("device", default_device)
model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t))
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat