From 7f19a4cce7004eee11704956d89c94566424f5ee Mon Sep 17 00:00:00 2001 From: kalazus <30507262+kalazus@users.noreply.github.com> Date: Tue, 16 Jan 2024 16:59:59 +0300 Subject: [PATCH] fix fast catavgmax selection --- timm/layers/adaptive_avgmax_pool.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/layers/adaptive_avgmax_pool.py b/timm/layers/adaptive_avgmax_pool.py index 7e4b762d..16af4afd 100644 --- a/timm/layers/adaptive_avgmax_pool.py +++ b/timm/layers/adaptive_avgmax_pool.py @@ -139,10 +139,10 @@ class SelectAdaptivePool2d(nn.Module): self.flatten = nn.Flatten(1) if flatten else nn.Identity() elif pool_type.startswith('fast') or input_fmt != 'NCHW': assert output_size == 1, 'Fast pooling and non NCHW input formats require output_size == 1.' - if pool_type.endswith('avgmax'): - self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt) - elif pool_type.endswith('catavgmax'): + if pool_type.endswith('catavgmax'): self.pool = FastAdaptiveCatAvgMaxPool(flatten, input_fmt=input_fmt) + elif pool_type.endswith('avgmax'): + self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt) elif pool_type.endswith('max'): self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt) else: