From 7d83749207ee258a3e5fbec225b534dd0af1ccfd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 17 Aug 2024 08:27:13 -0700 Subject: [PATCH] pool size test fixes --- timm/models/efficientnet.py | 2 +- timm/models/hieradet_sam2.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 09d6c66c..2cf4130d 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -1294,7 +1294,7 @@ default_cfgs = generate_default_cfgs({ 'efficientnet_b1.ra4_e3600_r240_in1k': _cfg( hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - input_size=(3, 240, 240), crop_pct=0.9, + input_size=(3, 240, 240), crop_pct=0.9, pool_size=(8, 8), test_input_size=(3, 288, 288), test_crop_pct=1.0), 'efficientnet_b1.ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 652a59e0..4d57f5c3 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -294,6 +294,7 @@ class HieraDet(nn.Module): assert len(stages) == len(window_spec) self.num_classes = num_classes self.window_spec = window_spec + self.output_fmt = 'NHWC' depth = sum(stages) self.q_stride = q_stride