[Fix] Fix bugs when out_channels==1 (#2911)

This commit is contained in:
李开宇 2023-04-18 18:26:59 +08:00 committed by GitHub
parent ced29fcaf8
commit 4371ba5db6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 2 deletions

View File

@ -187,6 +187,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
if C > 1:
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
else:
i_seg_logits = i_seg_logits.sigmoid()
i_seg_pred = (i_seg_logits >
self.decode_head.threshold).to(i_seg_logits)
data_samples[i].set_data({

View File

@ -260,10 +260,10 @@ class EncoderDecoder(BaseSegmentor):
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = inputs.size()
num_classes = self.num_classes
out_channels = self.out_channels
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = inputs.new_zeros((batch_size, num_classes, h_img, w_img))
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):