mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix binary C=1 focal loss & dataset fileio (#2935)
This commit is contained in:
parent
757f4a583e
commit
04f7ec60d8
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
@ -27,4 +28,5 @@ class ChaseDB1Dataset(BaseSegDataset):
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||
assert fileio.exists(
|
||||
self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
@ -27,4 +28,5 @@ class DRIVEDataset(BaseSegDataset):
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||
assert fileio.exists(
|
||||
self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
@ -27,4 +28,5 @@ class HRFDataset(BaseSegDataset):
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||
assert fileio.exists(
|
||||
self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
@ -26,4 +28,5 @@ class STAREDataset(BaseSegDataset):
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||
assert fileio.exists(
|
||||
self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
|
@ -271,7 +271,13 @@ class FocalLoss(nn.Module):
|
||||
num_classes = pred.size(1)
|
||||
if torch.cuda.is_available() and pred.is_cuda:
|
||||
if target.dim() == 1:
|
||||
one_hot_target = F.one_hot(target, num_classes=num_classes)
|
||||
one_hot_target = F.one_hot(
|
||||
target, num_classes=num_classes + 1)
|
||||
if num_classes == 1:
|
||||
one_hot_target = one_hot_target[:, 1]
|
||||
target = 1 - target
|
||||
else:
|
||||
one_hot_target = one_hot_target[:, :num_classes]
|
||||
else:
|
||||
one_hot_target = target
|
||||
target = target.argmax(dim=1)
|
||||
@ -280,7 +286,11 @@ class FocalLoss(nn.Module):
|
||||
else:
|
||||
one_hot_target = None
|
||||
if target.dim() == 1:
|
||||
target = F.one_hot(target, num_classes=num_classes)
|
||||
target = F.one_hot(target, num_classes=num_classes + 1)
|
||||
if num_classes == 1:
|
||||
target = target[:, 1]
|
||||
else:
|
||||
target = target[:, num_classes]
|
||||
else:
|
||||
valid_mask = (target.argmax(dim=1) != ignore_index).view(
|
||||
-1, 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user