diff --git a/patchconvnet_models.py b/patchconvnet_models.py index f4945bb..b4bd046 100644 --- a/patchconvnet_models.py +++ b/patchconvnet_models.py @@ -194,7 +194,7 @@ class ConvStem(nn.Module): self.num_patches = num_patches self.proj = torch.nn.Sequential( - conv3x3(3, embed_dim // 8, 2), + conv3x3(in_chans, embed_dim // 8, 2), nn.GELU(), conv3x3(embed_dim // 8, embed_dim // 4, 2), nn.GELU(),