[2.0] Support SDK for det and cls (#1073)

* init

* support SDK mmcls test

* align resize

* typo

* fix cls data_preprocessor
pull/1083/head^2
AllentDan 2022-09-21 14:17:23 +08:00 committed by GitHub
parent 5c87dd9565
commit 1d46bd752c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 140 additions and 56 deletions

View File

@ -2,7 +2,6 @@ _base_ = ['./classification_dynamic.py', '../_base_/backends/sdk.py']
codebase_config = dict(model_type='sdk')
backend_config = dict(pipeline=[
dict(type='LoadImageFromFile'),
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
])
backend_config = dict(
pipeline=[dict(type='LoadImageFromFile'),
dict(type='PackClsInputs')])

View File

@ -4,5 +4,8 @@ codebase_config = dict(model_type='sdk')
backend_config = dict(pipeline=[
dict(type='LoadImageFromFile'),
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape'))
])

View File

@ -4,5 +4,8 @@ codebase_config = dict(model_type='sdk', has_mask=True)
backend_config = dict(pipeline=[
dict(type='LoadImageFromFile'),
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape'))
])

View File

@ -96,6 +96,8 @@ class BaseTask(metaclass=ABCMeta):
model = deepcopy(self.model_cfg.model)
preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
preprocess_cfg.update(
deepcopy(self.model_cfg.get('data_preprocessor', {})))
model.setdefault('data_preprocessor', preprocess_cfg)
model = MODELS.build(model)
if model_checkpoint is not None:

View File

@ -47,9 +47,6 @@ def process_model_config(model_cfg: Config,
if cfg.test_pipeline[0]['type'] == 'LoadImageFromFile':
cfg.test_pipeline.pop(0)
# check whether input_shape is valid
if 'data_preprocessor' in cfg:
cfg.test_pipeline.insert(
3, dict(type='Normalize', **cfg['data_preprocessor']))
if input_shape is not None:
if 'crop_size' in cfg.test_pipeline[2]:
crop_size = cfg.test_pipeline[2]['crop_size']
@ -139,7 +136,7 @@ class Classification(BaseTask):
"""
from .classification_model import build_classification_model
data_preprocessor = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
data_preprocessor = self.model_cfg.data_preprocessor
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
model = build_classification_model(
@ -168,18 +165,11 @@ class Classification(BaseTask):
Returns:
tuple: (data, img), meta information for the input image and input.
"""
assert 'test_pipeline' in self.model_cfg, \
f'test_pipeline not found in {self.model_cfg}.'
model_cfg = process_model_config(self.model_cfg, imgs, input_shape)
assert 'test_pipeline' in model_cfg, \
f'test_pipeline not found in {model_cfg}.'
from mmengine.dataset import Compose
pipeline = deepcopy(model_cfg.test_pipeline)
if isinstance(imgs, str):
if pipeline[0]['type'] != 'LoadImageFromFile':
pipeline.insert(0, dict(type='LoadImageFromFile'))
else:
if pipeline[0]['type'] == 'LoadImageFromFile':
pipeline.pop(0)
pipeline = Compose(pipeline)
if isinstance(imgs, str):
@ -249,8 +239,40 @@ class Classification(BaseTask):
"""
input_shape = get_input_shape(self.deploy_cfg)
cfg = process_model_config(self.model_cfg, '', input_shape)
preprocess = cfg.test_pipeline
return preprocess
pipeline = cfg.test_pipeline
meta_keys = [
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
'valid_ratio'
]
transforms = [
item for item in pipeline if 'Random' not in item['type']
]
for i, transform in enumerate(transforms):
if transform['type'] == 'PackClsInputs':
meta_keys += transform[
'meta_keys'] if 'meta_keys' in transform else []
transform['meta_keys'] = list(set(meta_keys))
transform['keys'] = ['img']
transforms[i]['type'] = 'Collect'
if transform['type'] == 'Resize':
transforms[i]['size'] = transforms[i].pop('scale')
if transform['type'] == 'ResizeEdge':
transforms[i] = dict(
type='Resize',
keep_ratio=True,
size=(transform['scale'], -1))
data_preprocessor = self.model_cfg.data_preprocessor
transforms.insert(-1, dict(type='ImageToTensor', keys=['img']))
transforms.insert(
-2,
dict(
type='Normalize',
to_rgb=data_preprocessor.get('to_rgb', False),
mean=data_preprocessor.get('mean', [0, 0, 0]),
std=data_preprocessor.get('std', [1, 1, 1])))
return transforms
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
@ -267,7 +289,7 @@ class Classification(BaseTask):
else:
topk = postprocess.topk
postprocess.topk = max(topk)
return postprocess
return dict(type=postprocess.pop('type'), params=postprocess)
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.

View File

@ -90,22 +90,36 @@ class End2EndModel(BaseBackendModel):
class SDKEnd2EndModel(End2EndModel):
"""SDK inference class, converts SDK output to mmcls format."""
def forward(self, img: List[torch.Tensor], *args, **kwargs) -> list:
def forward(self,
inputs: Sequence[torch.Tensor],
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'predict',
*args,
**kwargs) -> list:
"""Run forward inference.
Args:
img (List[torch.Tensor]): A list contains input image(s)
in [N x C x H x W] format.
in [C x H x W] format.
*args: Other arguments.
**kwargs: Other key-pair arguments.
Returns:
list: A list contains predictions.
"""
cls_score = []
for input in inputs:
pred = self.wrapper.invoke(
input.permute(1, 2, 0).contiguous().detach().cpu().numpy())
pred = np.array(pred, dtype=np.float32)
pred = pred[np.argsort(pred[:, 0])][np.newaxis, :, 1]
cls_score.append(torch.from_numpy(pred).to(self.device))
pred = self.wrapper.invoke(img[0].contiguous().detach().cpu().numpy())
pred = np.array(pred, dtype=np.float32)
return pred[np.argsort(pred[:, 0])][np.newaxis, :, 1]
cls_score = torch.cat(cls_score, 0)
from mmcls.models.heads.cls_head import ClsHead
predict = ClsHead._get_predictions(
None, cls_score, data_samples=data_samples)
return predict
def build_classification_model(

View File

@ -214,8 +214,41 @@ class ObjectDetection(BaseTask):
"""
input_shape = get_input_shape(self.deploy_cfg)
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
preprocess = model_cfg.test_pipeline
return preprocess
pipeline = model_cfg.test_pipeline
meta_keys = [
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
'valid_ratio'
]
transforms = [
item for item in pipeline if 'Random' not in item['type']
and 'Annotation' not in item['type']
]
for i, transform in enumerate(transforms):
if transform['type'] == 'PackDetInputs':
meta_keys += transform[
'meta_keys'] if 'meta_keys' in transform else []
transform['meta_keys'] = list(set(meta_keys))
transform['keys'] = ['img']
transforms[i]['type'] = 'Collect'
if transform['type'] == 'Resize':
transforms[i]['size'] = transforms[i].pop('scale')
data_preprocessor = model_cfg.model.data_preprocessor
transforms.insert(-1, dict(type='DefaultFormatBundle'))
transforms.insert(
-2,
dict(
type='Pad',
size_divisor=data_preprocessor.get('pad_size_divisor', 1)))
transforms.insert(
-3,
dict(
type='Normalize',
to_rgb=data_preprocessor.get('bgr_to_rgb', False),
mean=data_preprocessor.get('mean', [0, 0, 0]),
std=data_preprocessor.get('std', [1, 1, 1])))
return transforms
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
@ -223,15 +256,16 @@ class ObjectDetection(BaseTask):
Return:
dict: Composed of the postprocess information.
"""
postprocess = self.model_cfg.model.test_cfg
if 'rpn' in postprocess:
postprocess['min_bbox_size'] = postprocess['rpn']['min_bbox_size']
if 'rcnn' in postprocess:
postprocess['score_thr'] = postprocess['rcnn']['score_thr']
if 'mask_thr_binary' in postprocess['rcnn']:
postprocess['mask_thr_binary'] = postprocess['rcnn'][
'mask_thr_binary']
return postprocess
params = self.model_cfg.model.test_cfg
type = 'ResizeBBox' # default for object detection
if 'rpn' in params:
params['min_bbox_size'] = params['rpn']['min_bbox_size']
if 'rcnn' in params:
params['score_thr'] = params['rcnn']['score_thr']
if 'mask_thr_binary' in params['rcnn']:
params['mask_thr_binary'] = params['rcnn']['mask_thr_binary']
type = 'ResizeInstanceMask' # for instance-seg
return dict(type=type, params=params)
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.

View File

@ -584,6 +584,7 @@ class SDKEnd2EndModel(End2EndModel):
def __init__(self, *args, **kwargs):
kwargs['data_preprocessor'] = None
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
self.has_mask = self.deploy_cfg.codebase_config.get('has_mask', False)
def forward(self,
inputs: torch.Tensor,
@ -607,23 +608,29 @@ class SDKEnd2EndModel(End2EndModel):
inputs = inputs.permute(1, 2, 0)
dets, labels, masks = self.wrapper.invoke(
inputs.contiguous().detach().cpu().numpy())
dets = torch.from_numpy(dets).to(self.device).unsqueeze(0)
labels = torch.from_numpy(labels).to(torch.int64).to(
self.device).unsqueeze(0)
predictions = []
masks = np.concatenate(masks, 0)
for det, label, mask, data_sample in zip(dets, labels, masks,
data_samples):
pred_instances = InstanceData()
pred_instances.scores = det[..., 4]
pred_instances.bboxes = det[..., :4]
pred_instances.labels = label
pred_instances.masks = torch.from_numpy(mask).\
to(self.device).unsqueeze(0)
result = InstanceData()
if self.has_mask:
segm_results = []
ori_h, ori_w = data_samples[0].ori_shape[:2]
for bbox, mask in zip(dets, masks):
img_mask = np.zeros((ori_h, ori_w), dtype=np.uint8)
left = int(max(np.floor(bbox[0]) - 1, 0))
top = int(max(np.floor(bbox[1]) - 1, 0))
img_mask[top:top + mask.shape[0],
left:left + mask.shape[1]] = mask
segm_results.append(torch.from_numpy(img_mask))
if len(segm_results) > 0:
result.masks = torch.stack(segm_results, 0).to(self.device)
else:
result.masks = torch.zeros([0, ori_h, ori_w]).to(self.device)
dets = torch.from_numpy(dets).to(self.device)
labels = torch.from_numpy(labels).to(torch.int64).to(self.device)
result.bboxes = dets[:, :4]
result.scores = dets[:, 4]
result.labels = labels
data_samples[0].pred_instances = result
data_sample.pred_instances = pred_instances
predictions.append(data_sample)
return predictions
return data_samples
def build_object_detection_model(

View File

@ -238,7 +238,7 @@ class TextDetection(BaseTask):
transform['keys'] = ['img']
transforms[i]['type'] = 'Collect'
if transform['type'] == 'Resize':
transforms[i]['size'] = transforms[i]['scale']
transforms[i]['size'] = transforms[i].pop('scale')
data_preprocessor = model_cfg.model.data_preprocessor
transforms.insert(-1, dict(type='DefaultFormatBundle'))

View File

@ -244,7 +244,7 @@ class TextRecognition(BaseTask):
transform['keys'] = ['img']
transforms[i]['type'] = 'Collect'
if transform['type'] == 'Resize':
transforms[i]['size'] = transforms[i]['scale']
transforms[i]['size'] = transforms[i].pop('scale')
data_preprocessor = model_cfg.model.data_preprocessor
transforms.insert(-1, dict(type='DefaultFormatBundle'))