fix mmseg exportation for out_channels=1 (#1997)

master
RunningLeon 2023-05-04 12:51:05 +08:00 committed by GitHub
parent c73756366e
commit 335ef8648d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 36 additions and 20 deletions

View File

@ -28,7 +28,6 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [3.7]
torch: [1.8.0, 1.9.0]
mmcv: [1.4.2]
include:
@ -40,22 +39,22 @@ jobs:
torchvision: 0.10.0
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
run: |
python -m pip install --upgrade pip
python -V
python -m pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install MMCV
run: |
pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cpu/${{matrix.torch_version}}/index.html
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cpu/${{matrix.torch_version}}/index.html
python -c 'import mmcv; print(mmcv.__version__)'
- name: Install unittest dependencies
run: |
pip install -r requirements.txt
pip install -U numpy
python -m pip install -U numpy
python -m pip install rapidfuzz==2.15.1
python -m pip install -r requirements.txt
- name: Build and install
run: rm -rf .eggs && pip install -e .
run: rm -rf .eggs && python -m pip install -e .
- name: Run python unittests and generate coverage report
run: |
coverage run --branch --source mmdeploy -m pytest -rsE tests
@ -139,6 +138,7 @@ jobs:
python -m pip install -U pip
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cu102/${{matrix.torch_version}}/index.html
python -m pip install -r requirements.txt
python -m pip install rapidfuzz==2.15.1
- name: Build and install
run: |
rm -rf .eggs && python -m pip install -e .
@ -174,6 +174,7 @@ jobs:
python -m pip install -U pip
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cu111/${{matrix.torch_version}}/index.html
python -m pip install -r requirements.txt
python -m pip install rapidfuzz==2.15.1
- name: Build and install
run: |
rm -rf .eggs && python -m pip install -e .

View File

@ -47,6 +47,7 @@ jobs:
python -m pip install -U pip
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cu111/${{matrix.torch_version}}/index.html
python -m pip install -r requirements.txt
python -m pip install rapidfuzz==2.15.1
- name: Install mmcls
run: |
cd ~

View File

@ -286,8 +286,13 @@ class Segmentation(BaseTask):
postprocess = self.model_cfg.model.decode_head
if isinstance(postprocess, list):
postprocess = postprocess[-1]
postprocess = postprocess.copy()
with_argmax = get_codebase_config(self.deploy_cfg).get(
'with_argmax', True)
# set with_argmax=True for this special case
if postprocess['num_classes'] == 2 and \
postprocess['out_channels'] == 1:
with_argmax = True
postprocess['with_argmax'] = with_argmax
return postprocess

View File

@ -25,10 +25,14 @@ def encoder_decoder__simple_test(ctx, self, img, img_meta, **kwargs):
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
"""
seg_logit = self.encode_decode(img, img_meta)
seg_logit = F.softmax(seg_logit, dim=1)
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
return seg_logit
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
if self.out_channels == 1:
seg_logit = F.sigmoid(seg_logit)
seg_pred = (seg_logit > self.decode_head.threshold).to(seg_logit)
else:
seg_pred = F.softmax(seg_logit, dim=1)
if get_codebase_config(ctx.cfg).get('with_argmax', True):
seg_pred = seg_pred.argmax(dim=1, keepdim=True)
return seg_pred
@ -51,5 +55,10 @@ def encoder_decoder__simple_test__rknn(ctx, self, img, img_meta, **kwargs):
torch.Tensor: Output segmentation map pf shape [N, C, H, W].
"""
seg_logit = self.encode_decode(img, img_meta)
seg_logit = F.softmax(seg_logit, dim=1)
return seg_logit
if self.out_channels == 1:
seg_logit = F.sigmoid(seg_logit)
seg_pred = (seg_logit > self.decode_head.threshold).to(seg_logit)
else:
seg_pred = F.softmax(seg_logit, dim=1)
return seg_pred

View File

@ -4,5 +4,5 @@ mmdet>=2.19.0,<=2.20.0
mmedit<1.0.0
mmocr>=0.3.0,<=0.4.1
mmpose>=0.24.0,<=0.25.1
mmrazor>=0.3.0
mmrazor>=0.3.0,<=0.3.1
mmsegmentation<1.0.0

View File

@ -1,11 +1,11 @@
h5py
mmcls>=0.21.0,<=0.23.0
mmdet>=2.19.0,<=2.20.0
mmedit
mmedit<1.0.0
mmocr>=0.3.0,<=0.4.1
mmpose>=0.24.0,<=0.25.1
mmrazor>=0.3.0
mmsegmentation
mmrazor>=0.3.0,<=0.3.1
mmsegmentation<1.0.0
onnxruntime>=1.8.0
openvino-dev>=2022.3.0
tqdm