fix mmseg exportation for out_channels=1 (#1997)
parent
c73756366e
commit
335ef8648d
|
@ -28,7 +28,6 @@ jobs:
|
|||
runs-on: ubuntu-20.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
torch: [1.8.0, 1.9.0]
|
||||
mmcv: [1.4.2]
|
||||
include:
|
||||
|
@ -40,22 +39,22 @@ jobs:
|
|||
torchvision: 0.10.0
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -V
|
||||
python -m pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install MMCV
|
||||
run: |
|
||||
pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cpu/${{matrix.torch_version}}/index.html
|
||||
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cpu/${{matrix.torch_version}}/index.html
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- name: Install unittest dependencies
|
||||
run: |
|
||||
pip install -r requirements.txt
|
||||
pip install -U numpy
|
||||
python -m pip install -U numpy
|
||||
python -m pip install rapidfuzz==2.15.1
|
||||
python -m pip install -r requirements.txt
|
||||
- name: Build and install
|
||||
run: rm -rf .eggs && pip install -e .
|
||||
run: rm -rf .eggs && python -m pip install -e .
|
||||
- name: Run python unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmdeploy -m pytest -rsE tests
|
||||
|
@ -139,6 +138,7 @@ jobs:
|
|||
python -m pip install -U pip
|
||||
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cu102/${{matrix.torch_version}}/index.html
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install rapidfuzz==2.15.1
|
||||
- name: Build and install
|
||||
run: |
|
||||
rm -rf .eggs && python -m pip install -e .
|
||||
|
@ -174,6 +174,7 @@ jobs:
|
|||
python -m pip install -U pip
|
||||
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cu111/${{matrix.torch_version}}/index.html
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install rapidfuzz==2.15.1
|
||||
- name: Build and install
|
||||
run: |
|
||||
rm -rf .eggs && python -m pip install -e .
|
||||
|
|
|
@ -47,6 +47,7 @@ jobs:
|
|||
python -m pip install -U pip
|
||||
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cu111/${{matrix.torch_version}}/index.html
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install rapidfuzz==2.15.1
|
||||
- name: Install mmcls
|
||||
run: |
|
||||
cd ~
|
||||
|
|
|
@ -286,8 +286,13 @@ class Segmentation(BaseTask):
|
|||
postprocess = self.model_cfg.model.decode_head
|
||||
if isinstance(postprocess, list):
|
||||
postprocess = postprocess[-1]
|
||||
postprocess = postprocess.copy()
|
||||
with_argmax = get_codebase_config(self.deploy_cfg).get(
|
||||
'with_argmax', True)
|
||||
# set with_argmax=True for this special case
|
||||
if postprocess['num_classes'] == 2 and \
|
||||
postprocess['out_channels'] == 1:
|
||||
with_argmax = True
|
||||
postprocess['with_argmax'] = with_argmax
|
||||
return postprocess
|
||||
|
||||
|
|
|
@ -25,10 +25,14 @@ def encoder_decoder__simple_test(ctx, self, img, img_meta, **kwargs):
|
|||
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
|
||||
"""
|
||||
seg_logit = self.encode_decode(img, img_meta)
|
||||
seg_logit = F.softmax(seg_logit, dim=1)
|
||||
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
|
||||
return seg_logit
|
||||
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
|
||||
if self.out_channels == 1:
|
||||
seg_logit = F.sigmoid(seg_logit)
|
||||
seg_pred = (seg_logit > self.decode_head.threshold).to(seg_logit)
|
||||
else:
|
||||
seg_pred = F.softmax(seg_logit, dim=1)
|
||||
if get_codebase_config(ctx.cfg).get('with_argmax', True):
|
||||
seg_pred = seg_pred.argmax(dim=1, keepdim=True)
|
||||
|
||||
return seg_pred
|
||||
|
||||
|
||||
|
@ -51,5 +55,10 @@ def encoder_decoder__simple_test__rknn(ctx, self, img, img_meta, **kwargs):
|
|||
torch.Tensor: Output segmentation map pf shape [N, C, H, W].
|
||||
"""
|
||||
seg_logit = self.encode_decode(img, img_meta)
|
||||
seg_logit = F.softmax(seg_logit, dim=1)
|
||||
return seg_logit
|
||||
if self.out_channels == 1:
|
||||
seg_logit = F.sigmoid(seg_logit)
|
||||
seg_pred = (seg_logit > self.decode_head.threshold).to(seg_logit)
|
||||
else:
|
||||
seg_pred = F.softmax(seg_logit, dim=1)
|
||||
|
||||
return seg_pred
|
||||
|
|
|
@ -4,5 +4,5 @@ mmdet>=2.19.0,<=2.20.0
|
|||
mmedit<1.0.0
|
||||
mmocr>=0.3.0,<=0.4.1
|
||||
mmpose>=0.24.0,<=0.25.1
|
||||
mmrazor>=0.3.0
|
||||
mmrazor>=0.3.0,<=0.3.1
|
||||
mmsegmentation<1.0.0
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
h5py
|
||||
mmcls>=0.21.0,<=0.23.0
|
||||
mmdet>=2.19.0,<=2.20.0
|
||||
mmedit
|
||||
mmedit<1.0.0
|
||||
mmocr>=0.3.0,<=0.4.1
|
||||
mmpose>=0.24.0,<=0.25.1
|
||||
mmrazor>=0.3.0
|
||||
mmsegmentation
|
||||
mmrazor>=0.3.0,<=0.3.1
|
||||
mmsegmentation<1.0.0
|
||||
onnxruntime>=1.8.0
|
||||
openvino-dev>=2022.3.0
|
||||
tqdm
|
||||
|
|
Loading…
Reference in New Issue