From b71d60cdb77211addb4be4abbfd44a638acfe784 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 26 Jul 2023 13:18:49 -0700 Subject: [PATCH] Two small fixes, num_classes in base class, add model tag --- timm/models/repvit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/models/repvit.py b/timm/models/repvit.py index e5e32880..b0199b89 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -178,7 +178,6 @@ class RepViTClassifier(nn.Module): self.distillation = distillation if distillation: self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() - self.num_classes = num_classes def forward(self, x): if self.distillation: @@ -248,6 +247,7 @@ class RepViT(nn.Module): self.grad_checkpointing = False self.global_pool = global_pool self.embed_dim = embed_dim + self.num_classes = num_classes in_dim = embed_dim[0] self.stem = RepViTStem(in_chans, in_dim, act_layer) @@ -356,13 +356,13 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs( { - 'repvit_m1': _cfg( + 'repvit_m1.dist_in1k': _cfg( url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth' ), - 'repvit_m2': _cfg( + 'repvit_m2.dist_in1k': _cfg( url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth' ), - 'repvit_m3': _cfg( + 'repvit_m3.dist_in1k': _cfg( url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth' ), }