fix num_classes not found
parent
81089b10a2
commit
bb2b6b5f09
|
@ -176,6 +176,7 @@ class RepViTClassifier(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
self.distillation = distillation
|
self.distillation = distillation
|
||||||
|
self.num_classes=num_classes
|
||||||
if distillation:
|
if distillation:
|
||||||
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue