Update en doc of Hubserving (#900)
parent
017ce2d511
commit
27849b7a76
|
@ -4,7 +4,7 @@ English | [简体中文](readme.md)
|
|||
|
||||
HubServing service pack contains 3 files, the directory is as follows:
|
||||
```
|
||||
deploy/hubserving/clas/
|
||||
hubserving/clas/
|
||||
└─ __init__.py Empty file, required
|
||||
└─ config.json Configuration file, optional, passed in as a parameter when using configuration to start the service
|
||||
└─ module.py Main module file, required, contains the complete logic of the service
|
||||
|
@ -22,22 +22,28 @@ pip3 install paddlehub==2.0.0b1 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/s
|
|||
Before installing the service module, you need to prepare the inference model and put it in the correct path. The default model path is:
|
||||
|
||||
```
|
||||
Model structure file: ./inference/cls_infer.pdmodel
|
||||
Model parameters file: ./inference/cls_infer.pdiparams
|
||||
Model structure file: PaddleClas/inference/inference.pdmodel
|
||||
Model parameters file: PaddleClas/inference/inference.pdiparams
|
||||
```
|
||||
|
||||
**The model path can be found and modified in `params.py`.** More models provided by PaddleClas can be obtained from the [model library](../../docs/en/models/models_intro_en.md). You can also use models trained by yourself.
|
||||
* The model file path can be viewed and modified in `PaddleClas/deploy/hubserving/clas/params.py`.
|
||||
|
||||
It should be noted that the prefix of model structure file and model parameters file must be `inference`.
|
||||
|
||||
* More models provided by PaddleClas can be obtained from the [model library](../../docs/en/models/models_intro_en.md). You can also use models trained by yourself.
|
||||
|
||||
### 3. Install Service Module
|
||||
|
||||
* On Linux platform, the examples are as follows.
|
||||
```shell
|
||||
hub install deploy/hubserving/clas/
|
||||
cd PaddleClas/deploy
|
||||
hub install hubserving/clas/
|
||||
```
|
||||
|
||||
* On Windows platform, the examples are as follows.
|
||||
```shell
|
||||
hub install deploy\hubserving\clas\
|
||||
cd PaddleClas\deploy
|
||||
hub install hubserving\clas\
|
||||
```
|
||||
|
||||
### 4. Start service
|
||||
|
@ -103,32 +109,30 @@ Wherein, the format of `config.json` is as follows:
|
|||
|
||||
For example, use GPU card No. 3 to start the 2-stage series service:
|
||||
```shell
|
||||
cd PaddleClas/deploy
|
||||
export CUDA_VISIBLE_DEVICES=3
|
||||
hub serving start -c deploy/hubserving/clas/config.json
|
||||
hub serving start -c hubserving/clas/config.json
|
||||
```
|
||||
|
||||
## Send prediction requests
|
||||
After the service starts, you can use the following command to send a prediction request to obtain the prediction result:
|
||||
```shell
|
||||
python tools/test_hubserving.py server_url image_path
|
||||
cd PaddleClas/deploy
|
||||
python hubserving/test_hubserving.py server_url image_path
|
||||
```
|
||||
|
||||
Two required parameters need to be passed to the script:
|
||||
- **server_url**: service address,format of which is
|
||||
`http://[ip_address]:[port]/predict/[module_name]`
|
||||
- **image_path**: Test image path, can be a single image path or an image directory path
|
||||
- **top_k**: [**Optional**] Return the top `top_k` 's scores ,default by `1`.
|
||||
- **batch_size**: [**Optional**] batch_size. Default by `1`.
|
||||
- **resize_short**: [**Optional**] Resize the input image according to short size. Default by `256`.
|
||||
- **resize**: [**Optional**] Resize the input image. Default by `224`.
|
||||
- **normalize**: [**Optional**] Whether normalize the input image. Default by `True`.
|
||||
|
||||
**Notice**:
|
||||
If you want to use `Transformer series models`, such as `DeiT_***_384`, `ViT_***_384`, etc., please pay attention to the input size of model, and need to set `--resize_short=384`, `--resize=384`.
|
||||
|
||||
**Eg.**
|
||||
```shell
|
||||
python tools/test_hubserving.py --server_url http://127.0.0.1:8866/predict/clas_system --image_file ./deploy/hubserving/ILSVRC2012_val_00006666.JPEG --top_k 5
|
||||
python hubserving/test_hubserving.py --server_url http://127.0.0.1:8866/predict/clas_system --image_file ./hubserving/ILSVRC2012_val_00006666.JPEG --batch_size 8
|
||||
```
|
||||
|
||||
### Returned result format
|
||||
|
@ -142,7 +146,7 @@ list: The returned results
|
|||
└─ float: The time cost of predicting the picture, unit second
|
||||
```
|
||||
|
||||
**Note:** If you need to add, delete or modify the returned fields, you can modify the file `module.py` of the corresponding module. For the complete process, refer to the user-defined modification service module in the next section.
|
||||
**Note:** If you need to add, delete or modify the returned fields, you can modify the corresponding module. For the details, refer to the user-defined modification service module in the next section.
|
||||
|
||||
## User defined service module modification
|
||||
If you need to modify the service logic, the following steps are generally required:
|
||||
|
@ -151,20 +155,41 @@ If you need to modify the service logic, the following steps are generally requi
|
|||
```shell
|
||||
hub serving stop --port/-p XXXX
|
||||
```
|
||||
2. Modify the code in the corresponding files, like `module.py` and `params.py`, according to the actual needs.
|
||||
For example, if you need to replace the model used by the deployed service, you need to modify model path parameters `cfg.model_file` and `cfg.params_file` in `params.py`. Of course, other related parameters may need to be modified at the same time. Please modify and debug according to the actual situation.
|
||||
|
||||
After modifying and installing (`hub install deploy/hubserving/clas/`) and before deploying, you can use `python deploy/hubserving/clas/test.py` to test the installed service module.
|
||||
2. Modify the code in the corresponding files, like `module.py` and `params.py`, according to the actual needs. You need re-install(hub install hubserving/clas/) and re-deploy after modifing `module.py`.
|
||||
After modifying and installing and before deploying, you can use `python hubserving/clas/module.py` to test the installed service module.
|
||||
|
||||
For example, if you need to replace the model used by the deployed service, you need to modify model path parameters `cfg.model_file` and `cfg.params_file` in `params.py`. Of course, other related parameters may need to be modified at the same time. Please modify and debug according to the actual situation.
|
||||
|
||||
3. Uninstall old service module
|
||||
```shell
|
||||
hub uninstall clas_system
|
||||
```
|
||||
|
||||
4. Install modified service module
|
||||
```shell
|
||||
hub install deploy/hubserving/clas/
|
||||
hub install hubserving/clas/
|
||||
```
|
||||
|
||||
5. Restart service
|
||||
```shell
|
||||
hub serving start -m clas_system
|
||||
```
|
||||
|
||||
**Note**:
|
||||
|
||||
Common parameters can be modified in params.py:
|
||||
* Directory of model files(include model structure file and model parameters file):
|
||||
```python
|
||||
"inference_model_dir":
|
||||
```
|
||||
* The number of Top-k results returned during post-processing:
|
||||
```python
|
||||
'topk':
|
||||
```
|
||||
* Mapping file corresponding to label and class ID during post-processing:
|
||||
```python
|
||||
'class_id_map_file':
|
||||
```
|
||||
|
||||
In order to avoid unnecessary delay and be able to predict in batch, the preprocessing (include resize, crop and other) is completed in the client, so modify [test_hubserving.py](./test_hubserving.py#L35-L52) if necessary.
|
||||
|
|
|
@ -69,8 +69,8 @@ def main(args):
|
|||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
logger.warning(
|
||||
"Image file failed to read and has been skipped. The path: {}".
|
||||
format(img_path))
|
||||
f"Image file failed to read and has been skipped. The path: {img_path}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
for ops in preprocess_ops:
|
||||
|
@ -101,8 +101,7 @@ def main(args):
|
|||
msg = r.json()["msg"]
|
||||
raise Exception(msg)
|
||||
except Exception as e:
|
||||
logger.error("{}, in file(s): {} etc.".format(e, img_name_list[
|
||||
0]))
|
||||
logger.error(f"{e}, in file(s): {img_name_list[0]} etc.")
|
||||
continue
|
||||
else:
|
||||
results = r.json()["results"]
|
||||
|
@ -120,8 +119,9 @@ def main(args):
|
|||
result_list["class_ids"][i],
|
||||
result_list["scores"][i])
|
||||
|
||||
logger.info("File:{}, The result(s): {}".format(
|
||||
img_name_list[number], result_str))
|
||||
logger.info(
|
||||
f"File:{img_name_list[number]}, The result(s): {result_str}"
|
||||
)
|
||||
|
||||
finally:
|
||||
img_data_list = []
|
||||
|
|
|
@ -22,20 +22,28 @@ from paddle.nn.initializer import TruncatedNormal, Constant, Normal
|
|||
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
MODEL_URLS = {
|
||||
"ViT_small_patch16_224": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_small_patch16_224_pretrained.pdparams",
|
||||
"ViT_base_patch16_224": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_224_pretrained.pdparams",
|
||||
"ViT_base_patch16_384": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_384_pretrained.pdparams",
|
||||
"ViT_base_patch32_384": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch32_384_pretrained.pdparams",
|
||||
"ViT_large_patch16_224": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_224_pretrained.pdparams",
|
||||
"ViT_large_patch16_384": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_384_pretrained.pdparams",
|
||||
"ViT_large_patch32_384": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch32_384_pretrained.pdparams",
|
||||
"ViT_huge_patch16_224": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_huge_patch16_224_pretrained.pdparams",
|
||||
"ViT_huge_patch32_384": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_huge_patch32_384_pretrained.pdparams"
|
||||
}
|
||||
"ViT_small_patch16_224":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_small_patch16_224_pretrained.pdparams",
|
||||
"ViT_base_patch16_224":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_224_pretrained.pdparams",
|
||||
"ViT_base_patch16_384":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_384_pretrained.pdparams",
|
||||
"ViT_base_patch32_384":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch32_384_pretrained.pdparams",
|
||||
"ViT_large_patch16_224":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_224_pretrained.pdparams",
|
||||
"ViT_large_patch16_384":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_384_pretrained.pdparams",
|
||||
"ViT_large_patch32_384":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch32_384_pretrained.pdparams",
|
||||
"ViT_huge_patch16_224":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_huge_patch16_224_pretrained.pdparams",
|
||||
"ViT_huge_patch32_384":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_huge_patch32_384_pretrained.pdparams"
|
||||
}
|
||||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
|
||||
trunc_normal_ = TruncatedNormal(std=.02)
|
||||
normal_ = Normal
|
||||
zeros_ = Constant(value=0.)
|
||||
|
@ -209,7 +217,7 @@ class PatchEmbed(nn.Layer):
|
|||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
|
||||
x = self.proj(x).flatten(2).transpose((0, 2, 1))
|
||||
return x
|
||||
|
@ -323,8 +331,11 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
|
|||
)
|
||||
|
||||
|
||||
|
||||
def ViT_small_patch16_224(pretrained, model, model_url, use_ssld=False, **kwargs):
|
||||
def ViT_small_patch16_224(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
|
@ -333,12 +344,19 @@ def ViT_small_patch16_224(pretrained, model, model_url, use_ssld=False, **kwargs
|
|||
mlp_ratio=3,
|
||||
qk_scale=768**-0.5,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ViT_small_patch16_224"], use_ssld=use_ssld)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["ViT_small_patch16_224"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def ViT_base_patch16_224(pretrained, model, model_url, use_ssld=False, **kwargs):
|
||||
def ViT_base_patch16_224(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
|
@ -348,11 +366,19 @@ def ViT_base_patch16_224(pretrained, model, model_url, use_ssld=False, **kwargs)
|
|||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ViT_base_patch16_224"], use_ssld=use_ssld)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["ViT_base_patch16_224"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ViT_base_patch16_384(pretrained, model, model_url, use_ssld=False, **kwargs):
|
||||
def ViT_base_patch16_384(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
|
@ -363,11 +389,19 @@ def ViT_base_patch16_384(pretrained, model, model_url, use_ssld=False, **kwargs)
|
|||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ViT_base_patch16_384"], use_ssld=use_ssld)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["ViT_base_patch16_384"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ViT_base_patch32_384(pretrained, model, model_url, use_ssld=False, **kwargs):
|
||||
def ViT_base_patch32_384(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384,
|
||||
patch_size=32,
|
||||
|
@ -378,11 +412,19 @@ def ViT_base_patch32_384(pretrained, model, model_url, use_ssld=False, **kwargs)
|
|||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ViT_base_patch32_384"], use_ssld=use_ssld)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["ViT_base_patch32_384"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ViT_large_patch16_224(pretrained, model, model_url, use_ssld=False, **kwargs):
|
||||
def ViT_large_patch16_224(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=1024,
|
||||
|
@ -392,11 +434,19 @@ def ViT_large_patch16_224(pretrained, model, model_url, use_ssld=False, **kwargs
|
|||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ViT_large_patch16_224"], use_ssld=use_ssld)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["ViT_large_patch16_224"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ViT_large_patch16_384(pretrained, model, model_url, use_ssld=False, **kwargs):
|
||||
def ViT_large_patch16_384(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
|
@ -407,11 +457,19 @@ def ViT_large_patch16_384(pretrained, model, model_url, use_ssld=False, **kwargs
|
|||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ViT_large_patch16_384"], use_ssld=use_ssld)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["ViT_large_patch16_384"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ViT_large_patch32_384(pretrained, model, model_url, use_ssld=False, **kwargs):
|
||||
def ViT_large_patch32_384(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384,
|
||||
patch_size=32,
|
||||
|
@ -422,11 +480,19 @@ def ViT_large_patch32_384(pretrained, model, model_url, use_ssld=False, **kwargs
|
|||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ViT_large_patch32_384"], use_ssld=use_ssld)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["ViT_large_patch32_384"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ViT_huge_patch16_224(pretrained, model, model_url, use_ssld=False, **kwargs):
|
||||
def ViT_huge_patch16_224(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=1280,
|
||||
|
@ -434,11 +500,19 @@ def ViT_huge_patch16_224(pretrained, model, model_url, use_ssld=False, **kwargs)
|
|||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ViT_huge_patch16_224"], use_ssld=use_ssld)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["ViT_huge_patch16_224"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ViT_huge_patch32_384(pretrained, model, model_url, use_ssld=False, **kwargs):
|
||||
def ViT_huge_patch32_384(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384,
|
||||
patch_size=32,
|
||||
|
@ -447,5 +521,9 @@ def ViT_huge_patch32_384(pretrained, model, model_url, use_ssld=False, **kwargs)
|
|||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ViT_huge_patch32_384"], use_ssld=use_ssld)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["ViT_huge_patch32_384"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue