fix num_classes not found

pull/1903/head
alec.tu 2023-08-07 15:16:03 +08:00
parent 81089b10a2
commit bb2b6b5f09
1 changed files with 1 additions and 0 deletions

View File

@ -176,6 +176,7 @@ class RepViTClassifier(nn.Module):
super().__init__() super().__init__()
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
self.distillation = distillation self.distillation = distillation
self.num_classes=num_classes
if distillation: if distillation:
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()