mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix bug for reset classifier and fix for validating the dimension
This commit is contained in:
parent
3718c5a5bd
commit
9fe5798bee
@ -243,7 +243,8 @@ class CrossViT(nn.Module):
|
|||||||
|
|
||||||
num_patches = _compute_num_patches(img_size, patch_size)
|
num_patches = _compute_num_patches(img_size, patch_size)
|
||||||
self.num_branches = len(patch_size)
|
self.num_branches = len(patch_size)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_features = embed_dim[0] # to pass the tests
|
||||||
self.patch_embed = nn.ModuleList()
|
self.patch_embed = nn.ModuleList()
|
||||||
|
|
||||||
# hard-coded for torch jit script
|
# hard-coded for torch jit script
|
||||||
@ -274,7 +275,6 @@ class CrossViT(nn.Module):
|
|||||||
|
|
||||||
for i in range(self.num_branches):
|
for i in range(self.num_branches):
|
||||||
if hasattr(self, f'pos_embed_{i}'):
|
if hasattr(self, f'pos_embed_{i}'):
|
||||||
# if self.pos_embed[i].requires_grad:
|
|
||||||
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
|
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
|
||||||
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
|
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
|
||||||
|
|
||||||
@ -301,7 +301,7 @@ class CrossViT(nn.Module):
|
|||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool=''):
|
def reset_classifier(self, num_classes, global_pool=''):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.ModuleList([nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)])
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
|
Loading…
x
Reference in New Issue
Block a user