ViTamin in_chans !=3 weight load fix

This commit is contained in:
Ross Wightman 2024-06-07 20:39:23 -07:00
parent 5517b054dd
commit 7702d9afa1

View File

@ -281,7 +281,7 @@ class GeGluMlp(nn.Module):
def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs):
out_indices = kwargs.pop('out_indices', 3)
assert embed_cfg is not None
backbone = MbConvStages(cfg=embed_cfg)
backbone = MbConvStages(cfg=embed_cfg, in_chans=kwargs.get('in_chans', 3))
kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set