Merge pull request #1903 from twmht/fix_num_classes

fix num_classes not found in repvit
This commit is contained in:
Ross Wightman 2023-08-07 16:35:36 -07:00 committed by GitHub
commit f6771909ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -176,6 +176,7 @@ class RepViTClassifier(nn.Module):
super().__init__() super().__init__()
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
self.distillation = distillation self.distillation = distillation
self.num_classes=num_classes
if distillation: if distillation:
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()