mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Finish timm mode api for efficientformer_v2, add grad checkpointing support to both efficientformers
This commit is contained in:
parent
9d03c6f526
commit
95ec255f7f
@ -20,6 +20,7 @@ import torch.nn as nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
|
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._manipulate import checkpoint_seq
|
||||||
from ._pretrained import generate_default_cfgs
|
from ._pretrained import generate_default_cfgs
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
|
|
||||||
@ -335,6 +336,9 @@ class EfficientFormerStage(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.downsample(x)
|
x = self.downsample(x)
|
||||||
|
if self.grad_checkpointing:
|
||||||
|
x = checkpoint_seq(self.blocks, x)
|
||||||
|
else:
|
||||||
x = self.blocks(x)
|
x = self.blocks(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
|
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
|
||||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
|
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._manipulate import checkpoint_seq
|
||||||
from ._pretrained import generate_default_cfgs
|
from ._pretrained import generate_default_cfgs
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
|
|
||||||
@ -498,6 +499,9 @@ class EfficientFormerV2Stage(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.downsample(x)
|
x = self.downsample(x)
|
||||||
|
if self.grad_checkpointing:
|
||||||
|
x = checkpoint_seq(self.blocks, x)
|
||||||
|
else:
|
||||||
x = self.blocks(x)
|
x = self.blocks(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -508,6 +512,7 @@ class EfficientFormerV2(nn.Module):
|
|||||||
depths,
|
depths,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
img_size=224,
|
img_size=224,
|
||||||
|
global_pool='avg',
|
||||||
embed_dims=None,
|
embed_dims=None,
|
||||||
downsamples=None,
|
downsamples=None,
|
||||||
mlp_ratios=4,
|
mlp_ratios=4,
|
||||||
@ -522,7 +527,9 @@ class EfficientFormerV2(nn.Module):
|
|||||||
distillation=True,
|
distillation=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert global_pool in ('avg', '')
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
self.global_pool = global_pool
|
||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
img_size = to_2tuple(img_size)
|
img_size = to_2tuple(img_size)
|
||||||
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
||||||
@ -583,11 +590,49 @@ class EfficientFormerV2(nn.Module):
|
|||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
def forward(self, x):
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay(self):
|
||||||
|
return {k for k, _ in self.named_parameters() if 'attention_biases' in k}
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def group_matcher(self, coarse=False):
|
||||||
|
matcher = dict(
|
||||||
|
stem=r'^stem', # stem and embed
|
||||||
|
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
|
||||||
|
)
|
||||||
|
return matcher
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
for s in self.stages:
|
||||||
|
s.grad_checkpointing = enable
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head, self.head_dist
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool=None):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
if global_pool is not None:
|
||||||
|
self.global_pool = global_pool
|
||||||
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_distilled_training(self, enable=True):
|
||||||
|
self.distilled_training = enable
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.stages(x)
|
x = self.stages(x)
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
|
if self.global_pool == 'avg':
|
||||||
x = x.mean(dim=(2, 3))
|
x = x.mean(dim=(2, 3))
|
||||||
|
if pre_logits:
|
||||||
|
return x
|
||||||
x, x_dist = self.head(x), self.head_dist(x)
|
x, x_dist = self.head(x), self.head_dist(x)
|
||||||
if self.distilled_training and self.training and not torch.jit.is_scripting():
|
if self.distilled_training and self.training and not torch.jit.is_scripting():
|
||||||
# only return separate classification predictions when training in distilled mode
|
# only return separate classification predictions when training in distilled mode
|
||||||
@ -596,6 +641,11 @@ class EfficientFormerV2(nn.Module):
|
|||||||
# during standard train/finetune, inference average the classifier predictions
|
# during standard train/finetune, inference average the classifier predictions
|
||||||
return (x + x_dist) / 2
|
return (x + x_dist) / 2
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.forward_head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user