# Copyright (c) Alibaba, Inc. and its affiliates. import os.path from modelscope.pipelines.builder import PIPELINES from modelscope.utils.constant import ModelFile, Tasks from easycv.toolkit.modelscope.metainfo import EasyCVPipelines as Pipelines from .base import EasyCVPipeline @PIPELINES.register_module( Tasks.hand_2d_keypoints, module_name=Pipelines.hand_2d_keypoints) class Hand2DKeypointsPipeline(EasyCVPipeline): """Pipeline for hand pose keypoint task.""" def __init__(self, model: str, model_file_pattern=ModelFile.TORCH_MODEL_FILE, *args, **kwargs): """ model (str): model id on modelscope hub or local model path. model_file_pattern (str): model file pattern. """ super(Hand2DKeypointsPipeline, self).__init__( model=model, model_file_pattern=model_file_pattern, *args, **kwargs) def _build_predict_op(self, **kwargs): """Build EasyCV predictor.""" from easycv.predictors.builder import build_predictor detection_predictor_type = self.cfg['DETECTION']['type'] detection_model_path = os.path.join( self.model_dir, self.cfg['DETECTION']['model_path']) detection_cfg_file = os.path.join(self.model_dir, self.cfg['DETECTION']['config_file']) detection_score_threshold = self.cfg['DETECTION']['score_threshold'] self.cfg.pipeline.predictor_config[ 'detection_predictor_config'] = dict( type=detection_predictor_type, model_path=detection_model_path, config_file=detection_cfg_file, score_threshold=detection_score_threshold) easycv_config = self._to_easycv_config() pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, { 'model_path': self.model_path, 'config_file': easycv_config, **kwargs }) return pipeline_op