mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Finish timm mode api for efficientformer_v2, add grad checkpointing support to both efficientformers
This commit is contained in:
parent
9d03c6f526
commit
95ec255f7f
@ -20,6 +20,7 @@ import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
|
||||
@ -335,7 +336,10 @@ class EfficientFormerStage(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
x = self.blocks(x)
|
||||
if self.grad_checkpointing:
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -25,6 +25,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
|
||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
|
||||
@ -498,7 +499,10 @@ class EfficientFormerV2Stage(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
x = self.blocks(x)
|
||||
if self.grad_checkpointing:
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -508,6 +512,7 @@ class EfficientFormerV2(nn.Module):
|
||||
depths,
|
||||
in_chans=3,
|
||||
img_size=224,
|
||||
global_pool='avg',
|
||||
embed_dims=None,
|
||||
downsamples=None,
|
||||
mlp_ratios=4,
|
||||
@ -522,7 +527,9 @@ class EfficientFormerV2(nn.Module):
|
||||
distillation=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert global_pool in ('avg', '')
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.feature_info = []
|
||||
img_size = to_2tuple(img_size)
|
||||
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
||||
@ -583,11 +590,49 @@ class EfficientFormerV2(nn.Module):
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {k for k, _ in self.named_parameters() if 'attention_biases' in k}
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
matcher = dict(
|
||||
stem=r'^stem', # stem and embed
|
||||
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
|
||||
)
|
||||
return matcher
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
for s in self.stages:
|
||||
s.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_distilled_training(self, enable=True):
|
||||
self.distilled_training = enable
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
x = self.norm(x)
|
||||
x = x.mean(dim=(2, 3))
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=(2, 3))
|
||||
if pre_logits:
|
||||
return x
|
||||
x, x_dist = self.head(x), self.head_dist(x)
|
||||
if self.distilled_training and self.training and not torch.jit.is_scripting():
|
||||
# only return separate classification predictions when training in distilled mode
|
||||
@ -596,6 +641,11 @@ class EfficientFormerV2(nn.Module):
|
||||
# during standard train/finetune, inference average the classifier predictions
|
||||
return (x + x_dist) / 2
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
|
Loading…
x
Reference in New Issue
Block a user