EasyCV/easycv/toolkit/modelscope/pipelines/segmentation_pipeline.py
Cathy0908 5b487e4977
add easycv plugin to modelscope (#303)
* add plugin for modelscope
2023-05-09 11:20:04 +08:00

48 lines
1.6 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any
import numpy as np
from modelscope.outputs import OutputKeys
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from easycv.toolkit.modelscope.metainfo import EasyCVPipelines as Pipelines
from .base import EasyCVPipeline
@PIPELINES.register_module(
Tasks.image_segmentation, module_name=Pipelines.easycv_segmentation)
class EasyCVSegmentationPipeline(EasyCVPipeline):
"""Pipeline for easycv segmentation task."""
def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs):
"""
model (str): model id on modelscope hub or local model path.
model_file_pattern (str): model file pattern.
"""
super(EasyCVSegmentationPipeline, self).__init__(
model=model,
model_file_pattern=model_file_pattern,
*args,
**kwargs)
def __call__(self, inputs) -> Any:
outputs = self.predict_op(inputs)
semantic_result = outputs[0]['seg_pred']
ids = np.unique(semantic_result)[::-1]
legal_indices = ids != len(self.predict_op.CLASSES) # for VOID label
ids = ids[legal_indices]
segms = (semantic_result[None] == ids[:, None, None])
masks = [it.astype(np.int) for it in segms]
labels_txt = np.array(self.predict_op.CLASSES)[ids].tolist()
results = {
OutputKeys.MASKS: masks,
OutputKeys.LABELS: labels_txt,
OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))]
}
return results