mirror of https://github.com/alibaba/EasyCV.git
update some predcitors, support batch inference (#195)
update some predcitors, support batch inferencepull/198/head
parent
bb53e066be
commit
5dfe7b2898
|
@ -137,6 +137,3 @@ pai_jobs/easycv/resources/
|
|||
*.tar.gz
|
||||
thirdparty/test
|
||||
scripts/test
|
||||
|
||||
# easycv default cache dir
|
||||
.easycv_cache
|
||||
|
|
|
@ -86,3 +86,13 @@ checkpoint_config = dict(interval=10)
|
|||
|
||||
# runtime settings
|
||||
total_epochs = 100
|
||||
|
||||
predict = dict(
|
||||
type='ClassificationPredictor',
|
||||
pipelines=[
|
||||
dict(type='Resize', size=256),
|
||||
dict(type='CenterCrop', size=224),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
|
|
|
@ -27,7 +27,7 @@ def load_image(img_path):
|
|||
|
||||
|
||||
def load_seg_map(seg_path, reduce_zero_label):
|
||||
gt_semantic_seg = _load_img(seg_path, mode='RGB')
|
||||
gt_semantic_seg = _load_img(seg_path, mode='P')
|
||||
# reduce zero_label
|
||||
if reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
@ -6,10 +7,10 @@ import cv2
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from easycv.file import io
|
||||
from easycv import file
|
||||
from easycv.framework.errors import IOError
|
||||
from easycv.utils.constant import MAX_READ_IMAGE_TRY_TIMES
|
||||
from .utils import is_oss_path
|
||||
from .utils import is_oss_path, is_url_path
|
||||
|
||||
|
||||
def load_image(img_path, mode='BGR', max_try_times=MAX_READ_IMAGE_TRY_TIMES):
|
||||
|
@ -20,16 +21,31 @@ def load_image(img_path, mode='BGR', max_try_times=MAX_READ_IMAGE_TRY_TIMES):
|
|||
img = None
|
||||
while try_cnt < max_try_times:
|
||||
try:
|
||||
with io.open(img_path, 'rb') as infile:
|
||||
# cv2.imdecode may corrupt when the img is broken
|
||||
image = Image.open(infile) # RGB
|
||||
if is_url_path(img_path):
|
||||
from mmcv.fileio.file_client import HTTPBackend
|
||||
client = HTTPBackend()
|
||||
img_bytes = client.get(img_path)
|
||||
buff = io.BytesIO(img_bytes)
|
||||
image = Image.open(buff)
|
||||
if mode.upper() != 'BGR' and image.mode.upper() != mode.upper(
|
||||
):
|
||||
image = image.convert(mode.upper())
|
||||
img = np.asarray(image, dtype=np.uint8)
|
||||
if mode.upper() == 'BGR':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
assert mode.upper() in ['RGB', 'BGR'
|
||||
], 'Only support `RGB` and `BGR` mode!'
|
||||
assert img is not None
|
||||
break
|
||||
else:
|
||||
with file.io.open(img_path, 'rb') as infile:
|
||||
# cv2.imdecode may corrupt when the img is broken
|
||||
image = Image.open(infile)
|
||||
if mode.upper() != 'BGR' and image.mode.upper(
|
||||
) != mode.upper():
|
||||
image = image.convert(mode.upper())
|
||||
img = np.asarray(image, dtype=np.uint8)
|
||||
|
||||
if mode.upper() == 'BGR':
|
||||
if image.mode.upper() != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
assert img is not None
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
logging.warning('Read file {} fault, try count : {}'.format(
|
||||
|
|
|
@ -13,7 +13,7 @@ from tqdm import tqdm
|
|||
from easycv.framework.errors import ValueError
|
||||
|
||||
OSS_PREFIX = 'oss://'
|
||||
URL_PREFIX = 'https://'
|
||||
URL_PREFIX = ('https://', 'http://')
|
||||
|
||||
|
||||
def create_namedtuple(**kwargs):
|
||||
|
@ -33,6 +33,7 @@ def url_path_exists(url):
|
|||
urllib.request.urlopen(url).code
|
||||
except Exception as err:
|
||||
print(err)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
@ -9,5 +9,4 @@ from .feature_extractor import (TorchFaceAttrExtractor,
|
|||
from .hand_keypoints_predictor import HandKeypointsPredictor
|
||||
from .pose_predictor import (TorchPoseTopDownPredictor,
|
||||
TorchPoseTopDownPredictorWithDetector)
|
||||
from .segmentation import (Mask2formerPredictor, SegFormerPredictor,
|
||||
SegmentationPredictor)
|
||||
from .segmentation import Mask2formerPredictor, SegmentationPredictor
|
||||
|
|
|
@ -1,19 +1,23 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import collate, scatter_kwargs
|
||||
from PIL import Image
|
||||
from torch.hub import load_state_dict_from_url
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from easycv.datasets.registry import PIPELINES
|
||||
from easycv.file import io
|
||||
from easycv.file.utils import is_url_path
|
||||
from easycv.framework.errors import ValueError
|
||||
from easycv.models.builder import build_model
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from easycv.utils.config_tools import Config, mmcv_config_fromfile
|
||||
from easycv.utils.constant import CACHE_DIR
|
||||
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
|
||||
remove_adapt_for_mmlab)
|
||||
|
@ -107,7 +111,9 @@ class PredictorV2(object):
|
|||
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
||||
save_results (bool): Whether to save predict results.
|
||||
save_path (str): File path for saving results, only valid when `save_results` is True.
|
||||
pipelines (list[dict]): Data pipeline configs.
|
||||
"""
|
||||
INPUT_IMAGE_MODE = 'BGR' # the image mode into the model
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
|
@ -116,30 +122,51 @@ class PredictorV2(object):
|
|||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
mode='rgb',
|
||||
pipelines=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
self.model_path = model_path
|
||||
self.batch_size = batch_size
|
||||
self.save_results = save_results
|
||||
self.save_path = save_path
|
||||
self.config_file = config_file
|
||||
if self.save_results:
|
||||
assert self.save_path is not None
|
||||
self.device = device
|
||||
if self.device is None:
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
self.cfg = None
|
||||
if config_file is not None:
|
||||
if isinstance(config_file, str):
|
||||
self.cfg = mmcv_config_fromfile(config_file)
|
||||
else:
|
||||
self.cfg = config_file
|
||||
else:
|
||||
self.cfg = self._load_cfg_from_ckpt(self.model_path)
|
||||
|
||||
if self.cfg is None:
|
||||
raise ValueError('Please provide "config_file"!')
|
||||
|
||||
self.model = self.prepare_model()
|
||||
self.pipelines = pipelines
|
||||
self.processor = self.build_processor()
|
||||
self._load_op = None
|
||||
self.mode = mode
|
||||
|
||||
def _load_cfg_from_ckpt(self, model_path):
|
||||
if is_url_path(model_path):
|
||||
ckpt = load_state_dict_from_url(model_path)
|
||||
else:
|
||||
with io.open(model_path, 'rb') as infile:
|
||||
ckpt = torch.load(infile, map_location='cpu')
|
||||
|
||||
cfg = None
|
||||
if 'meta' in ckpt and 'config' in ckpt['meta']:
|
||||
cfg = ckpt['meta']['config']
|
||||
if isinstance(cfg, dict):
|
||||
cfg = Config(cfg)
|
||||
elif isinstance(cfg, str):
|
||||
cfg = Config(json.loads(cfg))
|
||||
return cfg
|
||||
|
||||
def prepare_model(self):
|
||||
"""Build model from config file by default.
|
||||
|
@ -152,8 +179,6 @@ class PredictorV2(object):
|
|||
return model
|
||||
|
||||
def _build_model(self):
|
||||
if self.cfg is None:
|
||||
raise ValueError('Please provide "config_file"!')
|
||||
# Use mmdet model
|
||||
dynamic_adapt_for_mmlab(self.cfg)
|
||||
model = build_model(self.cfg.model)
|
||||
|
@ -165,16 +190,15 @@ class PredictorV2(object):
|
|||
"""Build processor to process loaded input.
|
||||
If you need custom preprocessing ops, you need to reimplement it.
|
||||
"""
|
||||
if self.cfg is None:
|
||||
pipeline = []
|
||||
if self.pipelines is not None:
|
||||
pipelines = self.pipelines
|
||||
else:
|
||||
pipeline = [
|
||||
build_from_cfg(p, PIPELINES)
|
||||
for p in self.cfg.get('test_pipeline', [])
|
||||
]
|
||||
pipelines = self.cfg.get('test_pipeline', [])
|
||||
|
||||
pipelines = [build_from_cfg(p, PIPELINES) for p in pipelines]
|
||||
|
||||
from easycv.datasets.shared.pipelines.transforms import Compose
|
||||
processor = Compose(pipeline)
|
||||
processor = Compose(pipelines)
|
||||
return processor
|
||||
|
||||
def _load_input(self, input):
|
||||
|
@ -190,10 +214,13 @@ class PredictorV2(object):
|
|||
}
|
||||
"""
|
||||
if self._load_op is None:
|
||||
load_cfg = dict(type='LoadImage', mode=self.mode)
|
||||
load_cfg = dict(type='LoadImage', mode=self.INPUT_IMAGE_MODE)
|
||||
self._load_op = build_from_cfg(load_cfg, PIPELINES)
|
||||
|
||||
if not isinstance(input, str):
|
||||
if isinstance(input, np.ndarray):
|
||||
# Only support RGB mode if input is np.ndarray.
|
||||
input = cv2.cvtColor(input, cv2.COLOR_RGB2BGR)
|
||||
sample = self._load_op({'img': input})
|
||||
else:
|
||||
sample = self._load_op({'filename': input})
|
||||
|
@ -229,8 +256,32 @@ class PredictorV2(object):
|
|||
return outputs
|
||||
|
||||
def postprocess(self, inputs, *args, **kwargs):
|
||||
"""Process model outputs.
|
||||
If you need add some processing ops to process model outputs, you need to reimplement it.
|
||||
"""Process model batch outputs.
|
||||
"""
|
||||
outputs = []
|
||||
out_i = {}
|
||||
batch_size = 1
|
||||
# get current batch size
|
||||
for k, batch_v in inputs.items():
|
||||
if batch_v is not None:
|
||||
batch_size = len(batch_v)
|
||||
break
|
||||
|
||||
for i in range(batch_size):
|
||||
for k, batch_v in inputs.items():
|
||||
if batch_v is not None:
|
||||
out_i[k] = batch_v[i]
|
||||
else:
|
||||
out_i[k] = None
|
||||
|
||||
out_i = self.postprocess_single(out_i)
|
||||
outputs.append(out_i)
|
||||
|
||||
return outputs
|
||||
|
||||
def postprocess_single(self, inputs):
|
||||
"""Process outputs of single sample.
|
||||
If you need add some processing ops, you need to reimplement it.
|
||||
"""
|
||||
return inputs
|
||||
|
||||
|
@ -260,16 +311,22 @@ class PredictorV2(object):
|
|||
|
||||
results_list = []
|
||||
for i in range(0, len(inputs), self.batch_size):
|
||||
batch = inputs[i:max(len(inputs) - 1, i + self.batch_size)]
|
||||
batch = inputs[i:min(len(inputs), i + self.batch_size)]
|
||||
batch_outputs = self.preprocess(batch)
|
||||
batch_outputs = self.forward(batch_outputs)
|
||||
results = self.postprocess(batch_outputs)
|
||||
assert len(results) == len(
|
||||
batch), f'Mismatch size {len(results)} != {len(batch)}'
|
||||
if keep_inputs:
|
||||
results = {'inputs': batch, 'results': results}
|
||||
for i in range(len(batch)):
|
||||
results[i].update({'inputs': batch[i]})
|
||||
# if dump, the outputs will not added to the return value to prevent taking up too much memory
|
||||
if self.save_results:
|
||||
self.dump([results], self.save_path, mode='ab+')
|
||||
self.dump(results, self.save_path, mode='ab+')
|
||||
else:
|
||||
results_list.append(results)
|
||||
if isinstance(results, list):
|
||||
results_list.extend(results)
|
||||
else:
|
||||
results_list.append(results)
|
||||
|
||||
return results_list
|
||||
|
|
|
@ -3,17 +3,130 @@ import math
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageFile
|
||||
|
||||
from easycv.file import io
|
||||
from easycv.framework.errors import ValueError
|
||||
from .base import Predictor
|
||||
from easycv.utils.misc import deprecated
|
||||
from .base import Predictor, PredictorV2
|
||||
from .builder import PREDICTORS
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class ClassificationPredictor(PredictorV2):
|
||||
"""Predictor for classification.
|
||||
Args:
|
||||
model_path (str): Path of model path.
|
||||
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
||||
batch_size (int): batch size for forward.
|
||||
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
||||
save_results (bool): Whether to save predict results.
|
||||
save_path (str): File path for saving results, only valid when `save_results` is True.
|
||||
pipelines (list[dict]): Data pipeline configs.
|
||||
topk (int): Return top-k results. Default: 1.
|
||||
pil_input (bool): Whether use PIL image. If processor need PIL input, set true, default false.
|
||||
label_map_path (str): File path of saving labels list.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
config_file=None,
|
||||
batch_size=1,
|
||||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
pipelines=[],
|
||||
topk=1,
|
||||
pil_input=True,
|
||||
label_map_path=[],
|
||||
*args,
|
||||
**kwargs):
|
||||
super(ClassificationPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file=config_file,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
save_results=save_results,
|
||||
save_path=save_path,
|
||||
pipelines=pipelines,
|
||||
*args,
|
||||
**kwargs)
|
||||
self.topk = topk
|
||||
self.pil_input = pil_input
|
||||
|
||||
# Adapt to torchvision transforms which process PIL inputs.
|
||||
if self.pil_input:
|
||||
self.INPUT_IMAGE_MODE = 'RGB'
|
||||
|
||||
if label_map_path is None:
|
||||
class_list = self.cfg.get('CLASSES', [])
|
||||
else:
|
||||
with io.open(label_map_path, 'r') as f:
|
||||
class_list = f.readlines()
|
||||
self.label_map = [i.strip() for i in class_list]
|
||||
|
||||
def _load_input(self, input):
|
||||
"""Load image from file or numpy or PIL object.
|
||||
Args:
|
||||
input: File path or numpy or PIL object.
|
||||
Returns:
|
||||
{
|
||||
'filename': filename,
|
||||
'img': img,
|
||||
'img_shape': img_shape,
|
||||
'img_fields': ['img']
|
||||
}
|
||||
"""
|
||||
if self.pil_input:
|
||||
results = {}
|
||||
if isinstance(input, str):
|
||||
img = Image.open(input)
|
||||
if img.mode.upper() != self.INPUT_IMAGE_MODE.upper():
|
||||
img = img.convert(self.INPUT_IMAGE_MODE.upper())
|
||||
results['filename'] = input
|
||||
else:
|
||||
assert isinstance(input, ImageFile.ImageFile)
|
||||
img = input
|
||||
results['filename'] = None
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.size
|
||||
results['ori_shape'] = img.size
|
||||
results['img_fields'] = ['img']
|
||||
return results
|
||||
|
||||
return super()._load_input(input)
|
||||
|
||||
def postprocess(self, inputs, *args, **kwargs):
|
||||
"""Return top-k results."""
|
||||
output_prob = inputs['prob'].data.cpu()
|
||||
topk_class = torch.topk(output_prob, self.topk).indices.numpy()
|
||||
output_prob = output_prob.numpy()
|
||||
batch_results = []
|
||||
batch_size = output_prob.shape[0]
|
||||
for i in range(batch_size):
|
||||
result = {'class': np.squeeze(topk_class[i]).tolist()}
|
||||
if isinstance(result['class'], int):
|
||||
result['class'] = [result['class']]
|
||||
|
||||
if len(self.label_map) > 0:
|
||||
result['class_name'] = [
|
||||
self.label_map[i] for i in result['class']
|
||||
]
|
||||
result['class_probs'] = {}
|
||||
for l_idx, l_name in enumerate(self.label_map):
|
||||
result['class_probs'][l_name] = output_prob[i][l_idx]
|
||||
|
||||
batch_results.append(result)
|
||||
return batch_results
|
||||
|
||||
|
||||
try:
|
||||
from easy_vision.python.inference.predictor import PredictorInterface
|
||||
except:
|
||||
from .interface import PredictorInterface
|
||||
|
||||
|
||||
@deprecated(reason='Please use ClassificationPredictor.')
|
||||
@PREDICTORS.register_module()
|
||||
class TorchClassifier(PredictorInterface):
|
||||
|
||||
|
|
|
@ -5,9 +5,6 @@ from glob import glob
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.ops import RoIPool
|
||||
from mmcv.parallel import collate, scatter
|
||||
from torch.hub import load_state_dict_from_url
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from easycv.apis.export import reparameterize_models
|
||||
|
@ -15,16 +12,12 @@ from easycv.core.visualization import imshow_bboxes
|
|||
from easycv.datasets.registry import PIPELINES
|
||||
from easycv.datasets.utils import replace_ImageToTensor
|
||||
from easycv.file import io
|
||||
from easycv.file.utils import is_url_path, url_path_exists
|
||||
from easycv.framework.errors import TypeError
|
||||
from easycv.models import build_model
|
||||
from easycv.models.detection.utils import postprocess
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from easycv.utils.constant import CACHE_DIR
|
||||
from easycv.utils.logger import get_root_logger
|
||||
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
|
||||
remove_adapt_for_mmlab)
|
||||
from easycv.utils.misc import deprecated
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
from .base import PredictorV2
|
||||
from .builder import PREDICTORS
|
||||
|
@ -47,14 +40,16 @@ class DetectionPredictor(PredictorV2):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_path=None,
|
||||
model_path,
|
||||
config_file=None,
|
||||
batch_size=1,
|
||||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
mode='rgb',
|
||||
score_threshold=0.5):
|
||||
pipelines=None,
|
||||
score_threshold=0.5,
|
||||
*arg,
|
||||
**kwargs):
|
||||
super(DetectionPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file=config_file,
|
||||
|
@ -62,194 +57,55 @@ class DetectionPredictor(PredictorV2):
|
|||
device=device,
|
||||
save_results=save_results,
|
||||
save_path=save_path,
|
||||
mode=mode,
|
||||
pipelines=pipelines,
|
||||
)
|
||||
self.score_thresh = score_threshold
|
||||
self.CLASSES = self.cfg.get('CLASSES', None)
|
||||
|
||||
def build_processor(self):
|
||||
if self.pipelines is not None:
|
||||
pipelines = self.pipelines
|
||||
elif self.cfg is None:
|
||||
pipelines = []
|
||||
else:
|
||||
pipelines = self.cfg.get('test_pipeline', [])
|
||||
|
||||
# for batch inference
|
||||
self.pipelines = replace_ImageToTensor(pipelines)
|
||||
|
||||
return super().build_processor()
|
||||
|
||||
def postprocess_single(self, inputs, *args, **kwargs):
|
||||
if inputs['detection_scores'] is None or len(
|
||||
inputs['detection_scores']) < 1:
|
||||
return inputs
|
||||
|
||||
scores = inputs['detection_scores']
|
||||
if scores is not None and self.score_thresh > 0:
|
||||
keeped_ids = scores > self.score_thresh
|
||||
inputs['detection_scores'] = inputs['detection_scores'][keeped_ids]
|
||||
inputs['detection_boxes'] = inputs['detection_boxes'][keeped_ids]
|
||||
inputs['detection_classes'] = inputs['detection_classes'][
|
||||
keeped_ids]
|
||||
|
||||
class_names = []
|
||||
for _, classes_id in enumerate(inputs['detection_classes']):
|
||||
if classes_id is None:
|
||||
class_names.append(None)
|
||||
elif self.CLASSES is not None and len(self.CLASSES) > 0:
|
||||
class_names.append(self.CLASSES[int(classes_id)])
|
||||
else:
|
||||
class_names.append(classes_id)
|
||||
|
||||
inputs['detection_class_names'] = class_names
|
||||
|
||||
def postprocess(self, inputs, *args, **kwargs):
|
||||
for batch_index in range(self.batch_size):
|
||||
this_detection_scores = inputs['detection_scores'][batch_index]
|
||||
sel_ids = this_detection_scores > self.score_thresh
|
||||
inputs['detection_scores'][batch_index] = inputs[
|
||||
'detection_scores'][batch_index][sel_ids]
|
||||
inputs['detection_boxes'][batch_index] = inputs['detection_boxes'][
|
||||
batch_index][sel_ids]
|
||||
inputs['detection_classes'][batch_index] = inputs[
|
||||
'detection_classes'][batch_index][sel_ids]
|
||||
# TODO class label remapping
|
||||
return inputs
|
||||
|
||||
|
||||
class DetrPredictor(PredictorInterface):
|
||||
"""Inference image(s) with the detector.
|
||||
Args:
|
||||
model_path (str): checkpoint model and export model are shared.
|
||||
config_path (str): If config_path is specified, both checkpoint model and export model can be used; if config_path=None, the export model is used by default.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path, config_path=None):
|
||||
|
||||
self.model_path = model_path
|
||||
|
||||
if config_path is not None:
|
||||
self.cfg = mmcv_config_fromfile(config_path)
|
||||
else:
|
||||
logger = get_root_logger()
|
||||
logger.warning('please use export model!')
|
||||
if is_url_path(self.model_path) and url_path_exists(
|
||||
self.model_path):
|
||||
checkpoint = load_state_dict_from_url(model_path)
|
||||
else:
|
||||
assert io.exists(
|
||||
self.model_path), f'{self.model_path} does not exists'
|
||||
|
||||
with io.open(self.model_path, 'rb') as infile:
|
||||
checkpoint = torch.load(infile, map_location='cpu')
|
||||
|
||||
assert 'meta' in checkpoint and 'config' in checkpoint[
|
||||
'meta'], 'meta.config is missing from checkpoint'
|
||||
|
||||
config_str = checkpoint['meta']['config']
|
||||
if isinstance(config_str, dict):
|
||||
config_str = json.dumps(config_str)
|
||||
|
||||
# get config
|
||||
basename = os.path.basename(self.model_path)
|
||||
fname, _ = os.path.splitext(basename)
|
||||
self.local_config_file = os.path.join(CACHE_DIR,
|
||||
f'{fname}_config.json')
|
||||
if not os.path.exists(CACHE_DIR):
|
||||
os.makedirs(CACHE_DIR)
|
||||
with open(self.local_config_file, 'w') as ofile:
|
||||
ofile.write(config_str)
|
||||
self.cfg = mmcv_config_fromfile(self.local_config_file)
|
||||
|
||||
# dynamic adapt mmdet models
|
||||
dynamic_adapt_for_mmlab(self.cfg)
|
||||
|
||||
# build model
|
||||
self.model = build_model(self.cfg.model)
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
map_location = 'cpu' if self.device == 'cpu' else 'cuda'
|
||||
self.ckpt = load_checkpoint(
|
||||
self.model, self.model_path, map_location=map_location)
|
||||
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
self.CLASSES = self.cfg.CLASSES
|
||||
|
||||
def predict(self, imgs):
|
||||
"""
|
||||
Args:
|
||||
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
||||
Either image files or loaded images.
|
||||
Returns:
|
||||
If imgs is a list or tuple, the same length list type results
|
||||
will be returned, otherwise return the detection results directly.
|
||||
"""
|
||||
|
||||
if isinstance(imgs, (list, tuple)):
|
||||
is_batch = True
|
||||
else:
|
||||
imgs = [imgs]
|
||||
is_batch = False
|
||||
|
||||
cfg = self.cfg
|
||||
device = next(self.model.parameters()).device # model device
|
||||
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
cfg = cfg.copy()
|
||||
# set loading pipeline type
|
||||
cfg.data.val.pipeline.insert(0, dict(type='LoadImageFromWebcam'))
|
||||
else:
|
||||
cfg = cfg.copy()
|
||||
# set loading pipeline type
|
||||
cfg.data.val.pipeline.insert(
|
||||
0,
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=dict(
|
||||
backend=('http' if imgs[0].startswith('http'
|
||||
) else 'disk'))))
|
||||
|
||||
cfg.data.val.pipeline = replace_ImageToTensor(cfg.data.val.pipeline)
|
||||
|
||||
transforms = []
|
||||
for transform in cfg.data.val.pipeline:
|
||||
if 'img_scale' in transform:
|
||||
transform['img_scale'] = tuple(transform['img_scale'])
|
||||
if isinstance(transform, dict):
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
transforms.append(transform)
|
||||
elif callable(transform):
|
||||
transforms.append(transform)
|
||||
else:
|
||||
raise TypeError('transform must be callable or a dict')
|
||||
test_pipeline = Compose(transforms)
|
||||
|
||||
datas = []
|
||||
for img in imgs:
|
||||
# prepare data
|
||||
if isinstance(img, np.ndarray):
|
||||
# directly add img
|
||||
data = dict(img=img)
|
||||
else:
|
||||
# add information into dict
|
||||
data = dict(img_info=dict(filename=img), img_prefix=None)
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
datas.append(data)
|
||||
|
||||
data = collate(datas, samples_per_gpu=len(imgs))
|
||||
# just get the actual data from DataContainer
|
||||
data['img_metas'] = [
|
||||
img_metas.data[0] for img_metas in data['img_metas']
|
||||
]
|
||||
data['img'] = [img.data[0] for img in data['img']]
|
||||
if next(self.model.parameters()).is_cuda:
|
||||
# scatter to specified GPU
|
||||
data = scatter(data, [device])[0]
|
||||
else:
|
||||
for m in self.model.modules():
|
||||
assert not isinstance(
|
||||
m, RoIPool
|
||||
), 'CPU inference with RoIPool is not supported currently.'
|
||||
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
results = self.model(mode='test', **data)
|
||||
|
||||
return results
|
||||
|
||||
def visualize(self,
|
||||
img,
|
||||
results,
|
||||
score_thr=0.3,
|
||||
show=False,
|
||||
out_file=None):
|
||||
bboxes = results['detection_boxes'][0]
|
||||
scores = results['detection_scores'][0]
|
||||
labels = results['detection_classes'][0].tolist()
|
||||
|
||||
# If self.CLASSES is not None, class_id will be converted to self.CLASSES for visualization,
|
||||
# otherwise the class_id will be displayed.
|
||||
# And don't try to modify the value in results, it may cause some bugs or even precision problems,
|
||||
# because `self.evaluate` will also use the results, refer to: https://github.com/alibaba/EasyCV/pull/67
|
||||
|
||||
if self.CLASSES is not None and len(self.CLASSES) > 0:
|
||||
for i, classes_id in enumerate(labels):
|
||||
if classes_id is None:
|
||||
labels[i] = None
|
||||
else:
|
||||
labels[i] = self.CLASSES[int(classes_id)]
|
||||
|
||||
if scores is not None and score_thr > 0:
|
||||
inds = scores > score_thr
|
||||
bboxes = bboxes[inds]
|
||||
labels = np.array(labels)[inds]
|
||||
|
||||
def visualize(self, img, results, show=False, out_file=None):
|
||||
"""Only support show one sample now."""
|
||||
bboxes = results['detection_boxes']
|
||||
labels = results['detection_class_names']
|
||||
img = self._load_input(img)['img']
|
||||
imshow_bboxes(
|
||||
img,
|
||||
bboxes,
|
||||
|
@ -263,6 +119,12 @@ class DetrPredictor(PredictorInterface):
|
|||
out_file=out_file)
|
||||
|
||||
|
||||
@deprecated(reason='Please use DetectionPredictor.')
|
||||
@PREDICTORS.register_module()
|
||||
class DetrPredictor(DetectionPredictor):
|
||||
""""""
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class TorchYoloXPredictor(PredictorInterface):
|
||||
|
||||
|
|
|
@ -25,6 +25,11 @@ class FaceKeypointsPredictor(PredictorV2):
|
|||
Args:
|
||||
model_path (str): Path of model path
|
||||
config_file (str): config file path for model and processor to init. Defaults to None.
|
||||
batch_size (int): batch size for forward.
|
||||
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
||||
save_results (bool): Whether to save predict results.
|
||||
save_path (str): File path for saving results, only valid when `save_results` is True.
|
||||
pipelines (list[dict]): Data pipeline configs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -34,7 +39,7 @@ class FaceKeypointsPredictor(PredictorV2):
|
|||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
mode='bgr'):
|
||||
pipelines=None):
|
||||
super(FaceKeypointsPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file,
|
||||
|
@ -42,7 +47,7 @@ class FaceKeypointsPredictor(PredictorV2):
|
|||
device=device,
|
||||
save_results=save_results,
|
||||
save_path=save_path,
|
||||
mode=mode)
|
||||
pipelines=pipelines)
|
||||
|
||||
self.input_size = self.cfg.IMAGE_SIZE
|
||||
self.point_number = self.cfg.POINT_NUMBER
|
||||
|
|
|
@ -25,9 +25,11 @@ class HandKeypointsPredictor(PredictorV2):
|
|||
config_file: path or ``Config`` of config file
|
||||
detection_model_config: dict of hand detection model predictor config,
|
||||
example like ``dict(type="", model_path="", config_file="", ......)``
|
||||
batch_size: batch_size to infer
|
||||
save_results: bool
|
||||
save_path: path of result image
|
||||
batch_size (int): batch size for forward.
|
||||
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
||||
save_results (bool): Whether to save predict results.
|
||||
save_path (str): File path for saving results, only valid when `save_results` is True.
|
||||
pipelines (list[dict]): Data pipeline configs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -38,7 +40,7 @@ class HandKeypointsPredictor(PredictorV2):
|
|||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
mode='rgb',
|
||||
pipelines=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(HandKeypointsPredictor, self).__init__(
|
||||
|
@ -48,7 +50,7 @@ class HandKeypointsPredictor(PredictorV2):
|
|||
device=device,
|
||||
save_results=save_results,
|
||||
save_path=save_path,
|
||||
mode=mode,
|
||||
pipelines=pipelines,
|
||||
*args,
|
||||
**kwargs)
|
||||
self.dataset_info = DatasetInfo(COCO_WHOLEBODY_HAND_DATASET_INFO)
|
||||
|
@ -70,52 +72,48 @@ class HandKeypointsPredictor(PredictorV2):
|
|||
}
|
||||
}
|
||||
"""
|
||||
image_paths = input['inputs']
|
||||
batch_data = []
|
||||
image_path = input['inputs']
|
||||
data_list = []
|
||||
box_id = 0
|
||||
for batch_index, image_path in enumerate(image_paths):
|
||||
det_bbox_result = input['results']['detection_boxes'][batch_index]
|
||||
det_bbox_scores = input['results']['detection_scores'][batch_index]
|
||||
img = mmcv.imread(image_path, 'color', self.mode)
|
||||
for bbox, score in zip(det_bbox_result, det_bbox_scores):
|
||||
center, scale = _box2cs(self.cfg.data_cfg['image_size'], bbox)
|
||||
# prepare data
|
||||
data = {
|
||||
'image_file':
|
||||
image_path,
|
||||
'img':
|
||||
img,
|
||||
'image_id':
|
||||
batch_index,
|
||||
'center':
|
||||
center,
|
||||
'scale':
|
||||
scale,
|
||||
'bbox_score':
|
||||
score,
|
||||
'bbox_id':
|
||||
box_id, # need to be assigned if batch_size > 1
|
||||
'dataset':
|
||||
'coco_wholebody_hand',
|
||||
'joints_3d':
|
||||
np.zeros((self.cfg.data_cfg.num_joints, 3),
|
||||
dtype=np.float32),
|
||||
'joints_3d_visible':
|
||||
np.zeros((self.cfg.data_cfg.num_joints, 3),
|
||||
dtype=np.float32),
|
||||
'rotation':
|
||||
0,
|
||||
'flip_pairs':
|
||||
self.dataset_info.flip_pairs,
|
||||
'ann_info': {
|
||||
'image_size':
|
||||
np.array(self.cfg.data_cfg['image_size']),
|
||||
'num_joints': self.cfg.data_cfg['num_joints']
|
||||
}
|
||||
det_bbox_result = input['detection_boxes']
|
||||
det_bbox_scores = input['detection_scores']
|
||||
img = mmcv.imread(image_path, 'color', self.INPUT_IMAGE_MODE)
|
||||
for bbox, score in zip(det_bbox_result, det_bbox_scores):
|
||||
center, scale = _box2cs(self.cfg.data_cfg['image_size'], bbox)
|
||||
# prepare data
|
||||
data = {
|
||||
'image_file':
|
||||
image_path,
|
||||
'img':
|
||||
img,
|
||||
'image_id':
|
||||
0,
|
||||
'center':
|
||||
center,
|
||||
'scale':
|
||||
scale,
|
||||
'bbox_score':
|
||||
score,
|
||||
'bbox_id':
|
||||
box_id, # need to be assigned if batch_size > 1
|
||||
'dataset':
|
||||
'coco_wholebody_hand',
|
||||
'joints_3d':
|
||||
np.zeros((self.cfg.data_cfg.num_joints, 3), dtype=np.float32),
|
||||
'joints_3d_visible':
|
||||
np.zeros((self.cfg.data_cfg.num_joints, 3), dtype=np.float32),
|
||||
'rotation':
|
||||
0,
|
||||
'flip_pairs':
|
||||
self.dataset_info.flip_pairs,
|
||||
'ann_info': {
|
||||
'image_size': np.array(self.cfg.data_cfg['image_size']),
|
||||
'num_joints': self.cfg.data_cfg['num_joints']
|
||||
}
|
||||
batch_data.append(data)
|
||||
box_id += 1
|
||||
return batch_data
|
||||
}
|
||||
data_list.append(data)
|
||||
box_id += 1
|
||||
return data_list
|
||||
|
||||
def preprocess_single(self, input):
|
||||
results = []
|
||||
|
@ -128,8 +126,11 @@ class HandKeypointsPredictor(PredictorV2):
|
|||
"""Process all inputs list. And collate to batch and put to target device.
|
||||
If you need custom ops to load or process a batch samples, you need to reimplement it.
|
||||
"""
|
||||
# hand det and return source image
|
||||
det_results = self.detection_predictor(inputs, keep_inputs=True)
|
||||
|
||||
batch_outputs = []
|
||||
for i in inputs:
|
||||
for i in det_results:
|
||||
for res in self.preprocess_single(i, *args, **kwargs):
|
||||
batch_outputs.append(res)
|
||||
batch_outputs = self._collate_fn(batch_outputs)
|
||||
|
@ -137,37 +138,25 @@ class HandKeypointsPredictor(PredictorV2):
|
|||
return batch_outputs
|
||||
|
||||
def postprocess(self, inputs, *args, **kwargs):
|
||||
output = {}
|
||||
output['keypoints'] = inputs['preds']
|
||||
output['boxes'] = inputs['boxes']
|
||||
for i, bbox in enumerate(output['boxes']):
|
||||
keypoints = inputs['preds']
|
||||
boxes = inputs['boxes']
|
||||
for i, bbox in enumerate(boxes):
|
||||
center, scale = bbox[:2], bbox[2:4]
|
||||
output['boxes'][i][:4] = bbox_cs2xyxy(center, scale)
|
||||
output['boxes'] = output['boxes'][:, :4]
|
||||
return output
|
||||
|
||||
def __call__(self, inputs, keep_inputs=False):
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
|
||||
results_list = []
|
||||
for i in range(0, len(inputs), self.batch_size):
|
||||
batch = inputs[i:max(len(inputs) - 1, i + self.batch_size)]
|
||||
# hand det and return source image
|
||||
det_results = self.detection_predictor(batch, keep_inputs=True)
|
||||
# hand keypoints
|
||||
batch_outputs = self.preprocess(det_results)
|
||||
batch_outputs = self.forward(batch_outputs)
|
||||
results = self.postprocess(batch_outputs)
|
||||
if keep_inputs:
|
||||
results = {'inputs': batch, 'results': results}
|
||||
# if dump, the outputs will not added to the return value to prevent taking up too much memory
|
||||
if self.save_results:
|
||||
self.dump([results], self.save_path, mode='ab+')
|
||||
else:
|
||||
results_list.append(results)
|
||||
|
||||
return results_list
|
||||
boxes[i][:4] = bbox_cs2xyxy(center, scale)
|
||||
boxes = boxes[:, :4]
|
||||
# TODO: support multi bboxes for a single sample
|
||||
assert len(keypoints.shape) == 3
|
||||
assert len(boxes.shape) == 2
|
||||
batch_outputs = []
|
||||
batch_size = keypoints.shape[0]
|
||||
keypoints = np.split(keypoints, batch_size)
|
||||
boxes = np.split(boxes, batch_size)
|
||||
for i in range(batch_size):
|
||||
batch_outputs.append({
|
||||
'keypoints': keypoints[i],
|
||||
'boxes': boxes[i]
|
||||
})
|
||||
return batch_outputs
|
||||
|
||||
def show_result(self,
|
||||
image_path,
|
||||
|
|
|
@ -5,22 +5,25 @@ import numpy as np
|
|||
import torch
|
||||
from matplotlib.collections import PatchCollection
|
||||
from matplotlib.patches import Polygon
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from easycv.core.visualization.image import imshow_bboxes
|
||||
from easycv.datasets.registry import PIPELINES
|
||||
from easycv.file import io
|
||||
from easycv.models import build_model
|
||||
from easycv.predictors.builder import PREDICTORS
|
||||
from easycv.predictors.interface import PredictorInterface
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
from .base import PredictorV2
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class SegmentationPredictor(PredictorV2):
|
||||
"""Predictor for Segmentation.
|
||||
|
||||
Args:
|
||||
model_path (str): Path of model path.
|
||||
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
||||
batch_size (int): batch size for forward.
|
||||
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
||||
save_results (bool): Whether to save predict results.
|
||||
save_path (str): File path for saving results, only valid when `save_results` is True.
|
||||
pipelines (list[dict]): Data pipeline configs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
|
@ -28,20 +31,21 @@ class SegmentationPredictor(PredictorV2):
|
|||
batch_size=1,
|
||||
device=None,
|
||||
save_results=False,
|
||||
save_path=None):
|
||||
"""Predict pipeline for Segmentation
|
||||
save_path=None,
|
||||
pipelines=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
|
||||
Args:
|
||||
model_path (str): Path of model path
|
||||
config_file (str): config file path for model and processor to init. Defaults to None.
|
||||
"""
|
||||
super(SegmentationPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
save_results=save_results,
|
||||
save_path=save_path)
|
||||
save_path=save_path,
|
||||
pipelines=pipelines,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
self.CLASSES = self.cfg.CLASSES
|
||||
self.PALETTE = self.cfg.PALETTE
|
||||
|
@ -123,71 +127,61 @@ class SegmentationPredictor(PredictorV2):
|
|||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class Mask2formerPredictor(PredictorInterface):
|
||||
class Mask2formerPredictor(SegmentationPredictor):
|
||||
"""Predictor for Mask2former.
|
||||
|
||||
def __init__(self, model_path, model_config=None):
|
||||
"""init model
|
||||
Args:
|
||||
model_path (str): Path of model path.
|
||||
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
||||
batch_size (int): batch size for forward.
|
||||
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
||||
save_results (bool): Whether to save predict results.
|
||||
save_path (str): File path for saving results, only valid when `save_results` is True.
|
||||
pipelines (list[dict]): Data pipeline configs.
|
||||
"""
|
||||
|
||||
Args:
|
||||
model_path (str): Path of model path
|
||||
model_config (config, optional): config string for model to init. Defaults to None.
|
||||
def __init__(self,
|
||||
model_path,
|
||||
config_file=None,
|
||||
batch_size=1,
|
||||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
pipelines=None,
|
||||
task_mode='panoptic',
|
||||
*args,
|
||||
**kwargs):
|
||||
super(Mask2formerPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
save_results=save_results,
|
||||
save_path=save_path,
|
||||
pipelines=pipelines,
|
||||
*args,
|
||||
**kwargs)
|
||||
self.task_mode = task_mode
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Model forward.
|
||||
"""
|
||||
self.model_path = model_path
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs, mode='test', encode=False)
|
||||
return outputs
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.model = None
|
||||
with io.open(self.model_path, 'rb') as infile:
|
||||
checkpoint = torch.load(infile, map_location='cpu')
|
||||
|
||||
assert 'meta' in checkpoint and 'config' in checkpoint[
|
||||
'meta'], 'meta.config is missing from checkpoint'
|
||||
|
||||
self.cfg = checkpoint['meta']['config']
|
||||
self.classes = len(self.cfg.PALETTE)
|
||||
self.class_name = self.cfg.CLASSES
|
||||
# build model
|
||||
self.model = build_model(self.cfg.model)
|
||||
|
||||
self.ckpt = load_checkpoint(
|
||||
self.model, self.model_path, map_location=self.device)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# build pipeline
|
||||
test_pipeline = self.cfg.test_pipeline
|
||||
pipeline = [build_from_cfg(p, PIPELINES) for p in test_pipeline]
|
||||
self.pipeline = Compose(pipeline)
|
||||
|
||||
def predict(self, input_data_list, mode='panoptic'):
|
||||
"""
|
||||
Args:
|
||||
input_data_list: a list of numpy array(in rgb order), each array is a sample
|
||||
to be predicted
|
||||
"""
|
||||
output_list = []
|
||||
for idx, img in enumerate(input_data_list):
|
||||
output = {}
|
||||
if not isinstance(img, np.ndarray):
|
||||
img = np.asarray(img)
|
||||
data_dict = {'img': img}
|
||||
ori_shape = img.shape
|
||||
data_dict = self.pipeline(data_dict)
|
||||
img = data_dict['img']
|
||||
img[0] = torch.unsqueeze(img[0], 0).to(self.device)
|
||||
img_metas = [[
|
||||
img_meta._data for img_meta in data_dict['img_metas']
|
||||
]]
|
||||
img_metas[0][0]['ori_shape'] = ori_shape
|
||||
res = self.model.forward_test(img, img_metas, encode=False)
|
||||
if mode == 'panoptic':
|
||||
output['pan'] = res['pan_results'][0]
|
||||
elif mode == 'instance':
|
||||
output['segms'] = res['detection_masks'][0]
|
||||
output['bboxes'] = res['detection_boxes'][0]
|
||||
output['scores'] = res['detection_scores'][0]
|
||||
output['labels'] = res['detection_classes'][0]
|
||||
output_list.append(output)
|
||||
return output_list
|
||||
def postprocess(self, inputs):
|
||||
output = {}
|
||||
if self.task_mode == 'panoptic':
|
||||
output['pan'] = inputs['pan_results'][0]
|
||||
elif self.task_mode == 'instance':
|
||||
output['segms'] = inputs['detection_masks'][0]
|
||||
output['bboxes'] = inputs['detection_boxes'][0]
|
||||
output['scores'] = inputs['detection_scores'][0]
|
||||
output['labels'] = inputs['detection_classes'][0]
|
||||
else:
|
||||
raise ValueError(f'Not support model {self.task_mode}')
|
||||
return output
|
||||
|
||||
def show_panoptic(self, img, pan_mask):
|
||||
pan_label = np.unique(pan_mask)
|
||||
|
@ -214,147 +208,6 @@ class Mask2formerPredictor(PredictorInterface):
|
|||
return instance_result
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class SegFormerPredictor(PredictorInterface):
|
||||
|
||||
def __init__(self, model_path, model_config):
|
||||
"""init model
|
||||
|
||||
Args:
|
||||
model_path (str): Path of model path
|
||||
model_config (config): config string for model to init. Defaults to None.
|
||||
"""
|
||||
self.model_path = model_path
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.model = None
|
||||
with io.open(self.model_path, 'rb') as infile:
|
||||
checkpoint = torch.load(infile, map_location='cpu')
|
||||
|
||||
self.cfg = mmcv_config_fromfile(model_config)
|
||||
self.CLASSES = self.cfg.CLASSES
|
||||
self.PALETTE = self.cfg.PALETTE
|
||||
# build model
|
||||
self.model = build_model(self.cfg.model)
|
||||
|
||||
self.ckpt = load_checkpoint(
|
||||
self.model, self.model_path, map_location=self.device)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# build pipeline
|
||||
test_pipeline = self.cfg.test_pipeline
|
||||
pipeline = [build_from_cfg(p, PIPELINES) for p in test_pipeline]
|
||||
self.pipeline = Compose(pipeline)
|
||||
|
||||
def predict(self, input_data_list):
|
||||
"""
|
||||
using session run predict a number of samples using batch_size
|
||||
|
||||
Args:
|
||||
input_data_list: a list of numpy array(in rgb order), each array is a sample
|
||||
to be predicted
|
||||
use a fixed number if you do not want to adjust batch_size in runtime
|
||||
"""
|
||||
output_list = []
|
||||
for idx, img in enumerate(input_data_list):
|
||||
if type(img) is not np.ndarray:
|
||||
img = np.asarray(img)
|
||||
|
||||
ori_img_shape = img.shape[:2]
|
||||
|
||||
data_dict = {'img': img}
|
||||
data_dict['ori_shape'] = ori_img_shape
|
||||
data_dict = self.pipeline(data_dict)
|
||||
img = data_dict['img']
|
||||
img = torch.unsqueeze(img[0], 0).to(self.device)
|
||||
data_dict.pop('img')
|
||||
|
||||
with torch.no_grad():
|
||||
out = self.model([img],
|
||||
mode='test',
|
||||
img_metas=[[data_dict['img_metas'][0]._data]])
|
||||
|
||||
output_list.append(out)
|
||||
|
||||
return output_list
|
||||
|
||||
def show_result(self,
|
||||
img,
|
||||
result,
|
||||
palette=None,
|
||||
win_name='',
|
||||
show=False,
|
||||
wait_time=0,
|
||||
out_file=None,
|
||||
opacity=0.5):
|
||||
"""Draw `result` over `img`.
|
||||
|
||||
Args:
|
||||
img (str or Tensor): The image to be displayed.
|
||||
result (Tensor): The semantic segmentation results to draw over
|
||||
`img`.
|
||||
palette (list[list[int]]] | np.ndarray | None): The palette of
|
||||
segmentation map. If None is given, random palette will be
|
||||
generated. Default: None
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
Default: 0.
|
||||
show (bool): Whether to show the image.
|
||||
Default: False.
|
||||
out_file (str or None): The filename to write the image.
|
||||
Default: None.
|
||||
opacity(float): Opacity of painted segmentation map.
|
||||
Default 0.5.
|
||||
Must be in (0, 1] range.
|
||||
Returns:
|
||||
img (Tensor): Only if not `show` or `out_file`
|
||||
"""
|
||||
|
||||
img = mmcv.imread(img)
|
||||
img = img.copy()
|
||||
seg = result[0]
|
||||
if palette is None:
|
||||
if self.PALETTE is None:
|
||||
# Get random state before set seed,
|
||||
# and restore random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
palette = np.random.randint(
|
||||
0, 255, size=(len(self.CLASSES), 3))
|
||||
np.random.set_state(state)
|
||||
else:
|
||||
palette = self.PALETTE
|
||||
palette = np.array(palette)
|
||||
assert palette.shape[0] == len(self.CLASSES)
|
||||
assert palette.shape[1] == 3
|
||||
assert len(palette.shape) == 2
|
||||
assert 0 < opacity <= 1.0
|
||||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
||||
for label, color in enumerate(palette):
|
||||
color_seg[seg == label, :] = color
|
||||
# convert to BGR
|
||||
color_seg = color_seg[..., ::-1]
|
||||
|
||||
img = img * (1 - opacity) + color_seg * opacity
|
||||
img = img.astype(np.uint8)
|
||||
# if out_file specified, do not show image in window
|
||||
if out_file is not None:
|
||||
show = False
|
||||
|
||||
if show:
|
||||
mmcv.imshow(img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
if not (show or out_file):
|
||||
return img
|
||||
|
||||
|
||||
def _get_bias_color(base, max_dist=30):
|
||||
"""Get different colors for each masks.
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
@ -8,6 +9,7 @@ from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from easycv.file import io
|
||||
from easycv.file.utils import is_url_path
|
||||
from easycv.framework.errors import TypeError
|
||||
from easycv.utils.constant import CACHE_DIR
|
||||
|
||||
|
@ -32,28 +34,40 @@ def load_checkpoint(model,
|
|||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
if not filename.startswith('oss://'):
|
||||
return mmcv_load_checkpoint(
|
||||
model,
|
||||
filename,
|
||||
map_location=map_location,
|
||||
strict=strict,
|
||||
logger=logger)
|
||||
else:
|
||||
if filename.startswith('oss://'):
|
||||
_, fname = os.path.split(filename)
|
||||
cache_file = os.path.join(CACHE_DIR, fname)
|
||||
if not os.path.exists(CACHE_DIR):
|
||||
os.makedirs(CACHE_DIR)
|
||||
if not os.path.exists(cache_file):
|
||||
print(f'download checkpoint from {filename} to {cache_file}')
|
||||
logging.info(
|
||||
f'download checkpoint from {filename} to {cache_file}')
|
||||
io.copy(filename, cache_file)
|
||||
if torch.distributed.is_available(
|
||||
) and torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
return mmcv_load_checkpoint(
|
||||
model,
|
||||
cache_file,
|
||||
map_location=map_location,
|
||||
strict=strict,
|
||||
logger=logger)
|
||||
filename = cache_file
|
||||
elif is_url_path(filename):
|
||||
from torch.hub import urlparse, download_url_to_file
|
||||
parts = urlparse(filename)
|
||||
base_name = os.path.basename(parts.path)
|
||||
cache_file = os.path.join(CACHE_DIR, base_name)
|
||||
if not os.path.exists(CACHE_DIR):
|
||||
os.makedirs(CACHE_DIR)
|
||||
if not os.path.exists(cache_file):
|
||||
logging.info(
|
||||
f'download checkpoint from {filename} to {cache_file}')
|
||||
download_url_to_file(filename, cache_file)
|
||||
if torch.distributed.is_available(
|
||||
) and torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
filename = cache_file
|
||||
return mmcv_load_checkpoint(
|
||||
model,
|
||||
filename,
|
||||
map_location=map_location,
|
||||
strict=strict,
|
||||
logger=logger)
|
||||
|
||||
|
||||
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
CACHE_DIR = '.easycv_cache'
|
||||
import os
|
||||
|
||||
CACHE_DIR = os.path.expanduser('~/.cache/easycv/')
|
||||
|
||||
MAX_READ_IMAGE_TRY_TIMES = 10
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from six.moves import map, zip
|
||||
|
||||
from easycv.models.backbones.repvgg_yolox_backbone import RepVGGBlock
|
||||
|
||||
|
||||
def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
|
||||
|
@ -79,6 +79,8 @@ def reparameterize_models(model):
|
|||
Args:
|
||||
model: nn.Module
|
||||
"""
|
||||
from easycv.models.backbones.repvgg_yolox_backbone import RepVGGBlock
|
||||
|
||||
reparameterize_count = 0
|
||||
for layer in model.modules():
|
||||
if isinstance(layer, RepVGGBlock):
|
||||
|
@ -89,3 +91,31 @@ def reparameterize_models(model):
|
|||
.format(reparameterize_count))
|
||||
print('reparam:', reparameterize_count)
|
||||
return model
|
||||
|
||||
|
||||
def deprecated(reason):
|
||||
"""
|
||||
This is a decorator which can be used to mark functions
|
||||
as deprecated. It will result in a warning being emitted
|
||||
when the function is used.
|
||||
"""
|
||||
|
||||
def decorator(func1):
|
||||
if inspect.isclass(func1):
|
||||
fmt1 = 'Call to deprecated class {name} ({reason}).'
|
||||
else:
|
||||
fmt1 = 'Call to deprecated function {name} ({reason}).'
|
||||
|
||||
@functools.wraps(func1)
|
||||
def new_func1(*args, **kwargs):
|
||||
warnings.simplefilter('always', DeprecationWarning)
|
||||
warnings.warn(
|
||||
fmt1.format(name=func1.__name__, reason=reason),
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2)
|
||||
warnings.simplefilter('default', DeprecationWarning)
|
||||
return func1(*args, **kwargs)
|
||||
|
||||
return new_func1
|
||||
|
||||
return decorator
|
||||
|
|
|
@ -7,9 +7,7 @@ from tests.ut_config import (IMG_NORM_CFG_255, SEG_DATA_SMALL_RAW_LOCAL,
|
|||
|
||||
from easycv.core.evaluation.builder import build_evaluator
|
||||
from easycv.datasets.builder import build_datasource
|
||||
from easycv.datasets.segmentation.data_sources.raw import SegSourceRaw
|
||||
from easycv.datasets.segmentation.raw import SegDataset
|
||||
from easycv.file import io
|
||||
|
||||
|
||||
class SegDatasetTest(unittest.TestCase):
|
||||
|
|
|
@ -8,14 +8,57 @@ import unittest
|
|||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from easycv.predictors.classifier import TorchClassifier
|
||||
|
||||
from easycv.predictors.builder import build_predictor
|
||||
from easycv.utils.test_util import clean_up, get_tmp_dir
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from tests.ut_config import (PRETRAINED_MODEL_RESNET50_WITHOUTHEAD,
|
||||
IMAGENET_LABEL_TXT, TEST_IMAGES_DIR)
|
||||
|
||||
|
||||
class ClassificationPredictorTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_single(self):
|
||||
checkpoint = PRETRAINED_MODEL_RESNET50_WITHOUTHEAD
|
||||
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
|
||||
cfg = mmcv_config_fromfile(config_file)
|
||||
predict_op = build_predictor(
|
||||
dict(
|
||||
**cfg.predict,
|
||||
model_path=checkpoint,
|
||||
config_file=config_file,
|
||||
label_map_path=IMAGENET_LABEL_TXT))
|
||||
img_path = os.path.join(TEST_IMAGES_DIR, 'catb.jpg')
|
||||
|
||||
results = predict_op([img_path])[0]
|
||||
self.assertListEqual(results['class'], [283])
|
||||
self.assertListEqual(results['class_name'], ['"Persian cat",'])
|
||||
self.assertEqual(len(results['class_probs']), 1000)
|
||||
|
||||
def test_batch(self):
|
||||
checkpoint = PRETRAINED_MODEL_RESNET50_WITHOUTHEAD
|
||||
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
|
||||
cfg = mmcv_config_fromfile(config_file)
|
||||
predict_op = build_predictor(
|
||||
dict(
|
||||
**cfg.predict,
|
||||
model_path=checkpoint,
|
||||
config_file=config_file,
|
||||
label_map_path=IMAGENET_LABEL_TXT,
|
||||
batch_size=3))
|
||||
img_path = os.path.join(TEST_IMAGES_DIR, 'catb.jpg')
|
||||
|
||||
num_imgs = 4
|
||||
results = predict_op([img_path] * num_imgs)
|
||||
self.assertEqual(len(results), num_imgs)
|
||||
for res in results:
|
||||
self.assertListEqual(res['class'], [283])
|
||||
self.assertListEqual(res['class_name'], ['"Persian cat",'])
|
||||
self.assertEqual(len(res['class_probs']), 1000)
|
||||
|
||||
|
||||
class TorchClassifierTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -62,6 +105,8 @@ class TorchClassifierTest(unittest.TestCase):
|
|||
output_ckpt = f'{self.tmp_dir}/export.pth'
|
||||
torch.save(output_dict, output_ckpt)
|
||||
|
||||
from easycv.predictors.classifier import TorchClassifier
|
||||
|
||||
fe = TorchClassifier(
|
||||
output_ckpt, topk=topk, label_map_path=IMAGENET_LABEL_TXT)
|
||||
|
||||
|
|
|
@ -4,11 +4,11 @@ isort:skip_file
|
|||
"""
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from easycv.predictors.detector import TorchYoloXPredictor, DetrPredictor
|
||||
from easycv.predictors.detector import TorchYoloXPredictor, DetectionPredictor
|
||||
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT,
|
||||
PRETRAINED_MODEL_YOLOXS_EXPORT_OLD,
|
||||
PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT,
|
||||
|
@ -154,25 +154,18 @@ class DetectorTest(unittest.TestCase):
|
|||
[510.37033, 268.4982, 527.67017, 273.04935]]),
|
||||
decimal=1)
|
||||
|
||||
def test_vitdet_detector(self):
|
||||
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth'
|
||||
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
||||
out_file = './result.jpg'
|
||||
vitdet = DetrPredictor(model_path)
|
||||
output = vitdet.predict(img)
|
||||
vitdet.visualize(img, output, out_file=out_file)
|
||||
|
||||
def _detection_detector_assert(self, output):
|
||||
self.assertIn('detection_boxes', output)
|
||||
self.assertIn('detection_scores', output)
|
||||
self.assertIn('detection_classes', output)
|
||||
self.assertIn('detection_masks', output)
|
||||
self.assertIn('img_metas', output)
|
||||
self.assertEqual(len(output['detection_boxes'][0]), 33)
|
||||
self.assertEqual(len(output['detection_scores'][0]), 33)
|
||||
self.assertEqual(len(output['detection_classes'][0]), 33)
|
||||
self.assertEqual(len(output['detection_boxes']), 33)
|
||||
self.assertEqual(len(output['detection_scores']), 33)
|
||||
self.assertEqual(len(output['detection_classes']), 33)
|
||||
|
||||
self.assertListEqual(
|
||||
output['detection_classes'][0].tolist(),
|
||||
output['detection_classes'].tolist(),
|
||||
np.array([
|
||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||
2, 2, 2, 2, 2, 2, 7, 7, 13, 13, 13, 56
|
||||
|
@ -180,7 +173,7 @@ class DetectorTest(unittest.TestCase):
|
|||
dtype=np.int32).tolist())
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_scores'][0],
|
||||
output['detection_scores'],
|
||||
np.array([
|
||||
0.9975854158401489, 0.9965696334838867, 0.9922919869422913,
|
||||
0.9833580851554871, 0.983080267906189, 0.970454752445221,
|
||||
|
@ -198,7 +191,7 @@ class DetectorTest(unittest.TestCase):
|
|||
decimal=2)
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_boxes'][0],
|
||||
output['detection_boxes'],
|
||||
np.array([[
|
||||
294.22674560546875, 116.6078109741211, 379.4328918457031,
|
||||
150.14097595214844
|
||||
|
@ -333,6 +326,32 @@ class DetectorTest(unittest.TestCase):
|
|||
]]),
|
||||
decimal=1)
|
||||
|
||||
def test_detection_detector_single(self):
|
||||
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth'
|
||||
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
||||
vitdet = DetectionPredictor(model_path, score_threshold=0.0)
|
||||
output = vitdet(img)
|
||||
output = output[0]
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg') as tmp_file:
|
||||
tmp_save_path = tmp_file.name
|
||||
vitdet.visualize(img, output, out_file=tmp_save_path)
|
||||
self._detection_detector_assert(output)
|
||||
|
||||
def test_detection_detector_batch(self):
|
||||
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth'
|
||||
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
||||
vitdet = DetectionPredictor(
|
||||
model_path, score_threshold=0.0, batch_size=2)
|
||||
num_samples = 3
|
||||
images = [img] * num_samples
|
||||
outputs = vitdet(images)
|
||||
self.assertEqual(len(outputs), num_samples)
|
||||
for output in outputs:
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg') as tmp_file:
|
||||
tmp_save_path = tmp_file.name
|
||||
vitdet.visualize(img, output, out_file=tmp_save_path)
|
||||
self._detection_detector_assert(output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -3,22 +3,14 @@
|
|||
isort:skip_file
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from easycv.predictors.detector import TorchYoloXPredictor
|
||||
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT,
|
||||
PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT,
|
||||
PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT,
|
||||
PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE,
|
||||
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE,
|
||||
PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE,
|
||||
DET_DATA_SMALL_COCO_LOCAL)
|
||||
|
||||
from easycv.utils.test_util import benchmark
|
||||
import logging
|
||||
import pandas as pd
|
||||
import torch
|
||||
from numpy.testing import assert_array_almost_equal
|
||||
|
||||
|
@ -37,7 +29,6 @@ class DetectorTest(unittest.TestCase):
|
|||
input_data_list = [np.asarray(Image.open(img))]
|
||||
|
||||
blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE
|
||||
# blade_path = '/home/zouxinyi.zxy/easycv_nfs/pretrained_models/detection/infer_yolox/debug_blade.pt.blade'
|
||||
predictor_blade = TorchYoloXPredictor(
|
||||
model_path=blade_path, score_thresh=0.5)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ class FaceKeypointsPredictorWithoutDetectorTest(unittest.TestCase):
|
|||
def test_single(self):
|
||||
predict_pipeline = FaceKeypointsPredictor(
|
||||
model_path=self.model_path, config_file=self.model_config_path)
|
||||
output = predict_pipeline(self.image_path)[0][0]
|
||||
output = predict_pipeline(self.image_path)[0]
|
||||
output_keypoints = output['point']
|
||||
output_pose = output['pose']
|
||||
img = cv2.imread(self.image_path)
|
||||
|
@ -38,18 +38,10 @@ class FaceKeypointsPredictorWithoutDetectorTest(unittest.TestCase):
|
|||
total_samples = 3
|
||||
output = predict_pipeline([self.image_path] * total_samples)
|
||||
|
||||
self.assertEqual(len(output), 2)
|
||||
self.assertEqual(len(output[0]), 2)
|
||||
self.assertEqual(len(output[1]), 1)
|
||||
self.assertEqual(output[0][0]['point'].shape[0], 106)
|
||||
self.assertEqual(output[0][0]['point'].shape[1], 2)
|
||||
self.assertEqual(output[0][0]['pose'].shape[0], 3)
|
||||
self.assertEqual(output[0][1]['point'].shape[0], 106)
|
||||
self.assertEqual(output[0][1]['point'].shape[1], 2)
|
||||
self.assertEqual(output[0][1]['pose'].shape[0], 3)
|
||||
self.assertEqual(output[1][0]['point'].shape[0], 106)
|
||||
self.assertEqual(output[1][0]['point'].shape[1], 2)
|
||||
self.assertEqual(output[1][0]['pose'].shape[0], 3)
|
||||
self.assertEqual(len(output), total_samples)
|
||||
for out in output:
|
||||
self.assertEqual(out['point'].shape, (106, 2))
|
||||
self.assertEqual(out['pose'].shape, (3, ))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -39,6 +39,37 @@ class HandKeypointsPredictorTest(unittest.TestCase):
|
|||
self.assertEqual(keypoints.shape[1], 21)
|
||||
self.assertEqual(keypoints.shape[2], 3)
|
||||
|
||||
def test_batch(self):
|
||||
config = mmcv_config_fromfile(self.model_config_path)
|
||||
predict_pipeline = HandKeypointsPredictor(
|
||||
model_path=self.model_path,
|
||||
config_file=config,
|
||||
batch_size=2,
|
||||
detection_predictor_config=dict(
|
||||
type='DetectionPredictor',
|
||||
model_path=MM_DEFAULT_HAND_DETECTION_SSDLITE_MODEL_PATH,
|
||||
config_file=MM_DEFAULT_HAND_DETECTION_SSDLITE_CONFIG_FILE,
|
||||
score_threshold=0.5))
|
||||
|
||||
num_samples = 4
|
||||
outputs = predict_pipeline(
|
||||
[self.image_path] * num_samples, keep_inputs=True)
|
||||
base_keypoints = outputs[0]['keypoints']
|
||||
base_boxes = outputs[0]['boxes']
|
||||
for output in outputs:
|
||||
keypoints = output['keypoints']
|
||||
boxes = output['boxes']
|
||||
image_show = predict_pipeline.show_result(
|
||||
self.image_path,
|
||||
keypoints,
|
||||
boxes,
|
||||
save_path=self.save_image_path)
|
||||
self.assertEqual(keypoints.shape, (1, 21, 3))
|
||||
self.assertEqual(boxes.shape, (1, 4))
|
||||
self.assertListEqual(keypoints.tolist(), base_keypoints.tolist())
|
||||
self.assertListEqual(boxes.tolist(), base_boxes.tolist())
|
||||
self.assertEqual(output['inputs'], self.image_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -8,6 +8,7 @@ import unittest
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
from tests.ut_config import (MODEL_CONFIG_SEGFORMER,
|
||||
PRETRAINED_MODEL_MASK2FORMER_DIR,
|
||||
PRETRAINED_MODEL_SEGFORMER, TEST_IMAGES_DIR)
|
||||
|
||||
from easycv.predictors.segmentation import SegmentationPredictor
|
||||
|
@ -31,14 +32,14 @@ class SegmentationPredictorTest(unittest.TestCase):
|
|||
|
||||
outputs = predict_pipeline(img_path, keep_inputs=True)
|
||||
self.assertEqual(len(outputs), 1)
|
||||
self.assertEqual(outputs[0]['inputs'], [img_path])
|
||||
results = outputs[0]
|
||||
self.assertEqual(results['inputs'], img_path)
|
||||
|
||||
results = outputs[0]['results']
|
||||
self.assertListEqual(
|
||||
list(img.shape)[:2], list(results['seg_pred'][0].shape))
|
||||
self.assertListEqual(results['seg_pred'][0][1, :10].tolist(),
|
||||
list(img.shape)[:2], list(results['seg_pred'].shape))
|
||||
self.assertListEqual(results['seg_pred'][1, :10].tolist(),
|
||||
[161 for i in range(10)])
|
||||
self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(),
|
||||
self.assertListEqual(results['seg_pred'][-1, -10:].tolist(),
|
||||
[133 for i in range(10)])
|
||||
|
||||
def test_batch(self):
|
||||
|
@ -56,19 +57,15 @@ class SegmentationPredictorTest(unittest.TestCase):
|
|||
total_samples = 3
|
||||
outputs = predict_pipeline(
|
||||
[img_path] * total_samples, keep_inputs=True)
|
||||
self.assertEqual(len(outputs), 2)
|
||||
self.assertEqual(len(outputs), 3)
|
||||
|
||||
self.assertEqual(outputs[0]['inputs'], [img_path] * 2)
|
||||
self.assertEqual(outputs[1]['inputs'], [img_path] * 1)
|
||||
self.assertEqual(len(outputs[0]['results']['seg_pred']), 2)
|
||||
self.assertEqual(len(outputs[1]['results']['seg_pred']), 1)
|
||||
|
||||
for result in [outputs[0]['results'], outputs[1]['results']]:
|
||||
for i in range(len(outputs)):
|
||||
self.assertEqual(outputs[i]['inputs'], img_path)
|
||||
self.assertListEqual(
|
||||
list(img.shape)[:2], list(result['seg_pred'][0].shape))
|
||||
self.assertListEqual(result['seg_pred'][0][1, :10].tolist(),
|
||||
list(img.shape)[:2], list(outputs[i]['seg_pred'].shape))
|
||||
self.assertListEqual(outputs[i]['seg_pred'][1, :10].tolist(),
|
||||
[161 for i in range(10)])
|
||||
self.assertListEqual(result['seg_pred'][0][-1, -10:].tolist(),
|
||||
self.assertListEqual(outputs[i]['seg_pred'][-1, -10:].tolist(),
|
||||
[133 for i in range(10)])
|
||||
|
||||
def test_dump(self):
|
||||
|
@ -91,17 +88,47 @@ class SegmentationPredictorTest(unittest.TestCase):
|
|||
|
||||
total_samples = 3
|
||||
outputs = predict_pipeline(
|
||||
[img_path] * total_samples, keep_inputs=True)
|
||||
[img_path] * total_samples, keep_inputs=False)
|
||||
self.assertEqual(outputs, [])
|
||||
|
||||
with open(tmp_path, 'rb') as f:
|
||||
results = pickle.loads(f.read())
|
||||
|
||||
self.assertIn('inputs', results[0])
|
||||
self.assertIn('results', results[0])
|
||||
for res in results:
|
||||
self.assertNotIn('inputs', res)
|
||||
self.assertIn('seg_pred', res)
|
||||
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@unittest.skipIf(True, 'WIP')
|
||||
class Mask2formerPredictorTest(unittest.TestCase):
|
||||
|
||||
def test_single(self):
|
||||
import cv2
|
||||
from easycv.predictors.segmentation import Mask2formerPredictor
|
||||
pan_ckpt = os.path.join(PRETRAINED_MODEL_MASK2FORMER_DIR,
|
||||
'mask2former_pan_export.pth')
|
||||
instance_ckpt = os.path.join(PRETRAINED_MODEL_MASK2FORMER_DIR,
|
||||
'mask2former_r50_instance.pth')
|
||||
img_path = os.path.join(TEST_IMAGES_DIR, 'mask2former.jpg')
|
||||
|
||||
# panop
|
||||
predictor = Mask2formerPredictor(
|
||||
model_path=pan_ckpt, output_mode='panoptic')
|
||||
img = cv2.imread(img_path)
|
||||
predict_out = predictor([img])
|
||||
pan_img = predictor.show_panoptic(img, predict_out[0]['pan'])
|
||||
cv2.imwrite('pan_out.jpg', pan_img)
|
||||
|
||||
# instance
|
||||
predictor = Mask2formerPredictor(
|
||||
model_path=instance_ckpt, output_mode='instance')
|
||||
img = cv2.imread(img_path)
|
||||
predict_out = predictor.predict([img], mode='instance')
|
||||
instance_img = predictor.show_instance(img, **predict_out[0])
|
||||
cv2.imwrite('instance_out.jpg', instance_img)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""
|
||||
isort:skip_file
|
||||
"""
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from tests.ut_config import TEST_IMAGES_DIR
|
||||
from tests.ut_config import (PRETRAINED_MODEL_SEGFORMER,
|
||||
MODEL_CONFIG_SEGFORMER)
|
||||
from easycv.predictors.segmentation import SegFormerPredictor
|
||||
|
||||
|
||||
class SegmentorTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_segformer_detector(self):
|
||||
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
|
||||
segmentation_model_config = MODEL_CONFIG_SEGFORMER
|
||||
|
||||
img = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
||||
if not os.path.exists(img):
|
||||
img = './data/test/segmentation/coco_stuff_164k/val2017/000000289059.jpg'
|
||||
|
||||
input_data_list = [np.asarray(Image.open(img))]
|
||||
predictor = SegFormerPredictor(
|
||||
model_path=segmentation_model_path,
|
||||
model_config=segmentation_model_config)
|
||||
|
||||
output = predictor.predict(input_data_list)[0]
|
||||
self.assertIn('seg_pred', output)
|
||||
|
||||
self.assertListEqual(
|
||||
list(input_data_list[0].shape)[:2],
|
||||
list(output['seg_pred'][0].shape))
|
||||
self.assertListEqual(output['seg_pred'][0][1, :10].tolist(),
|
||||
[161 for i in range(10)])
|
||||
self.assertListEqual(output['seg_pred'][0][-1, -10:].tolist(),
|
||||
[133 for i in range(10)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -120,10 +120,10 @@ PRETRAINED_MODEL_YOLOX_COMPRESSION = os.path.join(
|
|||
BASE_LOCAL_PATH, 'pretrained_models/compression/yolox_compression.pth')
|
||||
PRETRAINED_MODEL_MAE = os.path.join(
|
||||
BASE_LOCAL_PATH, 'pretrained_models/classification/vit/mae_vit_b_1600.pth')
|
||||
PRETRAINED_MODEL_MASK2FORMER = os.path.join(
|
||||
BASE_LOCAL_PATH,
|
||||
'pretrained_models/segmentation/mask2former/mask2former_r50_instance.pth')
|
||||
|
||||
PRETRAINED_MODEL_MASK2FORMER_DIR = os.path.join(
|
||||
BASE_LOCAL_PATH, 'pretrained_models/segmentation/mask2former/')
|
||||
PRETRAINED_MODEL_MASK2FORMER = os.path.join(PRETRAINED_MODEL_MASK2FORMER_DIR,
|
||||
'mask2former_r50_instance.pth')
|
||||
PRETRAINED_MODEL_SEGFORMER = os.path.join(
|
||||
BASE_LOCAL_PATH,
|
||||
'pretrained_models/segmentation/segformer/segformer_b0/SegmentationEvaluator_mIoU_best.pth'
|
||||
|
|
|
@ -21,6 +21,7 @@ except:
|
|||
|
||||
|
||||
from easycv.predictors.builder import build_predictor, PREDICTORS
|
||||
from easycv.utils.constant import CACHE_DIR
|
||||
|
||||
|
||||
def normPRED(d):
|
||||
|
@ -47,8 +48,8 @@ class SODPredictor(object):
|
|||
"""
|
||||
|
||||
def load_url_weights(name, url_index="http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/evtorch_thirdparty/u2net_sod/", map_location=None):
|
||||
os.makedirs('.easycv_cache', exist_ok=True)
|
||||
local_model = os.path.join('.easycv_cache', name+'.pth')
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
local_model = os.path.join(CACHE_DIR, name+'.pth')
|
||||
if os.path.exists(local_model):
|
||||
weights = torch.load(local_model)
|
||||
if weights is not None:
|
||||
|
|
Loading…
Reference in New Issue