mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
53 lines
2.0 KiB
Python
53 lines
2.0 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os.path
|
|
|
|
from modelscope.pipelines.builder import PIPELINES
|
|
from modelscope.utils.constant import ModelFile, Tasks
|
|
|
|
from easycv.toolkit.modelscope.metainfo import EasyCVPipelines as Pipelines
|
|
from .base import EasyCVPipeline
|
|
|
|
|
|
@PIPELINES.register_module(
|
|
Tasks.hand_2d_keypoints, module_name=Pipelines.hand_2d_keypoints)
|
|
class Hand2DKeypointsPipeline(EasyCVPipeline):
|
|
"""Pipeline for hand pose keypoint task."""
|
|
|
|
def __init__(self,
|
|
model: str,
|
|
model_file_pattern=ModelFile.TORCH_MODEL_FILE,
|
|
*args,
|
|
**kwargs):
|
|
"""
|
|
model (str): model id on modelscope hub or local model path.
|
|
model_file_pattern (str): model file pattern.
|
|
"""
|
|
super(Hand2DKeypointsPipeline, self).__init__(
|
|
model=model,
|
|
model_file_pattern=model_file_pattern,
|
|
*args,
|
|
**kwargs)
|
|
|
|
def _build_predict_op(self, **kwargs):
|
|
"""Build EasyCV predictor."""
|
|
from easycv.predictors.builder import build_predictor
|
|
detection_predictor_type = self.cfg['DETECTION']['type']
|
|
detection_model_path = os.path.join(
|
|
self.model_dir, self.cfg['DETECTION']['model_path'])
|
|
detection_cfg_file = os.path.join(self.model_dir,
|
|
self.cfg['DETECTION']['config_file'])
|
|
detection_score_threshold = self.cfg['DETECTION']['score_threshold']
|
|
self.cfg.pipeline.predictor_config[
|
|
'detection_predictor_config'] = dict(
|
|
type=detection_predictor_type,
|
|
model_path=detection_model_path,
|
|
config_file=detection_cfg_file,
|
|
score_threshold=detection_score_threshold)
|
|
easycv_config = self._to_easycv_config()
|
|
pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, {
|
|
'model_path': self.model_path,
|
|
'config_file': easycv_config,
|
|
**kwargs
|
|
})
|
|
return pipeline_op
|