Rethink name of patch embed grid info
parent
b2c305c2aa
commit
715519a5ef
|
@ -490,7 +490,7 @@ class CoaT(nn.Module):
|
|||
|
||||
# Serial blocks 1.
|
||||
x1 = self.patch_embed1(x0)
|
||||
H1, W1 = self.patch_embed1.out_size
|
||||
H1, W1 = self.patch_embed1.grid_size
|
||||
x1 = self.insert_cls(x1, self.cls_token1)
|
||||
for blk in self.serial_blocks1:
|
||||
x1 = blk(x1, size=(H1, W1))
|
||||
|
@ -499,7 +499,7 @@ class CoaT(nn.Module):
|
|||
|
||||
# Serial blocks 2.
|
||||
x2 = self.patch_embed2(x1_nocls)
|
||||
H2, W2 = self.patch_embed2.out_size
|
||||
H2, W2 = self.patch_embed2.grid_size
|
||||
x2 = self.insert_cls(x2, self.cls_token2)
|
||||
for blk in self.serial_blocks2:
|
||||
x2 = blk(x2, size=(H2, W2))
|
||||
|
@ -508,7 +508,7 @@ class CoaT(nn.Module):
|
|||
|
||||
# Serial blocks 3.
|
||||
x3 = self.patch_embed3(x2_nocls)
|
||||
H3, W3 = self.patch_embed3.out_size
|
||||
H3, W3 = self.patch_embed3.grid_size
|
||||
x3 = self.insert_cls(x3, self.cls_token3)
|
||||
for blk in self.serial_blocks3:
|
||||
x3 = blk(x3, size=(H3, W3))
|
||||
|
@ -517,7 +517,7 @@ class CoaT(nn.Module):
|
|||
|
||||
# Serial blocks 4.
|
||||
x4 = self.patch_embed4(x3_nocls)
|
||||
H4, W4 = self.patch_embed4.out_size
|
||||
H4, W4 = self.patch_embed4.grid_size
|
||||
x4 = self.insert_cls(x4, self.cls_token4)
|
||||
for blk in self.serial_blocks4:
|
||||
x4 = blk(x4, size=(H4, W4))
|
||||
|
|
|
@ -21,8 +21,8 @@ class PatchEmbed(nn.Module):
|
|||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.out_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.out_size[0] * self.out_size[1]
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
|
|
@ -467,7 +467,7 @@ class SwinTransformer(nn.Module):
|
|||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.patch_grid = self.patch_embed.out_size
|
||||
self.patch_grid = self.patch_embed.grid_size
|
||||
|
||||
# absolute position embedding
|
||||
if self.ape:
|
||||
|
|
Loading…
Reference in New Issue