override device kwargs of base nn classes
parent
e44f14d7d2
commit
d70c481179
|
@ -1,10 +1,12 @@
|
|||
from contextlib import contextmanager, nullcontext
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
||||
|
@ -360,6 +362,27 @@ def resolve_pretrained_cfg(
|
|||
return pretrained_cfg
|
||||
|
||||
|
||||
@contextmanager
|
||||
def make_meta_init(*classes):
|
||||
def create_new_init(cls):
|
||||
old_init = cls.__init__
|
||||
def new_init(self, *args, **kwargs):
|
||||
kwargs.update(device="meta")
|
||||
old_init(self, *args, **kwargs)
|
||||
return new_init
|
||||
|
||||
original_dict = dict()
|
||||
for cls in classes:
|
||||
original_dict[cls] = cls.__init__
|
||||
cls.__init__ = create_new_init(cls)
|
||||
|
||||
yield
|
||||
|
||||
# restore original __init__()
|
||||
for cls, old_init in original_dict.items():
|
||||
cls.__init__ = old_init
|
||||
|
||||
|
||||
def build_model_with_cfg(
|
||||
model_cls: Callable,
|
||||
variant: str,
|
||||
|
@ -419,11 +442,27 @@ def build_model_with_cfg(
|
|||
if 'feature_cls' in kwargs:
|
||||
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
|
||||
|
||||
# use meta-device init to speed up loading pretrained weights.
|
||||
# when num_classes is changed, we can't use meta device init since we need
|
||||
# the original __init__() to initialize head from scratch.
|
||||
num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"])
|
||||
use_meta_init = (
|
||||
pretrained
|
||||
and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"])
|
||||
)
|
||||
|
||||
# Instantiate the model
|
||||
if model_cfg is None:
|
||||
model = model_cls(**kwargs)
|
||||
else:
|
||||
model = model_cls(cfg=model_cfg, **kwargs)
|
||||
base_classes = [nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.LayerNorm]
|
||||
with make_meta_init(*base_classes) if use_meta_init else nullcontext():
|
||||
if model_cfg is None:
|
||||
model = model_cls(**kwargs)
|
||||
else:
|
||||
model = model_cls(cfg=model_cfg, **kwargs)
|
||||
|
||||
# convert meta-device tensors to concrete tensors
|
||||
device = kwargs.get("device", torch.get_default_device())
|
||||
model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t))
|
||||
|
||||
model.pretrained_cfg = pretrained_cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
||||
|
||||
|
|
Loading…
Reference in New Issue