[Refactor] Support resizing pos_embed while loading ckpt and format output (#1488)

* support resize pos_embed while loading ckpt

* update
pull/1490/head
Yixiao Fang 2023-04-14 19:08:35 +08:00 committed by GitHub
parent 02571fe4b8
commit e93d124ad4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 86 additions and 3 deletions

View File

@ -344,6 +344,14 @@ class ViTSAM(BaseBackbone):
channel reduction layer is disabled. Defaults to 256.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
out_type (str): The type of output features. Please choose from
- ``"raw"`` or ``"featmap"``: The feature map tensor from the
patch tokens with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
Defaults to ``"raw"``.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
@ -392,6 +400,7 @@ class ViTSAM(BaseBackbone):
'global_attn_indexes': [7, 15, 23, 31]
}),
}
OUT_TYPES = {'raw', 'featmap', 'avg_featmap'}
def __init__(self,
arch: str = 'base',
@ -400,6 +409,7 @@ class ViTSAM(BaseBackbone):
in_channels: int = 3,
out_channels: int = 256,
out_indices: int = -1,
out_type: str = 'raw',
drop_rate: float = 0.,
drop_path_rate: float = 0.,
qkv_bias: bool = True,
@ -444,7 +454,12 @@ class ViTSAM(BaseBackbone):
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
# num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
self.use_abs_pos = use_abs_pos
self.interpolate_mode = interpolate_mode
@ -453,6 +468,11 @@ class ViTSAM(BaseBackbone):
self.pos_embed = nn.Parameter(
torch.zeros(1, *self.patch_resolution, self.embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
if use_rel_pos:
self._register_load_state_dict_pre_hook(
self._prepare_relative_position)
if isinstance(out_indices, int):
out_indices = [out_indices]
@ -565,8 +585,71 @@ class ViTSAM(BaseBackbone):
x = layer(x)
if i in self.out_indices:
# (B, H, W, C) -> (B, C, H, W)
x = x.permute(0, 3, 1, 2)
if self.out_channels > 0:
x = self.channel_reduction(x.permute(0, 3, 1, 2))
outs.append(x)
x = self.channel_reduction(x)
outs.append(self._format_output(x))
return tuple(outs)
def _format_output(self, x) -> torch.Tensor:
if self.out_type == 'raw' or self.out_type == 'featmap':
return x
elif self.out_type == 'avg_featmap':
# (B, C, H, W) -> (B, C, N) -> (B, N, C)
x = x.flatten(2).permute(0, 2, 1)
return x.mean(dim=1)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = ckpt_pos_embed_shape[1:3]
pos_embed_shape = self.patch_embed.init_out_size
flattened_pos_embed = state_dict[name].flatten(1, 2)
resized_pos_embed = resize_pos_embed(flattened_pos_embed,
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode, 0)
state_dict[name] = resized_pos_embed.view(1, *pos_embed_shape,
self.embed_dims)
def _prepare_relative_position(self, state_dict, prefix, *args, **kwargs):
state_dict_model = self.state_dict()
all_keys = list(state_dict_model.keys())
for key in all_keys:
if 'rel_pos_' in key:
ckpt_key = prefix + key
if ckpt_key not in state_dict:
continue
relative_position_pretrained = state_dict[ckpt_key]
relative_position_current = state_dict_model[key]
L1, _ = relative_position_pretrained.size()
L2, _ = relative_position_current.size()
if L1 != L2:
new_rel_pos = F.interpolate(
relative_position_pretrained.reshape(1, L1,
-1).permute(
0, 2, 1),
size=L2,
mode='linear',
)
new_rel_pos = new_rel_pos.reshape(-1, L2).permute(1, 0)
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(f'Resize the {ckpt_key} from '
f'{state_dict[ckpt_key].shape} to '
f'{new_rel_pos.shape}')
state_dict[ckpt_key] = new_rel_pos