diff --git a/vits.py b/vits.py index c3fb6ee..391c843 100644 --- a/vits.py +++ b/vits.py @@ -22,48 +22,6 @@ __all__ = [ ] -class ConvStem(nn.Module): - """ - ConvStem, follow the design in https://arxiv.org/abs/2106.14881 - """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): - super().__init__() - - assert patch_size == 16, 'ConvStem only supports patch size of 16' - assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' - - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - - # build stem - stem = [] - input_dim, output_dim = 3, embed_dim // 8 - for l in range(4): - stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) - stem.append(nn.BatchNorm2d(output_dim)) - stem.append(nn.ReLU(inplace=True)) - input_dim = output_dim - output_dim *= 2 - stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) - self.proj = nn.Sequential(*stem) - - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x): - B, C, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x) - if self.flatten: - x = x.flatten(2).transpose(1, 2) # BCHW -> BNC - x = self.norm(x) - return x - class VisionTransformerMoCo(VisionTransformer): def __init__(self, stop_grad_conv1=False, **kwargs): @@ -112,6 +70,49 @@ class VisionTransformerMoCo(VisionTransformer): self.pos_embed.requires_grad = False +class ConvStem(nn.Module): + """ + ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + + assert patch_size == 16, 'ConvStem only supports patch size of 16' + assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' + + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + # build stem, similar to follow the design in https://arxiv.org/abs/2106.14881 + stem = [] + input_dim, output_dim = 3, embed_dim // 8 + for l in range(4): + stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) + stem.append(nn.BatchNorm2d(output_dim)) + stem.append(nn.ReLU(inplace=True)) + input_dim = output_dim + output_dim *= 2 + stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) + self.proj = nn.Sequential(*stem) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + def vit_small(**kwargs): model = VisionTransformerMoCo( patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,