fix num_classes not found
parent
81089b10a2
commit
bb2b6b5f09
|
@ -176,6 +176,7 @@ class RepViTClassifier(nn.Module):
|
|||
super().__init__()
|
||||
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.distillation = distillation
|
||||
self.num_classes=num_classes
|
||||
if distillation:
|
||||
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
|
|
Loading…
Reference in New Issue