diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 6e0c28f0..db1f2669 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -281,7 +281,7 @@ class GeGluMlp(nn.Module): def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs): out_indices = kwargs.pop('out_indices', 3) assert embed_cfg is not None - backbone = MbConvStages(cfg=embed_cfg) + backbone = MbConvStages(cfg=embed_cfg, in_chans=kwargs.get('in_chans', 3)) kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False) kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set