EasyCV/easycv/toolkit/modelscope/pipelines/hand_2d_keypoints_pipeline.py
Cathy0908 5b487e4977
add easycv plugin to modelscope (#303)
* add plugin for modelscope
2023-05-09 11:20:04 +08:00

53 lines
2.0 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import ModelFile, Tasks
from easycv.toolkit.modelscope.metainfo import EasyCVPipelines as Pipelines
from .base import EasyCVPipeline
@PIPELINES.register_module(
Tasks.hand_2d_keypoints, module_name=Pipelines.hand_2d_keypoints)
class Hand2DKeypointsPipeline(EasyCVPipeline):
"""Pipeline for hand pose keypoint task."""
def __init__(self,
model: str,
model_file_pattern=ModelFile.TORCH_MODEL_FILE,
*args,
**kwargs):
"""
model (str): model id on modelscope hub or local model path.
model_file_pattern (str): model file pattern.
"""
super(Hand2DKeypointsPipeline, self).__init__(
model=model,
model_file_pattern=model_file_pattern,
*args,
**kwargs)
def _build_predict_op(self, **kwargs):
"""Build EasyCV predictor."""
from easycv.predictors.builder import build_predictor
detection_predictor_type = self.cfg['DETECTION']['type']
detection_model_path = os.path.join(
self.model_dir, self.cfg['DETECTION']['model_path'])
detection_cfg_file = os.path.join(self.model_dir,
self.cfg['DETECTION']['config_file'])
detection_score_threshold = self.cfg['DETECTION']['score_threshold']
self.cfg.pipeline.predictor_config[
'detection_predictor_config'] = dict(
type=detection_predictor_type,
model_path=detection_model_path,
config_file=detection_cfg_file,
score_threshold=detection_score_threshold)
easycv_config = self._to_easycv_config()
pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, {
'model_path': self.model_path,
'config_file': easycv_config,
**kwargs
})
return pipeline_op