Support input of non three-channel image

Add in_chans augment and pass it to PatchEmbed layer.
pull/1866/head
ZhangYiqin 2024-01-16 19:05:28 +08:00 committed by GitHub
parent 17a886cb58
commit f378f3614c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 1 deletions

View File

@ -220,6 +220,7 @@ class PoolFormer(BaseBackbone):
Defaults to ``dict(type='LN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
in_chans (int): The num of channels of input image.
in_patch_size (int): The patch size of input image patch embedding.
Defaults to 7.
in_stride (int): The stride of input image patch embedding.
@ -285,6 +286,7 @@ class PoolFormer(BaseBackbone):
pool_size=3,
norm_cfg=dict(type='GN', num_groups=1),
act_cfg=dict(type='GELU'),
in_chans=3,
in_patch_size=7,
in_stride=4,
in_pad=2,
@ -320,7 +322,7 @@ class PoolFormer(BaseBackbone):
patch_size=in_patch_size,
stride=in_stride,
padding=in_pad,
in_chans=3,
in_chans=in_chans,
embed_dim=embed_dims[0])
# set the main block in network