[Fix] Fix binary C=1 focal loss & dataset fileio (#2935)

This commit is contained in:
CSH 2023-04-23 15:02:18 +08:00 committed by GitHub
parent 757f4a583e
commit 04f7ec60d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 25 additions and 6 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.fileio as fileio
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@ -27,4 +28,5 @@ class ChaseDB1Dataset(BaseSegDataset):
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
assert self.file_client.exists(self.data_prefix['img_path'])
assert fileio.exists(
self.data_prefix['img_path'], backend_args=self.backend_args)

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.fileio as fileio
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@ -27,4 +28,5 @@ class DRIVEDataset(BaseSegDataset):
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
assert self.file_client.exists(self.data_prefix['img_path'])
assert fileio.exists(
self.data_prefix['img_path'], backend_args=self.backend_args)

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.fileio as fileio
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@ -27,4 +28,5 @@ class HRFDataset(BaseSegDataset):
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
assert self.file_client.exists(self.data_prefix['img_path'])
assert fileio.exists(
self.data_prefix['img_path'], backend_args=self.backend_args)

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.fileio as fileio
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@ -26,4 +28,5 @@ class STAREDataset(BaseSegDataset):
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
assert self.file_client.exists(self.data_prefix['img_path'])
assert fileio.exists(
self.data_prefix['img_path'], backend_args=self.backend_args)

View File

@ -271,7 +271,13 @@ class FocalLoss(nn.Module):
num_classes = pred.size(1)
if torch.cuda.is_available() and pred.is_cuda:
if target.dim() == 1:
one_hot_target = F.one_hot(target, num_classes=num_classes)
one_hot_target = F.one_hot(
target, num_classes=num_classes + 1)
if num_classes == 1:
one_hot_target = one_hot_target[:, 1]
target = 1 - target
else:
one_hot_target = one_hot_target[:, :num_classes]
else:
one_hot_target = target
target = target.argmax(dim=1)
@ -280,7 +286,11 @@ class FocalLoss(nn.Module):
else:
one_hot_target = None
if target.dim() == 1:
target = F.one_hot(target, num_classes=num_classes)
target = F.one_hot(target, num_classes=num_classes + 1)
if num_classes == 1:
target = target[:, 1]
else:
target = target[:, num_classes]
else:
valid_mask = (target.argmax(dim=1) != ignore_index).view(
-1, 1)