[Fix] Incorrect stage freeze on RIFormer Model (#1573)

* [Doc] RIFormer's README did not link to its paper properly

* Incorrect code for reproducing RIFormer 

the default value of frozen stage is set to 0, and the doc says that this will lead to no stage be frozen. But the actual case is the patch_embed will be freezed.

This may cause incorrect training, thus influencing the result.

I suggest a careful review.
pull/1586/head
ZhangYiqin 2023-05-22 16:01:32 +08:00 committed by GitHub
parent b058912c0c
commit 023d6869bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -1,6 +1,6 @@
# RIFormer # RIFormer
> [RIFormer: Keep Your Vision Backbone Effective But Removing Token Mixer](https://arxiv.org/abs/xxxx.xxxxx) > [RIFormer: Keep Your Vision Backbone Effective But Removing Token Mixer](https://arxiv.org/abs/2304.05659)
<!-- [ALGORITHM] --> <!-- [ALGORITHM] -->

View File

@ -202,7 +202,7 @@ class RIFormer(BaseBackbone):
[stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4] [stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4]
Defaults to -1, means the last stage. Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed). frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters. Defaults to -1, which means not freezing any parameters.
deploy (bool): Whether to switch the model structure to deploy (bool): Whether to switch the model structure to
deployment mode. Default: False. deployment mode. Default: False.
init_cfg (dict, optional): Initialization config dict init_cfg (dict, optional): Initialization config dict
@ -259,7 +259,7 @@ class RIFormer(BaseBackbone):
drop_rate=0., drop_rate=0.,
drop_path_rate=0., drop_path_rate=0.,
out_indices=-1, out_indices=-1,
frozen_stages=0, frozen_stages=-1,
init_cfg=None, init_cfg=None,
deploy=False): deploy=False):
@ -366,7 +366,7 @@ class RIFormer(BaseBackbone):
for param in self.patch_embed.parameters(): for param in self.patch_embed.parameters():
param.requires_grad = False param.requires_grad = False
for i in range(self.frozen_stages): for i in range(0, self.frozen_stages + 1):
# Include both block and downsample layer. # Include both block and downsample layer.
module = self.network[i] module = self.network[i]
module.eval() module.eval()