[Fix] Fix the bug in binary cross entropy loss (#1499)
* [Fix] Fix the bug in binary cross entropy loss Fix the bug in binary cross entropy loss when using multi-label datasets e.g.VOC2007 * update ci --------- Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>pull/1518/head
parent
fec3da781f
commit
e954cf0aaf
|
@ -34,7 +34,7 @@ jobs:
|
|||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
- name: Install mmpretrain dependencies
|
||||
run: |
|
||||
pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
|||
mim install 'mmcv >= 2.0.0rc4'
|
||||
pip install -r requirements.txt
|
||||
- name: Build and install
|
||||
run: pip install -e .
|
||||
run: mim install .
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmpretrain -m pytest tests/
|
||||
|
@ -129,7 +129,7 @@ jobs:
|
|||
mim install 'mmcv >= 2.0.0rc4'
|
||||
pip install -r requirements.txt
|
||||
- name: Build and install
|
||||
run: pip install -e . -v
|
||||
run: mim install .
|
||||
- name: Run unittests
|
||||
run: |
|
||||
pytest tests/ --ignore tests/test_models/test_backbones
|
||||
|
|
|
@ -39,6 +39,6 @@ jobs:
|
|||
- name: Install openmim
|
||||
run: pip install openmim
|
||||
- name: Build and install
|
||||
run: mim install -e .
|
||||
run: mim install .
|
||||
- name: test commands of mim
|
||||
run: mim search mmpretrain
|
||||
|
|
|
@ -93,7 +93,7 @@ class MultiLabelClsHead(BaseModule):
|
|||
num_classes = cls_score.size()[-1]
|
||||
# Unpack data samples and pack targets
|
||||
if 'gt_score' in data_samples[0]:
|
||||
target = torch.stack([i.gt_score for i in data_samples])
|
||||
target = torch.stack([i.gt_score.float() for i in data_samples])
|
||||
else:
|
||||
target = torch.stack([
|
||||
label_to_onehot(i.gt_label, num_classes) for i in data_samples
|
||||
|
|
Loading…
Reference in New Issue