mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] add onnxruntime test tool (#498)
* add onnxruntime test tool, update pytorch2onnx to support slice export * onnx convert with custom output shape, update test code * update pytorch2onnx, add rescale_shape support, add document * update doc for lint error fixing * remove cpu flag in ort_test.py * change class name, fix cuda error * remote comment * fix bug of torch2onnx * mIOU to mIoU
This commit is contained in:
parent
fb031c59c8
commit
6ccb1c0fe5
@ -53,6 +53,7 @@ python tools/pytorch2onnx.py \
|
|||||||
--output-file ${ONNX_FILE} \
|
--output-file ${ONNX_FILE} \
|
||||||
--input-img ${INPUT_IMG} \
|
--input-img ${INPUT_IMG} \
|
||||||
--shape ${INPUT_SHAPE} \
|
--shape ${INPUT_SHAPE} \
|
||||||
|
--rescale-shape ${RESCALE_SHAPE} \
|
||||||
--show \
|
--show \
|
||||||
--verify \
|
--verify \
|
||||||
--dynamic-export \
|
--dynamic-export \
|
||||||
@ -66,7 +67,8 @@ Description of arguments:
|
|||||||
- `--checkpoint` : The path of a model checkpoint file.
|
- `--checkpoint` : The path of a model checkpoint file.
|
||||||
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
|
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
|
||||||
- `--input-img` : The path of an input image for conversion and visualize.
|
- `--input-img` : The path of an input image for conversion and visualize.
|
||||||
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `256 256`.
|
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to img_scale of testpipeline.
|
||||||
|
- `--rescale-shape`: rescale shape of output, set this value to avoid OOM, only work on `slide` mode.
|
||||||
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
|
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
|
||||||
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
|
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
|
||||||
- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.
|
- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.
|
||||||
@ -74,6 +76,55 @@ Description of arguments:
|
|||||||
|
|
||||||
**Note**: This tool is still experimental. Some customized operators are not supported for now.
|
**Note**: This tool is still experimental. Some customized operators are not supported for now.
|
||||||
|
|
||||||
|
### Evaluate ONNX model with ONNXRuntime
|
||||||
|
|
||||||
|
We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend.
|
||||||
|
|
||||||
|
#### Prerequisite
|
||||||
|
|
||||||
|
- Install onnx and onnxruntime-gpu
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install onnx onnxruntime-gpu
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
python tools/ort_test.py \
|
||||||
|
${CONFIG_FILE} \
|
||||||
|
${ONNX_FILE} \
|
||||||
|
--out ${OUTPUT_FILE} \
|
||||||
|
--eval ${EVALUATION_METRICS} \
|
||||||
|
--show \
|
||||||
|
--show-dir ${SHOW_DIRECTORY} \
|
||||||
|
--options ${CFG_OPTIONS} \
|
||||||
|
--eval-options ${EVALUATION_OPTIONS} \
|
||||||
|
--opacity ${OPACITY} \
|
||||||
|
```
|
||||||
|
|
||||||
|
Description of all arguments
|
||||||
|
|
||||||
|
- `config`: The path of a model config file.
|
||||||
|
- `model`: The path of a ONNX model file.
|
||||||
|
- `--out`: The path of output result file in pickle format.
|
||||||
|
- `--format-only` : Format the output results without perform evaluation. It is useful when you want to format the result to a specific format and submit it to the test server. If not specified, it will be set to `False`. Note that this argument is **mutually exclusive** with `--eval`.
|
||||||
|
- `--eval`: Evaluation metrics, which depends on the dataset, e.g., "mIoU" for generic datasets, and "cityscapes" for Cityscapes. Note that this argument is **mutually exclusive** with `--format-only`.
|
||||||
|
- `--show`: Show results flag.
|
||||||
|
- `--show-dir`: Directory where painted images will be saved
|
||||||
|
- `--options`: Override some settings in the used config file, the key-value pair in `xxx=yyy` format will be merged into config file.
|
||||||
|
- `--eval-options`: Custom options for evaluation, the key-value pair in `xxx=yyy` format will be kwargs for `dataset.evaluate()` function
|
||||||
|
- `--opacity`: Opacity of painted segmentation map. In (0, 1] range.
|
||||||
|
|
||||||
|
#### Results and Models
|
||||||
|
|
||||||
|
| Model | Config | Dataset | Metric | PyTorch | ONNXRuntime |
|
||||||
|
| :--------: | :--------------------------------------------: | :--------: | :----: | :-----: | :---------: |
|
||||||
|
| FCN | fcn_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 72.2 | 72.2 |
|
||||||
|
| PSPNet | pspnet_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.2 | 78.1 |
|
||||||
|
| deeplabv3 | deeplabv3_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.5 | 78.3 |
|
||||||
|
| deeplabv3+ | deeplabv3plus_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.9 | 78.7 |
|
||||||
|
|
||||||
### Convert to TorchScript (experimental)
|
### Convert to TorchScript (experimental)
|
||||||
|
|
||||||
We also provide a script to convert model to [TorchScript](https://pytorch.org/docs/stable/jit.html) format. You can use the pytorch C++ API [LibTorch](https://pytorch.org/docs/stable/cpp_index.html) inference the trained model. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and TorchScript model.
|
We also provide a script to convert model to [TorchScript](https://pytorch.org/docs/stable/jit.html) format. You can use the pytorch C++ API [LibTorch](https://pytorch.org/docs/stable/cpp_index.html) inference the trained model. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and TorchScript model.
|
||||||
|
191
tools/ort_test.py
Normal file
191
tools/ort_test.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
from mmcv.parallel import MMDataParallel
|
||||||
|
from mmcv.runner import get_dist_info
|
||||||
|
from mmcv.utils import DictAction
|
||||||
|
|
||||||
|
from mmseg.apis import single_gpu_test
|
||||||
|
from mmseg.datasets import build_dataloader, build_dataset
|
||||||
|
from mmseg.models.segmentors.base import BaseSegmentor
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXRuntimeSegmentor(BaseSegmentor):
|
||||||
|
|
||||||
|
def __init__(self, onnx_file, cfg, device_id):
|
||||||
|
super(ONNXRuntimeSegmentor, self).__init__()
|
||||||
|
# get the custom op path
|
||||||
|
ort_custom_op_path = ''
|
||||||
|
try:
|
||||||
|
from mmcv.ops import get_onnxruntime_op_path
|
||||||
|
ort_custom_op_path = get_onnxruntime_op_path()
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
warnings.warn('If input model has custom op from mmcv, \
|
||||||
|
you may have to build mmcv with ONNXRuntime from source.')
|
||||||
|
session_options = ort.SessionOptions()
|
||||||
|
# register custom op for onnxruntime
|
||||||
|
if osp.exists(ort_custom_op_path):
|
||||||
|
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||||
|
sess = ort.InferenceSession(onnx_file, session_options)
|
||||||
|
providers = ['CPUExecutionProvider']
|
||||||
|
options = [{}]
|
||||||
|
is_cuda_available = ort.get_device() == 'GPU'
|
||||||
|
if is_cuda_available:
|
||||||
|
providers.insert(0, 'CUDAExecutionProvider')
|
||||||
|
options.insert(0, {'device_id': device_id})
|
||||||
|
|
||||||
|
sess.set_providers(providers, options)
|
||||||
|
|
||||||
|
self.sess = sess
|
||||||
|
self.device_id = device_id
|
||||||
|
self.io_binding = sess.io_binding()
|
||||||
|
self.output_names = [_.name for _ in sess.get_outputs()]
|
||||||
|
for name in self.output_names:
|
||||||
|
self.io_binding.bind_output(name)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.test_mode = cfg.model.test_cfg.mode
|
||||||
|
|
||||||
|
def extract_feat(self, imgs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def encode_decode(self, img, img_metas):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def forward_train(self, imgs, img_metas, **kwargs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def simple_test(self, img, img_meta, **kwargs):
|
||||||
|
device_type = img.device.type
|
||||||
|
self.io_binding.bind_input(
|
||||||
|
name='input',
|
||||||
|
device_type=device_type,
|
||||||
|
device_id=self.device_id,
|
||||||
|
element_type=np.float32,
|
||||||
|
shape=img.shape,
|
||||||
|
buffer_ptr=img.data_ptr())
|
||||||
|
self.sess.run_with_iobinding(self.io_binding)
|
||||||
|
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
|
||||||
|
# whole might support dynamic reshape
|
||||||
|
ori_shape = img_meta[0]['ori_shape']
|
||||||
|
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||||
|
and ori_shape[1] == seg_pred.shape[-1]):
|
||||||
|
seg_pred = torch.from_numpy(seg_pred).float()
|
||||||
|
seg_pred = torch.nn.functional.interpolate(
|
||||||
|
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||||
|
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||||
|
seg_pred = seg_pred[0]
|
||||||
|
seg_pred = list(seg_pred)
|
||||||
|
return seg_pred
|
||||||
|
|
||||||
|
def aug_test(self, imgs, img_metas, **kwargs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='mmseg onnxruntime backend test (and eval) a model')
|
||||||
|
parser.add_argument('config', help='test config file path')
|
||||||
|
parser.add_argument('model', help='Input model file')
|
||||||
|
parser.add_argument('--out', help='output result file in pickle format')
|
||||||
|
parser.add_argument(
|
||||||
|
'--format-only',
|
||||||
|
action='store_true',
|
||||||
|
help='Format the output results without perform evaluation. It is'
|
||||||
|
'useful when you want to format the result to a specific format and '
|
||||||
|
'submit it to the test server')
|
||||||
|
parser.add_argument(
|
||||||
|
'--eval',
|
||||||
|
type=str,
|
||||||
|
nargs='+',
|
||||||
|
help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
|
||||||
|
' for generic datasets, and "cityscapes" for Cityscapes')
|
||||||
|
parser.add_argument('--show', action='store_true', help='show results')
|
||||||
|
parser.add_argument(
|
||||||
|
'--show-dir', help='directory where painted images will be saved')
|
||||||
|
parser.add_argument(
|
||||||
|
'--options', nargs='+', action=DictAction, help='custom options')
|
||||||
|
parser.add_argument(
|
||||||
|
'--eval-options',
|
||||||
|
nargs='+',
|
||||||
|
action=DictAction,
|
||||||
|
help='custom options for evaluation')
|
||||||
|
parser.add_argument(
|
||||||
|
'--opacity',
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help='Opacity of painted segmentation map. In (0, 1] range.')
|
||||||
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
if 'LOCAL_RANK' not in os.environ:
|
||||||
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
assert args.out or args.eval or args.format_only or args.show \
|
||||||
|
or args.show_dir, \
|
||||||
|
('Please specify at least one operation (save/eval/format/show the '
|
||||||
|
'results / save the results) with the argument "--out", "--eval"'
|
||||||
|
', "--format-only", "--show" or "--show-dir"')
|
||||||
|
|
||||||
|
if args.eval and args.format_only:
|
||||||
|
raise ValueError('--eval and --format_only cannot be both specified')
|
||||||
|
|
||||||
|
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
|
||||||
|
raise ValueError('The output file must be a pkl file.')
|
||||||
|
|
||||||
|
cfg = mmcv.Config.fromfile(args.config)
|
||||||
|
if args.options is not None:
|
||||||
|
cfg.merge_from_dict(args.options)
|
||||||
|
cfg.model.pretrained = None
|
||||||
|
cfg.data.test.test_mode = True
|
||||||
|
|
||||||
|
# init distributed env first, since logger depends on the dist info.
|
||||||
|
distributed = False
|
||||||
|
|
||||||
|
# build the dataloader
|
||||||
|
# TODO: support multiple images per gpu (only minor changes are needed)
|
||||||
|
dataset = build_dataset(cfg.data.test)
|
||||||
|
data_loader = build_dataloader(
|
||||||
|
dataset,
|
||||||
|
samples_per_gpu=1,
|
||||||
|
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||||
|
dist=distributed,
|
||||||
|
shuffle=False)
|
||||||
|
|
||||||
|
# load onnx config and meta
|
||||||
|
cfg.model.train_cfg = None
|
||||||
|
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
|
||||||
|
model.CLASSES = dataset.CLASSES
|
||||||
|
model.PALETTE = dataset.PALETTE
|
||||||
|
|
||||||
|
efficient_test = False
|
||||||
|
if args.eval_options is not None:
|
||||||
|
efficient_test = args.eval_options.get('efficient_test', False)
|
||||||
|
|
||||||
|
model = MMDataParallel(model, device_ids=[0])
|
||||||
|
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
|
||||||
|
efficient_test, args.opacity)
|
||||||
|
|
||||||
|
rank, _ = get_dist_info()
|
||||||
|
if rank == 0:
|
||||||
|
if args.out:
|
||||||
|
print(f'\nwriting results to {args.out}')
|
||||||
|
mmcv.dump(outputs, args.out)
|
||||||
|
kwargs = {} if args.eval_options is None else args.eval_options
|
||||||
|
if args.format_only:
|
||||||
|
dataset.format_results(outputs, **kwargs)
|
||||||
|
if args.eval:
|
||||||
|
dataset.evaluate(outputs, args.eval, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -71,10 +71,13 @@ def _demo_mm_inputs(input_shape, num_classes):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
def _prepare_input_img(img_path, test_pipeline, shape=None):
|
def _prepare_input_img(img_path,
|
||||||
|
test_pipeline,
|
||||||
|
shape=None,
|
||||||
|
rescale_shape=None):
|
||||||
# build the data pipeline
|
# build the data pipeline
|
||||||
if shape is not None:
|
if shape is not None:
|
||||||
test_pipeline[1]['img_scale'] = shape
|
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
|
||||||
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
|
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
|
||||||
test_pipeline = [LoadImage()] + test_pipeline[1:]
|
test_pipeline = [LoadImage()] + test_pipeline[1:]
|
||||||
test_pipeline = Compose(test_pipeline)
|
test_pipeline = Compose(test_pipeline)
|
||||||
@ -84,6 +87,10 @@ def _prepare_input_img(img_path, test_pipeline, shape=None):
|
|||||||
imgs = data['img']
|
imgs = data['img']
|
||||||
img_metas = [i.data for i in data['img_metas']]
|
img_metas = [i.data for i in data['img_metas']]
|
||||||
|
|
||||||
|
if rescale_shape is not None:
|
||||||
|
for img_meta in img_metas:
|
||||||
|
img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
|
||||||
|
|
||||||
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
|
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
|
||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
@ -91,15 +98,24 @@ def _prepare_input_img(img_path, test_pipeline, shape=None):
|
|||||||
|
|
||||||
def _update_input_img(img_list, img_meta_list):
|
def _update_input_img(img_list, img_meta_list):
|
||||||
# update img and its meta list
|
# update img and its meta list
|
||||||
N, C, H, W = img_list[0].shape
|
N = img_list[0].size(0)
|
||||||
img_meta = img_meta_list[0][0]
|
img_meta = img_meta_list[0][0]
|
||||||
|
img_shape = img_meta['img_shape']
|
||||||
|
ori_shape = img_meta['ori_shape']
|
||||||
|
pad_shape = img_meta['pad_shape']
|
||||||
new_img_meta_list = [[{
|
new_img_meta_list = [[{
|
||||||
'img_shape': (H, W, C),
|
'img_shape':
|
||||||
'ori_shape': (H, W, C),
|
img_shape,
|
||||||
'pad_shape': (H, W, C),
|
'ori_shape':
|
||||||
'filename': img_meta['filename'],
|
ori_shape,
|
||||||
'scale_factor': 1.,
|
'pad_shape':
|
||||||
'flip': False,
|
pad_shape,
|
||||||
|
'filename':
|
||||||
|
img_meta['filename'],
|
||||||
|
'scale_factor':
|
||||||
|
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
|
||||||
|
'flip':
|
||||||
|
False,
|
||||||
} for _ in range(N)]]
|
} for _ in range(N)]]
|
||||||
|
|
||||||
return img_list, new_img_meta_list
|
return img_list, new_img_meta_list
|
||||||
@ -128,6 +144,7 @@ def pytorch2onnx(model,
|
|||||||
Default: False.
|
Default: False.
|
||||||
"""
|
"""
|
||||||
model.cpu().eval()
|
model.cpu().eval()
|
||||||
|
test_mode = model.test_cfg.mode
|
||||||
|
|
||||||
if isinstance(model.decode_head, nn.ModuleList):
|
if isinstance(model.decode_head, nn.ModuleList):
|
||||||
num_classes = model.decode_head[-1].num_classes
|
num_classes = model.decode_head[-1].num_classes
|
||||||
@ -136,18 +153,24 @@ def pytorch2onnx(model,
|
|||||||
|
|
||||||
imgs = mm_inputs.pop('imgs')
|
imgs = mm_inputs.pop('imgs')
|
||||||
img_metas = mm_inputs.pop('img_metas')
|
img_metas = mm_inputs.pop('img_metas')
|
||||||
ori_shape = img_metas[0]['ori_shape']
|
|
||||||
|
|
||||||
img_list = [img[None, :] for img in imgs]
|
img_list = [img[None, :] for img in imgs]
|
||||||
img_meta_list = [[img_meta] for img_meta in img_metas]
|
img_meta_list = [[img_meta] for img_meta in img_metas]
|
||||||
|
# update img_meta
|
||||||
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
|
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
|
||||||
|
|
||||||
# replace original forward function
|
# replace original forward function
|
||||||
origin_forward = model.forward
|
origin_forward = model.forward
|
||||||
model.forward = partial(
|
model.forward = partial(
|
||||||
model.forward, img_metas=img_meta_list, return_loss=False)
|
model.forward,
|
||||||
|
img_metas=img_meta_list,
|
||||||
|
return_loss=False,
|
||||||
|
rescale=True)
|
||||||
dynamic_axes = None
|
dynamic_axes = None
|
||||||
if dynamic_export:
|
if dynamic_export:
|
||||||
|
if test_mode == 'slide':
|
||||||
|
dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}}
|
||||||
|
else:
|
||||||
dynamic_axes = {
|
dynamic_axes = {
|
||||||
'input': {
|
'input': {
|
||||||
0: 'batch',
|
0: 'batch',
|
||||||
@ -182,7 +205,7 @@ def pytorch2onnx(model,
|
|||||||
onnx_model = onnx.load(output_file)
|
onnx_model = onnx.load(output_file)
|
||||||
onnx.checker.check_model(onnx_model)
|
onnx.checker.check_model(onnx_model)
|
||||||
|
|
||||||
if dynamic_export:
|
if dynamic_export and test_mode == 'whole':
|
||||||
# scale image for dynamic shape test
|
# scale image for dynamic shape test
|
||||||
img_list = [
|
img_list = [
|
||||||
nn.functional.interpolate(_, scale_factor=1.5)
|
nn.functional.interpolate(_, scale_factor=1.5)
|
||||||
@ -223,6 +246,10 @@ def pytorch2onnx(model,
|
|||||||
if not osp.exists(img):
|
if not osp.exists(img):
|
||||||
img = imgs[0][:3, ...].permute(1, 2, 0) * 255
|
img = imgs[0][:3, ...].permute(1, 2, 0) * 255
|
||||||
img = img.detach().numpy().astype(np.uint8)
|
img = img.detach().numpy().astype(np.uint8)
|
||||||
|
ori_shape = img.shape[:2]
|
||||||
|
else:
|
||||||
|
ori_shape = LoadImage()({'img': img})['ori_shape']
|
||||||
|
|
||||||
# resize onnx_result to ori_shape
|
# resize onnx_result to ori_shape
|
||||||
onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
|
onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
|
||||||
(ori_shape[1], ori_shape[0]))
|
(ori_shape[1], ori_shape[0]))
|
||||||
@ -271,8 +298,14 @@ def parse_args():
|
|||||||
'--shape',
|
'--shape',
|
||||||
type=int,
|
type=int,
|
||||||
nargs='+',
|
nargs='+',
|
||||||
default=[256, 256],
|
default=None,
|
||||||
help='input image size')
|
help='input image height and width.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--rescale_shape',
|
||||||
|
type=int,
|
||||||
|
nargs='+',
|
||||||
|
default=None,
|
||||||
|
help='output image rescale height and width, work for slide mode.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--cfg-options',
|
'--cfg-options',
|
||||||
nargs='+',
|
nargs='+',
|
||||||
@ -294,7 +327,15 @@ def parse_args():
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
if len(args.shape) == 1:
|
cfg = mmcv.Config.fromfile(args.config)
|
||||||
|
if args.cfg_options is not None:
|
||||||
|
cfg.merge_from_dict(args.cfg_options)
|
||||||
|
cfg.model.pretrained = None
|
||||||
|
|
||||||
|
if args.shape is None:
|
||||||
|
img_scale = cfg.test_pipeline[1]['img_scale']
|
||||||
|
input_shape = (1, 3, img_scale[1], img_scale[0])
|
||||||
|
elif len(args.shape) == 1:
|
||||||
input_shape = (1, 3, args.shape[0], args.shape[0])
|
input_shape = (1, 3, args.shape[0], args.shape[0])
|
||||||
elif len(args.shape) == 2:
|
elif len(args.shape) == 2:
|
||||||
input_shape = (
|
input_shape = (
|
||||||
@ -304,10 +345,7 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
raise ValueError('invalid input shape')
|
raise ValueError('invalid input shape')
|
||||||
|
|
||||||
cfg = mmcv.Config.fromfile(args.config)
|
test_mode = cfg.model.test_cfg.mode
|
||||||
if args.cfg_options is not None:
|
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
|
||||||
cfg.model.pretrained = None
|
|
||||||
|
|
||||||
# build the model and load checkpoint
|
# build the model and load checkpoint
|
||||||
cfg.model.train_cfg = None
|
cfg.model.train_cfg = None
|
||||||
@ -324,8 +362,15 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# read input or create dummpy input
|
# read input or create dummpy input
|
||||||
if args.input_img is not None:
|
if args.input_img is not None:
|
||||||
mm_inputs = _prepare_input_img(args.input_img, cfg.data.test.pipeline,
|
preprocess_shape = (input_shape[2], input_shape[3])
|
||||||
(input_shape[3], input_shape[2]))
|
rescale_shape = None
|
||||||
|
if args.rescale_shape is not None:
|
||||||
|
rescale_shape = [args.rescale_shape[0], args.rescale_shape[1]]
|
||||||
|
mm_inputs = _prepare_input_img(
|
||||||
|
args.input_img,
|
||||||
|
cfg.data.test.pipeline,
|
||||||
|
shape=preprocess_shape,
|
||||||
|
rescale_shape=rescale_shape)
|
||||||
else:
|
else:
|
||||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||||
num_classes = segmentor.decode_head[-1].num_classes
|
num_classes = segmentor.decode_head[-1].num_classes
|
||||||
|
Loading…
x
Reference in New Issue
Block a user