From e954cf0aaf4966fb7083220ca2737dc891a97418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wangbo=20Zhao=28=E9=BB=91=E8=89=B2=E6=9E=B7=E9=94=81=29?= <56866854+wangbo-zhao@users.noreply.github.com> Date: Wed, 19 Apr 2023 13:53:31 +0800 Subject: [PATCH] [Fix] Fix the bug in binary cross entropy loss (#1499) * [Fix] Fix the bug in binary cross entropy loss Fix the bug in binary cross entropy loss when using multi-label datasets e.g.VOC2007 * update ci --------- Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> --- .github/workflows/pr_stage_test.yml | 6 +++--- .github/workflows/test_mim.yml | 2 +- mmpretrain/models/heads/multi_label_cls_head.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index e3460471..8a7afea3 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -34,7 +34,7 @@ jobs: - name: Upgrade pip run: pip install pip --upgrade - name: Install PyTorch - run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html + run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - name: Install mmpretrain dependencies run: | pip install git+https://github.com/open-mmlab/mmengine.git@main @@ -42,7 +42,7 @@ jobs: mim install 'mmcv >= 2.0.0rc4' pip install -r requirements.txt - name: Build and install - run: pip install -e . + run: mim install . - name: Run unittests and generate coverage report run: | coverage run --branch --source mmpretrain -m pytest tests/ @@ -129,7 +129,7 @@ jobs: mim install 'mmcv >= 2.0.0rc4' pip install -r requirements.txt - name: Build and install - run: pip install -e . -v + run: mim install . - name: Run unittests run: | pytest tests/ --ignore tests/test_models/test_backbones diff --git a/.github/workflows/test_mim.yml b/.github/workflows/test_mim.yml index 5b092385..96e83633 100644 --- a/.github/workflows/test_mim.yml +++ b/.github/workflows/test_mim.yml @@ -39,6 +39,6 @@ jobs: - name: Install openmim run: pip install openmim - name: Build and install - run: mim install -e . + run: mim install . - name: test commands of mim run: mim search mmpretrain diff --git a/mmpretrain/models/heads/multi_label_cls_head.py b/mmpretrain/models/heads/multi_label_cls_head.py index e69b5277..ca36bfe0 100644 --- a/mmpretrain/models/heads/multi_label_cls_head.py +++ b/mmpretrain/models/heads/multi_label_cls_head.py @@ -93,7 +93,7 @@ class MultiLabelClsHead(BaseModule): num_classes = cls_score.size()[-1] # Unpack data samples and pack targets if 'gt_score' in data_samples[0]: - target = torch.stack([i.gt_score for i in data_samples]) + target = torch.stack([i.gt_score.float() for i in data_samples]) else: target = torch.stack([ label_to_onehot(i.gt_label, num_classes) for i in data_samples