Add mmseg2torchserve tool (#552)
* Add docker/serve * Add handler * Add mmseg2torchserve * Fix mmv minimum version * Update docs with model serving section * Update useful_tools.md * pre-commit * Update useful_tools.md * Add 3dogs to resources * Move mask to resourcespull/676/head
parent
e6a8791ab0
commit
420783d007
|
@ -0,0 +1,47 @@
|
|||
ARG PYTORCH="1.6.0"
|
||||
ARG CUDA="10.1"
|
||||
ARG CUDNN="7"
|
||||
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
||||
|
||||
ARG MMCV="1.3.1"
|
||||
ARG MMSEG="0.13.0"
|
||||
|
||||
ENV PYTHONUNBUFFERED TRUE
|
||||
|
||||
RUN apt-get update && \
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
|
||||
ca-certificates \
|
||||
g++ \
|
||||
openjdk-11-jre-headless \
|
||||
# MMDet Requirements
|
||||
ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV PATH="/opt/conda/bin:$PATH"
|
||||
RUN export FORCE_CUDA=1
|
||||
|
||||
# TORCHSEVER
|
||||
RUN pip install torchserve torch-model-archiver
|
||||
|
||||
# MMLAB
|
||||
RUN pip install mmcv-full==${MMCV} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
|
||||
RUN pip install mmsegmentation==${MMSEG}
|
||||
|
||||
RUN useradd -m model-server \
|
||||
&& mkdir -p /home/model-server/tmp
|
||||
|
||||
COPY entrypoint.sh /usr/local/bin/entrypoint.sh
|
||||
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh \
|
||||
&& chown -R model-server /home/model-server
|
||||
|
||||
COPY config.properties /home/model-server/config.properties
|
||||
RUN mkdir /home/model-server/model-store && chown -R model-server /home/model-server/model-store
|
||||
|
||||
EXPOSE 8080 8081 8082
|
||||
|
||||
USER model-server
|
||||
WORKDIR /home/model-server
|
||||
ENV TEMP=/home/model-server/tmp
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["serve"]
|
|
@ -0,0 +1,5 @@
|
|||
inference_address=http://0.0.0.0:8080
|
||||
management_address=http://0.0.0.0:8081
|
||||
metrics_address=http://0.0.0.0:8082
|
||||
model_store=/home/model-server/model-store
|
||||
load_models=all
|
|
@ -0,0 +1,12 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if [[ "$1" = "serve" ]]; then
|
||||
shift 1
|
||||
torchserve --start --ts-config /home/model-server/config.properties
|
||||
else
|
||||
eval "$@"
|
||||
fi
|
||||
|
||||
# prevent docker exit
|
||||
tail -f /dev/null
|
|
@ -254,3 +254,64 @@ Examples:
|
|||
```shell
|
||||
python tools/analyze_logs.py log.json --keys loss --legend loss
|
||||
```
|
||||
|
||||
## Model Serving
|
||||
|
||||
In order to serve an `MMSegmentation` model with [`TorchServe`](https://pytorch.org/serve/), you can follow the steps:
|
||||
|
||||
### 1. Convert model from MMSegmentation to TorchServe
|
||||
|
||||
```shell
|
||||
python tools/mmseg2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \
|
||||
--output-folder ${MODEL_STORE} \
|
||||
--model-name ${MODEL_NAME}
|
||||
```
|
||||
|
||||
**Note**: ${MODEL_STORE} needs to be an absolute path to a folder.
|
||||
|
||||
### 2. Build `mmseg-serve` docker image
|
||||
|
||||
```shell
|
||||
docker build -t mmseg-serve:latest docker/serve/
|
||||
```
|
||||
|
||||
### 3. Run `mmseg-serve`
|
||||
|
||||
Check the official docs for [running TorchServe with docker](https://github.com/pytorch/serve/blob/master/docker/README.md#running-torchserve-in-a-production-docker-environment).
|
||||
|
||||
In order to run in GPU, you need to install [nvidia-docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). You can omit the `--gpus` argument in order to run in CPU.
|
||||
|
||||
Example:
|
||||
|
||||
```shell
|
||||
docker run --rm \
|
||||
--cpus 8 \
|
||||
--gpus device=0 \
|
||||
-p8080:8080 -p8081:8081 -p8082:8082 \
|
||||
--mount type=bind,source=$MODEL_STORE,target=/home/model-server/model-store \
|
||||
mmseg-serve:latest
|
||||
```
|
||||
|
||||
[Read the docs](https://github.com/pytorch/serve/blob/072f5d088cce9bb64b2a18af065886c9b01b317b/docs/rest_api.md) about the Inference (8080), Management (8081) and Metrics (8082) APis
|
||||
|
||||
### 4. Test deployment
|
||||
|
||||
```shell
|
||||
curl -O https://raw.githubusercontent.com/open-mmlab/mmsegmentation/master/resources/3dogs.jpg
|
||||
curl http://127.0.0.1:8080/predictions/${MODEL_NAME} -T 3dogs.jpg -o 3dogs_mask.png
|
||||
```
|
||||
|
||||
The response will be a ".png" mask.
|
||||
|
||||
You can visualize the output as follows:
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
plt.imshow(mmcv.imread("3dogs_mask.png", "grayscale"))
|
||||
plt.show()
|
||||
```
|
||||
|
||||
You should see something similar to:
|
||||
|
||||

|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 181 KiB |
Binary file not shown.
After Width: | Height: | Size: 19 KiB |
|
@ -8,6 +8,6 @@ line_length = 79
|
|||
multi_line_output = 0
|
||||
known_standard_library = setuptools
|
||||
known_first_party = mmseg
|
||||
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch
|
||||
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch,ts
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import mmcv
|
||||
|
||||
try:
|
||||
from model_archiver.model_packaging import package_model
|
||||
from model_archiver.model_packaging_utils import ModelExportUtils
|
||||
except ImportError:
|
||||
package_model = None
|
||||
|
||||
|
||||
def mmseg2torchserve(
|
||||
config_file: str,
|
||||
checkpoint_file: str,
|
||||
output_folder: str,
|
||||
model_name: str,
|
||||
model_version: str = '1.0',
|
||||
force: bool = False,
|
||||
):
|
||||
"""Converts mmsegmentation model (config + checkpoint) to TorchServe
|
||||
`.mar`.
|
||||
|
||||
Args:
|
||||
config_file:
|
||||
In MMSegmentation config format.
|
||||
The contents vary for each task repository.
|
||||
checkpoint_file:
|
||||
In MMSegmentation checkpoint format.
|
||||
The contents vary for each task repository.
|
||||
output_folder:
|
||||
Folder where `{model_name}.mar` will be created.
|
||||
The file created will be in TorchServe archive format.
|
||||
model_name:
|
||||
If not None, used for naming the `{model_name}.mar` file
|
||||
that will be created under `output_folder`.
|
||||
If None, `{Path(checkpoint_file).stem}` will be used.
|
||||
model_version:
|
||||
Model's version.
|
||||
force:
|
||||
If True, if there is an existing `{model_name}.mar`
|
||||
file under `output_folder` it will be overwritten.
|
||||
"""
|
||||
mmcv.mkdir_or_exist(output_folder)
|
||||
|
||||
config = mmcv.Config.fromfile(config_file)
|
||||
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
config.dump(f'{tmpdir}/config.py')
|
||||
|
||||
args = Namespace(
|
||||
**{
|
||||
'model_file': f'{tmpdir}/config.py',
|
||||
'serialized_file': checkpoint_file,
|
||||
'handler': f'{Path(__file__).parent}/mmseg_handler.py',
|
||||
'model_name': model_name or Path(checkpoint_file).stem,
|
||||
'version': model_version,
|
||||
'export_path': output_folder,
|
||||
'force': force,
|
||||
'requirements_file': None,
|
||||
'extra_files': None,
|
||||
'runtime': 'python',
|
||||
'archive_format': 'default'
|
||||
})
|
||||
manifest = ModelExportUtils.generate_manifest_json(args)
|
||||
package_model(args, manifest)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(
|
||||
description='Convert mmseg models to TorchServe `.mar` format.')
|
||||
parser.add_argument('config', type=str, help='config file path')
|
||||
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
|
||||
parser.add_argument(
|
||||
'--output-folder',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Folder where `{model_name}.mar` will be created.')
|
||||
parser.add_argument(
|
||||
'--model-name',
|
||||
type=str,
|
||||
default=None,
|
||||
help='If not None, used for naming the `{model_name}.mar`'
|
||||
'file that will be created under `output_folder`.'
|
||||
'If None, `{Path(checkpoint_file).stem}` will be used.')
|
||||
parser.add_argument(
|
||||
'--model-version',
|
||||
type=str,
|
||||
default='1.0',
|
||||
help='Number used for versioning.')
|
||||
parser.add_argument(
|
||||
'-f',
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='overwrite the existing `{model_name}.mar`')
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
if package_model is None:
|
||||
raise ImportError('`torch-model-archiver` is required.'
|
||||
'Try: pip install torch-model-archiver')
|
||||
|
||||
mmseg2torchserve(args.config, args.checkpoint, args.output_folder,
|
||||
args.model_name, args.model_version, args.force)
|
|
@ -0,0 +1,53 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import torch
|
||||
from ts.torch_handler.base_handler import BaseHandler
|
||||
|
||||
from mmseg.apis import inference_segmentor, init_segmentor
|
||||
|
||||
|
||||
class MMsegHandler(BaseHandler):
|
||||
|
||||
def initialize(self, context):
|
||||
properties = context.system_properties
|
||||
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.device = torch.device(self.map_location + ':' +
|
||||
str(properties.get('gpu_id')) if torch.cuda.
|
||||
is_available() else self.map_location)
|
||||
self.manifest = context.manifest
|
||||
|
||||
model_dir = properties.get('model_dir')
|
||||
serialized_file = self.manifest['model']['serializedFile']
|
||||
checkpoint = os.path.join(model_dir, serialized_file)
|
||||
self.config_file = os.path.join(model_dir, 'config.py')
|
||||
|
||||
self.model = init_segmentor(self.config_file, checkpoint, self.device)
|
||||
self.initialized = True
|
||||
|
||||
def preprocess(self, data):
|
||||
images = []
|
||||
|
||||
for row in data:
|
||||
image = row.get('data') or row.get('body')
|
||||
if isinstance(image, str):
|
||||
image = base64.b64decode(image)
|
||||
image = mmcv.imfrombytes(image)
|
||||
images.append(image)
|
||||
|
||||
return images
|
||||
|
||||
def inference(self, data, *args, **kwargs):
|
||||
results = [inference_segmentor(self.model, img) for img in data]
|
||||
return results
|
||||
|
||||
def postprocess(self, data):
|
||||
output = []
|
||||
for image_result in data:
|
||||
buffer = io.BytesIO()
|
||||
_, buffer = cv2.imencode('.png', image_result[0].astype('uint8'))
|
||||
output.append(buffer.tobytes())
|
||||
return output
|
Loading…
Reference in New Issue