cp update shitu whl
parent
8eba0f0b24
commit
10376757ce
|
@ -0,0 +1,130 @@
|
|||
# PP-ShiTu Whl 使用说明
|
||||
|
||||
PaddleClas 支持 Python Whl 包方式进行预测。
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
|
||||
|
||||
<a name="1"></a>
|
||||
|
||||
## 1. 安装 paddleclas
|
||||
|
||||
* **[推荐]** 直接 pip 安装:
|
||||
|
||||
```bash
|
||||
pip3 install paddleclas
|
||||
```
|
||||
|
||||
* 如需使用 PaddleClas develop 分支体验最新功能,或是需要基于 PaddleClas 进行二次开发,请本地构建安装:
|
||||
|
||||
```bash
|
||||
python3 setup.py install
|
||||
```
|
||||
|
||||
<a name="2"></a>
|
||||
|
||||
## 2. 快速开始
|
||||
|
||||
<a name="2.1"></a>
|
||||
|
||||
### 2.1 构建索引库
|
||||
|
||||
下载demo数据集,命令如下:
|
||||
```shell
|
||||
# 下载 demo 数据并解压
|
||||
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v2.0.tar && tar -xf drink_dataset_v2.0.tar
|
||||
```
|
||||
|
||||
解压完毕后,`drink_dataset_v2.0/` 文件夹下应有如下文件结构:
|
||||
|
||||
```log
|
||||
├── drink_dataset_v2.0/
|
||||
│ ├── gallery/
|
||||
│ ├── index/
|
||||
│ ├── index_all/
|
||||
│ └── test_images/
|
||||
├── ...
|
||||
```
|
||||
|
||||
其中 `gallery` 文件夹中存放的是用于构建索引库的原始图像,`index` 表示基于原始图像构建得到的索引库信息,`test_images` 文件夹中存放的是用于测试识别效果的图像列表。
|
||||
|
||||
|
||||
|
||||
|
||||
**在Python代码中构建索引库**
|
||||
```python
|
||||
from paddleclas import PaddleClas
|
||||
build = PaddleClas(
|
||||
build_gallery=True,
|
||||
gallery_image_root='./drink_dataset_v2.0/gallery/',
|
||||
gallery_data_file='./drink_dataset_v2.0/gallery/drink_label.txt',
|
||||
index_dir='./drink_dataset_v2.0/index')
|
||||
```
|
||||
参数说明:
|
||||
- build_gallery:是否使用索引库构建模式,默认为`False`。
|
||||
- gallery_image_root:构建索引库使用的`gallery`图像地址。
|
||||
- gallery_data_file:构建索引库图像的真值文件。
|
||||
- index_dir:索引库存放地址。
|
||||
|
||||
|
||||
**在命令行中构建索引库**
|
||||
```shell
|
||||
paddleclas --build_gallery=True --model_name="PP-ShiTuV2" \
|
||||
-o IndexProcess.image_root=./drink_dataset_v2.0/gallery/ \
|
||||
-o IndexProcess.index_dir=./drink_dataset_v2.0/index \
|
||||
-o IndexProcess.data_file=./drink_dataset_v2.0/gallery/drink_label.txt
|
||||
```
|
||||
其中参数`build_gallery(bool)`控制是否使用索引库构建模式,默认为`False`。
|
||||
|
||||
同时可以通过`-o`指令更改构建索引库使用的配置,字段说明如下:
|
||||
|
||||
- IndexProcess.image_root(str): 构建索引库使用的`gallery`图像地址。
|
||||
- IndexProcess.index_dir(str): 索引库存放地址。
|
||||
- IndexProcess.data_file(str): 构建索引库图像的真值文件。
|
||||
|
||||
|
||||
<a name="2.2"></a>
|
||||
|
||||
### 2.2 瓶装饮料识别
|
||||
|
||||
体验瓶装饮料识别,对图像`./drink_dataset_v2.0/test_images/001.jpeg`进行识别与检索。
|
||||
|
||||
待检索图像如下:
|
||||

