mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Two small fixes, num_classes in base class, add model tag
This commit is contained in:
parent
3318e7614d
commit
b71d60cdb7
@ -178,7 +178,6 @@ class RepViTClassifier(nn.Module):
|
||||
self.distillation = distillation
|
||||
if distillation:
|
||||
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, x):
|
||||
if self.distillation:
|
||||
@ -248,6 +247,7 @@ class RepViT(nn.Module):
|
||||
self.grad_checkpointing = False
|
||||
self.global_pool = global_pool
|
||||
self.embed_dim = embed_dim
|
||||
self.num_classes = num_classes
|
||||
|
||||
in_dim = embed_dim[0]
|
||||
self.stem = RepViTStem(in_chans, in_dim, act_layer)
|
||||
@ -356,13 +356,13 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
default_cfgs = generate_default_cfgs(
|
||||
{
|
||||
'repvit_m1': _cfg(
|
||||
'repvit_m1.dist_in1k': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
|
||||
),
|
||||
'repvit_m2': _cfg(
|
||||
'repvit_m2.dist_in1k': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
|
||||
),
|
||||
'repvit_m3': _cfg(
|
||||
'repvit_m3.dist_in1k': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
|
||||
),
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user