[2.0] Support SDK for det and cls (#1073)
* init * support SDK mmcls test * align resize * typo * fix cls data_preprocessorpull/1083/head^2
parent
5c87dd9565
commit
1d46bd752c
|
@ -2,7 +2,6 @@ _base_ = ['./classification_dynamic.py', '../_base_/backends/sdk.py']
|
|||
|
||||
codebase_config = dict(model_type='sdk')
|
||||
|
||||
backend_config = dict(pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
|
||||
])
|
||||
backend_config = dict(
|
||||
pipeline=[dict(type='LoadImageFromFile'),
|
||||
dict(type='PackClsInputs')])
|
||||
|
|
|
@ -4,5 +4,8 @@ codebase_config = dict(model_type='sdk')
|
|||
|
||||
backend_config = dict(pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
|
||||
dict(type='LoadAnnotations', with_bbox=True),
|
||||
dict(
|
||||
type='PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape'))
|
||||
])
|
||||
|
|
|
@ -4,5 +4,8 @@ codebase_config = dict(model_type='sdk', has_mask=True)
|
|||
|
||||
backend_config = dict(pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
|
||||
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
||||
dict(
|
||||
type='PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape'))
|
||||
])
|
||||
|
|
|
@ -96,6 +96,8 @@ class BaseTask(metaclass=ABCMeta):
|
|||
|
||||
model = deepcopy(self.model_cfg.model)
|
||||
preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
|
||||
preprocess_cfg.update(
|
||||
deepcopy(self.model_cfg.get('data_preprocessor', {})))
|
||||
model.setdefault('data_preprocessor', preprocess_cfg)
|
||||
model = MODELS.build(model)
|
||||
if model_checkpoint is not None:
|
||||
|
|
|
@ -47,9 +47,6 @@ def process_model_config(model_cfg: Config,
|
|||
if cfg.test_pipeline[0]['type'] == 'LoadImageFromFile':
|
||||
cfg.test_pipeline.pop(0)
|
||||
# check whether input_shape is valid
|
||||
if 'data_preprocessor' in cfg:
|
||||
cfg.test_pipeline.insert(
|
||||
3, dict(type='Normalize', **cfg['data_preprocessor']))
|
||||
if input_shape is not None:
|
||||
if 'crop_size' in cfg.test_pipeline[2]:
|
||||
crop_size = cfg.test_pipeline[2]['crop_size']
|
||||
|
@ -139,7 +136,7 @@ class Classification(BaseTask):
|
|||
"""
|
||||
from .classification_model import build_classification_model
|
||||
|
||||
data_preprocessor = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
|
||||
data_preprocessor = self.model_cfg.data_preprocessor
|
||||
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
|
||||
|
||||
model = build_classification_model(
|
||||
|
@ -168,18 +165,11 @@ class Classification(BaseTask):
|
|||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
"""
|
||||
|
||||
assert 'test_pipeline' in self.model_cfg, \
|
||||
f'test_pipeline not found in {self.model_cfg}.'
|
||||
model_cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||
assert 'test_pipeline' in model_cfg, \
|
||||
f'test_pipeline not found in {model_cfg}.'
|
||||
from mmengine.dataset import Compose
|
||||
pipeline = deepcopy(model_cfg.test_pipeline)
|
||||
if isinstance(imgs, str):
|
||||
if pipeline[0]['type'] != 'LoadImageFromFile':
|
||||
pipeline.insert(0, dict(type='LoadImageFromFile'))
|
||||
else:
|
||||
if pipeline[0]['type'] == 'LoadImageFromFile':
|
||||
pipeline.pop(0)
|
||||
pipeline = Compose(pipeline)
|
||||
|
||||
if isinstance(imgs, str):
|
||||
|
@ -249,8 +239,40 @@ class Classification(BaseTask):
|
|||
"""
|
||||
input_shape = get_input_shape(self.deploy_cfg)
|
||||
cfg = process_model_config(self.model_cfg, '', input_shape)
|
||||
preprocess = cfg.test_pipeline
|
||||
return preprocess
|
||||
pipeline = cfg.test_pipeline
|
||||
meta_keys = [
|
||||
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
|
||||
'valid_ratio'
|
||||
]
|
||||
transforms = [
|
||||
item for item in pipeline if 'Random' not in item['type']
|
||||
]
|
||||
for i, transform in enumerate(transforms):
|
||||
if transform['type'] == 'PackClsInputs':
|
||||
meta_keys += transform[
|
||||
'meta_keys'] if 'meta_keys' in transform else []
|
||||
transform['meta_keys'] = list(set(meta_keys))
|
||||
transform['keys'] = ['img']
|
||||
transforms[i]['type'] = 'Collect'
|
||||
if transform['type'] == 'Resize':
|
||||
transforms[i]['size'] = transforms[i].pop('scale')
|
||||
if transform['type'] == 'ResizeEdge':
|
||||
transforms[i] = dict(
|
||||
type='Resize',
|
||||
keep_ratio=True,
|
||||
size=(transform['scale'], -1))
|
||||
|
||||
data_preprocessor = self.model_cfg.data_preprocessor
|
||||
transforms.insert(-1, dict(type='ImageToTensor', keys=['img']))
|
||||
transforms.insert(
|
||||
-2,
|
||||
dict(
|
||||
type='Normalize',
|
||||
to_rgb=data_preprocessor.get('to_rgb', False),
|
||||
mean=data_preprocessor.get('mean', [0, 0, 0]),
|
||||
std=data_preprocessor.get('std', [1, 1, 1])))
|
||||
return transforms
|
||||
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
@ -267,7 +289,7 @@ class Classification(BaseTask):
|
|||
else:
|
||||
topk = postprocess.topk
|
||||
postprocess.topk = max(topk)
|
||||
return postprocess
|
||||
return dict(type=postprocess.pop('type'), params=postprocess)
|
||||
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
|
|
@ -90,22 +90,36 @@ class End2EndModel(BaseBackendModel):
|
|||
class SDKEnd2EndModel(End2EndModel):
|
||||
"""SDK inference class, converts SDK output to mmcls format."""
|
||||
|
||||
def forward(self, img: List[torch.Tensor], *args, **kwargs) -> list:
|
||||
def forward(self,
|
||||
inputs: Sequence[torch.Tensor],
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
mode: str = 'predict',
|
||||
*args,
|
||||
**kwargs) -> list:
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
img (List[torch.Tensor]): A list contains input image(s)
|
||||
in [N x C x H x W] format.
|
||||
in [C x H x W] format.
|
||||
*args: Other arguments.
|
||||
**kwargs: Other key-pair arguments.
|
||||
|
||||
Returns:
|
||||
list: A list contains predictions.
|
||||
"""
|
||||
cls_score = []
|
||||
for input in inputs:
|
||||
pred = self.wrapper.invoke(
|
||||
input.permute(1, 2, 0).contiguous().detach().cpu().numpy())
|
||||
pred = np.array(pred, dtype=np.float32)
|
||||
pred = pred[np.argsort(pred[:, 0])][np.newaxis, :, 1]
|
||||
cls_score.append(torch.from_numpy(pred).to(self.device))
|
||||
|
||||
pred = self.wrapper.invoke(img[0].contiguous().detach().cpu().numpy())
|
||||
pred = np.array(pred, dtype=np.float32)
|
||||
return pred[np.argsort(pred[:, 0])][np.newaxis, :, 1]
|
||||
cls_score = torch.cat(cls_score, 0)
|
||||
from mmcls.models.heads.cls_head import ClsHead
|
||||
predict = ClsHead._get_predictions(
|
||||
None, cls_score, data_samples=data_samples)
|
||||
return predict
|
||||
|
||||
|
||||
def build_classification_model(
|
||||
|
|
|
@ -214,8 +214,41 @@ class ObjectDetection(BaseTask):
|
|||
"""
|
||||
input_shape = get_input_shape(self.deploy_cfg)
|
||||
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
|
||||
preprocess = model_cfg.test_pipeline
|
||||
return preprocess
|
||||
pipeline = model_cfg.test_pipeline
|
||||
meta_keys = [
|
||||
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
|
||||
'valid_ratio'
|
||||
]
|
||||
transforms = [
|
||||
item for item in pipeline if 'Random' not in item['type']
|
||||
and 'Annotation' not in item['type']
|
||||
]
|
||||
for i, transform in enumerate(transforms):
|
||||
if transform['type'] == 'PackDetInputs':
|
||||
meta_keys += transform[
|
||||
'meta_keys'] if 'meta_keys' in transform else []
|
||||
transform['meta_keys'] = list(set(meta_keys))
|
||||
transform['keys'] = ['img']
|
||||
transforms[i]['type'] = 'Collect'
|
||||
if transform['type'] == 'Resize':
|
||||
transforms[i]['size'] = transforms[i].pop('scale')
|
||||
|
||||
data_preprocessor = model_cfg.model.data_preprocessor
|
||||
transforms.insert(-1, dict(type='DefaultFormatBundle'))
|
||||
transforms.insert(
|
||||
-2,
|
||||
dict(
|
||||
type='Pad',
|
||||
size_divisor=data_preprocessor.get('pad_size_divisor', 1)))
|
||||
transforms.insert(
|
||||
-3,
|
||||
dict(
|
||||
type='Normalize',
|
||||
to_rgb=data_preprocessor.get('bgr_to_rgb', False),
|
||||
mean=data_preprocessor.get('mean', [0, 0, 0]),
|
||||
std=data_preprocessor.get('std', [1, 1, 1])))
|
||||
return transforms
|
||||
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
@ -223,15 +256,16 @@ class ObjectDetection(BaseTask):
|
|||
Return:
|
||||
dict: Composed of the postprocess information.
|
||||
"""
|
||||
postprocess = self.model_cfg.model.test_cfg
|
||||
if 'rpn' in postprocess:
|
||||
postprocess['min_bbox_size'] = postprocess['rpn']['min_bbox_size']
|
||||
if 'rcnn' in postprocess:
|
||||
postprocess['score_thr'] = postprocess['rcnn']['score_thr']
|
||||
if 'mask_thr_binary' in postprocess['rcnn']:
|
||||
postprocess['mask_thr_binary'] = postprocess['rcnn'][
|
||||
'mask_thr_binary']
|
||||
return postprocess
|
||||
params = self.model_cfg.model.test_cfg
|
||||
type = 'ResizeBBox' # default for object detection
|
||||
if 'rpn' in params:
|
||||
params['min_bbox_size'] = params['rpn']['min_bbox_size']
|
||||
if 'rcnn' in params:
|
||||
params['score_thr'] = params['rcnn']['score_thr']
|
||||
if 'mask_thr_binary' in params['rcnn']:
|
||||
params['mask_thr_binary'] = params['rcnn']['mask_thr_binary']
|
||||
type = 'ResizeInstanceMask' # for instance-seg
|
||||
return dict(type=type, params=params)
|
||||
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get the model name.
|
||||
|
|
|
@ -584,6 +584,7 @@ class SDKEnd2EndModel(End2EndModel):
|
|||
def __init__(self, *args, **kwargs):
|
||||
kwargs['data_preprocessor'] = None
|
||||
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
|
||||
self.has_mask = self.deploy_cfg.codebase_config.get('has_mask', False)
|
||||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
|
@ -607,23 +608,29 @@ class SDKEnd2EndModel(End2EndModel):
|
|||
inputs = inputs.permute(1, 2, 0)
|
||||
dets, labels, masks = self.wrapper.invoke(
|
||||
inputs.contiguous().detach().cpu().numpy())
|
||||
dets = torch.from_numpy(dets).to(self.device).unsqueeze(0)
|
||||
labels = torch.from_numpy(labels).to(torch.int64).to(
|
||||
self.device).unsqueeze(0)
|
||||
predictions = []
|
||||
masks = np.concatenate(masks, 0)
|
||||
for det, label, mask, data_sample in zip(dets, labels, masks,
|
||||
data_samples):
|
||||
pred_instances = InstanceData()
|
||||
pred_instances.scores = det[..., 4]
|
||||
pred_instances.bboxes = det[..., :4]
|
||||
pred_instances.labels = label
|
||||
pred_instances.masks = torch.from_numpy(mask).\
|
||||
to(self.device).unsqueeze(0)
|
||||
result = InstanceData()
|
||||
if self.has_mask:
|
||||
segm_results = []
|
||||
ori_h, ori_w = data_samples[0].ori_shape[:2]
|
||||
for bbox, mask in zip(dets, masks):
|
||||
img_mask = np.zeros((ori_h, ori_w), dtype=np.uint8)
|
||||
left = int(max(np.floor(bbox[0]) - 1, 0))
|
||||
top = int(max(np.floor(bbox[1]) - 1, 0))
|
||||
img_mask[top:top + mask.shape[0],
|
||||
left:left + mask.shape[1]] = mask
|
||||
segm_results.append(torch.from_numpy(img_mask))
|
||||
if len(segm_results) > 0:
|
||||
result.masks = torch.stack(segm_results, 0).to(self.device)
|
||||
else:
|
||||
result.masks = torch.zeros([0, ori_h, ori_w]).to(self.device)
|
||||
dets = torch.from_numpy(dets).to(self.device)
|
||||
labels = torch.from_numpy(labels).to(torch.int64).to(self.device)
|
||||
result.bboxes = dets[:, :4]
|
||||
result.scores = dets[:, 4]
|
||||
result.labels = labels
|
||||
data_samples[0].pred_instances = result
|
||||
|
||||
data_sample.pred_instances = pred_instances
|
||||
predictions.append(data_sample)
|
||||
return predictions
|
||||
return data_samples
|
||||
|
||||
|
||||
def build_object_detection_model(
|
||||
|
|
|
@ -238,7 +238,7 @@ class TextDetection(BaseTask):
|
|||
transform['keys'] = ['img']
|
||||
transforms[i]['type'] = 'Collect'
|
||||
if transform['type'] == 'Resize':
|
||||
transforms[i]['size'] = transforms[i]['scale']
|
||||
transforms[i]['size'] = transforms[i].pop('scale')
|
||||
|
||||
data_preprocessor = model_cfg.model.data_preprocessor
|
||||
transforms.insert(-1, dict(type='DefaultFormatBundle'))
|
||||
|
|
|
@ -244,7 +244,7 @@ class TextRecognition(BaseTask):
|
|||
transform['keys'] = ['img']
|
||||
transforms[i]['type'] = 'Collect'
|
||||
if transform['type'] == 'Resize':
|
||||
transforms[i]['size'] = transforms[i]['scale']
|
||||
transforms[i]['size'] = transforms[i].pop('scale')
|
||||
|
||||
data_preprocessor = model_cfg.model.data_preprocessor
|
||||
transforms.insert(-1, dict(type='DefaultFormatBundle'))
|
||||
|
|
Loading…
Reference in New Issue