mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix bug for reset classifier and fix for validating the dimension
This commit is contained in:
parent
3718c5a5bd
commit
9fe5798bee
@ -243,7 +243,8 @@ class CrossViT(nn.Module):
|
||||
|
||||
num_patches = _compute_num_patches(img_size, patch_size)
|
||||
self.num_branches = len(patch_size)
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.num_features = embed_dim[0] # to pass the tests
|
||||
self.patch_embed = nn.ModuleList()
|
||||
|
||||
# hard-coded for torch jit script
|
||||
@ -274,7 +275,6 @@ class CrossViT(nn.Module):
|
||||
|
||||
for i in range(self.num_branches):
|
||||
if hasattr(self, f'pos_embed_{i}'):
|
||||
# if self.pos_embed[i].requires_grad:
|
||||
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
|
||||
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
|
||||
|
||||
@ -301,7 +301,7 @@ class CrossViT(nn.Module):
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = nn.ModuleList([nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)])
|
||||
|
||||
def forward_features(self, x):
|
||||
B, C, H, W = x.shape
|
||||
|
Loading…
x
Reference in New Issue
Block a user