[Fix] Fix the bug in binary cross entropy loss (#1499)

* [Fix] Fix the bug in binary cross entropy loss

 Fix the bug in binary cross entropy loss when using multi-label datasets e.g.VOC2007

* update ci

---------

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
pull/1518/head
Wangbo Zhao(黑色枷锁) 2023-04-19 13:53:31 +08:00 committed by GitHub
parent fec3da781f
commit e954cf0aaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 5 deletions

View File

@ -34,7 +34,7 @@ jobs:
- name: Upgrade pip
run: pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Install mmpretrain dependencies
run: |
pip install git+https://github.com/open-mmlab/mmengine.git@main
@ -42,7 +42,7 @@ jobs:
mim install 'mmcv >= 2.0.0rc4'
pip install -r requirements.txt
- name: Build and install
run: pip install -e .
run: mim install .
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source mmpretrain -m pytest tests/
@ -129,7 +129,7 @@ jobs:
mim install 'mmcv >= 2.0.0rc4'
pip install -r requirements.txt
- name: Build and install
run: pip install -e . -v
run: mim install .
- name: Run unittests
run: |
pytest tests/ --ignore tests/test_models/test_backbones

View File

@ -39,6 +39,6 @@ jobs:
- name: Install openmim
run: pip install openmim
- name: Build and install
run: mim install -e .
run: mim install .
- name: test commands of mim
run: mim search mmpretrain

View File

@ -93,7 +93,7 @@ class MultiLabelClsHead(BaseModule):
num_classes = cls_score.size()[-1]
# Unpack data samples and pack targets
if 'gt_score' in data_samples[0]:
target = torch.stack([i.gt_score for i in data_samples])
target = torch.stack([i.gt_score.float() for i in data_samples])
else:
target = torch.stack([
label_to_onehot(i.gt_label, num_classes) for i in data_samples