From 33ada0cbca3a7af58b17c1645cdac27da6b07516 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 Mar 2023 23:21:49 -0700 Subject: [PATCH] Add group_matcher to focalnet for proper layer-wise LR decay --- timm/models/focalnet.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 6feeac65..57e2352f 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -436,6 +436,20 @@ class FocalNet(nn.Module): def no_weight_decay(self): return {''} + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=[ + (r'^layers\.(\d+)', None), + (r'^norm', (99999,)) + ] if coarse else [ + (r'^layers\.(\d+).downsample', (0,)), + (r'^layers\.(\d+)\.\w+\.(\d+)', None), + (r'^norm', (99999,)), + ] + ) + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable