Two small fixes, num_classes in base class, add model tag

This commit is contained in:
Ross Wightman 2023-07-26 13:18:49 -07:00
parent 3318e7614d
commit b71d60cdb7

View File

@ -178,7 +178,6 @@ class RepViTClassifier(nn.Module):
self.distillation = distillation
if distillation:
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
self.num_classes = num_classes
def forward(self, x):
if self.distillation:
@ -248,6 +247,7 @@ class RepViT(nn.Module):
self.grad_checkpointing = False
self.global_pool = global_pool
self.embed_dim = embed_dim
self.num_classes = num_classes
in_dim = embed_dim[0]
self.stem = RepViTStem(in_chans, in_dim, act_layer)
@ -356,13 +356,13 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs(
{
'repvit_m1': _cfg(
'repvit_m1.dist_in1k': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
),
'repvit_m2': _cfg(
'repvit_m2.dist_in1k': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
),
'repvit_m3': _cfg(
'repvit_m3.dist_in1k': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
),
}