Merge c3445e9d75
into c8c4f256b8
commit
c0a54c4caa
|
@ -1,10 +1,12 @@
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.hub import load_state_dict_from_url
|
from torch.hub import load_state_dict_from_url
|
||||||
|
|
||||||
|
@ -370,6 +372,27 @@ def resolve_pretrained_cfg(
|
||||||
return pretrained_cfg
|
return pretrained_cfg
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def make_meta_init(*classes):
|
||||||
|
def create_new_init(cls):
|
||||||
|
old_init = cls.__init__
|
||||||
|
def new_init(self, *args, **kwargs):
|
||||||
|
kwargs.update(device="meta")
|
||||||
|
old_init(self, *args, **kwargs)
|
||||||
|
return new_init
|
||||||
|
|
||||||
|
original_dict = dict()
|
||||||
|
for cls in classes:
|
||||||
|
original_dict[cls] = cls.__init__
|
||||||
|
cls.__init__ = create_new_init(cls)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# restore original __init__()
|
||||||
|
for cls, old_init in original_dict.items():
|
||||||
|
cls.__init__ = old_init
|
||||||
|
|
||||||
|
|
||||||
def build_model_with_cfg(
|
def build_model_with_cfg(
|
||||||
model_cls: Callable,
|
model_cls: Callable,
|
||||||
variant: str,
|
variant: str,
|
||||||
|
@ -429,11 +452,28 @@ def build_model_with_cfg(
|
||||||
if 'feature_cls' in kwargs:
|
if 'feature_cls' in kwargs:
|
||||||
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
|
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
|
||||||
|
|
||||||
|
# use meta-device init to speed up loading pretrained weights.
|
||||||
|
# when num_classes is changed, we can't use meta device init since we need
|
||||||
|
# the original __init__() to initialize head from scratch.
|
||||||
|
num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"])
|
||||||
|
use_meta_init = (
|
||||||
|
pretrained
|
||||||
|
and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"])
|
||||||
|
)
|
||||||
|
|
||||||
# Instantiate the model
|
# Instantiate the model
|
||||||
|
base_classes = [nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.LayerNorm]
|
||||||
|
with make_meta_init(*base_classes) if use_meta_init else nullcontext():
|
||||||
if model_cfg is None:
|
if model_cfg is None:
|
||||||
model = model_cls(**kwargs)
|
model = model_cls(**kwargs)
|
||||||
else:
|
else:
|
||||||
model = model_cls(cfg=model_cfg, **kwargs)
|
model = model_cls(cfg=model_cfg, **kwargs)
|
||||||
|
|
||||||
|
# convert meta-device tensors to concrete tensors
|
||||||
|
default_device = torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
|
||||||
|
device = kwargs.get("device", default_device)
|
||||||
|
model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t))
|
||||||
|
|
||||||
model.pretrained_cfg = pretrained_cfg
|
model.pretrained_cfg = pretrained_cfg
|
||||||
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue