Fixed slide inference (#90)

pull/93/head
Jerry Jiarui XU 2020-08-25 20:01:01 +08:00 committed by GitHub
parent 03ba9c6c26
commit bafc0e5db6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 9 additions and 7 deletions

View File

@ -229,7 +229,7 @@ To trade speed with GPU memory, you may pass in `--options model.backbone.with_c
python tools/train.py ${CONFIG_FILE} [optional arguments]
```
If you want to specify the working directory in the command, you can add an argument `--work_dir ${YOUR_WORK_DIR}`.
If you want to specify the working directory in the command, you can add an argument `--work-dir ${YOUR_WORK_DIR}`.
### Train with multiple GPUs
@ -253,7 +253,7 @@ Difference between `resume-from` and `load-from`:
If you run MMSegmentation on a cluster managed with [slurm](https://slurm.schedmd.com/), you can use the script `slurm_train.sh`. (This script also supports single machine training.)
```shell
[GPUS=${GPUS}] ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR}
[GPUS=${GPUS}] ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} --work-dir ${WORK_DIR}
```
Here is an example of using 16 GPUs to train PSPNet on the dev partition.

View File

@ -30,7 +30,7 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None):
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=not show, **data)
result = model(return_loss=False, **data)
if isinstance(results, list):
results.extend(result)
else:

View File

@ -173,7 +173,7 @@ class CityscapesDataset(CustomDataset):
try:
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
except ImportError:
raise ImportError('Please run "pip install citscapesscripts" to '
raise ImportError('Please run "pip install cityscapesscripts" to '
'install cityscapesscripts first.')
msg = 'Evaluating in Cityscapes style'
if logger is None:

View File

@ -195,8 +195,10 @@ class EncoderDecoder(BaseSegmentor):
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
# We want to regard count_mat as a constant while exporting to ONNX
count_mat = torch.from_numpy(count_mat.detach().numpy())
if torch.onnx.is_in_onnx_export():
# cast count_mat to constant while exporting to ONNX
count_mat = torch.from_numpy(
count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
if rescale:
preds = resize(

View File

@ -19,7 +19,7 @@ from mmseg.utils import collect_env, get_root_logger
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work_dir', help='the dir to save logs and models')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--load-from', help='the checkpoint file to load weights from')
parser.add_argument(