|
||||
|
||||
**在Python代码中进行识别和检索**
|
||||
```python
|
||||
from paddleclas import PaddleClas
|
||||
clas = PaddleClas(model_name='PP-ShiTuV2',
|
||||
index_dir='./drink_dataset_v2.0/index')
|
||||
infer_imgs='./drink_dataset_v2.0/test_images/001.jpeg'
|
||||
result=clas.predict(infer_imgs, predict_type='shitu')
|
||||
print(next(result))
|
||||
```
|
||||
参数说明:
|
||||
- model_name(str):用于检索和识别的模型。
|
||||
- index_dir(str):用于检索的索引库地址。
|
||||
|
||||
最终输出结果如下:
|
||||
```
|
||||
[{'bbox': [437, 71, 660, 728], 'rec_docs': '元气森林', 'rec_scores': 0.7740249}, {'bbox': [221, 72, 449, 701], 'rec_docs': '元气森林', 'rec_scores': 0.6950992}, {'bbox': [794, 104, 979, 652], 'rec_docs': '元气森林', 'rec_scores': 0.6305153}]
|
||||
```
|
||||
|
||||
**在命令行中进行识别和检索**
|
||||
```shell
|
||||
paddleclas --model_name=PP-ShiTuV2 --predict_type=shitu \
|
||||
-o Global.infer_imgs='./drink_dataset_v2.0/test_images/001.jpeg' \
|
||||
-o IndexProcess.index_dir='./drink_dataset_v2.0/index'
|
||||
```
|
||||
其中参数`model_name`为用于检索和识别的模型、`predict_type`设置为'shitu'模式。
|
||||
|
||||
同时可以通过`-o`指令更改检索图像以及索引库,字段说明如下:
|
||||
- Global.infer_imgs(str):待检索图像地址。
|
||||
- IndexProcess.index_dir(str): 索引库存放地址。
|
||||
|
||||
最终输出结果如下:
|
||||
```
|
||||
[{'bbox': [437, 71, 660, 728], 'rec_docs': '元气森林', 'rec_scores': 0.7740249}, {'bbox': [221, 72, 449, 701], 'rec_docs': '元气森林', 'rec_scores': 0.6950992}, {'bbox': [794, 104, 979, 652], 'rec_docs': '元气森林', 'rec_scores': 0.6305153}], filename: ./drink_dataset_v2.0/test_images/100.jpeg
|
||||
```
|
111
paddleclas.py
111
paddleclas.py
|
@ -33,6 +33,7 @@ from .ppcls.utils import logger
|
|||
|
||||
from .deploy.python.predict_cls import ClsPredictor
|
||||
from .deploy.python.predict_system import SystemPredictor
|
||||
from .deploy.python.build_gallery import GalleryBuilder
|
||||
from .deploy.utils.get_image_list import get_image_list
|
||||
from .deploy.utils import config
|
||||
|
||||
|
@ -196,7 +197,8 @@ PULC_MODEL_BASE_DOWNLOAD_URL = "https://paddleclas.bj.bcebos.com/models/PULC/inf
|
|||
PULC_MODELS = [
|
||||
"car_exists", "language_classification", "person_attribute",
|
||||
"person_exists", "safety_helmet", "text_image_orientation",
|
||||
"textline_orientation", "traffic_sign", "vehicle_attribute"
|
||||
"textline_orientation", "traffic_sign", "vehicle_attribute",
|
||||
"table_attribute"
|
||||
]
|
||||
|
||||
SHITU_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/{}_infer.tar"
|
||||
|
@ -226,7 +228,9 @@ class InputModelError(Exception):
|
|||
|
||||
def init_config(model_type, model_name, inference_model_dir, **kwargs):
|
||||
|
||||
if model_type == "pulc":
|
||||
if kwargs.get("build_gallery", False):
|
||||
cfg_path = "deploy/configs/inference_general.yaml"
|
||||
elif model_type == "pulc":
|
||||
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml"
|
||||
elif model_type == "shitu":
|
||||
cfg_path = "deploy/configs/inference_general.yaml"
|
||||
|
@ -235,7 +239,8 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
|
|||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
cfg_path = os.path.join(__dir__, cfg_path)
|
||||
cfg = config.get_config(cfg_path, show=False)
|
||||
cfg = config.get_config(
|
||||
cfg_path, overrides=kwargs.get("override", None), show=False)
|
||||
if cfg.Global.get("inference_model_dir"):
|
||||
cfg.Global.inference_model_dir = inference_model_dir
|
||||
else:
|
||||
|
@ -282,6 +287,7 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
|
|||
if "thresh" in kwargs and kwargs[
|
||||
"thresh"] and "ThreshOutput" in cfg.PostProcess:
|
||||
cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
|
||||
|
||||
if cfg.get("PostProcess"):
|
||||
if "Topk" in cfg.PostProcess:
|
||||
if "topk" in kwargs and kwargs["topk"]:
|
||||
|
@ -301,7 +307,26 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
|
|||
if "type_threshold" in kwargs and kwargs["type_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
|
||||
"type_threshold"]
|
||||
|
||||
if "TableAttribute" in cfg.PostProcess:
|
||||
if "source_threshold" in kwargs and kwargs["source_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"source_threshold"]
|
||||
if "number_threshold" in kwargs and kwargs["number_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"number_threshold"]
|
||||
if "color_threshold" in kwargs and kwargs["color_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"color_threshold"]
|
||||
if "clarity_threshold" in kwargs and kwargs["clarity_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"clarity_threshold"]
|
||||
if "obstruction_threshold" in kwargs and kwargs[
|
||||
"obstruction_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"obstruction_threshold"]
|
||||
if "angle_threshold" in kwargs and kwargs["angle_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"angle_threshold"]
|
||||
if "save_dir" in kwargs and kwargs["save_dir"]:
|
||||
cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"]
|
||||
|
||||
|
@ -316,10 +341,15 @@ def args_cfg():
|
|||
parser.add_argument(
|
||||
"--infer_imgs",
|
||||
type=str,
|
||||
required=True,
|
||||
required=False,
|
||||
help="The image(s) to be predicted.")
|
||||
parser.add_argument(
|
||||
"--model_name", type=str, help="The model name to be used.")
|
||||
parser.add_argument(
|
||||
"--predict_type",
|
||||
type=str,
|
||||
default="cls",
|
||||
help="The predict type to be selected.")
|
||||
parser.add_argument(
|
||||
"--inference_model_dir",
|
||||
type=str,
|
||||
|
@ -374,7 +404,17 @@ def args_cfg():
|
|||
parser.add_argument(
|
||||
"--resize_short", type=int, help="Resize according to short size.")
|
||||
parser.add_argument("--crop_size", type=int, help="Centor crop size.")
|
||||
|
||||
parser.add_argument(
|
||||
"--build_gallery",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether build gallery.")
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--override',
|
||||
action='append',
|
||||
default=[],
|
||||
help='config options to be overridden')
|
||||
args = parser.parse_args()
|
||||
return vars(args)
|
||||
|
||||
|
@ -514,6 +554,10 @@ class PaddleClas(object):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
build_gallery: bool=False,
|
||||
gallery_image_root: str=None,
|
||||
gallery_data_file: str=None,
|
||||
index_dir: str=None,
|
||||
model_name: str=None,
|
||||
inference_model_dir: str=None,
|
||||
**kwargs):
|
||||
|
@ -528,14 +572,35 @@ class PaddleClas(object):
|
|||
"""
|
||||
super().__init__()
|
||||
|
||||
self.model_type, inference_model_dir = self._check_input_model(
|
||||
model_name, inference_model_dir)
|
||||
self._config = init_config(self.model_type, model_name,
|
||||
inference_model_dir, **kwargs)
|
||||
if self.model_type == "shitu":
|
||||
self.predictor = SystemPredictor(self._config)
|
||||
if build_gallery:
|
||||
self.model_type, inference_model_dir = self._check_input_model(
|
||||
model_name
|
||||
if model_name else "PP-ShiTuV2", inference_model_dir)
|
||||
self._config = init_config(self.model_type, model_name
|
||||
if model_name else "PP-ShiTuV2",
|
||||
inference_model_dir, **kwargs)
|
||||
if gallery_image_root:
|
||||
self._config.IndexProcess.image_root = gallery_image_root
|
||||
if gallery_data_file:
|
||||
self._config.IndexProcess.data_file = gallery_data_file
|
||||
if index_dir:
|
||||
self._config.IndexProcess.index_dir = index_dir
|
||||
|
||||
logger.info("Building Gallery...")
|
||||
GalleryBuilder(self._config)
|
||||
|
||||
else:
|
||||
self.predictor = ClsPredictor(self._config)
|
||||
self.model_type, inference_model_dir = self._check_input_model(
|
||||
model_name, inference_model_dir)
|
||||
self._config = init_config(self.model_type, model_name,
|
||||
inference_model_dir, **kwargs)
|
||||
|
||||
if self.model_type == "shitu":
|
||||
if index_dir:
|
||||
self._config.IndexProcess.index_dir = index_dir
|
||||
self.predictor = SystemPredictor(self._config)
|
||||
else:
|
||||
self.predictor = ClsPredictor(self._config)
|
||||
|
||||
def get_config(self):
|
||||
"""Get the config.
|
||||
|
@ -679,6 +744,9 @@ class PaddleClas(object):
|
|||
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
|
||||
The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
|
||||
"""
|
||||
if input_data == None and self._config.Global.infer_imgs:
|
||||
input_data = self._config.Global.infer_imgs
|
||||
|
||||
if isinstance(input_data, np.ndarray):
|
||||
yield self.predictor.predict(input_data)
|
||||
elif isinstance(input_data, str):
|
||||
|
@ -721,6 +789,8 @@ class PaddleClas(object):
|
|||
input_data: Union[str, np.array],
|
||||
print_pred: bool=False,
|
||||
predict_type="cls"):
|
||||
assert predict_type in ["cls", "shitu"
|
||||
], "Predict type should be 'cls' or 'shitu'."
|
||||
if predict_type == "cls":
|
||||
return self.predict_cls(input_data, print_pred)
|
||||
elif predict_type == "shitu":
|
||||
|
@ -739,13 +809,14 @@ def main():
|
|||
print_info()
|
||||
cfg = args_cfg()
|
||||
clas_engine = PaddleClas(**cfg)
|
||||
res = clas_engine.predict(
|
||||
cfg["infer_imgs"],
|
||||
print_pred=True,
|
||||
predict_type="cls" if "PP-ShiTu" not in cfg["model_name"] else "shitu")
|
||||
for _ in res:
|
||||
pass
|
||||
logger.info("Predict complete!")
|
||||
if cfg["build_gallery"] == False:
|
||||
res = clas_engine.predict(
|
||||
cfg["infer_imgs"],
|
||||
print_pred=True,
|
||||
predict_type=cfg["predict_type"])
|
||||
for _ in res:
|
||||
pass
|
||||
logger.info("Predict complete!")
|
||||
return
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue