diff --git a/timm/layers/mlp.py b/timm/layers/mlp.py index d1e6774c..188c6b53 100644 --- a/timm/layers/mlp.py +++ b/timm/layers/mlp.py @@ -12,6 +12,8 @@ from .helpers import to_2tuple class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks + + NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. """ def __init__( self, @@ -51,6 +53,8 @@ class Mlp(nn.Module): class GluMlp(nn.Module): """ MLP w/ GLU style gating See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 + + NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. """ def __init__( self, @@ -192,7 +196,7 @@ class GatedMlp(nn.Module): class ConvMlp(nn.Module): - """ MLP using 1x1 convs that keeps spatial dims + """ MLP using 1x1 convs that keeps spatial dims (for 2D NCHW tensors) """ def __init__( self, @@ -226,6 +230,8 @@ class ConvMlp(nn.Module): class GlobalResponseNormMlp(nn.Module): """ MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d + + NOTE: Intended for '2D' NCHW (use_conv=True) or NHWC (use_conv=False, channels-last) tensor layouts """ def __init__( self,