fix fast catavgmax selection

This commit is contained in:
kalazus 2024-01-16 16:59:59 +03:00 committed by Ross Wightman
parent 8c663c4b86
commit 7f19a4cce7

View File

@ -139,10 +139,10 @@ class SelectAdaptivePool2d(nn.Module):
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
elif pool_type.startswith('fast') or input_fmt != 'NCHW':
assert output_size == 1, 'Fast pooling and non NCHW input formats require output_size == 1.'
if pool_type.endswith('avgmax'):
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
elif pool_type.endswith('catavgmax'):
if pool_type.endswith('catavgmax'):
self.pool = FastAdaptiveCatAvgMaxPool(flatten, input_fmt=input_fmt)
elif pool_type.endswith('avgmax'):
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
elif pool_type.endswith('max'):
self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt)
else: