override device kwargs of base nn classes

pull/2350/head
Thien Tran 2025-03-30 17:05:25 +08:00
parent e44f14d7d2
commit d70c481179
1 changed files with 44 additions and 5 deletions

View File

@ -1,10 +1,12 @@
from contextlib import contextmanager, nullcontext
import dataclasses
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
from torch import nn as nn
from torch.hub import load_state_dict_from_url
@ -360,6 +362,27 @@ def resolve_pretrained_cfg(
return pretrained_cfg
@contextmanager
def make_meta_init(*classes):
def create_new_init(cls):
old_init = cls.__init__
def new_init(self, *args, **kwargs):
kwargs.update(device="meta")
old_init(self, *args, **kwargs)
return new_init
original_dict = dict()
for cls in classes:
original_dict[cls] = cls.__init__
cls.__init__ = create_new_init(cls)
yield
# restore original __init__()
for cls, old_init in original_dict.items():
cls.__init__ = old_init
def build_model_with_cfg(
model_cls: Callable,
variant: str,
@ -419,11 +442,27 @@ def build_model_with_cfg(
if 'feature_cls' in kwargs:
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
# use meta-device init to speed up loading pretrained weights.
# when num_classes is changed, we can't use meta device init since we need
# the original __init__() to initialize head from scratch.
num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"])
use_meta_init = (
pretrained
and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"])
)
# Instantiate the model
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
base_classes = [nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.LayerNorm]
with make_meta_init(*base_classes) if use_meta_init else nullcontext():
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
# convert meta-device tensors to concrete tensors
device = kwargs.get("device", torch.get_default_device())
model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t))
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat