mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
update group_matcher
This commit is contained in:
parent
7fc0692843
commit
89d2952375
@ -16,7 +16,7 @@ Modifications by / Copyright 2025 Ryan Hou & Ross Wightman, original copyrights
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -118,6 +118,7 @@ class Block(nn.Module):
|
||||
merge_size: Union[int, Tuple[int, int]] = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.grad_checkpointing = False
|
||||
self.blocks = nn.Sequential(*[
|
||||
MLPBlock(
|
||||
dim=dim,
|
||||
@ -127,18 +128,22 @@ class Block(nn.Module):
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
pconv_fw_type=pconv_fw_type
|
||||
pconv_fw_type=pconv_fw_type,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.down = PatchMerging(
|
||||
self.downsample = PatchMerging(
|
||||
dim=dim // 2,
|
||||
patch_size=merge_size,
|
||||
norm_layer=norm_layer,
|
||||
) if use_merge else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.blocks(self.down(x))
|
||||
x = self.downsample(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x, flatten=True)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -202,7 +207,6 @@ class FasterNet(nn.Module):
|
||||
depths = (depths) # it means the model has only one stage
|
||||
self.num_stages = len(depths)
|
||||
self.feature_info = []
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_chans=in_chans,
|
||||
@ -255,20 +259,26 @@ class FasterNet(nn.Module):
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self) -> Set:
|
||||
return set()
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
||||
matcher = dict(
|
||||
stem=r'patch_embed',
|
||||
blocks=[
|
||||
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
|
||||
(r'conv_head', (99999,))
|
||||
stem=r'^patch_embed', # stem and embed
|
||||
blocks=r'^stages\.(\d+)' if coarse else [
|
||||
(r'^stages\.(\d+).downsample', (0,)),
|
||||
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||
(r'^conv_head', (99999,)),
|
||||
]
|
||||
)
|
||||
return matcher
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
for s in self.stages:
|
||||
s.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self) -> nn.Module:
|
||||
@ -339,9 +349,6 @@ class FasterNet(nn.Module):
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.stages, x, flatten=True)
|
||||
else:
|
||||
x = self.stages(x)
|
||||
return x
|
||||
|
||||
@ -371,11 +378,11 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module)
|
||||
}
|
||||
|
||||
stage_mapping = {
|
||||
'stages.1.': 'stages.1.down.',
|
||||
'stages.1.': 'stages.1.downsample.',
|
||||
'stages.2.': 'stages.1.',
|
||||
'stages.3.': 'stages.2.down.',
|
||||
'stages.3.': 'stages.2.downsample.',
|
||||
'stages.4.': 'stages.2.',
|
||||
'stages.5.': 'stages.3.down.',
|
||||
'stages.5.': 'stages.3.downsample.',
|
||||
'stages.6.': 'stages.3.'
|
||||
}
|
||||
|
||||
|
@ -58,7 +58,7 @@ class Conv2d_BN(nn.Sequential):
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
bn_weight_init: int = 1,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.add_module('c', nn.Conv2d(
|
||||
@ -229,7 +229,8 @@ class StageBlock(nn.Module):
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
):
|
||||
super().__init__()
|
||||
self.down = nn.Sequential(
|
||||
self.grad_checkpointing = False
|
||||
self.downsample = nn.Sequential(
|
||||
Residule(Conv2d_BN(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim)),
|
||||
Residule(FFN(prev_dim, int(prev_dim * 2), act_layer)),
|
||||
PatchMerging(prev_dim, dim, act_layer),
|
||||
@ -237,13 +238,16 @@ class StageBlock(nn.Module):
|
||||
Residule(FFN(dim, int(dim * 2), act_layer)),
|
||||
) if prev_dim != dim else nn.Identity()
|
||||
|
||||
self.block = nn.Sequential(*[
|
||||
self.blocks = nn.Sequential(*[
|
||||
BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth)
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.down(x)
|
||||
x = self.block(x)
|
||||
x = self.downsample(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x, flatten=True)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -265,7 +269,6 @@ class SHViT(nn.Module):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
self.grad_checkpointing = False
|
||||
self.feature_info = []
|
||||
|
||||
# Patch embedding
|
||||
@ -281,10 +284,10 @@ class SHViT(nn.Module):
|
||||
)
|
||||
|
||||
# Build SHViT blocks
|
||||
blocks = []
|
||||
stages = []
|
||||
prev_chs = stem_chs
|
||||
for i in range(len(embed_dim)):
|
||||
blocks.append(StageBlock(
|
||||
stages.append(StageBlock(
|
||||
prev_dim=prev_chs,
|
||||
dim=embed_dim[i],
|
||||
qk_dim=qk_dim[i],
|
||||
@ -295,9 +298,9 @@ class SHViT(nn.Module):
|
||||
act_layer=act_layer,
|
||||
))
|
||||
prev_chs = embed_dim[i]
|
||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'blocks.{i}'))
|
||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'stages.{i}'))
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
# Classifier head
|
||||
self.num_features = self.head_hidden_size = embed_dim[-1]
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
@ -310,12 +313,19 @@ class SHViT(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
||||
matcher = dict(stem=r'^patch_embed', blocks=[(r'^blocks\.(\d+)', None)])
|
||||
matcher = dict(
|
||||
stem=r'^patch_embed', # stem and embed
|
||||
blocks=r'^stages\.(\d+)' if coarse else [
|
||||
(r'^stages\.(\d+).downsample', (0,)),
|
||||
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||
]
|
||||
)
|
||||
return matcher
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
for s in self.stages:
|
||||
s.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self) -> nn.Module:
|
||||
@ -351,14 +361,14 @@ class SHViT(nn.Module):
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.blocks
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.blocks[:max_index + 1]
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
@ -378,18 +388,15 @@ class SHViT(nn.Module):
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x, flatten=True)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
x = self.stages(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||
@ -424,19 +431,19 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module)
|
||||
out_dict = {}
|
||||
|
||||
replace_rules = [
|
||||
(re.compile(r'^blocks1\.'), 'blocks.0.block.'),
|
||||
(re.compile(r'^blocks2\.'), 'blocks.1.block.'),
|
||||
(re.compile(r'^blocks3\.'), 'blocks.2.block.'),
|
||||
(re.compile(r'^blocks1\.'), 'stages.0.blocks.'),
|
||||
(re.compile(r'^blocks2\.'), 'stages.1.blocks.'),
|
||||
(re.compile(r'^blocks3\.'), 'stages.2.blocks.'),
|
||||
]
|
||||
downsample_mapping = {}
|
||||
for i in range(1, 3):
|
||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.0\\.'] = f'blocks.{i}.down.0.'
|
||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.1\\.'] = f'blocks.{i}.down.1.'
|
||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.1\\.'] = f'blocks.{i}.down.2.'
|
||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.0\\.'] = f'blocks.{i}.down.3.'
|
||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.1\\.'] = f'blocks.{i}.down.4.'
|
||||
downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.0\\.'] = f'stages.{i}.downsample.0.'
|
||||
downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.1\\.'] = f'stages.{i}.downsample.1.'
|
||||
downsample_mapping[f'^stages\\.{i}\\.blocks\\.1\\.'] = f'stages.{i}.downsample.2.'
|
||||
downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.0\\.'] = f'stages.{i}.downsample.3.'
|
||||
downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.1\\.'] = f'stages.{i}.downsample.4.'
|
||||
for j in range(3, 10):
|
||||
downsample_mapping[f'^blocks\\.{i}\\.block\\.{j}\\.'] = f'blocks.{i}.block.{j - 3}.'
|
||||
downsample_mapping[f'^stages\\.{i}\\.blocks\\.{j}\\.'] = f'stages.{i}.blocks.{j - 3}.'
|
||||
|
||||
downsample_patterns = [
|
||||
(re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()]
|
||||
|
@ -34,7 +34,7 @@ class ConvBN(nn.Sequential):
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
with_bn: bool = True,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.add_module('conv', nn.Conv2d(
|
||||
@ -141,7 +141,10 @@ class StarNet(nn.Module):
|
||||
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
||||
matcher = dict(
|
||||
stem=r'^stem\.\d+',
|
||||
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
|
||||
blocks=[
|
||||
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
|
||||
(r'norm', (99999,))
|
||||
]
|
||||
)
|
||||
return matcher
|
||||
|
||||
@ -206,6 +209,7 @@ class StarNet(nn.Module):
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
@ -402,7 +402,11 @@ class SwiftFormer(nn.Module):
|
||||
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
|
||||
matcher = dict(
|
||||
stem=r'^stem', # stem and embed
|
||||
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
|
||||
blocks=r'^stages\.(\d+)' if coarse else [
|
||||
(r'^stages\.(\d+).downsample', (0,)),
|
||||
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||
(r'^norm', (99999,)),
|
||||
]
|
||||
)
|
||||
return matcher
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user