mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Two small fixes, num_classes in base class, add model tag
This commit is contained in:
parent
3318e7614d
commit
b71d60cdb7
@ -178,7 +178,6 @@ class RepViTClassifier(nn.Module):
|
|||||||
self.distillation = distillation
|
self.distillation = distillation
|
||||||
if distillation:
|
if distillation:
|
||||||
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
self.num_classes = num_classes
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.distillation:
|
if self.distillation:
|
||||||
@ -248,6 +247,7 @@ class RepViT(nn.Module):
|
|||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
self.num_classes = num_classes
|
||||||
|
|
||||||
in_dim = embed_dim[0]
|
in_dim = embed_dim[0]
|
||||||
self.stem = RepViTStem(in_chans, in_dim, act_layer)
|
self.stem = RepViTStem(in_chans, in_dim, act_layer)
|
||||||
@ -356,13 +356,13 @@ def _cfg(url='', **kwargs):
|
|||||||
|
|
||||||
default_cfgs = generate_default_cfgs(
|
default_cfgs = generate_default_cfgs(
|
||||||
{
|
{
|
||||||
'repvit_m1': _cfg(
|
'repvit_m1.dist_in1k': _cfg(
|
||||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
|
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
|
||||||
),
|
),
|
||||||
'repvit_m2': _cfg(
|
'repvit_m2.dist_in1k': _cfg(
|
||||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
|
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
|
||||||
),
|
),
|
||||||
'repvit_m3': _cfg(
|
'repvit_m3.dist_in1k': _cfg(
|
||||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
|
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user