Update simmim_neck.py

Update SimMIMLinearDecoder with `target_channels`. The downstream loss for SimMIM i.e. the `PixelReconstructionLoss` already allows user to set the number of channels through the `channel` argument. 
Useful in cases when reconstructing non-rgb images.
pull/1875/head
Ashutosh Singh 2024-02-25 23:00:24 +01:00 committed by GitHub
parent 17a886cb58
commit c088a191ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 2 deletions

View File

@ -15,14 +15,15 @@ class SimMIMLinearDecoder(BaseModule):
Args:
in_channels (int): Channel dimension of the feature map.
encoder_stride (int): The total stride of the encoder.
target_channels (int): Channel dimensions of original image.
"""
def __init__(self, in_channels: int, encoder_stride: int) -> None:
def __init__(self, in_channels: int, encoder_stride: int, target_channels: int = 3) -> None:
super().__init__()
self.decoder = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=encoder_stride**2 * 3,
out_channels=encoder_stride**2 * target_channels,
kernel_size=1),
nn.PixelShuffle(encoder_stride),
)