mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix bugs when out_channels==1 (#2911)
This commit is contained in:
parent
ced29fcaf8
commit
4371ba5db6
@ -187,6 +187,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
if C > 1:
|
||||
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
|
||||
else:
|
||||
i_seg_logits = i_seg_logits.sigmoid()
|
||||
i_seg_pred = (i_seg_logits >
|
||||
self.decode_head.threshold).to(i_seg_logits)
|
||||
data_samples[i].set_data({
|
||||
|
@ -260,10 +260,10 @@ class EncoderDecoder(BaseSegmentor):
|
||||
h_stride, w_stride = self.test_cfg.stride
|
||||
h_crop, w_crop = self.test_cfg.crop_size
|
||||
batch_size, _, h_img, w_img = inputs.size()
|
||||
num_classes = self.num_classes
|
||||
out_channels = self.out_channels
|
||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||
preds = inputs.new_zeros((batch_size, num_classes, h_img, w_img))
|
||||
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
|
||||
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
|
||||
for h_idx in range(h_grids):
|
||||
for w_idx in range(w_grids):
|
||||
|
Loading…
x
Reference in New Issue
Block a user