fix num_classes not found

This commit is contained in:
alec.tu 2023-08-07 15:16:03 +08:00
parent 81089b10a2
commit bb2b6b5f09

View File

@ -176,6 +176,7 @@ class RepViTClassifier(nn.Module):
super().__init__()
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
self.distillation = distillation
self.num_classes=num_classes
if distillation:
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()