mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add group_matcher to focalnet for proper layer-wise LR decay
This commit is contained in:
parent
b271dc0e16
commit
33ada0cbca
@ -436,6 +436,20 @@ class FocalNet(nn.Module):
|
||||
def no_weight_decay(self):
|
||||
return {''}
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
stem=r'^stem',
|
||||
blocks=[
|
||||
(r'^layers\.(\d+)', None),
|
||||
(r'^norm', (99999,))
|
||||
] if coarse else [
|
||||
(r'^layers\.(\d+).downsample', (0,)),
|
||||
(r'^layers\.(\d+)\.\w+\.(\d+)', None),
|
||||
(r'^norm', (99999,)),
|
||||
]
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
Loading…
x
Reference in New Issue
Block a user