diff --git a/timm/layers/adaptive_avgmax_pool.py b/timm/layers/adaptive_avgmax_pool.py index 16af4afd..d0dd58d9 100644 --- a/timm/layers/adaptive_avgmax_pool.py +++ b/timm/layers/adaptive_avgmax_pool.py @@ -134,6 +134,7 @@ class SelectAdaptivePool2d(nn.Module): super(SelectAdaptivePool2d, self).__init__() assert input_fmt in ('NCHW', 'NHWC') self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing + pool_type = pool_type.lower() if not pool_type: self.pool = nn.Identity() # pass through self.flatten = nn.Flatten(1) if flatten else nn.Identity() @@ -145,8 +146,10 @@ class SelectAdaptivePool2d(nn.Module): self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt) elif pool_type.endswith('max'): self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt) - else: + elif pool_type == 'fast' or pool_type.endswith('avg'): self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt) + else: + assert False, 'Invalid pool type: %s' % pool_type self.flatten = nn.Identity() else: assert input_fmt == 'NCHW' @@ -156,8 +159,10 @@ class SelectAdaptivePool2d(nn.Module): self.pool = AdaptiveCatAvgMaxPool2d(output_size) elif pool_type == 'max': self.pool = nn.AdaptiveMaxPool2d(output_size) - else: + elif pool_type == 'avg': self.pool = nn.AdaptiveAvgPool2d(output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type self.flatten = nn.Flatten(1) if flatten else nn.Identity() def is_identity(self):