[Refactor] Support resizing pos_embed while loading ckpt and format output (#1488)
* support resize pos_embed while loading ckpt * updatepull/1490/head
parent
02571fe4b8
commit
e93d124ad4
|
@ -344,6 +344,14 @@ class ViTSAM(BaseBackbone):
|
|||
channel reduction layer is disabled. Defaults to 256.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"raw"`` or ``"featmap"``: The feature map tensor from the
|
||||
patch tokens with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
|
||||
Defaults to ``"raw"``.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
|
@ -392,6 +400,7 @@ class ViTSAM(BaseBackbone):
|
|||
'global_attn_indexes': [7, 15, 23, 31]
|
||||
}),
|
||||
}
|
||||
OUT_TYPES = {'raw', 'featmap', 'avg_featmap'}
|
||||
|
||||
def __init__(self,
|
||||
arch: str = 'base',
|
||||
|
@ -400,6 +409,7 @@ class ViTSAM(BaseBackbone):
|
|||
in_channels: int = 3,
|
||||
out_channels: int = 256,
|
||||
out_indices: int = -1,
|
||||
out_type: str = 'raw',
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
qkv_bias: bool = True,
|
||||
|
@ -444,7 +454,12 @@ class ViTSAM(BaseBackbone):
|
|||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
# num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# Set out type
|
||||
if out_type not in self.OUT_TYPES:
|
||||
raise ValueError(f'Unsupported `out_type` {out_type}, please '
|
||||
f'choose from {self.OUT_TYPES}')
|
||||
self.out_type = out_type
|
||||
|
||||
self.use_abs_pos = use_abs_pos
|
||||
self.interpolate_mode = interpolate_mode
|
||||
|
@ -453,6 +468,11 @@ class ViTSAM(BaseBackbone):
|
|||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, *self.patch_resolution, self.embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
|
||||
|
||||
if use_rel_pos:
|
||||
self._register_load_state_dict_pre_hook(
|
||||
self._prepare_relative_position)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
|
@ -565,8 +585,71 @@ class ViTSAM(BaseBackbone):
|
|||
x = layer(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
# (B, H, W, C) -> (B, C, H, W)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
|
||||
if self.out_channels > 0:
|
||||
x = self.channel_reduction(x.permute(0, 3, 1, 2))
|
||||
outs.append(x)
|
||||
x = self.channel_reduction(x)
|
||||
outs.append(self._format_output(x))
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _format_output(self, x) -> torch.Tensor:
|
||||
if self.out_type == 'raw' or self.out_type == 'featmap':
|
||||
return x
|
||||
elif self.out_type == 'avg_featmap':
|
||||
# (B, C, H, W) -> (B, C, N) -> (B, N, C)
|
||||
x = x.flatten(2).permute(0, 2, 1)
|
||||
return x.mean(dim=1)
|
||||
|
||||
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if self.pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmengine.logging import MMLogger
|
||||
logger = MMLogger.get_current_instance()
|
||||
logger.info(
|
||||
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
|
||||
f'to {self.pos_embed.shape}.')
|
||||
|
||||
ckpt_pos_embed_shape = ckpt_pos_embed_shape[1:3]
|
||||
pos_embed_shape = self.patch_embed.init_out_size
|
||||
|
||||
flattened_pos_embed = state_dict[name].flatten(1, 2)
|
||||
resized_pos_embed = resize_pos_embed(flattened_pos_embed,
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode, 0)
|
||||
state_dict[name] = resized_pos_embed.view(1, *pos_embed_shape,
|
||||
self.embed_dims)
|
||||
|
||||
def _prepare_relative_position(self, state_dict, prefix, *args, **kwargs):
|
||||
state_dict_model = self.state_dict()
|
||||
all_keys = list(state_dict_model.keys())
|
||||
for key in all_keys:
|
||||
if 'rel_pos_' in key:
|
||||
ckpt_key = prefix + key
|
||||
if ckpt_key not in state_dict:
|
||||
continue
|
||||
relative_position_pretrained = state_dict[ckpt_key]
|
||||
relative_position_current = state_dict_model[key]
|
||||
L1, _ = relative_position_pretrained.size()
|
||||
L2, _ = relative_position_current.size()
|
||||
if L1 != L2:
|
||||
new_rel_pos = F.interpolate(
|
||||
relative_position_pretrained.reshape(1, L1,
|
||||
-1).permute(
|
||||
0, 2, 1),
|
||||
size=L2,
|
||||
mode='linear',
|
||||
)
|
||||
new_rel_pos = new_rel_pos.reshape(-1, L2).permute(1, 0)
|
||||
from mmengine.logging import MMLogger
|
||||
logger = MMLogger.get_current_instance()
|
||||
logger.info(f'Resize the {ckpt_key} from '
|
||||
f'{state_dict[ckpt_key].shape} to '
|
||||
f'{new_rel_pos.shape}')
|
||||
state_dict[ckpt_key] = new_rel_pos
|
||||
|
|
Loading…
Reference in New Issue