diff --git a/docker/serve/Dockerfile b/docker/serve/Dockerfile new file mode 100644 index 00000000..cd28f50f --- /dev/null +++ b/docker/serve/Dockerfile @@ -0,0 +1,47 @@ +ARG PYTORCH="1.6.0" +ARG CUDA="10.1" +ARG CUDNN="7" +FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel + +ARG MMCV="1.3.1" +ARG MMCLS="0.12.0" + +ENV PYTHONUNBUFFERED TRUE + +RUN apt-get update && \ + DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ + ca-certificates \ + g++ \ + openjdk-11-jre-headless \ + # MMDet Requirements + ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \ + && rm -rf /var/lib/apt/lists/* + +ENV PATH="/opt/conda/bin:$PATH" +RUN export FORCE_CUDA=1 + +# TORCHSEVER +RUN pip install torchserve torch-model-archiver + +# MMLAB +RUN pip install mmcv-full==${MMCV} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html +RUN pip install mmcls==${MMCLS} + +RUN useradd -m model-server \ + && mkdir -p /home/model-server/tmp + +COPY entrypoint.sh /usr/local/bin/entrypoint.sh + +RUN chmod +x /usr/local/bin/entrypoint.sh \ + && chown -R model-server /home/model-server + +COPY config.properties /home/model-server/config.properties +RUN mkdir /home/model-server/model-store && chown -R model-server /home/model-server/model-store + +EXPOSE 8080 8081 8082 + +USER model-server +WORKDIR /home/model-server +ENV TEMP=/home/model-server/tmp +ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] +CMD ["serve"] diff --git a/docker/serve/config.properties b/docker/serve/config.properties new file mode 100644 index 00000000..efb9c47e --- /dev/null +++ b/docker/serve/config.properties @@ -0,0 +1,5 @@ +inference_address=http://0.0.0.0:8080 +management_address=http://0.0.0.0:8081 +metrics_address=http://0.0.0.0:8082 +model_store=/home/model-server/model-store +load_models=all diff --git a/docker/serve/entrypoint.sh b/docker/serve/entrypoint.sh new file mode 100644 index 00000000..41ba00b0 --- /dev/null +++ b/docker/serve/entrypoint.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +if [[ "$1" = "serve" ]]; then + shift 1 + torchserve --start --ts-config /home/model-server/config.properties +else + eval "$@" +fi + +# prevent docker exit +tail -f /dev/null diff --git a/docs/tutorials/model_serving.md b/docs/tutorials/model_serving.md new file mode 100644 index 00000000..253b7db7 --- /dev/null +++ b/docs/tutorials/model_serving.md @@ -0,0 +1,55 @@ +# Model Serving + +In order to serve an `MMClassification` model with [`TorchServe`](https://pytorch.org/serve/), you can follow the steps: + +## 1. Convert model from MMClassification to TorchServe + +```shell +python tools/deployment/mmcls2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \ +--output-folder ${MODEL_STORE} \ +--model-name ${MODEL_NAME} +``` + +**Note**: ${MODEL_STORE} needs to be an absolute path to a folder. + +## 2. Build `mmcls-serve` docker image + +```shell +docker build -t mmcls-serve:latest docker/serve/ +``` + +## 3. Run `mmcls-serve` + +Check the official docs for [running TorchServe with docker](https://github.com/pytorch/serve/blob/master/docker/README.md#running-torchserve-in-a-production-docker-environment). + +In order to run in GPU, you need to install [nvidia-docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). You can omit the `--gpus` argument in order to run in CPU. + +Example: + +```shell +docker run --rm \ +--cpus 8 \ +--gpus device=0 \ +-p8080:8080 -p8081:8081 -p8082:8082 \ +--mount type=bind,source=$MODEL_STORE,target=/home/model-server/model-store \ +mmcls-serve:latest +``` + +[Read the docs](https://github.com/pytorch/serve/blob/072f5d088cce9bb64b2a18af065886c9b01b317b/docs/rest_api.md) about the Inference (8080), Management (8081) and Metrics (8082) APis + +## 4. Test deployment + +```shell +curl -O https://raw.githubusercontent.com/pytorch/serve/master/docs/images/3dogs.jpg +curl http://127.0.0.1:8080/predictions/${MODEL_NAME} -T 3dogs.jpg +``` + +You should obtain a respose similar to: + +```json +{ + "pred_label": 245, + "pred_score": 0.5536593794822693, + "pred_class": "French bulldog" +} +``` diff --git a/setup.cfg b/setup.cfg index f050a8bb..ce016361 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = mmcls -known_third_party = PIL,cv2,matplotlib,mmcv,numpy,onnxruntime,pytest,torch,torchvision +known_third_party = PIL,cv2,matplotlib,mmcv,numpy,onnxruntime,pytest,torch,torchvision,ts no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tools/deployment/mmcls2torchserve.py b/tools/deployment/mmcls2torchserve.py new file mode 100644 index 00000000..6124291f --- /dev/null +++ b/tools/deployment/mmcls2torchserve.py @@ -0,0 +1,110 @@ +from argparse import ArgumentParser, Namespace +from pathlib import Path +from tempfile import TemporaryDirectory + +import mmcv + +try: + from model_archiver.model_packaging import package_model + from model_archiver.model_packaging_utils import ModelExportUtils +except ImportError: + package_model = None + + +def mmcls2torchserve( + config_file: str, + checkpoint_file: str, + output_folder: str, + model_name: str, + model_version: str = '1.0', + force: bool = False, +): + """Converts mmclassification model (config + checkpoint) to TorchServe + `.mar`. + + Args: + config_file: + In MMClassification config format. + The contents vary for each task repository. + checkpoint_file: + In MMClassification checkpoint format. + The contents vary for each task repository. + output_folder: + Folder where `{model_name}.mar` will be created. + The file created will be in TorchServe archive format. + model_name: + If not None, used for naming the `{model_name}.mar` file + that will be created under `output_folder`. + If None, `{Path(checkpoint_file).stem}` will be used. + model_version: + Model's version. + force: + If True, if there is an existing `{model_name}.mar` + file under `output_folder` it will be overwritten. + """ + mmcv.mkdir_or_exist(output_folder) + + config = mmcv.Config.fromfile(config_file) + + with TemporaryDirectory() as tmpdir: + config.dump(f'{tmpdir}/config.py') + + args = Namespace( + **{ + 'model_file': f'{tmpdir}/config.py', + 'serialized_file': checkpoint_file, + 'handler': f'{Path(__file__).parent}/mmcls_handler.py', + 'model_name': model_name or Path(checkpoint_file).stem, + 'version': model_version, + 'export_path': output_folder, + 'force': force, + 'requirements_file': None, + 'extra_files': None, + 'runtime': 'python', + 'archive_format': 'default' + }) + manifest = ModelExportUtils.generate_manifest_json(args) + package_model(args, manifest) + + +def parse_args(): + parser = ArgumentParser( + description='Convert mmcls models to TorchServe `.mar` format.') + parser.add_argument('config', type=str, help='config file path') + parser.add_argument('checkpoint', type=str, help='checkpoint file path') + parser.add_argument( + '--output-folder', + type=str, + required=True, + help='Folder where `{model_name}.mar` will be created.') + parser.add_argument( + '--model-name', + type=str, + default=None, + help='If not None, used for naming the `{model_name}.mar`' + 'file that will be created under `output_folder`.' + 'If None, `{Path(checkpoint_file).stem}` will be used.') + parser.add_argument( + '--model-version', + type=str, + default='1.0', + help='Number used for versioning.') + parser.add_argument( + '-f', + '--force', + action='store_true', + help='overwrite the existing `{model_name}.mar`') + args = parser.parse_args() + + return args + + +if __name__ == '__main__': + args = parse_args() + + if package_model is None: + raise ImportError('`torch-model-archiver` is required.' + 'Try: pip install torch-model-archiver') + + mmcls2torchserve(args.config, args.checkpoint, args.output_folder, + args.model_name, args.model_version, args.force) diff --git a/tools/deployment/mmcls_handler.py b/tools/deployment/mmcls_handler.py new file mode 100644 index 00000000..8a728f47 --- /dev/null +++ b/tools/deployment/mmcls_handler.py @@ -0,0 +1,50 @@ +import base64 +import os + +import mmcv +import torch +from ts.torch_handler.base_handler import BaseHandler + +from mmcls.apis import inference_model, init_model + + +class MMclsHandler(BaseHandler): + + def initialize(self, context): + properties = context.system_properties + self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = torch.device(self.map_location + ':' + + str(properties.get('gpu_id')) if torch.cuda. + is_available() else self.map_location) + self.manifest = context.manifest + + model_dir = properties.get('model_dir') + serialized_file = self.manifest['model']['serializedFile'] + checkpoint = os.path.join(model_dir, serialized_file) + self.config_file = os.path.join(model_dir, 'config.py') + + self.model = init_model(self.config_file, checkpoint, self.device) + self.initialized = True + + def preprocess(self, data): + images = [] + + for row in data: + image = row.get('data') or row.get('body') + if isinstance(image, str): + image = base64.b64decode(image) + image = mmcv.imfrombytes(image) + images.append(image) + + return images + + def inference(self, data, *args, **kwargs): + results = [] + for image in data: + results.append(inference_model(self.model, image)) + return results + + def postprocess(self, data): + for result in data: + result['pred_label'] = int(result['pred_label']) + return data