Add group_matcher to focalnet for proper layer-wise LR decay

This commit is contained in:
Ross Wightman 2023-03-23 23:21:49 -07:00
parent b271dc0e16
commit 33ada0cbca

View File

@ -436,6 +436,20 @@ class FocalNet(nn.Module):
def no_weight_decay(self):
return {''}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=[
(r'^layers\.(\d+)', None),
(r'^norm', (99999,))
] if coarse else [
(r'^layers\.(\d+).downsample', (0,)),
(r'^layers\.(\d+)\.\w+\.(\d+)', None),
(r'^norm', (99999,)),
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable